| |
| """ |
| Standalone inference script for the NiT-XL Diffusers checkpoint. |
| |
| This script only uses code vendored in this model repository: |
| `custom_pipeline/` for NiT pipeline, transformer, and scheduler classes. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from diffusers import DiffusionPipeline |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run class-conditional NiT-XL inference.") |
| parser.add_argument( |
| "--model-dir", |
| type=Path, |
| default=Path(__file__).resolve().parent, |
| help="Path to model repository root.", |
| ) |
| parser.add_argument("--class-label", type=int, default=207, help="ImageNet class label to sample.") |
| parser.add_argument("--height", type=int, default=512, help="Output image height.") |
| parser.add_argument("--width", type=int, default=512, help="Output image width.") |
| parser.add_argument("--steps", type=int, default=250, help="Number of inference steps.") |
| parser.add_argument("--mode", choices=["ode", "sde"], default="sde", help="Sampling mode.") |
| parser.add_argument("--guidance-scale", type=float, default=2.05, help="Classifier-free guidance scale.") |
| parser.add_argument("--guidance-low", type=float, default=0.0, help="Guidance start timestep fraction.") |
| parser.add_argument("--guidance-high", type=float, default=0.7, help="Guidance end timestep fraction.") |
| parser.add_argument("--heun", action="store_true", help="Enable Heun correction for ODE mode.") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed.") |
| parser.add_argument( |
| "--output", |
| type=Path, |
| default=Path("demo_images/demo_sde250_class207_seed42.png"), |
| help="Output image path relative to model dir, or absolute path.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def resolve_output_path(model_dir: Path, output: Path) -> Path: |
| if output.is_absolute(): |
| return output |
| return model_dir / output |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| model_dir = args.model_dir.resolve() |
| custom_dir = model_dir / "custom_pipeline" |
| if not custom_dir.exists(): |
| raise FileNotFoundError(f"Missing custom pipeline dir: {custom_dir}") |
| if not (model_dir / "pipeline.py").exists(): |
| raise FileNotFoundError(f"Missing custom entrypoint: {model_dir / 'pipeline.py'}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch_dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32 |
| generator_device = device.type if device.type != "cpu" else "cpu" |
| generator = torch.Generator(device=generator_device).manual_seed(args.seed) |
|
|
| pipe = DiffusionPipeline.from_pretrained( |
| model_dir, |
| custom_pipeline=str(model_dir / "pipeline.py"), |
| local_files_only=True, |
| ).to(device=device) |
| if device.type == "cuda": |
| pipe.transformer.to(dtype=torch_dtype) |
| pipe.vae.to(dtype=torch_dtype) |
|
|
| output = pipe( |
| class_labels=[args.class_label], |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.steps, |
| mode=args.mode, |
| guidance_scale=args.guidance_scale, |
| guidance_interval=(args.guidance_low, args.guidance_high), |
| heun=args.heun, |
| generator=generator, |
| output_type="pil", |
| ) |
|
|
| output_path = resolve_output_path(model_dir, args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output.images[0].save(output_path) |
|
|
| print(f"Saved image to: {output_path}") |
| print(f"Device: {device} | dtype: {torch_dtype}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|