#!/usr/bin/env python3 """Generate a demo image with DiCo class-conditional checkpoints.""" from __future__ import annotations import argparse from pathlib import Path import torch from diffusers import DiffusionPipeline REPO_ROOT = Path(__file__).resolve().parent VARIANTS = { "xl": { "dir": REPO_ROOT / "DiCo-XL-256", "class_label": "golden retriever", "num_inference_steps": 250, "guidance_scale": 1.4, "seed": 0, }, "s": { "dir": REPO_ROOT / "DiCo-S-256", "class_label": "golden retriever", "num_inference_steps": 250, "guidance_scale": 1.0, "seed": 0, }, "b": { "dir": REPO_ROOT / "DiCo-B-256", "class_label": "golden retriever", "num_inference_steps": 250, "guidance_scale": 1.0, "seed": 0, }, "l": { "dir": REPO_ROOT / "DiCo-L-256", "class_label": "golden retriever", "num_inference_steps": 250, "guidance_scale": 1.0, "seed": 0, }, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run DiCo demo inference.") parser.add_argument( "--variant", choices=sorted(VARIANTS), default="xl", help="Checkpoint variant to sample (default: xl).", ) return parser.parse_args() def main() -> None: args = parse_args() settings = VARIANTS[args.variant] model_dir = settings["dir"] output_path = model_dir / "demo.png" pipe = DiffusionPipeline.from_pretrained( str(model_dir), local_files_only=True, custom_pipeline=str(model_dir / "pipeline.py"), trust_remote_code=True, torch_dtype=torch.bfloat16, ) pipe.to("cuda") print(f"[{args.variant}] {settings['class_label']} -> {pipe.get_label_ids(settings['class_label'])}") generator = torch.Generator(device="cuda").manual_seed(settings["seed"]) image = pipe( class_labels=settings["class_label"], height=256, width=256, num_inference_steps=settings["num_inference_steps"], guidance_scale=settings["guidance_scale"], generator=generator, ).images[0] image.save(output_path) print(f"Saved demo image to {output_path}") if __name__ == "__main__": main()