File size: 3,719 Bytes
f0eba3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
"""
Standalone inference script for the NiT-XL Diffusers checkpoint.

This script only uses code vendored in this model repository:
`custom_pipeline/` for NiT pipeline, transformer, and scheduler classes.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch
from diffusers import DiffusionPipeline


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run class-conditional NiT-XL inference.")
    parser.add_argument(
        "--model-dir",
        type=Path,
        default=Path(__file__).resolve().parent,
        help="Path to model repository root.",
    )
    parser.add_argument("--class-label", type=int, default=207, help="ImageNet class label to sample.")
    parser.add_argument("--height", type=int, default=512, help="Output image height.")
    parser.add_argument("--width", type=int, default=512, help="Output image width.")
    parser.add_argument("--steps", type=int, default=250, help="Number of inference steps.")
    parser.add_argument("--mode", choices=["ode", "sde"], default="sde", help="Sampling mode.")
    parser.add_argument("--guidance-scale", type=float, default=2.05, help="Classifier-free guidance scale.")
    parser.add_argument("--guidance-low", type=float, default=0.0, help="Guidance start timestep fraction.")
    parser.add_argument("--guidance-high", type=float, default=0.7, help="Guidance end timestep fraction.")
    parser.add_argument("--heun", action="store_true", help="Enable Heun correction for ODE mode.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("demo_images/demo_sde250_class207_seed42.png"),
        help="Output image path relative to model dir, or absolute path.",
    )
    return parser.parse_args()


def resolve_output_path(model_dir: Path, output: Path) -> Path:
    if output.is_absolute():
        return output
    return model_dir / output


def main() -> None:
    args = parse_args()
    model_dir = args.model_dir.resolve()
    custom_dir = model_dir / "custom_pipeline"
    if not custom_dir.exists():
        raise FileNotFoundError(f"Missing custom pipeline dir: {custom_dir}")
    if not (model_dir / "pipeline.py").exists():
        raise FileNotFoundError(f"Missing custom entrypoint: {model_dir / 'pipeline.py'}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
    generator_device = device.type if device.type != "cpu" else "cpu"
    generator = torch.Generator(device=generator_device).manual_seed(args.seed)

    pipe = DiffusionPipeline.from_pretrained(
        model_dir,
        custom_pipeline=str(model_dir / "pipeline.py"),
        local_files_only=True,
    ).to(device=device)
    if device.type == "cuda":
        pipe.transformer.to(dtype=torch_dtype)
        pipe.vae.to(dtype=torch_dtype)

    output = pipe(
        class_labels=[args.class_label],
        height=args.height,
        width=args.width,
        num_inference_steps=args.steps,
        mode=args.mode,
        guidance_scale=args.guidance_scale,
        guidance_interval=(args.guidance_low, args.guidance_high),
        heun=args.heun,
        generator=generator,
        output_type="pil",
    )

    output_path = resolve_output_path(model_dir, args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output.images[0].save(output_path)

    print(f"Saved image to: {output_path}")
    print(f"Device: {device} | dtype: {torch_dtype}")


if __name__ == "__main__":
    main()