JiT-diffusers / demo_inference.py
BiliSakura's picture
Upload folder using huggingface_hub
5673750 verified
#!/usr/bin/env python3
"""Generate a demo image with JiT-H-32."""
from pathlib import Path
import torch
from diffusers import DiffusionPipeline, FlowMatchHeunDiscreteScheduler
REPO_ROOT = Path(__file__).resolve().parent
MODEL_DIR = REPO_ROOT / "JiT-H-32"
OUTPUT_PATH = REPO_ROOT / "demo.png"
def main() -> None:
pipe = DiffusionPipeline.from_pretrained(
str(MODEL_DIR),
custom_pipeline=str(MODEL_DIR / "pipeline.py"),
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(pipe.scheduler.config, shift=4.0)
pipe.to("cuda")
pipe.set_progress_bar_config(disable=False)
print(pipe.id2label[207])
print(pipe.get_label_ids("golden retriever"))
generator = torch.Generator(device="cuda").manual_seed(42)
image = pipe(
class_labels="golden retriever",
num_inference_steps=50,
guidance_scale=2.3,
generator=generator,
).images[0]
image.save(OUTPUT_PATH)
print(f"Saved demo image to {OUTPUT_PATH}")
if __name__ == "__main__":
main()