BiliSakura commited on
Commit
7fc7e34
·
verified ·
1 Parent(s): 9b59af7

Upload folder using huggingface_hub

Browse files
ADM-G-256/README.md CHANGED
@@ -20,23 +20,27 @@ ADM-G-256/
20
  ## Load
21
 
22
  ```python
23
- import sys
24
  from pathlib import Path
25
- from huggingface_hub import snapshot_download
26
-
27
- repo_dir = Path(snapshot_download("BiliSakura/ADM-diffusers"))
28
- sys.path.insert(0, str(repo_dir / "ADM-G-256"))
29
- from pipeline import ADMPipeline
30
-
31
- pipe = ADMPipeline.from_pretrained(".")
32
- pipe.to("cuda")
33
- pipe.unet.float()
34
- pipe.classifier.float()
35
- pipe.classifier.model.dtype = torch.float32
36
-
37
- images = pipe(
38
- class_labels=207,
 
 
 
 
39
  num_inference_steps=250,
40
- classifier_guidance_scale=1.0,
41
- ).images
 
42
  ```
 
20
  ## Load
21
 
22
  ```python
 
23
  from pathlib import Path
24
+ import torch
25
+ from diffusers import DDPMScheduler, DiffusionPipeline
26
+
27
+ model_dir = Path("./BiliSakura/ADM-diffusers/ADM-G-256")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ torch_dtype=torch.bfloat16,
33
+ )
34
+ pipe = pipe.to("cuda")
35
+ pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
36
+ class_id = pipe.get_label_ids("golden retriever")[0]
37
+ generator = torch.Generator(device="cuda").manual_seed(42)
38
+
39
+ out = pipe(
40
+ class_labels=class_id,
41
+ guidance_scale=1.0,
42
  num_inference_steps=250,
43
+ generator=generator,
44
+ ).images[0]
45
+ out
46
  ```
ADM-G-256/__pycache__/pipeline.cpython-312.pyc CHANGED
Binary files a/ADM-G-256/__pycache__/pipeline.cpython-312.pyc and b/ADM-G-256/__pycache__/pipeline.cpython-312.pyc differ
 
ADM-G-256/classifier/__pycache__/classifier_adm.cpython-312.pyc CHANGED
Binary files a/ADM-G-256/classifier/__pycache__/classifier_adm.cpython-312.pyc and b/ADM-G-256/classifier/__pycache__/classifier_adm.cpython-312.pyc differ
 
ADM-G-256/classifier/__pycache__/modeling_adm.cpython-312.pyc ADDED
Binary file (41.2 kB). View file
 
ADM-G-256/classifier/classifier_adm.py CHANGED
@@ -3,18 +3,524 @@
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
 
 
 
6
  from dataclasses import dataclass
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
10
  import torch.nn.functional as F
 
11
 
12
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
13
  from diffusers.models.modeling_utils import ModelMixin
14
  from diffusers.utils import BaseOutput
15
 
 
16
 
17
- from modeling_adm import create_adm_classifier_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @dataclass
 
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
 
6
+ import math
7
+ from abc import abstractmethod
8
  from dataclasses import dataclass
9
  from typing import Optional, Tuple, Union
10
 
11
  import torch
12
+ import torch.nn as nn
13
  import torch.nn.functional as F
14
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
15
 
16
  from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.models.embeddings import get_timestep_embedding
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from diffusers.utils import BaseOutput
20
 
21
+ NUM_CLASSES = 1000
22
 
23
+
24
+ def conv_nd(dims: int, *args, **kwargs):
25
+ if dims == 1:
26
+ return nn.Conv1d(*args, **kwargs)
27
+ if dims == 2:
28
+ return nn.Conv2d(*args, **kwargs)
29
+ if dims == 3:
30
+ return nn.Conv3d(*args, **kwargs)
31
+ raise ValueError(f"unsupported dimensions: {dims}")
32
+
33
+
34
+ def linear(*args, **kwargs):
35
+ return nn.Linear(*args, **kwargs)
36
+
37
+
38
+ def avg_pool_nd(dims: int, *args, **kwargs):
39
+ if dims == 1:
40
+ return nn.AvgPool1d(*args, **kwargs)
41
+ if dims == 2:
42
+ return nn.AvgPool2d(*args, **kwargs)
43
+ if dims == 3:
44
+ return nn.AvgPool3d(*args, **kwargs)
45
+ raise ValueError(f"unsupported dimensions: {dims}")
46
+
47
+
48
+ class GroupNorm32(nn.GroupNorm):
49
+ def forward(self, x):
50
+ weight = self.weight.float() if self.weight is not None else None
51
+ bias = self.bias.float() if self.bias is not None else None
52
+ y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps)
53
+ return y.to(dtype=x.dtype)
54
+
55
+
56
+ def normalization(channels: int):
57
+ return GroupNorm32(32, channels)
58
+
59
+
60
+ def zero_module(module: nn.Module):
61
+ for p in module.parameters():
62
+ p.detach().zero_()
63
+ return module
64
+
65
+
66
+ def convert_module_to_f16(module: nn.Module):
67
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
68
+ module.weight.data = module.weight.data.half()
69
+ if module.bias is not None:
70
+ module.bias.data = module.bias.data.half()
71
+
72
+
73
+ def convert_module_to_f32(module: nn.Module):
74
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
75
+ module.weight.data = module.weight.data.float()
76
+ if module.bias is not None:
77
+ module.bias.data = module.bias.data.float()
78
+
79
+
80
+ class TimestepBlock(nn.Module):
81
+ @abstractmethod
82
+ def forward(self, x, emb):
83
+ raise NotImplementedError
84
+
85
+
86
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
87
+ def forward(self, x, emb):
88
+ for layer in self:
89
+ if isinstance(layer, TimestepBlock):
90
+ x = layer(x, emb)
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.out_channels = out_channels or channels
101
+ self.use_conv = use_conv
102
+ self.dims = dims
103
+ if use_conv:
104
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
105
+
106
+ def forward(self, x):
107
+ assert x.shape[1] == self.channels
108
+ if self.dims == 3:
109
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
110
+ else:
111
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
112
+ if self.use_conv:
113
+ x = self.conv(x)
114
+ return x
115
+
116
+
117
+ class Downsample(nn.Module):
118
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
119
+ super().__init__()
120
+ self.channels = channels
121
+ self.out_channels = out_channels or channels
122
+ self.use_conv = use_conv
123
+ stride = 2 if dims != 3 else (1, 2, 2)
124
+ if use_conv:
125
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
126
+ else:
127
+ assert self.channels == self.out_channels
128
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
129
+
130
+ def forward(self, x):
131
+ assert x.shape[1] == self.channels
132
+ return self.op(x)
133
+
134
+
135
+ class ResBlock(TimestepBlock):
136
+ def __init__(
137
+ self,
138
+ channels,
139
+ emb_channels,
140
+ dropout,
141
+ out_channels=None,
142
+ use_conv=False,
143
+ use_scale_shift_norm=False,
144
+ dims=2,
145
+ use_checkpoint=False,
146
+ up=False,
147
+ down=False,
148
+ ):
149
+ super().__init__()
150
+ self.channels = channels
151
+ self.out_channels = out_channels or channels
152
+ self.use_checkpoint = use_checkpoint
153
+ self.use_scale_shift_norm = use_scale_shift_norm
154
+ self.in_layers = nn.Sequential(
155
+ normalization(channels),
156
+ nn.SiLU(),
157
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
158
+ )
159
+
160
+ self.updown = up or down
161
+ if up:
162
+ self.h_upd = Upsample(channels, False, dims)
163
+ self.x_upd = Upsample(channels, False, dims)
164
+ elif down:
165
+ self.h_upd = Downsample(channels, False, dims)
166
+ self.x_upd = Downsample(channels, False, dims)
167
+ else:
168
+ self.h_upd = self.x_upd = nn.Identity()
169
+
170
+ self.emb_layers = nn.Sequential(
171
+ nn.SiLU(),
172
+ linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
173
+ )
174
+ self.out_layers = nn.Sequential(
175
+ normalization(self.out_channels),
176
+ nn.SiLU(),
177
+ nn.Dropout(p=dropout),
178
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
179
+ )
180
+
181
+ if self.out_channels == channels:
182
+ self.skip_connection = nn.Identity()
183
+ elif use_conv:
184
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
185
+ else:
186
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
187
+
188
+ def forward(self, x, emb):
189
+ if self.use_checkpoint and x.requires_grad:
190
+ return torch_checkpoint(self._forward, x, emb, use_reentrant=False)
191
+ return self._forward(x, emb)
192
+
193
+ def _forward(self, x, emb):
194
+ if self.updown:
195
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
196
+ h = in_rest(x)
197
+ h = self.h_upd(h)
198
+ x = self.x_upd(x)
199
+ h = in_conv(h)
200
+ else:
201
+ h = self.in_layers(x)
202
+
203
+ emb_out = self.emb_layers(emb).type(h.dtype)
204
+ while len(emb_out.shape) < len(h.shape):
205
+ emb_out = emb_out[..., None]
206
+ if self.use_scale_shift_norm:
207
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
208
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
209
+ h = out_norm(h) * (1 + scale) + shift
210
+ h = out_rest(h)
211
+ else:
212
+ h = h + emb_out
213
+ h = self.out_layers(h)
214
+ return self.skip_connection(x) + h
215
+
216
+
217
+ class QKVAttentionLegacy(nn.Module):
218
+ def __init__(self, n_heads):
219
+ super().__init__()
220
+ self.n_heads = n_heads
221
+
222
+ def forward(self, qkv):
223
+ bs, width, length = qkv.shape
224
+ assert width % (3 * self.n_heads) == 0
225
+ ch = width // (3 * self.n_heads)
226
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
227
+ scale = 1 / math.sqrt(math.sqrt(ch))
228
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
229
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
230
+ a = torch.einsum("bts,bcs->bct", weight, v)
231
+ return a.reshape(bs, -1, length)
232
+
233
+
234
+ class QKVAttention(nn.Module):
235
+ def __init__(self, n_heads):
236
+ super().__init__()
237
+ self.n_heads = n_heads
238
+
239
+ def forward(self, qkv):
240
+ bs, width, length = qkv.shape
241
+ assert width % (3 * self.n_heads) == 0
242
+ ch = width // (3 * self.n_heads)
243
+ q, k, v = qkv.chunk(3, dim=1)
244
+ scale = 1 / math.sqrt(math.sqrt(ch))
245
+ weight = torch.einsum(
246
+ "bct,bcs->bts",
247
+ (q * scale).view(bs * self.n_heads, ch, length),
248
+ (k * scale).view(bs * self.n_heads, ch, length),
249
+ )
250
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
251
+ a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
252
+ return a.reshape(bs, -1, length)
253
+
254
+
255
+ class AttentionBlock(nn.Module):
256
+ def __init__(
257
+ self,
258
+ channels,
259
+ num_heads=1,
260
+ num_head_channels=-1,
261
+ use_checkpoint=False,
262
+ use_new_attention_order=False,
263
+ ):
264
+ super().__init__()
265
+ if num_head_channels == -1:
266
+ self.num_heads = num_heads
267
+ else:
268
+ assert channels % num_head_channels == 0
269
+ self.num_heads = channels // num_head_channels
270
+ self.use_checkpoint = use_checkpoint
271
+ self.norm = normalization(channels)
272
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
273
+ self.attention = QKVAttention(self.num_heads) if use_new_attention_order else QKVAttentionLegacy(self.num_heads)
274
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
275
+
276
+ def forward(self, x):
277
+ if self.use_checkpoint and x.requires_grad:
278
+ return torch_checkpoint(self._forward, x, use_reentrant=False)
279
+ return self._forward(x)
280
+
281
+ def _forward(self, x):
282
+ b, c, *spatial = x.shape
283
+ x = x.reshape(b, c, -1)
284
+ qkv = self.qkv(self.norm(x))
285
+ h = self.attention(qkv)
286
+ h = self.proj_out(h)
287
+ return (x + h).reshape(b, c, *spatial)
288
+
289
+
290
+ class AttentionPool2d(nn.Module):
291
+ """CLIP-style attention pooling used by ADM noisy classifiers."""
292
+
293
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None):
294
+ super().__init__()
295
+ self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
296
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
297
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
298
+ self.num_heads = embed_dim // num_heads_channels
299
+ self.attention = QKVAttention(self.num_heads)
300
+
301
+ def forward(self, x):
302
+ b, c, *_spatial = x.shape
303
+ x = x.reshape(b, c, -1)
304
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
305
+ x = x + self.positional_embedding[None, :, :].to(x.dtype)
306
+ x = self.qkv_proj(x)
307
+ x = self.attention(x)
308
+ x = self.c_proj(x)
309
+ return x[:, :, 0]
310
+
311
+
312
+ class EncoderUNetModel(nn.Module):
313
+ """Noisy image classifier backbone for ADM-G (classifier guidance)."""
314
+
315
+ def __init__(
316
+ self,
317
+ image_size,
318
+ in_channels,
319
+ model_channels,
320
+ out_channels,
321
+ num_res_blocks,
322
+ attention_resolutions,
323
+ dropout=0,
324
+ channel_mult=(1, 2, 4, 8),
325
+ conv_resample=True,
326
+ dims=2,
327
+ use_checkpoint=False,
328
+ use_fp16=False,
329
+ num_heads=1,
330
+ num_head_channels=-1,
331
+ use_scale_shift_norm=False,
332
+ resblock_updown=False,
333
+ use_new_attention_order=False,
334
+ pool="adaptive",
335
+ ):
336
+ super().__init__()
337
+
338
+ self.model_channels = model_channels
339
+ self.use_checkpoint = use_checkpoint
340
+ self.dtype = torch.float16 if use_fp16 else torch.float32
341
+
342
+ time_embed_dim = model_channels * 4
343
+ self.time_embed = nn.Sequential(
344
+ linear(model_channels, time_embed_dim),
345
+ nn.SiLU(),
346
+ linear(time_embed_dim, time_embed_dim),
347
+ )
348
+
349
+ ch = int(channel_mult[0] * model_channels)
350
+ self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
351
+ self._feature_size = ch
352
+ ds = 1
353
+ for level, mult in enumerate(channel_mult):
354
+ for _ in range(num_res_blocks):
355
+ layers = [
356
+ ResBlock(
357
+ ch,
358
+ time_embed_dim,
359
+ dropout,
360
+ out_channels=int(mult * model_channels),
361
+ dims=dims,
362
+ use_checkpoint=use_checkpoint,
363
+ use_scale_shift_norm=use_scale_shift_norm,
364
+ )
365
+ ]
366
+ ch = int(mult * model_channels)
367
+ if ds in attention_resolutions:
368
+ layers.append(
369
+ AttentionBlock(
370
+ ch,
371
+ use_checkpoint=use_checkpoint,
372
+ num_heads=num_heads,
373
+ num_head_channels=num_head_channels,
374
+ use_new_attention_order=use_new_attention_order,
375
+ )
376
+ )
377
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
378
+ self._feature_size += ch
379
+ if level != len(channel_mult) - 1:
380
+ out_ch = ch
381
+ self.input_blocks.append(
382
+ TimestepEmbedSequential(
383
+ ResBlock(
384
+ ch,
385
+ time_embed_dim,
386
+ dropout,
387
+ out_channels=out_ch,
388
+ dims=dims,
389
+ use_checkpoint=use_checkpoint,
390
+ use_scale_shift_norm=use_scale_shift_norm,
391
+ down=True,
392
+ )
393
+ if resblock_updown
394
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
395
+ )
396
+ )
397
+ ch = out_ch
398
+ ds *= 2
399
+ self._feature_size += ch
400
+
401
+ self.middle_block = TimestepEmbedSequential(
402
+ ResBlock(
403
+ ch,
404
+ time_embed_dim,
405
+ dropout,
406
+ dims=dims,
407
+ use_checkpoint=use_checkpoint,
408
+ use_scale_shift_norm=use_scale_shift_norm,
409
+ ),
410
+ AttentionBlock(
411
+ ch,
412
+ use_checkpoint=use_checkpoint,
413
+ num_heads=num_heads,
414
+ num_head_channels=num_head_channels,
415
+ use_new_attention_order=use_new_attention_order,
416
+ ),
417
+ ResBlock(
418
+ ch,
419
+ time_embed_dim,
420
+ dropout,
421
+ dims=dims,
422
+ use_checkpoint=use_checkpoint,
423
+ use_scale_shift_norm=use_scale_shift_norm,
424
+ ),
425
+ )
426
+ self._feature_size += ch
427
+ self.pool = pool
428
+ if pool == "adaptive":
429
+ self.out = nn.Sequential(
430
+ normalization(ch),
431
+ nn.SiLU(),
432
+ nn.AdaptiveAvgPool2d((1, 1)),
433
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
434
+ nn.Flatten(),
435
+ )
436
+ elif pool == "attention":
437
+ assert num_head_channels != -1
438
+ self.out = nn.Sequential(
439
+ normalization(ch),
440
+ nn.SiLU(),
441
+ AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
442
+ )
443
+ elif pool == "spatial":
444
+ self.out = nn.Sequential(
445
+ nn.Linear(self._feature_size, 2048),
446
+ nn.ReLU(),
447
+ nn.Linear(2048, out_channels),
448
+ )
449
+ elif pool == "spatial_v2":
450
+ self.out = nn.Sequential(
451
+ nn.Linear(self._feature_size, 2048),
452
+ normalization(2048),
453
+ nn.SiLU(),
454
+ nn.Linear(2048, out_channels),
455
+ )
456
+ else:
457
+ raise NotImplementedError(f"Unexpected {pool} pooling")
458
+
459
+ def convert_to_fp16(self):
460
+ self.input_blocks.apply(convert_module_to_f16)
461
+ self.middle_block.apply(convert_module_to_f16)
462
+
463
+ def convert_to_fp32(self):
464
+ self.input_blocks.apply(convert_module_to_f32)
465
+ self.middle_block.apply(convert_module_to_f32)
466
+
467
+ def forward(self, x, timesteps):
468
+ emb = get_timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
469
+ emb = self.time_embed(emb)
470
+ results = []
471
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
472
+ for module in self.input_blocks:
473
+ h = module(h, emb)
474
+ if self.pool.startswith("spatial"):
475
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
476
+ h = self.middle_block(h, emb)
477
+ if self.pool.startswith("spatial"):
478
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
479
+ h = torch.cat(results, dim=-1)
480
+ return self.out(h)
481
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
482
+ return self.out(h)
483
+
484
+
485
+ def _default_channel_mult(image_size: int):
486
+ if image_size == 512:
487
+ return (0.5, 1, 1, 2, 2, 4, 4)
488
+ if image_size == 256:
489
+ return (1, 1, 2, 2, 4, 4)
490
+ if image_size == 128:
491
+ return (1, 1, 2, 3, 4)
492
+ if image_size == 64:
493
+ return (1, 2, 3, 4)
494
+ raise ValueError(f"unsupported image size: {image_size}")
495
+
496
+
497
+ def create_adm_classifier_model(
498
+ image_size: int,
499
+ classifier_width: int = 128,
500
+ classifier_depth: int = 2,
501
+ classifier_attention_resolutions: str = "32,16,8",
502
+ classifier_use_scale_shift_norm: bool = True,
503
+ classifier_resblock_updown: bool = True,
504
+ classifier_pool: str = "attention",
505
+ use_fp16: bool = False,
506
+ num_classes: int = NUM_CLASSES,
507
+ ):
508
+ channel_mult = _default_channel_mult(image_size)
509
+ attention_ds = tuple(image_size // int(res) for res in classifier_attention_resolutions.split(","))
510
+ return EncoderUNetModel(
511
+ image_size=image_size,
512
+ in_channels=3,
513
+ model_channels=classifier_width,
514
+ out_channels=num_classes,
515
+ num_res_blocks=classifier_depth,
516
+ attention_resolutions=attention_ds,
517
+ channel_mult=channel_mult,
518
+ use_fp16=use_fp16,
519
+ num_head_channels=64,
520
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
521
+ resblock_updown=classifier_resblock_updown,
522
+ pool=classifier_pool,
523
+ )
524
 
525
 
526
  @dataclass
ADM-G-256/model_index.json CHANGED
@@ -2,8 +2,8 @@
2
  "_class_name": "ADMPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
- "scheduling_adm",
6
- "ADMScheduler"
7
  ],
8
  "unet": [
9
  "unet_adm",
 
2
  "_class_name": "ADMPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
+ "diffusers",
6
+ "DDPMScheduler"
7
  ],
8
  "unet": [
9
  "unet_adm",
ADM-G-256/pipeline.py CHANGED
@@ -2,208 +2,72 @@
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
 
 
 
 
 
 
 
 
5
 
6
- """Hub custom pipeline: ADMPipeline.
7
-
8
- Load with native Hugging Face diffusers and `trust_remote_code=True`.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import importlib
14
- import sys
15
- from dataclasses import dataclass
16
- from pathlib import Path
17
- from typing import Dict, List, Optional, Tuple, Union
18
 
19
- import numpy as np
20
  import torch
21
- from tqdm.auto import tqdm
22
 
23
  from diffusers.image_processor import VaeImageProcessor
24
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
- from diffusers.utils import BaseOutput, replace_example_docstring
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
-
29
  EXAMPLE_DOC_STRING = """
30
  Examples:
31
  ```py
 
32
  >>> import torch
33
  >>> from diffusers import DiffusionPipeline
34
 
35
- >>> from pipeline import ADMPipeline
36
-
37
- >>> pipe = ADMPipeline.from_pretrained(".")
38
- >>> pipe.to("cuda")
39
-
40
- >>> # ADM-G (classifier guidance) with numeric class id
41
- >>> images = pipe(class_labels=207, classifier_guidance_scale=1.0, num_inference_steps=250).images
42
-
43
- >>> # Or use human-readable ImageNet labels (English)
44
- >>> pipe.id2label[207]
45
- >>> class_ids = pipe.get_label_ids("golden retriever")
46
- >>> images = pipe(class_labels="golden retriever", classifier_guidance_scale=1.0).images
47
  ```
48
  """
49
 
50
 
51
- @dataclass
52
- class ADMPipelineOutput(BaseOutput):
53
- """
54
- Output class for ADM pipelines.
55
-
56
- Args:
57
- images (`torch.Tensor` or `list[PIL.Image.Image]` or `np.ndarray`):
58
- Generated images of shape `(batch_size, num_channels, height, width)` when `output_type="pt"`,
59
- or a list of PIL images / NumPy array when post-processed.
60
- """
61
-
62
- images: Union[torch.Tensor, List, np.ndarray]
63
-
64
-
65
  class ADMPipeline(DiffusionPipeline):
66
- r"""
67
- Pipeline for image generation with ADM (Ablated Diffusion Model).
68
-
69
- Supports class-conditional ADM (labels embedded in the UNet) and **ADM-G** (unconditional UNet + noisy
70
- classifier guidance). For ADM-G, pass `classifier_guidance_scale > 0` and provide `class_labels`; the
71
- optional `classifier` predicts `p(y | x_t)` and steers sampling.
72
-
73
- Args:
74
- unet ([`ADMUNet2DModel`]):
75
- A UNet model to denoise image samples (typically unconditional for ADM-G).
76
- scheduler ([`ADMScheduler`]):
77
- A scheduler used with the UNet to denoise image samples.
78
- classifier ([`ADMClassifierModel`], *optional*):
79
- Noisy ImageNet classifier for ADM-G guidance.
80
- id2label (`dict[int, str]`, *optional*):
81
- ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
82
- """
83
 
84
  model_cpu_offload_seq = "classifier->unet"
85
  _optional_components = ["classifier"]
86
 
87
- @classmethod
88
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
89
- """Load a self-contained variant folder locally or from the Hub.
90
-
91
- Examples:
92
- ADMPipeline.from_pretrained(".")
93
- ADMPipeline.from_pretrained("./ADM-G-256")
94
- ADMPipeline.from_pretrained("BiliSakura/ADM-diffusers", subfolder="ADM-G-512")
95
- """
96
- repo_root = Path(__file__).resolve().parent
97
-
98
- if pretrained_model_name_or_path in (None, "", "."):
99
- variant = repo_root
100
- elif (
101
- isinstance(pretrained_model_name_or_path, str)
102
- and "/" in pretrained_model_name_or_path
103
- and not Path(pretrained_model_name_or_path).exists()
104
- ):
105
- from huggingface_hub import snapshot_download
106
-
107
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
108
- if subfolder:
109
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
110
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
111
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
112
- else:
113
- variant = Path(pretrained_model_name_or_path)
114
- if not variant.is_absolute():
115
- candidate = (Path.cwd() / variant).resolve()
116
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
117
- if subfolder:
118
- variant = variant / subfolder
119
-
120
- id2label_override = kwargs.pop("id2label", None)
121
- model_kwargs = dict(kwargs)
122
- inserted: List[str] = []
123
-
124
- def _load_component(folder: str, module_name: str, class_name: str):
125
- comp_dir = variant / folder
126
- module_path = comp_dir / f"{module_name}.py"
127
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
128
- if not module_path.exists() or not has_weights:
129
- return None
130
-
131
- comp_path = str(comp_dir)
132
- if comp_path not in sys.path:
133
- sys.path.insert(0, comp_path)
134
- inserted.append(comp_path)
135
-
136
- module = importlib.import_module(module_name)
137
- component_cls = getattr(module, class_name)
138
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
139
-
140
- try:
141
- unet = _load_component("unet", "unet_adm", "ADMUNet2DModel")
142
- scheduler = _load_component("scheduler", "scheduling_adm", "ADMScheduler")
143
- classifier = _load_component("classifier", "classifier_adm", "ADMClassifierModel")
144
-
145
- if scheduler is None:
146
- sched_dir = variant / "scheduler"
147
- if (sched_dir / "scheduling_adm.py").exists():
148
- sched_path = str(sched_dir)
149
- if sched_path not in sys.path:
150
- sys.path.insert(0, sched_path)
151
- inserted.append(sched_path)
152
- scheduler = importlib.import_module("scheduling_adm").ADMScheduler()
153
-
154
- if unet is None and classifier is None:
155
- raise ValueError(f"No loadable components found under {variant}")
156
-
157
- id2label = id2label_override
158
- if id2label is None:
159
- model_index_path = variant / "model_index.json"
160
- if model_index_path.exists():
161
- id2label = cls._read_id2label_from_model_index(model_index_path)
162
-
163
- return cls(
164
- unet=unet,
165
- scheduler=scheduler,
166
- classifier=classifier,
167
- id2label=id2label,
168
- )
169
- finally:
170
- for comp_path in inserted:
171
- if comp_path in sys.path:
172
- sys.path.remove(comp_path)
173
-
174
  def __init__(
175
  self,
176
  unet,
177
- scheduler,
178
- classifier=None,
179
- id2label: Optional[Dict[Union[int, str], str]] = None,
180
- ):
 
181
  super().__init__()
182
  self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
 
183
  self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
184
-
185
- self._id2label = self._normalize_id2label(id2label)
186
  self.labels = self._build_label2id(self._id2label)
187
 
188
  @staticmethod
189
- def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
190
- if not id2label:
191
- return {}
192
- return {int(key): value for key, value in id2label.items()}
193
-
194
- @staticmethod
195
- def _read_id2label_from_model_index(model_index_path: Path) -> Optional[Dict[int, str]]:
196
- import json
197
-
198
- raw = json.loads(model_index_path.read_text(encoding="utf-8"))
199
- id2label = raw.get("id2label")
200
- if not isinstance(id2label, dict):
201
- return None
202
- return {int(key): value for key, value in id2label.items()}
203
-
204
- @staticmethod
205
- def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
206
- label2id: dict[str, int] = {}
207
  for class_id, value in id2label.items():
208
  for synonym in value.split(","):
209
  synonym = synonym.strip()
@@ -212,153 +76,44 @@ class ADMPipeline(DiffusionPipeline):
212
  return dict(sorted(label2id.items()))
213
 
214
  @property
215
- def id2label(self) -> dict[int, str]:
216
- """ImageNet class id to English label string (comma-separated synonyms)."""
217
  return self._id2label
218
 
219
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
220
- r"""
221
- Map ImageNet label strings to class ids.
222
-
223
- Args:
224
- label (`str` or `list[str]`):
225
- One or more English ImageNet label strings matching a synonym in `id2label`.
226
-
227
- Returns:
228
- `list[int]`: Class ids for [`~ADMPipeline.__call__`].
229
- """
230
- label2id = self.labels
231
- if not label2id:
232
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
233
-
234
- if isinstance(label, str):
235
- label = [label]
236
-
237
- missing = [item for item in label if item not in label2id]
238
  if missing:
239
- preview = ", ".join(list(label2id.keys())[:8])
240
- raise ValueError(
241
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
242
- )
243
- return [label2id[item] for item in label]
244
-
245
- @property
246
- def do_classifier_guidance(self) -> bool:
247
- return self.classifier is not None and getattr(self, "_classifier_guidance_scale", 0.0) > 0
248
-
249
- def _normalize_class_labels(
250
- self,
251
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]],
252
- ) -> Optional[Union[int, List[int], torch.Tensor]]:
253
- if class_labels is None:
254
- return None
255
-
256
- if isinstance(class_labels, str):
257
- return self.get_label_ids(class_labels)[0]
258
-
259
- if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
260
- return self.get_label_ids(class_labels)
261
-
262
- return class_labels
263
 
264
- def check_inputs(
265
- self,
266
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]],
267
- height: Optional[int],
268
- width: Optional[int],
269
- ):
270
- if class_labels is None and self.unet.config.class_cond:
271
- raise ValueError("`class_labels` are required for class-conditional ADM checkpoints.")
272
-
273
- if class_labels is not None and self.classifier is None and not self.unet.config.class_cond:
274
- raise ValueError(
275
- "This checkpoint is unconditional and has no classifier. Load an ADM-G repo with a "
276
- "`classifier/` subfolder, or use a class-conditional UNet."
277
- )
278
-
279
- if height is not None and height % 8 != 0:
280
- raise ValueError(f"`height` must be divisible by 8 but is {height}.")
281
- if width is not None and width % 8 != 0:
282
- raise ValueError(f"`width` must be divisible by 8 but is {width}.")
283
-
284
- def _prepare_class_labels(
285
- self,
286
- class_labels: Optional[Union[int, List[int], torch.Tensor]],
287
- batch_size: int,
288
- device: torch.device,
289
- ) -> Optional[torch.Tensor]:
290
- if class_labels is None:
291
- return None
292
-
293
- if isinstance(class_labels, int):
294
- class_labels = [class_labels]
295
- if not torch.is_tensor(class_labels):
296
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
297
- else:
298
- class_labels = class_labels.to(device=device, dtype=torch.long)
299
-
300
- if class_labels.shape[0] != batch_size:
301
- raise ValueError(
302
- f"`class_labels` batch ({class_labels.shape[0]}) must match requested batch size ({batch_size})."
303
- )
304
- return class_labels
305
 
306
- def _get_classifier_grad(
307
- self,
308
- sample: torch.Tensor,
309
- timestep: torch.Tensor,
310
- class_labels: torch.Tensor,
311
- classifier_scale: float,
312
- ) -> torch.Tensor:
313
- return self.classifier.guidance_gradient(
314
- sample,
315
- timestep,
316
- class_labels,
317
- classifier_scale=classifier_scale,
318
- )
319
 
320
- def prepare_latents(
321
- self,
322
- batch_size: int,
323
- num_channels: int,
324
- height: int,
325
- width: int,
326
- dtype: torch.dtype,
327
- device: torch.device,
328
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
329
- latents: Optional[torch.Tensor] = None,
330
- ) -> torch.Tensor:
331
- """
332
- Prepare initial Gaussian noise for pixel-space sampling.
333
-
334
- Args:
335
- batch_size (`int`):
336
- Number of images to generate.
337
- num_channels (`int`):
338
- Number of image channels (typically 3).
339
- height (`int`):
340
- Image height in pixels.
341
- width (`int`):
342
- Image width in pixels.
343
- dtype (`torch.dtype`):
344
- Data type for the latent tensor.
345
- device (`torch.device`):
346
- Target device.
347
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
348
- RNG for deterministic sampling.
349
- latents (`torch.Tensor`, *optional*):
350
- Pre-generated noise tensor.
351
-
352
- Returns:
353
- `torch.Tensor`:
354
- Initial noise of shape `(batch_size, num_channels, height, width)`.
355
- """
356
- shape = (batch_size, num_channels, height, width)
357
- if latents is None:
358
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359
- else:
360
- latents = latents.to(device=device, dtype=dtype)
361
- return latents
362
 
363
  @torch.no_grad()
364
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -369,142 +124,130 @@ class ADMPipeline(DiffusionPipeline):
369
  height: Optional[int] = None,
370
  width: Optional[int] = None,
371
  num_inference_steps: int = 250,
372
- use_ddim: bool = False,
 
373
  eta: float = 0.0,
374
  clip_denoised: bool = True,
375
- classifier_guidance_scale: float = 0.0,
376
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
377
  latents: Optional[torch.Tensor] = None,
378
  output_type: str = "pil",
379
  return_dict: bool = True,
380
- ) -> Union[ADMPipelineOutput, Tuple]:
381
  r"""
382
- Generate images with ADM.
383
-
384
- Args:
385
- class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
386
- ImageNet class indices or English label strings. Required for class-conditional UNets and for ADM-G
387
- classifier guidance. Strings are resolved via [`~ADMPipeline.get_label_ids`].
388
- batch_size (`int`, *optional*, defaults to 1):
389
- Number of images to generate when `class_labels` is not provided.
390
- height (`int`, *optional*):
391
- Height in pixels. Defaults to `unet.config.image_size`.
392
- width (`int`, *optional*):
393
- Width in pixels. Defaults to `unet.config.image_size`.
394
- num_inference_steps (`int`, *optional*, defaults to 250):
395
- Number of denoising steps.
396
- use_ddim (`bool`, *optional*, defaults to `False`):
397
- Use DDIM sampling instead of DDPM.
398
- eta (`float`, *optional*, defaults to 0.0):
399
- DDIM stochasticity parameter. Only used when `use_ddim=True`.
400
- clip_denoised (`bool`, *optional*, defaults to `True`):
401
- Clamp predicted `x_0` to `[-1, 1]` inside the scheduler.
402
- classifier_guidance_scale (`float`, *optional*, defaults to 0.0):
403
- ADM-G guidance strength. Values `> 0` require a loaded `classifier` (OpenAI `classifier_scale`).
404
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
405
- RNG for reproducible generation.
406
- latents (`torch.Tensor`, *optional*):
407
- Pre-generated initial noise.
408
- output_type (`str`, *optional*, defaults to `"pil"`):
409
- Output format: `"pil"`, `"np"`, or `"pt"`.
410
- return_dict (`bool`, *optional*, defaults to `True`):
411
- Return an [`ADMPipelineOutput`] instead of a tuple.
412
 
413
  Examples:
414
-
415
- Returns:
416
- [`ADMPipelineOutput`] or `tuple`:
417
- Generated images.
418
  """
419
- if height is None:
420
- height = int(self.unet.config.image_size)
421
- if width is None:
422
- width = int(self.unet.config.image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
- class_labels = self._normalize_class_labels(class_labels)
425
- self.check_inputs(class_labels, height, width)
426
 
427
- if classifier_guidance_scale > 0 and self.classifier is None:
428
- raise ValueError("`classifier_guidance_scale > 0` requires a loaded `classifier` (ADM-G checkpoint).")
429
- if classifier_guidance_scale > 0 and class_labels is None:
430
- raise ValueError("`class_labels` are required when using classifier guidance.")
 
 
 
431
 
432
- self._classifier_guidance_scale = classifier_guidance_scale
433
  device = self._execution_device
434
- model_dtype = self.unet.dtype
 
435
 
 
 
 
436
  if class_labels is not None:
437
- if isinstance(class_labels, int):
438
- batch_size = 1
439
- elif isinstance(class_labels, list):
440
- batch_size = len(class_labels)
441
- elif torch.is_tensor(class_labels):
442
- batch_size = class_labels.shape[0]
443
-
444
- class_labels = self._prepare_class_labels(class_labels, batch_size, device)
445
-
446
- latents = self.prepare_latents(
447
- batch_size,
448
- 3,
449
- height,
450
- width,
451
- model_dtype,
452
- device,
453
- generator,
454
- latents,
455
- )
456
-
457
- self.scheduler.set_timesteps(num_inference_steps, device=device, use_ddim=use_ddim)
458
- self.scheduler._eta = eta
459
-
460
- self._num_timesteps = len(self.scheduler.timesteps)
461
-
462
- unet_class_labels = class_labels if self.unet.config.class_cond else None
463
 
464
- for t in tqdm(self.scheduler.timesteps, desc="Denoising"):
465
- timestep = torch.full((batch_size,), t, device=device, dtype=torch.long)
466
- model_timesteps = self.scheduler.scale_timesteps_for_model(timestep)
467
 
468
- model_output = self.unet(
469
- latents,
470
- model_timesteps,
471
- class_labels=unet_class_labels,
472
- return_dict=True,
473
- ).sample
474
 
475
  cond_grad = None
476
- if self.do_classifier_guidance:
477
- cond_grad = self._get_classifier_grad(
478
- latents,
479
- timestep,
480
- class_labels,
481
- classifier_guidance_scale,
482
  )
483
 
484
- latents = self.scheduler.step(
485
- model_output,
486
- t,
487
- latents,
488
- generator=generator,
489
- clip_denoised=clip_denoised,
490
- eta=eta,
491
- cond_grad=cond_grad,
492
- ).prev_sample
493
-
494
- image = latents
495
- has_nsfw_concept = None
496
-
497
- if output_type == "latent":
498
- image = latents
499
- elif output_type == "pt":
500
- image = (image / 2 + 0.5).clamp(0, 1)
501
- elif output_type in ("pil", "np"):
502
- image = (image / 2 + 0.5).clamp(0, 1)
 
 
 
 
 
503
  image = self.image_processor.postprocess(image, output_type=output_type)
504
 
505
  self.maybe_free_model_hooks()
506
-
507
  if not return_dict:
508
- return (image, has_nsfw_concept)
509
-
510
- return ADMPipelineOutput(images=image)
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
+ import inspect
16
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 
 
 
 
 
 
 
 
 
17
 
 
18
  import torch
 
19
 
20
  from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
22
+ from diffusers.schedulers import KarrasDiffusionSchedulers
23
+ from diffusers.utils import replace_example_docstring
24
  from diffusers.utils.torch_utils import randn_tensor
25
 
 
26
  EXAMPLE_DOC_STRING = """
27
  Examples:
28
  ```py
29
+ >>> from pathlib import Path
30
  >>> import torch
31
  >>> from diffusers import DiffusionPipeline
32
 
33
+ >>> model_dir = Path("path/to/BiliSakura/ADM-diffusers/ADM-G-256")
34
+ >>> pipe = DiffusionPipeline.from_pretrained(
35
+ ... str(model_dir),
36
+ ... local_files_only=True,
37
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
38
+ ... torch_dtype=torch.bfloat16,
39
+ ... )
40
+ >>> pipe = pipe.to("cuda")
41
+ >>> class_id = pipe.get_label_ids("golden retriever")[0]
42
+ >>> image = pipe(class_labels=class_id, guidance_scale=1.0).images[0]
 
 
43
  ```
44
  """
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ADMPipeline(DiffusionPipeline):
48
+ r"""ADM/ADM-G pipeline compatible with Diffusers custom pipeline loading."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  model_cpu_offload_seq = "classifier->unet"
51
  _optional_components = ["classifier"]
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def __init__(
54
  self,
55
  unet,
56
+ scheduler: KarrasDiffusionSchedulers,
57
+ classifier: Optional[Any] = None,
58
+ id2label: Optional[Dict[str, str]] = None,
59
+ null_class_id: int = 1000,
60
+ ) -> None:
61
  super().__init__()
62
  self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
63
+ self.register_to_config(null_class_id=int(null_class_id))
64
  self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
65
+ self._id2label = {int(k): v for k, v in (id2label or {}).items()}
 
66
  self.labels = self._build_label2id(self._id2label)
67
 
68
  @staticmethod
69
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
70
+ label2id: Dict[str, int] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  for class_id, value in id2label.items():
72
  for synonym in value.split(","):
73
  synonym = synonym.strip()
 
76
  return dict(sorted(label2id.items()))
77
 
78
  @property
79
+ def id2label(self) -> Dict[int, str]:
 
80
  return self._id2label
81
 
82
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
83
+ if not self.labels:
84
+ raise ValueError("No id2label mapping is available in this checkpoint.")
85
+ labels = [label] if isinstance(label, str) else label
86
+ missing = [item for item in labels if item not in self.labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if missing:
88
+ preview = ", ".join(list(self.labels.keys())[:8])
89
+ raise ValueError(f"Unknown labels: {missing}. Example valid labels: {preview}, ...")
90
+ return [self.labels[item] for item in labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ @staticmethod
93
+ def prepare_extra_step_kwargs(
94
+ scheduler: KarrasDiffusionSchedulers,
95
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
96
+ eta: float,
97
+ ) -> Dict[str, Any]:
98
+ kwargs: Dict[str, Any] = {}
99
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
100
+ if "eta" in step_params:
101
+ kwargs["eta"] = eta
102
+ if "generator" in step_params:
103
+ kwargs["generator"] = generator
104
+ return kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ @staticmethod
107
+ def _is_ddim_like(step_params: Set[str]) -> bool:
108
+ return "eta" in step_params
 
 
 
 
 
 
 
 
 
 
109
 
110
+ @staticmethod
111
+ def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
112
+ if not torch.is_tensor(timestep):
113
+ timestep = torch.tensor([timestep], dtype=torch.long, device=device)
114
+ elif timestep.ndim == 0:
115
+ timestep = timestep[None].to(device=device)
116
+ return timestep.expand(batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @torch.no_grad()
119
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
124
  height: Optional[int] = None,
125
  width: Optional[int] = None,
126
  num_inference_steps: int = 250,
127
+ guidance_scale: float = 1.0,
128
+ classifier_guidance_scale: float = 0.0,
129
  eta: float = 0.0,
130
  clip_denoised: bool = True,
 
131
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
132
  latents: Optional[torch.Tensor] = None,
133
  output_type: str = "pil",
134
  return_dict: bool = True,
135
+ ) -> Union[ImagePipelineOutput, Tuple]:
136
  r"""
137
+ Generate samples from the ADM/ADM-G checkpoint.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  Examples:
140
+ <!-- this section is replaced by replace_example_docstring -->
 
 
 
141
  """
142
+ # Stage 1: check inputs
143
+ if isinstance(class_labels, str):
144
+ class_labels = self.get_label_ids(class_labels)[0]
145
+ if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
146
+ class_labels = self.get_label_ids(class_labels)
147
+
148
+ native_size = int(getattr(self.unet.config, "image_size", 256))
149
+ height = native_size if height is None else int(height)
150
+ width = native_size if width is None else int(width)
151
+
152
+ if height % 8 != 0 or width % 8 != 0:
153
+ raise ValueError(f"height and width must be divisible by 8, got ({height}, {width}).")
154
+ if output_type not in {"pil", "np", "pt", "latent"}:
155
+ raise ValueError(f"Unsupported output_type: {output_type}")
156
+ # This checkpoint does not use classifier-free guidance (CFG).
157
+ # Keep classifier_guidance_scale for compatibility, but treat guidance_scale
158
+ # as the primary classifier-guidance strength.
159
+ effective_classifier_guidance_scale = (
160
+ float(classifier_guidance_scale) if classifier_guidance_scale > 0 else float(guidance_scale)
161
+ )
162
 
163
+ if class_labels is None and (self.unet.config.class_cond or effective_classifier_guidance_scale > 0):
164
+ raise ValueError("class_labels are required for class-conditional sampling and ADM-G guidance.")
165
 
166
+ if isinstance(class_labels, int):
167
+ batch_size = 1
168
+ class_labels = [class_labels]
169
+ elif isinstance(class_labels, list):
170
+ batch_size = len(class_labels)
171
+ elif torch.is_tensor(class_labels):
172
+ batch_size = int(class_labels.shape[0])
173
 
174
+ # Stage 2: define call parameters
175
  device = self._execution_device
176
+ channels = int(getattr(self.unet.config, "in_channels", 3))
177
+ dtype = self.unet.dtype
178
 
179
+ # Stage 3: prepare class conditioning
180
+ class_tensor = None
181
+ class_input = None
182
  if class_labels is not None:
183
+ class_tensor = class_labels if torch.is_tensor(class_labels) else torch.tensor(class_labels, dtype=torch.long)
184
+ class_tensor = class_tensor.to(device=device, dtype=torch.long).reshape(-1)
185
+ if class_tensor.shape[0] != batch_size:
186
+ raise ValueError("class_labels batch must match requested batch_size")
187
+ if self.unet.config.class_cond:
188
+ class_input = class_tensor
189
+
190
+ # Stage 4: prepare timesteps
191
+ scheduler = self.scheduler
192
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
193
+ scheduler.set_timesteps(num_inference_steps, device=device)
194
+
195
+ # Stage 5: prepare latent variables
196
+ shape = (batch_size, channels, height, width)
197
+ if latents is None:
198
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
199
+ else:
200
+ if tuple(latents.shape) != shape:
201
+ raise ValueError(f"Unexpected latents shape {tuple(latents.shape)}; expected {shape}.")
202
+ latents = latents.to(device=device, dtype=dtype)
203
+ latents = latents * scheduler.init_noise_sigma
 
 
 
 
 
204
 
205
+ # Stage 6: prepare extra step kwargs
206
+ extra_step_kwargs = self.prepare_extra_step_kwargs(scheduler, generator, eta)
 
207
 
208
+ # Stage 7: denoising loop
209
+ for timestep in self.progress_bar(scheduler.timesteps):
210
+ model_input = latents
211
+ model_input = scheduler.scale_model_input(model_input, timestep)
212
+ timestep_input = self._expand_timestep(timestep, model_input.shape[0], model_input.device)
213
+ model_output = self.unet(model_input, timestep_input, class_labels=class_input, return_dict=True).sample
214
 
215
  cond_grad = None
216
+ if effective_classifier_guidance_scale > 0:
217
+ if self.classifier is None or class_tensor is None:
218
+ raise ValueError("guidance_scale requires both classifier and class_labels.")
219
+ grad_t = self._expand_timestep(timestep, batch_size, latents.device)
220
+ cond_grad = self.classifier.guidance_gradient(
221
+ latents, grad_t, class_tensor, classifier_scale=effective_classifier_guidance_scale
222
  )
223
 
224
+ step_model_output = model_output
225
+ if cond_grad is not None:
226
+ if self._is_ddim_like(step_params):
227
+ eps = model_output[:, :channels] if model_output.shape[1] == 2 * channels else model_output
228
+ alpha_bar_t = scheduler.alphas_cumprod[timestep].to(device=latents.device, dtype=latents.dtype)
229
+ step_model_output = eps - (1 - alpha_bar_t).sqrt() * cond_grad
230
+ elif hasattr(scheduler, "_get_variance"):
231
+ pred_var = None
232
+ if model_output.shape[1] == 2 * channels:
233
+ _, pred_var = torch.split(model_output, channels, dim=1)
234
+ variance = scheduler._get_variance(int(timestep), predicted_variance=pred_var)
235
+ if scheduler.config.variance_type == "learned_range":
236
+ variance = torch.exp(variance)
237
+ latents = latents + variance * cond_grad
238
+ else:
239
+ raise ValueError(
240
+ "guidance_scale is not supported for the current scheduler. "
241
+ "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
242
+ )
243
+
244
+ latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
245
+
246
+ image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
247
+ if output_type in {"pil", "np"}:
248
  image = self.image_processor.postprocess(image, output_type=output_type)
249
 
250
  self.maybe_free_model_hooks()
 
251
  if not return_dict:
252
+ return (image,)
253
+ return ImagePipelineOutput(images=image)
 
ADM-G-256/scheduler/scheduler_config.json CHANGED
@@ -1,11 +1,12 @@
1
  {
2
- "_class_name": "ADMScheduler",
3
  "_diffusers_version": "0.36.0",
4
- "learn_sigma": true,
5
- "noise_schedule": "linear",
6
- "predict_xstart": false,
7
- "rescale_timesteps": false,
8
- "sigma_small": false,
9
- "steps": 1000,
10
- "timestep_respacing": ""
 
11
  }
 
1
  {
2
+ "_class_name": "DDPMScheduler",
3
  "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "beta_start": 0.0001,
6
+ "beta_end": 0.02,
7
+ "beta_schedule": "linear",
8
+ "prediction_type": "epsilon",
9
+ "variance_type": "learned_range",
10
+ "clip_sample": true,
11
+ "timestep_spacing": "leading"
12
  }
ADM-G-256/unet/__pycache__/unet_adm.cpython-312.pyc CHANGED
Binary files a/ADM-G-256/unet/__pycache__/unet_adm.cpython-312.pyc and b/ADM-G-256/unet/__pycache__/unet_adm.cpython-312.pyc differ
 
ADM-G-256/unet/modeling_adm.py CHANGED
@@ -37,7 +37,10 @@ def avg_pool_nd(dims: int, *args, **kwargs):
37
 
38
  class GroupNorm32(nn.GroupNorm):
39
  def forward(self, x):
40
- return super().forward(x.float()).type(x.dtype)
 
 
 
41
 
42
 
43
  def normalization(channels: int):
@@ -475,19 +478,20 @@ class EncoderUNetModel(nn.Module):
475
  self.middle_block.apply(convert_module_to_f32)
476
 
477
  def forward(self, x, timesteps):
478
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
479
  results = []
480
- h = x.type(self.dtype)
481
  for module in self.input_blocks:
482
  h = module(h, emb)
483
  if self.pool.startswith("spatial"):
484
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
485
  h = self.middle_block(h, emb)
486
  if self.pool.startswith("spatial"):
487
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
488
  h = torch.cat(results, dim=-1)
489
  return self.out(h)
490
- h = h.type(x.dtype)
491
  return self.out(h)
492
 
493
 
@@ -673,12 +677,13 @@ class UNetModel(nn.Module):
673
  def forward(self, x, timesteps, y: Optional[torch.Tensor] = None):
674
  assert (y is not None) == (self.num_classes is not None)
675
  hs = []
676
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
677
  if self.num_classes is not None:
678
  assert y.shape == (x.shape[0],)
679
  emb = emb + self.label_emb(y)
680
 
681
- h = x.type(self.dtype)
682
  for module in self.input_blocks:
683
  h = module(h, emb)
684
  hs.append(h)
@@ -686,7 +691,7 @@ class UNetModel(nn.Module):
686
  for module in self.output_blocks:
687
  h = torch.cat([h, hs.pop()], dim=1)
688
  h = module(h, emb)
689
- h = h.type(x.dtype)
690
  return self.out(h)
691
 
692
 
 
37
 
38
  class GroupNorm32(nn.GroupNorm):
39
  def forward(self, x):
40
+ weight = self.weight.float() if self.weight is not None else None
41
+ bias = self.bias.float() if self.bias is not None else None
42
+ y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps)
43
+ return y.to(dtype=x.dtype)
44
 
45
 
46
  def normalization(channels: int):
 
478
  self.middle_block.apply(convert_module_to_f32)
479
 
480
  def forward(self, x, timesteps):
481
+ emb = timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
482
+ emb = self.time_embed(emb)
483
  results = []
484
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
485
  for module in self.input_blocks:
486
  h = module(h, emb)
487
  if self.pool.startswith("spatial"):
488
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
489
  h = self.middle_block(h, emb)
490
  if self.pool.startswith("spatial"):
491
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
492
  h = torch.cat(results, dim=-1)
493
  return self.out(h)
494
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
495
  return self.out(h)
496
 
497
 
 
677
  def forward(self, x, timesteps, y: Optional[torch.Tensor] = None):
678
  assert (y is not None) == (self.num_classes is not None)
679
  hs = []
680
+ emb = timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
681
+ emb = self.time_embed(emb)
682
  if self.num_classes is not None:
683
  assert y.shape == (x.shape[0],)
684
  emb = emb + self.label_emb(y)
685
 
686
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
687
  for module in self.input_blocks:
688
  h = module(h, emb)
689
  hs.append(h)
 
691
  for module in self.output_blocks:
692
  h = torch.cat([h, hs.pop()], dim=1)
693
  h = module(h, emb)
694
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
695
  return self.out(h)
696
 
697
 
ADM-G-256/unet/unet_adm.py CHANGED
@@ -12,7 +12,12 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
  from diffusers.models.modeling_utils import ModelMixin
13
  from diffusers.utils import BaseOutput
14
 
15
- from modeling_adm import create_adm_unet_model
 
 
 
 
 
16
 
17
 
18
  @dataclass
 
12
  from diffusers.models.modeling_utils import ModelMixin
13
  from diffusers.utils import BaseOutput
14
 
15
+ try:
16
+ from .modeling_adm import create_adm_unet_model
17
+ except ImportError:
18
+ import importlib
19
+
20
+ create_adm_unet_model = importlib.import_module("modeling_adm").create_adm_unet_model
21
 
22
 
23
  @dataclass
ADM-G-512/README.md CHANGED
@@ -10,6 +10,8 @@ Self-contained ADM-G checkpoint inside [`BiliSakura/ADM-diffusers`](https://hugg
10
 
11
  ![ADM-G-512 demo](demo.png)
12
 
 
 
13
  ## Layout
14
 
15
  ```text
@@ -25,23 +27,27 @@ ADM-G-512/
25
  ## Load
26
 
27
  ```python
28
- import sys
29
  from pathlib import Path
30
- from huggingface_hub import snapshot_download
31
-
32
- repo_dir = Path(snapshot_download("BiliSakura/ADM-diffusers"))
33
- sys.path.insert(0, str(repo_dir / "ADM-G-512"))
34
- from pipeline import ADMPipeline
35
-
36
- pipe = ADMPipeline.from_pretrained(".")
37
- pipe.to("cuda")
38
- pipe.unet.float()
39
- pipe.classifier.float()
40
- pipe.classifier.model.dtype = torch.float32
41
-
42
- images = pipe(
43
- class_labels=207,
44
- num_inference_steps=250,
45
- classifier_guidance_scale=4.0,
46
- ).images
 
 
 
 
 
47
  ```
 
10
 
11
  ![ADM-G-512 demo](demo.png)
12
 
13
+ Settings used for this demo image: `ADM-G-512`, `DDIMScheduler`, `num_inference_steps=50`, `guidance_scale=4.0`, `seed=42`, class `"golden retriever"`.
14
+
15
  ## Layout
16
 
17
  ```text
 
27
  ## Load
28
 
29
  ```python
 
30
  from pathlib import Path
31
+ import torch
32
+ from diffusers import DDIMScheduler, DiffusionPipeline
33
+
34
+ model_dir = Path("./BiliSakura/ADM-diffusers/ADM-G-512")
35
+ pipe = DiffusionPipeline.from_pretrained(
36
+ str(model_dir),
37
+ local_files_only=True,
38
+ custom_pipeline=str(model_dir / "pipeline.py"),
39
+ torch_dtype=torch.bfloat16,
40
+ )
41
+ pipe = pipe.to("cuda")
42
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
43
+ class_id = pipe.get_label_ids("golden retriever")[0]
44
+ generator = torch.Generator(device="cuda").manual_seed(42)
45
+
46
+ out = pipe(
47
+ class_labels=class_id,
48
+ guidance_scale=4.0,
49
+ num_inference_steps=50,
50
+ generator=generator,
51
+ ).images[0]
52
+ out
53
  ```
ADM-G-512/__pycache__/pipeline.cpython-312.pyc CHANGED
Binary files a/ADM-G-512/__pycache__/pipeline.cpython-312.pyc and b/ADM-G-512/__pycache__/pipeline.cpython-312.pyc differ
 
ADM-G-512/classifier/__pycache__/classifier_adm.cpython-312.pyc CHANGED
Binary files a/ADM-G-512/classifier/__pycache__/classifier_adm.cpython-312.pyc and b/ADM-G-512/classifier/__pycache__/classifier_adm.cpython-312.pyc differ
 
ADM-G-512/classifier/classifier_adm.py CHANGED
@@ -3,18 +3,524 @@
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
 
 
 
6
  from dataclasses import dataclass
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
10
  import torch.nn.functional as F
 
11
 
12
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
13
  from diffusers.models.modeling_utils import ModelMixin
14
  from diffusers.utils import BaseOutput
15
 
 
16
 
17
- from modeling_adm import create_adm_classifier_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @dataclass
 
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
 
6
+ import math
7
+ from abc import abstractmethod
8
  from dataclasses import dataclass
9
  from typing import Optional, Tuple, Union
10
 
11
  import torch
12
+ import torch.nn as nn
13
  import torch.nn.functional as F
14
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
15
 
16
  from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.models.embeddings import get_timestep_embedding
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from diffusers.utils import BaseOutput
20
 
21
+ NUM_CLASSES = 1000
22
 
23
+
24
+ def conv_nd(dims: int, *args, **kwargs):
25
+ if dims == 1:
26
+ return nn.Conv1d(*args, **kwargs)
27
+ if dims == 2:
28
+ return nn.Conv2d(*args, **kwargs)
29
+ if dims == 3:
30
+ return nn.Conv3d(*args, **kwargs)
31
+ raise ValueError(f"unsupported dimensions: {dims}")
32
+
33
+
34
+ def linear(*args, **kwargs):
35
+ return nn.Linear(*args, **kwargs)
36
+
37
+
38
+ def avg_pool_nd(dims: int, *args, **kwargs):
39
+ if dims == 1:
40
+ return nn.AvgPool1d(*args, **kwargs)
41
+ if dims == 2:
42
+ return nn.AvgPool2d(*args, **kwargs)
43
+ if dims == 3:
44
+ return nn.AvgPool3d(*args, **kwargs)
45
+ raise ValueError(f"unsupported dimensions: {dims}")
46
+
47
+
48
+ class GroupNorm32(nn.GroupNorm):
49
+ def forward(self, x):
50
+ weight = self.weight.float() if self.weight is not None else None
51
+ bias = self.bias.float() if self.bias is not None else None
52
+ y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps)
53
+ return y.to(dtype=x.dtype)
54
+
55
+
56
+ def normalization(channels: int):
57
+ return GroupNorm32(32, channels)
58
+
59
+
60
+ def zero_module(module: nn.Module):
61
+ for p in module.parameters():
62
+ p.detach().zero_()
63
+ return module
64
+
65
+
66
+ def convert_module_to_f16(module: nn.Module):
67
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
68
+ module.weight.data = module.weight.data.half()
69
+ if module.bias is not None:
70
+ module.bias.data = module.bias.data.half()
71
+
72
+
73
+ def convert_module_to_f32(module: nn.Module):
74
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
75
+ module.weight.data = module.weight.data.float()
76
+ if module.bias is not None:
77
+ module.bias.data = module.bias.data.float()
78
+
79
+
80
+ class TimestepBlock(nn.Module):
81
+ @abstractmethod
82
+ def forward(self, x, emb):
83
+ raise NotImplementedError
84
+
85
+
86
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
87
+ def forward(self, x, emb):
88
+ for layer in self:
89
+ if isinstance(layer, TimestepBlock):
90
+ x = layer(x, emb)
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.out_channels = out_channels or channels
101
+ self.use_conv = use_conv
102
+ self.dims = dims
103
+ if use_conv:
104
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
105
+
106
+ def forward(self, x):
107
+ assert x.shape[1] == self.channels
108
+ if self.dims == 3:
109
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
110
+ else:
111
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
112
+ if self.use_conv:
113
+ x = self.conv(x)
114
+ return x
115
+
116
+
117
+ class Downsample(nn.Module):
118
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
119
+ super().__init__()
120
+ self.channels = channels
121
+ self.out_channels = out_channels or channels
122
+ self.use_conv = use_conv
123
+ stride = 2 if dims != 3 else (1, 2, 2)
124
+ if use_conv:
125
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
126
+ else:
127
+ assert self.channels == self.out_channels
128
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
129
+
130
+ def forward(self, x):
131
+ assert x.shape[1] == self.channels
132
+ return self.op(x)
133
+
134
+
135
+ class ResBlock(TimestepBlock):
136
+ def __init__(
137
+ self,
138
+ channels,
139
+ emb_channels,
140
+ dropout,
141
+ out_channels=None,
142
+ use_conv=False,
143
+ use_scale_shift_norm=False,
144
+ dims=2,
145
+ use_checkpoint=False,
146
+ up=False,
147
+ down=False,
148
+ ):
149
+ super().__init__()
150
+ self.channels = channels
151
+ self.out_channels = out_channels or channels
152
+ self.use_checkpoint = use_checkpoint
153
+ self.use_scale_shift_norm = use_scale_shift_norm
154
+ self.in_layers = nn.Sequential(
155
+ normalization(channels),
156
+ nn.SiLU(),
157
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
158
+ )
159
+
160
+ self.updown = up or down
161
+ if up:
162
+ self.h_upd = Upsample(channels, False, dims)
163
+ self.x_upd = Upsample(channels, False, dims)
164
+ elif down:
165
+ self.h_upd = Downsample(channels, False, dims)
166
+ self.x_upd = Downsample(channels, False, dims)
167
+ else:
168
+ self.h_upd = self.x_upd = nn.Identity()
169
+
170
+ self.emb_layers = nn.Sequential(
171
+ nn.SiLU(),
172
+ linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
173
+ )
174
+ self.out_layers = nn.Sequential(
175
+ normalization(self.out_channels),
176
+ nn.SiLU(),
177
+ nn.Dropout(p=dropout),
178
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
179
+ )
180
+
181
+ if self.out_channels == channels:
182
+ self.skip_connection = nn.Identity()
183
+ elif use_conv:
184
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
185
+ else:
186
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
187
+
188
+ def forward(self, x, emb):
189
+ if self.use_checkpoint and x.requires_grad:
190
+ return torch_checkpoint(self._forward, x, emb, use_reentrant=False)
191
+ return self._forward(x, emb)
192
+
193
+ def _forward(self, x, emb):
194
+ if self.updown:
195
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
196
+ h = in_rest(x)
197
+ h = self.h_upd(h)
198
+ x = self.x_upd(x)
199
+ h = in_conv(h)
200
+ else:
201
+ h = self.in_layers(x)
202
+
203
+ emb_out = self.emb_layers(emb).type(h.dtype)
204
+ while len(emb_out.shape) < len(h.shape):
205
+ emb_out = emb_out[..., None]
206
+ if self.use_scale_shift_norm:
207
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
208
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
209
+ h = out_norm(h) * (1 + scale) + shift
210
+ h = out_rest(h)
211
+ else:
212
+ h = h + emb_out
213
+ h = self.out_layers(h)
214
+ return self.skip_connection(x) + h
215
+
216
+
217
+ class QKVAttentionLegacy(nn.Module):
218
+ def __init__(self, n_heads):
219
+ super().__init__()
220
+ self.n_heads = n_heads
221
+
222
+ def forward(self, qkv):
223
+ bs, width, length = qkv.shape
224
+ assert width % (3 * self.n_heads) == 0
225
+ ch = width // (3 * self.n_heads)
226
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
227
+ scale = 1 / math.sqrt(math.sqrt(ch))
228
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
229
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
230
+ a = torch.einsum("bts,bcs->bct", weight, v)
231
+ return a.reshape(bs, -1, length)
232
+
233
+
234
+ class QKVAttention(nn.Module):
235
+ def __init__(self, n_heads):
236
+ super().__init__()
237
+ self.n_heads = n_heads
238
+
239
+ def forward(self, qkv):
240
+ bs, width, length = qkv.shape
241
+ assert width % (3 * self.n_heads) == 0
242
+ ch = width // (3 * self.n_heads)
243
+ q, k, v = qkv.chunk(3, dim=1)
244
+ scale = 1 / math.sqrt(math.sqrt(ch))
245
+ weight = torch.einsum(
246
+ "bct,bcs->bts",
247
+ (q * scale).view(bs * self.n_heads, ch, length),
248
+ (k * scale).view(bs * self.n_heads, ch, length),
249
+ )
250
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
251
+ a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
252
+ return a.reshape(bs, -1, length)
253
+
254
+
255
+ class AttentionBlock(nn.Module):
256
+ def __init__(
257
+ self,
258
+ channels,
259
+ num_heads=1,
260
+ num_head_channels=-1,
261
+ use_checkpoint=False,
262
+ use_new_attention_order=False,
263
+ ):
264
+ super().__init__()
265
+ if num_head_channels == -1:
266
+ self.num_heads = num_heads
267
+ else:
268
+ assert channels % num_head_channels == 0
269
+ self.num_heads = channels // num_head_channels
270
+ self.use_checkpoint = use_checkpoint
271
+ self.norm = normalization(channels)
272
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
273
+ self.attention = QKVAttention(self.num_heads) if use_new_attention_order else QKVAttentionLegacy(self.num_heads)
274
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
275
+
276
+ def forward(self, x):
277
+ if self.use_checkpoint and x.requires_grad:
278
+ return torch_checkpoint(self._forward, x, use_reentrant=False)
279
+ return self._forward(x)
280
+
281
+ def _forward(self, x):
282
+ b, c, *spatial = x.shape
283
+ x = x.reshape(b, c, -1)
284
+ qkv = self.qkv(self.norm(x))
285
+ h = self.attention(qkv)
286
+ h = self.proj_out(h)
287
+ return (x + h).reshape(b, c, *spatial)
288
+
289
+
290
+ class AttentionPool2d(nn.Module):
291
+ """CLIP-style attention pooling used by ADM noisy classifiers."""
292
+
293
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None):
294
+ super().__init__()
295
+ self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
296
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
297
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
298
+ self.num_heads = embed_dim // num_heads_channels
299
+ self.attention = QKVAttention(self.num_heads)
300
+
301
+ def forward(self, x):
302
+ b, c, *_spatial = x.shape
303
+ x = x.reshape(b, c, -1)
304
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
305
+ x = x + self.positional_embedding[None, :, :].to(x.dtype)
306
+ x = self.qkv_proj(x)
307
+ x = self.attention(x)
308
+ x = self.c_proj(x)
309
+ return x[:, :, 0]
310
+
311
+
312
+ class EncoderUNetModel(nn.Module):
313
+ """Noisy image classifier backbone for ADM-G (classifier guidance)."""
314
+
315
+ def __init__(
316
+ self,
317
+ image_size,
318
+ in_channels,
319
+ model_channels,
320
+ out_channels,
321
+ num_res_blocks,
322
+ attention_resolutions,
323
+ dropout=0,
324
+ channel_mult=(1, 2, 4, 8),
325
+ conv_resample=True,
326
+ dims=2,
327
+ use_checkpoint=False,
328
+ use_fp16=False,
329
+ num_heads=1,
330
+ num_head_channels=-1,
331
+ use_scale_shift_norm=False,
332
+ resblock_updown=False,
333
+ use_new_attention_order=False,
334
+ pool="adaptive",
335
+ ):
336
+ super().__init__()
337
+
338
+ self.model_channels = model_channels
339
+ self.use_checkpoint = use_checkpoint
340
+ self.dtype = torch.float16 if use_fp16 else torch.float32
341
+
342
+ time_embed_dim = model_channels * 4
343
+ self.time_embed = nn.Sequential(
344
+ linear(model_channels, time_embed_dim),
345
+ nn.SiLU(),
346
+ linear(time_embed_dim, time_embed_dim),
347
+ )
348
+
349
+ ch = int(channel_mult[0] * model_channels)
350
+ self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
351
+ self._feature_size = ch
352
+ ds = 1
353
+ for level, mult in enumerate(channel_mult):
354
+ for _ in range(num_res_blocks):
355
+ layers = [
356
+ ResBlock(
357
+ ch,
358
+ time_embed_dim,
359
+ dropout,
360
+ out_channels=int(mult * model_channels),
361
+ dims=dims,
362
+ use_checkpoint=use_checkpoint,
363
+ use_scale_shift_norm=use_scale_shift_norm,
364
+ )
365
+ ]
366
+ ch = int(mult * model_channels)
367
+ if ds in attention_resolutions:
368
+ layers.append(
369
+ AttentionBlock(
370
+ ch,
371
+ use_checkpoint=use_checkpoint,
372
+ num_heads=num_heads,
373
+ num_head_channels=num_head_channels,
374
+ use_new_attention_order=use_new_attention_order,
375
+ )
376
+ )
377
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
378
+ self._feature_size += ch
379
+ if level != len(channel_mult) - 1:
380
+ out_ch = ch
381
+ self.input_blocks.append(
382
+ TimestepEmbedSequential(
383
+ ResBlock(
384
+ ch,
385
+ time_embed_dim,
386
+ dropout,
387
+ out_channels=out_ch,
388
+ dims=dims,
389
+ use_checkpoint=use_checkpoint,
390
+ use_scale_shift_norm=use_scale_shift_norm,
391
+ down=True,
392
+ )
393
+ if resblock_updown
394
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
395
+ )
396
+ )
397
+ ch = out_ch
398
+ ds *= 2
399
+ self._feature_size += ch
400
+
401
+ self.middle_block = TimestepEmbedSequential(
402
+ ResBlock(
403
+ ch,
404
+ time_embed_dim,
405
+ dropout,
406
+ dims=dims,
407
+ use_checkpoint=use_checkpoint,
408
+ use_scale_shift_norm=use_scale_shift_norm,
409
+ ),
410
+ AttentionBlock(
411
+ ch,
412
+ use_checkpoint=use_checkpoint,
413
+ num_heads=num_heads,
414
+ num_head_channels=num_head_channels,
415
+ use_new_attention_order=use_new_attention_order,
416
+ ),
417
+ ResBlock(
418
+ ch,
419
+ time_embed_dim,
420
+ dropout,
421
+ dims=dims,
422
+ use_checkpoint=use_checkpoint,
423
+ use_scale_shift_norm=use_scale_shift_norm,
424
+ ),
425
+ )
426
+ self._feature_size += ch
427
+ self.pool = pool
428
+ if pool == "adaptive":
429
+ self.out = nn.Sequential(
430
+ normalization(ch),
431
+ nn.SiLU(),
432
+ nn.AdaptiveAvgPool2d((1, 1)),
433
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
434
+ nn.Flatten(),
435
+ )
436
+ elif pool == "attention":
437
+ assert num_head_channels != -1
438
+ self.out = nn.Sequential(
439
+ normalization(ch),
440
+ nn.SiLU(),
441
+ AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
442
+ )
443
+ elif pool == "spatial":
444
+ self.out = nn.Sequential(
445
+ nn.Linear(self._feature_size, 2048),
446
+ nn.ReLU(),
447
+ nn.Linear(2048, out_channels),
448
+ )
449
+ elif pool == "spatial_v2":
450
+ self.out = nn.Sequential(
451
+ nn.Linear(self._feature_size, 2048),
452
+ normalization(2048),
453
+ nn.SiLU(),
454
+ nn.Linear(2048, out_channels),
455
+ )
456
+ else:
457
+ raise NotImplementedError(f"Unexpected {pool} pooling")
458
+
459
+ def convert_to_fp16(self):
460
+ self.input_blocks.apply(convert_module_to_f16)
461
+ self.middle_block.apply(convert_module_to_f16)
462
+
463
+ def convert_to_fp32(self):
464
+ self.input_blocks.apply(convert_module_to_f32)
465
+ self.middle_block.apply(convert_module_to_f32)
466
+
467
+ def forward(self, x, timesteps):
468
+ emb = get_timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
469
+ emb = self.time_embed(emb)
470
+ results = []
471
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
472
+ for module in self.input_blocks:
473
+ h = module(h, emb)
474
+ if self.pool.startswith("spatial"):
475
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
476
+ h = self.middle_block(h, emb)
477
+ if self.pool.startswith("spatial"):
478
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
479
+ h = torch.cat(results, dim=-1)
480
+ return self.out(h)
481
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
482
+ return self.out(h)
483
+
484
+
485
+ def _default_channel_mult(image_size: int):
486
+ if image_size == 512:
487
+ return (0.5, 1, 1, 2, 2, 4, 4)
488
+ if image_size == 256:
489
+ return (1, 1, 2, 2, 4, 4)
490
+ if image_size == 128:
491
+ return (1, 1, 2, 3, 4)
492
+ if image_size == 64:
493
+ return (1, 2, 3, 4)
494
+ raise ValueError(f"unsupported image size: {image_size}")
495
+
496
+
497
+ def create_adm_classifier_model(
498
+ image_size: int,
499
+ classifier_width: int = 128,
500
+ classifier_depth: int = 2,
501
+ classifier_attention_resolutions: str = "32,16,8",
502
+ classifier_use_scale_shift_norm: bool = True,
503
+ classifier_resblock_updown: bool = True,
504
+ classifier_pool: str = "attention",
505
+ use_fp16: bool = False,
506
+ num_classes: int = NUM_CLASSES,
507
+ ):
508
+ channel_mult = _default_channel_mult(image_size)
509
+ attention_ds = tuple(image_size // int(res) for res in classifier_attention_resolutions.split(","))
510
+ return EncoderUNetModel(
511
+ image_size=image_size,
512
+ in_channels=3,
513
+ model_channels=classifier_width,
514
+ out_channels=num_classes,
515
+ num_res_blocks=classifier_depth,
516
+ attention_resolutions=attention_ds,
517
+ channel_mult=channel_mult,
518
+ use_fp16=use_fp16,
519
+ num_head_channels=64,
520
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
521
+ resblock_updown=classifier_resblock_updown,
522
+ pool=classifier_pool,
523
+ )
524
 
525
 
526
  @dataclass
ADM-G-512/demo.png CHANGED

Git LFS Details

  • SHA256: e6e3afc16ac17292f33ae8f13f28145d33a882ae781a7a42c137687c8f98dea8
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB

Git LFS Details

  • SHA256: 82ea34d28d5fe28f719a7142da3194e6cfc860db7ac51f0478dba6600e87bf56
  • Pointer size: 131 Bytes
  • Size of remote file: 300 kB
ADM-G-512/model_index.json CHANGED
@@ -2,8 +2,8 @@
2
  "_class_name": "ADMPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
- "scheduling_adm",
6
- "ADMScheduler"
7
  ],
8
  "unet": [
9
  "unet_adm",
 
2
  "_class_name": "ADMPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
+ "diffusers",
6
+ "DDPMScheduler"
7
  ],
8
  "unet": [
9
  "unet_adm",
ADM-G-512/pipeline.py CHANGED
@@ -2,208 +2,72 @@
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
 
 
 
 
 
 
 
 
5
 
6
- """Hub custom pipeline: ADMPipeline.
7
-
8
- Load with native Hugging Face diffusers and `trust_remote_code=True`.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import importlib
14
- import sys
15
- from dataclasses import dataclass
16
- from pathlib import Path
17
- from typing import Dict, List, Optional, Tuple, Union
18
 
19
- import numpy as np
20
  import torch
21
- from tqdm.auto import tqdm
22
 
23
  from diffusers.image_processor import VaeImageProcessor
24
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
- from diffusers.utils import BaseOutput, replace_example_docstring
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
-
29
  EXAMPLE_DOC_STRING = """
30
  Examples:
31
  ```py
 
32
  >>> import torch
33
  >>> from diffusers import DiffusionPipeline
34
 
35
- >>> from pipeline import ADMPipeline
36
-
37
- >>> pipe = ADMPipeline.from_pretrained(".")
38
- >>> pipe.to("cuda")
39
-
40
- >>> # ADM-G (classifier guidance) with numeric class id
41
- >>> images = pipe(class_labels=207, classifier_guidance_scale=1.0, num_inference_steps=250).images
42
-
43
- >>> # Or use human-readable ImageNet labels (English)
44
- >>> pipe.id2label[207]
45
- >>> class_ids = pipe.get_label_ids("golden retriever")
46
- >>> images = pipe(class_labels="golden retriever", classifier_guidance_scale=1.0).images
47
  ```
48
  """
49
 
50
 
51
- @dataclass
52
- class ADMPipelineOutput(BaseOutput):
53
- """
54
- Output class for ADM pipelines.
55
-
56
- Args:
57
- images (`torch.Tensor` or `list[PIL.Image.Image]` or `np.ndarray`):
58
- Generated images of shape `(batch_size, num_channels, height, width)` when `output_type="pt"`,
59
- or a list of PIL images / NumPy array when post-processed.
60
- """
61
-
62
- images: Union[torch.Tensor, List, np.ndarray]
63
-
64
-
65
  class ADMPipeline(DiffusionPipeline):
66
- r"""
67
- Pipeline for image generation with ADM (Ablated Diffusion Model).
68
-
69
- Supports class-conditional ADM (labels embedded in the UNet) and **ADM-G** (unconditional UNet + noisy
70
- classifier guidance). For ADM-G, pass `classifier_guidance_scale > 0` and provide `class_labels`; the
71
- optional `classifier` predicts `p(y | x_t)` and steers sampling.
72
-
73
- Args:
74
- unet ([`ADMUNet2DModel`]):
75
- A UNet model to denoise image samples (typically unconditional for ADM-G).
76
- scheduler ([`ADMScheduler`]):
77
- A scheduler used with the UNet to denoise image samples.
78
- classifier ([`ADMClassifierModel`], *optional*):
79
- Noisy ImageNet classifier for ADM-G guidance.
80
- id2label (`dict[int, str]`, *optional*):
81
- ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
82
- """
83
 
84
  model_cpu_offload_seq = "classifier->unet"
85
  _optional_components = ["classifier"]
86
 
87
- @classmethod
88
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
89
- """Load a self-contained variant folder locally or from the Hub.
90
-
91
- Examples:
92
- ADMPipeline.from_pretrained(".")
93
- ADMPipeline.from_pretrained("./ADM-G-256")
94
- ADMPipeline.from_pretrained("BiliSakura/ADM-diffusers", subfolder="ADM-G-512")
95
- """
96
- repo_root = Path(__file__).resolve().parent
97
-
98
- if pretrained_model_name_or_path in (None, "", "."):
99
- variant = repo_root
100
- elif (
101
- isinstance(pretrained_model_name_or_path, str)
102
- and "/" in pretrained_model_name_or_path
103
- and not Path(pretrained_model_name_or_path).exists()
104
- ):
105
- from huggingface_hub import snapshot_download
106
-
107
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
108
- if subfolder:
109
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
110
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
111
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
112
- else:
113
- variant = Path(pretrained_model_name_or_path)
114
- if not variant.is_absolute():
115
- candidate = (Path.cwd() / variant).resolve()
116
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
117
- if subfolder:
118
- variant = variant / subfolder
119
-
120
- id2label_override = kwargs.pop("id2label", None)
121
- model_kwargs = dict(kwargs)
122
- inserted: List[str] = []
123
-
124
- def _load_component(folder: str, module_name: str, class_name: str):
125
- comp_dir = variant / folder
126
- module_path = comp_dir / f"{module_name}.py"
127
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
128
- if not module_path.exists() or not has_weights:
129
- return None
130
-
131
- comp_path = str(comp_dir)
132
- if comp_path not in sys.path:
133
- sys.path.insert(0, comp_path)
134
- inserted.append(comp_path)
135
-
136
- module = importlib.import_module(module_name)
137
- component_cls = getattr(module, class_name)
138
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
139
-
140
- try:
141
- unet = _load_component("unet", "unet_adm", "ADMUNet2DModel")
142
- scheduler = _load_component("scheduler", "scheduling_adm", "ADMScheduler")
143
- classifier = _load_component("classifier", "classifier_adm", "ADMClassifierModel")
144
-
145
- if scheduler is None:
146
- sched_dir = variant / "scheduler"
147
- if (sched_dir / "scheduling_adm.py").exists():
148
- sched_path = str(sched_dir)
149
- if sched_path not in sys.path:
150
- sys.path.insert(0, sched_path)
151
- inserted.append(sched_path)
152
- scheduler = importlib.import_module("scheduling_adm").ADMScheduler()
153
-
154
- if unet is None and classifier is None:
155
- raise ValueError(f"No loadable components found under {variant}")
156
-
157
- id2label = id2label_override
158
- if id2label is None:
159
- model_index_path = variant / "model_index.json"
160
- if model_index_path.exists():
161
- id2label = cls._read_id2label_from_model_index(model_index_path)
162
-
163
- return cls(
164
- unet=unet,
165
- scheduler=scheduler,
166
- classifier=classifier,
167
- id2label=id2label,
168
- )
169
- finally:
170
- for comp_path in inserted:
171
- if comp_path in sys.path:
172
- sys.path.remove(comp_path)
173
-
174
  def __init__(
175
  self,
176
  unet,
177
- scheduler,
178
- classifier=None,
179
- id2label: Optional[Dict[Union[int, str], str]] = None,
180
- ):
 
181
  super().__init__()
182
  self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
 
183
  self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
184
-
185
- self._id2label = self._normalize_id2label(id2label)
186
  self.labels = self._build_label2id(self._id2label)
187
 
188
  @staticmethod
189
- def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
190
- if not id2label:
191
- return {}
192
- return {int(key): value for key, value in id2label.items()}
193
-
194
- @staticmethod
195
- def _read_id2label_from_model_index(model_index_path: Path) -> Optional[Dict[int, str]]:
196
- import json
197
-
198
- raw = json.loads(model_index_path.read_text(encoding="utf-8"))
199
- id2label = raw.get("id2label")
200
- if not isinstance(id2label, dict):
201
- return None
202
- return {int(key): value for key, value in id2label.items()}
203
-
204
- @staticmethod
205
- def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
206
- label2id: dict[str, int] = {}
207
  for class_id, value in id2label.items():
208
  for synonym in value.split(","):
209
  synonym = synonym.strip()
@@ -212,153 +76,44 @@ class ADMPipeline(DiffusionPipeline):
212
  return dict(sorted(label2id.items()))
213
 
214
  @property
215
- def id2label(self) -> dict[int, str]:
216
- """ImageNet class id to English label string (comma-separated synonyms)."""
217
  return self._id2label
218
 
219
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
220
- r"""
221
- Map ImageNet label strings to class ids.
222
-
223
- Args:
224
- label (`str` or `list[str]`):
225
- One or more English ImageNet label strings matching a synonym in `id2label`.
226
-
227
- Returns:
228
- `list[int]`: Class ids for [`~ADMPipeline.__call__`].
229
- """
230
- label2id = self.labels
231
- if not label2id:
232
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
233
-
234
- if isinstance(label, str):
235
- label = [label]
236
-
237
- missing = [item for item in label if item not in label2id]
238
  if missing:
239
- preview = ", ".join(list(label2id.keys())[:8])
240
- raise ValueError(
241
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
242
- )
243
- return [label2id[item] for item in label]
244
-
245
- @property
246
- def do_classifier_guidance(self) -> bool:
247
- return self.classifier is not None and getattr(self, "_classifier_guidance_scale", 0.0) > 0
248
-
249
- def _normalize_class_labels(
250
- self,
251
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]],
252
- ) -> Optional[Union[int, List[int], torch.Tensor]]:
253
- if class_labels is None:
254
- return None
255
-
256
- if isinstance(class_labels, str):
257
- return self.get_label_ids(class_labels)[0]
258
-
259
- if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
260
- return self.get_label_ids(class_labels)
261
-
262
- return class_labels
263
 
264
- def check_inputs(
265
- self,
266
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]],
267
- height: Optional[int],
268
- width: Optional[int],
269
- ):
270
- if class_labels is None and self.unet.config.class_cond:
271
- raise ValueError("`class_labels` are required for class-conditional ADM checkpoints.")
272
-
273
- if class_labels is not None and self.classifier is None and not self.unet.config.class_cond:
274
- raise ValueError(
275
- "This checkpoint is unconditional and has no classifier. Load an ADM-G repo with a "
276
- "`classifier/` subfolder, or use a class-conditional UNet."
277
- )
278
-
279
- if height is not None and height % 8 != 0:
280
- raise ValueError(f"`height` must be divisible by 8 but is {height}.")
281
- if width is not None and width % 8 != 0:
282
- raise ValueError(f"`width` must be divisible by 8 but is {width}.")
283
-
284
- def _prepare_class_labels(
285
- self,
286
- class_labels: Optional[Union[int, List[int], torch.Tensor]],
287
- batch_size: int,
288
- device: torch.device,
289
- ) -> Optional[torch.Tensor]:
290
- if class_labels is None:
291
- return None
292
-
293
- if isinstance(class_labels, int):
294
- class_labels = [class_labels]
295
- if not torch.is_tensor(class_labels):
296
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
297
- else:
298
- class_labels = class_labels.to(device=device, dtype=torch.long)
299
-
300
- if class_labels.shape[0] != batch_size:
301
- raise ValueError(
302
- f"`class_labels` batch ({class_labels.shape[0]}) must match requested batch size ({batch_size})."
303
- )
304
- return class_labels
305
 
306
- def _get_classifier_grad(
307
- self,
308
- sample: torch.Tensor,
309
- timestep: torch.Tensor,
310
- class_labels: torch.Tensor,
311
- classifier_scale: float,
312
- ) -> torch.Tensor:
313
- return self.classifier.guidance_gradient(
314
- sample,
315
- timestep,
316
- class_labels,
317
- classifier_scale=classifier_scale,
318
- )
319
 
320
- def prepare_latents(
321
- self,
322
- batch_size: int,
323
- num_channels: int,
324
- height: int,
325
- width: int,
326
- dtype: torch.dtype,
327
- device: torch.device,
328
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
329
- latents: Optional[torch.Tensor] = None,
330
- ) -> torch.Tensor:
331
- """
332
- Prepare initial Gaussian noise for pixel-space sampling.
333
-
334
- Args:
335
- batch_size (`int`):
336
- Number of images to generate.
337
- num_channels (`int`):
338
- Number of image channels (typically 3).
339
- height (`int`):
340
- Image height in pixels.
341
- width (`int`):
342
- Image width in pixels.
343
- dtype (`torch.dtype`):
344
- Data type for the latent tensor.
345
- device (`torch.device`):
346
- Target device.
347
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
348
- RNG for deterministic sampling.
349
- latents (`torch.Tensor`, *optional*):
350
- Pre-generated noise tensor.
351
-
352
- Returns:
353
- `torch.Tensor`:
354
- Initial noise of shape `(batch_size, num_channels, height, width)`.
355
- """
356
- shape = (batch_size, num_channels, height, width)
357
- if latents is None:
358
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359
- else:
360
- latents = latents.to(device=device, dtype=dtype)
361
- return latents
362
 
363
  @torch.no_grad()
364
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -369,142 +124,130 @@ class ADMPipeline(DiffusionPipeline):
369
  height: Optional[int] = None,
370
  width: Optional[int] = None,
371
  num_inference_steps: int = 250,
372
- use_ddim: bool = False,
 
373
  eta: float = 0.0,
374
  clip_denoised: bool = True,
375
- classifier_guidance_scale: float = 0.0,
376
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
377
  latents: Optional[torch.Tensor] = None,
378
  output_type: str = "pil",
379
  return_dict: bool = True,
380
- ) -> Union[ADMPipelineOutput, Tuple]:
381
  r"""
382
- Generate images with ADM.
383
-
384
- Args:
385
- class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
386
- ImageNet class indices or English label strings. Required for class-conditional UNets and for ADM-G
387
- classifier guidance. Strings are resolved via [`~ADMPipeline.get_label_ids`].
388
- batch_size (`int`, *optional*, defaults to 1):
389
- Number of images to generate when `class_labels` is not provided.
390
- height (`int`, *optional*):
391
- Height in pixels. Defaults to `unet.config.image_size`.
392
- width (`int`, *optional*):
393
- Width in pixels. Defaults to `unet.config.image_size`.
394
- num_inference_steps (`int`, *optional*, defaults to 250):
395
- Number of denoising steps.
396
- use_ddim (`bool`, *optional*, defaults to `False`):
397
- Use DDIM sampling instead of DDPM.
398
- eta (`float`, *optional*, defaults to 0.0):
399
- DDIM stochasticity parameter. Only used when `use_ddim=True`.
400
- clip_denoised (`bool`, *optional*, defaults to `True`):
401
- Clamp predicted `x_0` to `[-1, 1]` inside the scheduler.
402
- classifier_guidance_scale (`float`, *optional*, defaults to 0.0):
403
- ADM-G guidance strength. Values `> 0` require a loaded `classifier` (OpenAI `classifier_scale`).
404
- generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
405
- RNG for reproducible generation.
406
- latents (`torch.Tensor`, *optional*):
407
- Pre-generated initial noise.
408
- output_type (`str`, *optional*, defaults to `"pil"`):
409
- Output format: `"pil"`, `"np"`, or `"pt"`.
410
- return_dict (`bool`, *optional*, defaults to `True`):
411
- Return an [`ADMPipelineOutput`] instead of a tuple.
412
 
413
  Examples:
414
-
415
- Returns:
416
- [`ADMPipelineOutput`] or `tuple`:
417
- Generated images.
418
  """
419
- if height is None:
420
- height = int(self.unet.config.image_size)
421
- if width is None:
422
- width = int(self.unet.config.image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
- class_labels = self._normalize_class_labels(class_labels)
425
- self.check_inputs(class_labels, height, width)
426
 
427
- if classifier_guidance_scale > 0 and self.classifier is None:
428
- raise ValueError("`classifier_guidance_scale > 0` requires a loaded `classifier` (ADM-G checkpoint).")
429
- if classifier_guidance_scale > 0 and class_labels is None:
430
- raise ValueError("`class_labels` are required when using classifier guidance.")
 
 
 
431
 
432
- self._classifier_guidance_scale = classifier_guidance_scale
433
  device = self._execution_device
434
- model_dtype = self.unet.dtype
 
435
 
 
 
 
436
  if class_labels is not None:
437
- if isinstance(class_labels, int):
438
- batch_size = 1
439
- elif isinstance(class_labels, list):
440
- batch_size = len(class_labels)
441
- elif torch.is_tensor(class_labels):
442
- batch_size = class_labels.shape[0]
443
-
444
- class_labels = self._prepare_class_labels(class_labels, batch_size, device)
445
-
446
- latents = self.prepare_latents(
447
- batch_size,
448
- 3,
449
- height,
450
- width,
451
- model_dtype,
452
- device,
453
- generator,
454
- latents,
455
- )
456
-
457
- self.scheduler.set_timesteps(num_inference_steps, device=device, use_ddim=use_ddim)
458
- self.scheduler._eta = eta
459
-
460
- self._num_timesteps = len(self.scheduler.timesteps)
461
-
462
- unet_class_labels = class_labels if self.unet.config.class_cond else None
463
 
464
- for t in tqdm(self.scheduler.timesteps, desc="Denoising"):
465
- timestep = torch.full((batch_size,), t, device=device, dtype=torch.long)
466
- model_timesteps = self.scheduler.scale_timesteps_for_model(timestep)
467
 
468
- model_output = self.unet(
469
- latents,
470
- model_timesteps,
471
- class_labels=unet_class_labels,
472
- return_dict=True,
473
- ).sample
474
 
475
  cond_grad = None
476
- if self.do_classifier_guidance:
477
- cond_grad = self._get_classifier_grad(
478
- latents,
479
- timestep,
480
- class_labels,
481
- classifier_guidance_scale,
482
  )
483
 
484
- latents = self.scheduler.step(
485
- model_output,
486
- t,
487
- latents,
488
- generator=generator,
489
- clip_denoised=clip_denoised,
490
- eta=eta,
491
- cond_grad=cond_grad,
492
- ).prev_sample
493
-
494
- image = latents
495
- has_nsfw_concept = None
496
-
497
- if output_type == "latent":
498
- image = latents
499
- elif output_type == "pt":
500
- image = (image / 2 + 0.5).clamp(0, 1)
501
- elif output_type in ("pil", "np"):
502
- image = (image / 2 + 0.5).clamp(0, 1)
 
 
 
 
 
503
  image = self.image_processor.postprocess(image, output_type=output_type)
504
 
505
  self.maybe_free_model_hooks()
506
-
507
  if not return_dict:
508
- return (image, has_nsfw_concept)
509
-
510
- return ADMPipelineOutput(images=image)
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
+ import inspect
16
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 
 
 
 
 
 
 
 
 
17
 
 
18
  import torch
 
19
 
20
  from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
22
+ from diffusers.schedulers import KarrasDiffusionSchedulers
23
+ from diffusers.utils import replace_example_docstring
24
  from diffusers.utils.torch_utils import randn_tensor
25
 
 
26
  EXAMPLE_DOC_STRING = """
27
  Examples:
28
  ```py
29
+ >>> from pathlib import Path
30
  >>> import torch
31
  >>> from diffusers import DiffusionPipeline
32
 
33
+ >>> model_dir = Path("path/to/BiliSakura/ADM-diffusers/ADM-G-512")
34
+ >>> pipe = DiffusionPipeline.from_pretrained(
35
+ ... str(model_dir),
36
+ ... local_files_only=True,
37
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
38
+ ... torch_dtype=torch.bfloat16,
39
+ ... )
40
+ >>> pipe = pipe.to("cuda")
41
+ >>> class_id = pipe.get_label_ids("golden retriever")[0]
42
+ >>> image = pipe(class_labels=class_id, guidance_scale=4.0).images[0]
 
 
43
  ```
44
  """
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ADMPipeline(DiffusionPipeline):
48
+ r"""ADM/ADM-G pipeline compatible with Diffusers custom pipeline loading."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  model_cpu_offload_seq = "classifier->unet"
51
  _optional_components = ["classifier"]
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def __init__(
54
  self,
55
  unet,
56
+ scheduler: KarrasDiffusionSchedulers,
57
+ classifier: Optional[Any] = None,
58
+ id2label: Optional[Dict[str, str]] = None,
59
+ null_class_id: int = 1000,
60
+ ) -> None:
61
  super().__init__()
62
  self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
63
+ self.register_to_config(null_class_id=int(null_class_id))
64
  self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
65
+ self._id2label = {int(k): v for k, v in (id2label or {}).items()}
 
66
  self.labels = self._build_label2id(self._id2label)
67
 
68
  @staticmethod
69
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
70
+ label2id: Dict[str, int] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  for class_id, value in id2label.items():
72
  for synonym in value.split(","):
73
  synonym = synonym.strip()
 
76
  return dict(sorted(label2id.items()))
77
 
78
  @property
79
+ def id2label(self) -> Dict[int, str]:
 
80
  return self._id2label
81
 
82
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
83
+ if not self.labels:
84
+ raise ValueError("No id2label mapping is available in this checkpoint.")
85
+ labels = [label] if isinstance(label, str) else label
86
+ missing = [item for item in labels if item not in self.labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if missing:
88
+ preview = ", ".join(list(self.labels.keys())[:8])
89
+ raise ValueError(f"Unknown labels: {missing}. Example valid labels: {preview}, ...")
90
+ return [self.labels[item] for item in labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ @staticmethod
93
+ def prepare_extra_step_kwargs(
94
+ scheduler: KarrasDiffusionSchedulers,
95
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
96
+ eta: float,
97
+ ) -> Dict[str, Any]:
98
+ kwargs: Dict[str, Any] = {}
99
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
100
+ if "eta" in step_params:
101
+ kwargs["eta"] = eta
102
+ if "generator" in step_params:
103
+ kwargs["generator"] = generator
104
+ return kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ @staticmethod
107
+ def _is_ddim_like(step_params: Set[str]) -> bool:
108
+ return "eta" in step_params
 
 
 
 
 
 
 
 
 
 
109
 
110
+ @staticmethod
111
+ def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
112
+ if not torch.is_tensor(timestep):
113
+ timestep = torch.tensor([timestep], dtype=torch.long, device=device)
114
+ elif timestep.ndim == 0:
115
+ timestep = timestep[None].to(device=device)
116
+ return timestep.expand(batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @torch.no_grad()
119
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
124
  height: Optional[int] = None,
125
  width: Optional[int] = None,
126
  num_inference_steps: int = 250,
127
+ guidance_scale: float = 4.0,
128
+ classifier_guidance_scale: float = 0.0,
129
  eta: float = 0.0,
130
  clip_denoised: bool = True,
 
131
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
132
  latents: Optional[torch.Tensor] = None,
133
  output_type: str = "pil",
134
  return_dict: bool = True,
135
+ ) -> Union[ImagePipelineOutput, Tuple]:
136
  r"""
137
+ Generate samples from the ADM/ADM-G checkpoint.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  Examples:
140
+ <!-- this section is replaced by replace_example_docstring -->
 
 
 
141
  """
142
+ # Stage 1: check inputs
143
+ if isinstance(class_labels, str):
144
+ class_labels = self.get_label_ids(class_labels)[0]
145
+ if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
146
+ class_labels = self.get_label_ids(class_labels)
147
+
148
+ native_size = int(getattr(self.unet.config, "image_size", 256))
149
+ height = native_size if height is None else int(height)
150
+ width = native_size if width is None else int(width)
151
+
152
+ if height % 8 != 0 or width % 8 != 0:
153
+ raise ValueError(f"height and width must be divisible by 8, got ({height}, {width}).")
154
+ if output_type not in {"pil", "np", "pt", "latent"}:
155
+ raise ValueError(f"Unsupported output_type: {output_type}")
156
+ # This checkpoint does not use classifier-free guidance (CFG).
157
+ # Keep classifier_guidance_scale for compatibility, but treat guidance_scale
158
+ # as the primary classifier-guidance strength.
159
+ effective_classifier_guidance_scale = (
160
+ float(classifier_guidance_scale) if classifier_guidance_scale > 0 else float(guidance_scale)
161
+ )
162
 
163
+ if class_labels is None and (self.unet.config.class_cond or effective_classifier_guidance_scale > 0):
164
+ raise ValueError("class_labels are required for class-conditional sampling and ADM-G guidance.")
165
 
166
+ if isinstance(class_labels, int):
167
+ batch_size = 1
168
+ class_labels = [class_labels]
169
+ elif isinstance(class_labels, list):
170
+ batch_size = len(class_labels)
171
+ elif torch.is_tensor(class_labels):
172
+ batch_size = int(class_labels.shape[0])
173
 
174
+ # Stage 2: define call parameters
175
  device = self._execution_device
176
+ channels = int(getattr(self.unet.config, "in_channels", 3))
177
+ dtype = self.unet.dtype
178
 
179
+ # Stage 3: prepare class conditioning
180
+ class_tensor = None
181
+ class_input = None
182
  if class_labels is not None:
183
+ class_tensor = class_labels if torch.is_tensor(class_labels) else torch.tensor(class_labels, dtype=torch.long)
184
+ class_tensor = class_tensor.to(device=device, dtype=torch.long).reshape(-1)
185
+ if class_tensor.shape[0] != batch_size:
186
+ raise ValueError("class_labels batch must match requested batch_size")
187
+ if self.unet.config.class_cond:
188
+ class_input = class_tensor
189
+
190
+ # Stage 4: prepare timesteps
191
+ scheduler = self.scheduler
192
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
193
+ scheduler.set_timesteps(num_inference_steps, device=device)
194
+
195
+ # Stage 5: prepare latent variables
196
+ shape = (batch_size, channels, height, width)
197
+ if latents is None:
198
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
199
+ else:
200
+ if tuple(latents.shape) != shape:
201
+ raise ValueError(f"Unexpected latents shape {tuple(latents.shape)}; expected {shape}.")
202
+ latents = latents.to(device=device, dtype=dtype)
203
+ latents = latents * scheduler.init_noise_sigma
 
 
 
 
 
204
 
205
+ # Stage 6: prepare extra step kwargs
206
+ extra_step_kwargs = self.prepare_extra_step_kwargs(scheduler, generator, eta)
 
207
 
208
+ # Stage 7: denoising loop
209
+ for timestep in self.progress_bar(scheduler.timesteps):
210
+ model_input = latents
211
+ model_input = scheduler.scale_model_input(model_input, timestep)
212
+ timestep_input = self._expand_timestep(timestep, model_input.shape[0], model_input.device)
213
+ model_output = self.unet(model_input, timestep_input, class_labels=class_input, return_dict=True).sample
214
 
215
  cond_grad = None
216
+ if effective_classifier_guidance_scale > 0:
217
+ if self.classifier is None or class_tensor is None:
218
+ raise ValueError("guidance_scale requires both classifier and class_labels.")
219
+ grad_t = self._expand_timestep(timestep, batch_size, latents.device)
220
+ cond_grad = self.classifier.guidance_gradient(
221
+ latents, grad_t, class_tensor, classifier_scale=effective_classifier_guidance_scale
222
  )
223
 
224
+ step_model_output = model_output
225
+ if cond_grad is not None:
226
+ if self._is_ddim_like(step_params):
227
+ eps = model_output[:, :channels] if model_output.shape[1] == 2 * channels else model_output
228
+ alpha_bar_t = scheduler.alphas_cumprod[timestep].to(device=latents.device, dtype=latents.dtype)
229
+ step_model_output = eps - (1 - alpha_bar_t).sqrt() * cond_grad
230
+ elif hasattr(scheduler, "_get_variance"):
231
+ pred_var = None
232
+ if model_output.shape[1] == 2 * channels:
233
+ _, pred_var = torch.split(model_output, channels, dim=1)
234
+ variance = scheduler._get_variance(int(timestep), predicted_variance=pred_var)
235
+ if scheduler.config.variance_type == "learned_range":
236
+ variance = torch.exp(variance)
237
+ latents = latents + variance * cond_grad
238
+ else:
239
+ raise ValueError(
240
+ "guidance_scale is not supported for the current scheduler. "
241
+ "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
242
+ )
243
+
244
+ latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
245
+
246
+ image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
247
+ if output_type in {"pil", "np"}:
248
  image = self.image_processor.postprocess(image, output_type=output_type)
249
 
250
  self.maybe_free_model_hooks()
 
251
  if not return_dict:
252
+ return (image,)
253
+ return ImagePipelineOutput(images=image)
 
ADM-G-512/scheduler/scheduler_config.json CHANGED
@@ -1,11 +1,12 @@
1
  {
2
- "_class_name": "ADMScheduler",
3
  "_diffusers_version": "0.36.0",
4
- "learn_sigma": true,
5
- "noise_schedule": "linear",
6
- "predict_xstart": false,
7
- "rescale_timesteps": false,
8
- "sigma_small": false,
9
- "steps": 1000,
10
- "timestep_respacing": ""
 
11
  }
 
1
  {
2
+ "_class_name": "DDPMScheduler",
3
  "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "beta_start": 0.0001,
6
+ "beta_end": 0.02,
7
+ "beta_schedule": "linear",
8
+ "prediction_type": "epsilon",
9
+ "variance_type": "learned_range",
10
+ "clip_sample": true,
11
+ "timestep_spacing": "leading"
12
  }
ADM-G-512/unet/modeling_adm.py CHANGED
@@ -37,7 +37,10 @@ def avg_pool_nd(dims: int, *args, **kwargs):
37
 
38
  class GroupNorm32(nn.GroupNorm):
39
  def forward(self, x):
40
- return super().forward(x.float()).type(x.dtype)
 
 
 
41
 
42
 
43
  def normalization(channels: int):
@@ -475,19 +478,20 @@ class EncoderUNetModel(nn.Module):
475
  self.middle_block.apply(convert_module_to_f32)
476
 
477
  def forward(self, x, timesteps):
478
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
479
  results = []
480
- h = x.type(self.dtype)
481
  for module in self.input_blocks:
482
  h = module(h, emb)
483
  if self.pool.startswith("spatial"):
484
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
485
  h = self.middle_block(h, emb)
486
  if self.pool.startswith("spatial"):
487
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
488
  h = torch.cat(results, dim=-1)
489
  return self.out(h)
490
- h = h.type(x.dtype)
491
  return self.out(h)
492
 
493
 
@@ -673,12 +677,13 @@ class UNetModel(nn.Module):
673
  def forward(self, x, timesteps, y: Optional[torch.Tensor] = None):
674
  assert (y is not None) == (self.num_classes is not None)
675
  hs = []
676
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
677
  if self.num_classes is not None:
678
  assert y.shape == (x.shape[0],)
679
  emb = emb + self.label_emb(y)
680
 
681
- h = x.type(self.dtype)
682
  for module in self.input_blocks:
683
  h = module(h, emb)
684
  hs.append(h)
@@ -686,7 +691,7 @@ class UNetModel(nn.Module):
686
  for module in self.output_blocks:
687
  h = torch.cat([h, hs.pop()], dim=1)
688
  h = module(h, emb)
689
- h = h.type(x.dtype)
690
  return self.out(h)
691
 
692
 
 
37
 
38
  class GroupNorm32(nn.GroupNorm):
39
  def forward(self, x):
40
+ weight = self.weight.float() if self.weight is not None else None
41
+ bias = self.bias.float() if self.bias is not None else None
42
+ y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps)
43
+ return y.to(dtype=x.dtype)
44
 
45
 
46
  def normalization(channels: int):
 
478
  self.middle_block.apply(convert_module_to_f32)
479
 
480
  def forward(self, x, timesteps):
481
+ emb = timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
482
+ emb = self.time_embed(emb)
483
  results = []
484
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
485
  for module in self.input_blocks:
486
  h = module(h, emb)
487
  if self.pool.startswith("spatial"):
488
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
489
  h = self.middle_block(h, emb)
490
  if self.pool.startswith("spatial"):
491
+ results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
492
  h = torch.cat(results, dim=-1)
493
  return self.out(h)
494
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
495
  return self.out(h)
496
 
497
 
 
677
  def forward(self, x, timesteps, y: Optional[torch.Tensor] = None):
678
  assert (y is not None) == (self.num_classes is not None)
679
  hs = []
680
+ emb = timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
681
+ emb = self.time_embed(emb)
682
  if self.num_classes is not None:
683
  assert y.shape == (x.shape[0],)
684
  emb = emb + self.label_emb(y)
685
 
686
+ h = x.to(dtype=self.time_embed[0].weight.dtype)
687
  for module in self.input_blocks:
688
  h = module(h, emb)
689
  hs.append(h)
 
691
  for module in self.output_blocks:
692
  h = torch.cat([h, hs.pop()], dim=1)
693
  h = module(h, emb)
694
+ h = h.to(dtype=self.time_embed[0].weight.dtype)
695
  return self.out(h)
696
 
697
 
ADM-G-512/unet/unet_adm.py CHANGED
@@ -12,7 +12,12 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
  from diffusers.models.modeling_utils import ModelMixin
13
  from diffusers.utils import BaseOutput
14
 
15
- from modeling_adm import create_adm_unet_model
 
 
 
 
 
16
 
17
 
18
  @dataclass
 
12
  from diffusers.models.modeling_utils import ModelMixin
13
  from diffusers.utils import BaseOutput
14
 
15
+ try:
16
+ from .modeling_adm import create_adm_unet_model
17
+ except ImportError:
18
+ import importlib
19
+
20
+ create_adm_unet_model = importlib.import_module("modeling_adm").create_adm_unet_model
21
 
22
 
23
  @dataclass
README.md CHANGED
@@ -28,7 +28,7 @@ This Hugging Face repo hosts **multiple self-contained checkpoints as subfolders
28
 
29
  ## Available checkpoints
30
 
31
- | Subfolder | Resolution | Classifier scale | OpenAI sources |
32
  | --- | --- | ---: | --- |
33
  | [`ADM-G-256/`](ADM-G-256/) | 256×256 | 1.0 | `256x256_diffusion.pt` + `256x256_classifier.pt` |
34
  | [`ADM-G-512/`](ADM-G-512/) | 512×512 | 4.0 | `512x512_diffusion.pt` + `512x512_classifier.pt` |
@@ -50,52 +50,33 @@ Chinese labels are still preserved in the main source repo under `src/labels/id2
50
 
51
  ![ADM-G-512 demo](ADM-G-512/demo.png)
52
 
53
- ## Load from Hugging Face
54
 
55
  ```python
56
- import sys
57
  from pathlib import Path
58
  import torch
59
- from huggingface_hub import snapshot_download
60
-
61
- repo_dir = Path(snapshot_download("BiliSakura/ADM-diffusers"))
62
- variant = "ADM-G-512" # or "ADM-G-256"
63
-
64
- sys.path.insert(0, str(repo_dir / variant))
65
- from pipeline import ADMPipeline
66
-
67
- pipe = ADMPipeline.from_pretrained(".")
68
- pipe.to("cuda")
69
- pipe.unet.float()
70
- pipe.classifier.float()
71
- pipe.classifier.model.dtype = torch.float32
72
-
73
- images = pipe(
74
- class_labels=207,
75
- num_inference_steps=250,
76
- classifier_guidance_scale=4.0 if variant == "ADM-G-512" else 1.0,
77
- ).images
78
-
79
- # Human-readable ImageNet labels (English)
80
- print(pipe.id2label[207]) # "golden retriever"
81
- pipe.get_label_ids("golden retriever") # [207]
82
- images = pipe(class_labels="golden retriever", classifier_guidance_scale=1.0).images
83
- ```
84
-
85
- ## Load from a local clone
86
-
87
- ```python
88
- import sys
89
- from pathlib import Path
90
-
91
- repo = Path("BiliSakura/ADM-diffusers").resolve()
92
- variant = "ADM-G-256"
93
-
94
- sys.path.insert(0, str(repo / variant))
95
- from pipeline import ADMPipeline
96
 
97
- pipe = ADMPipeline.from_pretrained(".")
98
- pipe.to("cuda")
99
  ```
100
 
101
  ## Repo layout
 
28
 
29
  ## Available checkpoints
30
 
31
+ | Subfolder | Resolution | Guidance scale | OpenAI sources |
32
  | --- | --- | ---: | --- |
33
  | [`ADM-G-256/`](ADM-G-256/) | 256×256 | 1.0 | `256x256_diffusion.pt` + `256x256_classifier.pt` |
34
  | [`ADM-G-512/`](ADM-G-512/) | 512×512 | 4.0 | `512x512_diffusion.pt` + `512x512_classifier.pt` |
 
50
 
51
  ![ADM-G-512 demo](ADM-G-512/demo.png)
52
 
53
+ Settings used for this demo image: `ADM-G-512`, `DDIMScheduler`, `num_inference_steps=50`, `guidance_scale=4.0`, `seed=42`, class `"golden retriever"`.
54
 
55
  ```python
 
56
  from pathlib import Path
57
  import torch
58
+ from diffusers import DDIMScheduler, DiffusionPipeline
59
+
60
+ model_dir = Path("./BiliSakura/ADM-diffusers/ADM-G-512")
61
+ pipe = DiffusionPipeline.from_pretrained(
62
+ str(model_dir),
63
+ local_files_only=True,
64
+ custom_pipeline=str(model_dir / "pipeline.py"),
65
+ torch_dtype=torch.bfloat16,
66
+ )
67
+ pipe = pipe.to("cuda")
68
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
69
+ class_id = pipe.get_label_ids("golden retriever")[0]
70
+ generator = torch.Generator(device="cuda").manual_seed(42)
71
+
72
+ out = pipe(
73
+ class_labels=class_id,
74
+ guidance_scale=4.0,
75
+ num_inference_steps=50,
76
+ generator=generator,
77
+ ).images[0]
78
+ out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
80
  ```
81
 
82
  ## Repo layout