#!/usr/bin/env python3 """Generate demo images for PixelDiT class-conditional checkpoints.""" from __future__ import annotations import argparse from pathlib import Path import torch from diffusers import DiffusionPipeline REPO_ROOT = Path(__file__).resolve().parent VARIANTS = { "256": { "dir": REPO_ROOT / "PixelDiT-XL-16-256", "height": 256, "width": 256, "num_inference_steps": 100, "guidance_scale": 3.25, "class_label": "golden retriever", "seed": 7, }, "512": { "dir": REPO_ROOT / "PixelDiT-XL-16-512", "height": 512, "width": 512, "num_inference_steps": 100, "guidance_scale": 3.75, "class_label": "golden retriever", "seed": 7, }, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run PixelDiT demo inference.") parser.add_argument( "--variant", choices=sorted(VARIANTS), default="256", help="Checkpoint resolution variant to sample.", ) parser.add_argument( "--all", action="store_true", help="Generate demo.png for every supported variant.", ) return parser.parse_args() def run_variant(name: str) -> Path: settings = VARIANTS[name] model_dir = settings["dir"] output_path = model_dir / "demo.png" pipe = DiffusionPipeline.from_pretrained( str(model_dir), local_files_only=True, custom_pipeline=str(model_dir / "pipeline.py"), trust_remote_code=True, torch_dtype=torch.bfloat16, ) pipe.to("cuda") print(f"[{name}] {settings['class_label']} -> {pipe.get_label_ids(settings['class_label'])}") print(f"[{name}] scheduler shift={pipe.scheduler.config.shift}") generator = torch.Generator(device="cuda").manual_seed(settings["seed"]) image = pipe( class_labels=settings["class_label"], height=settings["height"], width=settings["width"], num_inference_steps=settings["num_inference_steps"], guidance_scale=settings["guidance_scale"], guidance_interval_min=0.1, guidance_interval_max=1.0, generator=generator, ).images[0] image.save(output_path) print(f"[{name}] Saved demo image to {output_path}") return output_path def main() -> None: args = parse_args() if args.all: for name in VARIANTS: run_variant(name) return run_variant(args.variant) if __name__ == "__main__": main()