BiliSakura commited on
Commit
9dc3cb9
·
verified ·
1 Parent(s): 26e1caf

Upload folder using huggingface_hub

Browse files
Files changed (45) hide show
  1. .gitattributes +3 -0
  2. DeCo-XL-16-256/decoder/__pycache__/decoder_deco.cpython-312.pyc +0 -0
  3. DeCo-XL-16-256/decoder/config.json +9 -0
  4. DeCo-XL-16-256/decoder/decoder_deco.py +163 -0
  5. DeCo-XL-16-256/decoder/diffusion_pytorch_model.safetensors +3 -0
  6. DeCo-XL-16-256/model_index.json +1021 -0
  7. DeCo-XL-16-256/pipeline.py +268 -0
  8. DeCo-XL-16-256/scheduler/scheduler_config.json +8 -0
  9. DeCo-XL-16-256/scheduler/scheduling_deco_flow_match_euler_discrete.py +82 -0
  10. DeCo-XL-16-256/transformer/__pycache__/transformer_deco.cpython-312.pyc +0 -0
  11. DeCo-XL-16-256/transformer/config.json +22 -0
  12. DeCo-XL-16-256/transformer/diffusion_pytorch_model.safetensors +3 -0
  13. DeCo-XL-16-256/transformer/transformer_deco.py +332 -0
  14. DeCo-XL-16-512/decoder/__pycache__/decoder_deco.cpython-312.pyc +0 -0
  15. DeCo-XL-16-512/decoder/config.json +8 -0
  16. DeCo-XL-16-512/decoder/decoder_deco.py +163 -0
  17. DeCo-XL-16-512/decoder/diffusion_pytorch_model.safetensors +3 -0
  18. DeCo-XL-16-512/decoder/diffusion_pytorch_model.safetensors.bak +3 -0
  19. DeCo-XL-16-512/demo.png +3 -0
  20. DeCo-XL-16-512/model_index.json +1021 -0
  21. DeCo-XL-16-512/pipeline.py +268 -0
  22. DeCo-XL-16-512/scheduler/scheduler_config.json +8 -0
  23. DeCo-XL-16-512/scheduler/scheduling_deco_flow_match_euler_discrete.py +82 -0
  24. DeCo-XL-16-512/transformer/__pycache__/transformer_deco.cpython-312.pyc +0 -0
  25. DeCo-XL-16-512/transformer/config.json +21 -0
  26. DeCo-XL-16-512/transformer/diffusion_pytorch_model.safetensors +3 -0
  27. DeCo-XL-16-512/transformer/diffusion_pytorch_model.safetensors.bak +3 -0
  28. DeCo-XL-16-512/transformer/transformer_deco.py +332 -0
  29. DeCo-XXL-16-512-t2i/decoder/__pycache__/decoder_deco.cpython-312.pyc +0 -0
  30. DeCo-XXL-16-512-t2i/decoder/config.json +8 -0
  31. DeCo-XXL-16-512-t2i/decoder/decoder_deco.py +177 -0
  32. DeCo-XXL-16-512-t2i/decoder/diffusion_pytorch_model.safetensors +3 -0
  33. DeCo-XXL-16-512-t2i/model_index.json +27 -0
  34. DeCo-XXL-16-512-t2i/pipeline.py +291 -0
  35. DeCo-XXL-16-512-t2i/scheduler/scheduler_config.json +13 -0
  36. DeCo-XXL-16-512-t2i/scheduler/scheduling_deco_flow_match_adam_discrete.py +200 -0
  37. DeCo-XXL-16-512-t2i/scheduler/scheduling_deco_flow_match_euler_discrete.py +82 -0
  38. DeCo-XXL-16-512-t2i/scripts/run_t2i_demo.py +47 -0
  39. DeCo-XXL-16-512-t2i/scripts/test_t2i_load.py +53 -0
  40. DeCo-XXL-16-512-t2i/transformer/__pycache__/transformer_deco_t2i.cpython-312.pyc +0 -0
  41. DeCo-XXL-16-512-t2i/transformer/config.json +21 -0
  42. DeCo-XXL-16-512-t2i/transformer/diffusion_pytorch_model.safetensors +3 -0
  43. DeCo-XXL-16-512-t2i/transformer/transformer_deco_t2i.py +411 -0
  44. README.md +109 -0
  45. t2i_DeCo.ckpt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ DeCo-XL-16-512/decoder/diffusion_pytorch_model.safetensors.bak filter=lfs diff=lfs merge=lfs -text
37
+ DeCo-XL-16-512/demo.png filter=lfs diff=lfs merge=lfs -text
38
+ DeCo-XL-16-512/transformer/diffusion_pytorch_model.safetensors.bak filter=lfs diff=lfs merge=lfs -text
DeCo-XL-16-256/decoder/__pycache__/decoder_deco.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
DeCo-XL-16-256/decoder/config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeCoPatchDecoderModel",
3
+ "hidden_size_x": 32,
4
+ "in_channels": 3,
5
+ "max_freqs": 8,
6
+ "num_res_blocks": 3,
7
+ "patch_size": 16,
8
+ "z_channels": 1152
9
+ }
DeCo-XL-16-256/decoder/decoder_deco.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+
18
+
19
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
20
+ return x * (1 + scale) + shift
21
+
22
+
23
+ class NerfEmbedder(nn.Module):
24
+ def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
25
+ super().__init__()
26
+ self.max_freqs = max_freqs
27
+ self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True))
28
+
29
+ @lru_cache
30
+ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
31
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
32
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
33
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
34
+ freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
35
+ freqs_x = freqs[None, :, None]
36
+ freqs_y = freqs[None, None, :]
37
+ coeffs = (1 + freqs_x * freqs_y) ** -1
38
+ dct = (
39
+ torch.cos(pos_x.reshape(-1, 1, 1) * freqs_x * torch.pi)
40
+ * torch.cos(pos_y.reshape(-1, 1, 1) * freqs_y * torch.pi)
41
+ * coeffs
42
+ ).view(1, -1, self.max_freqs**2)
43
+ return dct
44
+
45
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46
+ batch_size, patch_tokens, _ = inputs.shape
47
+ patch_size = int(patch_tokens**0.5)
48
+ dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
49
+ return self.embedder(torch.cat([inputs, dct], dim=-1))
50
+
51
+
52
+ class ResBlock(nn.Module):
53
+ def __init__(self, channels: int):
54
+ super().__init__()
55
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
56
+ self.mlp = nn.Sequential(
57
+ nn.Linear(channels, channels, bias=True),
58
+ nn.SiLU(),
59
+ nn.Linear(channels, channels, bias=True),
60
+ )
61
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
62
+
63
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
65
+ return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp))
66
+
67
+
68
+ class DecoderFinalLayer(nn.Module):
69
+ def __init__(self, model_channels: int, out_channels: int):
70
+ super().__init__()
71
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
72
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return self.linear(self.norm_final(x))
76
+
77
+
78
+ class SimpleMLPAdaLN(nn.Module):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ model_channels: int,
83
+ out_channels: int,
84
+ z_channels: int,
85
+ num_res_blocks: int,
86
+ patch_size: int,
87
+ grad_checkpointing: bool = False,
88
+ ):
89
+ super().__init__()
90
+ self.patch_size = patch_size
91
+ self.grad_checkpointing = grad_checkpointing
92
+ self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
93
+ self.input_proj = nn.Linear(in_channels, model_channels)
94
+ self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
95
+ self.final_layer = DecoderFinalLayer(model_channels, out_channels)
96
+ self._init_weights()
97
+
98
+ def _init_weights(self) -> None:
99
+ for block in self.res_blocks:
100
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
101
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
102
+ nn.init.constant_(self.final_layer.linear.weight, 0)
103
+ nn.init.constant_(self.final_layer.linear.bias, 0)
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ x = self.input_proj(x)
107
+ y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1)
108
+ for block in self.res_blocks:
109
+ if self.grad_checkpointing and not torch.jit.is_scripting():
110
+ x = checkpoint(block, x, y)
111
+ else:
112
+ x = block(x, y)
113
+ return self.final_layer(x)
114
+
115
+
116
+ @dataclass
117
+ class DeCoPatchDecoderOutput(BaseOutput):
118
+ sample: torch.Tensor
119
+
120
+
121
+ class DeCoPatchDecoderModel(ModelMixin, ConfigMixin):
122
+ """Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP)."""
123
+
124
+ config_name = "config.json"
125
+
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ in_channels: int = 3,
130
+ hidden_size_x: int = 32,
131
+ z_channels: int = 1152,
132
+ num_res_blocks: int = 3,
133
+ patch_size: int = 16,
134
+ max_freqs: int = 8,
135
+ ):
136
+ super().__init__()
137
+ self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs)
138
+ self.dec_net = SimpleMLPAdaLN(
139
+ in_channels=hidden_size_x,
140
+ model_channels=hidden_size_x,
141
+ out_channels=in_channels,
142
+ z_channels=z_channels,
143
+ num_res_blocks=num_res_blocks,
144
+ patch_size=patch_size,
145
+ )
146
+
147
+ def forward(
148
+ self,
149
+ patch_pixels: torch.Tensor,
150
+ conditioning: torch.Tensor,
151
+ return_dict: bool = True,
152
+ ) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]:
153
+ """
154
+ Args:
155
+ patch_pixels (`torch.Tensor`):
156
+ Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`.
157
+ conditioning (`torch.Tensor`):
158
+ Per-patch conditioning of shape `(batch * num_patches, z_channels)`.
159
+ """
160
+ output = self.dec_net(self.x_embedder(patch_pixels), conditioning)
161
+ if not return_dict:
162
+ return (output,)
163
+ return DeCoPatchDecoderOutput(sample=output)
DeCo-XL-16-256/decoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2de852e7fd141788fe391192901098dd2c5f5196e7a6e988391a1b7be002e5f6
3
+ size 37862236
DeCo-XL-16-256/model_index.json ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "DeCoPipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "decoder": [
8
+ "decoder_deco",
9
+ "DeCoPatchDecoderModel"
10
+ ],
11
+ "id2label": {
12
+ "0": "tench, Tinca tinca",
13
+ "1": "goldfish, Carassius auratus",
14
+ "10": "brambling, Fringilla montifringilla",
15
+ "100": "black swan, Cygnus atratus",
16
+ "101": "tusker",
17
+ "102": "echidna, spiny anteater, anteater",
18
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
19
+ "104": "wallaby, brush kangaroo",
20
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
21
+ "106": "wombat",
22
+ "107": "jellyfish",
23
+ "108": "sea anemone, anemone",
24
+ "109": "brain coral",
25
+ "11": "goldfinch, Carduelis carduelis",
26
+ "110": "flatworm, platyhelminth",
27
+ "111": "nematode, nematode worm, roundworm",
28
+ "112": "conch",
29
+ "113": "snail",
30
+ "114": "slug",
31
+ "115": "sea slug, nudibranch",
32
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
33
+ "117": "chambered nautilus, pearly nautilus, nautilus",
34
+ "118": "Dungeness crab, Cancer magister",
35
+ "119": "rock crab, Cancer irroratus",
36
+ "12": "house finch, linnet, Carpodacus mexicanus",
37
+ "120": "fiddler crab",
38
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
39
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
40
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
41
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
42
+ "125": "hermit crab",
43
+ "126": "isopod",
44
+ "127": "white stork, Ciconia ciconia",
45
+ "128": "black stork, Ciconia nigra",
46
+ "129": "spoonbill",
47
+ "13": "junco, snowbird",
48
+ "130": "flamingo",
49
+ "131": "little blue heron, Egretta caerulea",
50
+ "132": "American egret, great white heron, Egretta albus",
51
+ "133": "bittern",
52
+ "134": "crane",
53
+ "135": "limpkin, Aramus pictus",
54
+ "136": "European gallinule, Porphyrio porphyrio",
55
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
56
+ "138": "bustard",
57
+ "139": "ruddy turnstone, Arenaria interpres",
58
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
59
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
60
+ "141": "redshank, Tringa totanus",
61
+ "142": "dowitcher",
62
+ "143": "oystercatcher, oyster catcher",
63
+ "144": "pelican",
64
+ "145": "king penguin, Aptenodytes patagonica",
65
+ "146": "albatross, mollymawk",
66
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
67
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
68
+ "149": "dugong, Dugong dugon",
69
+ "15": "robin, American robin, Turdus migratorius",
70
+ "150": "sea lion",
71
+ "151": "Chihuahua",
72
+ "152": "Japanese spaniel",
73
+ "153": "Maltese dog, Maltese terrier, Maltese",
74
+ "154": "Pekinese, Pekingese, Peke",
75
+ "155": "Shih-Tzu",
76
+ "156": "Blenheim spaniel",
77
+ "157": "papillon",
78
+ "158": "toy terrier",
79
+ "159": "Rhodesian ridgeback",
80
+ "16": "bulbul",
81
+ "160": "Afghan hound, Afghan",
82
+ "161": "basset, basset hound",
83
+ "162": "beagle",
84
+ "163": "bloodhound, sleuthhound",
85
+ "164": "bluetick",
86
+ "165": "black-and-tan coonhound",
87
+ "166": "Walker hound, Walker foxhound",
88
+ "167": "English foxhound",
89
+ "168": "redbone",
90
+ "169": "borzoi, Russian wolfhound",
91
+ "17": "jay",
92
+ "170": "Irish wolfhound",
93
+ "171": "Italian greyhound",
94
+ "172": "whippet",
95
+ "173": "Ibizan hound, Ibizan Podenco",
96
+ "174": "Norwegian elkhound, elkhound",
97
+ "175": "otterhound, otter hound",
98
+ "176": "Saluki, gazelle hound",
99
+ "177": "Scottish deerhound, deerhound",
100
+ "178": "Weimaraner",
101
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
102
+ "18": "magpie",
103
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
104
+ "181": "Bedlington terrier",
105
+ "182": "Border terrier",
106
+ "183": "Kerry blue terrier",
107
+ "184": "Irish terrier",
108
+ "185": "Norfolk terrier",
109
+ "186": "Norwich terrier",
110
+ "187": "Yorkshire terrier",
111
+ "188": "wire-haired fox terrier",
112
+ "189": "Lakeland terrier",
113
+ "19": "chickadee",
114
+ "190": "Sealyham terrier, Sealyham",
115
+ "191": "Airedale, Airedale terrier",
116
+ "192": "cairn, cairn terrier",
117
+ "193": "Australian terrier",
118
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
119
+ "195": "Boston bull, Boston terrier",
120
+ "196": "miniature schnauzer",
121
+ "197": "giant schnauzer",
122
+ "198": "standard schnauzer",
123
+ "199": "Scotch terrier, Scottish terrier, Scottie",
124
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
125
+ "20": "water ouzel, dipper",
126
+ "200": "Tibetan terrier, chrysanthemum dog",
127
+ "201": "silky terrier, Sydney silky",
128
+ "202": "soft-coated wheaten terrier",
129
+ "203": "West Highland white terrier",
130
+ "204": "Lhasa, Lhasa apso",
131
+ "205": "flat-coated retriever",
132
+ "206": "curly-coated retriever",
133
+ "207": "golden retriever",
134
+ "208": "Labrador retriever",
135
+ "209": "Chesapeake Bay retriever",
136
+ "21": "kite",
137
+ "210": "German short-haired pointer",
138
+ "211": "vizsla, Hungarian pointer",
139
+ "212": "English setter",
140
+ "213": "Irish setter, red setter",
141
+ "214": "Gordon setter",
142
+ "215": "Brittany spaniel",
143
+ "216": "clumber, clumber spaniel",
144
+ "217": "English springer, English springer spaniel",
145
+ "218": "Welsh springer spaniel",
146
+ "219": "cocker spaniel, English cocker spaniel, cocker",
147
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
148
+ "220": "Sussex spaniel",
149
+ "221": "Irish water spaniel",
150
+ "222": "kuvasz",
151
+ "223": "schipperke",
152
+ "224": "groenendael",
153
+ "225": "malinois",
154
+ "226": "briard",
155
+ "227": "kelpie",
156
+ "228": "komondor",
157
+ "229": "Old English sheepdog, bobtail",
158
+ "23": "vulture",
159
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
160
+ "231": "collie",
161
+ "232": "Border collie",
162
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
163
+ "234": "Rottweiler",
164
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
165
+ "236": "Doberman, Doberman pinscher",
166
+ "237": "miniature pinscher",
167
+ "238": "Greater Swiss Mountain dog",
168
+ "239": "Bernese mountain dog",
169
+ "24": "great grey owl, great gray owl, Strix nebulosa",
170
+ "240": "Appenzeller",
171
+ "241": "EntleBucher",
172
+ "242": "boxer",
173
+ "243": "bull mastiff",
174
+ "244": "Tibetan mastiff",
175
+ "245": "French bulldog",
176
+ "246": "Great Dane",
177
+ "247": "Saint Bernard, St Bernard",
178
+ "248": "Eskimo dog, husky",
179
+ "249": "malamute, malemute, Alaskan malamute",
180
+ "25": "European fire salamander, Salamandra salamandra",
181
+ "250": "Siberian husky",
182
+ "251": "dalmatian, coach dog, carriage dog",
183
+ "252": "affenpinscher, monkey pinscher, monkey dog",
184
+ "253": "basenji",
185
+ "254": "pug, pug-dog",
186
+ "255": "Leonberg",
187
+ "256": "Newfoundland, Newfoundland dog",
188
+ "257": "Great Pyrenees",
189
+ "258": "Samoyed, Samoyede",
190
+ "259": "Pomeranian",
191
+ "26": "common newt, Triturus vulgaris",
192
+ "260": "chow, chow chow",
193
+ "261": "keeshond",
194
+ "262": "Brabancon griffon",
195
+ "263": "Pembroke, Pembroke Welsh corgi",
196
+ "264": "Cardigan, Cardigan Welsh corgi",
197
+ "265": "toy poodle",
198
+ "266": "miniature poodle",
199
+ "267": "standard poodle",
200
+ "268": "Mexican hairless",
201
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
202
+ "27": "eft",
203
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
204
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
205
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
206
+ "273": "dingo, warrigal, warragal, Canis dingo",
207
+ "274": "dhole, Cuon alpinus",
208
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
209
+ "276": "hyena, hyaena",
210
+ "277": "red fox, Vulpes vulpes",
211
+ "278": "kit fox, Vulpes macrotis",
212
+ "279": "Arctic fox, white fox, Alopex lagopus",
213
+ "28": "spotted salamander, Ambystoma maculatum",
214
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
215
+ "281": "tabby, tabby cat",
216
+ "282": "tiger cat",
217
+ "283": "Persian cat",
218
+ "284": "Siamese cat, Siamese",
219
+ "285": "Egyptian cat",
220
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
221
+ "287": "lynx, catamount",
222
+ "288": "leopard, Panthera pardus",
223
+ "289": "snow leopard, ounce, Panthera uncia",
224
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
225
+ "290": "jaguar, panther, Panthera onca, Felis onca",
226
+ "291": "lion, king of beasts, Panthera leo",
227
+ "292": "tiger, Panthera tigris",
228
+ "293": "cheetah, chetah, Acinonyx jubatus",
229
+ "294": "brown bear, bruin, Ursus arctos",
230
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
231
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
232
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
233
+ "298": "mongoose",
234
+ "299": "meerkat, mierkat",
235
+ "3": "tiger shark, Galeocerdo cuvieri",
236
+ "30": "bullfrog, Rana catesbeiana",
237
+ "300": "tiger beetle",
238
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
239
+ "302": "ground beetle, carabid beetle",
240
+ "303": "long-horned beetle, longicorn, longicorn beetle",
241
+ "304": "leaf beetle, chrysomelid",
242
+ "305": "dung beetle",
243
+ "306": "rhinoceros beetle",
244
+ "307": "weevil",
245
+ "308": "fly",
246
+ "309": "bee",
247
+ "31": "tree frog, tree-frog",
248
+ "310": "ant, emmet, pismire",
249
+ "311": "grasshopper, hopper",
250
+ "312": "cricket",
251
+ "313": "walking stick, walkingstick, stick insect",
252
+ "314": "cockroach, roach",
253
+ "315": "mantis, mantid",
254
+ "316": "cicada, cicala",
255
+ "317": "leafhopper",
256
+ "318": "lacewing, lacewing fly",
257
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
258
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
259
+ "320": "damselfly",
260
+ "321": "admiral",
261
+ "322": "ringlet, ringlet butterfly",
262
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
263
+ "324": "cabbage butterfly",
264
+ "325": "sulphur butterfly, sulfur butterfly",
265
+ "326": "lycaenid, lycaenid butterfly",
266
+ "327": "starfish, sea star",
267
+ "328": "sea urchin",
268
+ "329": "sea cucumber, holothurian",
269
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
270
+ "330": "wood rabbit, cottontail, cottontail rabbit",
271
+ "331": "hare",
272
+ "332": "Angora, Angora rabbit",
273
+ "333": "hamster",
274
+ "334": "porcupine, hedgehog",
275
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
276
+ "336": "marmot",
277
+ "337": "beaver",
278
+ "338": "guinea pig, Cavia cobaya",
279
+ "339": "sorrel",
280
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
281
+ "340": "zebra",
282
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
283
+ "342": "wild boar, boar, Sus scrofa",
284
+ "343": "warthog",
285
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
286
+ "345": "ox",
287
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
288
+ "347": "bison",
289
+ "348": "ram, tup",
290
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
291
+ "35": "mud turtle",
292
+ "350": "ibex, Capra ibex",
293
+ "351": "hartebeest",
294
+ "352": "impala, Aepyceros melampus",
295
+ "353": "gazelle",
296
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
297
+ "355": "llama",
298
+ "356": "weasel",
299
+ "357": "mink",
300
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
301
+ "359": "black-footed ferret, ferret, Mustela nigripes",
302
+ "36": "terrapin",
303
+ "360": "otter",
304
+ "361": "skunk, polecat, wood pussy",
305
+ "362": "badger",
306
+ "363": "armadillo",
307
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
308
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
309
+ "366": "gorilla, Gorilla gorilla",
310
+ "367": "chimpanzee, chimp, Pan troglodytes",
311
+ "368": "gibbon, Hylobates lar",
312
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
313
+ "37": "box turtle, box tortoise",
314
+ "370": "guenon, guenon monkey",
315
+ "371": "patas, hussar monkey, Erythrocebus patas",
316
+ "372": "baboon",
317
+ "373": "macaque",
318
+ "374": "langur",
319
+ "375": "colobus, colobus monkey",
320
+ "376": "proboscis monkey, Nasalis larvatus",
321
+ "377": "marmoset",
322
+ "378": "capuchin, ringtail, Cebus capucinus",
323
+ "379": "howler monkey, howler",
324
+ "38": "banded gecko",
325
+ "380": "titi, titi monkey",
326
+ "381": "spider monkey, Ateles geoffroyi",
327
+ "382": "squirrel monkey, Saimiri sciureus",
328
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
329
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
330
+ "385": "Indian elephant, Elephas maximus",
331
+ "386": "African elephant, Loxodonta africana",
332
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
333
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
334
+ "389": "barracouta, snoek",
335
+ "39": "common iguana, iguana, Iguana iguana",
336
+ "390": "eel",
337
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
338
+ "392": "rock beauty, Holocanthus tricolor",
339
+ "393": "anemone fish",
340
+ "394": "sturgeon",
341
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
342
+ "396": "lionfish",
343
+ "397": "puffer, pufferfish, blowfish, globefish",
344
+ "398": "abacus",
345
+ "399": "abaya",
346
+ "4": "hammerhead, hammerhead shark",
347
+ "40": "American chameleon, anole, Anolis carolinensis",
348
+ "400": "academic gown, academic robe, judge robe",
349
+ "401": "accordion, piano accordion, squeeze box",
350
+ "402": "acoustic guitar",
351
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
352
+ "404": "airliner",
353
+ "405": "airship, dirigible",
354
+ "406": "altar",
355
+ "407": "ambulance",
356
+ "408": "amphibian, amphibious vehicle",
357
+ "409": "analog clock",
358
+ "41": "whiptail, whiptail lizard",
359
+ "410": "apiary, bee house",
360
+ "411": "apron",
361
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
362
+ "413": "assault rifle, assault gun",
363
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
364
+ "415": "bakery, bakeshop, bakehouse",
365
+ "416": "balance beam, beam",
366
+ "417": "balloon",
367
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
368
+ "419": "Band Aid",
369
+ "42": "agama",
370
+ "420": "banjo",
371
+ "421": "bannister, banister, balustrade, balusters, handrail",
372
+ "422": "barbell",
373
+ "423": "barber chair",
374
+ "424": "barbershop",
375
+ "425": "barn",
376
+ "426": "barometer",
377
+ "427": "barrel, cask",
378
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
379
+ "429": "baseball",
380
+ "43": "frilled lizard, Chlamydosaurus kingi",
381
+ "430": "basketball",
382
+ "431": "bassinet",
383
+ "432": "bassoon",
384
+ "433": "bathing cap, swimming cap",
385
+ "434": "bath towel",
386
+ "435": "bathtub, bathing tub, bath, tub",
387
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
388
+ "437": "beacon, lighthouse, beacon light, pharos",
389
+ "438": "beaker",
390
+ "439": "bearskin, busby, shako",
391
+ "44": "alligator lizard",
392
+ "440": "beer bottle",
393
+ "441": "beer glass",
394
+ "442": "bell cote, bell cot",
395
+ "443": "bib",
396
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
397
+ "445": "bikini, two-piece",
398
+ "446": "binder, ring-binder",
399
+ "447": "binoculars, field glasses, opera glasses",
400
+ "448": "birdhouse",
401
+ "449": "boathouse",
402
+ "45": "Gila monster, Heloderma suspectum",
403
+ "450": "bobsled, bobsleigh, bob",
404
+ "451": "bolo tie, bolo, bola tie, bola",
405
+ "452": "bonnet, poke bonnet",
406
+ "453": "bookcase",
407
+ "454": "bookshop, bookstore, bookstall",
408
+ "455": "bottlecap",
409
+ "456": "bow",
410
+ "457": "bow tie, bow-tie, bowtie",
411
+ "458": "brass, memorial tablet, plaque",
412
+ "459": "brassiere, bra, bandeau",
413
+ "46": "green lizard, Lacerta viridis",
414
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
415
+ "461": "breastplate, aegis, egis",
416
+ "462": "broom",
417
+ "463": "bucket, pail",
418
+ "464": "buckle",
419
+ "465": "bulletproof vest",
420
+ "466": "bullet train, bullet",
421
+ "467": "butcher shop, meat market",
422
+ "468": "cab, hack, taxi, taxicab",
423
+ "469": "caldron, cauldron",
424
+ "47": "African chameleon, Chamaeleo chamaeleon",
425
+ "470": "candle, taper, wax light",
426
+ "471": "cannon",
427
+ "472": "canoe",
428
+ "473": "can opener, tin opener",
429
+ "474": "cardigan",
430
+ "475": "car mirror",
431
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
432
+ "477": "carpenters kit, tool kit",
433
+ "478": "carton",
434
+ "479": "car wheel",
435
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
436
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
437
+ "481": "cassette",
438
+ "482": "cassette player",
439
+ "483": "castle",
440
+ "484": "catamaran",
441
+ "485": "CD player",
442
+ "486": "cello, violoncello",
443
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
444
+ "488": "chain",
445
+ "489": "chainlink fence",
446
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
447
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
448
+ "491": "chain saw, chainsaw",
449
+ "492": "chest",
450
+ "493": "chiffonier, commode",
451
+ "494": "chime, bell, gong",
452
+ "495": "china cabinet, china closet",
453
+ "496": "Christmas stocking",
454
+ "497": "church, church building",
455
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
456
+ "499": "cleaver, meat cleaver, chopper",
457
+ "5": "electric ray, crampfish, numbfish, torpedo",
458
+ "50": "American alligator, Alligator mississipiensis",
459
+ "500": "cliff dwelling",
460
+ "501": "cloak",
461
+ "502": "clog, geta, patten, sabot",
462
+ "503": "cocktail shaker",
463
+ "504": "coffee mug",
464
+ "505": "coffeepot",
465
+ "506": "coil, spiral, volute, whorl, helix",
466
+ "507": "combination lock",
467
+ "508": "computer keyboard, keypad",
468
+ "509": "confectionery, confectionary, candy store",
469
+ "51": "triceratops",
470
+ "510": "container ship, containership, container vessel",
471
+ "511": "convertible",
472
+ "512": "corkscrew, bottle screw",
473
+ "513": "cornet, horn, trumpet, trump",
474
+ "514": "cowboy boot",
475
+ "515": "cowboy hat, ten-gallon hat",
476
+ "516": "cradle",
477
+ "517": "crane",
478
+ "518": "crash helmet",
479
+ "519": "crate",
480
+ "52": "thunder snake, worm snake, Carphophis amoenus",
481
+ "520": "crib, cot",
482
+ "521": "Crock Pot",
483
+ "522": "croquet ball",
484
+ "523": "crutch",
485
+ "524": "cuirass",
486
+ "525": "dam, dike, dyke",
487
+ "526": "desk",
488
+ "527": "desktop computer",
489
+ "528": "dial telephone, dial phone",
490
+ "529": "diaper, nappy, napkin",
491
+ "53": "ringneck snake, ring-necked snake, ring snake",
492
+ "530": "digital clock",
493
+ "531": "digital watch",
494
+ "532": "dining table, board",
495
+ "533": "dishrag, dishcloth",
496
+ "534": "dishwasher, dish washer, dishwashing machine",
497
+ "535": "disk brake, disc brake",
498
+ "536": "dock, dockage, docking facility",
499
+ "537": "dogsled, dog sled, dog sleigh",
500
+ "538": "dome",
501
+ "539": "doormat, welcome mat",
502
+ "54": "hognose snake, puff adder, sand viper",
503
+ "540": "drilling platform, offshore rig",
504
+ "541": "drum, membranophone, tympan",
505
+ "542": "drumstick",
506
+ "543": "dumbbell",
507
+ "544": "Dutch oven",
508
+ "545": "electric fan, blower",
509
+ "546": "electric guitar",
510
+ "547": "electric locomotive",
511
+ "548": "entertainment center",
512
+ "549": "envelope",
513
+ "55": "green snake, grass snake",
514
+ "550": "espresso maker",
515
+ "551": "face powder",
516
+ "552": "feather boa, boa",
517
+ "553": "file, file cabinet, filing cabinet",
518
+ "554": "fireboat",
519
+ "555": "fire engine, fire truck",
520
+ "556": "fire screen, fireguard",
521
+ "557": "flagpole, flagstaff",
522
+ "558": "flute, transverse flute",
523
+ "559": "folding chair",
524
+ "56": "king snake, kingsnake",
525
+ "560": "football helmet",
526
+ "561": "forklift",
527
+ "562": "fountain",
528
+ "563": "fountain pen",
529
+ "564": "four-poster",
530
+ "565": "freight car",
531
+ "566": "French horn, horn",
532
+ "567": "frying pan, frypan, skillet",
533
+ "568": "fur coat",
534
+ "569": "garbage truck, dustcart",
535
+ "57": "garter snake, grass snake",
536
+ "570": "gasmask, respirator, gas helmet",
537
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
538
+ "572": "goblet",
539
+ "573": "go-kart",
540
+ "574": "golf ball",
541
+ "575": "golfcart, golf cart",
542
+ "576": "gondola",
543
+ "577": "gong, tam-tam",
544
+ "578": "gown",
545
+ "579": "grand piano, grand",
546
+ "58": "water snake",
547
+ "580": "greenhouse, nursery, glasshouse",
548
+ "581": "grille, radiator grille",
549
+ "582": "grocery store, grocery, food market, market",
550
+ "583": "guillotine",
551
+ "584": "hair slide",
552
+ "585": "hair spray",
553
+ "586": "half track",
554
+ "587": "hammer",
555
+ "588": "hamper",
556
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
557
+ "59": "vine snake",
558
+ "590": "hand-held computer, hand-held microcomputer",
559
+ "591": "handkerchief, hankie, hanky, hankey",
560
+ "592": "hard disc, hard disk, fixed disk",
561
+ "593": "harmonica, mouth organ, harp, mouth harp",
562
+ "594": "harp",
563
+ "595": "harvester, reaper",
564
+ "596": "hatchet",
565
+ "597": "holster",
566
+ "598": "home theater, home theatre",
567
+ "599": "honeycomb",
568
+ "6": "stingray",
569
+ "60": "night snake, Hypsiglena torquata",
570
+ "600": "hook, claw",
571
+ "601": "hoopskirt, crinoline",
572
+ "602": "horizontal bar, high bar",
573
+ "603": "horse cart, horse-cart",
574
+ "604": "hourglass",
575
+ "605": "iPod",
576
+ "606": "iron, smoothing iron",
577
+ "607": "jack-o-lantern",
578
+ "608": "jean, blue jean, denim",
579
+ "609": "jeep, landrover",
580
+ "61": "boa constrictor, Constrictor constrictor",
581
+ "610": "jersey, T-shirt, tee shirt",
582
+ "611": "jigsaw puzzle",
583
+ "612": "jinrikisha, ricksha, rickshaw",
584
+ "613": "joystick",
585
+ "614": "kimono",
586
+ "615": "knee pad",
587
+ "616": "knot",
588
+ "617": "lab coat, laboratory coat",
589
+ "618": "ladle",
590
+ "619": "lampshade, lamp shade",
591
+ "62": "rock python, rock snake, Python sebae",
592
+ "620": "laptop, laptop computer",
593
+ "621": "lawn mower, mower",
594
+ "622": "lens cap, lens cover",
595
+ "623": "letter opener, paper knife, paperknife",
596
+ "624": "library",
597
+ "625": "lifeboat",
598
+ "626": "lighter, light, igniter, ignitor",
599
+ "627": "limousine, limo",
600
+ "628": "liner, ocean liner",
601
+ "629": "lipstick, lip rouge",
602
+ "63": "Indian cobra, Naja naja",
603
+ "630": "Loafer",
604
+ "631": "lotion",
605
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
606
+ "633": "loupe, jewelers loupe",
607
+ "634": "lumbermill, sawmill",
608
+ "635": "magnetic compass",
609
+ "636": "mailbag, postbag",
610
+ "637": "mailbox, letter box",
611
+ "638": "maillot",
612
+ "639": "maillot, tank suit",
613
+ "64": "green mamba",
614
+ "640": "manhole cover",
615
+ "641": "maraca",
616
+ "642": "marimba, xylophone",
617
+ "643": "mask",
618
+ "644": "matchstick",
619
+ "645": "maypole",
620
+ "646": "maze, labyrinth",
621
+ "647": "measuring cup",
622
+ "648": "medicine chest, medicine cabinet",
623
+ "649": "megalith, megalithic structure",
624
+ "65": "sea snake",
625
+ "650": "microphone, mike",
626
+ "651": "microwave, microwave oven",
627
+ "652": "military uniform",
628
+ "653": "milk can",
629
+ "654": "minibus",
630
+ "655": "miniskirt, mini",
631
+ "656": "minivan",
632
+ "657": "missile",
633
+ "658": "mitten",
634
+ "659": "mixing bowl",
635
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
636
+ "660": "mobile home, manufactured home",
637
+ "661": "Model T",
638
+ "662": "modem",
639
+ "663": "monastery",
640
+ "664": "monitor",
641
+ "665": "moped",
642
+ "666": "mortar",
643
+ "667": "mortarboard",
644
+ "668": "mosque",
645
+ "669": "mosquito net",
646
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
647
+ "670": "motor scooter, scooter",
648
+ "671": "mountain bike, all-terrain bike, off-roader",
649
+ "672": "mountain tent",
650
+ "673": "mouse, computer mouse",
651
+ "674": "mousetrap",
652
+ "675": "moving van",
653
+ "676": "muzzle",
654
+ "677": "nail",
655
+ "678": "neck brace",
656
+ "679": "necklace",
657
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
658
+ "680": "nipple",
659
+ "681": "notebook, notebook computer",
660
+ "682": "obelisk",
661
+ "683": "oboe, hautboy, hautbois",
662
+ "684": "ocarina, sweet potato",
663
+ "685": "odometer, hodometer, mileometer, milometer",
664
+ "686": "oil filter",
665
+ "687": "organ, pipe organ",
666
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
667
+ "689": "overskirt",
668
+ "69": "trilobite",
669
+ "690": "oxcart",
670
+ "691": "oxygen mask",
671
+ "692": "packet",
672
+ "693": "paddle, boat paddle",
673
+ "694": "paddlewheel, paddle wheel",
674
+ "695": "padlock",
675
+ "696": "paintbrush",
676
+ "697": "pajama, pyjama, pjs, jammies",
677
+ "698": "palace",
678
+ "699": "panpipe, pandean pipe, syrinx",
679
+ "7": "cock",
680
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
681
+ "700": "paper towel",
682
+ "701": "parachute, chute",
683
+ "702": "parallel bars, bars",
684
+ "703": "park bench",
685
+ "704": "parking meter",
686
+ "705": "passenger car, coach, carriage",
687
+ "706": "patio, terrace",
688
+ "707": "pay-phone, pay-station",
689
+ "708": "pedestal, plinth, footstall",
690
+ "709": "pencil box, pencil case",
691
+ "71": "scorpion",
692
+ "710": "pencil sharpener",
693
+ "711": "perfume, essence",
694
+ "712": "Petri dish",
695
+ "713": "photocopier",
696
+ "714": "pick, plectrum, plectron",
697
+ "715": "pickelhaube",
698
+ "716": "picket fence, paling",
699
+ "717": "pickup, pickup truck",
700
+ "718": "pier",
701
+ "719": "piggy bank, penny bank",
702
+ "72": "black and gold garden spider, Argiope aurantia",
703
+ "720": "pill bottle",
704
+ "721": "pillow",
705
+ "722": "ping-pong ball",
706
+ "723": "pinwheel",
707
+ "724": "pirate, pirate ship",
708
+ "725": "pitcher, ewer",
709
+ "726": "plane, carpenters plane, woodworking plane",
710
+ "727": "planetarium",
711
+ "728": "plastic bag",
712
+ "729": "plate rack",
713
+ "73": "barn spider, Araneus cavaticus",
714
+ "730": "plow, plough",
715
+ "731": "plunger, plumbers helper",
716
+ "732": "Polaroid camera, Polaroid Land camera",
717
+ "733": "pole",
718
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
719
+ "735": "poncho",
720
+ "736": "pool table, billiard table, snooker table",
721
+ "737": "pop bottle, soda bottle",
722
+ "738": "pot, flowerpot",
723
+ "739": "potters wheel",
724
+ "74": "garden spider, Aranea diademata",
725
+ "740": "power drill",
726
+ "741": "prayer rug, prayer mat",
727
+ "742": "printer",
728
+ "743": "prison, prison house",
729
+ "744": "projectile, missile",
730
+ "745": "projector",
731
+ "746": "puck, hockey puck",
732
+ "747": "punching bag, punch bag, punching ball, punchball",
733
+ "748": "purse",
734
+ "749": "quill, quill pen",
735
+ "75": "black widow, Latrodectus mactans",
736
+ "750": "quilt, comforter, comfort, puff",
737
+ "751": "racer, race car, racing car",
738
+ "752": "racket, racquet",
739
+ "753": "radiator",
740
+ "754": "radio, wireless",
741
+ "755": "radio telescope, radio reflector",
742
+ "756": "rain barrel",
743
+ "757": "recreational vehicle, RV, R.V.",
744
+ "758": "reel",
745
+ "759": "reflex camera",
746
+ "76": "tarantula",
747
+ "760": "refrigerator, icebox",
748
+ "761": "remote control, remote",
749
+ "762": "restaurant, eating house, eating place, eatery",
750
+ "763": "revolver, six-gun, six-shooter",
751
+ "764": "rifle",
752
+ "765": "rocking chair, rocker",
753
+ "766": "rotisserie",
754
+ "767": "rubber eraser, rubber, pencil eraser",
755
+ "768": "rugby ball",
756
+ "769": "rule, ruler",
757
+ "77": "wolf spider, hunting spider",
758
+ "770": "running shoe",
759
+ "771": "safe",
760
+ "772": "safety pin",
761
+ "773": "saltshaker, salt shaker",
762
+ "774": "sandal",
763
+ "775": "sarong",
764
+ "776": "sax, saxophone",
765
+ "777": "scabbard",
766
+ "778": "scale, weighing machine",
767
+ "779": "school bus",
768
+ "78": "tick",
769
+ "780": "schooner",
770
+ "781": "scoreboard",
771
+ "782": "screen, CRT screen",
772
+ "783": "screw",
773
+ "784": "screwdriver",
774
+ "785": "seat belt, seatbelt",
775
+ "786": "sewing machine",
776
+ "787": "shield, buckler",
777
+ "788": "shoe shop, shoe-shop, shoe store",
778
+ "789": "shoji",
779
+ "79": "centipede",
780
+ "790": "shopping basket",
781
+ "791": "shopping cart",
782
+ "792": "shovel",
783
+ "793": "shower cap",
784
+ "794": "shower curtain",
785
+ "795": "ski",
786
+ "796": "ski mask",
787
+ "797": "sleeping bag",
788
+ "798": "slide rule, slipstick",
789
+ "799": "sliding door",
790
+ "8": "hen",
791
+ "80": "black grouse",
792
+ "800": "slot, one-armed bandit",
793
+ "801": "snorkel",
794
+ "802": "snowmobile",
795
+ "803": "snowplow, snowplough",
796
+ "804": "soap dispenser",
797
+ "805": "soccer ball",
798
+ "806": "sock",
799
+ "807": "solar dish, solar collector, solar furnace",
800
+ "808": "sombrero",
801
+ "809": "soup bowl",
802
+ "81": "ptarmigan",
803
+ "810": "space bar",
804
+ "811": "space heater",
805
+ "812": "space shuttle",
806
+ "813": "spatula",
807
+ "814": "speedboat",
808
+ "815": "spider web, spiders web",
809
+ "816": "spindle",
810
+ "817": "sports car, sport car",
811
+ "818": "spotlight, spot",
812
+ "819": "stage",
813
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
814
+ "820": "steam locomotive",
815
+ "821": "steel arch bridge",
816
+ "822": "steel drum",
817
+ "823": "stethoscope",
818
+ "824": "stole",
819
+ "825": "stone wall",
820
+ "826": "stopwatch, stop watch",
821
+ "827": "stove",
822
+ "828": "strainer",
823
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
824
+ "83": "prairie chicken, prairie grouse, prairie fowl",
825
+ "830": "stretcher",
826
+ "831": "studio couch, day bed",
827
+ "832": "stupa, tope",
828
+ "833": "submarine, pigboat, sub, U-boat",
829
+ "834": "suit, suit of clothes",
830
+ "835": "sundial",
831
+ "836": "sunglass",
832
+ "837": "sunglasses, dark glasses, shades",
833
+ "838": "sunscreen, sunblock, sun blocker",
834
+ "839": "suspension bridge",
835
+ "84": "peacock",
836
+ "840": "swab, swob, mop",
837
+ "841": "sweatshirt",
838
+ "842": "swimming trunks, bathing trunks",
839
+ "843": "swing",
840
+ "844": "switch, electric switch, electrical switch",
841
+ "845": "syringe",
842
+ "846": "table lamp",
843
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
844
+ "848": "tape player",
845
+ "849": "teapot",
846
+ "85": "quail",
847
+ "850": "teddy, teddy bear",
848
+ "851": "television, television system",
849
+ "852": "tennis ball",
850
+ "853": "thatch, thatched roof",
851
+ "854": "theater curtain, theatre curtain",
852
+ "855": "thimble",
853
+ "856": "thresher, thrasher, threshing machine",
854
+ "857": "throne",
855
+ "858": "tile roof",
856
+ "859": "toaster",
857
+ "86": "partridge",
858
+ "860": "tobacco shop, tobacconist shop, tobacconist",
859
+ "861": "toilet seat",
860
+ "862": "torch",
861
+ "863": "totem pole",
862
+ "864": "tow truck, tow car, wrecker",
863
+ "865": "toyshop",
864
+ "866": "tractor",
865
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
866
+ "868": "tray",
867
+ "869": "trench coat",
868
+ "87": "African grey, African gray, Psittacus erithacus",
869
+ "870": "tricycle, trike, velocipede",
870
+ "871": "trimaran",
871
+ "872": "tripod",
872
+ "873": "triumphal arch",
873
+ "874": "trolleybus, trolley coach, trackless trolley",
874
+ "875": "trombone",
875
+ "876": "tub, vat",
876
+ "877": "turnstile",
877
+ "878": "typewriter keyboard",
878
+ "879": "umbrella",
879
+ "88": "macaw",
880
+ "880": "unicycle, monocycle",
881
+ "881": "upright, upright piano",
882
+ "882": "vacuum, vacuum cleaner",
883
+ "883": "vase",
884
+ "884": "vault",
885
+ "885": "velvet",
886
+ "886": "vending machine",
887
+ "887": "vestment",
888
+ "888": "viaduct",
889
+ "889": "violin, fiddle",
890
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
891
+ "890": "volleyball",
892
+ "891": "waffle iron",
893
+ "892": "wall clock",
894
+ "893": "wallet, billfold, notecase, pocketbook",
895
+ "894": "wardrobe, closet, press",
896
+ "895": "warplane, military plane",
897
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
898
+ "897": "washer, automatic washer, washing machine",
899
+ "898": "water bottle",
900
+ "899": "water jug",
901
+ "9": "ostrich, Struthio camelus",
902
+ "90": "lorikeet",
903
+ "900": "water tower",
904
+ "901": "whiskey jug",
905
+ "902": "whistle",
906
+ "903": "wig",
907
+ "904": "window screen",
908
+ "905": "window shade",
909
+ "906": "Windsor tie",
910
+ "907": "wine bottle",
911
+ "908": "wing",
912
+ "909": "wok",
913
+ "91": "coucal",
914
+ "910": "wooden spoon",
915
+ "911": "wool, woolen, woollen",
916
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
917
+ "913": "wreck",
918
+ "914": "yawl",
919
+ "915": "yurt",
920
+ "916": "web site, website, internet site, site",
921
+ "917": "comic book",
922
+ "918": "crossword puzzle, crossword",
923
+ "919": "street sign",
924
+ "92": "bee eater",
925
+ "920": "traffic light, traffic signal, stoplight",
926
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
927
+ "922": "menu",
928
+ "923": "plate",
929
+ "924": "guacamole",
930
+ "925": "consomme",
931
+ "926": "hot pot, hotpot",
932
+ "927": "trifle",
933
+ "928": "ice cream, icecream",
934
+ "929": "ice lolly, lolly, lollipop, popsicle",
935
+ "93": "hornbill",
936
+ "930": "French loaf",
937
+ "931": "bagel, beigel",
938
+ "932": "pretzel",
939
+ "933": "cheeseburger",
940
+ "934": "hotdog, hot dog, red hot",
941
+ "935": "mashed potato",
942
+ "936": "head cabbage",
943
+ "937": "broccoli",
944
+ "938": "cauliflower",
945
+ "939": "zucchini, courgette",
946
+ "94": "hummingbird",
947
+ "940": "spaghetti squash",
948
+ "941": "acorn squash",
949
+ "942": "butternut squash",
950
+ "943": "cucumber, cuke",
951
+ "944": "artichoke, globe artichoke",
952
+ "945": "bell pepper",
953
+ "946": "cardoon",
954
+ "947": "mushroom",
955
+ "948": "Granny Smith",
956
+ "949": "strawberry",
957
+ "95": "jacamar",
958
+ "950": "orange",
959
+ "951": "lemon",
960
+ "952": "fig",
961
+ "953": "pineapple, ananas",
962
+ "954": "banana",
963
+ "955": "jackfruit, jak, jack",
964
+ "956": "custard apple",
965
+ "957": "pomegranate",
966
+ "958": "hay",
967
+ "959": "carbonara",
968
+ "96": "toucan",
969
+ "960": "chocolate sauce, chocolate syrup",
970
+ "961": "dough",
971
+ "962": "meat loaf, meatloaf",
972
+ "963": "pizza, pizza pie",
973
+ "964": "potpie",
974
+ "965": "burrito",
975
+ "966": "red wine",
976
+ "967": "espresso",
977
+ "968": "cup",
978
+ "969": "eggnog",
979
+ "97": "drake",
980
+ "970": "alp",
981
+ "971": "bubble",
982
+ "972": "cliff, drop, drop-off",
983
+ "973": "coral reef",
984
+ "974": "geyser",
985
+ "975": "lakeside, lakeshore",
986
+ "976": "promontory, headland, head, foreland",
987
+ "977": "sandbar, sand bar",
988
+ "978": "seashore, coast, seacoast, sea-coast",
989
+ "979": "valley, vale",
990
+ "98": "red-breasted merganser, Mergus serrator",
991
+ "980": "volcano",
992
+ "981": "ballplayer, baseball player",
993
+ "982": "groom, bridegroom",
994
+ "983": "scuba diver",
995
+ "984": "rapeseed",
996
+ "985": "daisy",
997
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
998
+ "987": "corn",
999
+ "988": "acorn",
1000
+ "989": "hip, rose hip, rosehip",
1001
+ "99": "goose",
1002
+ "990": "buckeye, horse chestnut, conker",
1003
+ "991": "coral fungus",
1004
+ "992": "agaric",
1005
+ "993": "gyromitra",
1006
+ "994": "stinkhorn, carrion fungus",
1007
+ "995": "earthstar",
1008
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1009
+ "997": "bolete",
1010
+ "998": "ear, spike, capitulum",
1011
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1012
+ },
1013
+ "scheduler": [
1014
+ "scheduling_deco_flow_match_euler_discrete",
1015
+ "DeCoFlowMatchEulerDiscreteScheduler"
1016
+ ],
1017
+ "transformer": [
1018
+ "transformer_deco",
1019
+ "DeCoTransformer2DModel"
1020
+ ]
1021
+ }
DeCo-XL-16-256/pipeline.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: DeCoPipeline (class-conditioned c2i).
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+
15
+ EXAMPLE_DOC_STRING = """
16
+ Examples:
17
+ ```py
18
+ >>> from pathlib import Path
19
+ >>> from diffusers import DiffusionPipeline
20
+ >>> import torch
21
+
22
+ >>> model_dir = Path("./DeCo-XL-16-512").resolve()
23
+ >>> pipe = DiffusionPipeline.from_pretrained(
24
+ ... str(model_dir),
25
+ ... local_files_only=True,
26
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
27
+ ... trust_remote_code=True,
28
+ ... torch_dtype=torch.bfloat16,
29
+ ... )
30
+ >>> pipe.to("cuda")
31
+
32
+ >>> print(pipe.id2label[207])
33
+ >>> print(pipe.get_label_ids("golden retriever"))
34
+
35
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
36
+ >>> image = pipe(
37
+ ... class_labels="golden retriever",
38
+ ... num_inference_steps=100,
39
+ ... guidance_scale=5.0,
40
+ ... generator=generator,
41
+ ... ).images[0]
42
+ ```
43
+ """
44
+
45
+
46
+ class DeCoPipeline(DiffusionPipeline):
47
+ r"""
48
+ Pipeline for class-conditional image generation with DeCo.
49
+
50
+ Parameters:
51
+ transformer ([`DeCoTransformer2DModel`]):
52
+ Class-conditional DeCo transformer.
53
+ scheduler ([`DeCoFlowMatchEulerDiscreteScheduler`]):
54
+ Flow-matching Euler scheduler for DeCo.
55
+ decoder ([`DeCoPatchDecoderModel`]):
56
+ Per-patch RGB decoder (NerfEmbedder + AdaLN MLP).
57
+ id2label (`dict[int, str]`, *optional*):
58
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
59
+ """
60
+
61
+ model_cpu_offload_seq = "transformer->decoder"
62
+
63
+ def __init__(
64
+ self,
65
+ transformer,
66
+ scheduler,
67
+ decoder,
68
+ id2label: Optional[Dict[Union[int, str], str]] = None,
69
+ ):
70
+ super().__init__()
71
+ self.register_modules(transformer=transformer, scheduler=scheduler, decoder=decoder)
72
+ self._id2label = self._normalize_id2label(id2label)
73
+ self.labels = self._build_label2id(self._id2label)
74
+ self._labels_loaded_from_model_index = bool(self._id2label)
75
+
76
+ def _ensure_labels_loaded(self) -> None:
77
+ if self._labels_loaded_from_model_index:
78
+ return
79
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
80
+ if loaded:
81
+ self._id2label = loaded
82
+ self.labels = self._build_label2id(self._id2label)
83
+ self._labels_loaded_from_model_index = True
84
+
85
+ @staticmethod
86
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
87
+ if not id2label:
88
+ return {}
89
+ return {int(key): value for key, value in id2label.items()}
90
+
91
+ @staticmethod
92
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
93
+ if not variant_path:
94
+ return {}
95
+ variant_dir = Path(variant_path).resolve()
96
+ model_index_path = variant_dir / "model_index.json"
97
+ if not model_index_path.exists():
98
+ return {}
99
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
100
+ id2label = raw.get("id2label")
101
+ if not isinstance(id2label, dict):
102
+ return {}
103
+ return {int(key): value for key, value in id2label.items()}
104
+
105
+ @staticmethod
106
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
107
+ label2id: Dict[str, int] = {}
108
+ for class_id, value in id2label.items():
109
+ for synonym in value.split(","):
110
+ synonym = synonym.strip()
111
+ if synonym:
112
+ label2id[synonym] = int(class_id)
113
+ return dict(sorted(label2id.items()))
114
+
115
+ @property
116
+ def id2label(self) -> Dict[int, str]:
117
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
118
+ self._ensure_labels_loaded()
119
+ return self._id2label
120
+
121
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
122
+ r"""
123
+ Map ImageNet label strings to class ids.
124
+
125
+ Args:
126
+ label (`str` or `list[str]`):
127
+ One or more English label strings. Each string must match a synonym in `id2label`.
128
+ """
129
+ self._ensure_labels_loaded()
130
+ label2id = self.labels
131
+ if not label2id:
132
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
133
+
134
+ if isinstance(label, str):
135
+ label = [label]
136
+
137
+ missing = [item for item in label if item not in label2id]
138
+ if missing:
139
+ preview = ", ".join(list(label2id.keys())[:8])
140
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
141
+ return [label2id[item] for item in label]
142
+
143
+ def _normalize_class_labels(
144
+ self,
145
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
146
+ ) -> torch.LongTensor:
147
+ if torch.is_tensor(class_labels):
148
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
149
+
150
+ if isinstance(class_labels, int):
151
+ class_label_ids = [class_labels]
152
+ elif isinstance(class_labels, str):
153
+ class_label_ids = self.get_label_ids(class_labels)
154
+ elif class_labels and isinstance(class_labels[0], str):
155
+ class_label_ids = self.get_label_ids(class_labels)
156
+ else:
157
+ class_label_ids = list(class_labels)
158
+
159
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
160
+
161
+ def _default_sample_size(self) -> int:
162
+ return int(getattr(self.transformer.config, "sample_size", 256))
163
+
164
+ @torch.no_grad()
165
+ def __call__(
166
+ self,
167
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
168
+ batch_size: Optional[int] = None,
169
+ height: Optional[int] = None,
170
+ width: Optional[int] = None,
171
+ num_inference_steps: int = 50,
172
+ guidance_scale: float = 1.0,
173
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
174
+ output_type: str = "pil",
175
+ return_dict: bool = True,
176
+ ) -> Union[ImagePipelineOutput, Tuple]:
177
+ r"""
178
+ Generate class-conditional images with DeCo.
179
+
180
+ Args:
181
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
182
+ ImageNet class indices or human-readable English label strings.
183
+ batch_size (`int`, *optional*):
184
+ Number of images to generate. Defaults to the number of class labels. When a single
185
+ class label is provided, repeats it to match `batch_size`.
186
+ height (`int`, *optional*):
187
+ Output image height in pixels. Defaults to `transformer.config.sample_size`.
188
+ width (`int`, *optional*):
189
+ Output image width in pixels. Defaults to `transformer.config.sample_size`.
190
+ num_inference_steps (`int`, defaults to `50`):
191
+ Number of denoising steps.
192
+ guidance_scale (`float`, defaults to `1.0`):
193
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
194
+ generator (`torch.Generator`, *optional*):
195
+ RNG for reproducibility.
196
+ output_type (`str`, defaults to `"pil"`):
197
+ `"pil"`, `"np"`, or `"latent"`.
198
+ return_dict (`bool`, defaults to `True`):
199
+ Return [`ImagePipelineOutput`] if True.
200
+ """
201
+ device = self._execution_device
202
+ dtype = next(self.transformer.parameters()).dtype
203
+ do_cfg = guidance_scale is not None and float(guidance_scale) > 1.0
204
+
205
+ sample_size = self._default_sample_size()
206
+ height = int(height if height is not None else sample_size)
207
+ width = int(width if width is not None else sample_size)
208
+
209
+ class_labels = self._normalize_class_labels(class_labels)
210
+ if batch_size is None:
211
+ batch_size = int(class_labels.numel())
212
+ elif class_labels.numel() == 1 and batch_size > 1:
213
+ class_labels = class_labels.repeat(batch_size)
214
+ elif class_labels.numel() != batch_size:
215
+ raise ValueError("class_labels batch size must match batch_size")
216
+
217
+ if do_cfg:
218
+ null_label = int(self.transformer.config.num_classes)
219
+ uncond_labels = torch.full((batch_size,), null_label, device=device, dtype=torch.long)
220
+
221
+ latents = randn_tensor(
222
+ (batch_size, int(self.transformer.config.in_channels), height, width),
223
+ generator=generator,
224
+ device=device,
225
+ dtype=dtype,
226
+ )
227
+
228
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
229
+ timesteps = self.scheduler.timesteps[:-1]
230
+
231
+ for timestep in self.progress_bar(timesteps):
232
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
233
+
234
+ if do_cfg:
235
+ latent_model_input = torch.cat([latent_model_input, latent_model_input], dim=0)
236
+ model_output = self.transformer(
237
+ latent_model_input,
238
+ timestep,
239
+ class_labels=torch.cat([uncond_labels, class_labels], dim=0),
240
+ decoder=self.decoder,
241
+ ).sample
242
+ model_output_uncond, model_output_cond = model_output.chunk(2)
243
+ model_output = model_output_uncond + float(guidance_scale) * (model_output_cond - model_output_uncond)
244
+ else:
245
+ model_output = self.transformer(
246
+ latent_model_input, timestep, class_labels=class_labels, decoder=self.decoder
247
+ ).sample
248
+
249
+ latents = self.scheduler.step(model_output, timestep, latents).prev_sample
250
+
251
+ image = latents
252
+
253
+ if output_type == "latent":
254
+ if not return_dict:
255
+ return (image,)
256
+ return ImagePipelineOutput(images=image)
257
+
258
+ image = (image / 2 + 0.5).clamp(0, 1)
259
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
260
+
261
+ if output_type == "pil":
262
+ image = self.numpy_to_pil(image)
263
+ elif output_type != "np":
264
+ raise ValueError("output_type must be one of {'pil', 'np', 'latent'}")
265
+
266
+ if not return_dict:
267
+ return (image,)
268
+ return ImagePipelineOutput(images=image)
DeCo-XL-16-256/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeCoFlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.31.0",
4
+ "last_step": null,
5
+ "num_train_timesteps": 1000,
6
+ "prediction_type": "v_prediction",
7
+ "shift": 1.0
8
+ }
DeCo-XL-16-256/scheduler/scheduling_deco_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
9
+
10
+
11
+
12
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
13
+ return t / (t + (1 - t) * shift)
14
+
15
+
16
+ class DeCoFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
17
+ config_name = "scheduler_config.json"
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ num_train_timesteps: int = 1000,
23
+ shift: float = 1.0,
24
+ last_step: Optional[float] = None,
25
+ prediction_type: str = "v_prediction",
26
+ ):
27
+ self.timesteps = torch.tensor([], dtype=torch.float32)
28
+ self.num_inference_steps: Optional[int] = None
29
+ self._step_index: int = 0
30
+
31
+ @property
32
+ def init_noise_sigma(self) -> float:
33
+ return 1.0
34
+
35
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
36
+ if num_inference_steps <= 0:
37
+ raise ValueError("num_inference_steps must be > 0")
38
+
39
+ self.num_inference_steps = int(num_inference_steps)
40
+ last_step = self.config.last_step
41
+ if last_step is None:
42
+ last_step = 1.0 / float(self.num_inference_steps)
43
+
44
+ base_timesteps = torch.linspace(0.0, 1.0 - float(last_step), self.num_inference_steps, dtype=torch.float32)
45
+ base_timesteps = torch.cat([base_timesteps, torch.tensor([1.0], dtype=torch.float32)], dim=0)
46
+ timesteps = _shift_respace_fn(base_timesteps, shift=float(self.config.shift))
47
+
48
+ if device is not None:
49
+ timesteps = timesteps.to(device)
50
+
51
+ self.timesteps = timesteps
52
+ self._step_index = 0
53
+
54
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
55
+ return sample
56
+
57
+ def step(
58
+ self,
59
+ model_output: torch.Tensor,
60
+ timestep: Union[torch.Tensor, float],
61
+ sample: torch.Tensor,
62
+ return_dict: bool = True,
63
+ ):
64
+ if self.num_inference_steps is None or self.timesteps.numel() == 0:
65
+ raise ValueError("Call set_timesteps before step")
66
+
67
+ step_index = min(self._step_index, len(self.timesteps) - 2)
68
+ dt = (self.timesteps[step_index + 1] - self.timesteps[step_index]).to(device=sample.device, dtype=sample.dtype)
69
+
70
+ prev_sample = sample + model_output * dt
71
+
72
+ self._step_index += 1
73
+
74
+ if not return_dict:
75
+ return (prev_sample,)
76
+ return SchedulerOutput(prev_sample=prev_sample)
77
+
78
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
79
+ if timesteps.ndim == 0:
80
+ timesteps = timesteps[None]
81
+ t = timesteps.to(device=original_samples.device, dtype=original_samples.dtype).view(-1, 1, 1, 1)
82
+ return t * original_samples + (1.0 - t) * noise
DeCo-XL-16-256/transformer/__pycache__/transformer_deco.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
DeCo-XL-16-256/transformer/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeCoTransformer2DModel",
3
+ "conditioning_type": "class",
4
+ "decoder_hidden_size": 64,
5
+ "deep_supervision": 0,
6
+ "hidden_size": 1152,
7
+ "hidden_size_x": 32,
8
+ "in_channels": 3,
9
+ "learn_sigma": true,
10
+ "nerf_mlpratio": 4,
11
+ "num_blocks": 31,
12
+ "num_classes": 1000,
13
+ "num_cond_blocks": 28,
14
+ "num_decoder_blocks": 4,
15
+ "num_encoder_blocks": 18,
16
+ "num_groups": 16,
17
+ "num_text_blocks": 4,
18
+ "patch_size": 16,
19
+ "sample_size": 256,
20
+ "txt_embed_dim": 1024,
21
+ "txt_max_length": 100
22
+ }
DeCo-XL-16-256/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75c24fe14dde1f4def9b52ab7211252b7baa344f09d7a3da7b95a5033ccfb824
3
+ size 2691309848
DeCo-XL-16-256/transformer/transformer_deco.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+ from diffusers.models.normalization import RMSNorm
18
+
19
+
20
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
21
+ return x * (1 + scale) + shift
22
+
23
+
24
+ class PatchEmbed(nn.Module):
25
+ def __init__(self, in_chans: int, embed_dim: int, bias: bool = True):
26
+ super().__init__()
27
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return self.proj(x)
31
+
32
+
33
+ class TimestepEmbedder(nn.Module):
34
+ """Sinusoidal timestep embedding with checkpoint-compatible `mlp` module names."""
35
+
36
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
37
+ super().__init__()
38
+ self.mlp = nn.Sequential(
39
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
40
+ nn.SiLU(),
41
+ nn.Linear(hidden_size, hidden_size, bias=True),
42
+ )
43
+ self.frequency_embedding_size = frequency_embedding_size
44
+
45
+ @staticmethod
46
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10) -> torch.Tensor:
47
+ half = dim // 2
48
+ freqs = torch.exp(
49
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
50
+ )
51
+ args = t[..., None].float() * freqs[None, ...]
52
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
53
+ if dim % 2:
54
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
55
+ return embedding.to(t.dtype)
56
+
57
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
58
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
59
+ return self.mlp(t_freq)
60
+
61
+
62
+ class DeCoSwiGLU(nn.Module):
63
+ """SwiGLU MLP with w1/w2/w3 layout matching official DeCo checkpoints."""
64
+
65
+ def __init__(self, dim: int, hidden_dim: int):
66
+ super().__init__()
67
+ hidden_dim = int(2 * hidden_dim / 3)
68
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
69
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
70
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
74
+
75
+
76
+ def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale: float = 16.0) -> torch.Tensor:
77
+ x_pos = torch.linspace(0, scale, width)
78
+ y_pos = torch.linspace(0, scale, height)
79
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
80
+ y_pos = y_pos.reshape(-1)
81
+ x_pos = x_pos.reshape(-1)
82
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
83
+ x_freqs = torch.outer(x_pos, freqs).float()
84
+ y_freqs = torch.outer(y_pos, freqs).float()
85
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
86
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
87
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
88
+ return freqs_cis.reshape(height * width, -1)
89
+
90
+
91
+ def apply_rotary_emb(
92
+ xq: torch.Tensor,
93
+ xk: torch.Tensor,
94
+ freqs_cis: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ freqs_cis = freqs_cis[None, :, None, :]
97
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
98
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
99
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
100
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
101
+ return xq_out.type_as(xq), xk_out.type_as(xk)
102
+
103
+
104
+ class LabelEmbedder(nn.Module):
105
+ def __init__(self, num_classes: int, hidden_size: int):
106
+ super().__init__()
107
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
108
+
109
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
110
+ return self.embedding_table(labels)
111
+
112
+
113
+ class RAttention(nn.Module):
114
+ def __init__(
115
+ self,
116
+ dim: int,
117
+ num_heads: int = 8,
118
+ qkv_bias: bool = False,
119
+ qk_norm: bool = True,
120
+ proj_drop: float = 0.0,
121
+ ) -> None:
122
+ super().__init__()
123
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
124
+ self.num_heads = num_heads
125
+ self.head_dim = dim // num_heads
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.q_norm = RMSNorm(self.head_dim, eps=1e-6) if qk_norm else nn.Identity()
128
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-6) if qk_norm else nn.Identity()
129
+ self.proj = nn.Linear(dim, dim)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ def forward(self, x: torch.Tensor, pos: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
133
+ batch_size, num_tokens, channels = x.shape
134
+ qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
135
+ query, key, value = qkv[0], qkv[1], qkv[2]
136
+ query = self.q_norm(query)
137
+ key = self.k_norm(key)
138
+ query, key = apply_rotary_emb(query, key, freqs_cis=pos)
139
+ query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
140
+ key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
141
+ value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
142
+ x = scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=0.0)
143
+ x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
144
+ return self.proj_drop(self.proj(x))
145
+
146
+
147
+ class FlattenDiTBlock(nn.Module):
148
+ def __init__(self, hidden_size: int, groups: int, mlp_ratio: float = 4.0):
149
+ super().__init__()
150
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
151
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
152
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
153
+ self.mlp = DeCoSwiGLU(hidden_size, int(hidden_size * mlp_ratio))
154
+ self.adaLN_modulation = nn.Sequential(nn.Linear(hidden_size, 6 * hidden_size, bias=True))
155
+
156
+ def forward(self, x: torch.Tensor, c: torch.Tensor, pos: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
157
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
158
+ x = x + gate_msa * self.attn(_modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
159
+ return x + gate_mlp * self.mlp(_modulate(self.norm2(x), shift_mlp, scale_mlp))
160
+
161
+
162
+ @dataclass
163
+ class DeCoTransformer2DModelOutput(BaseOutput):
164
+ sample: torch.Tensor
165
+
166
+
167
+ class _DeCoTransformerBackbone(nn.Module):
168
+ """Class-conditioned DeCo conditioning trunk. Checkpoint weights live under the `backbone.` prefix."""
169
+
170
+ def __init__(
171
+ self,
172
+ in_channels: int,
173
+ patch_size: int,
174
+ num_groups: int,
175
+ hidden_size: int,
176
+ num_cond_blocks: int,
177
+ num_classes: int,
178
+ learn_sigma: bool,
179
+ deep_supervision: int,
180
+ ):
181
+ super().__init__()
182
+ self.learn_sigma = learn_sigma
183
+ self.deep_supervision = deep_supervision
184
+ self.in_channels = in_channels
185
+ self.patch_size = patch_size
186
+ self.hidden_size = hidden_size
187
+ self.num_groups = num_groups
188
+ self.num_cond_blocks = num_cond_blocks
189
+
190
+ self.s_embedder = PatchEmbed(in_channels * patch_size**2, hidden_size, bias=True)
191
+ self.t_embedder = TimestepEmbedder(hidden_size)
192
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
193
+ self.blocks = nn.ModuleList([FlattenDiTBlock(hidden_size, num_groups) for _ in range(num_cond_blocks)])
194
+ self.precompute_pos: dict[tuple[int, int], torch.Tensor] = {}
195
+ self._init_weights()
196
+
197
+ def _init_weights(self) -> None:
198
+ weight = self.s_embedder.proj.weight.data
199
+ nn.init.xavier_uniform_(weight.view([weight.shape[0], -1]))
200
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
201
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
202
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204
+
205
+ def fetch_pos(self, height: int, width: int, device: torch.device) -> torch.Tensor:
206
+ key = (height, width)
207
+ if key not in self.precompute_pos:
208
+ self.precompute_pos[key] = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
209
+ return self.precompute_pos[key].to(device)
210
+
211
+ def forward(
212
+ self,
213
+ x: torch.Tensor,
214
+ t: torch.Tensor,
215
+ y: torch.Tensor,
216
+ decoder: nn.Module,
217
+ s: Optional[torch.Tensor] = None,
218
+ mask: Optional[torch.Tensor] = None,
219
+ ) -> torch.Tensor:
220
+ batch_size, _, height, width = x.shape
221
+ pos = self.fetch_pos(height // self.patch_size, width // self.patch_size, x.device)
222
+ x = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
223
+ t = self.t_embedder(t.view(-1)).view(batch_size, -1, self.hidden_size)
224
+ y = self.y_embedder(y).view(batch_size, 1, self.hidden_size)
225
+ c = F.silu(t + y)
226
+ if s is None:
227
+ s = self.s_embedder(x)
228
+ for block in self.blocks:
229
+ s = block(s, c, pos, mask)
230
+ s = F.silu(t + s)
231
+ batch_size, length, _ = s.shape
232
+ patch_pixels = x.reshape(batch_size * length, self.in_channels, self.patch_size**2).transpose(1, 2)
233
+ conditioning = s.view(batch_size * length, self.hidden_size)
234
+ decoded = decoder(patch_pixels, conditioning).sample
235
+ x = decoded.transpose(1, 2).reshape(batch_size, length, -1)
236
+ return F.fold(
237
+ x.transpose(1, 2).contiguous(),
238
+ (height, width),
239
+ kernel_size=self.patch_size,
240
+ stride=self.patch_size,
241
+ )
242
+
243
+
244
+ class DeCoTransformer2DModel(ModelMixin, ConfigMixin):
245
+ """Class-conditioned DeCo transformer (c2i) for Diffusers pipelines."""
246
+
247
+ config_name = "config.json"
248
+
249
+ @register_to_config
250
+ def __init__(
251
+ self,
252
+ in_channels: int = 4,
253
+ patch_size: int = 2,
254
+ num_groups: int = 12,
255
+ hidden_size: int = 1152,
256
+ hidden_size_x: int = 64,
257
+ num_blocks: int = 18,
258
+ num_cond_blocks: int = 4,
259
+ num_classes: int = 1000,
260
+ learn_sigma: bool = True,
261
+ deep_supervision: int = 0,
262
+ sample_size: int = 256,
263
+ # Deprecated config keys kept for backward-compatible hub configs.
264
+ conditioning_type: str = "class",
265
+ nerf_mlpratio: int = 4,
266
+ decoder_hidden_size: int = 64,
267
+ num_encoder_blocks: int = 18,
268
+ num_decoder_blocks: int = 4,
269
+ num_text_blocks: int = 4,
270
+ txt_embed_dim: int = 1024,
271
+ txt_max_length: int = 100,
272
+ ):
273
+ super().__init__()
274
+ del hidden_size_x, nerf_mlpratio, decoder_hidden_size, num_encoder_blocks, num_decoder_blocks
275
+ del num_text_blocks, txt_embed_dim, txt_max_length
276
+ if conditioning_type != "class":
277
+ raise ValueError("DeCoTransformer2DModel only supports class conditioning (c2i).")
278
+
279
+ self.backbone = _DeCoTransformerBackbone(
280
+ in_channels=in_channels,
281
+ patch_size=patch_size,
282
+ num_groups=num_groups,
283
+ hidden_size=hidden_size,
284
+ num_cond_blocks=num_cond_blocks,
285
+ num_classes=num_classes,
286
+ learn_sigma=learn_sigma,
287
+ deep_supervision=deep_supervision,
288
+ )
289
+
290
+ @property
291
+ def in_channels(self) -> int:
292
+ return int(self.config.in_channels)
293
+
294
+ def _prepare_timestep(
295
+ self, timestep: Union[torch.Tensor, float, int], batch_size: int, sample: torch.Tensor
296
+ ) -> torch.Tensor:
297
+ if not isinstance(timestep, torch.Tensor):
298
+ timestep = torch.tensor([timestep], device=sample.device, dtype=sample.dtype)
299
+ timestep = timestep.to(device=sample.device, dtype=sample.dtype)
300
+ if timestep.ndim == 0:
301
+ timestep = timestep[None]
302
+ if timestep.shape[0] == 1 and batch_size > 1:
303
+ timestep = timestep.repeat(batch_size)
304
+ return timestep
305
+
306
+ def forward(
307
+ self,
308
+ sample: torch.Tensor,
309
+ timestep: Union[torch.Tensor, float, int],
310
+ class_labels: Optional[torch.Tensor] = None,
311
+ decoder: Optional[nn.Module] = None,
312
+ encoder_hidden_states: Optional[torch.Tensor] = None,
313
+ return_dict: bool = True,
314
+ ) -> Union[DeCoTransformer2DModelOutput, tuple[torch.Tensor]]:
315
+ if encoder_hidden_states is not None:
316
+ raise ValueError("encoder_hidden_states is not supported; use class_labels for c2i DeCo models.")
317
+ if class_labels is None:
318
+ raise ValueError("class_labels must be provided for class-conditioned DeCo models.")
319
+ if decoder is None:
320
+ raise ValueError("decoder must be provided; load DeCoPatchDecoderModel as a separate pipeline component.")
321
+
322
+ batch_size = sample.shape[0]
323
+ t = self._prepare_timestep(timestep=timestep, batch_size=batch_size, sample=sample)
324
+ output = self.backbone(
325
+ sample,
326
+ t,
327
+ class_labels.to(device=sample.device, dtype=torch.long),
328
+ decoder=decoder,
329
+ )
330
+ if not return_dict:
331
+ return (output,)
332
+ return DeCoTransformer2DModelOutput(sample=output)
DeCo-XL-16-512/decoder/__pycache__/decoder_deco.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
DeCo-XL-16-512/decoder/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "hidden_size_x": 32,
4
+ "z_channels": 1152,
5
+ "max_freqs": 8,
6
+ "num_res_blocks": 3,
7
+ "patch_size": 16
8
+ }
DeCo-XL-16-512/decoder/decoder_deco.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+
18
+
19
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
20
+ return x * (1 + scale) + shift
21
+
22
+
23
+ class NerfEmbedder(nn.Module):
24
+ def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
25
+ super().__init__()
26
+ self.max_freqs = max_freqs
27
+ self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True))
28
+
29
+ @lru_cache
30
+ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
31
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
32
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
33
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
34
+ freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
35
+ freqs_x = freqs[None, :, None]
36
+ freqs_y = freqs[None, None, :]
37
+ coeffs = (1 + freqs_x * freqs_y) ** -1
38
+ dct = (
39
+ torch.cos(pos_x.reshape(-1, 1, 1) * freqs_x * torch.pi)
40
+ * torch.cos(pos_y.reshape(-1, 1, 1) * freqs_y * torch.pi)
41
+ * coeffs
42
+ ).view(1, -1, self.max_freqs**2)
43
+ return dct
44
+
45
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46
+ batch_size, patch_tokens, _ = inputs.shape
47
+ patch_size = int(patch_tokens**0.5)
48
+ dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
49
+ return self.embedder(torch.cat([inputs, dct], dim=-1))
50
+
51
+
52
+ class ResBlock(nn.Module):
53
+ def __init__(self, channels: int):
54
+ super().__init__()
55
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
56
+ self.mlp = nn.Sequential(
57
+ nn.Linear(channels, channels, bias=True),
58
+ nn.SiLU(),
59
+ nn.Linear(channels, channels, bias=True),
60
+ )
61
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
62
+
63
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
65
+ return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp))
66
+
67
+
68
+ class DecoderFinalLayer(nn.Module):
69
+ def __init__(self, model_channels: int, out_channels: int):
70
+ super().__init__()
71
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
72
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return self.linear(self.norm_final(x))
76
+
77
+
78
+ class SimpleMLPAdaLN(nn.Module):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ model_channels: int,
83
+ out_channels: int,
84
+ z_channels: int,
85
+ num_res_blocks: int,
86
+ patch_size: int,
87
+ grad_checkpointing: bool = False,
88
+ ):
89
+ super().__init__()
90
+ self.patch_size = patch_size
91
+ self.grad_checkpointing = grad_checkpointing
92
+ self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
93
+ self.input_proj = nn.Linear(in_channels, model_channels)
94
+ self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
95
+ self.final_layer = DecoderFinalLayer(model_channels, out_channels)
96
+ self._init_weights()
97
+
98
+ def _init_weights(self) -> None:
99
+ for block in self.res_blocks:
100
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
101
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
102
+ nn.init.constant_(self.final_layer.linear.weight, 0)
103
+ nn.init.constant_(self.final_layer.linear.bias, 0)
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ x = self.input_proj(x)
107
+ y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1)
108
+ for block in self.res_blocks:
109
+ if self.grad_checkpointing and not torch.jit.is_scripting():
110
+ x = checkpoint(block, x, y)
111
+ else:
112
+ x = block(x, y)
113
+ return self.final_layer(x)
114
+
115
+
116
+ @dataclass
117
+ class DeCoPatchDecoderOutput(BaseOutput):
118
+ sample: torch.Tensor
119
+
120
+
121
+ class DeCoPatchDecoderModel(ModelMixin, ConfigMixin):
122
+ """Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP)."""
123
+
124
+ config_name = "config.json"
125
+
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ in_channels: int = 3,
130
+ hidden_size_x: int = 32,
131
+ z_channels: int = 1152,
132
+ num_res_blocks: int = 3,
133
+ patch_size: int = 16,
134
+ max_freqs: int = 8,
135
+ ):
136
+ super().__init__()
137
+ self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs)
138
+ self.dec_net = SimpleMLPAdaLN(
139
+ in_channels=hidden_size_x,
140
+ model_channels=hidden_size_x,
141
+ out_channels=in_channels,
142
+ z_channels=z_channels,
143
+ num_res_blocks=num_res_blocks,
144
+ patch_size=patch_size,
145
+ )
146
+
147
+ def forward(
148
+ self,
149
+ patch_pixels: torch.Tensor,
150
+ conditioning: torch.Tensor,
151
+ return_dict: bool = True,
152
+ ) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]:
153
+ """
154
+ Args:
155
+ patch_pixels (`torch.Tensor`):
156
+ Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`.
157
+ conditioning (`torch.Tensor`):
158
+ Per-patch conditioning of shape `(batch * num_patches, z_channels)`.
159
+ """
160
+ output = self.dec_net(self.x_embedder(patch_pixels), conditioning)
161
+ if not return_dict:
162
+ return (output,)
163
+ return DeCoPatchDecoderOutput(sample=output)
DeCo-XL-16-512/decoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca6476afbc38d431cc503a810567d5d30075c57e9209567b1b12279d749b5a8
3
+ size 37862236
DeCo-XL-16-512/decoder/diffusion_pytorch_model.safetensors.bak ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca6476afbc38d431cc503a810567d5d30075c57e9209567b1b12279d749b5a8
3
+ size 37862236
DeCo-XL-16-512/demo.png ADDED

Git LFS Details

  • SHA256: af1ae9ea3f293d2f531f437c93c28df86e648f28fbdba5ec3ce65724cf480822
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
DeCo-XL-16-512/model_index.json ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "DeCoPipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "transformer": [
8
+ "transformer_deco",
9
+ "DeCoTransformer2DModel"
10
+ ],
11
+ "scheduler": [
12
+ "scheduling_deco_flow_match_euler_discrete",
13
+ "DeCoFlowMatchEulerDiscreteScheduler"
14
+ ],
15
+ "decoder": [
16
+ "decoder_deco",
17
+ "DeCoPatchDecoderModel"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
DeCo-XL-16-512/pipeline.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: DeCoPipeline (class-conditioned c2i).
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+
15
+ EXAMPLE_DOC_STRING = """
16
+ Examples:
17
+ ```py
18
+ >>> from pathlib import Path
19
+ >>> from diffusers import DiffusionPipeline
20
+ >>> import torch
21
+
22
+ >>> model_dir = Path("./DeCo-XL-16-512").resolve()
23
+ >>> pipe = DiffusionPipeline.from_pretrained(
24
+ ... str(model_dir),
25
+ ... local_files_only=True,
26
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
27
+ ... trust_remote_code=True,
28
+ ... torch_dtype=torch.bfloat16,
29
+ ... )
30
+ >>> pipe.to("cuda")
31
+
32
+ >>> print(pipe.id2label[207])
33
+ >>> print(pipe.get_label_ids("golden retriever"))
34
+
35
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
36
+ >>> image = pipe(
37
+ ... class_labels="golden retriever",
38
+ ... num_inference_steps=100,
39
+ ... guidance_scale=5.0,
40
+ ... generator=generator,
41
+ ... ).images[0]
42
+ ```
43
+ """
44
+
45
+
46
+ class DeCoPipeline(DiffusionPipeline):
47
+ r"""
48
+ Pipeline for class-conditional image generation with DeCo.
49
+
50
+ Parameters:
51
+ transformer ([`DeCoTransformer2DModel`]):
52
+ Class-conditional DeCo transformer.
53
+ scheduler ([`DeCoFlowMatchEulerDiscreteScheduler`]):
54
+ Flow-matching Euler scheduler for DeCo.
55
+ decoder ([`DeCoPatchDecoderModel`]):
56
+ Per-patch RGB decoder (NerfEmbedder + AdaLN MLP).
57
+ id2label (`dict[int, str]`, *optional*):
58
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
59
+ """
60
+
61
+ model_cpu_offload_seq = "transformer->decoder"
62
+
63
+ def __init__(
64
+ self,
65
+ transformer,
66
+ scheduler,
67
+ decoder,
68
+ id2label: Optional[Dict[Union[int, str], str]] = None,
69
+ ):
70
+ super().__init__()
71
+ self.register_modules(transformer=transformer, scheduler=scheduler, decoder=decoder)
72
+ self._id2label = self._normalize_id2label(id2label)
73
+ self.labels = self._build_label2id(self._id2label)
74
+ self._labels_loaded_from_model_index = bool(self._id2label)
75
+
76
+ def _ensure_labels_loaded(self) -> None:
77
+ if self._labels_loaded_from_model_index:
78
+ return
79
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
80
+ if loaded:
81
+ self._id2label = loaded
82
+ self.labels = self._build_label2id(self._id2label)
83
+ self._labels_loaded_from_model_index = True
84
+
85
+ @staticmethod
86
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
87
+ if not id2label:
88
+ return {}
89
+ return {int(key): value for key, value in id2label.items()}
90
+
91
+ @staticmethod
92
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
93
+ if not variant_path:
94
+ return {}
95
+ variant_dir = Path(variant_path).resolve()
96
+ model_index_path = variant_dir / "model_index.json"
97
+ if not model_index_path.exists():
98
+ return {}
99
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
100
+ id2label = raw.get("id2label")
101
+ if not isinstance(id2label, dict):
102
+ return {}
103
+ return {int(key): value for key, value in id2label.items()}
104
+
105
+ @staticmethod
106
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
107
+ label2id: Dict[str, int] = {}
108
+ for class_id, value in id2label.items():
109
+ for synonym in value.split(","):
110
+ synonym = synonym.strip()
111
+ if synonym:
112
+ label2id[synonym] = int(class_id)
113
+ return dict(sorted(label2id.items()))
114
+
115
+ @property
116
+ def id2label(self) -> Dict[int, str]:
117
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
118
+ self._ensure_labels_loaded()
119
+ return self._id2label
120
+
121
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
122
+ r"""
123
+ Map ImageNet label strings to class ids.
124
+
125
+ Args:
126
+ label (`str` or `list[str]`):
127
+ One or more English label strings. Each string must match a synonym in `id2label`.
128
+ """
129
+ self._ensure_labels_loaded()
130
+ label2id = self.labels
131
+ if not label2id:
132
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
133
+
134
+ if isinstance(label, str):
135
+ label = [label]
136
+
137
+ missing = [item for item in label if item not in label2id]
138
+ if missing:
139
+ preview = ", ".join(list(label2id.keys())[:8])
140
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
141
+ return [label2id[item] for item in label]
142
+
143
+ def _normalize_class_labels(
144
+ self,
145
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
146
+ ) -> torch.LongTensor:
147
+ if torch.is_tensor(class_labels):
148
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
149
+
150
+ if isinstance(class_labels, int):
151
+ class_label_ids = [class_labels]
152
+ elif isinstance(class_labels, str):
153
+ class_label_ids = self.get_label_ids(class_labels)
154
+ elif class_labels and isinstance(class_labels[0], str):
155
+ class_label_ids = self.get_label_ids(class_labels)
156
+ else:
157
+ class_label_ids = list(class_labels)
158
+
159
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
160
+
161
+ def _default_sample_size(self) -> int:
162
+ return int(getattr(self.transformer.config, "sample_size", 256))
163
+
164
+ @torch.no_grad()
165
+ def __call__(
166
+ self,
167
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
168
+ batch_size: Optional[int] = None,
169
+ height: Optional[int] = None,
170
+ width: Optional[int] = None,
171
+ num_inference_steps: int = 50,
172
+ guidance_scale: float = 1.0,
173
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
174
+ output_type: str = "pil",
175
+ return_dict: bool = True,
176
+ ) -> Union[ImagePipelineOutput, Tuple]:
177
+ r"""
178
+ Generate class-conditional images with DeCo.
179
+
180
+ Args:
181
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
182
+ ImageNet class indices or human-readable English label strings.
183
+ batch_size (`int`, *optional*):
184
+ Number of images to generate. Defaults to the number of class labels. When a single
185
+ class label is provided, repeats it to match `batch_size`.
186
+ height (`int`, *optional*):
187
+ Output image height in pixels. Defaults to `transformer.config.sample_size`.
188
+ width (`int`, *optional*):
189
+ Output image width in pixels. Defaults to `transformer.config.sample_size`.
190
+ num_inference_steps (`int`, defaults to `50`):
191
+ Number of denoising steps.
192
+ guidance_scale (`float`, defaults to `1.0`):
193
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
194
+ generator (`torch.Generator`, *optional*):
195
+ RNG for reproducibility.
196
+ output_type (`str`, defaults to `"pil"`):
197
+ `"pil"`, `"np"`, or `"latent"`.
198
+ return_dict (`bool`, defaults to `True`):
199
+ Return [`ImagePipelineOutput`] if True.
200
+ """
201
+ device = self._execution_device
202
+ dtype = next(self.transformer.parameters()).dtype
203
+ do_cfg = guidance_scale is not None and float(guidance_scale) > 1.0
204
+
205
+ sample_size = self._default_sample_size()
206
+ height = int(height if height is not None else sample_size)
207
+ width = int(width if width is not None else sample_size)
208
+
209
+ class_labels = self._normalize_class_labels(class_labels)
210
+ if batch_size is None:
211
+ batch_size = int(class_labels.numel())
212
+ elif class_labels.numel() == 1 and batch_size > 1:
213
+ class_labels = class_labels.repeat(batch_size)
214
+ elif class_labels.numel() != batch_size:
215
+ raise ValueError("class_labels batch size must match batch_size")
216
+
217
+ if do_cfg:
218
+ null_label = int(self.transformer.config.num_classes)
219
+ uncond_labels = torch.full((batch_size,), null_label, device=device, dtype=torch.long)
220
+
221
+ latents = randn_tensor(
222
+ (batch_size, int(self.transformer.config.in_channels), height, width),
223
+ generator=generator,
224
+ device=device,
225
+ dtype=dtype,
226
+ )
227
+
228
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
229
+ timesteps = self.scheduler.timesteps[:-1]
230
+
231
+ for timestep in self.progress_bar(timesteps):
232
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
233
+
234
+ if do_cfg:
235
+ latent_model_input = torch.cat([latent_model_input, latent_model_input], dim=0)
236
+ model_output = self.transformer(
237
+ latent_model_input,
238
+ timestep,
239
+ class_labels=torch.cat([uncond_labels, class_labels], dim=0),
240
+ decoder=self.decoder,
241
+ ).sample
242
+ model_output_uncond, model_output_cond = model_output.chunk(2)
243
+ model_output = model_output_uncond + float(guidance_scale) * (model_output_cond - model_output_uncond)
244
+ else:
245
+ model_output = self.transformer(
246
+ latent_model_input, timestep, class_labels=class_labels, decoder=self.decoder
247
+ ).sample
248
+
249
+ latents = self.scheduler.step(model_output, timestep, latents).prev_sample
250
+
251
+ image = latents
252
+
253
+ if output_type == "latent":
254
+ if not return_dict:
255
+ return (image,)
256
+ return ImagePipelineOutput(images=image)
257
+
258
+ image = (image / 2 + 0.5).clamp(0, 1)
259
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
260
+
261
+ if output_type == "pil":
262
+ image = self.numpy_to_pil(image)
263
+ elif output_type != "np":
264
+ raise ValueError("output_type must be one of {'pil', 'np', 'latent'}")
265
+
266
+ if not return_dict:
267
+ return (image,)
268
+ return ImagePipelineOutput(images=image)
DeCo-XL-16-512/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeCoFlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.31.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "last_step": null,
7
+ "prediction_type": "v_prediction"
8
+ }
DeCo-XL-16-512/scheduler/scheduling_deco_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
9
+
10
+
11
+
12
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
13
+ return t / (t + (1 - t) * shift)
14
+
15
+
16
+ class DeCoFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
17
+ config_name = "scheduler_config.json"
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ num_train_timesteps: int = 1000,
23
+ shift: float = 1.0,
24
+ last_step: Optional[float] = None,
25
+ prediction_type: str = "v_prediction",
26
+ ):
27
+ self.timesteps = torch.tensor([], dtype=torch.float32)
28
+ self.num_inference_steps: Optional[int] = None
29
+ self._step_index: int = 0
30
+
31
+ @property
32
+ def init_noise_sigma(self) -> float:
33
+ return 1.0
34
+
35
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
36
+ if num_inference_steps <= 0:
37
+ raise ValueError("num_inference_steps must be > 0")
38
+
39
+ self.num_inference_steps = int(num_inference_steps)
40
+ last_step = self.config.last_step
41
+ if last_step is None:
42
+ last_step = 1.0 / float(self.num_inference_steps)
43
+
44
+ base_timesteps = torch.linspace(0.0, 1.0 - float(last_step), self.num_inference_steps, dtype=torch.float32)
45
+ base_timesteps = torch.cat([base_timesteps, torch.tensor([1.0], dtype=torch.float32)], dim=0)
46
+ timesteps = _shift_respace_fn(base_timesteps, shift=float(self.config.shift))
47
+
48
+ if device is not None:
49
+ timesteps = timesteps.to(device)
50
+
51
+ self.timesteps = timesteps
52
+ self._step_index = 0
53
+
54
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
55
+ return sample
56
+
57
+ def step(
58
+ self,
59
+ model_output: torch.Tensor,
60
+ timestep: Union[torch.Tensor, float],
61
+ sample: torch.Tensor,
62
+ return_dict: bool = True,
63
+ ):
64
+ if self.num_inference_steps is None or self.timesteps.numel() == 0:
65
+ raise ValueError("Call set_timesteps before step")
66
+
67
+ step_index = min(self._step_index, len(self.timesteps) - 2)
68
+ dt = (self.timesteps[step_index + 1] - self.timesteps[step_index]).to(device=sample.device, dtype=sample.dtype)
69
+
70
+ prev_sample = sample + model_output * dt
71
+
72
+ self._step_index += 1
73
+
74
+ if not return_dict:
75
+ return (prev_sample,)
76
+ return SchedulerOutput(prev_sample=prev_sample)
77
+
78
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
79
+ if timesteps.ndim == 0:
80
+ timesteps = timesteps[None]
81
+ t = timesteps.to(device=original_samples.device, dtype=original_samples.dtype).view(-1, 1, 1, 1)
82
+ return t * original_samples + (1.0 - t) * noise
DeCo-XL-16-512/transformer/__pycache__/transformer_deco.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
DeCo-XL-16-512/transformer/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sample_size": 512,
3
+ "conditioning_type": "class",
4
+ "decoder_hidden_size": 64,
5
+ "deep_supervision": 0,
6
+ "hidden_size": 1152,
7
+ "hidden_size_x": 32,
8
+ "in_channels": 3,
9
+ "learn_sigma": true,
10
+ "nerf_mlpratio": 4,
11
+ "num_blocks": 31,
12
+ "num_classes": 1000,
13
+ "num_cond_blocks": 28,
14
+ "num_decoder_blocks": 4,
15
+ "num_encoder_blocks": 18,
16
+ "num_groups": 16,
17
+ "num_text_blocks": 4,
18
+ "patch_size": 16,
19
+ "txt_embed_dim": 1024,
20
+ "txt_max_length": 100
21
+ }
DeCo-XL-16-512/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0468009ed0ab700db3cbf906ae88f6ae19ac6548ddef7f8f2a8f0195c2fe33f
3
+ size 2691309848
DeCo-XL-16-512/transformer/diffusion_pytorch_model.safetensors.bak ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d5ae272eea0747e306bcef99cc32b014cc1180f7fc1462cdb3e6e27ee0ffd3e
3
+ size 2691309848
DeCo-XL-16-512/transformer/transformer_deco.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+ from diffusers.models.normalization import RMSNorm
18
+
19
+
20
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
21
+ return x * (1 + scale) + shift
22
+
23
+
24
+ class PatchEmbed(nn.Module):
25
+ def __init__(self, in_chans: int, embed_dim: int, bias: bool = True):
26
+ super().__init__()
27
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return self.proj(x)
31
+
32
+
33
+ class TimestepEmbedder(nn.Module):
34
+ """Sinusoidal timestep embedding with checkpoint-compatible `mlp` module names."""
35
+
36
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
37
+ super().__init__()
38
+ self.mlp = nn.Sequential(
39
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
40
+ nn.SiLU(),
41
+ nn.Linear(hidden_size, hidden_size, bias=True),
42
+ )
43
+ self.frequency_embedding_size = frequency_embedding_size
44
+
45
+ @staticmethod
46
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10) -> torch.Tensor:
47
+ half = dim // 2
48
+ freqs = torch.exp(
49
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
50
+ )
51
+ args = t[..., None].float() * freqs[None, ...]
52
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
53
+ if dim % 2:
54
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
55
+ return embedding.to(t.dtype)
56
+
57
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
58
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
59
+ return self.mlp(t_freq)
60
+
61
+
62
+ class DeCoSwiGLU(nn.Module):
63
+ """SwiGLU MLP with w1/w2/w3 layout matching official DeCo checkpoints."""
64
+
65
+ def __init__(self, dim: int, hidden_dim: int):
66
+ super().__init__()
67
+ hidden_dim = int(2 * hidden_dim / 3)
68
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
69
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
70
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
74
+
75
+
76
+ def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale: float = 16.0) -> torch.Tensor:
77
+ x_pos = torch.linspace(0, scale, width)
78
+ y_pos = torch.linspace(0, scale, height)
79
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
80
+ y_pos = y_pos.reshape(-1)
81
+ x_pos = x_pos.reshape(-1)
82
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
83
+ x_freqs = torch.outer(x_pos, freqs).float()
84
+ y_freqs = torch.outer(y_pos, freqs).float()
85
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
86
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
87
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
88
+ return freqs_cis.reshape(height * width, -1)
89
+
90
+
91
+ def apply_rotary_emb(
92
+ xq: torch.Tensor,
93
+ xk: torch.Tensor,
94
+ freqs_cis: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ freqs_cis = freqs_cis[None, :, None, :]
97
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
98
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
99
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
100
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
101
+ return xq_out.type_as(xq), xk_out.type_as(xk)
102
+
103
+
104
+ class LabelEmbedder(nn.Module):
105
+ def __init__(self, num_classes: int, hidden_size: int):
106
+ super().__init__()
107
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
108
+
109
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
110
+ return self.embedding_table(labels)
111
+
112
+
113
+ class RAttention(nn.Module):
114
+ def __init__(
115
+ self,
116
+ dim: int,
117
+ num_heads: int = 8,
118
+ qkv_bias: bool = False,
119
+ qk_norm: bool = True,
120
+ proj_drop: float = 0.0,
121
+ ) -> None:
122
+ super().__init__()
123
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
124
+ self.num_heads = num_heads
125
+ self.head_dim = dim // num_heads
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.q_norm = RMSNorm(self.head_dim, eps=1e-6) if qk_norm else nn.Identity()
128
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-6) if qk_norm else nn.Identity()
129
+ self.proj = nn.Linear(dim, dim)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ def forward(self, x: torch.Tensor, pos: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
133
+ batch_size, num_tokens, channels = x.shape
134
+ qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
135
+ query, key, value = qkv[0], qkv[1], qkv[2]
136
+ query = self.q_norm(query)
137
+ key = self.k_norm(key)
138
+ query, key = apply_rotary_emb(query, key, freqs_cis=pos)
139
+ query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
140
+ key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
141
+ value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
142
+ x = scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=0.0)
143
+ x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
144
+ return self.proj_drop(self.proj(x))
145
+
146
+
147
+ class FlattenDiTBlock(nn.Module):
148
+ def __init__(self, hidden_size: int, groups: int, mlp_ratio: float = 4.0):
149
+ super().__init__()
150
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
151
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
152
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
153
+ self.mlp = DeCoSwiGLU(hidden_size, int(hidden_size * mlp_ratio))
154
+ self.adaLN_modulation = nn.Sequential(nn.Linear(hidden_size, 6 * hidden_size, bias=True))
155
+
156
+ def forward(self, x: torch.Tensor, c: torch.Tensor, pos: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
157
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
158
+ x = x + gate_msa * self.attn(_modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
159
+ return x + gate_mlp * self.mlp(_modulate(self.norm2(x), shift_mlp, scale_mlp))
160
+
161
+
162
+ @dataclass
163
+ class DeCoTransformer2DModelOutput(BaseOutput):
164
+ sample: torch.Tensor
165
+
166
+
167
+ class _DeCoTransformerBackbone(nn.Module):
168
+ """Class-conditioned DeCo conditioning trunk. Checkpoint weights live under the `backbone.` prefix."""
169
+
170
+ def __init__(
171
+ self,
172
+ in_channels: int,
173
+ patch_size: int,
174
+ num_groups: int,
175
+ hidden_size: int,
176
+ num_cond_blocks: int,
177
+ num_classes: int,
178
+ learn_sigma: bool,
179
+ deep_supervision: int,
180
+ ):
181
+ super().__init__()
182
+ self.learn_sigma = learn_sigma
183
+ self.deep_supervision = deep_supervision
184
+ self.in_channels = in_channels
185
+ self.patch_size = patch_size
186
+ self.hidden_size = hidden_size
187
+ self.num_groups = num_groups
188
+ self.num_cond_blocks = num_cond_blocks
189
+
190
+ self.s_embedder = PatchEmbed(in_channels * patch_size**2, hidden_size, bias=True)
191
+ self.t_embedder = TimestepEmbedder(hidden_size)
192
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
193
+ self.blocks = nn.ModuleList([FlattenDiTBlock(hidden_size, num_groups) for _ in range(num_cond_blocks)])
194
+ self.precompute_pos: dict[tuple[int, int], torch.Tensor] = {}
195
+ self._init_weights()
196
+
197
+ def _init_weights(self) -> None:
198
+ weight = self.s_embedder.proj.weight.data
199
+ nn.init.xavier_uniform_(weight.view([weight.shape[0], -1]))
200
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
201
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
202
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204
+
205
+ def fetch_pos(self, height: int, width: int, device: torch.device) -> torch.Tensor:
206
+ key = (height, width)
207
+ if key not in self.precompute_pos:
208
+ self.precompute_pos[key] = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
209
+ return self.precompute_pos[key].to(device)
210
+
211
+ def forward(
212
+ self,
213
+ x: torch.Tensor,
214
+ t: torch.Tensor,
215
+ y: torch.Tensor,
216
+ decoder: nn.Module,
217
+ s: Optional[torch.Tensor] = None,
218
+ mask: Optional[torch.Tensor] = None,
219
+ ) -> torch.Tensor:
220
+ batch_size, _, height, width = x.shape
221
+ pos = self.fetch_pos(height // self.patch_size, width // self.patch_size, x.device)
222
+ x = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
223
+ t = self.t_embedder(t.view(-1)).view(batch_size, -1, self.hidden_size)
224
+ y = self.y_embedder(y).view(batch_size, 1, self.hidden_size)
225
+ c = F.silu(t + y)
226
+ if s is None:
227
+ s = self.s_embedder(x)
228
+ for block in self.blocks:
229
+ s = block(s, c, pos, mask)
230
+ s = F.silu(t + s)
231
+ batch_size, length, _ = s.shape
232
+ patch_pixels = x.reshape(batch_size * length, self.in_channels, self.patch_size**2).transpose(1, 2)
233
+ conditioning = s.view(batch_size * length, self.hidden_size)
234
+ decoded = decoder(patch_pixels, conditioning).sample
235
+ x = decoded.transpose(1, 2).reshape(batch_size, length, -1)
236
+ return F.fold(
237
+ x.transpose(1, 2).contiguous(),
238
+ (height, width),
239
+ kernel_size=self.patch_size,
240
+ stride=self.patch_size,
241
+ )
242
+
243
+
244
+ class DeCoTransformer2DModel(ModelMixin, ConfigMixin):
245
+ """Class-conditioned DeCo transformer (c2i) for Diffusers pipelines."""
246
+
247
+ config_name = "config.json"
248
+
249
+ @register_to_config
250
+ def __init__(
251
+ self,
252
+ in_channels: int = 4,
253
+ patch_size: int = 2,
254
+ num_groups: int = 12,
255
+ hidden_size: int = 1152,
256
+ hidden_size_x: int = 64,
257
+ num_blocks: int = 18,
258
+ num_cond_blocks: int = 4,
259
+ num_classes: int = 1000,
260
+ learn_sigma: bool = True,
261
+ deep_supervision: int = 0,
262
+ sample_size: int = 256,
263
+ # Deprecated config keys kept for backward-compatible hub configs.
264
+ conditioning_type: str = "class",
265
+ nerf_mlpratio: int = 4,
266
+ decoder_hidden_size: int = 64,
267
+ num_encoder_blocks: int = 18,
268
+ num_decoder_blocks: int = 4,
269
+ num_text_blocks: int = 4,
270
+ txt_embed_dim: int = 1024,
271
+ txt_max_length: int = 100,
272
+ ):
273
+ super().__init__()
274
+ del hidden_size_x, nerf_mlpratio, decoder_hidden_size, num_encoder_blocks, num_decoder_blocks
275
+ del num_text_blocks, txt_embed_dim, txt_max_length
276
+ if conditioning_type != "class":
277
+ raise ValueError("DeCoTransformer2DModel only supports class conditioning (c2i).")
278
+
279
+ self.backbone = _DeCoTransformerBackbone(
280
+ in_channels=in_channels,
281
+ patch_size=patch_size,
282
+ num_groups=num_groups,
283
+ hidden_size=hidden_size,
284
+ num_cond_blocks=num_cond_blocks,
285
+ num_classes=num_classes,
286
+ learn_sigma=learn_sigma,
287
+ deep_supervision=deep_supervision,
288
+ )
289
+
290
+ @property
291
+ def in_channels(self) -> int:
292
+ return int(self.config.in_channels)
293
+
294
+ def _prepare_timestep(
295
+ self, timestep: Union[torch.Tensor, float, int], batch_size: int, sample: torch.Tensor
296
+ ) -> torch.Tensor:
297
+ if not isinstance(timestep, torch.Tensor):
298
+ timestep = torch.tensor([timestep], device=sample.device, dtype=sample.dtype)
299
+ timestep = timestep.to(device=sample.device, dtype=sample.dtype)
300
+ if timestep.ndim == 0:
301
+ timestep = timestep[None]
302
+ if timestep.shape[0] == 1 and batch_size > 1:
303
+ timestep = timestep.repeat(batch_size)
304
+ return timestep
305
+
306
+ def forward(
307
+ self,
308
+ sample: torch.Tensor,
309
+ timestep: Union[torch.Tensor, float, int],
310
+ class_labels: Optional[torch.Tensor] = None,
311
+ decoder: Optional[nn.Module] = None,
312
+ encoder_hidden_states: Optional[torch.Tensor] = None,
313
+ return_dict: bool = True,
314
+ ) -> Union[DeCoTransformer2DModelOutput, tuple[torch.Tensor]]:
315
+ if encoder_hidden_states is not None:
316
+ raise ValueError("encoder_hidden_states is not supported; use class_labels for c2i DeCo models.")
317
+ if class_labels is None:
318
+ raise ValueError("class_labels must be provided for class-conditioned DeCo models.")
319
+ if decoder is None:
320
+ raise ValueError("decoder must be provided; load DeCoPatchDecoderModel as a separate pipeline component.")
321
+
322
+ batch_size = sample.shape[0]
323
+ t = self._prepare_timestep(timestep=timestep, batch_size=batch_size, sample=sample)
324
+ output = self.backbone(
325
+ sample,
326
+ t,
327
+ class_labels.to(device=sample.device, dtype=torch.long),
328
+ decoder=decoder,
329
+ )
330
+ if not return_dict:
331
+ return (output,)
332
+ return DeCoTransformer2DModelOutput(sample=output)
DeCo-XXL-16-512-t2i/decoder/__pycache__/decoder_deco.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
DeCo-XXL-16-512-t2i/decoder/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "hidden_size_x": 32,
4
+ "z_channels": 1536,
5
+ "max_freqs": 8,
6
+ "num_res_blocks": 3,
7
+ "patch_size": 16
8
+ }
DeCo-XXL-16-512-t2i/decoder/decoder_deco.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+
18
+
19
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
20
+ return x * (1 + scale) + shift
21
+
22
+
23
+ def _precompute_freqs_cis_ex2d(
24
+ dim: int,
25
+ height: int,
26
+ width: int,
27
+ theta: float = 10000.0,
28
+ scale: float = 1.0,
29
+ ) -> torch.Tensor:
30
+ """Match Zehong-Ma/DeCo `precompute_freqs_cis_ex2d` used by NerfEmbedder."""
31
+ if isinstance(scale, float):
32
+ scale = (scale, scale)
33
+ x_pos = torch.linspace(0, height * scale[0], width)
34
+ y_pos = torch.linspace(0, width * scale[1], height)
35
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
36
+ y_pos = y_pos.reshape(-1)
37
+ x_pos = x_pos.reshape(-1)
38
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
39
+ x_freqs = torch.outer(x_pos, freqs).float()
40
+ y_freqs = torch.outer(y_pos, freqs).float()
41
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
42
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
43
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
44
+ return freqs_cis.reshape(height * width, -1)
45
+
46
+
47
+ class NerfEmbedder(nn.Module):
48
+ def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
49
+ super().__init__()
50
+ self.max_freqs = max_freqs
51
+ self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True))
52
+
53
+ @lru_cache
54
+ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
55
+ pos = _precompute_freqs_cis_ex2d(self.max_freqs**2 * 2, patch_size, patch_size)
56
+ # Official code casts complex cis to real when concatenating with patch pixels.
57
+ return pos[None, :, :].to(device=device, dtype=dtype)
58
+
59
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
60
+ batch_size, patch_tokens, _ = inputs.shape
61
+ patch_size = int(patch_tokens**0.5)
62
+ dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
63
+ return self.embedder(torch.cat([inputs, dct], dim=-1))
64
+
65
+
66
+ class ResBlock(nn.Module):
67
+ def __init__(self, channels: int):
68
+ super().__init__()
69
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
70
+ self.mlp = nn.Sequential(
71
+ nn.Linear(channels, channels, bias=True),
72
+ nn.SiLU(),
73
+ nn.Linear(channels, channels, bias=True),
74
+ )
75
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
76
+
77
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
78
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
79
+ return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp))
80
+
81
+
82
+ class DecoderFinalLayer(nn.Module):
83
+ def __init__(self, model_channels: int, out_channels: int):
84
+ super().__init__()
85
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
86
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
87
+
88
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
89
+ return self.linear(self.norm_final(x))
90
+
91
+
92
+ class SimpleMLPAdaLN(nn.Module):
93
+ def __init__(
94
+ self,
95
+ in_channels: int,
96
+ model_channels: int,
97
+ out_channels: int,
98
+ z_channels: int,
99
+ num_res_blocks: int,
100
+ patch_size: int,
101
+ grad_checkpointing: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.patch_size = patch_size
105
+ self.grad_checkpointing = grad_checkpointing
106
+ self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
107
+ self.input_proj = nn.Linear(in_channels, model_channels)
108
+ self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
109
+ self.final_layer = DecoderFinalLayer(model_channels, out_channels)
110
+ self._init_weights()
111
+
112
+ def _init_weights(self) -> None:
113
+ for block in self.res_blocks:
114
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
115
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
116
+ nn.init.constant_(self.final_layer.linear.weight, 0)
117
+ nn.init.constant_(self.final_layer.linear.bias, 0)
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ x = self.input_proj(x)
121
+ y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1)
122
+ for block in self.res_blocks:
123
+ if self.grad_checkpointing and not torch.jit.is_scripting():
124
+ x = checkpoint(block, x, y)
125
+ else:
126
+ x = block(x, y)
127
+ return self.final_layer(x)
128
+
129
+
130
+ @dataclass
131
+ class DeCoPatchDecoderOutput(BaseOutput):
132
+ sample: torch.Tensor
133
+
134
+
135
+ class DeCoPatchDecoderModel(ModelMixin, ConfigMixin):
136
+ """Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP)."""
137
+
138
+ config_name = "config.json"
139
+
140
+ @register_to_config
141
+ def __init__(
142
+ self,
143
+ in_channels: int = 3,
144
+ hidden_size_x: int = 32,
145
+ z_channels: int = 1152,
146
+ num_res_blocks: int = 3,
147
+ patch_size: int = 16,
148
+ max_freqs: int = 8,
149
+ ):
150
+ super().__init__()
151
+ self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs)
152
+ self.dec_net = SimpleMLPAdaLN(
153
+ in_channels=hidden_size_x,
154
+ model_channels=hidden_size_x,
155
+ out_channels=in_channels,
156
+ z_channels=z_channels,
157
+ num_res_blocks=num_res_blocks,
158
+ patch_size=patch_size,
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ patch_pixels: torch.Tensor,
164
+ conditioning: torch.Tensor,
165
+ return_dict: bool = True,
166
+ ) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]:
167
+ """
168
+ Args:
169
+ patch_pixels (`torch.Tensor`):
170
+ Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`.
171
+ conditioning (`torch.Tensor`):
172
+ Per-patch conditioning of shape `(batch * num_patches, z_channels)`.
173
+ """
174
+ output = self.dec_net(self.x_embedder(patch_pixels), conditioning)
175
+ if not return_dict:
176
+ return (output,)
177
+ return DeCoPatchDecoderOutput(sample=output)
DeCo-XXL-16-512-t2i/decoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2468211dbc3dd72c7ebd9d7d86913442c9a1dc93fec03cfa135a658d84d5fd5e
3
+ size 50445148
DeCo-XXL-16-512-t2i/model_index.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "DeCoT2IPipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "transformer": [
8
+ "transformer_deco_t2i",
9
+ "DeCoT2ITransformer2DModel"
10
+ ],
11
+ "decoder": [
12
+ "decoder_deco",
13
+ "DeCoPatchDecoderModel"
14
+ ],
15
+ "scheduler": [
16
+ "scheduling_deco_flow_match_adam_discrete",
17
+ "DeCoFlowMatchAdamDiscreteScheduler"
18
+ ],
19
+ "text_encoder": [
20
+ "transformers",
21
+ "Qwen3Model"
22
+ ],
23
+ "tokenizer": [
24
+ "transformers",
25
+ "Qwen2Tokenizer"
26
+ ]
27
+ }
DeCo-XXL-16-512-t2i/pipeline.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: DeCoT2IPipeline (text-to-image, 512×512).
2
+
3
+ Sampling matches official DeCo AdamLMSampler:
4
+ https://github.com/Zehong-Ma/DeCo/blob/main/src/diffusion/flow_matching/adam_sampling.py
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+ from typing import List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+
16
+ DEFAULT_TEXT_ENCODER_REPO = "Qwen/Qwen3-1.7B"
17
+
18
+
19
+ class DeCoT2IPipeline(DiffusionPipeline):
20
+ model_cpu_offload_seq = "text_encoder->transformer->decoder"
21
+ _optional_components = ["text_encoder", "tokenizer"]
22
+
23
+ def __init__(
24
+ self,
25
+ transformer,
26
+ scheduler,
27
+ decoder,
28
+ text_encoder=None,
29
+ tokenizer=None,
30
+ ):
31
+ super().__init__()
32
+ self.register_modules(
33
+ transformer=transformer,
34
+ scheduler=scheduler,
35
+ decoder=decoder,
36
+ text_encoder=text_encoder,
37
+ tokenizer=tokenizer,
38
+ )
39
+
40
+ @classmethod
41
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
42
+ pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
43
+ if pipe.text_encoder is None or pipe.tokenizer is None:
44
+ model_dir = Path(getattr(pipe.config, "_name_or_path", pretrained_model_name_or_path)).resolve()
45
+ pipe._load_text_encoder(model_dir, **kwargs)
46
+ return pipe
47
+
48
+ @staticmethod
49
+ def _resolve_text_encoder_path(model_dir: Path) -> Path:
50
+ hint = model_dir / "text_encoder_pretrained_model_name_or_path.txt"
51
+ if hint.exists():
52
+ raw = hint.read_text(encoding="utf-8").strip().splitlines()[0].strip()
53
+ path = Path(raw)
54
+ if not path.is_absolute():
55
+ path = (model_dir / path).resolve()
56
+ if path.exists():
57
+ return path
58
+ local = model_dir / "text_encoder"
59
+ if local.exists():
60
+ return local.resolve()
61
+ return Path(DEFAULT_TEXT_ENCODER_REPO)
62
+
63
+ def _load_text_encoder(self, model_dir: Path, **kwargs) -> None:
64
+ from transformers import Qwen2Tokenizer, Qwen3Model
65
+
66
+ text_path = self._resolve_text_encoder_path(model_dir)
67
+ load_kwargs = {
68
+ k: kwargs[k]
69
+ for k in ("torch_dtype", "device_map", "local_files_only", "revision", "cache_dir")
70
+ if k in kwargs
71
+ }
72
+ text_encoder = Qwen3Model.from_pretrained(str(text_path), **load_kwargs)
73
+ tokenizer = Qwen2Tokenizer.from_pretrained(
74
+ str(text_path),
75
+ max_length=self.txt_max_length,
76
+ padding_side="right",
77
+ **{k: v for k, v in load_kwargs.items() if k in ("local_files_only", "revision", "cache_dir")},
78
+ )
79
+ self.register_modules(text_encoder=text_encoder, tokenizer=tokenizer)
80
+
81
+ @property
82
+ def txt_embed_dim(self) -> int:
83
+ return int(getattr(self.transformer.config, "txt_embed_dim", 2048))
84
+
85
+ @property
86
+ def txt_max_length(self) -> int:
87
+ return int(getattr(self.transformer.config, "txt_max_length", 128))
88
+
89
+ @staticmethod
90
+ def _effective_guidance_scale(
91
+ timestep: Union[torch.Tensor, float],
92
+ guidance_scale: float,
93
+ do_cfg: bool,
94
+ guidance_interval_min: float,
95
+ guidance_interval_max: float,
96
+ ) -> float:
97
+ """Match official AdamLMSampler: CFG when t > min and t < max."""
98
+ if not do_cfg:
99
+ return 1.0
100
+ t = float(timestep)
101
+ if t > guidance_interval_min and t < guidance_interval_max:
102
+ return float(guidance_scale)
103
+ return 1.0
104
+
105
+ @staticmethod
106
+ def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
107
+ return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
108
+
109
+ def encode_prompt(
110
+ self,
111
+ prompt: Union[str, List[str]],
112
+ negative_prompt: Optional[Union[str, List[str]]] = None,
113
+ device: Optional[torch.device] = None,
114
+ dtype: Optional[torch.dtype] = None,
115
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
116
+ if self.text_encoder is None or self.tokenizer is None:
117
+ raise ValueError("text_encoder and tokenizer must be loaded for t2i inference.")
118
+
119
+ device = device or self._execution_device
120
+ dtype = dtype or next(self.text_encoder.parameters()).dtype
121
+
122
+ if isinstance(prompt, str):
123
+ prompt = [prompt]
124
+ batch_size = len(prompt)
125
+
126
+ if negative_prompt is None:
127
+ negative_prompt = [""] * batch_size
128
+ elif isinstance(negative_prompt, str):
129
+ negative_prompt = [negative_prompt] * batch_size
130
+
131
+ def _encode(texts: List[str]) -> torch.Tensor:
132
+ tokenized = self.tokenizer(
133
+ texts,
134
+ truncation=True,
135
+ max_length=self.txt_max_length,
136
+ padding="max_length",
137
+ return_tensors="pt",
138
+ )
139
+ input_ids = tokenized.input_ids.to(device)
140
+ attention_mask = tokenized.attention_mask.to(device)
141
+ outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
142
+ hidden = outputs[0]
143
+ embed_dim = self.txt_embed_dim
144
+ if hidden.shape[-1] < embed_dim:
145
+ pad = torch.zeros(
146
+ hidden.shape[0],
147
+ hidden.shape[1],
148
+ embed_dim - hidden.shape[-1],
149
+ device=hidden.device,
150
+ dtype=hidden.dtype,
151
+ )
152
+ hidden = torch.cat([hidden, pad], dim=-1)
153
+ elif hidden.shape[-1] > embed_dim:
154
+ hidden = hidden[:, :, :embed_dim]
155
+ return hidden.to(dtype=dtype)
156
+
157
+ return _encode(prompt), _encode(negative_prompt)
158
+
159
+ def _default_sample_size(self) -> int:
160
+ return int(getattr(self.transformer.config, "sample_size", 512))
161
+
162
+ @torch.no_grad()
163
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available())
164
+ def __call__(
165
+ self,
166
+ prompt: Optional[Union[str, List[str]]] = None,
167
+ negative_prompt: Optional[Union[str, List[str]]] = None,
168
+ prompt_embeds: Optional[torch.Tensor] = None,
169
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
170
+ height: Optional[int] = None,
171
+ width: Optional[int] = None,
172
+ num_inference_steps: int = 25,
173
+ guidance_scale: float = 4.0,
174
+ timeshift: Optional[float] = None,
175
+ order: Optional[int] = None,
176
+ guidance_interval_min: Optional[float] = None,
177
+ guidance_interval_max: Optional[float] = None,
178
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
179
+ output_type: str = "pil",
180
+ return_dict: bool = True,
181
+ ) -> Union[ImagePipelineOutput, Tuple]:
182
+ device = self._execution_device
183
+ dtype = next(self.transformer.parameters()).dtype
184
+ do_cfg = guidance_scale is not None and float(guidance_scale) > 1.0
185
+
186
+ if prompt_embeds is not None:
187
+ batch_size = int(prompt_embeds.shape[0])
188
+ elif prompt is None:
189
+ raise ValueError("Either `prompt` or `prompt_embeds` must be provided.")
190
+ elif isinstance(prompt, str):
191
+ batch_size = 1
192
+ else:
193
+ batch_size = len(prompt)
194
+
195
+ sample_size = self._default_sample_size()
196
+ height = int(height if height is not None else sample_size)
197
+ width = int(width if width is not None else sample_size)
198
+ height = height // 16 * 16
199
+ width = width // 16 * 16
200
+
201
+ interval_min = (
202
+ float(guidance_interval_min)
203
+ if guidance_interval_min is not None
204
+ else float(getattr(self.scheduler.config, "guidance_interval_min", 0.0))
205
+ )
206
+ interval_max = (
207
+ float(guidance_interval_max)
208
+ if guidance_interval_max is not None
209
+ else float(getattr(self.scheduler.config, "guidance_interval_max", 1.0))
210
+ )
211
+
212
+ if prompt_embeds is None:
213
+ prompt_embeds, negative_embeds = self.encode_prompt(
214
+ prompt=prompt,
215
+ negative_prompt=negative_prompt,
216
+ device=device,
217
+ dtype=dtype,
218
+ )
219
+ else:
220
+ negative_embeds = negative_prompt_embeds
221
+ if negative_embeds is None:
222
+ negative_embeds = torch.zeros_like(prompt_embeds)
223
+
224
+ # Official DeCo t2i: float32 noise on CPU, then move to device (see app.py / GenEval).
225
+ noise_shape = (batch_size, int(self.transformer.config.in_channels), height, width)
226
+ if generator is not None:
227
+ gen_device = getattr(generator, "device", None)
228
+ if gen_device is not None and str(gen_device).startswith("cuda"):
229
+ latents = randn_tensor(
230
+ noise_shape, generator=generator, device=device, dtype=torch.float32
231
+ )
232
+ else:
233
+ latents = randn_tensor(
234
+ noise_shape, generator=generator, device="cpu", dtype=torch.float32
235
+ ).to(device)
236
+ else:
237
+ latents = randn_tensor(noise_shape, device="cpu", dtype=torch.float32).to(device)
238
+
239
+ set_kwargs = {
240
+ "num_inference_steps": num_inference_steps,
241
+ "guidance_scale": guidance_scale,
242
+ "device": device,
243
+ }
244
+ if timeshift is not None:
245
+ set_kwargs["timeshift"] = timeshift
246
+ if order is not None:
247
+ set_kwargs["order"] = order
248
+ self.scheduler.set_timesteps(**set_kwargs)
249
+
250
+ cfg_condition = torch.cat([negative_embeds, prompt_embeds], dim=0)
251
+ pred_trajectory: list[torch.Tensor] = []
252
+ t_cur = torch.zeros(batch_size, device=device, dtype=torch.float32)
253
+ timedeltas = self.scheduler._timedeltas
254
+ solver_coeffs = self.scheduler._solver_coeffs
255
+
256
+ for i in self.progress_bar(range(len(timedeltas))):
257
+ cfg_x = torch.cat([latents, latents], dim=0)
258
+ cfg_t = t_cur.repeat(2)
259
+ out = self.transformer(cfg_x, cfg_t, encoder_hidden_states=cfg_condition, decoder=self.decoder).sample
260
+
261
+ if do_cfg and t_cur[0] > interval_min and t_cur[0] < interval_max:
262
+ cfg_scale = float(guidance_scale)
263
+ else:
264
+ cfg_scale = 1.0
265
+ uncond, cond = out.chunk(2, dim=0)
266
+ out = uncond + cfg_scale * (cond - uncond)
267
+
268
+ pred_trajectory.append(out)
269
+ combined = torch.zeros_like(out)
270
+ order = len(solver_coeffs[i])
271
+ for j in range(order):
272
+ combined = combined + solver_coeffs[i][j] * pred_trajectory[-order:][j]
273
+ latents = latents + combined * timedeltas[i]
274
+ t_cur = t_cur + timedeltas[i]
275
+
276
+ if output_type == "latent":
277
+ if not return_dict:
278
+ return (latents,)
279
+ return ImagePipelineOutput(images=latents)
280
+
281
+ images_uint8 = self._fp_to_uint8(latents.float()).permute(0, 2, 3, 1).cpu().numpy()
282
+ if output_type == "pil":
283
+ image = self.numpy_to_pil(images_uint8)
284
+ elif output_type == "np":
285
+ image = images_uint8
286
+ else:
287
+ raise ValueError("output_type must be one of {'pil', 'np', 'latent'}")
288
+
289
+ if not return_dict:
290
+ return (image,)
291
+ return ImagePipelineOutput(images=image)
DeCo-XXL-16-512-t2i/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeCoFlowMatchAdamDiscreteScheduler",
3
+ "_diffusers_version": "0.31.0",
4
+ "num_train_timesteps": 1000,
5
+ "num_inference_steps": 25,
6
+ "guidance_scale": 4.0,
7
+ "timeshift": 3.0,
8
+ "order": 2,
9
+ "guidance_interval_min": 0.0,
10
+ "guidance_interval_max": 1.0,
11
+ "last_step": null,
12
+ "prediction_type": "v_prediction"
13
+ }
DeCo-XXL-16-512-t2i/scheduler/scheduling_deco_flow_match_adam_discrete.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow-matching AdamLM scheduler matching Zehong-Ma/DeCo AdamLMSampler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class DeCoFlowMatchAdamSchedulerOutput:
15
+ prev_sample: torch.Tensor
16
+
17
+
18
+ class DeCoFlowMatchAdamDiscreteScheduler(SchedulerMixin, ConfigMixin):
19
+ """AdamLM multi-step flow-matching sampler (order=2 by default for t2i)."""
20
+
21
+ config_name = "scheduler_config.json"
22
+ order = 1
23
+ init_noise_sigma = 1.0
24
+
25
+ @staticmethod
26
+ def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
27
+ ts = [float(v) for v in pre_ts[-order:].tolist()]
28
+ a = float(t_start)
29
+ b = float(t_end)
30
+
31
+ if order == 1:
32
+ return [1.0]
33
+ if order == 2:
34
+ t1, t2 = ts
35
+ int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
36
+ int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
37
+ total = int1 + int2
38
+ return [int1 / total, int2 / total]
39
+ if order == 3:
40
+ t1, t2, t3 = ts
41
+ int1_denom = (t1 - t2) * (t1 - t3)
42
+ int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
43
+ (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
44
+ )
45
+ int1 = int1 / int1_denom
46
+ int2_denom = (t2 - t1) * (t2 - t3)
47
+ int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
48
+ (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
49
+ )
50
+ int2 = int2 / int2_denom
51
+ int3_denom = (t3 - t1) * (t3 - t2)
52
+ int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
53
+ (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
54
+ )
55
+ int3 = int3 / int3_denom
56
+ total = int1 + int2 + int3
57
+ return [int1 / total, int2 / total, int3 / total]
58
+ raise ValueError(f"Unsupported solver order: {order}.")
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ num_train_timesteps: int = 1000,
64
+ num_inference_steps: int = 25,
65
+ guidance_scale: float = 4.0,
66
+ timeshift: float = 3.0,
67
+ order: int = 2,
68
+ guidance_interval_min: float = 0.0,
69
+ guidance_interval_max: float = 1.0,
70
+ last_step: Optional[float] = None,
71
+ prediction_type: str = "v_prediction",
72
+ ) -> None:
73
+ self.num_inference_steps = int(num_inference_steps)
74
+ self.guidance_scale = float(guidance_scale)
75
+ self.timeshift = float(timeshift)
76
+ self.order = int(order)
77
+ self.guidance_interval_min = float(guidance_interval_min)
78
+ self.guidance_interval_max = float(guidance_interval_max)
79
+ self.last_step = last_step
80
+ self._reset_state()
81
+
82
+ def _reset_state(self) -> None:
83
+ self.timesteps: Optional[torch.Tensor] = None
84
+ self._timedeltas: Optional[torch.Tensor] = None
85
+ self._solver_coeffs: Optional[List[List[float]]] = None
86
+ self._model_outputs: List[torch.Tensor] = []
87
+ self._step_index = 0
88
+
89
+ @staticmethod
90
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
91
+ return t / (t + (1 - t) * shift)
92
+
93
+ def _build_solver_state(
94
+ self,
95
+ num_inference_steps: int,
96
+ timeshift: float,
97
+ device: Optional[Union[str, torch.device]] = None,
98
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
99
+ last_step = self.last_step
100
+ if last_step is None:
101
+ last_step = 1.0 / float(num_inference_steps)
102
+
103
+ endpoints = torch.linspace(0.0, 1.0 - float(last_step), int(num_inference_steps), dtype=torch.float32)
104
+ endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
105
+ timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
106
+ timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
107
+
108
+ solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
109
+ for i in range(int(num_inference_steps)):
110
+ order = min(self.order, i + 1)
111
+ pre_ts = timesteps[: i + 1]
112
+ coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
113
+ solver_coeffs[i] = coeffs
114
+ return timesteps[:-1], timedeltas, solver_coeffs
115
+
116
+ def set_timesteps(
117
+ self,
118
+ num_inference_steps: Optional[int] = None,
119
+ device: Optional[Union[str, torch.device]] = None,
120
+ timeshift: Optional[float] = None,
121
+ guidance_scale: Optional[float] = None,
122
+ order: Optional[int] = None,
123
+ **kwargs: Any,
124
+ ) -> None:
125
+ if num_inference_steps is not None:
126
+ self.num_inference_steps = int(num_inference_steps)
127
+ if timeshift is not None:
128
+ self.timeshift = float(timeshift)
129
+ else:
130
+ self.timeshift = float(getattr(self.config, "timeshift", self.timeshift))
131
+ if guidance_scale is not None:
132
+ self.guidance_scale = float(guidance_scale)
133
+ if order is not None:
134
+ self.order = int(order)
135
+ else:
136
+ self.order = int(getattr(self.config, "order", self.order))
137
+
138
+ timesteps, timedeltas, solver_coeffs = self._build_solver_state(
139
+ self.num_inference_steps,
140
+ self.timeshift,
141
+ device=device,
142
+ )
143
+ self.timesteps = timesteps
144
+ self._timedeltas = timedeltas
145
+ self._solver_coeffs = solver_coeffs
146
+ self._model_outputs = []
147
+ self._step_index = 0
148
+
149
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
150
+ return sample
151
+
152
+ def classifier_free_guidance(
153
+ self,
154
+ model_output: torch.Tensor,
155
+ guidance_scale: Optional[float] = None,
156
+ ) -> torch.Tensor:
157
+ if model_output.shape[0] % 2 != 0:
158
+ raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
159
+ scale = self.guidance_scale if guidance_scale is None else float(guidance_scale)
160
+ uncond, cond = model_output.chunk(2, dim=0)
161
+ return uncond + scale * (cond - uncond)
162
+
163
+ def step(
164
+ self,
165
+ model_output: torch.Tensor,
166
+ timestep: Union[torch.Tensor, float, int],
167
+ sample: torch.Tensor,
168
+ return_dict: bool = True,
169
+ **kwargs: Any,
170
+ ) -> Union[DeCoFlowMatchAdamSchedulerOutput, Tuple[torch.Tensor]]:
171
+ del timestep, kwargs
172
+ if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
173
+ raise RuntimeError("`set_timesteps` must be called before `step`.")
174
+ if self._step_index >= len(self._solver_coeffs):
175
+ raise RuntimeError("Scheduler step index exceeded configured timesteps.")
176
+
177
+ coeffs = self._solver_coeffs[self._step_index]
178
+ self._model_outputs.append(model_output)
179
+ order = len(coeffs)
180
+ pred = torch.zeros_like(model_output)
181
+ recent = self._model_outputs[-order:]
182
+ for coeff, output in zip(coeffs, recent):
183
+ pred = pred + coeff * output
184
+
185
+ prev_sample = sample + pred * self._timedeltas[self._step_index]
186
+ self._step_index += 1
187
+
188
+ if not return_dict:
189
+ return (prev_sample,)
190
+ return DeCoFlowMatchAdamSchedulerOutput(prev_sample=prev_sample)
191
+
192
+ def add_noise(
193
+ self,
194
+ original_samples: torch.Tensor,
195
+ noise: torch.Tensor,
196
+ timesteps: torch.Tensor,
197
+ ) -> torch.Tensor:
198
+ alpha = timesteps.view(-1, 1, 1, 1)
199
+ sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
200
+ return alpha * original_samples + sigma * noise
DeCo-XXL-16-512-t2i/scheduler/scheduling_deco_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
9
+
10
+
11
+
12
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
13
+ return t / (t + (1 - t) * shift)
14
+
15
+
16
+ class DeCoFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
17
+ config_name = "scheduler_config.json"
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ num_train_timesteps: int = 1000,
23
+ shift: float = 1.0,
24
+ last_step: Optional[float] = None,
25
+ prediction_type: str = "v_prediction",
26
+ ):
27
+ self.timesteps = torch.tensor([], dtype=torch.float32)
28
+ self.num_inference_steps: Optional[int] = None
29
+ self._step_index: int = 0
30
+
31
+ @property
32
+ def init_noise_sigma(self) -> float:
33
+ return 1.0
34
+
35
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
36
+ if num_inference_steps <= 0:
37
+ raise ValueError("num_inference_steps must be > 0")
38
+
39
+ self.num_inference_steps = int(num_inference_steps)
40
+ last_step = self.config.last_step
41
+ if last_step is None:
42
+ last_step = 1.0 / float(self.num_inference_steps)
43
+
44
+ base_timesteps = torch.linspace(0.0, 1.0 - float(last_step), self.num_inference_steps, dtype=torch.float32)
45
+ base_timesteps = torch.cat([base_timesteps, torch.tensor([1.0], dtype=torch.float32)], dim=0)
46
+ timesteps = _shift_respace_fn(base_timesteps, shift=float(self.config.shift))
47
+
48
+ if device is not None:
49
+ timesteps = timesteps.to(device)
50
+
51
+ self.timesteps = timesteps
52
+ self._step_index = 0
53
+
54
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
55
+ return sample
56
+
57
+ def step(
58
+ self,
59
+ model_output: torch.Tensor,
60
+ timestep: Union[torch.Tensor, float],
61
+ sample: torch.Tensor,
62
+ return_dict: bool = True,
63
+ ):
64
+ if self.num_inference_steps is None or self.timesteps.numel() == 0:
65
+ raise ValueError("Call set_timesteps before step")
66
+
67
+ step_index = min(self._step_index, len(self.timesteps) - 2)
68
+ dt = (self.timesteps[step_index + 1] - self.timesteps[step_index]).to(device=sample.device, dtype=sample.dtype)
69
+
70
+ prev_sample = sample + model_output * dt
71
+
72
+ self._step_index += 1
73
+
74
+ if not return_dict:
75
+ return (prev_sample,)
76
+ return SchedulerOutput(prev_sample=prev_sample)
77
+
78
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
79
+ if timesteps.ndim == 0:
80
+ timesteps = timesteps[None]
81
+ t = timesteps.to(device=original_samples.device, dtype=original_samples.dtype).view(-1, 1, 1, 1)
82
+ return t * original_samples + (1.0 - t) * noise
DeCo-XXL-16-512-t2i/scripts/run_t2i_demo.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Full t2i demo: load local Qwen text encoder and save demo.png."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from diffusers import DiffusionPipeline
10
+ MODEL_DIR = Path(__file__).resolve().parents[1]
11
+
12
+
13
+ def main() -> None:
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
16
+
17
+ pipe = DiffusionPipeline.from_pretrained(
18
+ str(MODEL_DIR),
19
+ local_files_only=True,
20
+ custom_pipeline=str(MODEL_DIR / "pipeline.py"),
21
+ trust_remote_code=True,
22
+ torch_dtype=dtype,
23
+ )
24
+ print("text_encoder:", type(pipe.text_encoder).__name__)
25
+ pipe.to(device)
26
+
27
+ prompt = "a golden retriever playing in the snow, high quality photograph"
28
+ # Official DeCo uses CPU generator for reproducible noise (app.py / GenEval).
29
+ generator = torch.Generator(device="cpu").manual_seed(42)
30
+
31
+ print("generating...", prompt)
32
+ result = pipe(
33
+ prompt=prompt,
34
+ negative_prompt="Unrealistic, JPEG artifacts.",
35
+ num_inference_steps=25,
36
+ guidance_scale=4.0,
37
+ generator=generator,
38
+ output_type="pil",
39
+ )
40
+ image = result.images[0]
41
+ out_path = MODEL_DIR / "demo.png"
42
+ image.save(out_path)
43
+ print("saved", out_path, image.size)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
DeCo-XXL-16-512-t2i/scripts/test_t2i_load.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Smoke test: load converted DeCo-XXL-16-512-t2i and run 2-step denoise with dummy text."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from diffusers import DiffusionPipeline
11
+
12
+ MODEL_DIR = Path(__file__).resolve().parents[1]
13
+
14
+
15
+ def main() -> None:
16
+ pipe = DiffusionPipeline.from_pretrained(
17
+ str(MODEL_DIR),
18
+ local_files_only=True,
19
+ custom_pipeline=str(MODEL_DIR / "pipeline.py"),
20
+ trust_remote_code=True,
21
+ torch_dtype=torch.float32,
22
+ )
23
+ assert pipe.decoder is not None and pipe.transformer is not None
24
+
25
+ batch_size = 1
26
+ seq_len = int(pipe.transformer.config.txt_max_length)
27
+ embed_dim = int(pipe.transformer.config.txt_embed_dim)
28
+ hidden = torch.randn(batch_size, seq_len, embed_dim)
29
+
30
+ torch.manual_seed(0)
31
+ with torch.inference_mode():
32
+ result = pipe.transformer(
33
+ torch.randn(batch_size, 3, 512, 512),
34
+ 0.5,
35
+ encoder_hidden_states=hidden,
36
+ decoder=pipe.decoder,
37
+ return_dict=True,
38
+ )
39
+
40
+ out = result.sample
41
+ assert out.shape == (batch_size, 3, 512, 512)
42
+ print("transformer:", type(pipe.transformer).__name__)
43
+ print("decoder:", type(pipe.decoder).__name__)
44
+ print("output shape:", tuple(out.shape))
45
+ print("ok")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ try:
50
+ main()
51
+ except Exception as exc:
52
+ print(f"FAILED: {exc}", file=sys.stderr)
53
+ raise
DeCo-XXL-16-512-t2i/transformer/__pycache__/transformer_deco_t2i.cpython-312.pyc ADDED
Binary file (29.7 kB). View file
 
DeCo-XXL-16-512-t2i/transformer/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sample_size": 512,
3
+ "conditioning_type": "text",
4
+ "decoder_hidden_size": 32,
5
+ "deep_supervision": 0,
6
+ "hidden_size": 1536,
7
+ "hidden_size_x": 32,
8
+ "in_channels": 3,
9
+ "learn_sigma": true,
10
+ "nerf_mlpratio": 4,
11
+ "num_blocks": 19,
12
+ "num_classes": 0,
13
+ "num_cond_blocks": 16,
14
+ "num_decoder_blocks": 3,
15
+ "num_encoder_blocks": 16,
16
+ "num_groups": 24,
17
+ "num_text_blocks": 4,
18
+ "patch_size": 16,
19
+ "txt_embed_dim": 2048,
20
+ "txt_max_length": 128
21
+ }
DeCo-XXL-16-512-t2i/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dacd4b933da1f5326035b1f63d21d127aa72f598d97eb07ca57b54c0dab9b08
3
+ size 4484623152
DeCo-XXL-16-512-t2i/transformer/transformer_deco_t2i.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+ class RMSNorm(nn.Module):
18
+ """Match Zehong-Ma/DeCo `src.models.layers.rmsnorm.RMSNorm` (not diffusers variant)."""
19
+
20
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
21
+ super().__init__()
22
+ self.weight = nn.Parameter(torch.ones(hidden_size))
23
+ self.variance_epsilon = eps
24
+
25
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
26
+ input_dtype = hidden_states.dtype
27
+ hidden_states = hidden_states.to(torch.float32)
28
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
29
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
30
+ return (self.weight * hidden_states).to(input_dtype)
31
+
32
+
33
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
34
+ return x * (1 + scale) + shift
35
+
36
+
37
+ class PatchEmbed(nn.Module):
38
+ def __init__(self, in_chans: int, embed_dim: int, bias: bool = True):
39
+ super().__init__()
40
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ return self.proj(x)
44
+
45
+
46
+ class TimestepEmbedder(nn.Module):
47
+ """Sinusoidal timestep embedding with checkpoint-compatible `mlp` module names."""
48
+
49
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
50
+ super().__init__()
51
+ self.mlp = nn.Sequential(
52
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
53
+ nn.SiLU(),
54
+ nn.Linear(hidden_size, hidden_size, bias=True),
55
+ )
56
+ self.frequency_embedding_size = frequency_embedding_size
57
+
58
+ @staticmethod
59
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10) -> torch.Tensor:
60
+ half = dim // 2
61
+ freqs = torch.exp(
62
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
63
+ )
64
+ args = t[..., None].float() * freqs[None, ...]
65
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
66
+ if dim % 2:
67
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
68
+ return embedding.to(t.dtype)
69
+
70
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
71
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72
+ return self.mlp(t_freq)
73
+
74
+
75
+ class DeCoSwiGLU(nn.Module):
76
+ """SwiGLU MLP with w1/w2/w3 layout matching official DeCo checkpoints."""
77
+
78
+ def __init__(self, dim: int, hidden_dim: int):
79
+ super().__init__()
80
+ hidden_dim = int(2 * hidden_dim / 3)
81
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
82
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
83
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
87
+
88
+
89
+ def precompute_freqs_cis_2d(
90
+ dim: int,
91
+ height: int,
92
+ width: int,
93
+ theta: float = 10000.0,
94
+ scale: float = 1.0,
95
+ ) -> torch.Tensor:
96
+ """Official t2i uses `precompute_freqs_cis_ex2d` (aliased as precompute_freqs_cis_2d)."""
97
+ if isinstance(scale, float):
98
+ scale = (scale, scale)
99
+ x_pos = torch.linspace(0, height * scale[0], width)
100
+ y_pos = torch.linspace(0, width * scale[1], height)
101
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
102
+ y_pos = y_pos.reshape(-1)
103
+ x_pos = x_pos.reshape(-1)
104
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
105
+ x_freqs = torch.outer(x_pos, freqs).float()
106
+ y_freqs = torch.outer(y_pos, freqs).float()
107
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
108
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
109
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
110
+ return freqs_cis.reshape(height * width, -1)
111
+
112
+
113
+ def apply_rotary_emb(
114
+ xq: torch.Tensor,
115
+ xk: torch.Tensor,
116
+ freqs_cis: torch.Tensor,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ freqs_cis = freqs_cis[None, None, :, :]
119
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
120
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
121
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
122
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
123
+ return xq_out.type_as(xq), xk_out.type_as(xk)
124
+
125
+
126
+
127
+ class DeCoT2ISwiGLU(nn.Module):
128
+ """Official DeCo-XXL t2i SwiGLU (w12/w3), distinct from c2i w1/w2/w3 layout."""
129
+
130
+ def __init__(self, dim: int, hidden_dim: int):
131
+ super().__init__()
132
+ self.w12 = nn.Linear(dim, hidden_dim * 2, bias=False)
133
+ self.w3 = nn.Linear(hidden_dim, dim, bias=False)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ x1, x2 = self.w12(x).chunk(2, dim=-1)
137
+ return self.w3(F.silu(x1) * x2)
138
+
139
+
140
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
141
+ return x * (1 + scale) + shift
142
+
143
+
144
+ class TextEmbedder(nn.Module):
145
+ def __init__(self, in_channels: int, embed_dim: int, bias: bool = True):
146
+ super().__init__()
147
+ self.proj = nn.Linear(in_channels, embed_dim, bias=bias)
148
+ self.norm = RMSNorm(embed_dim, eps=1e-6)
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.norm(self.proj(x))
152
+
153
+
154
+ class CrossAttention(nn.Module):
155
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_drop: float = 0.0):
156
+ super().__init__()
157
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
158
+ self.num_heads = num_heads
159
+ self.head_dim = dim // num_heads
160
+ self.qkv_x = nn.Linear(dim, dim * 3, bias=qkv_bias)
161
+ self.kv_y = nn.Linear(dim, dim * 2, bias=qkv_bias)
162
+ self.q_norm = RMSNorm(self.head_dim, eps=1e-6)
163
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-6)
164
+ self.proj = nn.Linear(dim, dim)
165
+ self.proj_drop = nn.Dropout(proj_drop)
166
+
167
+ def forward(self, x: torch.Tensor, y: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
168
+ batch_size, num_tokens, channels = x.shape
169
+ qkv_x = self.qkv_x(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
170
+ query, key_x, value_x = qkv_x[0], qkv_x[1], qkv_x[2]
171
+ query = self.q_norm(query.contiguous())
172
+ key_x = self.k_norm(key_x.contiguous())
173
+ query, key_x = apply_rotary_emb(query, key_x, freqs_cis=pos)
174
+
175
+ kv_y = self.kv_y(y).reshape(batch_size, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
176
+ key_y, value_y = kv_y[0], kv_y[1]
177
+ key_y = self.k_norm(key_y.contiguous())
178
+
179
+ key = torch.cat([key_x, key_y], dim=2)
180
+ value = torch.cat([value_x, value_y], dim=2)
181
+
182
+ query = query.view(batch_size, self.num_heads, -1, self.head_dim)
183
+ key = key.view(batch_size, self.num_heads, -1, self.head_dim).contiguous()
184
+ value = value.view(batch_size, self.num_heads, -1, self.head_dim).contiguous()
185
+ out = scaled_dot_product_attention(query, key, value, dropout_p=0.0)
186
+ out = out.transpose(1, 2).reshape(batch_size, num_tokens, channels)
187
+ return self.proj_drop(self.proj(out))
188
+
189
+
190
+ class TextRefineAttention(nn.Module):
191
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_drop: float = 0.0):
192
+ super().__init__()
193
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
194
+ self.num_heads = num_heads
195
+ self.head_dim = dim // num_heads
196
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
197
+ self.q_norm = RMSNorm(self.head_dim, eps=1e-6)
198
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-6)
199
+ self.proj = nn.Linear(dim, dim)
200
+ self.proj_drop = nn.Dropout(proj_drop)
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ batch_size, num_tokens, channels = x.shape
204
+ qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
205
+ query, key, value = qkv[0], qkv[1], qkv[2]
206
+ query = self.q_norm(query.contiguous())
207
+ key = self.k_norm(key.contiguous())
208
+ query = query.view(batch_size, self.num_heads, -1, self.head_dim)
209
+ key = key.view(batch_size, self.num_heads, -1, self.head_dim).contiguous()
210
+ value = value.view(batch_size, self.num_heads, -1, self.head_dim).contiguous()
211
+ out = scaled_dot_product_attention(query, key, value, dropout_p=0.0)
212
+ out = out.transpose(1, 2).reshape(batch_size, num_tokens, channels)
213
+ return self.proj_drop(self.proj(out))
214
+
215
+
216
+ class T2IFlattenDiTBlock(nn.Module):
217
+ def __init__(self, hidden_size: int, groups: int, mlp_ratio: float = 4.0):
218
+ super().__init__()
219
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
220
+ self.attn = CrossAttention(hidden_size, num_heads=groups, qkv_bias=False)
221
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
222
+ self.mlp = DeCoT2ISwiGLU(hidden_size, int(hidden_size * mlp_ratio))
223
+ self.adaLN_modulation = nn.Sequential(nn.Linear(hidden_size, 6 * hidden_size, bias=True))
224
+
225
+ def forward(self, x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
226
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
227
+ x = x + gate_msa * self.attn(_modulate(self.norm1(x), shift_msa, scale_msa), y, pos)
228
+ return x + gate_mlp * self.mlp(_modulate(self.norm2(x), shift_mlp, scale_mlp))
229
+
230
+
231
+ class TextRefineBlock(nn.Module):
232
+ def __init__(self, hidden_size: int, groups: int, mlp_ratio: float = 4.0):
233
+ super().__init__()
234
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
235
+ self.attn = TextRefineAttention(hidden_size, num_heads=groups, qkv_bias=False)
236
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
237
+ self.mlp = DeCoT2ISwiGLU(hidden_size, int(hidden_size * mlp_ratio))
238
+ self.adaLN_modulation = nn.Sequential(nn.Linear(hidden_size, 6 * hidden_size, bias=True))
239
+
240
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
241
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
242
+ x = x + gate_msa * self.attn(_modulate(self.norm1(x), shift_msa, scale_msa))
243
+ return x + gate_mlp * self.mlp(_modulate(self.norm2(x), shift_mlp, scale_mlp))
244
+
245
+
246
+ @dataclass
247
+ class DeCoT2ITransformer2DModelOutput(BaseOutput):
248
+ sample: torch.Tensor
249
+
250
+
251
+ class _DeCoT2ITransformerBackbone(nn.Module):
252
+ def __init__(
253
+ self,
254
+ in_channels: int,
255
+ patch_size: int,
256
+ num_groups: int,
257
+ hidden_size: int,
258
+ num_encoder_blocks: int,
259
+ num_text_blocks: int,
260
+ txt_embed_dim: int,
261
+ txt_max_length: int,
262
+ ):
263
+ super().__init__()
264
+ self.in_channels = in_channels
265
+ self.patch_size = patch_size
266
+ self.hidden_size = hidden_size
267
+ self.num_groups = num_groups
268
+ self.num_encoder_blocks = num_encoder_blocks
269
+ self.txt_max_length = txt_max_length
270
+
271
+ self.s_embedder = PatchEmbed(in_channels * patch_size**2, hidden_size, bias=True)
272
+ self.t_embedder = TimestepEmbedder(hidden_size)
273
+ self.y_embedder = TextEmbedder(txt_embed_dim, hidden_size, bias=True)
274
+ self.y_pos_embedding = nn.Parameter(torch.randn(1, txt_max_length, hidden_size))
275
+ self.blocks = nn.ModuleList(
276
+ [T2IFlattenDiTBlock(hidden_size, num_groups) for _ in range(num_encoder_blocks)]
277
+ )
278
+ self.text_refine_blocks = nn.ModuleList(
279
+ [TextRefineBlock(hidden_size, num_groups) for _ in range(num_text_blocks)]
280
+ )
281
+ self.precompute_pos: dict[tuple[int, int], torch.Tensor] = {}
282
+ self._init_weights()
283
+
284
+ def _init_weights(self) -> None:
285
+ weight = self.s_embedder.proj.weight.data
286
+ nn.init.xavier_uniform_(weight.view([weight.shape[0], -1]))
287
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
288
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
289
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
290
+
291
+ def fetch_pos(self, height: int, width: int, device: torch.device) -> torch.Tensor:
292
+ key = (height, width)
293
+ if key not in self.precompute_pos:
294
+ self.precompute_pos[key] = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
295
+ return self.precompute_pos[key].to(device)
296
+
297
+ def forward(
298
+ self,
299
+ x: torch.Tensor,
300
+ t: torch.Tensor,
301
+ encoder_hidden_states: torch.Tensor,
302
+ decoder: nn.Module,
303
+ ) -> torch.Tensor:
304
+ batch_size, _, height, width = x.shape
305
+ pos = self.fetch_pos(height // self.patch_size, width // self.patch_size, x.device)
306
+ x = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
307
+ t = self.t_embedder(t.view(-1)).view(batch_size, -1, self.hidden_size)
308
+ y = self.y_embedder(encoder_hidden_states) + self.y_pos_embedding.to(encoder_hidden_states.dtype)
309
+ condition = F.silu(t)
310
+
311
+ for block in self.text_refine_blocks:
312
+ y = block(y, condition)
313
+
314
+ s = self.s_embedder(x)
315
+ for block in self.blocks:
316
+ s = block(s, y, condition, pos)
317
+ s = F.silu(t + s)
318
+
319
+ batch_size, length, _ = s.shape
320
+ patch_pixels = x.reshape(batch_size * length, self.in_channels, self.patch_size**2).transpose(1, 2)
321
+ conditioning = s.view(batch_size * length, self.hidden_size)
322
+ decoded = decoder(patch_pixels, conditioning).sample
323
+ x = decoded.transpose(1, 2).reshape(batch_size, length, -1)
324
+ return F.fold(
325
+ x.transpose(1, 2).contiguous(),
326
+ (height, width),
327
+ kernel_size=self.patch_size,
328
+ stride=self.patch_size,
329
+ )
330
+
331
+
332
+ class DeCoT2ITransformer2DModel(ModelMixin, ConfigMixin):
333
+ config_name = "config.json"
334
+
335
+ @register_to_config
336
+ def __init__(
337
+ self,
338
+ in_channels: int = 3,
339
+ patch_size: int = 16,
340
+ num_groups: int = 24,
341
+ hidden_size: int = 1536,
342
+ hidden_size_x: int = 32,
343
+ num_blocks: int = 19,
344
+ num_encoder_blocks: int = 16,
345
+ num_decoder_blocks: int = 3,
346
+ num_text_blocks: int = 4,
347
+ num_cond_blocks: int = 16,
348
+ num_classes: int = 0,
349
+ learn_sigma: bool = True,
350
+ deep_supervision: int = 0,
351
+ sample_size: int = 512,
352
+ conditioning_type: str = "text",
353
+ nerf_mlpratio: int = 4,
354
+ decoder_hidden_size: int = 32,
355
+ txt_embed_dim: int = 2048,
356
+ txt_max_length: int = 128,
357
+ ):
358
+ super().__init__()
359
+ del hidden_size_x, nerf_mlpratio, num_blocks, num_cond_blocks, num_classes, learn_sigma, deep_supervision
360
+ if conditioning_type != "text":
361
+ raise ValueError("DeCoT2ITransformer2DModel only supports text conditioning (t2i).")
362
+
363
+ self.backbone = _DeCoT2ITransformerBackbone(
364
+ in_channels=in_channels,
365
+ patch_size=patch_size,
366
+ num_groups=num_groups,
367
+ hidden_size=hidden_size,
368
+ num_encoder_blocks=num_encoder_blocks,
369
+ txt_embed_dim=txt_embed_dim,
370
+ txt_max_length=txt_max_length,
371
+ num_text_blocks=num_text_blocks,
372
+ )
373
+
374
+ @property
375
+ def in_channels(self) -> int:
376
+ return int(self.config.in_channels)
377
+
378
+ def _prepare_timestep(
379
+ self, timestep: Union[torch.Tensor, float, int], batch_size: int, sample: torch.Tensor
380
+ ) -> torch.Tensor:
381
+ if not isinstance(timestep, torch.Tensor):
382
+ timestep = torch.tensor([timestep], device=sample.device, dtype=sample.dtype)
383
+ timestep = timestep.to(device=sample.device, dtype=sample.dtype)
384
+ if timestep.ndim == 0:
385
+ timestep = timestep[None]
386
+ if timestep.shape[0] == 1 and batch_size > 1:
387
+ timestep = timestep.repeat(batch_size)
388
+ return timestep
389
+
390
+ def forward(
391
+ self,
392
+ sample: torch.Tensor,
393
+ timestep: Union[torch.Tensor, float, int],
394
+ encoder_hidden_states: Optional[torch.Tensor] = None,
395
+ decoder: Optional[nn.Module] = None,
396
+ class_labels: Optional[torch.Tensor] = None,
397
+ return_dict: bool = True,
398
+ ) -> Union[DeCoT2ITransformer2DModelOutput, tuple[torch.Tensor]]:
399
+ if class_labels is not None:
400
+ raise ValueError("class_labels are not supported; use encoder_hidden_states for t2i DeCo models.")
401
+ if encoder_hidden_states is None:
402
+ raise ValueError("encoder_hidden_states must be provided for text-conditioned DeCo models.")
403
+ if decoder is None:
404
+ raise ValueError("decoder must be provided; load DeCoPatchDecoderModel as a separate pipeline component.")
405
+
406
+ batch_size = sample.shape[0]
407
+ t = self._prepare_timestep(timestep=timestep, batch_size=batch_size, sample=sample)
408
+ output = self.backbone(sample, t, encoder_hidden_states, decoder=decoder)
409
+ if not return_dict:
410
+ return (output,)
411
+ return DeCoT2ITransformer2DModelOutput(sample=output)
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - deco
7
+ - image-generation
8
+ - class-conditional
9
+ - imagenet
10
+ license: mit
11
+ inference: true
12
+ widget:
13
+ - text: golden retriever
14
+ output:
15
+ url: DeCo-XL-16-512/demo.png
16
+ language:
17
+ - en
18
+ ---
19
+
20
+ # DeCo-diffusers
21
+
22
+ Diffusers-ready checkpoints for **DeCo** (Decoupled Conditioning), converted for local/offline use.
23
+
24
+ This root folder is a model collection that contains:
25
+
26
+ - `DeCo-XL-16-256`
27
+ - `DeCo-XL-16-512`
28
+ - `DeCo-XXL-16-512-t2i` (text-to-image; requires `Qwen/Qwen3-1.7B` text encoder)
29
+
30
+ Each subfolder is a self-contained Diffusers model repo with:
31
+
32
+ - `pipeline.py`
33
+ - `transformer/transformer_deco.py`
34
+ - `scheduler/scheduling_deco_flow_match_euler_discrete.py`
35
+ - `transformer/diffusion_pytorch_model.safetensors`
36
+ - `vae/autoencoder_deco.py`
37
+
38
+ Each variant embeds English `id2label` directly in `model_index.json` (DiT-style), so class labels can be passed as
39
+ ImageNet ids or English synonym strings.
40
+
41
+ - `pipe.id2label` — id → English label (comma-separated synonyms)
42
+ - `pipe.get_label_ids("golden retriever")` — English label → id
43
+
44
+ ## Demo
45
+
46
+ ![DeCo-XL-16-512 demo](DeCo-XL-16-512/demo.png)
47
+
48
+ Class-conditional sample (ImageNet class **207**, golden retriever), `DeCo-XL/16` at 512×512, 100 steps, CFG 5.0, seed 42.
49
+
50
+ ## Model Paths
51
+
52
+ Use paths relative to this root README:
53
+
54
+ | Model | Resolution | Source checkpoint | Local path |
55
+ | --- | ---: | --- | --- |
56
+ | DeCo-XL/16 | 256×256 | `imagenet256_epoch800.ckpt` (EMA) | `./DeCo-XL-16-256` |
57
+ | DeCo-XL/16 | 512×512 | `imagenet512_epoch340.ckpt` (EMA) | `./DeCo-XL-16-512` |
58
+ | DeCo-XXL/16 | 512×512 t2i | `t2i_DeCo.ckpt` (EMA) | `./DeCo-XXL-16-512-t2i` |
59
+
60
+ ## Inference Demo (Diffusers)
61
+
62
+ ### 1) Load a local subfolder checkpoint
63
+
64
+ ```python
65
+ import torch
66
+ from diffusers import DiffusionPipeline
67
+
68
+ model_path = "./DeCo-XL-16-512" # change to ./DeCo-XL-16-256 for 256px
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+
71
+ pipe = DiffusionPipeline.from_pretrained(
72
+ model_path,
73
+ trust_remote_code=True,
74
+ torch_dtype=torch.bfloat16,
75
+ ).to(device)
76
+
77
+ generator = torch.Generator(device=device).manual_seed(42)
78
+
79
+ # ImageNet class example: 207 = golden retriever
80
+ print(pipe.id2label[207])
81
+ print(pipe.get_label_ids("golden retriever")) # [207]
82
+
83
+ result = pipe(
84
+ class_labels="golden retriever",
85
+ num_inference_steps=100,
86
+ guidance_scale=5.0, # use 3.2 for DeCo-XL-16-256
87
+ generator=generator,
88
+ )
89
+
90
+ image = result.images[0]
91
+ image.save("deco_xl_512_demo.png")
92
+ ```
93
+
94
+ ### 2) Quick variant switch (256 model)
95
+
96
+ ```python
97
+ model_path = "./DeCo-XL-16-256"
98
+
99
+ pipe = DiffusionPipeline.from_pretrained(model_path, trust_remote_code=True).to(device)
100
+ image = pipe(
101
+ class_labels=207,
102
+ num_inference_steps=100,
103
+ guidance_scale=3.2,
104
+ generator=generator,
105
+ ).images[0]
106
+ image.save("deco_xl_256_demo.png")
107
+ ```
108
+
109
+ Integer class ids, batched labels, and optional `batch_size` for repeating a single label are also supported.
t2i_DeCo.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:433f55684b19c446d9fad4591f840fcdf9770a668c383ea910da86362651492f
3
+ size 4558758567