BiliSakura commited on
Commit
7e723ae
·
verified ·
1 Parent(s): deb2428

Delete test_inference.py

Browse files
Files changed (1) hide show
  1. test_inference.py +0 -96
test_inference.py DELETED
@@ -1,96 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Standalone inference script for the NiT-XL Diffusers checkpoint.
4
-
5
- This script only uses code vendored in this model repository:
6
- `custom_pipeline/` for NiT pipeline, transformer, and scheduler classes.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import argparse
12
- from pathlib import Path
13
-
14
- import torch
15
- from diffusers import DiffusionPipeline
16
-
17
-
18
- def parse_args() -> argparse.Namespace:
19
- parser = argparse.ArgumentParser(description="Run class-conditional NiT-XL inference.")
20
- parser.add_argument(
21
- "--model-dir",
22
- type=Path,
23
- default=Path(__file__).resolve().parent,
24
- help="Path to model repository root.",
25
- )
26
- parser.add_argument("--class-label", type=int, default=207, help="ImageNet class label to sample.")
27
- parser.add_argument("--height", type=int, default=512, help="Output image height.")
28
- parser.add_argument("--width", type=int, default=512, help="Output image width.")
29
- parser.add_argument("--steps", type=int, default=250, help="Number of inference steps.")
30
- parser.add_argument("--mode", choices=["ode", "sde"], default="sde", help="Sampling mode.")
31
- parser.add_argument("--guidance-scale", type=float, default=2.05, help="Classifier-free guidance scale.")
32
- parser.add_argument("--guidance-low", type=float, default=0.0, help="Guidance start timestep fraction.")
33
- parser.add_argument("--guidance-high", type=float, default=0.7, help="Guidance end timestep fraction.")
34
- parser.add_argument("--heun", action="store_true", help="Enable Heun correction for ODE mode.")
35
- parser.add_argument("--seed", type=int, default=42, help="Random seed.")
36
- parser.add_argument(
37
- "--output",
38
- type=Path,
39
- default=Path("demo_images/demo_sde250_class207_seed42.png"),
40
- help="Output image path relative to model dir, or absolute path.",
41
- )
42
- return parser.parse_args()
43
-
44
-
45
- def resolve_output_path(model_dir: Path, output: Path) -> Path:
46
- if output.is_absolute():
47
- return output
48
- return model_dir / output
49
-
50
-
51
- def main() -> None:
52
- args = parse_args()
53
- model_dir = args.model_dir.resolve()
54
- custom_dir = model_dir / "custom_pipeline"
55
- if not custom_dir.exists():
56
- raise FileNotFoundError(f"Missing custom pipeline dir: {custom_dir}")
57
- if not (model_dir / "pipeline.py").exists():
58
- raise FileNotFoundError(f"Missing custom entrypoint: {model_dir / 'pipeline.py'}")
59
-
60
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
- torch_dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
62
- generator_device = device.type if device.type != "cpu" else "cpu"
63
- generator = torch.Generator(device=generator_device).manual_seed(args.seed)
64
-
65
- pipe = DiffusionPipeline.from_pretrained(
66
- model_dir,
67
- custom_pipeline=str(model_dir / "pipeline.py"),
68
- local_files_only=True,
69
- ).to(device=device)
70
- if device.type == "cuda":
71
- pipe.transformer.to(dtype=torch_dtype)
72
- pipe.vae.to(dtype=torch_dtype)
73
-
74
- output = pipe(
75
- class_labels=[args.class_label],
76
- height=args.height,
77
- width=args.width,
78
- num_inference_steps=args.steps,
79
- mode=args.mode,
80
- guidance_scale=args.guidance_scale,
81
- guidance_interval=(args.guidance_low, args.guidance_high),
82
- heun=args.heun,
83
- generator=generator,
84
- output_type="pil",
85
- )
86
-
87
- output_path = resolve_output_path(model_dir, args.output)
88
- output_path.parent.mkdir(parents=True, exist_ok=True)
89
- output.images[0].save(output_path)
90
-
91
- print(f"Saved image to: {output_path}")
92
- print(f"Device: {device} | dtype: {torch_dtype}")
93
-
94
-
95
- if __name__ == "__main__":
96
- main()