DiCo-diffusers / demo_inference.py
BiliSakura's picture
Upload folder using huggingface_hub
28463c6 verified
Raw
History Blame Contribute Delete
2.33 kB
#!/usr/bin/env python3
"""Generate a demo image with DiCo class-conditional checkpoints."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from diffusers import DiffusionPipeline
REPO_ROOT = Path(__file__).resolve().parent
VARIANTS = {
"xl": {
"dir": REPO_ROOT / "DiCo-XL-256",
"class_label": "golden retriever",
"num_inference_steps": 250,
"guidance_scale": 1.4,
"seed": 0,
},
"s": {
"dir": REPO_ROOT / "DiCo-S-256",
"class_label": "golden retriever",
"num_inference_steps": 250,
"guidance_scale": 1.0,
"seed": 0,
},
"b": {
"dir": REPO_ROOT / "DiCo-B-256",
"class_label": "golden retriever",
"num_inference_steps": 250,
"guidance_scale": 1.0,
"seed": 0,
},
"l": {
"dir": REPO_ROOT / "DiCo-L-256",
"class_label": "golden retriever",
"num_inference_steps": 250,
"guidance_scale": 1.0,
"seed": 0,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run DiCo demo inference.")
parser.add_argument(
"--variant",
choices=sorted(VARIANTS),
default="xl",
help="Checkpoint variant to sample (default: xl).",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
settings = VARIANTS[args.variant]
model_dir = settings["dir"]
output_path = model_dir / "demo.png"
pipe = DiffusionPipeline.from_pretrained(
str(model_dir),
local_files_only=True,
custom_pipeline=str(model_dir / "pipeline.py"),
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
print(f"[{args.variant}] {settings['class_label']} -> {pipe.get_label_ids(settings['class_label'])}")
generator = torch.Generator(device="cuda").manual_seed(settings["seed"])
image = pipe(
class_labels=settings["class_label"],
height=256,
width=256,
num_inference_steps=settings["num_inference_steps"],
guidance_scale=settings["guidance_scale"],
generator=generator,
).images[0]
image.save(output_path)
print(f"Saved demo image to {output_path}")
if __name__ == "__main__":
main()