BiliSakura commited on
Commit
dbc7cc8
·
verified ·
1 Parent(s): 05d8082

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -388
pipeline.py DELETED
@@ -1,388 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
-
6
- """Hub custom pipeline: ADMPipeline.
7
-
8
- Load with native Hugging Face diffusers and `trust_remote_code=True`.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import importlib
14
- import sys
15
- from dataclasses import dataclass
16
- from pathlib import Path
17
- from typing import List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch
21
- from tqdm.auto import tqdm
22
-
23
- from diffusers.image_processor import VaeImageProcessor
24
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
- from diffusers.utils import BaseOutput, replace_example_docstring
26
- from diffusers.utils.torch_utils import randn_tensor
27
-
28
-
29
- EXAMPLE_DOC_STRING = """
30
- Examples:
31
- ```py
32
- >>> import torch
33
- >>> from diffusers import DiffusionPipeline
34
-
35
- >>> from pipeline import ADMPipeline
36
-
37
- >>> pipe = ADMPipeline.from_pretrained("./ADM-G-512", torch_dtype=torch.float16)
38
- >>> pipe.to("cuda")
39
-
40
- >>> # ADM-G (classifier guidance)
41
- >>> images = pipe(class_labels=207, classifier_guidance_scale=1.0, num_inference_steps=250).images
42
- ```
43
- """
44
-
45
-
46
- @dataclass
47
- class ADMPipelineOutput(BaseOutput):
48
- """
49
- Output class for ADM pipelines.
50
-
51
- Args:
52
- images (`torch.Tensor` or `list[PIL.Image.Image]` or `np.ndarray`):
53
- Generated images of shape `(batch_size, num_channels, height, width)` when `output_type="pt"`,
54
- or a list of PIL images / NumPy array when post-processed.
55
- """
56
-
57
- images: Union[torch.Tensor, List, np.ndarray]
58
-
59
-
60
- class ADMPipeline(DiffusionPipeline):
61
- r"""
62
- Pipeline for image generation with ADM (Ablated Diffusion Model).
63
-
64
- Supports class-conditional ADM (labels embedded in the UNet) and **ADM-G** (unconditional UNet + noisy
65
- classifier guidance). For ADM-G, pass `classifier_guidance_scale > 0` and provide `class_labels`; the
66
- optional `classifier` predicts `p(y | x_t)` and steers sampling.
67
-
68
- Args:
69
- unet ([`ADMUNet2DModel`]):
70
- A UNet model to denoise image samples (typically unconditional for ADM-G).
71
- scheduler ([`ADMScheduler`]):
72
- A scheduler used with the UNet to denoise image samples.
73
- classifier ([`ADMClassifierModel`], *optional*):
74
- Noisy ImageNet classifier for ADM-G guidance.
75
- """
76
-
77
- model_cpu_offload_seq = "classifier->unet"
78
- _optional_components = ["classifier"]
79
-
80
- @classmethod
81
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
82
- """Load a variant folder (e.g. `./ADM-G-512`) with `unet/`, `scheduler/`, `classifier/` subfolders."""
83
- repo_root = Path(__file__).resolve().parent
84
- variant = Path(pretrained_model_name_or_path)
85
- if not variant.is_absolute():
86
- variant = (repo_root / variant).resolve()
87
-
88
- model_kwargs = dict(kwargs)
89
- inserted: List[str] = []
90
-
91
- def _load_component(folder: str, module_name: str, class_name: str):
92
- comp_dir = variant / folder
93
- module_path = comp_dir / f"{module_name}.py"
94
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
95
- if not module_path.exists() or not has_weights:
96
- return None
97
-
98
- comp_path = str(comp_dir)
99
- if comp_path not in sys.path:
100
- sys.path.insert(0, comp_path)
101
- inserted.append(comp_path)
102
-
103
- module = importlib.import_module(module_name)
104
- component_cls = getattr(module, class_name)
105
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
106
-
107
- try:
108
- unet = _load_component("unet", "unet_adm", "ADMUNet2DModel")
109
- scheduler = _load_component("scheduler", "scheduling_adm", "ADMScheduler")
110
- classifier = _load_component("classifier", "classifier_adm", "ADMClassifierModel")
111
-
112
- if scheduler is None:
113
- sched_dir = variant / "scheduler"
114
- if (sched_dir / "scheduling_adm.py").exists():
115
- sched_path = str(sched_dir)
116
- if sched_path not in sys.path:
117
- sys.path.insert(0, sched_path)
118
- inserted.append(sched_path)
119
- scheduler = importlib.import_module("scheduling_adm").ADMScheduler()
120
-
121
- if unet is None and classifier is None:
122
- raise ValueError(f"No loadable components found under {variant}")
123
-
124
- return cls(unet=unet, scheduler=scheduler, classifier=classifier)
125
- finally:
126
- for comp_path in inserted:
127
- if comp_path in sys.path:
128
- sys.path.remove(comp_path)
129
-
130
- def __init__(
131
- self,
132
- unet,
133
- scheduler,
134
- classifier=None,
135
- ):
136
- super().__init__()
137
- self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
138
- self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
139
-
140
- @property
141
- def do_classifier_guidance(self) -> bool:
142
- return self.classifier is not None and getattr(self, "_classifier_guidance_scale", 0.0) > 0
143
-
144
- def check_inputs(
145
- self,
146
- class_labels: Optional[Union[int, List[int], torch.Tensor]],
147
- height: Optional[int],
148
- width: Optional[int],
149
- ):
150
- if class_labels is None and self.unet.config.class_cond:
151
- raise ValueError("`class_labels` are required for class-conditional ADM checkpoints.")
152
-
153
- if class_labels is not None and self.classifier is None and not self.unet.config.class_cond:
154
- raise ValueError(
155
- "This checkpoint is unconditional and has no classifier. Load an ADM-G repo with a "
156
- "`classifier/` subfolder, or use a class-conditional UNet."
157
- )
158
-
159
- if height is not None and height % 8 != 0:
160
- raise ValueError(f"`height` must be divisible by 8 but is {height}.")
161
- if width is not None and width % 8 != 0:
162
- raise ValueError(f"`width` must be divisible by 8 but is {width}.")
163
-
164
- def _prepare_class_labels(
165
- self,
166
- class_labels: Optional[Union[int, List[int], torch.Tensor]],
167
- batch_size: int,
168
- device: torch.device,
169
- ) -> Optional[torch.Tensor]:
170
- if class_labels is None:
171
- return None
172
-
173
- if isinstance(class_labels, int):
174
- class_labels = [class_labels]
175
- if not torch.is_tensor(class_labels):
176
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
177
- else:
178
- class_labels = class_labels.to(device=device, dtype=torch.long)
179
-
180
- if class_labels.shape[0] != batch_size:
181
- raise ValueError(
182
- f"`class_labels` batch ({class_labels.shape[0]}) must match requested batch size ({batch_size})."
183
- )
184
- return class_labels
185
-
186
- def _get_classifier_grad(
187
- self,
188
- sample: torch.Tensor,
189
- timestep: torch.Tensor,
190
- class_labels: torch.Tensor,
191
- classifier_scale: float,
192
- ) -> torch.Tensor:
193
- return self.classifier.guidance_gradient(
194
- sample,
195
- timestep,
196
- class_labels,
197
- classifier_scale=classifier_scale,
198
- )
199
-
200
- def prepare_latents(
201
- self,
202
- batch_size: int,
203
- num_channels: int,
204
- height: int,
205
- width: int,
206
- dtype: torch.dtype,
207
- device: torch.device,
208
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
209
- latents: Optional[torch.Tensor] = None,
210
- ) -> torch.Tensor:
211
- """
212
- Prepare initial Gaussian noise for pixel-space sampling.
213
-
214
- Args:
215
- batch_size (`int`):
216
- Number of images to generate.
217
- num_channels (`int`):
218
- Number of image channels (typically 3).
219
- height (`int`):
220
- Image height in pixels.
221
- width (`int`):
222
- Image width in pixels.
223
- dtype (`torch.dtype`):
224
- Data type for the latent tensor.
225
- device (`torch.device`):
226
- Target device.
227
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
228
- RNG for deterministic sampling.
229
- latents (`torch.Tensor`, *optional*):
230
- Pre-generated noise tensor.
231
-
232
- Returns:
233
- `torch.Tensor`:
234
- Initial noise of shape `(batch_size, num_channels, height, width)`.
235
- """
236
- shape = (batch_size, num_channels, height, width)
237
- if latents is None:
238
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
239
- else:
240
- latents = latents.to(device=device, dtype=dtype)
241
- return latents
242
-
243
- @torch.no_grad()
244
- @replace_example_docstring(EXAMPLE_DOC_STRING)
245
- def __call__(
246
- self,
247
- class_labels: Optional[Union[int, List[int], torch.Tensor]] = None,
248
- batch_size: int = 1,
249
- height: Optional[int] = None,
250
- width: Optional[int] = None,
251
- num_inference_steps: int = 250,
252
- use_ddim: bool = False,
253
- eta: float = 0.0,
254
- clip_denoised: bool = True,
255
- classifier_guidance_scale: float = 0.0,
256
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
257
- latents: Optional[torch.Tensor] = None,
258
- output_type: str = "pil",
259
- return_dict: bool = True,
260
- ) -> Union[ADMPipelineOutput, Tuple]:
261
- r"""
262
- Generate images with ADM.
263
-
264
- Args:
265
- class_labels (`int` or `list[int]` or `torch.Tensor`, *optional*):
266
- ImageNet class indices. Required for class-conditional UNets and for ADM-G classifier guidance.
267
- batch_size (`int`, *optional*, defaults to 1):
268
- Number of images to generate when `class_labels` is not provided.
269
- height (`int`, *optional*):
270
- Height in pixels. Defaults to `unet.config.image_size`.
271
- width (`int`, *optional*):
272
- Width in pixels. Defaults to `unet.config.image_size`.
273
- num_inference_steps (`int`, *optional*, defaults to 250):
274
- Number of denoising steps.
275
- use_ddim (`bool`, *optional*, defaults to `False`):
276
- Use DDIM sampling instead of DDPM.
277
- eta (`float`, *optional*, defaults to 0.0):
278
- DDIM stochasticity parameter. Only used when `use_ddim=True`.
279
- clip_denoised (`bool`, *optional*, defaults to `True`):
280
- Clamp predicted `x_0` to `[-1, 1]` inside the scheduler.
281
- classifier_guidance_scale (`float`, *optional*, defaults to 0.0):
282
- ADM-G guidance strength. Values `> 0` require a loaded `classifier` (OpenAI `classifier_scale`).
283
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
284
- RNG for reproducible generation.
285
- latents (`torch.Tensor`, *optional*):
286
- Pre-generated initial noise.
287
- output_type (`str`, *optional*, defaults to `"pil"`):
288
- Output format: `"pil"`, `"np"`, or `"pt"`.
289
- return_dict (`bool`, *optional*, defaults to `True`):
290
- Return an [`ADMPipelineOutput`] instead of a tuple.
291
-
292
- Examples:
293
-
294
- Returns:
295
- [`ADMPipelineOutput`] or `tuple`:
296
- Generated images.
297
- """
298
- if height is None:
299
- height = int(self.unet.config.image_size)
300
- if width is None:
301
- width = int(self.unet.config.image_size)
302
-
303
- self.check_inputs(class_labels, height, width)
304
-
305
- if classifier_guidance_scale > 0 and self.classifier is None:
306
- raise ValueError("`classifier_guidance_scale > 0` requires a loaded `classifier` (ADM-G checkpoint).")
307
- if classifier_guidance_scale > 0 and class_labels is None:
308
- raise ValueError("`class_labels` are required when using classifier guidance.")
309
-
310
- self._classifier_guidance_scale = classifier_guidance_scale
311
- device = self._execution_device
312
- model_dtype = self.unet.dtype
313
-
314
- if class_labels is not None:
315
- if isinstance(class_labels, int):
316
- batch_size = 1
317
- elif isinstance(class_labels, list):
318
- batch_size = len(class_labels)
319
- elif torch.is_tensor(class_labels):
320
- batch_size = class_labels.shape[0]
321
-
322
- class_labels = self._prepare_class_labels(class_labels, batch_size, device)
323
-
324
- latents = self.prepare_latents(
325
- batch_size,
326
- 3,
327
- height,
328
- width,
329
- model_dtype,
330
- device,
331
- generator,
332
- latents,
333
- )
334
-
335
- self.scheduler.set_timesteps(num_inference_steps, device=device, use_ddim=use_ddim)
336
- self.scheduler._eta = eta
337
-
338
- self._num_timesteps = len(self.scheduler.timesteps)
339
-
340
- unet_class_labels = class_labels if self.unet.config.class_cond else None
341
-
342
- for t in tqdm(self.scheduler.timesteps, desc="Denoising"):
343
- timestep = torch.full((batch_size,), t, device=device, dtype=torch.long)
344
- model_timesteps = self.scheduler.scale_timesteps_for_model(timestep)
345
-
346
- model_output = self.unet(
347
- latents,
348
- model_timesteps,
349
- class_labels=unet_class_labels,
350
- return_dict=True,
351
- ).sample
352
-
353
- cond_grad = None
354
- if self.do_classifier_guidance:
355
- cond_grad = self._get_classifier_grad(
356
- latents,
357
- timestep,
358
- class_labels,
359
- classifier_guidance_scale,
360
- )
361
-
362
- latents = self.scheduler.step(
363
- model_output,
364
- t,
365
- latents,
366
- generator=generator,
367
- clip_denoised=clip_denoised,
368
- eta=eta,
369
- cond_grad=cond_grad,
370
- ).prev_sample
371
-
372
- image = latents
373
- has_nsfw_concept = None
374
-
375
- if output_type == "latent":
376
- image = latents
377
- elif output_type == "pt":
378
- image = (image / 2 + 0.5).clamp(0, 1)
379
- elif output_type in ("pil", "np"):
380
- image = (image / 2 + 0.5).clamp(0, 1)
381
- image = self.image_processor.postprocess(image, output_type=output_type)
382
-
383
- self.maybe_free_model_hooks()
384
-
385
- if not return_dict:
386
- return (image, has_nsfw_concept)
387
-
388
- return ADMPipelineOutput(images=image)