BiliSakura commited on
Commit
9392475
·
verified ·
1 Parent(s): b590587

Delete demo_inference.py

Browse files
Files changed (1) hide show
  1. demo_inference.py +0 -161
demo_inference.py DELETED
@@ -1,161 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Smoke-test MVSplit-DiT inference from the converted Diffusers Hub folder."""
3
-
4
- from __future__ import annotations
5
-
6
- import argparse
7
- import importlib.util
8
- import sys
9
- from pathlib import Path
10
-
11
- import torch
12
- from diffusers import AutoencoderKLFlux2
13
- from transformers import AutoModel, AutoTokenizer
14
-
15
-
16
- def parse_args() -> argparse.Namespace:
17
- parser = argparse.ArgumentParser(description="Run MVSplit-DiT inference.")
18
- parser.add_argument(
19
- "--model",
20
- type=Path,
21
- default=Path(__file__).resolve().parent,
22
- help="Path to MVSplit-DiT-1000L pipeline directory.",
23
- )
24
- parser.add_argument(
25
- "--prompt",
26
- type=str,
27
- default="a red panda climbing a bamboo stalk",
28
- help="Text prompt for generation.",
29
- )
30
- parser.add_argument("--height", type=int, default=256)
31
- parser.add_argument("--width", type=int, default=256)
32
- parser.add_argument("--num-inference-steps", type=int, default=35)
33
- parser.add_argument("--guidance-scale", type=float, default=2.0)
34
- parser.add_argument("--time-shift-alpha", type=float, default=4.0)
35
- parser.add_argument("--seed", type=int, default=42)
36
- parser.add_argument(
37
- "--output",
38
- type=Path,
39
- default=Path(__file__).resolve().parent / "demo.png",
40
- help="Output image path. Ignored when --output-type=latent.",
41
- )
42
- parser.add_argument(
43
- "--output-type",
44
- choices=("pil", "latent"),
45
- default="pil",
46
- help="Return decoded image or raw latents.",
47
- )
48
- parser.add_argument(
49
- "--skip-vae",
50
- action="store_true",
51
- help="Skip VAE decode even when output-type=pil (saves memory).",
52
- )
53
- parser.add_argument(
54
- "--device",
55
- choices=("auto", "cuda", "cpu"),
56
- default="auto",
57
- help="Execution device. auto prefers CUDA when available.",
58
- )
59
- parser.add_argument(
60
- "--cpu-offload",
61
- action="store_true",
62
- help="Use sequential CPU offload instead of keeping the pipeline on GPU.",
63
- )
64
- return parser.parse_args()
65
-
66
-
67
- def _resolve_device(choice: str) -> torch.device:
68
- if choice == "auto":
69
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
- return torch.device(choice)
71
-
72
-
73
- def _load_pipeline_class(model_dir: Path):
74
- transformer_path = model_dir / "transformer" / "transformer_mvsplit_dit.py"
75
- spec = importlib.util.spec_from_file_location("transformer_mvsplit_dit", transformer_path)
76
- module = importlib.util.module_from_spec(spec)
77
- sys.modules[spec.name] = module
78
- spec.loader.exec_module(module)
79
-
80
- pipe_spec = importlib.util.spec_from_file_location("mvsplit_pipeline", model_dir / "pipeline.py")
81
- pipe_module = importlib.util.module_from_spec(pipe_spec)
82
- sys.modules[pipe_spec.name] = pipe_module
83
- pipe_spec.loader.exec_module(pipe_module)
84
- return module.MVSplitDiTTransformer2DModel, pipe_module.MVSplitDiTPipeline
85
-
86
-
87
- def main() -> None:
88
- args = parse_args()
89
- model_dir = args.model.resolve()
90
- device = _resolve_device(args.device)
91
- transformer_cls, pipeline_cls = _load_pipeline_class(model_dir)
92
-
93
- print(f"Loading components on {device}...", flush=True)
94
- transformer = transformer_cls.from_pretrained(
95
- model_dir / "transformer",
96
- torch_dtype=torch.bfloat16,
97
- local_files_only=True,
98
- )
99
- tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer", local_files_only=True)
100
- text_encoder = AutoModel.from_pretrained(
101
- model_dir / "text_encoder",
102
- torch_dtype=torch.bfloat16,
103
- local_files_only=True,
104
- )
105
-
106
- vae = None
107
- if not args.skip_vae and args.output_type == "pil":
108
- vae = AutoencoderKLFlux2.from_pretrained(
109
- model_dir / "vae",
110
- torch_dtype=torch.bfloat16,
111
- local_files_only=True,
112
- )
113
-
114
- pipe = pipeline_cls(
115
- transformer=transformer,
116
- scheduler=None,
117
- vae=vae,
118
- text_encoder=text_encoder,
119
- tokenizer=tokenizer,
120
- time_shift_alpha=args.time_shift_alpha,
121
- )
122
- if args.cpu_offload and device.type == "cuda":
123
- pipe.enable_sequential_cpu_offload(gpu_id=device.index or 0)
124
- else:
125
- pipe.to(device)
126
-
127
- print(
128
- f"Running inference ({args.num_inference_steps} steps, {args.height}x{args.width})...",
129
- flush=True,
130
- )
131
- generator_device = "cpu" if args.cpu_offload else device.type
132
- generator = torch.Generator(device=generator_device).manual_seed(args.seed)
133
- result = pipe(
134
- prompt=args.prompt,
135
- height=args.height,
136
- width=args.width,
137
- num_inference_steps=args.num_inference_steps,
138
- guidance_scale=args.guidance_scale,
139
- generator=generator,
140
- output_type=args.output_type,
141
- )
142
-
143
- if args.output_type == "latent":
144
- latents = result.images
145
- print(f"latent shape={tuple(latents.shape)} dtype={latents.dtype}")
146
- print(
147
- "latent stats:",
148
- f"min={float(latents.min()):.4f}",
149
- f"max={float(latents.max()):.4f}",
150
- f"mean={float(latents.mean()):.4f}",
151
- )
152
- return
153
-
154
- image = result.images[0]
155
- args.output.parent.mkdir(parents=True, exist_ok=True)
156
- image.save(args.output)
157
- print(f"Saved image to {args.output}")
158
-
159
-
160
- if __name__ == "__main__":
161
- main()