PixelDiT-diffusers / demo_inference.py
BiliSakura's picture
Upload folder using huggingface_hub
fbad450 verified
Raw
History Blame Contribute Delete
2.55 kB
#!/usr/bin/env python3
"""Generate demo images for PixelDiT class-conditional checkpoints."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from diffusers import DiffusionPipeline
REPO_ROOT = Path(__file__).resolve().parent
VARIANTS = {
"256": {
"dir": REPO_ROOT / "PixelDiT-XL-16-256",
"height": 256,
"width": 256,
"num_inference_steps": 100,
"guidance_scale": 3.25,
"class_label": "golden retriever",
"seed": 7,
},
"512": {
"dir": REPO_ROOT / "PixelDiT-XL-16-512",
"height": 512,
"width": 512,
"num_inference_steps": 100,
"guidance_scale": 3.75,
"class_label": "golden retriever",
"seed": 7,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run PixelDiT demo inference.")
parser.add_argument(
"--variant",
choices=sorted(VARIANTS),
default="256",
help="Checkpoint resolution variant to sample.",
)
parser.add_argument(
"--all",
action="store_true",
help="Generate demo.png for every supported variant.",
)
return parser.parse_args()
def run_variant(name: str) -> Path:
settings = VARIANTS[name]
model_dir = settings["dir"]
output_path = model_dir / "demo.png"
pipe = DiffusionPipeline.from_pretrained(
str(model_dir),
local_files_only=True,
custom_pipeline=str(model_dir / "pipeline.py"),
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
print(f"[{name}] {settings['class_label']} -> {pipe.get_label_ids(settings['class_label'])}")
print(f"[{name}] scheduler shift={pipe.scheduler.config.shift}")
generator = torch.Generator(device="cuda").manual_seed(settings["seed"])
image = pipe(
class_labels=settings["class_label"],
height=settings["height"],
width=settings["width"],
num_inference_steps=settings["num_inference_steps"],
guidance_scale=settings["guidance_scale"],
guidance_interval_min=0.1,
guidance_interval_max=1.0,
generator=generator,
).images[0]
image.save(output_path)
print(f"[{name}] Saved demo image to {output_path}")
return output_path
def main() -> None:
args = parse_args()
if args.all:
for name in VARIANTS:
run_variant(name)
return
run_variant(args.variant)
if __name__ == "__main__":
main()