BiliSakura commited on
Commit
24196fc
·
verified ·
1 Parent(s): 7c3c8a2

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ProMoE-B-256/__pycache__/pipeline.cpython-312.pyc +0 -0
  2. ProMoE-B-256/model_index.json +1021 -0
  3. ProMoE-B-256/pipeline.py +259 -0
  4. ProMoE-B-256/scheduler/config.json +5 -0
  5. ProMoE-B-256/scheduler/scheduler_config.json +7 -0
  6. ProMoE-B-256/scheduler/scheduling_flow_match_promoe.py +43 -0
  7. ProMoE-B-256/transformer/backbone_diffmoe.py +302 -0
  8. ProMoE-B-256/transformer/backbone_dit.py +123 -0
  9. ProMoE-B-256/transformer/backbone_ecdit.py +220 -0
  10. ProMoE-B-256/transformer/backbone_promoe_ec.py +286 -0
  11. ProMoE-B-256/transformer/backbone_promoe_tc.py +355 -0
  12. ProMoE-B-256/transformer/backbone_tcdit.py +304 -0
  13. ProMoE-B-256/transformer/config.json +22 -0
  14. ProMoE-B-256/transformer/diffusion_pytorch_model.safetensors +3 -0
  15. ProMoE-B-256/transformer/modeling_promoe_common.py +291 -0
  16. ProMoE-B-256/transformer/transformer_promoe.py +137 -0
  17. ProMoE-B-256/vae/config.json +29 -0
  18. ProMoE-B-256/vae/diffusion_pytorch_model.safetensors +3 -0
  19. ProMoE-L-256/model_index.json +1021 -0
  20. ProMoE-L-256/pipeline.py +259 -0
  21. ProMoE-L-256/scheduler/config.json +5 -0
  22. ProMoE-L-256/scheduler/scheduler_config.json +7 -0
  23. ProMoE-L-256/scheduler/scheduling_flow_match_promoe.py +43 -0
  24. ProMoE-L-256/transformer/backbone_diffmoe.py +302 -0
  25. ProMoE-L-256/transformer/backbone_dit.py +123 -0
  26. ProMoE-L-256/transformer/backbone_ecdit.py +220 -0
  27. ProMoE-L-256/transformer/backbone_promoe_ec.py +286 -0
  28. ProMoE-L-256/transformer/backbone_promoe_tc.py +355 -0
  29. ProMoE-L-256/transformer/backbone_tcdit.py +304 -0
  30. ProMoE-L-256/transformer/config.json +22 -0
  31. ProMoE-L-256/transformer/diffusion_pytorch_model.safetensors +3 -0
  32. ProMoE-L-256/transformer/modeling_promoe_common.py +291 -0
  33. ProMoE-L-256/transformer/transformer_promoe.py +137 -0
  34. ProMoE-L-256/vae/config.json +29 -0
  35. ProMoE-L-256/vae/diffusion_pytorch_model.safetensors +3 -0
  36. ProMoE-XL-256/model_index.json +1021 -0
  37. ProMoE-XL-256/pipeline.py +259 -0
  38. ProMoE-XL-256/scheduler/config.json +5 -0
  39. ProMoE-XL-256/scheduler/scheduler_config.json +7 -0
  40. ProMoE-XL-256/scheduler/scheduling_flow_match_promoe.py +43 -0
  41. ProMoE-XL-256/transformer/backbone_diffmoe.py +302 -0
  42. ProMoE-XL-256/transformer/backbone_dit.py +123 -0
  43. ProMoE-XL-256/transformer/backbone_ecdit.py +220 -0
  44. ProMoE-XL-256/transformer/backbone_promoe_ec.py +286 -0
  45. ProMoE-XL-256/transformer/backbone_promoe_tc.py +355 -0
  46. ProMoE-XL-256/transformer/backbone_tcdit.py +304 -0
  47. ProMoE-XL-256/transformer/config.json +22 -0
  48. ProMoE-XL-256/transformer/diffusion_pytorch_model.safetensors +3 -0
  49. ProMoE-XL-256/transformer/modeling_promoe_common.py +291 -0
  50. ProMoE-XL-256/transformer/transformer_promoe.py +137 -0
ProMoE-B-256/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
ProMoE-B-256/model_index.json ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "ProMoEPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_promoe",
9
+ "ProMoEFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_promoe",
13
+ "ProMoETransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
ProMoE-B-256/pipeline.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: ProMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+
16
+ try:
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ except Exception: # pragma: no cover
19
+ class DiffusionPipeline:
20
+ def __init__(self):
21
+ self._execution_device = torch.device("cpu")
22
+
23
+ def register_modules(self, **kwargs):
24
+ for key, value in kwargs.items():
25
+ setattr(self, key, value)
26
+
27
+ def to(self, device):
28
+ self._execution_device = torch.device(device)
29
+ for module in (getattr(self, "transformer", None), getattr(self, "vae", None)):
30
+ if module is not None and hasattr(module, "to"):
31
+ module.to(device)
32
+ return self
33
+
34
+ def progress_bar(self, iterable):
35
+ return iterable
36
+
37
+ def maybe_free_model_hooks(self):
38
+ return None
39
+
40
+ @dataclass
41
+ class ProMoEPipelineOutput:
42
+ images: Union[List[Image.Image], np.ndarray, torch.Tensor]
43
+
44
+ class ProMoEPipeline(DiffusionPipeline):
45
+ r"""
46
+ Pipeline for class-conditional image generation with ProMoE.
47
+
48
+ Parameters:
49
+ transformer ([`ProMoETransformer2DModel`]):
50
+ Class-conditional ProMoE transformer for flow-matching in latent space.
51
+ scheduler ([`ProMoEFlowMatchScheduler`]):
52
+ Flow-matching scheduler used during denoising.
53
+ vae ([`AutoencoderKL`], *optional*):
54
+ Variational autoencoder used to decode latents to pixels.
55
+ id2label (`dict[int, str]`, *optional*):
56
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
57
+ """
58
+
59
+ model_cpu_offload_seq = "transformer->vae"
60
+ _optional_components = ["vae"]
61
+
62
+ def __init__(
63
+ self,
64
+ transformer,
65
+ scheduler,
66
+ vae=None,
67
+ id2label: Optional[Dict[Union[int, str], str]] = None,
68
+ ):
69
+ super().__init__()
70
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
71
+ self._id2label = self._normalize_id2label(id2label)
72
+ self.labels = self._build_label2id(self._id2label)
73
+ self._labels_loaded_from_model_index = bool(self._id2label)
74
+
75
+ def _ensure_labels_loaded(self) -> None:
76
+ if self._labels_loaded_from_model_index:
77
+ return
78
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
79
+ if loaded:
80
+ self._id2label = loaded
81
+ self.labels = self._build_label2id(self._id2label)
82
+ self._labels_loaded_from_model_index = True
83
+
84
+ @staticmethod
85
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
86
+ if not id2label:
87
+ return {}
88
+ return {int(key): value for key, value in id2label.items()}
89
+
90
+ @staticmethod
91
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
92
+ if not variant_path:
93
+ return {}
94
+ variant_dir = Path(variant_path).resolve()
95
+ model_index_path = variant_dir / "model_index.json"
96
+ if not model_index_path.exists():
97
+ return {}
98
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
99
+ id2label = raw.get("id2label")
100
+ if not isinstance(id2label, dict):
101
+ return {}
102
+ return {int(key): value for key, value in id2label.items()}
103
+
104
+ @staticmethod
105
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
106
+ label2id: Dict[str, int] = {}
107
+ for class_id, value in id2label.items():
108
+ for synonym in value.split(","):
109
+ synonym = synonym.strip()
110
+ if synonym:
111
+ label2id[synonym] = int(class_id)
112
+ return dict(sorted(label2id.items()))
113
+
114
+ @property
115
+ def id2label(self) -> Dict[int, str]:
116
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
117
+ self._ensure_labels_loaded()
118
+ return self._id2label
119
+
120
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
121
+ r"""
122
+ Map ImageNet label strings to class ids.
123
+
124
+ Args:
125
+ label (`str` or `list[str]`):
126
+ One or more English label strings. Each string must match a synonym in `id2label`.
127
+ """
128
+ self._ensure_labels_loaded()
129
+ label2id = self.labels
130
+ if not label2id:
131
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
132
+
133
+ if isinstance(label, str):
134
+ label = [label]
135
+
136
+ missing = [item for item in label if item not in label2id]
137
+ if missing:
138
+ preview = ", ".join(list(label2id.keys())[:8])
139
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
140
+ return [label2id[item] for item in label]
141
+
142
+ def _get_vae_spatial_downsample(self) -> int:
143
+ if self.vae is None:
144
+ return 8
145
+ block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0])
146
+ return 2 ** (len(block_out_channels) - 1)
147
+
148
+ def _normalize_class_labels(
149
+ self,
150
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
151
+ device: torch.device,
152
+ ) -> torch.LongTensor:
153
+ if torch.is_tensor(class_labels):
154
+ return class_labels.to(device=device, dtype=torch.long).reshape(-1)
155
+
156
+ if isinstance(class_labels, int):
157
+ class_label_ids = [class_labels]
158
+ elif isinstance(class_labels, str):
159
+ class_label_ids = self.get_label_ids(class_labels)
160
+ elif class_labels and isinstance(class_labels[0], str):
161
+ class_label_ids = self.get_label_ids(class_labels)
162
+ else:
163
+ class_label_ids = list(class_labels)
164
+
165
+ return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1)
166
+
167
+ def _prepare_latents(
168
+ self,
169
+ batch_size: int,
170
+ latent_height: int,
171
+ latent_width: int,
172
+ dtype: torch.dtype,
173
+ device: torch.device,
174
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
175
+ ) -> torch.Tensor:
176
+ shape = (batch_size, self.transformer.in_channels, latent_height, latent_width)
177
+ if isinstance(generator, list):
178
+ latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator]
179
+ return torch.cat(latents, dim=0)
180
+ return torch.randn(shape, generator=generator, device=device, dtype=dtype)
181
+
182
+ def _decode_latents(self, latents: torch.Tensor, output_type: str):
183
+ if output_type == "latent":
184
+ return latents
185
+ if self.vae is not None:
186
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
187
+ decode_dtype = next(self.vae.parameters()).dtype
188
+ latents = (latents / scaling_factor).to(dtype=decode_dtype)
189
+ image = self.vae.decode(latents, return_dict=False)[0]
190
+ else:
191
+ image = latents
192
+
193
+ image = (image / 2 + 0.5).clamp(0, 1)
194
+ if output_type == "pt":
195
+ return image
196
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
197
+ if output_type == "np":
198
+ return image
199
+ pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image]
200
+ return pil_images
201
+
202
+ @torch.no_grad()
203
+ def __call__(
204
+ self,
205
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
206
+ height: int = 256,
207
+ width: int = 256,
208
+ num_inference_steps: int = 50,
209
+ guidance_scale: float = 1.0,
210
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
211
+ output_type: str = "pil",
212
+ return_dict: bool = True,
213
+ ) -> Union[ProMoEPipelineOutput, Tuple]:
214
+ r"""
215
+ Generate class-conditional images with ProMoE.
216
+
217
+ Args:
218
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
219
+ ImageNet class indices or human-readable English label strings.
220
+ """
221
+ device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu")
222
+ model_dtype = next(self.transformer.parameters()).dtype
223
+ class_labels = self._normalize_class_labels(class_labels, device)
224
+ batch_size = class_labels.shape[0]
225
+
226
+ vae_scale = self._get_vae_spatial_downsample()
227
+ latent_height = height // vae_scale
228
+ latent_width = width // vae_scale
229
+ latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator)
230
+
231
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
232
+ null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000))
233
+
234
+ for t in self.progress_bar(self.scheduler.timesteps):
235
+ if guidance_scale > 1.0:
236
+ latent_input = torch.cat([latents, latents], dim=0)
237
+ labels = torch.cat([class_labels, null_labels], dim=0)
238
+ else:
239
+ latent_input = latents
240
+ labels = class_labels
241
+ timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype)
242
+ model_output = self.transformer(
243
+ hidden_states=latent_input,
244
+ timestep=timestep,
245
+ class_labels=labels,
246
+ return_dict=True,
247
+ ).sample
248
+ if model_output.shape[1] != latents.shape[1]:
249
+ model_output = model_output.chunk(2, dim=1)[0]
250
+ if guidance_scale > 1.0:
251
+ model_output_cond, model_output_uncond = model_output.chunk(2)
252
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
253
+ latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample
254
+
255
+ images = self._decode_latents(latents, output_type)
256
+ self.maybe_free_model_hooks()
257
+ if not return_dict:
258
+ return (images,)
259
+ return ProMoEPipelineOutput(images=images)
ProMoE-B-256/scheduler/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "num_train_timesteps": 1000,
4
+ "shift": 1.0
5
+ }
ProMoE-B-256/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
ProMoE-B-256/scheduler/scheduling_flow_match_promoe.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from types import SimpleNamespace
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ try:
8
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
9
+ except Exception: # pragma: no cover
10
+ FlowMatchEulerDiscreteScheduler = None
11
+
12
+
13
+ @dataclass
14
+ class ProMoEFlowMatchSchedulerOutput:
15
+ prev_sample: torch.FloatTensor
16
+
17
+
18
+ if FlowMatchEulerDiscreteScheduler is not None:
19
+
20
+ class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler):
21
+ pass
22
+
23
+ else:
24
+
25
+ class ProMoEFlowMatchScheduler:
26
+ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
27
+ self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False)
28
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32)
29
+
30
+ def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None):
31
+ self.timesteps = torch.linspace(
32
+ self.config.num_train_timesteps - 1,
33
+ 0,
34
+ num_inference_steps,
35
+ dtype=torch.float32,
36
+ device=device,
37
+ )
38
+
39
+ def step(self, model_output, timestep, sample, generator=None):
40
+ del generator
41
+ dt = 1.0 / max(len(self.timesteps), 1)
42
+ prev_sample = sample - dt * model_output
43
+ return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample)
ProMoE-B-256/transformer/backbone_diffmoe.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .modeling_promoe_common import (
7
+ Attention,
8
+ FinalLayer,
9
+ LabelEmbedder,
10
+ Mlp,
11
+ MoeMLP_DiffMoE as MoeMLP,
12
+ PatchEmbed,
13
+ TimestepEmbedder,
14
+ get_2d_sincos_pos_embed,
15
+ modulate,
16
+ )
17
+
18
+
19
+ class SparseMoEBlock(nn.Module):
20
+ def __init__(
21
+ self,
22
+ experts,
23
+ hidden_dim,
24
+ num_experts,
25
+ n_shared_experts=0,
26
+ capacity=2,
27
+ mlp_ratio=4.0,
28
+ use_diff_expert=False,
29
+ ):
30
+ super().__init__()
31
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
32
+ nn.init.normal_(self.gate_weight, std=0.006)
33
+ self.experts = nn.ModuleList(experts)
34
+ self.capacity = capacity
35
+ self.num_experts = num_experts
36
+ self.n_shared_experts = n_shared_experts
37
+ self.use_diff_expert = use_diff_expert
38
+ if use_diff_expert:
39
+ self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio))
40
+
41
+ self.capacity_predictor = nn.Sequential(
42
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_dim, self.num_experts, bias=True),
45
+ )
46
+
47
+ if self.n_shared_experts > 0:
48
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2)
49
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
50
+ self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
51
+
52
+ self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts))
53
+ self.register_buffer("ema_decay", torch.tensor([0.95]))
54
+
55
+ def forward(self, x):
56
+ if self.training:
57
+ return self.forward_train(x)
58
+ return self.forward_eval(x)
59
+
60
+ def update_threshold(self, capacity_pred):
61
+ if not self.training:
62
+ return
63
+ capacity_pred = torch.sigmoid(capacity_pred)
64
+ seq_len = capacity_pred.size(0)
65
+ topk = int((seq_len / self.num_experts) * self.capacity)
66
+ threshold = self.expert_threshold
67
+ ema_decay = self.ema_decay
68
+ for i in range(self.num_experts):
69
+ scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True)
70
+ quantile = scores[-1].detach()
71
+ threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile
72
+ if dist.is_available() and dist.is_initialized():
73
+ dist.all_reduce(threshold, op=dist.ReduceOp.SUM)
74
+ threshold /= dist.get_world_size()
75
+ self.expert_threshold = threshold
76
+
77
+ def forward_train(self, x):
78
+ bsz, seq_len, hidden_dim = x.shape
79
+ identity = x
80
+ x = x.view(-1, hidden_dim)
81
+ total_tokens = x.shape[0]
82
+ capacity_pred = self.capacity_predictor(x.detach())
83
+ k = int((total_tokens / self.num_experts) * self.capacity)
84
+ logits = F.linear(x, self.gate_weight, None)
85
+ scores = logits.softmax(dim=-1).permute(1, 0)
86
+ gating, index = torch.topk(scores, k=k, dim=-1, sorted=False)
87
+ mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device)
88
+ mask.scatter_(1, index, 1.0)
89
+ expert_inputs = x[index]
90
+ expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)])
91
+ gated_outputs = gating.unsqueeze(-1) * expert_outputs
92
+
93
+ y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device)
94
+ offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens
95
+ flat_index = (index + offset.long()).view(-1)
96
+ y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim))
97
+ y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False)
98
+
99
+ self.update_threshold(capacity_pred)
100
+ x_out = y.view(bsz, seq_len, hidden_dim)
101
+ ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts)
102
+ capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts)
103
+ if self.n_shared_experts > 0:
104
+ x_out = x_out + self.shared_experts(identity)
105
+ if self.use_diff_expert:
106
+ x_out = x_out - self.diff_expert(identity)
107
+ return x_out, ones, capacity_pred
108
+
109
+ def forward_eval(self, x):
110
+ bsz, seq_len, hidden_dim = x.shape
111
+ identity = x
112
+ x = x.view(-1, hidden_dim)
113
+ total_tokens = x.shape[0]
114
+ capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach()))
115
+ threshold = self.expert_threshold
116
+ logits = F.linear(x, self.gate_weight, None)
117
+ scores = logits.softmax(dim=-1).permute(-1, -2)
118
+ y = torch.zeros_like(x, dtype=x.dtype)
119
+ for i, expert in enumerate(self.experts):
120
+ k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum()
121
+ gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False)
122
+ y[index, :] += gating.unsqueeze(-1) * expert(x[index, :])
123
+ x_out = y.view(bsz, seq_len, hidden_dim)
124
+ if self.n_shared_experts > 0:
125
+ x_out = x_out + self.shared_experts(identity)
126
+ return x_out, None, None
127
+
128
+
129
+ class DiTBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_size,
133
+ num_heads,
134
+ head_dim=None,
135
+ mlp_ratio=4.0,
136
+ use_swiglu=False,
137
+ MoE_config=None,
138
+ use_moe=False,
139
+ qk_norm=False,
140
+ **block_kwargs,
141
+ ):
142
+ super().__init__()
143
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
144
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs)
145
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.use_moe = use_moe
148
+ if use_moe:
149
+ if not use_swiglu:
150
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
151
+ experts = [
152
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
153
+ for _ in range(MoE_config.num_experts)
154
+ ]
155
+ else:
156
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
157
+ self.mlp = SparseMoEBlock(
158
+ experts=experts,
159
+ hidden_dim=hidden_size,
160
+ num_experts=MoE_config.num_experts,
161
+ capacity=MoE_config.capacity,
162
+ n_shared_experts=MoE_config.n_shared_experts,
163
+ mlp_ratio=4.0,
164
+ )
165
+ else:
166
+ if not use_swiglu:
167
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
169
+ else:
170
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
171
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
172
+
173
+ def forward(self, x, c):
174
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
175
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
176
+ if self.use_moe:
177
+ x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
178
+ x = x + gate_mlp.unsqueeze(1) * x_mlp
179
+ return x, ones, pred_c
180
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+ return x, None, None
182
+
183
+
184
+ class DiT(nn.Module):
185
+ def __init__(
186
+ self,
187
+ input_size=32,
188
+ patch_size=2,
189
+ in_channels=4,
190
+ hidden_size=1152,
191
+ depth=28,
192
+ num_heads=16,
193
+ mlp_ratio=4.0,
194
+ qk_norm=False,
195
+ class_dropout_prob=0.1,
196
+ num_classes=1000,
197
+ learn_sigma=True,
198
+ use_swiglu=False,
199
+ MoE_config=None,
200
+ head_dim=None,
201
+ CapacityPred_loss_weight=0.01,
202
+ ):
203
+ super().__init__()
204
+ self.learn_sigma = learn_sigma
205
+ self.in_channels = in_channels
206
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
207
+ self.patch_size = patch_size
208
+ self.num_heads = num_heads
209
+ self.MoE_config = MoE_config
210
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
211
+ self.CapacityPred_loss_weight = CapacityPred_loss_weight
212
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
213
+ self.t_embedder = TimestepEmbedder(hidden_size)
214
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
215
+ num_patches = self.x_embedder.num_patches
216
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
217
+ self.blocks = nn.ModuleList(
218
+ [
219
+ DiTBlock(
220
+ hidden_size,
221
+ num_heads,
222
+ head_dim=head_dim,
223
+ mlp_ratio=mlp_ratio,
224
+ qk_norm=qk_norm,
225
+ use_swiglu=use_swiglu,
226
+ MoE_config=MoE_config,
227
+ use_moe=use_moe_flag[i],
228
+ )
229
+ for i in range(depth)
230
+ ]
231
+ )
232
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
233
+ self.init_MoeMLP = MoE_config.init_MoeMLP
234
+ self.initialize_weights()
235
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
236
+ if self.capacity_schedule:
237
+ self.training_iters = -1
238
+
239
+ def initialize_weights(self):
240
+ def _basic_init(module):
241
+ if isinstance(module, nn.Linear):
242
+ torch.nn.init.xavier_uniform_(module.weight)
243
+ if module.bias is not None:
244
+ nn.init.constant_(module.bias, 0)
245
+
246
+ self.apply(_basic_init)
247
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
248
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
249
+ w = self.x_embedder.proj.weight.data
250
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
251
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
252
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
253
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
254
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
255
+ for block in self.blocks:
256
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
258
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
259
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
260
+ nn.init.constant_(self.final_layer.linear.weight, 0)
261
+ nn.init.constant_(self.final_layer.linear.bias, 0)
262
+
263
+ def unpatchify(self, x):
264
+ c = self.out_channels
265
+ p = self.x_embedder.patch_size[0]
266
+ h = w = int(x.shape[1] ** 0.5)
267
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
268
+ x = torch.einsum("nhwpqc->nchpwq", x)
269
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
270
+
271
+ def forward(self, x, t, context, **kwargs):
272
+ y = context
273
+ if len(x.shape) != 4:
274
+ x = x.squeeze(2)
275
+
276
+ if self.training and self.capacity_schedule:
277
+ num_experts = self.MoE_config.num_experts
278
+ capacity = self.MoE_config.capacity
279
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
280
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
281
+ if self.training_iters <= stage_i:
282
+ capacity = num_experts
283
+ elif self.training_iters <= stage_ii:
284
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
285
+ for block in self.blocks:
286
+ if hasattr(block.mlp, "capacity"):
287
+ block.mlp.capacity = capacity
288
+
289
+ x = self.x_embedder(x) + self.pos_embed
290
+ t = self.t_embedder(t)
291
+ y = self.y_embedder(y, self.training)
292
+ c = t + y
293
+ ones_list, pred_c_list, layer_idx_list = [], [], []
294
+ for layer_idx, block in enumerate(self.blocks):
295
+ x, ones, pred_c = block(x, c)
296
+ if ones is not None:
297
+ ones_list.append(ones)
298
+ pred_c_list.append(pred_c)
299
+ layer_idx_list.append(layer_idx)
300
+ x = self.final_layer(x, c)
301
+ x = self.unpatchify(x)
302
+ return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight
ProMoE-B-256/transformer/backbone_dit.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .modeling_promoe_common import (
5
+ Attention,
6
+ FinalLayer,
7
+ LabelEmbedder,
8
+ Mlp,
9
+ PatchEmbed,
10
+ TimestepEmbedder,
11
+ get_2d_sincos_pos_embed,
12
+ modulate,
13
+ )
14
+
15
+
16
+ class DiTBlock(nn.Module):
17
+ def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs):
18
+ super().__init__()
19
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
20
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
21
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
22
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
23
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
24
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
25
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
26
+
27
+ def forward(self, x, c):
28
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
29
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
30
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
31
+ return x
32
+
33
+
34
+ class DiT(nn.Module):
35
+ def __init__(
36
+ self,
37
+ input_size=32,
38
+ patch_size=2,
39
+ in_channels=4,
40
+ hidden_size=1152,
41
+ depth=28,
42
+ num_heads=16,
43
+ mlp_ratio=4.0,
44
+ qk_norm=False,
45
+ class_dropout_prob=0.1,
46
+ num_classes=1000,
47
+ learn_sigma=True,
48
+ head_dim=None,
49
+ use_swiglu=False,
50
+ ):
51
+ super().__init__()
52
+ self.learn_sigma = learn_sigma
53
+ self.in_channels = in_channels
54
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
55
+ self.patch_size = patch_size
56
+ self.num_heads = num_heads
57
+
58
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
59
+ self.t_embedder = TimestepEmbedder(hidden_size)
60
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
61
+ num_patches = self.x_embedder.num_patches
62
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
63
+
64
+ self.blocks = nn.ModuleList(
65
+ [
66
+ DiTBlock(
67
+ hidden_size,
68
+ num_heads,
69
+ head_dim=head_dim,
70
+ mlp_ratio=mlp_ratio,
71
+ qk_norm=qk_norm,
72
+ use_swiglu=use_swiglu,
73
+ )
74
+ for _ in range(depth)
75
+ ]
76
+ )
77
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
78
+ self.initialize_weights()
79
+
80
+ def initialize_weights(self):
81
+ def _basic_init(module):
82
+ if isinstance(module, nn.Linear):
83
+ torch.nn.init.xavier_uniform_(module.weight)
84
+ if module.bias is not None:
85
+ nn.init.constant_(module.bias, 0)
86
+
87
+ self.apply(_basic_init)
88
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
89
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
90
+ w = self.x_embedder.proj.weight.data
91
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
92
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
93
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
94
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
95
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
96
+ for block in self.blocks:
97
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
98
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
99
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
100
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
101
+ nn.init.constant_(self.final_layer.linear.weight, 0)
102
+ nn.init.constant_(self.final_layer.linear.bias, 0)
103
+
104
+ def unpatchify(self, x):
105
+ c = self.out_channels
106
+ p = self.x_embedder.patch_size[0]
107
+ h = w = int(x.shape[1] ** 0.5)
108
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
109
+ x = torch.einsum("nhwpqc->nchpwq", x)
110
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
111
+
112
+ def forward(self, x, t, context, **kwargs):
113
+ y = context
114
+ if len(x.shape) != 4:
115
+ x = x.squeeze(2)
116
+ x = self.x_embedder(x) + self.pos_embed
117
+ t = self.t_embedder(t)
118
+ y = self.y_embedder(y, self.training)
119
+ c = t + y
120
+ for block in self.blocks:
121
+ x = block(x, c)
122
+ x = self.final_layer(x, c)
123
+ return self.unpatchify(x)
ProMoE-B-256/transformer/backbone_ecdit.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP_DiffMoE as MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class SparseMoEBlock(nn.Module):
19
+ def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2):
20
+ super().__init__()
21
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
22
+ nn.init.normal_(self.gate_weight, std=0.006)
23
+ self.experts = nn.ModuleList(experts)
24
+ self.capacity = capacity
25
+ self.num_experts = num_experts
26
+ self.n_shared_experts = n_shared_experts
27
+ if self.n_shared_experts > 0:
28
+ intermediate_size = hidden_dim * self.n_shared_experts
29
+ self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2)
30
+
31
+ def forward(self, x):
32
+ identity = x
33
+ batch_size, seq_len, _ = x.shape
34
+ logits = F.linear(x, self.gate_weight, None)
35
+ affinity = logits.softmax(dim=-1)
36
+ affinity = torch.einsum("b s e -> b e s", affinity)
37
+ k = int((seq_len / self.num_experts) * self.capacity)
38
+ gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False)
39
+ dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype)
40
+ x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x)
41
+ x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)]
42
+ x_e = torch.stack(x_e, dim=1)
43
+ x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e)
44
+ if self.n_shared_experts > 0:
45
+ x_out = x_out + self.shared_experts(identity)
46
+ return x_out
47
+
48
+
49
+ class DiTBlock(nn.Module):
50
+ def __init__(
51
+ self,
52
+ hidden_size,
53
+ num_heads,
54
+ head_dim=None,
55
+ mlp_ratio=4.0,
56
+ use_swiglu=False,
57
+ MoE_config=None,
58
+ use_moe=False,
59
+ **block_kwargs,
60
+ ):
61
+ super().__init__()
62
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
63
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
64
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
65
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
66
+ if use_moe:
67
+ if not use_swiglu:
68
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
69
+ experts = [
70
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
71
+ for _ in range(MoE_config.num_experts)
72
+ ]
73
+ else:
74
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
75
+ self.mlp = SparseMoEBlock(
76
+ experts=experts,
77
+ hidden_dim=hidden_size,
78
+ num_experts=MoE_config.num_experts,
79
+ capacity=MoE_config.capacity,
80
+ n_shared_experts=MoE_config.n_shared_experts,
81
+ )
82
+ else:
83
+ if not use_swiglu:
84
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
85
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
86
+ else:
87
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
88
+
89
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
90
+
91
+ def forward(self, x, c):
92
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
93
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
94
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
95
+ return x
96
+
97
+
98
+ class DiT(nn.Module):
99
+ def __init__(
100
+ self,
101
+ input_size=32,
102
+ patch_size=2,
103
+ in_channels=4,
104
+ hidden_size=1152,
105
+ depth=28,
106
+ num_heads=16,
107
+ mlp_ratio=4.0,
108
+ qk_norm=False,
109
+ class_dropout_prob=0.1,
110
+ num_classes=1000,
111
+ learn_sigma=True,
112
+ use_swiglu=False,
113
+ MoE_config=None,
114
+ head_dim=None,
115
+ ):
116
+ super().__init__()
117
+ self.learn_sigma = learn_sigma
118
+ self.in_channels = in_channels
119
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
120
+ self.patch_size = patch_size
121
+ self.num_heads = num_heads
122
+ self.MoE_config = MoE_config
123
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
124
+
125
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
126
+ self.t_embedder = TimestepEmbedder(hidden_size)
127
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
128
+ num_patches = self.x_embedder.num_patches
129
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
130
+ self.blocks = nn.ModuleList(
131
+ [
132
+ DiTBlock(
133
+ hidden_size,
134
+ num_heads,
135
+ head_dim=head_dim,
136
+ mlp_ratio=mlp_ratio,
137
+ qk_norm=qk_norm,
138
+ use_swiglu=use_swiglu,
139
+ MoE_config=MoE_config,
140
+ use_moe=use_moe_flag[i],
141
+ )
142
+ for i in range(depth)
143
+ ]
144
+ )
145
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
146
+ self.init_MoeMLP = MoE_config.init_MoeMLP
147
+ self.initialize_weights()
148
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
149
+ if self.capacity_schedule:
150
+ self.training_iters = -1
151
+
152
+ def initialize_weights(self):
153
+ def _basic_init(module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.xavier_uniform_(module.weight)
156
+ if module.bias is not None:
157
+ nn.init.constant_(module.bias, 0)
158
+
159
+ self.apply(_basic_init)
160
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
161
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
162
+ w = self.x_embedder.proj.weight.data
163
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
164
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
165
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
166
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
167
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
168
+ for block in self.blocks:
169
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
172
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
173
+ nn.init.constant_(self.final_layer.linear.weight, 0)
174
+ nn.init.constant_(self.final_layer.linear.bias, 0)
175
+
176
+ def init_moe_mlp(module, std=0.006):
177
+ nn.init.normal_(module.gate_proj.weight, std=std)
178
+ nn.init.normal_(module.up_proj.weight, std=std)
179
+ nn.init.normal_(module.down_proj.weight, std=std)
180
+
181
+ if self.init_MoeMLP:
182
+ for block in self.blocks:
183
+ if hasattr(block.mlp, "experts"):
184
+ for expert in block.mlp.experts:
185
+ if hasattr(expert, "gate_proj"):
186
+ init_moe_mlp(expert)
187
+
188
+ def unpatchify(self, x):
189
+ c = self.out_channels
190
+ p = self.x_embedder.patch_size[0]
191
+ h = w = int(x.shape[1] ** 0.5)
192
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
193
+ x = torch.einsum("nhwpqc->nchpwq", x)
194
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
195
+
196
+ def forward(self, x, t, context, **kwargs):
197
+ y = context
198
+ if len(x.shape) != 4:
199
+ x = x.squeeze(2)
200
+ if self.training and self.capacity_schedule:
201
+ num_experts = self.MoE_config.num_experts
202
+ capacity = self.MoE_config.capacity
203
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
204
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
205
+ if self.training_iters <= stage_i:
206
+ capacity = num_experts
207
+ elif self.training_iters <= stage_ii:
208
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
209
+ for block in self.blocks:
210
+ if hasattr(block.mlp, "capacity"):
211
+ block.mlp.capacity = capacity
212
+
213
+ x = self.x_embedder(x) + self.pos_embed
214
+ t = self.t_embedder(t)
215
+ y = self.y_embedder(y, self.training)
216
+ c = t + y
217
+ for block in self.blocks:
218
+ x = block(x, c)
219
+ x = self.final_layer(x, c)
220
+ return self.unpatchify(x)
ProMoE-B-256/transformer/backbone_promoe_ec.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=1,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.hidden_size = hidden_size
55
+ self.top_k = top_k
56
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
57
+ self.use_shared_expert = use_shared_expert
58
+ self.use_uncond_expert = use_uncond_expert
59
+ self.router_weight_mode = router_weight_mode
60
+ self.routing_contrastive_lam = routing_contrastive_lam
61
+ self.routing_contrastive_temperature = routing_contrastive_temperature
62
+ self.experts = nn.ModuleList(
63
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
64
+ )
65
+ if use_shared_expert:
66
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
67
+ self._init_weights()
68
+
69
+ def compute_router(self, cond_hidden_states):
70
+ b_cond, seq_len, _ = cond_hidden_states.shape
71
+ num_cond_experts = self.num_routed_experts
72
+ input_norm = F.normalize(cond_hidden_states, p=2, dim=-1)
73
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1)
74
+ cos_sim = input_norm @ cluster_norm.T
75
+ cos_sim_expert_view = cos_sim.transpose(1, 2)
76
+ if self.router_weight_mode == "softmax":
77
+ cond_weights = F.softmax(cos_sim_expert_view, dim=-1)
78
+ elif self.router_weight_mode == "sigmoid":
79
+ cond_weights = torch.sigmoid(cos_sim_expert_view)
80
+ elif self.router_weight_mode == "identity":
81
+ cond_weights = cos_sim_expert_view
82
+ else:
83
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
84
+ k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len))
85
+ router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False)
86
+ dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype)
87
+ expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states)
88
+ return dispatch_mask, router_weights, expert_inputs
89
+
90
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
91
+ identity = hidden_states
92
+ batch_size, _, hidden_dim = hidden_states.shape
93
+ final_output = torch.zeros_like(hidden_states)
94
+ loss = None
95
+ cond_batch_mask = (
96
+ labels.view(-1) != 1000
97
+ ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device)
98
+ uncond_batch_mask = ~cond_batch_mask
99
+ cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts
100
+
101
+ if cond_batch_mask.any():
102
+ cond_hidden_states = hidden_states[cond_batch_mask]
103
+ dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states)
104
+ num_cond_experts = len(cond_experts)
105
+ expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1)
106
+ cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype)
107
+ final_output[cond_batch_mask] = cond_output
108
+ if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1:
109
+ expert_token_means = expert_inputs.mean(dim=2)
110
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means)
111
+ loss = routing_contrastive_loss * self.routing_contrastive_lam
112
+ else:
113
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
114
+ for expert in cond_experts:
115
+ final_output = final_output + expert(dummy_input).sum() * 0
116
+
117
+ if self.use_uncond_expert:
118
+ if uncond_batch_mask.any():
119
+ uncond_hidden_states = hidden_states[uncond_batch_mask]
120
+ final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype)
121
+ else:
122
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
123
+ final_output = final_output + self.experts[-1](dummy_input).sum() * 0
124
+
125
+ if self.use_shared_expert:
126
+ final_output += self.shared_expert(identity).to(hidden_states.dtype)
127
+ return final_output, loss
128
+
129
+ def compute_routing_contrastive_loss(self, expert_token_means):
130
+ batch_size, num_cond_experts, _ = expert_token_means.shape
131
+ if num_cond_experts < 2:
132
+ return torch.tensor(0.0, device=expert_token_means.device)
133
+ centers_norm = F.normalize(self.cluster_centers, p=2, dim=1)
134
+ means_norm = F.normalize(expert_token_means, p=2, dim=2)
135
+ sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm)
136
+ logits = sim_matrix / self.routing_contrastive_temperature
137
+ labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1)
138
+ return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1))
139
+
140
+ def _init_weights(self):
141
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
142
+
143
+
144
+ class DiTBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ hidden_size,
148
+ num_heads,
149
+ head_dim=None,
150
+ mlp_ratio=4.0,
151
+ use_swiglu=False,
152
+ MoE_config=None,
153
+ use_moe=False,
154
+ **block_kwargs,
155
+ ):
156
+ super().__init__()
157
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
158
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
159
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
160
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
161
+ self.use_moe = use_moe
162
+ if use_moe:
163
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
164
+ else:
165
+ if not use_swiglu:
166
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
167
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
168
+ else:
169
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
170
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
171
+
172
+ def forward(self, x, c, label):
173
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
174
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
175
+ if self.use_moe:
176
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
177
+ if aux_loss is not None:
178
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
179
+ return x + gate_mlp.unsqueeze(1) * x_mlp
180
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+
182
+
183
+ class DiT(nn.Module):
184
+ def __init__(
185
+ self,
186
+ input_size=32,
187
+ patch_size=2,
188
+ in_channels=4,
189
+ hidden_size=1152,
190
+ depth=28,
191
+ num_heads=16,
192
+ mlp_ratio=4.0,
193
+ qk_norm=False,
194
+ class_dropout_prob=0.1,
195
+ num_classes=1000,
196
+ learn_sigma=True,
197
+ use_swiglu=False,
198
+ MoE_config=None,
199
+ head_dim=None,
200
+ ):
201
+ super().__init__()
202
+ self.learn_sigma = learn_sigma
203
+ self.in_channels = in_channels
204
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
205
+ self.patch_size = patch_size
206
+ self.num_heads = num_heads
207
+ self.MoE_config = MoE_config
208
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
209
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
210
+ self.t_embedder = TimestepEmbedder(hidden_size)
211
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
212
+ num_patches = self.x_embedder.num_patches
213
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
214
+ self.blocks = nn.ModuleList(
215
+ [
216
+ DiTBlock(
217
+ hidden_size,
218
+ num_heads,
219
+ head_dim=head_dim,
220
+ mlp_ratio=mlp_ratio,
221
+ qk_norm=qk_norm,
222
+ use_swiglu=use_swiglu,
223
+ MoE_config=MoE_config,
224
+ use_moe=use_moe_flag[i],
225
+ )
226
+ for i in range(depth)
227
+ ]
228
+ )
229
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
230
+ self.init_MoeMLP = MoE_config.init_MoeMLP
231
+ self.initialize_weights()
232
+
233
+ def initialize_weights(self):
234
+ def _basic_init(module):
235
+ if isinstance(module, nn.Linear):
236
+ torch.nn.init.xavier_uniform_(module.weight)
237
+ if module.bias is not None:
238
+ nn.init.constant_(module.bias, 0)
239
+
240
+ self.apply(_basic_init)
241
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
242
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
243
+ w = self.x_embedder.proj.weight.data
244
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
245
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
246
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
247
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
248
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
249
+ for block in self.blocks:
250
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
251
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
252
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
253
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
254
+ nn.init.constant_(self.final_layer.linear.weight, 0)
255
+ nn.init.constant_(self.final_layer.linear.bias, 0)
256
+
257
+ def init_moe_mlp(module, std=0.006):
258
+ nn.init.normal_(module.up_proj.weight, std=std)
259
+ nn.init.normal_(module.down_proj.weight, std=std)
260
+
261
+ if self.init_MoeMLP:
262
+ for block in self.blocks:
263
+ if hasattr(block.mlp, "experts"):
264
+ for expert in block.mlp.experts:
265
+ init_moe_mlp(expert)
266
+
267
+ def unpatchify(self, x):
268
+ c = self.out_channels
269
+ p = self.x_embedder.patch_size[0]
270
+ h = w = int(x.shape[1] ** 0.5)
271
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
272
+ x = torch.einsum("nhwpqc->nchpwq", x)
273
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
274
+
275
+ def forward(self, x, timestep, context, **kwargs):
276
+ y = context
277
+ if len(x.shape) != 4:
278
+ x = x.squeeze(2)
279
+ x = self.x_embedder(x) + self.pos_embed
280
+ t = self.t_embedder(timestep)
281
+ y, labels = self.y_embedder(y, self.training)
282
+ c = t + y
283
+ for block in self.blocks:
284
+ x = block(x, c, labels)
285
+ x = self.final_layer(x, c)
286
+ return self.unpatchify(x)
ProMoE-B-256/transformer/backbone_promoe_tc.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=2,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del norm_topk_prob
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.seq_aux = seq_aux
55
+ self.hidden_size = hidden_size
56
+ self.top_k = top_k
57
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
58
+ self.alpha = load_balance_loss_coef
59
+ self.use_shared_expert = use_shared_expert
60
+ self.use_uncond_expert = use_uncond_expert
61
+ self.router_weight_mode = router_weight_mode
62
+ self.routing_contrastive_lam = routing_contrastive_lam
63
+ self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive
64
+ self.routing_contrastive_temperature = routing_contrastive_temperature
65
+ self.experts = nn.ModuleList(
66
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
67
+ )
68
+ if use_shared_expert:
69
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
70
+ self._init_weights()
71
+
72
+ def compute_router(self, hidden_states, labels):
73
+ batch_size, seq_len, _ = hidden_states.shape
74
+ device = hidden_states.device
75
+ flat_input = hidden_states.view(-1, self.hidden_size)
76
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
77
+ if self.use_uncond_expert and flat_labels is not None:
78
+ uncond_mask = flat_labels == 1000
79
+ cond_mask = ~uncond_mask
80
+ else:
81
+ uncond_mask = None
82
+ cond_mask = torch.ones_like(flat_labels, dtype=torch.bool)
83
+
84
+ router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype)
85
+ expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long)
86
+
87
+ if uncond_mask is not None and uncond_mask.any():
88
+ uncond_positions = torch.where(uncond_mask)[0]
89
+ router_weights[uncond_positions, 0] = 1.0
90
+ expert_indices[uncond_positions] = self.num_experts - 1
91
+
92
+ cond_weights = None
93
+ topk_idx = None
94
+ if cond_mask.any():
95
+ cond_positions = torch.where(cond_mask)[0]
96
+ cond_input = flat_input[cond_positions]
97
+ input_norm = F.normalize(cond_input, p=2, dim=1)
98
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1)
99
+ cos_sim = input_norm @ cluster_norm.T
100
+ if self.router_weight_mode == "softmax":
101
+ cond_weights = F.softmax(cos_sim, dim=1)
102
+ elif self.router_weight_mode == "sigmoid":
103
+ cond_weights = torch.sigmoid(cos_sim)
104
+ elif self.router_weight_mode == "identity":
105
+ cond_weights = cos_sim
106
+ else:
107
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
108
+ topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1)
109
+ router_weights[cond_positions] = topk_scores.to(router_weights.dtype)
110
+ expert_indices[cond_positions] = topk_idx
111
+
112
+ router_weights = router_weights.view(batch_size, seq_len, self.top_k)
113
+ expert_indices = expert_indices.view(batch_size, seq_len, self.top_k)
114
+
115
+ load_balance_loss = None
116
+ if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None:
117
+ cond_batch_size = (labels != 1000).sum()
118
+ scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights
119
+ topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1)
120
+ if self.seq_aux:
121
+ scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1)
122
+ ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device)
123
+ ce.scatter_add_(
124
+ 1,
125
+ topk_idx_for_aux_loss,
126
+ torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device),
127
+ ).div_(seq_len * self.top_k / self.num_routed_experts)
128
+ load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
129
+ else:
130
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts)
131
+ ce = mask_ce.float().mean(0)
132
+ pi = scores_for_aux.mean(0)
133
+ fi = ce * self.num_routed_experts
134
+ load_balance_loss = (pi * fi).sum() * self.alpha
135
+ return router_weights, expert_indices, load_balance_loss
136
+
137
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
138
+ router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels)
139
+ batch_size, seq_len, hidden_dim = hidden_states.shape
140
+ flat_input = hidden_states.view(-1, hidden_dim)
141
+ flat_weights = router_weights.view(-1, self.top_k)
142
+ flat_indices = expert_indices.view(-1, self.top_k)
143
+ total_tokens = batch_size * seq_len
144
+ final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
145
+
146
+ for expert_id in range(self.num_experts):
147
+ expert_mask = (flat_indices == expert_id).any(dim=1)
148
+ token_ids = torch.where(expert_mask)[0]
149
+ if token_ids.numel() > 0:
150
+ expert_input = flat_input[token_ids]
151
+ expert_weight_mask = flat_indices[token_ids] == expert_id
152
+ expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype)
153
+ combined_weights = expert_weights.sum(dim=1)
154
+ expert_output = self.experts[expert_id](expert_input)
155
+ weighted_output = expert_output * combined_weights.unsqueeze(1)
156
+ final_output.index_add_(0, token_ids, weighted_output)
157
+ else:
158
+ dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
159
+ final_output[0] += self.experts[expert_id](dummy_input)[0] * 0
160
+
161
+ final_output = final_output.view(batch_size, seq_len, hidden_dim)
162
+ if self.use_shared_expert:
163
+ final_output += self.shared_expert(hidden_states)
164
+
165
+ loss = load_balance_loss
166
+ if self.training and self.routing_contrastive_lam > 0:
167
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
168
+ cond_mask = ~(
169
+ flat_labels == 1000
170
+ ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device)
171
+ cond_token_embeddings = flat_input[cond_mask]
172
+ if self.use_top_k_for_routing_contrastive:
173
+ cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask]
174
+ else:
175
+ top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0]
176
+ cond_cluster_assignments = top1_expert_indices[cond_mask]
177
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(
178
+ cond_token_embeddings,
179
+ cond_cluster_assignments,
180
+ use_top_k=self.use_top_k_for_routing_contrastive,
181
+ )
182
+ routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam
183
+ loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss
184
+
185
+ return final_output, loss
186
+
187
+ def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False):
188
+ cluster_centers = self.cluster_centers
189
+ num_clusters = cluster_centers.size(0)
190
+ device = cluster_centers.device
191
+ cluster_means = []
192
+ valid_clusters = []
193
+ for cluster_id in range(num_clusters):
194
+ mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id
195
+ if mask.sum() > 0:
196
+ cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True))
197
+ valid_clusters.append(cluster_id)
198
+ if len(valid_clusters) < 2:
199
+ return torch.tensor(0.0, device=device)
200
+ cluster_means = torch.cat(cluster_means, dim=0)
201
+ valid_centers = cluster_centers[valid_clusters]
202
+ centers_norm = F.normalize(valid_centers, p=2, dim=1)
203
+ means_norm = F.normalize(cluster_means, p=2, dim=1)
204
+ sim_matrix = centers_norm @ means_norm.T
205
+ logits = sim_matrix / self.routing_contrastive_temperature
206
+ labels = torch.arange(sim_matrix.size(0), device=device)
207
+ return F.cross_entropy(logits, labels)
208
+
209
+ def _init_weights(self):
210
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
211
+
212
+
213
+ class DiTBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ hidden_size,
217
+ num_heads,
218
+ head_dim=None,
219
+ mlp_ratio=4.0,
220
+ use_swiglu=False,
221
+ MoE_config=None,
222
+ use_moe=False,
223
+ **block_kwargs,
224
+ ):
225
+ super().__init__()
226
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
228
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
230
+ self.use_moe = use_moe
231
+ if use_moe:
232
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
233
+ else:
234
+ if not use_swiglu:
235
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
236
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
237
+ else:
238
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
239
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
240
+
241
+ def forward(self, x, c, label):
242
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
243
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
244
+ if self.use_moe:
245
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
246
+ if aux_loss is not None:
247
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
248
+ return x + gate_mlp.unsqueeze(1) * x_mlp
249
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
250
+
251
+
252
+ class DiT(nn.Module):
253
+ def __init__(
254
+ self,
255
+ input_size=32,
256
+ patch_size=2,
257
+ in_channels=4,
258
+ hidden_size=1152,
259
+ depth=28,
260
+ num_heads=16,
261
+ mlp_ratio=4.0,
262
+ qk_norm=False,
263
+ class_dropout_prob=0.1,
264
+ num_classes=1000,
265
+ learn_sigma=True,
266
+ use_swiglu=False,
267
+ MoE_config=None,
268
+ head_dim=None,
269
+ ):
270
+ super().__init__()
271
+ self.learn_sigma = learn_sigma
272
+ self.in_channels = in_channels
273
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
274
+ self.patch_size = patch_size
275
+ self.num_heads = num_heads
276
+ self.MoE_config = MoE_config
277
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
278
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
279
+ self.t_embedder = TimestepEmbedder(hidden_size)
280
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
281
+ num_patches = self.x_embedder.num_patches
282
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
283
+ self.blocks = nn.ModuleList(
284
+ [
285
+ DiTBlock(
286
+ hidden_size,
287
+ num_heads,
288
+ head_dim=head_dim,
289
+ mlp_ratio=mlp_ratio,
290
+ qk_norm=qk_norm,
291
+ use_swiglu=use_swiglu,
292
+ MoE_config=MoE_config,
293
+ use_moe=use_moe_flag[i],
294
+ )
295
+ for i in range(depth)
296
+ ]
297
+ )
298
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
299
+ self.init_MoeMLP = MoE_config.init_MoeMLP
300
+ self.initialize_weights()
301
+
302
+ def initialize_weights(self):
303
+ def _basic_init(module):
304
+ if isinstance(module, nn.Linear):
305
+ torch.nn.init.xavier_uniform_(module.weight)
306
+ if module.bias is not None:
307
+ nn.init.constant_(module.bias, 0)
308
+
309
+ self.apply(_basic_init)
310
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
311
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
312
+ w = self.x_embedder.proj.weight.data
313
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
314
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
315
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
316
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
317
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
318
+ for block in self.blocks:
319
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
320
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
321
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
322
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
323
+ nn.init.constant_(self.final_layer.linear.weight, 0)
324
+ nn.init.constant_(self.final_layer.linear.bias, 0)
325
+
326
+ def init_moe_mlp(module, std=0.006):
327
+ nn.init.normal_(module.up_proj.weight, std=std)
328
+ nn.init.normal_(module.down_proj.weight, std=std)
329
+
330
+ if self.init_MoeMLP:
331
+ for block in self.blocks:
332
+ if hasattr(block.mlp, "experts"):
333
+ for expert in block.mlp.experts:
334
+ init_moe_mlp(expert)
335
+
336
+ def unpatchify(self, x):
337
+ c = self.out_channels
338
+ p = self.x_embedder.patch_size[0]
339
+ h = w = int(x.shape[1] ** 0.5)
340
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
341
+ x = torch.einsum("nhwpqc->nchpwq", x)
342
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
343
+
344
+ def forward(self, x, timestep, context, **kwargs):
345
+ y = context
346
+ if len(x.shape) != 4:
347
+ x = x.squeeze(2)
348
+ x = self.x_embedder(x) + self.pos_embed
349
+ t = self.t_embedder(timestep)
350
+ y, labels = self.y_embedder(y, self.training)
351
+ c = t + y
352
+ for block in self.blocks:
353
+ x = block(x, c, labels)
354
+ x = self.final_layer(x, c)
355
+ return self.unpatchify(x)
ProMoE-B-256/transformer/backbone_tcdit.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_promoe_common import (
8
+ Attention,
9
+ FinalLayer,
10
+ LabelEmbedder,
11
+ Mlp,
12
+ MoeMLP_DiffMoE as MoeMLP,
13
+ PatchEmbed,
14
+ TimestepEmbedder,
15
+ get_2d_sincos_pos_embed,
16
+ modulate,
17
+ )
18
+
19
+
20
+ class MoEGate(nn.Module):
21
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
22
+ super().__init__()
23
+ self.top_k = num_experts_per_tok
24
+ self.n_routed_experts = num_experts
25
+ self.scoring_func = "softmax"
26
+ self.alpha = aux_loss_alpha
27
+ self.seq_aux = False
28
+ self.norm_topk_prob = False
29
+ self.gating_dim = embed_dim
30
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
31
+ self.reset_parameters()
32
+
33
+ def reset_parameters(self):
34
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
35
+
36
+ def forward(self, hidden_states):
37
+ bsz, seq_len, h = hidden_states.shape
38
+ hidden_states = hidden_states.view(-1, h)
39
+ logits = F.linear(hidden_states, self.weight, None)
40
+ if self.scoring_func != "softmax":
41
+ raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}")
42
+ scores = logits.softmax(dim=-1)
43
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
44
+ if self.top_k > 1 and self.norm_topk_prob:
45
+ topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
46
+
47
+ if self.training and self.alpha > 0.0:
48
+ scores_for_aux = scores
49
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
50
+ if self.seq_aux:
51
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
52
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
53
+ ce.scatter_add_(
54
+ 1,
55
+ topk_idx_for_aux_loss,
56
+ torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device),
57
+ ).div_(seq_len * self.top_k / self.n_routed_experts)
58
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
59
+ else:
60
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
61
+ ce = mask_ce.float().mean(0)
62
+ pi = scores_for_aux.mean(0)
63
+ fi = ce * self.n_routed_experts
64
+ aux_loss = (pi * fi).sum() * self.alpha
65
+ else:
66
+ aux_loss = None
67
+ return topk_idx, topk_weight, aux_loss
68
+
69
+
70
+ class AddAuxiliaryLoss(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, x, loss):
73
+ ctx.dtype = loss.dtype
74
+ ctx.required_aux_loss = loss.requires_grad
75
+ return x
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output):
79
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
80
+ return grad_output, grad_loss
81
+
82
+
83
+ class SparseMoEBlock(nn.Module):
84
+ def __init__(
85
+ self,
86
+ experts,
87
+ hidden_dim,
88
+ mlp_ratio=4,
89
+ num_experts=16,
90
+ num_experts_per_tok=2,
91
+ pretraining_tp=2,
92
+ n_shared_experts=2,
93
+ ):
94
+ super().__init__()
95
+ self.top_k = num_experts_per_tok
96
+ self.experts = nn.ModuleList(experts)
97
+ self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
98
+ self.n_shared_experts = n_shared_experts
99
+ if self.n_shared_experts > 0:
100
+ intermediate_size = hidden_dim * self.n_shared_experts
101
+ self.shared_experts = MoeMLP(
102
+ hidden_size=hidden_dim,
103
+ intermediate_size=intermediate_size,
104
+ pretraining_tp=pretraining_tp,
105
+ )
106
+
107
+ def forward(self, hidden_states):
108
+ identity = hidden_states
109
+ orig_shape = hidden_states.shape
110
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
111
+
112
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
113
+ flat_topk_idx = topk_idx.view(-1)
114
+ if self.training:
115
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
116
+ y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
117
+ for i, expert in enumerate(self.experts):
118
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float()
119
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
120
+ y = y.view(*orig_shape)
121
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
122
+ else:
123
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
124
+ if self.n_shared_experts > 0:
125
+ y = y + self.shared_experts(identity)
126
+ return y
127
+
128
+ @torch.no_grad()
129
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
130
+ expert_cache = torch.zeros_like(x)
131
+ idxs = flat_expert_indices.argsort()
132
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
133
+ token_idxs = idxs // self.top_k
134
+ for i, end_idx in enumerate(tokens_per_expert):
135
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
136
+ if start_idx == end_idx:
137
+ continue
138
+ expert = self.experts[i]
139
+ exp_token_idx = token_idxs[start_idx:end_idx]
140
+ expert_tokens = x[exp_token_idx]
141
+ expert_out = expert(expert_tokens)
142
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
143
+ expert_cache = expert_cache.to(expert_out.dtype)
144
+ expert_cache.scatter_reduce_(
145
+ 0,
146
+ exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
147
+ expert_out,
148
+ reduce="sum",
149
+ )
150
+ return expert_cache
151
+
152
+
153
+ class DiTBlock(nn.Module):
154
+ def __init__(
155
+ self,
156
+ hidden_size,
157
+ num_heads,
158
+ mlp_ratio=4,
159
+ pretraining_tp=2,
160
+ use_swiglu=False,
161
+ MoE_config=None,
162
+ use_moe=True,
163
+ **block_kwargs,
164
+ ):
165
+ super().__init__()
166
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
167
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
168
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
170
+ self.use_moe = use_moe
171
+ if use_moe:
172
+ if not use_swiglu:
173
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
174
+ experts = [
175
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
176
+ for _ in range(MoE_config.num_experts)
177
+ ]
178
+ else:
179
+ experts = [
180
+ MoeMLP(
181
+ hidden_size=hidden_size,
182
+ intermediate_size=mlp_hidden_dim,
183
+ pretraining_tp=pretraining_tp,
184
+ )
185
+ for _ in range(MoE_config.num_experts)
186
+ ]
187
+ self.mlp = SparseMoEBlock(
188
+ experts=experts,
189
+ hidden_dim=hidden_size,
190
+ num_experts=MoE_config.num_experts,
191
+ num_experts_per_tok=MoE_config.capacity,
192
+ n_shared_experts=MoE_config.n_shared_experts,
193
+ )
194
+ else:
195
+ if not use_swiglu:
196
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
197
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
198
+ else:
199
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
200
+
201
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
202
+
203
+ def forward(self, x, c):
204
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
205
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
206
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
207
+ return x
208
+
209
+
210
+ class DiT(nn.Module):
211
+ def __init__(
212
+ self,
213
+ input_size=32,
214
+ patch_size=2,
215
+ in_channels=4,
216
+ hidden_size=1152,
217
+ depth=28,
218
+ num_heads=16,
219
+ mlp_ratio=4,
220
+ qk_norm=False,
221
+ class_dropout_prob=0.1,
222
+ num_classes=1000,
223
+ pretraining_tp=1,
224
+ learn_sigma=True,
225
+ use_swiglu=False,
226
+ MoE_config=None,
227
+ ):
228
+ super().__init__()
229
+ self.learn_sigma = learn_sigma
230
+ self.in_channels = in_channels
231
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
232
+ self.patch_size = patch_size
233
+ self.num_heads = num_heads
234
+ self.MoE_config = MoE_config
235
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
236
+
237
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
238
+ self.t_embedder = TimestepEmbedder(hidden_size)
239
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
240
+ num_patches = self.x_embedder.num_patches
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
242
+
243
+ self.blocks = nn.ModuleList(
244
+ [
245
+ DiTBlock(
246
+ hidden_size,
247
+ num_heads,
248
+ mlp_ratio=mlp_ratio,
249
+ qk_norm=qk_norm,
250
+ use_swiglu=use_swiglu,
251
+ pretraining_tp=pretraining_tp,
252
+ MoE_config=MoE_config,
253
+ use_moe=use_moe_flag[i],
254
+ )
255
+ for i in range(depth)
256
+ ]
257
+ )
258
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
259
+ self.initialize_weights()
260
+
261
+ def initialize_weights(self):
262
+ def _basic_init(module):
263
+ if isinstance(module, nn.Linear):
264
+ torch.nn.init.xavier_uniform_(module.weight)
265
+ if module.bias is not None:
266
+ nn.init.constant_(module.bias, 0)
267
+
268
+ self.apply(_basic_init)
269
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
270
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
271
+ w = self.x_embedder.proj.weight.data
272
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
273
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
274
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
275
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
276
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
277
+ for block in self.blocks:
278
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
279
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
280
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
281
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
282
+ nn.init.constant_(self.final_layer.linear.weight, 0)
283
+ nn.init.constant_(self.final_layer.linear.bias, 0)
284
+
285
+ def unpatchify(self, x):
286
+ c = self.out_channels
287
+ p = self.x_embedder.patch_size[0]
288
+ h = w = int(x.shape[1] ** 0.5)
289
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
290
+ x = torch.einsum("nhwpqc->nchpwq", x)
291
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
292
+
293
+ def forward(self, x, t, context, **kwargs):
294
+ y = context
295
+ if len(x.shape) != 4:
296
+ x = x.squeeze(2)
297
+ x = self.x_embedder(x) + self.pos_embed
298
+ t = self.t_embedder(t)
299
+ y = self.y_embedder(y, self.training)
300
+ c = t + y
301
+ for block in self.blocks:
302
+ x = block(x, c)
303
+ x = self.final_layer(x, c)
304
+ return self.unpatchify(x)
ProMoE-B-256/transformer/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoETransformer2DModel",
3
+ "architecture": "promoe_tc",
4
+ "model_config": {
5
+ "MoE_config": {
6
+ "init_MoeMLP": false,
7
+ "interleave": true,
8
+ "moe_intermediate_size": 1536,
9
+ "num_routed_experts": 12,
10
+ "shared_expert_intermediate_size": 1536,
11
+ "top_k": 1,
12
+ "use_shared_expert": true,
13
+ "use_uncond_expert": true
14
+ },
15
+ "depth": 12,
16
+ "hidden_size": 768,
17
+ "input_size": 32,
18
+ "num_classes": 1000,
19
+ "num_heads": 12,
20
+ "patch_size": 2
21
+ }
22
+ }
ProMoE-B-256/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08f42b814b6cb9b8948665fc996bbb559e6d55b0961253573a8f0d6e4c64fdcd
3
+ size 1202482576
ProMoE-B-256/transformer/modeling_promoe_common.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from dataclasses import dataclass
4
+ from itertools import repeat
5
+ from typing import Any, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
16
+ return tuple(x)
17
+ return tuple(repeat(x, n))
18
+
19
+ return parse
20
+
21
+
22
+ to_2tuple = _ntuple(2)
23
+
24
+
25
+ class AttrDict(dict):
26
+ def __getattr__(self, item):
27
+ try:
28
+ return self[item]
29
+ except KeyError as error:
30
+ raise AttributeError(item) from error
31
+
32
+ def __setattr__(self, key, value):
33
+ self[key] = value
34
+
35
+ @staticmethod
36
+ def from_data(data: Any) -> Any:
37
+ if isinstance(data, dict):
38
+ return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()})
39
+ if isinstance(data, list):
40
+ return [AttrDict.from_data(v) for v in data]
41
+ return data
42
+
43
+
44
+ class PatchEmbed(nn.Module):
45
+ def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True):
46
+ super().__init__()
47
+ self.img_size = to_2tuple(input_size)
48
+ self.patch_size = to_2tuple(patch_size)
49
+ self.grid_size = (
50
+ self.img_size[0] // self.patch_size[0],
51
+ self.img_size[1] // self.patch_size[1],
52
+ )
53
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
54
+ self.proj = nn.Conv2d(
55
+ in_channels,
56
+ embed_dim,
57
+ kernel_size=self.patch_size,
58
+ stride=self.patch_size,
59
+ bias=bias,
60
+ )
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ hidden_states = self.proj(hidden_states)
64
+ return hidden_states.flatten(2).transpose(1, 2)
65
+
66
+
67
+ class Mlp(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_features,
71
+ hidden_features=None,
72
+ out_features=None,
73
+ act_layer=nn.GELU,
74
+ norm_layer=None,
75
+ bias=True,
76
+ drop=0.0,
77
+ ):
78
+ super().__init__()
79
+ out_features = out_features or in_features
80
+ hidden_features = hidden_features or in_features
81
+ bias = to_2tuple(bias)
82
+ drop_probs = to_2tuple(drop)
83
+
84
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
85
+ self.act = act_layer()
86
+ self.drop1 = nn.Dropout(drop_probs[0])
87
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
88
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
89
+ self.drop2 = nn.Dropout(drop_probs[1])
90
+
91
+ def forward(self, x):
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop1(x)
95
+ x = self.norm(x)
96
+ x = self.fc2(x)
97
+ x = self.drop2(x)
98
+ return x
99
+
100
+
101
+ class MoeMLP(nn.Module):
102
+ def __init__(self, hidden_size, intermediate_size):
103
+ super().__init__()
104
+ self.hidden_size = hidden_size
105
+ self.intermediate_size = intermediate_size
106
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
107
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
108
+ self.act_fn = nn.GELU(approximate="tanh")
109
+
110
+ def forward(self, x):
111
+ return self.down_proj(self.act_fn(self.up_proj(x)))
112
+
113
+
114
+ class MoeMLP_DiffMoE(nn.Module):
115
+ def __init__(self, hidden_size, intermediate_size, pretraining_tp=2):
116
+ super().__init__()
117
+ self.hidden_size = hidden_size
118
+ self.intermediate_size = intermediate_size
119
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
120
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
121
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
122
+ self.act_fn = nn.SiLU()
123
+ self.pretraining_tp = pretraining_tp
124
+
125
+ def forward(self, x):
126
+ if self.pretraining_tp > 1:
127
+ split_size = self.intermediate_size // self.pretraining_tp
128
+ gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0)
129
+ up_proj_slices = self.up_proj.weight.split(split_size, dim=0)
130
+ down_proj_slices = self.down_proj.weight.split(split_size, dim=1)
131
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
132
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
133
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1)
134
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
135
+ return sum(down_proj)
136
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
137
+
138
+
139
+ class Attention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ dim: int,
143
+ num_heads: int = 8,
144
+ qkv_bias: bool = False,
145
+ qk_norm: bool = False,
146
+ attn_drop: float = 0.0,
147
+ proj_drop: float = 0.0,
148
+ head_dim=None,
149
+ norm_layer: nn.Module = nn.LayerNorm,
150
+ ):
151
+ super().__init__()
152
+ self.num_heads = num_heads
153
+ if head_dim is None:
154
+ if dim % num_heads != 0:
155
+ raise ValueError("dim must be divisible by num_heads")
156
+ self.head_dim = dim // num_heads
157
+ else:
158
+ self.head_dim = head_dim
159
+ self.scale = self.head_dim**-0.5
160
+ self.fused_attn = True
161
+ self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias)
162
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
163
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
164
+ self.attn_drop = nn.Dropout(attn_drop)
165
+ self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
166
+ self.proj_drop = nn.Dropout(proj_drop)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ batch_size, seq_len, _ = x.shape
170
+ qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
184
+ attn = self.attn_drop(attn)
185
+ x = attn @ v
186
+
187
+ x = x.transpose(1, 2).reshape(batch_size, seq_len, -1)
188
+ x = self.proj(x)
189
+ return self.proj_drop(x)
190
+
191
+
192
+ def modulate(x, shift, scale):
193
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
194
+
195
+
196
+ class TimestepEmbedder(nn.Module):
197
+ def __init__(self, hidden_size, frequency_embedding_size=256):
198
+ super().__init__()
199
+ self.mlp = nn.Sequential(
200
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
201
+ nn.SiLU(),
202
+ nn.Linear(hidden_size, hidden_size, bias=True),
203
+ )
204
+ self.frequency_embedding_size = frequency_embedding_size
205
+
206
+ @staticmethod
207
+ def timestep_embedding(t, dim, max_period=10000):
208
+ half = dim // 2
209
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
210
+ device=t.device
211
+ )
212
+ args = t[:, None].float() * freqs[None]
213
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
214
+ if dim % 2:
215
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
216
+ return embedding
217
+
218
+ def forward(self, t):
219
+ t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size)
220
+ weight_dtype = self.mlp[0].weight.dtype
221
+ return self.mlp(t_freq.to(dtype=weight_dtype))
222
+
223
+
224
+ class LabelEmbedder(nn.Module):
225
+ def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False):
226
+ super().__init__()
227
+ use_cfg_embedding = dropout_prob > 0
228
+ self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size)
229
+ self.num_classes = num_classes
230
+ self.dropout_prob = dropout_prob
231
+ self.return_labels = return_labels
232
+
233
+ def token_drop(self, labels, force_drop_ids=None):
234
+ if force_drop_ids is None:
235
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
236
+ else:
237
+ drop_ids = force_drop_ids == 1
238
+ return torch.where(drop_ids, self.num_classes, labels)
239
+
240
+ def forward(self, labels, train, force_drop_ids=None):
241
+ if (train and self.dropout_prob > 0) or (force_drop_ids is not None):
242
+ labels = self.token_drop(labels, force_drop_ids)
243
+ embeddings = self.embedding_table(labels)
244
+ if self.return_labels:
245
+ return embeddings, labels
246
+ return embeddings
247
+
248
+
249
+ class FinalLayer(nn.Module):
250
+ def __init__(self, hidden_size, patch_size, out_channels):
251
+ super().__init__()
252
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
253
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
254
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
255
+
256
+ def forward(self, x, c):
257
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
258
+ x = modulate(self.norm_final(x), shift, scale)
259
+ return self.linear(x)
260
+
261
+
262
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
263
+ grid_h = np.arange(grid_size, dtype=np.float32)
264
+ grid_w = np.arange(grid_size, dtype=np.float32)
265
+ grid = np.meshgrid(grid_w, grid_h)
266
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
267
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
268
+ if cls_token and extra_tokens > 0:
269
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
270
+ return pos_embed
271
+
272
+
273
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
274
+ if embed_dim % 2 != 0:
275
+ raise ValueError("embed_dim must be even")
276
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
277
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
278
+ return np.concatenate([emb_h, emb_w], axis=1)
279
+
280
+
281
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
282
+ if embed_dim % 2 != 0:
283
+ raise ValueError("embed_dim must be even")
284
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
285
+ omega /= embed_dim / 2.0
286
+ omega = 1.0 / 10000**omega
287
+ pos = pos.reshape(-1)
288
+ out = np.einsum("m,d->md", pos, omega)
289
+ emb_sin = np.sin(out)
290
+ emb_cos = np.cos(out)
291
+ return np.concatenate([emb_sin, emb_cos], axis=1)
ProMoE-B-256/transformer/transformer_promoe.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ try:
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.utils import BaseOutput
11
+ except Exception: # pragma: no cover
12
+ class BaseOutput(dict):
13
+ def __post_init__(self):
14
+ self.update(self.__dict__)
15
+
16
+ class _Config(dict):
17
+ def __getattr__(self, key):
18
+ try:
19
+ return self[key]
20
+ except KeyError as error:
21
+ raise AttributeError(key) from error
22
+
23
+ class ConfigMixin:
24
+ config_name = "config.json"
25
+
26
+ class ModelMixin(nn.Module):
27
+ pass
28
+
29
+ def register_to_config(init):
30
+ def wrapper(self, *args, **kwargs):
31
+ import inspect
32
+
33
+ signature = inspect.signature(init)
34
+ bound = signature.bind(self, *args, **kwargs)
35
+ bound.apply_defaults()
36
+ self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
37
+ init(self, *args, **kwargs)
38
+
39
+ return wrapper
40
+
41
+ from .backbone_diffmoe import DiT as DiffMoEBackbone
42
+ from .backbone_dit import DiT as DiTBackbone
43
+ from .backbone_ecdit import DiT as ECDiTBackbone
44
+ from .backbone_promoe_ec import DiT as ProMoEECBackbone
45
+ from .backbone_promoe_tc import DiT as ProMoETCBackbone
46
+ from .backbone_tcdit import DiT as TCDiTBackbone
47
+ from .modeling_promoe_common import AttrDict
48
+
49
+
50
+ @dataclass
51
+ class ProMoETransformer2DModelOutput(BaseOutput):
52
+ sample: torch.FloatTensor
53
+ loss_strategy: Optional[str] = None
54
+ layer_idx_list: Optional[Tuple[int, ...]] = None
55
+ ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None
56
+ pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None
57
+ capacity_pred_loss_weight: Optional[float] = None
58
+
59
+
60
+ _BACKBONES = {
61
+ "dit": DiTBackbone,
62
+ "tcdit": TCDiTBackbone,
63
+ "ecdit": ECDiTBackbone,
64
+ "diffmoe": DiffMoEBackbone,
65
+ "promoe_tc": ProMoETCBackbone,
66
+ "promoe_ec": ProMoEECBackbone,
67
+ }
68
+
69
+
70
+ class ProMoETransformer2DModel(ModelMixin, ConfigMixin):
71
+ config_name = "config.json"
72
+
73
+ @register_to_config
74
+ def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None):
75
+ super().__init__()
76
+ if architecture not in _BACKBONES:
77
+ raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}")
78
+ model_config = model_config or {}
79
+ self.architecture = architecture
80
+ self.model_config = model_config
81
+ self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config))
82
+ self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4))
83
+ self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4))
84
+
85
+ def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]:
86
+ prepared = {}
87
+ for key, value in model_config.items():
88
+ prepared[key] = AttrDict.from_data(value)
89
+ return prepared
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ timestep: Union[torch.Tensor, float, int],
95
+ class_labels: Optional[torch.LongTensor] = None,
96
+ context: Optional[torch.LongTensor] = None,
97
+ return_dict: bool = True,
98
+ **kwargs,
99
+ ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]:
100
+ labels = class_labels if class_labels is not None else context
101
+ if labels is None:
102
+ raise ValueError("Either `class_labels` or `context` must be provided.")
103
+
104
+ if not torch.is_tensor(timestep):
105
+ timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype)
106
+ timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten()
107
+ if timestep.numel() == 1:
108
+ timestep = timestep.repeat(labels.shape[0])
109
+
110
+ sample = self.backbone(hidden_states, timestep, labels, **kwargs)
111
+ if isinstance(sample, tuple):
112
+ if len(sample) == 6 and sample[1] == "Capacity_Pred":
113
+ output = ProMoETransformer2DModelOutput(
114
+ sample=sample[0],
115
+ loss_strategy=sample[1],
116
+ layer_idx_list=tuple(sample[2]),
117
+ ones_list=tuple(sample[3]),
118
+ pred_c_list=tuple(sample[4]),
119
+ capacity_pred_loss_weight=float(sample[5]),
120
+ )
121
+ else:
122
+ output = ProMoETransformer2DModelOutput(sample=sample[0])
123
+ else:
124
+ output = ProMoETransformer2DModelOutput(sample=sample)
125
+
126
+ if not return_dict:
127
+ if output.loss_strategy is None:
128
+ return (output.sample,)
129
+ return (
130
+ output.sample,
131
+ output.loss_strategy,
132
+ output.layer_idx_list,
133
+ output.ones_list,
134
+ output.pred_c_list,
135
+ output.capacity_pred_loss_weight,
136
+ )
137
+ return output
ProMoE-B-256/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.4.2",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 256,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
ProMoE-B-256/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
ProMoE-L-256/model_index.json ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "ProMoEPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "id2label": {
8
+ "0": "tench, Tinca tinca",
9
+ "1": "goldfish, Carassius auratus",
10
+ "10": "brambling, Fringilla montifringilla",
11
+ "100": "black swan, Cygnus atratus",
12
+ "101": "tusker",
13
+ "102": "echidna, spiny anteater, anteater",
14
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
15
+ "104": "wallaby, brush kangaroo",
16
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
17
+ "106": "wombat",
18
+ "107": "jellyfish",
19
+ "108": "sea anemone, anemone",
20
+ "109": "brain coral",
21
+ "11": "goldfinch, Carduelis carduelis",
22
+ "110": "flatworm, platyhelminth",
23
+ "111": "nematode, nematode worm, roundworm",
24
+ "112": "conch",
25
+ "113": "snail",
26
+ "114": "slug",
27
+ "115": "sea slug, nudibranch",
28
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
29
+ "117": "chambered nautilus, pearly nautilus, nautilus",
30
+ "118": "Dungeness crab, Cancer magister",
31
+ "119": "rock crab, Cancer irroratus",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "120": "fiddler crab",
34
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
35
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
36
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
37
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
38
+ "125": "hermit crab",
39
+ "126": "isopod",
40
+ "127": "white stork, Ciconia ciconia",
41
+ "128": "black stork, Ciconia nigra",
42
+ "129": "spoonbill",
43
+ "13": "junco, snowbird",
44
+ "130": "flamingo",
45
+ "131": "little blue heron, Egretta caerulea",
46
+ "132": "American egret, great white heron, Egretta albus",
47
+ "133": "bittern",
48
+ "134": "crane",
49
+ "135": "limpkin, Aramus pictus",
50
+ "136": "European gallinule, Porphyrio porphyrio",
51
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
52
+ "138": "bustard",
53
+ "139": "ruddy turnstone, Arenaria interpres",
54
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
55
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
56
+ "141": "redshank, Tringa totanus",
57
+ "142": "dowitcher",
58
+ "143": "oystercatcher, oyster catcher",
59
+ "144": "pelican",
60
+ "145": "king penguin, Aptenodytes patagonica",
61
+ "146": "albatross, mollymawk",
62
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
63
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
64
+ "149": "dugong, Dugong dugon",
65
+ "15": "robin, American robin, Turdus migratorius",
66
+ "150": "sea lion",
67
+ "151": "Chihuahua",
68
+ "152": "Japanese spaniel",
69
+ "153": "Maltese dog, Maltese terrier, Maltese",
70
+ "154": "Pekinese, Pekingese, Peke",
71
+ "155": "Shih-Tzu",
72
+ "156": "Blenheim spaniel",
73
+ "157": "papillon",
74
+ "158": "toy terrier",
75
+ "159": "Rhodesian ridgeback",
76
+ "16": "bulbul",
77
+ "160": "Afghan hound, Afghan",
78
+ "161": "basset, basset hound",
79
+ "162": "beagle",
80
+ "163": "bloodhound, sleuthhound",
81
+ "164": "bluetick",
82
+ "165": "black-and-tan coonhound",
83
+ "166": "Walker hound, Walker foxhound",
84
+ "167": "English foxhound",
85
+ "168": "redbone",
86
+ "169": "borzoi, Russian wolfhound",
87
+ "17": "jay",
88
+ "170": "Irish wolfhound",
89
+ "171": "Italian greyhound",
90
+ "172": "whippet",
91
+ "173": "Ibizan hound, Ibizan Podenco",
92
+ "174": "Norwegian elkhound, elkhound",
93
+ "175": "otterhound, otter hound",
94
+ "176": "Saluki, gazelle hound",
95
+ "177": "Scottish deerhound, deerhound",
96
+ "178": "Weimaraner",
97
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
98
+ "18": "magpie",
99
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
100
+ "181": "Bedlington terrier",
101
+ "182": "Border terrier",
102
+ "183": "Kerry blue terrier",
103
+ "184": "Irish terrier",
104
+ "185": "Norfolk terrier",
105
+ "186": "Norwich terrier",
106
+ "187": "Yorkshire terrier",
107
+ "188": "wire-haired fox terrier",
108
+ "189": "Lakeland terrier",
109
+ "19": "chickadee",
110
+ "190": "Sealyham terrier, Sealyham",
111
+ "191": "Airedale, Airedale terrier",
112
+ "192": "cairn, cairn terrier",
113
+ "193": "Australian terrier",
114
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
115
+ "195": "Boston bull, Boston terrier",
116
+ "196": "miniature schnauzer",
117
+ "197": "giant schnauzer",
118
+ "198": "standard schnauzer",
119
+ "199": "Scotch terrier, Scottish terrier, Scottie",
120
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
121
+ "20": "water ouzel, dipper",
122
+ "200": "Tibetan terrier, chrysanthemum dog",
123
+ "201": "silky terrier, Sydney silky",
124
+ "202": "soft-coated wheaten terrier",
125
+ "203": "West Highland white terrier",
126
+ "204": "Lhasa, Lhasa apso",
127
+ "205": "flat-coated retriever",
128
+ "206": "curly-coated retriever",
129
+ "207": "golden retriever",
130
+ "208": "Labrador retriever",
131
+ "209": "Chesapeake Bay retriever",
132
+ "21": "kite",
133
+ "210": "German short-haired pointer",
134
+ "211": "vizsla, Hungarian pointer",
135
+ "212": "English setter",
136
+ "213": "Irish setter, red setter",
137
+ "214": "Gordon setter",
138
+ "215": "Brittany spaniel",
139
+ "216": "clumber, clumber spaniel",
140
+ "217": "English springer, English springer spaniel",
141
+ "218": "Welsh springer spaniel",
142
+ "219": "cocker spaniel, English cocker spaniel, cocker",
143
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
144
+ "220": "Sussex spaniel",
145
+ "221": "Irish water spaniel",
146
+ "222": "kuvasz",
147
+ "223": "schipperke",
148
+ "224": "groenendael",
149
+ "225": "malinois",
150
+ "226": "briard",
151
+ "227": "kelpie",
152
+ "228": "komondor",
153
+ "229": "Old English sheepdog, bobtail",
154
+ "23": "vulture",
155
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
156
+ "231": "collie",
157
+ "232": "Border collie",
158
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
159
+ "234": "Rottweiler",
160
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
161
+ "236": "Doberman, Doberman pinscher",
162
+ "237": "miniature pinscher",
163
+ "238": "Greater Swiss Mountain dog",
164
+ "239": "Bernese mountain dog",
165
+ "24": "great grey owl, great gray owl, Strix nebulosa",
166
+ "240": "Appenzeller",
167
+ "241": "EntleBucher",
168
+ "242": "boxer",
169
+ "243": "bull mastiff",
170
+ "244": "Tibetan mastiff",
171
+ "245": "French bulldog",
172
+ "246": "Great Dane",
173
+ "247": "Saint Bernard, St Bernard",
174
+ "248": "Eskimo dog, husky",
175
+ "249": "malamute, malemute, Alaskan malamute",
176
+ "25": "European fire salamander, Salamandra salamandra",
177
+ "250": "Siberian husky",
178
+ "251": "dalmatian, coach dog, carriage dog",
179
+ "252": "affenpinscher, monkey pinscher, monkey dog",
180
+ "253": "basenji",
181
+ "254": "pug, pug-dog",
182
+ "255": "Leonberg",
183
+ "256": "Newfoundland, Newfoundland dog",
184
+ "257": "Great Pyrenees",
185
+ "258": "Samoyed, Samoyede",
186
+ "259": "Pomeranian",
187
+ "26": "common newt, Triturus vulgaris",
188
+ "260": "chow, chow chow",
189
+ "261": "keeshond",
190
+ "262": "Brabancon griffon",
191
+ "263": "Pembroke, Pembroke Welsh corgi",
192
+ "264": "Cardigan, Cardigan Welsh corgi",
193
+ "265": "toy poodle",
194
+ "266": "miniature poodle",
195
+ "267": "standard poodle",
196
+ "268": "Mexican hairless",
197
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
198
+ "27": "eft",
199
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
200
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
201
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
202
+ "273": "dingo, warrigal, warragal, Canis dingo",
203
+ "274": "dhole, Cuon alpinus",
204
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
205
+ "276": "hyena, hyaena",
206
+ "277": "red fox, Vulpes vulpes",
207
+ "278": "kit fox, Vulpes macrotis",
208
+ "279": "Arctic fox, white fox, Alopex lagopus",
209
+ "28": "spotted salamander, Ambystoma maculatum",
210
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
211
+ "281": "tabby, tabby cat",
212
+ "282": "tiger cat",
213
+ "283": "Persian cat",
214
+ "284": "Siamese cat, Siamese",
215
+ "285": "Egyptian cat",
216
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
217
+ "287": "lynx, catamount",
218
+ "288": "leopard, Panthera pardus",
219
+ "289": "snow leopard, ounce, Panthera uncia",
220
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
221
+ "290": "jaguar, panther, Panthera onca, Felis onca",
222
+ "291": "lion, king of beasts, Panthera leo",
223
+ "292": "tiger, Panthera tigris",
224
+ "293": "cheetah, chetah, Acinonyx jubatus",
225
+ "294": "brown bear, bruin, Ursus arctos",
226
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
227
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
228
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
229
+ "298": "mongoose",
230
+ "299": "meerkat, mierkat",
231
+ "3": "tiger shark, Galeocerdo cuvieri",
232
+ "30": "bullfrog, Rana catesbeiana",
233
+ "300": "tiger beetle",
234
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
235
+ "302": "ground beetle, carabid beetle",
236
+ "303": "long-horned beetle, longicorn, longicorn beetle",
237
+ "304": "leaf beetle, chrysomelid",
238
+ "305": "dung beetle",
239
+ "306": "rhinoceros beetle",
240
+ "307": "weevil",
241
+ "308": "fly",
242
+ "309": "bee",
243
+ "31": "tree frog, tree-frog",
244
+ "310": "ant, emmet, pismire",
245
+ "311": "grasshopper, hopper",
246
+ "312": "cricket",
247
+ "313": "walking stick, walkingstick, stick insect",
248
+ "314": "cockroach, roach",
249
+ "315": "mantis, mantid",
250
+ "316": "cicada, cicala",
251
+ "317": "leafhopper",
252
+ "318": "lacewing, lacewing fly",
253
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
254
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
255
+ "320": "damselfly",
256
+ "321": "admiral",
257
+ "322": "ringlet, ringlet butterfly",
258
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
259
+ "324": "cabbage butterfly",
260
+ "325": "sulphur butterfly, sulfur butterfly",
261
+ "326": "lycaenid, lycaenid butterfly",
262
+ "327": "starfish, sea star",
263
+ "328": "sea urchin",
264
+ "329": "sea cucumber, holothurian",
265
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
266
+ "330": "wood rabbit, cottontail, cottontail rabbit",
267
+ "331": "hare",
268
+ "332": "Angora, Angora rabbit",
269
+ "333": "hamster",
270
+ "334": "porcupine, hedgehog",
271
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
272
+ "336": "marmot",
273
+ "337": "beaver",
274
+ "338": "guinea pig, Cavia cobaya",
275
+ "339": "sorrel",
276
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
277
+ "340": "zebra",
278
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
279
+ "342": "wild boar, boar, Sus scrofa",
280
+ "343": "warthog",
281
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
282
+ "345": "ox",
283
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
284
+ "347": "bison",
285
+ "348": "ram, tup",
286
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
287
+ "35": "mud turtle",
288
+ "350": "ibex, Capra ibex",
289
+ "351": "hartebeest",
290
+ "352": "impala, Aepyceros melampus",
291
+ "353": "gazelle",
292
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
293
+ "355": "llama",
294
+ "356": "weasel",
295
+ "357": "mink",
296
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
297
+ "359": "black-footed ferret, ferret, Mustela nigripes",
298
+ "36": "terrapin",
299
+ "360": "otter",
300
+ "361": "skunk, polecat, wood pussy",
301
+ "362": "badger",
302
+ "363": "armadillo",
303
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
304
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
305
+ "366": "gorilla, Gorilla gorilla",
306
+ "367": "chimpanzee, chimp, Pan troglodytes",
307
+ "368": "gibbon, Hylobates lar",
308
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
309
+ "37": "box turtle, box tortoise",
310
+ "370": "guenon, guenon monkey",
311
+ "371": "patas, hussar monkey, Erythrocebus patas",
312
+ "372": "baboon",
313
+ "373": "macaque",
314
+ "374": "langur",
315
+ "375": "colobus, colobus monkey",
316
+ "376": "proboscis monkey, Nasalis larvatus",
317
+ "377": "marmoset",
318
+ "378": "capuchin, ringtail, Cebus capucinus",
319
+ "379": "howler monkey, howler",
320
+ "38": "banded gecko",
321
+ "380": "titi, titi monkey",
322
+ "381": "spider monkey, Ateles geoffroyi",
323
+ "382": "squirrel monkey, Saimiri sciureus",
324
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
325
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
326
+ "385": "Indian elephant, Elephas maximus",
327
+ "386": "African elephant, Loxodonta africana",
328
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
329
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
330
+ "389": "barracouta, snoek",
331
+ "39": "common iguana, iguana, Iguana iguana",
332
+ "390": "eel",
333
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
334
+ "392": "rock beauty, Holocanthus tricolor",
335
+ "393": "anemone fish",
336
+ "394": "sturgeon",
337
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
338
+ "396": "lionfish",
339
+ "397": "puffer, pufferfish, blowfish, globefish",
340
+ "398": "abacus",
341
+ "399": "abaya",
342
+ "4": "hammerhead, hammerhead shark",
343
+ "40": "American chameleon, anole, Anolis carolinensis",
344
+ "400": "academic gown, academic robe, judge robe",
345
+ "401": "accordion, piano accordion, squeeze box",
346
+ "402": "acoustic guitar",
347
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
348
+ "404": "airliner",
349
+ "405": "airship, dirigible",
350
+ "406": "altar",
351
+ "407": "ambulance",
352
+ "408": "amphibian, amphibious vehicle",
353
+ "409": "analog clock",
354
+ "41": "whiptail, whiptail lizard",
355
+ "410": "apiary, bee house",
356
+ "411": "apron",
357
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
358
+ "413": "assault rifle, assault gun",
359
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
360
+ "415": "bakery, bakeshop, bakehouse",
361
+ "416": "balance beam, beam",
362
+ "417": "balloon",
363
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
364
+ "419": "Band Aid",
365
+ "42": "agama",
366
+ "420": "banjo",
367
+ "421": "bannister, banister, balustrade, balusters, handrail",
368
+ "422": "barbell",
369
+ "423": "barber chair",
370
+ "424": "barbershop",
371
+ "425": "barn",
372
+ "426": "barometer",
373
+ "427": "barrel, cask",
374
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
375
+ "429": "baseball",
376
+ "43": "frilled lizard, Chlamydosaurus kingi",
377
+ "430": "basketball",
378
+ "431": "bassinet",
379
+ "432": "bassoon",
380
+ "433": "bathing cap, swimming cap",
381
+ "434": "bath towel",
382
+ "435": "bathtub, bathing tub, bath, tub",
383
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
384
+ "437": "beacon, lighthouse, beacon light, pharos",
385
+ "438": "beaker",
386
+ "439": "bearskin, busby, shako",
387
+ "44": "alligator lizard",
388
+ "440": "beer bottle",
389
+ "441": "beer glass",
390
+ "442": "bell cote, bell cot",
391
+ "443": "bib",
392
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
393
+ "445": "bikini, two-piece",
394
+ "446": "binder, ring-binder",
395
+ "447": "binoculars, field glasses, opera glasses",
396
+ "448": "birdhouse",
397
+ "449": "boathouse",
398
+ "45": "Gila monster, Heloderma suspectum",
399
+ "450": "bobsled, bobsleigh, bob",
400
+ "451": "bolo tie, bolo, bola tie, bola",
401
+ "452": "bonnet, poke bonnet",
402
+ "453": "bookcase",
403
+ "454": "bookshop, bookstore, bookstall",
404
+ "455": "bottlecap",
405
+ "456": "bow",
406
+ "457": "bow tie, bow-tie, bowtie",
407
+ "458": "brass, memorial tablet, plaque",
408
+ "459": "brassiere, bra, bandeau",
409
+ "46": "green lizard, Lacerta viridis",
410
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
411
+ "461": "breastplate, aegis, egis",
412
+ "462": "broom",
413
+ "463": "bucket, pail",
414
+ "464": "buckle",
415
+ "465": "bulletproof vest",
416
+ "466": "bullet train, bullet",
417
+ "467": "butcher shop, meat market",
418
+ "468": "cab, hack, taxi, taxicab",
419
+ "469": "caldron, cauldron",
420
+ "47": "African chameleon, Chamaeleo chamaeleon",
421
+ "470": "candle, taper, wax light",
422
+ "471": "cannon",
423
+ "472": "canoe",
424
+ "473": "can opener, tin opener",
425
+ "474": "cardigan",
426
+ "475": "car mirror",
427
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
428
+ "477": "carpenters kit, tool kit",
429
+ "478": "carton",
430
+ "479": "car wheel",
431
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
432
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
433
+ "481": "cassette",
434
+ "482": "cassette player",
435
+ "483": "castle",
436
+ "484": "catamaran",
437
+ "485": "CD player",
438
+ "486": "cello, violoncello",
439
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
440
+ "488": "chain",
441
+ "489": "chainlink fence",
442
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
443
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
444
+ "491": "chain saw, chainsaw",
445
+ "492": "chest",
446
+ "493": "chiffonier, commode",
447
+ "494": "chime, bell, gong",
448
+ "495": "china cabinet, china closet",
449
+ "496": "Christmas stocking",
450
+ "497": "church, church building",
451
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
452
+ "499": "cleaver, meat cleaver, chopper",
453
+ "5": "electric ray, crampfish, numbfish, torpedo",
454
+ "50": "American alligator, Alligator mississipiensis",
455
+ "500": "cliff dwelling",
456
+ "501": "cloak",
457
+ "502": "clog, geta, patten, sabot",
458
+ "503": "cocktail shaker",
459
+ "504": "coffee mug",
460
+ "505": "coffeepot",
461
+ "506": "coil, spiral, volute, whorl, helix",
462
+ "507": "combination lock",
463
+ "508": "computer keyboard, keypad",
464
+ "509": "confectionery, confectionary, candy store",
465
+ "51": "triceratops",
466
+ "510": "container ship, containership, container vessel",
467
+ "511": "convertible",
468
+ "512": "corkscrew, bottle screw",
469
+ "513": "cornet, horn, trumpet, trump",
470
+ "514": "cowboy boot",
471
+ "515": "cowboy hat, ten-gallon hat",
472
+ "516": "cradle",
473
+ "517": "crane",
474
+ "518": "crash helmet",
475
+ "519": "crate",
476
+ "52": "thunder snake, worm snake, Carphophis amoenus",
477
+ "520": "crib, cot",
478
+ "521": "Crock Pot",
479
+ "522": "croquet ball",
480
+ "523": "crutch",
481
+ "524": "cuirass",
482
+ "525": "dam, dike, dyke",
483
+ "526": "desk",
484
+ "527": "desktop computer",
485
+ "528": "dial telephone, dial phone",
486
+ "529": "diaper, nappy, napkin",
487
+ "53": "ringneck snake, ring-necked snake, ring snake",
488
+ "530": "digital clock",
489
+ "531": "digital watch",
490
+ "532": "dining table, board",
491
+ "533": "dishrag, dishcloth",
492
+ "534": "dishwasher, dish washer, dishwashing machine",
493
+ "535": "disk brake, disc brake",
494
+ "536": "dock, dockage, docking facility",
495
+ "537": "dogsled, dog sled, dog sleigh",
496
+ "538": "dome",
497
+ "539": "doormat, welcome mat",
498
+ "54": "hognose snake, puff adder, sand viper",
499
+ "540": "drilling platform, offshore rig",
500
+ "541": "drum, membranophone, tympan",
501
+ "542": "drumstick",
502
+ "543": "dumbbell",
503
+ "544": "Dutch oven",
504
+ "545": "electric fan, blower",
505
+ "546": "electric guitar",
506
+ "547": "electric locomotive",
507
+ "548": "entertainment center",
508
+ "549": "envelope",
509
+ "55": "green snake, grass snake",
510
+ "550": "espresso maker",
511
+ "551": "face powder",
512
+ "552": "feather boa, boa",
513
+ "553": "file, file cabinet, filing cabinet",
514
+ "554": "fireboat",
515
+ "555": "fire engine, fire truck",
516
+ "556": "fire screen, fireguard",
517
+ "557": "flagpole, flagstaff",
518
+ "558": "flute, transverse flute",
519
+ "559": "folding chair",
520
+ "56": "king snake, kingsnake",
521
+ "560": "football helmet",
522
+ "561": "forklift",
523
+ "562": "fountain",
524
+ "563": "fountain pen",
525
+ "564": "four-poster",
526
+ "565": "freight car",
527
+ "566": "French horn, horn",
528
+ "567": "frying pan, frypan, skillet",
529
+ "568": "fur coat",
530
+ "569": "garbage truck, dustcart",
531
+ "57": "garter snake, grass snake",
532
+ "570": "gasmask, respirator, gas helmet",
533
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
534
+ "572": "goblet",
535
+ "573": "go-kart",
536
+ "574": "golf ball",
537
+ "575": "golfcart, golf cart",
538
+ "576": "gondola",
539
+ "577": "gong, tam-tam",
540
+ "578": "gown",
541
+ "579": "grand piano, grand",
542
+ "58": "water snake",
543
+ "580": "greenhouse, nursery, glasshouse",
544
+ "581": "grille, radiator grille",
545
+ "582": "grocery store, grocery, food market, market",
546
+ "583": "guillotine",
547
+ "584": "hair slide",
548
+ "585": "hair spray",
549
+ "586": "half track",
550
+ "587": "hammer",
551
+ "588": "hamper",
552
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
553
+ "59": "vine snake",
554
+ "590": "hand-held computer, hand-held microcomputer",
555
+ "591": "handkerchief, hankie, hanky, hankey",
556
+ "592": "hard disc, hard disk, fixed disk",
557
+ "593": "harmonica, mouth organ, harp, mouth harp",
558
+ "594": "harp",
559
+ "595": "harvester, reaper",
560
+ "596": "hatchet",
561
+ "597": "holster",
562
+ "598": "home theater, home theatre",
563
+ "599": "honeycomb",
564
+ "6": "stingray",
565
+ "60": "night snake, Hypsiglena torquata",
566
+ "600": "hook, claw",
567
+ "601": "hoopskirt, crinoline",
568
+ "602": "horizontal bar, high bar",
569
+ "603": "horse cart, horse-cart",
570
+ "604": "hourglass",
571
+ "605": "iPod",
572
+ "606": "iron, smoothing iron",
573
+ "607": "jack-o-lantern",
574
+ "608": "jean, blue jean, denim",
575
+ "609": "jeep, landrover",
576
+ "61": "boa constrictor, Constrictor constrictor",
577
+ "610": "jersey, T-shirt, tee shirt",
578
+ "611": "jigsaw puzzle",
579
+ "612": "jinrikisha, ricksha, rickshaw",
580
+ "613": "joystick",
581
+ "614": "kimono",
582
+ "615": "knee pad",
583
+ "616": "knot",
584
+ "617": "lab coat, laboratory coat",
585
+ "618": "ladle",
586
+ "619": "lampshade, lamp shade",
587
+ "62": "rock python, rock snake, Python sebae",
588
+ "620": "laptop, laptop computer",
589
+ "621": "lawn mower, mower",
590
+ "622": "lens cap, lens cover",
591
+ "623": "letter opener, paper knife, paperknife",
592
+ "624": "library",
593
+ "625": "lifeboat",
594
+ "626": "lighter, light, igniter, ignitor",
595
+ "627": "limousine, limo",
596
+ "628": "liner, ocean liner",
597
+ "629": "lipstick, lip rouge",
598
+ "63": "Indian cobra, Naja naja",
599
+ "630": "Loafer",
600
+ "631": "lotion",
601
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
602
+ "633": "loupe, jewelers loupe",
603
+ "634": "lumbermill, sawmill",
604
+ "635": "magnetic compass",
605
+ "636": "mailbag, postbag",
606
+ "637": "mailbox, letter box",
607
+ "638": "maillot",
608
+ "639": "maillot, tank suit",
609
+ "64": "green mamba",
610
+ "640": "manhole cover",
611
+ "641": "maraca",
612
+ "642": "marimba, xylophone",
613
+ "643": "mask",
614
+ "644": "matchstick",
615
+ "645": "maypole",
616
+ "646": "maze, labyrinth",
617
+ "647": "measuring cup",
618
+ "648": "medicine chest, medicine cabinet",
619
+ "649": "megalith, megalithic structure",
620
+ "65": "sea snake",
621
+ "650": "microphone, mike",
622
+ "651": "microwave, microwave oven",
623
+ "652": "military uniform",
624
+ "653": "milk can",
625
+ "654": "minibus",
626
+ "655": "miniskirt, mini",
627
+ "656": "minivan",
628
+ "657": "missile",
629
+ "658": "mitten",
630
+ "659": "mixing bowl",
631
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
632
+ "660": "mobile home, manufactured home",
633
+ "661": "Model T",
634
+ "662": "modem",
635
+ "663": "monastery",
636
+ "664": "monitor",
637
+ "665": "moped",
638
+ "666": "mortar",
639
+ "667": "mortarboard",
640
+ "668": "mosque",
641
+ "669": "mosquito net",
642
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
643
+ "670": "motor scooter, scooter",
644
+ "671": "mountain bike, all-terrain bike, off-roader",
645
+ "672": "mountain tent",
646
+ "673": "mouse, computer mouse",
647
+ "674": "mousetrap",
648
+ "675": "moving van",
649
+ "676": "muzzle",
650
+ "677": "nail",
651
+ "678": "neck brace",
652
+ "679": "necklace",
653
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
654
+ "680": "nipple",
655
+ "681": "notebook, notebook computer",
656
+ "682": "obelisk",
657
+ "683": "oboe, hautboy, hautbois",
658
+ "684": "ocarina, sweet potato",
659
+ "685": "odometer, hodometer, mileometer, milometer",
660
+ "686": "oil filter",
661
+ "687": "organ, pipe organ",
662
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
663
+ "689": "overskirt",
664
+ "69": "trilobite",
665
+ "690": "oxcart",
666
+ "691": "oxygen mask",
667
+ "692": "packet",
668
+ "693": "paddle, boat paddle",
669
+ "694": "paddlewheel, paddle wheel",
670
+ "695": "padlock",
671
+ "696": "paintbrush",
672
+ "697": "pajama, pyjama, pjs, jammies",
673
+ "698": "palace",
674
+ "699": "panpipe, pandean pipe, syrinx",
675
+ "7": "cock",
676
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
677
+ "700": "paper towel",
678
+ "701": "parachute, chute",
679
+ "702": "parallel bars, bars",
680
+ "703": "park bench",
681
+ "704": "parking meter",
682
+ "705": "passenger car, coach, carriage",
683
+ "706": "patio, terrace",
684
+ "707": "pay-phone, pay-station",
685
+ "708": "pedestal, plinth, footstall",
686
+ "709": "pencil box, pencil case",
687
+ "71": "scorpion",
688
+ "710": "pencil sharpener",
689
+ "711": "perfume, essence",
690
+ "712": "Petri dish",
691
+ "713": "photocopier",
692
+ "714": "pick, plectrum, plectron",
693
+ "715": "pickelhaube",
694
+ "716": "picket fence, paling",
695
+ "717": "pickup, pickup truck",
696
+ "718": "pier",
697
+ "719": "piggy bank, penny bank",
698
+ "72": "black and gold garden spider, Argiope aurantia",
699
+ "720": "pill bottle",
700
+ "721": "pillow",
701
+ "722": "ping-pong ball",
702
+ "723": "pinwheel",
703
+ "724": "pirate, pirate ship",
704
+ "725": "pitcher, ewer",
705
+ "726": "plane, carpenters plane, woodworking plane",
706
+ "727": "planetarium",
707
+ "728": "plastic bag",
708
+ "729": "plate rack",
709
+ "73": "barn spider, Araneus cavaticus",
710
+ "730": "plow, plough",
711
+ "731": "plunger, plumbers helper",
712
+ "732": "Polaroid camera, Polaroid Land camera",
713
+ "733": "pole",
714
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
715
+ "735": "poncho",
716
+ "736": "pool table, billiard table, snooker table",
717
+ "737": "pop bottle, soda bottle",
718
+ "738": "pot, flowerpot",
719
+ "739": "potters wheel",
720
+ "74": "garden spider, Aranea diademata",
721
+ "740": "power drill",
722
+ "741": "prayer rug, prayer mat",
723
+ "742": "printer",
724
+ "743": "prison, prison house",
725
+ "744": "projectile, missile",
726
+ "745": "projector",
727
+ "746": "puck, hockey puck",
728
+ "747": "punching bag, punch bag, punching ball, punchball",
729
+ "748": "purse",
730
+ "749": "quill, quill pen",
731
+ "75": "black widow, Latrodectus mactans",
732
+ "750": "quilt, comforter, comfort, puff",
733
+ "751": "racer, race car, racing car",
734
+ "752": "racket, racquet",
735
+ "753": "radiator",
736
+ "754": "radio, wireless",
737
+ "755": "radio telescope, radio reflector",
738
+ "756": "rain barrel",
739
+ "757": "recreational vehicle, RV, R.V.",
740
+ "758": "reel",
741
+ "759": "reflex camera",
742
+ "76": "tarantula",
743
+ "760": "refrigerator, icebox",
744
+ "761": "remote control, remote",
745
+ "762": "restaurant, eating house, eating place, eatery",
746
+ "763": "revolver, six-gun, six-shooter",
747
+ "764": "rifle",
748
+ "765": "rocking chair, rocker",
749
+ "766": "rotisserie",
750
+ "767": "rubber eraser, rubber, pencil eraser",
751
+ "768": "rugby ball",
752
+ "769": "rule, ruler",
753
+ "77": "wolf spider, hunting spider",
754
+ "770": "running shoe",
755
+ "771": "safe",
756
+ "772": "safety pin",
757
+ "773": "saltshaker, salt shaker",
758
+ "774": "sandal",
759
+ "775": "sarong",
760
+ "776": "sax, saxophone",
761
+ "777": "scabbard",
762
+ "778": "scale, weighing machine",
763
+ "779": "school bus",
764
+ "78": "tick",
765
+ "780": "schooner",
766
+ "781": "scoreboard",
767
+ "782": "screen, CRT screen",
768
+ "783": "screw",
769
+ "784": "screwdriver",
770
+ "785": "seat belt, seatbelt",
771
+ "786": "sewing machine",
772
+ "787": "shield, buckler",
773
+ "788": "shoe shop, shoe-shop, shoe store",
774
+ "789": "shoji",
775
+ "79": "centipede",
776
+ "790": "shopping basket",
777
+ "791": "shopping cart",
778
+ "792": "shovel",
779
+ "793": "shower cap",
780
+ "794": "shower curtain",
781
+ "795": "ski",
782
+ "796": "ski mask",
783
+ "797": "sleeping bag",
784
+ "798": "slide rule, slipstick",
785
+ "799": "sliding door",
786
+ "8": "hen",
787
+ "80": "black grouse",
788
+ "800": "slot, one-armed bandit",
789
+ "801": "snorkel",
790
+ "802": "snowmobile",
791
+ "803": "snowplow, snowplough",
792
+ "804": "soap dispenser",
793
+ "805": "soccer ball",
794
+ "806": "sock",
795
+ "807": "solar dish, solar collector, solar furnace",
796
+ "808": "sombrero",
797
+ "809": "soup bowl",
798
+ "81": "ptarmigan",
799
+ "810": "space bar",
800
+ "811": "space heater",
801
+ "812": "space shuttle",
802
+ "813": "spatula",
803
+ "814": "speedboat",
804
+ "815": "spider web, spiders web",
805
+ "816": "spindle",
806
+ "817": "sports car, sport car",
807
+ "818": "spotlight, spot",
808
+ "819": "stage",
809
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
810
+ "820": "steam locomotive",
811
+ "821": "steel arch bridge",
812
+ "822": "steel drum",
813
+ "823": "stethoscope",
814
+ "824": "stole",
815
+ "825": "stone wall",
816
+ "826": "stopwatch, stop watch",
817
+ "827": "stove",
818
+ "828": "strainer",
819
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
820
+ "83": "prairie chicken, prairie grouse, prairie fowl",
821
+ "830": "stretcher",
822
+ "831": "studio couch, day bed",
823
+ "832": "stupa, tope",
824
+ "833": "submarine, pigboat, sub, U-boat",
825
+ "834": "suit, suit of clothes",
826
+ "835": "sundial",
827
+ "836": "sunglass",
828
+ "837": "sunglasses, dark glasses, shades",
829
+ "838": "sunscreen, sunblock, sun blocker",
830
+ "839": "suspension bridge",
831
+ "84": "peacock",
832
+ "840": "swab, swob, mop",
833
+ "841": "sweatshirt",
834
+ "842": "swimming trunks, bathing trunks",
835
+ "843": "swing",
836
+ "844": "switch, electric switch, electrical switch",
837
+ "845": "syringe",
838
+ "846": "table lamp",
839
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
840
+ "848": "tape player",
841
+ "849": "teapot",
842
+ "85": "quail",
843
+ "850": "teddy, teddy bear",
844
+ "851": "television, television system",
845
+ "852": "tennis ball",
846
+ "853": "thatch, thatched roof",
847
+ "854": "theater curtain, theatre curtain",
848
+ "855": "thimble",
849
+ "856": "thresher, thrasher, threshing machine",
850
+ "857": "throne",
851
+ "858": "tile roof",
852
+ "859": "toaster",
853
+ "86": "partridge",
854
+ "860": "tobacco shop, tobacconist shop, tobacconist",
855
+ "861": "toilet seat",
856
+ "862": "torch",
857
+ "863": "totem pole",
858
+ "864": "tow truck, tow car, wrecker",
859
+ "865": "toyshop",
860
+ "866": "tractor",
861
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
862
+ "868": "tray",
863
+ "869": "trench coat",
864
+ "87": "African grey, African gray, Psittacus erithacus",
865
+ "870": "tricycle, trike, velocipede",
866
+ "871": "trimaran",
867
+ "872": "tripod",
868
+ "873": "triumphal arch",
869
+ "874": "trolleybus, trolley coach, trackless trolley",
870
+ "875": "trombone",
871
+ "876": "tub, vat",
872
+ "877": "turnstile",
873
+ "878": "typewriter keyboard",
874
+ "879": "umbrella",
875
+ "88": "macaw",
876
+ "880": "unicycle, monocycle",
877
+ "881": "upright, upright piano",
878
+ "882": "vacuum, vacuum cleaner",
879
+ "883": "vase",
880
+ "884": "vault",
881
+ "885": "velvet",
882
+ "886": "vending machine",
883
+ "887": "vestment",
884
+ "888": "viaduct",
885
+ "889": "violin, fiddle",
886
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
887
+ "890": "volleyball",
888
+ "891": "waffle iron",
889
+ "892": "wall clock",
890
+ "893": "wallet, billfold, notecase, pocketbook",
891
+ "894": "wardrobe, closet, press",
892
+ "895": "warplane, military plane",
893
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
894
+ "897": "washer, automatic washer, washing machine",
895
+ "898": "water bottle",
896
+ "899": "water jug",
897
+ "9": "ostrich, Struthio camelus",
898
+ "90": "lorikeet",
899
+ "900": "water tower",
900
+ "901": "whiskey jug",
901
+ "902": "whistle",
902
+ "903": "wig",
903
+ "904": "window screen",
904
+ "905": "window shade",
905
+ "906": "Windsor tie",
906
+ "907": "wine bottle",
907
+ "908": "wing",
908
+ "909": "wok",
909
+ "91": "coucal",
910
+ "910": "wooden spoon",
911
+ "911": "wool, woolen, woollen",
912
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
913
+ "913": "wreck",
914
+ "914": "yawl",
915
+ "915": "yurt",
916
+ "916": "web site, website, internet site, site",
917
+ "917": "comic book",
918
+ "918": "crossword puzzle, crossword",
919
+ "919": "street sign",
920
+ "92": "bee eater",
921
+ "920": "traffic light, traffic signal, stoplight",
922
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
923
+ "922": "menu",
924
+ "923": "plate",
925
+ "924": "guacamole",
926
+ "925": "consomme",
927
+ "926": "hot pot, hotpot",
928
+ "927": "trifle",
929
+ "928": "ice cream, icecream",
930
+ "929": "ice lolly, lolly, lollipop, popsicle",
931
+ "93": "hornbill",
932
+ "930": "French loaf",
933
+ "931": "bagel, beigel",
934
+ "932": "pretzel",
935
+ "933": "cheeseburger",
936
+ "934": "hotdog, hot dog, red hot",
937
+ "935": "mashed potato",
938
+ "936": "head cabbage",
939
+ "937": "broccoli",
940
+ "938": "cauliflower",
941
+ "939": "zucchini, courgette",
942
+ "94": "hummingbird",
943
+ "940": "spaghetti squash",
944
+ "941": "acorn squash",
945
+ "942": "butternut squash",
946
+ "943": "cucumber, cuke",
947
+ "944": "artichoke, globe artichoke",
948
+ "945": "bell pepper",
949
+ "946": "cardoon",
950
+ "947": "mushroom",
951
+ "948": "Granny Smith",
952
+ "949": "strawberry",
953
+ "95": "jacamar",
954
+ "950": "orange",
955
+ "951": "lemon",
956
+ "952": "fig",
957
+ "953": "pineapple, ananas",
958
+ "954": "banana",
959
+ "955": "jackfruit, jak, jack",
960
+ "956": "custard apple",
961
+ "957": "pomegranate",
962
+ "958": "hay",
963
+ "959": "carbonara",
964
+ "96": "toucan",
965
+ "960": "chocolate sauce, chocolate syrup",
966
+ "961": "dough",
967
+ "962": "meat loaf, meatloaf",
968
+ "963": "pizza, pizza pie",
969
+ "964": "potpie",
970
+ "965": "burrito",
971
+ "966": "red wine",
972
+ "967": "espresso",
973
+ "968": "cup",
974
+ "969": "eggnog",
975
+ "97": "drake",
976
+ "970": "alp",
977
+ "971": "bubble",
978
+ "972": "cliff, drop, drop-off",
979
+ "973": "coral reef",
980
+ "974": "geyser",
981
+ "975": "lakeside, lakeshore",
982
+ "976": "promontory, headland, head, foreland",
983
+ "977": "sandbar, sand bar",
984
+ "978": "seashore, coast, seacoast, sea-coast",
985
+ "979": "valley, vale",
986
+ "98": "red-breasted merganser, Mergus serrator",
987
+ "980": "volcano",
988
+ "981": "ballplayer, baseball player",
989
+ "982": "groom, bridegroom",
990
+ "983": "scuba diver",
991
+ "984": "rapeseed",
992
+ "985": "daisy",
993
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
994
+ "987": "corn",
995
+ "988": "acorn",
996
+ "989": "hip, rose hip, rosehip",
997
+ "99": "goose",
998
+ "990": "buckeye, horse chestnut, conker",
999
+ "991": "coral fungus",
1000
+ "992": "agaric",
1001
+ "993": "gyromitra",
1002
+ "994": "stinkhorn, carrion fungus",
1003
+ "995": "earthstar",
1004
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1005
+ "997": "bolete",
1006
+ "998": "ear, spike, capitulum",
1007
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1008
+ },
1009
+ "scheduler": [
1010
+ "scheduling_flow_match_promoe",
1011
+ "ProMoEFlowMatchScheduler"
1012
+ ],
1013
+ "transformer": [
1014
+ "transformer_promoe",
1015
+ "ProMoETransformer2DModel"
1016
+ ],
1017
+ "vae": [
1018
+ "diffusers",
1019
+ "AutoencoderKL"
1020
+ ]
1021
+ }
ProMoE-L-256/pipeline.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: ProMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+
16
+ try:
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ except Exception: # pragma: no cover
19
+ class DiffusionPipeline:
20
+ def __init__(self):
21
+ self._execution_device = torch.device("cpu")
22
+
23
+ def register_modules(self, **kwargs):
24
+ for key, value in kwargs.items():
25
+ setattr(self, key, value)
26
+
27
+ def to(self, device):
28
+ self._execution_device = torch.device(device)
29
+ for module in (getattr(self, "transformer", None), getattr(self, "vae", None)):
30
+ if module is not None and hasattr(module, "to"):
31
+ module.to(device)
32
+ return self
33
+
34
+ def progress_bar(self, iterable):
35
+ return iterable
36
+
37
+ def maybe_free_model_hooks(self):
38
+ return None
39
+
40
+ @dataclass
41
+ class ProMoEPipelineOutput:
42
+ images: Union[List[Image.Image], np.ndarray, torch.Tensor]
43
+
44
+ class ProMoEPipeline(DiffusionPipeline):
45
+ r"""
46
+ Pipeline for class-conditional image generation with ProMoE.
47
+
48
+ Parameters:
49
+ transformer ([`ProMoETransformer2DModel`]):
50
+ Class-conditional ProMoE transformer for flow-matching in latent space.
51
+ scheduler ([`ProMoEFlowMatchScheduler`]):
52
+ Flow-matching scheduler used during denoising.
53
+ vae ([`AutoencoderKL`], *optional*):
54
+ Variational autoencoder used to decode latents to pixels.
55
+ id2label (`dict[int, str]`, *optional*):
56
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
57
+ """
58
+
59
+ model_cpu_offload_seq = "transformer->vae"
60
+ _optional_components = ["vae"]
61
+
62
+ def __init__(
63
+ self,
64
+ transformer,
65
+ scheduler,
66
+ vae=None,
67
+ id2label: Optional[Dict[Union[int, str], str]] = None,
68
+ ):
69
+ super().__init__()
70
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
71
+ self._id2label = self._normalize_id2label(id2label)
72
+ self.labels = self._build_label2id(self._id2label)
73
+ self._labels_loaded_from_model_index = bool(self._id2label)
74
+
75
+ def _ensure_labels_loaded(self) -> None:
76
+ if self._labels_loaded_from_model_index:
77
+ return
78
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
79
+ if loaded:
80
+ self._id2label = loaded
81
+ self.labels = self._build_label2id(self._id2label)
82
+ self._labels_loaded_from_model_index = True
83
+
84
+ @staticmethod
85
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
86
+ if not id2label:
87
+ return {}
88
+ return {int(key): value for key, value in id2label.items()}
89
+
90
+ @staticmethod
91
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
92
+ if not variant_path:
93
+ return {}
94
+ variant_dir = Path(variant_path).resolve()
95
+ model_index_path = variant_dir / "model_index.json"
96
+ if not model_index_path.exists():
97
+ return {}
98
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
99
+ id2label = raw.get("id2label")
100
+ if not isinstance(id2label, dict):
101
+ return {}
102
+ return {int(key): value for key, value in id2label.items()}
103
+
104
+ @staticmethod
105
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
106
+ label2id: Dict[str, int] = {}
107
+ for class_id, value in id2label.items():
108
+ for synonym in value.split(","):
109
+ synonym = synonym.strip()
110
+ if synonym:
111
+ label2id[synonym] = int(class_id)
112
+ return dict(sorted(label2id.items()))
113
+
114
+ @property
115
+ def id2label(self) -> Dict[int, str]:
116
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
117
+ self._ensure_labels_loaded()
118
+ return self._id2label
119
+
120
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
121
+ r"""
122
+ Map ImageNet label strings to class ids.
123
+
124
+ Args:
125
+ label (`str` or `list[str]`):
126
+ One or more English label strings. Each string must match a synonym in `id2label`.
127
+ """
128
+ self._ensure_labels_loaded()
129
+ label2id = self.labels
130
+ if not label2id:
131
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
132
+
133
+ if isinstance(label, str):
134
+ label = [label]
135
+
136
+ missing = [item for item in label if item not in label2id]
137
+ if missing:
138
+ preview = ", ".join(list(label2id.keys())[:8])
139
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
140
+ return [label2id[item] for item in label]
141
+
142
+ def _get_vae_spatial_downsample(self) -> int:
143
+ if self.vae is None:
144
+ return 8
145
+ block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0])
146
+ return 2 ** (len(block_out_channels) - 1)
147
+
148
+ def _normalize_class_labels(
149
+ self,
150
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
151
+ device: torch.device,
152
+ ) -> torch.LongTensor:
153
+ if torch.is_tensor(class_labels):
154
+ return class_labels.to(device=device, dtype=torch.long).reshape(-1)
155
+
156
+ if isinstance(class_labels, int):
157
+ class_label_ids = [class_labels]
158
+ elif isinstance(class_labels, str):
159
+ class_label_ids = self.get_label_ids(class_labels)
160
+ elif class_labels and isinstance(class_labels[0], str):
161
+ class_label_ids = self.get_label_ids(class_labels)
162
+ else:
163
+ class_label_ids = list(class_labels)
164
+
165
+ return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1)
166
+
167
+ def _prepare_latents(
168
+ self,
169
+ batch_size: int,
170
+ latent_height: int,
171
+ latent_width: int,
172
+ dtype: torch.dtype,
173
+ device: torch.device,
174
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
175
+ ) -> torch.Tensor:
176
+ shape = (batch_size, self.transformer.in_channels, latent_height, latent_width)
177
+ if isinstance(generator, list):
178
+ latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator]
179
+ return torch.cat(latents, dim=0)
180
+ return torch.randn(shape, generator=generator, device=device, dtype=dtype)
181
+
182
+ def _decode_latents(self, latents: torch.Tensor, output_type: str):
183
+ if output_type == "latent":
184
+ return latents
185
+ if self.vae is not None:
186
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
187
+ decode_dtype = next(self.vae.parameters()).dtype
188
+ latents = (latents / scaling_factor).to(dtype=decode_dtype)
189
+ image = self.vae.decode(latents, return_dict=False)[0]
190
+ else:
191
+ image = latents
192
+
193
+ image = (image / 2 + 0.5).clamp(0, 1)
194
+ if output_type == "pt":
195
+ return image
196
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
197
+ if output_type == "np":
198
+ return image
199
+ pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image]
200
+ return pil_images
201
+
202
+ @torch.no_grad()
203
+ def __call__(
204
+ self,
205
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
206
+ height: int = 256,
207
+ width: int = 256,
208
+ num_inference_steps: int = 50,
209
+ guidance_scale: float = 1.0,
210
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
211
+ output_type: str = "pil",
212
+ return_dict: bool = True,
213
+ ) -> Union[ProMoEPipelineOutput, Tuple]:
214
+ r"""
215
+ Generate class-conditional images with ProMoE.
216
+
217
+ Args:
218
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
219
+ ImageNet class indices or human-readable English label strings.
220
+ """
221
+ device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu")
222
+ model_dtype = next(self.transformer.parameters()).dtype
223
+ class_labels = self._normalize_class_labels(class_labels, device)
224
+ batch_size = class_labels.shape[0]
225
+
226
+ vae_scale = self._get_vae_spatial_downsample()
227
+ latent_height = height // vae_scale
228
+ latent_width = width // vae_scale
229
+ latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator)
230
+
231
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
232
+ null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000))
233
+
234
+ for t in self.progress_bar(self.scheduler.timesteps):
235
+ if guidance_scale > 1.0:
236
+ latent_input = torch.cat([latents, latents], dim=0)
237
+ labels = torch.cat([class_labels, null_labels], dim=0)
238
+ else:
239
+ latent_input = latents
240
+ labels = class_labels
241
+ timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype)
242
+ model_output = self.transformer(
243
+ hidden_states=latent_input,
244
+ timestep=timestep,
245
+ class_labels=labels,
246
+ return_dict=True,
247
+ ).sample
248
+ if model_output.shape[1] != latents.shape[1]:
249
+ model_output = model_output.chunk(2, dim=1)[0]
250
+ if guidance_scale > 1.0:
251
+ model_output_cond, model_output_uncond = model_output.chunk(2)
252
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
253
+ latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample
254
+
255
+ images = self._decode_latents(latents, output_type)
256
+ self.maybe_free_model_hooks()
257
+ if not return_dict:
258
+ return (images,)
259
+ return ProMoEPipelineOutput(images=images)
ProMoE-L-256/scheduler/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "num_train_timesteps": 1000,
4
+ "shift": 1.0
5
+ }
ProMoE-L-256/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
ProMoE-L-256/scheduler/scheduling_flow_match_promoe.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from types import SimpleNamespace
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ try:
8
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
9
+ except Exception: # pragma: no cover
10
+ FlowMatchEulerDiscreteScheduler = None
11
+
12
+
13
+ @dataclass
14
+ class ProMoEFlowMatchSchedulerOutput:
15
+ prev_sample: torch.FloatTensor
16
+
17
+
18
+ if FlowMatchEulerDiscreteScheduler is not None:
19
+
20
+ class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler):
21
+ pass
22
+
23
+ else:
24
+
25
+ class ProMoEFlowMatchScheduler:
26
+ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
27
+ self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False)
28
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32)
29
+
30
+ def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None):
31
+ self.timesteps = torch.linspace(
32
+ self.config.num_train_timesteps - 1,
33
+ 0,
34
+ num_inference_steps,
35
+ dtype=torch.float32,
36
+ device=device,
37
+ )
38
+
39
+ def step(self, model_output, timestep, sample, generator=None):
40
+ del generator
41
+ dt = 1.0 / max(len(self.timesteps), 1)
42
+ prev_sample = sample - dt * model_output
43
+ return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample)
ProMoE-L-256/transformer/backbone_diffmoe.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .modeling_promoe_common import (
7
+ Attention,
8
+ FinalLayer,
9
+ LabelEmbedder,
10
+ Mlp,
11
+ MoeMLP_DiffMoE as MoeMLP,
12
+ PatchEmbed,
13
+ TimestepEmbedder,
14
+ get_2d_sincos_pos_embed,
15
+ modulate,
16
+ )
17
+
18
+
19
+ class SparseMoEBlock(nn.Module):
20
+ def __init__(
21
+ self,
22
+ experts,
23
+ hidden_dim,
24
+ num_experts,
25
+ n_shared_experts=0,
26
+ capacity=2,
27
+ mlp_ratio=4.0,
28
+ use_diff_expert=False,
29
+ ):
30
+ super().__init__()
31
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
32
+ nn.init.normal_(self.gate_weight, std=0.006)
33
+ self.experts = nn.ModuleList(experts)
34
+ self.capacity = capacity
35
+ self.num_experts = num_experts
36
+ self.n_shared_experts = n_shared_experts
37
+ self.use_diff_expert = use_diff_expert
38
+ if use_diff_expert:
39
+ self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio))
40
+
41
+ self.capacity_predictor = nn.Sequential(
42
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_dim, self.num_experts, bias=True),
45
+ )
46
+
47
+ if self.n_shared_experts > 0:
48
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2)
49
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
50
+ self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
51
+
52
+ self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts))
53
+ self.register_buffer("ema_decay", torch.tensor([0.95]))
54
+
55
+ def forward(self, x):
56
+ if self.training:
57
+ return self.forward_train(x)
58
+ return self.forward_eval(x)
59
+
60
+ def update_threshold(self, capacity_pred):
61
+ if not self.training:
62
+ return
63
+ capacity_pred = torch.sigmoid(capacity_pred)
64
+ seq_len = capacity_pred.size(0)
65
+ topk = int((seq_len / self.num_experts) * self.capacity)
66
+ threshold = self.expert_threshold
67
+ ema_decay = self.ema_decay
68
+ for i in range(self.num_experts):
69
+ scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True)
70
+ quantile = scores[-1].detach()
71
+ threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile
72
+ if dist.is_available() and dist.is_initialized():
73
+ dist.all_reduce(threshold, op=dist.ReduceOp.SUM)
74
+ threshold /= dist.get_world_size()
75
+ self.expert_threshold = threshold
76
+
77
+ def forward_train(self, x):
78
+ bsz, seq_len, hidden_dim = x.shape
79
+ identity = x
80
+ x = x.view(-1, hidden_dim)
81
+ total_tokens = x.shape[0]
82
+ capacity_pred = self.capacity_predictor(x.detach())
83
+ k = int((total_tokens / self.num_experts) * self.capacity)
84
+ logits = F.linear(x, self.gate_weight, None)
85
+ scores = logits.softmax(dim=-1).permute(1, 0)
86
+ gating, index = torch.topk(scores, k=k, dim=-1, sorted=False)
87
+ mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device)
88
+ mask.scatter_(1, index, 1.0)
89
+ expert_inputs = x[index]
90
+ expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)])
91
+ gated_outputs = gating.unsqueeze(-1) * expert_outputs
92
+
93
+ y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device)
94
+ offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens
95
+ flat_index = (index + offset.long()).view(-1)
96
+ y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim))
97
+ y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False)
98
+
99
+ self.update_threshold(capacity_pred)
100
+ x_out = y.view(bsz, seq_len, hidden_dim)
101
+ ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts)
102
+ capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts)
103
+ if self.n_shared_experts > 0:
104
+ x_out = x_out + self.shared_experts(identity)
105
+ if self.use_diff_expert:
106
+ x_out = x_out - self.diff_expert(identity)
107
+ return x_out, ones, capacity_pred
108
+
109
+ def forward_eval(self, x):
110
+ bsz, seq_len, hidden_dim = x.shape
111
+ identity = x
112
+ x = x.view(-1, hidden_dim)
113
+ total_tokens = x.shape[0]
114
+ capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach()))
115
+ threshold = self.expert_threshold
116
+ logits = F.linear(x, self.gate_weight, None)
117
+ scores = logits.softmax(dim=-1).permute(-1, -2)
118
+ y = torch.zeros_like(x, dtype=x.dtype)
119
+ for i, expert in enumerate(self.experts):
120
+ k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum()
121
+ gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False)
122
+ y[index, :] += gating.unsqueeze(-1) * expert(x[index, :])
123
+ x_out = y.view(bsz, seq_len, hidden_dim)
124
+ if self.n_shared_experts > 0:
125
+ x_out = x_out + self.shared_experts(identity)
126
+ return x_out, None, None
127
+
128
+
129
+ class DiTBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_size,
133
+ num_heads,
134
+ head_dim=None,
135
+ mlp_ratio=4.0,
136
+ use_swiglu=False,
137
+ MoE_config=None,
138
+ use_moe=False,
139
+ qk_norm=False,
140
+ **block_kwargs,
141
+ ):
142
+ super().__init__()
143
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
144
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs)
145
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.use_moe = use_moe
148
+ if use_moe:
149
+ if not use_swiglu:
150
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
151
+ experts = [
152
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
153
+ for _ in range(MoE_config.num_experts)
154
+ ]
155
+ else:
156
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
157
+ self.mlp = SparseMoEBlock(
158
+ experts=experts,
159
+ hidden_dim=hidden_size,
160
+ num_experts=MoE_config.num_experts,
161
+ capacity=MoE_config.capacity,
162
+ n_shared_experts=MoE_config.n_shared_experts,
163
+ mlp_ratio=4.0,
164
+ )
165
+ else:
166
+ if not use_swiglu:
167
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
169
+ else:
170
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
171
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
172
+
173
+ def forward(self, x, c):
174
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
175
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
176
+ if self.use_moe:
177
+ x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
178
+ x = x + gate_mlp.unsqueeze(1) * x_mlp
179
+ return x, ones, pred_c
180
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+ return x, None, None
182
+
183
+
184
+ class DiT(nn.Module):
185
+ def __init__(
186
+ self,
187
+ input_size=32,
188
+ patch_size=2,
189
+ in_channels=4,
190
+ hidden_size=1152,
191
+ depth=28,
192
+ num_heads=16,
193
+ mlp_ratio=4.0,
194
+ qk_norm=False,
195
+ class_dropout_prob=0.1,
196
+ num_classes=1000,
197
+ learn_sigma=True,
198
+ use_swiglu=False,
199
+ MoE_config=None,
200
+ head_dim=None,
201
+ CapacityPred_loss_weight=0.01,
202
+ ):
203
+ super().__init__()
204
+ self.learn_sigma = learn_sigma
205
+ self.in_channels = in_channels
206
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
207
+ self.patch_size = patch_size
208
+ self.num_heads = num_heads
209
+ self.MoE_config = MoE_config
210
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
211
+ self.CapacityPred_loss_weight = CapacityPred_loss_weight
212
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
213
+ self.t_embedder = TimestepEmbedder(hidden_size)
214
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
215
+ num_patches = self.x_embedder.num_patches
216
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
217
+ self.blocks = nn.ModuleList(
218
+ [
219
+ DiTBlock(
220
+ hidden_size,
221
+ num_heads,
222
+ head_dim=head_dim,
223
+ mlp_ratio=mlp_ratio,
224
+ qk_norm=qk_norm,
225
+ use_swiglu=use_swiglu,
226
+ MoE_config=MoE_config,
227
+ use_moe=use_moe_flag[i],
228
+ )
229
+ for i in range(depth)
230
+ ]
231
+ )
232
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
233
+ self.init_MoeMLP = MoE_config.init_MoeMLP
234
+ self.initialize_weights()
235
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
236
+ if self.capacity_schedule:
237
+ self.training_iters = -1
238
+
239
+ def initialize_weights(self):
240
+ def _basic_init(module):
241
+ if isinstance(module, nn.Linear):
242
+ torch.nn.init.xavier_uniform_(module.weight)
243
+ if module.bias is not None:
244
+ nn.init.constant_(module.bias, 0)
245
+
246
+ self.apply(_basic_init)
247
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
248
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
249
+ w = self.x_embedder.proj.weight.data
250
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
251
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
252
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
253
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
254
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
255
+ for block in self.blocks:
256
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
258
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
259
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
260
+ nn.init.constant_(self.final_layer.linear.weight, 0)
261
+ nn.init.constant_(self.final_layer.linear.bias, 0)
262
+
263
+ def unpatchify(self, x):
264
+ c = self.out_channels
265
+ p = self.x_embedder.patch_size[0]
266
+ h = w = int(x.shape[1] ** 0.5)
267
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
268
+ x = torch.einsum("nhwpqc->nchpwq", x)
269
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
270
+
271
+ def forward(self, x, t, context, **kwargs):
272
+ y = context
273
+ if len(x.shape) != 4:
274
+ x = x.squeeze(2)
275
+
276
+ if self.training and self.capacity_schedule:
277
+ num_experts = self.MoE_config.num_experts
278
+ capacity = self.MoE_config.capacity
279
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
280
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
281
+ if self.training_iters <= stage_i:
282
+ capacity = num_experts
283
+ elif self.training_iters <= stage_ii:
284
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
285
+ for block in self.blocks:
286
+ if hasattr(block.mlp, "capacity"):
287
+ block.mlp.capacity = capacity
288
+
289
+ x = self.x_embedder(x) + self.pos_embed
290
+ t = self.t_embedder(t)
291
+ y = self.y_embedder(y, self.training)
292
+ c = t + y
293
+ ones_list, pred_c_list, layer_idx_list = [], [], []
294
+ for layer_idx, block in enumerate(self.blocks):
295
+ x, ones, pred_c = block(x, c)
296
+ if ones is not None:
297
+ ones_list.append(ones)
298
+ pred_c_list.append(pred_c)
299
+ layer_idx_list.append(layer_idx)
300
+ x = self.final_layer(x, c)
301
+ x = self.unpatchify(x)
302
+ return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight
ProMoE-L-256/transformer/backbone_dit.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .modeling_promoe_common import (
5
+ Attention,
6
+ FinalLayer,
7
+ LabelEmbedder,
8
+ Mlp,
9
+ PatchEmbed,
10
+ TimestepEmbedder,
11
+ get_2d_sincos_pos_embed,
12
+ modulate,
13
+ )
14
+
15
+
16
+ class DiTBlock(nn.Module):
17
+ def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs):
18
+ super().__init__()
19
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
20
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
21
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
22
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
23
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
24
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
25
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
26
+
27
+ def forward(self, x, c):
28
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
29
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
30
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
31
+ return x
32
+
33
+
34
+ class DiT(nn.Module):
35
+ def __init__(
36
+ self,
37
+ input_size=32,
38
+ patch_size=2,
39
+ in_channels=4,
40
+ hidden_size=1152,
41
+ depth=28,
42
+ num_heads=16,
43
+ mlp_ratio=4.0,
44
+ qk_norm=False,
45
+ class_dropout_prob=0.1,
46
+ num_classes=1000,
47
+ learn_sigma=True,
48
+ head_dim=None,
49
+ use_swiglu=False,
50
+ ):
51
+ super().__init__()
52
+ self.learn_sigma = learn_sigma
53
+ self.in_channels = in_channels
54
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
55
+ self.patch_size = patch_size
56
+ self.num_heads = num_heads
57
+
58
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
59
+ self.t_embedder = TimestepEmbedder(hidden_size)
60
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
61
+ num_patches = self.x_embedder.num_patches
62
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
63
+
64
+ self.blocks = nn.ModuleList(
65
+ [
66
+ DiTBlock(
67
+ hidden_size,
68
+ num_heads,
69
+ head_dim=head_dim,
70
+ mlp_ratio=mlp_ratio,
71
+ qk_norm=qk_norm,
72
+ use_swiglu=use_swiglu,
73
+ )
74
+ for _ in range(depth)
75
+ ]
76
+ )
77
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
78
+ self.initialize_weights()
79
+
80
+ def initialize_weights(self):
81
+ def _basic_init(module):
82
+ if isinstance(module, nn.Linear):
83
+ torch.nn.init.xavier_uniform_(module.weight)
84
+ if module.bias is not None:
85
+ nn.init.constant_(module.bias, 0)
86
+
87
+ self.apply(_basic_init)
88
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
89
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
90
+ w = self.x_embedder.proj.weight.data
91
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
92
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
93
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
94
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
95
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
96
+ for block in self.blocks:
97
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
98
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
99
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
100
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
101
+ nn.init.constant_(self.final_layer.linear.weight, 0)
102
+ nn.init.constant_(self.final_layer.linear.bias, 0)
103
+
104
+ def unpatchify(self, x):
105
+ c = self.out_channels
106
+ p = self.x_embedder.patch_size[0]
107
+ h = w = int(x.shape[1] ** 0.5)
108
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
109
+ x = torch.einsum("nhwpqc->nchpwq", x)
110
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
111
+
112
+ def forward(self, x, t, context, **kwargs):
113
+ y = context
114
+ if len(x.shape) != 4:
115
+ x = x.squeeze(2)
116
+ x = self.x_embedder(x) + self.pos_embed
117
+ t = self.t_embedder(t)
118
+ y = self.y_embedder(y, self.training)
119
+ c = t + y
120
+ for block in self.blocks:
121
+ x = block(x, c)
122
+ x = self.final_layer(x, c)
123
+ return self.unpatchify(x)
ProMoE-L-256/transformer/backbone_ecdit.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP_DiffMoE as MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class SparseMoEBlock(nn.Module):
19
+ def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2):
20
+ super().__init__()
21
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
22
+ nn.init.normal_(self.gate_weight, std=0.006)
23
+ self.experts = nn.ModuleList(experts)
24
+ self.capacity = capacity
25
+ self.num_experts = num_experts
26
+ self.n_shared_experts = n_shared_experts
27
+ if self.n_shared_experts > 0:
28
+ intermediate_size = hidden_dim * self.n_shared_experts
29
+ self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2)
30
+
31
+ def forward(self, x):
32
+ identity = x
33
+ batch_size, seq_len, _ = x.shape
34
+ logits = F.linear(x, self.gate_weight, None)
35
+ affinity = logits.softmax(dim=-1)
36
+ affinity = torch.einsum("b s e -> b e s", affinity)
37
+ k = int((seq_len / self.num_experts) * self.capacity)
38
+ gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False)
39
+ dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype)
40
+ x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x)
41
+ x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)]
42
+ x_e = torch.stack(x_e, dim=1)
43
+ x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e)
44
+ if self.n_shared_experts > 0:
45
+ x_out = x_out + self.shared_experts(identity)
46
+ return x_out
47
+
48
+
49
+ class DiTBlock(nn.Module):
50
+ def __init__(
51
+ self,
52
+ hidden_size,
53
+ num_heads,
54
+ head_dim=None,
55
+ mlp_ratio=4.0,
56
+ use_swiglu=False,
57
+ MoE_config=None,
58
+ use_moe=False,
59
+ **block_kwargs,
60
+ ):
61
+ super().__init__()
62
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
63
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
64
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
65
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
66
+ if use_moe:
67
+ if not use_swiglu:
68
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
69
+ experts = [
70
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
71
+ for _ in range(MoE_config.num_experts)
72
+ ]
73
+ else:
74
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
75
+ self.mlp = SparseMoEBlock(
76
+ experts=experts,
77
+ hidden_dim=hidden_size,
78
+ num_experts=MoE_config.num_experts,
79
+ capacity=MoE_config.capacity,
80
+ n_shared_experts=MoE_config.n_shared_experts,
81
+ )
82
+ else:
83
+ if not use_swiglu:
84
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
85
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
86
+ else:
87
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
88
+
89
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
90
+
91
+ def forward(self, x, c):
92
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
93
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
94
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
95
+ return x
96
+
97
+
98
+ class DiT(nn.Module):
99
+ def __init__(
100
+ self,
101
+ input_size=32,
102
+ patch_size=2,
103
+ in_channels=4,
104
+ hidden_size=1152,
105
+ depth=28,
106
+ num_heads=16,
107
+ mlp_ratio=4.0,
108
+ qk_norm=False,
109
+ class_dropout_prob=0.1,
110
+ num_classes=1000,
111
+ learn_sigma=True,
112
+ use_swiglu=False,
113
+ MoE_config=None,
114
+ head_dim=None,
115
+ ):
116
+ super().__init__()
117
+ self.learn_sigma = learn_sigma
118
+ self.in_channels = in_channels
119
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
120
+ self.patch_size = patch_size
121
+ self.num_heads = num_heads
122
+ self.MoE_config = MoE_config
123
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
124
+
125
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
126
+ self.t_embedder = TimestepEmbedder(hidden_size)
127
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
128
+ num_patches = self.x_embedder.num_patches
129
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
130
+ self.blocks = nn.ModuleList(
131
+ [
132
+ DiTBlock(
133
+ hidden_size,
134
+ num_heads,
135
+ head_dim=head_dim,
136
+ mlp_ratio=mlp_ratio,
137
+ qk_norm=qk_norm,
138
+ use_swiglu=use_swiglu,
139
+ MoE_config=MoE_config,
140
+ use_moe=use_moe_flag[i],
141
+ )
142
+ for i in range(depth)
143
+ ]
144
+ )
145
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
146
+ self.init_MoeMLP = MoE_config.init_MoeMLP
147
+ self.initialize_weights()
148
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
149
+ if self.capacity_schedule:
150
+ self.training_iters = -1
151
+
152
+ def initialize_weights(self):
153
+ def _basic_init(module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.xavier_uniform_(module.weight)
156
+ if module.bias is not None:
157
+ nn.init.constant_(module.bias, 0)
158
+
159
+ self.apply(_basic_init)
160
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
161
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
162
+ w = self.x_embedder.proj.weight.data
163
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
164
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
165
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
166
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
167
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
168
+ for block in self.blocks:
169
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
172
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
173
+ nn.init.constant_(self.final_layer.linear.weight, 0)
174
+ nn.init.constant_(self.final_layer.linear.bias, 0)
175
+
176
+ def init_moe_mlp(module, std=0.006):
177
+ nn.init.normal_(module.gate_proj.weight, std=std)
178
+ nn.init.normal_(module.up_proj.weight, std=std)
179
+ nn.init.normal_(module.down_proj.weight, std=std)
180
+
181
+ if self.init_MoeMLP:
182
+ for block in self.blocks:
183
+ if hasattr(block.mlp, "experts"):
184
+ for expert in block.mlp.experts:
185
+ if hasattr(expert, "gate_proj"):
186
+ init_moe_mlp(expert)
187
+
188
+ def unpatchify(self, x):
189
+ c = self.out_channels
190
+ p = self.x_embedder.patch_size[0]
191
+ h = w = int(x.shape[1] ** 0.5)
192
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
193
+ x = torch.einsum("nhwpqc->nchpwq", x)
194
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
195
+
196
+ def forward(self, x, t, context, **kwargs):
197
+ y = context
198
+ if len(x.shape) != 4:
199
+ x = x.squeeze(2)
200
+ if self.training and self.capacity_schedule:
201
+ num_experts = self.MoE_config.num_experts
202
+ capacity = self.MoE_config.capacity
203
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
204
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
205
+ if self.training_iters <= stage_i:
206
+ capacity = num_experts
207
+ elif self.training_iters <= stage_ii:
208
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
209
+ for block in self.blocks:
210
+ if hasattr(block.mlp, "capacity"):
211
+ block.mlp.capacity = capacity
212
+
213
+ x = self.x_embedder(x) + self.pos_embed
214
+ t = self.t_embedder(t)
215
+ y = self.y_embedder(y, self.training)
216
+ c = t + y
217
+ for block in self.blocks:
218
+ x = block(x, c)
219
+ x = self.final_layer(x, c)
220
+ return self.unpatchify(x)
ProMoE-L-256/transformer/backbone_promoe_ec.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=1,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.hidden_size = hidden_size
55
+ self.top_k = top_k
56
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
57
+ self.use_shared_expert = use_shared_expert
58
+ self.use_uncond_expert = use_uncond_expert
59
+ self.router_weight_mode = router_weight_mode
60
+ self.routing_contrastive_lam = routing_contrastive_lam
61
+ self.routing_contrastive_temperature = routing_contrastive_temperature
62
+ self.experts = nn.ModuleList(
63
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
64
+ )
65
+ if use_shared_expert:
66
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
67
+ self._init_weights()
68
+
69
+ def compute_router(self, cond_hidden_states):
70
+ b_cond, seq_len, _ = cond_hidden_states.shape
71
+ num_cond_experts = self.num_routed_experts
72
+ input_norm = F.normalize(cond_hidden_states, p=2, dim=-1)
73
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1)
74
+ cos_sim = input_norm @ cluster_norm.T
75
+ cos_sim_expert_view = cos_sim.transpose(1, 2)
76
+ if self.router_weight_mode == "softmax":
77
+ cond_weights = F.softmax(cos_sim_expert_view, dim=-1)
78
+ elif self.router_weight_mode == "sigmoid":
79
+ cond_weights = torch.sigmoid(cos_sim_expert_view)
80
+ elif self.router_weight_mode == "identity":
81
+ cond_weights = cos_sim_expert_view
82
+ else:
83
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
84
+ k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len))
85
+ router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False)
86
+ dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype)
87
+ expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states)
88
+ return dispatch_mask, router_weights, expert_inputs
89
+
90
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
91
+ identity = hidden_states
92
+ batch_size, _, hidden_dim = hidden_states.shape
93
+ final_output = torch.zeros_like(hidden_states)
94
+ loss = None
95
+ cond_batch_mask = (
96
+ labels.view(-1) != 1000
97
+ ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device)
98
+ uncond_batch_mask = ~cond_batch_mask
99
+ cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts
100
+
101
+ if cond_batch_mask.any():
102
+ cond_hidden_states = hidden_states[cond_batch_mask]
103
+ dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states)
104
+ num_cond_experts = len(cond_experts)
105
+ expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1)
106
+ cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype)
107
+ final_output[cond_batch_mask] = cond_output
108
+ if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1:
109
+ expert_token_means = expert_inputs.mean(dim=2)
110
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means)
111
+ loss = routing_contrastive_loss * self.routing_contrastive_lam
112
+ else:
113
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
114
+ for expert in cond_experts:
115
+ final_output = final_output + expert(dummy_input).sum() * 0
116
+
117
+ if self.use_uncond_expert:
118
+ if uncond_batch_mask.any():
119
+ uncond_hidden_states = hidden_states[uncond_batch_mask]
120
+ final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype)
121
+ else:
122
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
123
+ final_output = final_output + self.experts[-1](dummy_input).sum() * 0
124
+
125
+ if self.use_shared_expert:
126
+ final_output += self.shared_expert(identity).to(hidden_states.dtype)
127
+ return final_output, loss
128
+
129
+ def compute_routing_contrastive_loss(self, expert_token_means):
130
+ batch_size, num_cond_experts, _ = expert_token_means.shape
131
+ if num_cond_experts < 2:
132
+ return torch.tensor(0.0, device=expert_token_means.device)
133
+ centers_norm = F.normalize(self.cluster_centers, p=2, dim=1)
134
+ means_norm = F.normalize(expert_token_means, p=2, dim=2)
135
+ sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm)
136
+ logits = sim_matrix / self.routing_contrastive_temperature
137
+ labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1)
138
+ return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1))
139
+
140
+ def _init_weights(self):
141
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
142
+
143
+
144
+ class DiTBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ hidden_size,
148
+ num_heads,
149
+ head_dim=None,
150
+ mlp_ratio=4.0,
151
+ use_swiglu=False,
152
+ MoE_config=None,
153
+ use_moe=False,
154
+ **block_kwargs,
155
+ ):
156
+ super().__init__()
157
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
158
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
159
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
160
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
161
+ self.use_moe = use_moe
162
+ if use_moe:
163
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
164
+ else:
165
+ if not use_swiglu:
166
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
167
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
168
+ else:
169
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
170
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
171
+
172
+ def forward(self, x, c, label):
173
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
174
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
175
+ if self.use_moe:
176
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
177
+ if aux_loss is not None:
178
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
179
+ return x + gate_mlp.unsqueeze(1) * x_mlp
180
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+
182
+
183
+ class DiT(nn.Module):
184
+ def __init__(
185
+ self,
186
+ input_size=32,
187
+ patch_size=2,
188
+ in_channels=4,
189
+ hidden_size=1152,
190
+ depth=28,
191
+ num_heads=16,
192
+ mlp_ratio=4.0,
193
+ qk_norm=False,
194
+ class_dropout_prob=0.1,
195
+ num_classes=1000,
196
+ learn_sigma=True,
197
+ use_swiglu=False,
198
+ MoE_config=None,
199
+ head_dim=None,
200
+ ):
201
+ super().__init__()
202
+ self.learn_sigma = learn_sigma
203
+ self.in_channels = in_channels
204
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
205
+ self.patch_size = patch_size
206
+ self.num_heads = num_heads
207
+ self.MoE_config = MoE_config
208
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
209
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
210
+ self.t_embedder = TimestepEmbedder(hidden_size)
211
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
212
+ num_patches = self.x_embedder.num_patches
213
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
214
+ self.blocks = nn.ModuleList(
215
+ [
216
+ DiTBlock(
217
+ hidden_size,
218
+ num_heads,
219
+ head_dim=head_dim,
220
+ mlp_ratio=mlp_ratio,
221
+ qk_norm=qk_norm,
222
+ use_swiglu=use_swiglu,
223
+ MoE_config=MoE_config,
224
+ use_moe=use_moe_flag[i],
225
+ )
226
+ for i in range(depth)
227
+ ]
228
+ )
229
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
230
+ self.init_MoeMLP = MoE_config.init_MoeMLP
231
+ self.initialize_weights()
232
+
233
+ def initialize_weights(self):
234
+ def _basic_init(module):
235
+ if isinstance(module, nn.Linear):
236
+ torch.nn.init.xavier_uniform_(module.weight)
237
+ if module.bias is not None:
238
+ nn.init.constant_(module.bias, 0)
239
+
240
+ self.apply(_basic_init)
241
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
242
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
243
+ w = self.x_embedder.proj.weight.data
244
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
245
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
246
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
247
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
248
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
249
+ for block in self.blocks:
250
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
251
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
252
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
253
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
254
+ nn.init.constant_(self.final_layer.linear.weight, 0)
255
+ nn.init.constant_(self.final_layer.linear.bias, 0)
256
+
257
+ def init_moe_mlp(module, std=0.006):
258
+ nn.init.normal_(module.up_proj.weight, std=std)
259
+ nn.init.normal_(module.down_proj.weight, std=std)
260
+
261
+ if self.init_MoeMLP:
262
+ for block in self.blocks:
263
+ if hasattr(block.mlp, "experts"):
264
+ for expert in block.mlp.experts:
265
+ init_moe_mlp(expert)
266
+
267
+ def unpatchify(self, x):
268
+ c = self.out_channels
269
+ p = self.x_embedder.patch_size[0]
270
+ h = w = int(x.shape[1] ** 0.5)
271
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
272
+ x = torch.einsum("nhwpqc->nchpwq", x)
273
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
274
+
275
+ def forward(self, x, timestep, context, **kwargs):
276
+ y = context
277
+ if len(x.shape) != 4:
278
+ x = x.squeeze(2)
279
+ x = self.x_embedder(x) + self.pos_embed
280
+ t = self.t_embedder(timestep)
281
+ y, labels = self.y_embedder(y, self.training)
282
+ c = t + y
283
+ for block in self.blocks:
284
+ x = block(x, c, labels)
285
+ x = self.final_layer(x, c)
286
+ return self.unpatchify(x)
ProMoE-L-256/transformer/backbone_promoe_tc.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=2,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del norm_topk_prob
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.seq_aux = seq_aux
55
+ self.hidden_size = hidden_size
56
+ self.top_k = top_k
57
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
58
+ self.alpha = load_balance_loss_coef
59
+ self.use_shared_expert = use_shared_expert
60
+ self.use_uncond_expert = use_uncond_expert
61
+ self.router_weight_mode = router_weight_mode
62
+ self.routing_contrastive_lam = routing_contrastive_lam
63
+ self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive
64
+ self.routing_contrastive_temperature = routing_contrastive_temperature
65
+ self.experts = nn.ModuleList(
66
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
67
+ )
68
+ if use_shared_expert:
69
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
70
+ self._init_weights()
71
+
72
+ def compute_router(self, hidden_states, labels):
73
+ batch_size, seq_len, _ = hidden_states.shape
74
+ device = hidden_states.device
75
+ flat_input = hidden_states.view(-1, self.hidden_size)
76
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
77
+ if self.use_uncond_expert and flat_labels is not None:
78
+ uncond_mask = flat_labels == 1000
79
+ cond_mask = ~uncond_mask
80
+ else:
81
+ uncond_mask = None
82
+ cond_mask = torch.ones_like(flat_labels, dtype=torch.bool)
83
+
84
+ router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype)
85
+ expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long)
86
+
87
+ if uncond_mask is not None and uncond_mask.any():
88
+ uncond_positions = torch.where(uncond_mask)[0]
89
+ router_weights[uncond_positions, 0] = 1.0
90
+ expert_indices[uncond_positions] = self.num_experts - 1
91
+
92
+ cond_weights = None
93
+ topk_idx = None
94
+ if cond_mask.any():
95
+ cond_positions = torch.where(cond_mask)[0]
96
+ cond_input = flat_input[cond_positions]
97
+ input_norm = F.normalize(cond_input, p=2, dim=1)
98
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1)
99
+ cos_sim = input_norm @ cluster_norm.T
100
+ if self.router_weight_mode == "softmax":
101
+ cond_weights = F.softmax(cos_sim, dim=1)
102
+ elif self.router_weight_mode == "sigmoid":
103
+ cond_weights = torch.sigmoid(cos_sim)
104
+ elif self.router_weight_mode == "identity":
105
+ cond_weights = cos_sim
106
+ else:
107
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
108
+ topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1)
109
+ router_weights[cond_positions] = topk_scores.to(router_weights.dtype)
110
+ expert_indices[cond_positions] = topk_idx
111
+
112
+ router_weights = router_weights.view(batch_size, seq_len, self.top_k)
113
+ expert_indices = expert_indices.view(batch_size, seq_len, self.top_k)
114
+
115
+ load_balance_loss = None
116
+ if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None:
117
+ cond_batch_size = (labels != 1000).sum()
118
+ scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights
119
+ topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1)
120
+ if self.seq_aux:
121
+ scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1)
122
+ ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device)
123
+ ce.scatter_add_(
124
+ 1,
125
+ topk_idx_for_aux_loss,
126
+ torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device),
127
+ ).div_(seq_len * self.top_k / self.num_routed_experts)
128
+ load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
129
+ else:
130
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts)
131
+ ce = mask_ce.float().mean(0)
132
+ pi = scores_for_aux.mean(0)
133
+ fi = ce * self.num_routed_experts
134
+ load_balance_loss = (pi * fi).sum() * self.alpha
135
+ return router_weights, expert_indices, load_balance_loss
136
+
137
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
138
+ router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels)
139
+ batch_size, seq_len, hidden_dim = hidden_states.shape
140
+ flat_input = hidden_states.view(-1, hidden_dim)
141
+ flat_weights = router_weights.view(-1, self.top_k)
142
+ flat_indices = expert_indices.view(-1, self.top_k)
143
+ total_tokens = batch_size * seq_len
144
+ final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
145
+
146
+ for expert_id in range(self.num_experts):
147
+ expert_mask = (flat_indices == expert_id).any(dim=1)
148
+ token_ids = torch.where(expert_mask)[0]
149
+ if token_ids.numel() > 0:
150
+ expert_input = flat_input[token_ids]
151
+ expert_weight_mask = flat_indices[token_ids] == expert_id
152
+ expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype)
153
+ combined_weights = expert_weights.sum(dim=1)
154
+ expert_output = self.experts[expert_id](expert_input)
155
+ weighted_output = expert_output * combined_weights.unsqueeze(1)
156
+ final_output.index_add_(0, token_ids, weighted_output)
157
+ else:
158
+ dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
159
+ final_output[0] += self.experts[expert_id](dummy_input)[0] * 0
160
+
161
+ final_output = final_output.view(batch_size, seq_len, hidden_dim)
162
+ if self.use_shared_expert:
163
+ final_output += self.shared_expert(hidden_states)
164
+
165
+ loss = load_balance_loss
166
+ if self.training and self.routing_contrastive_lam > 0:
167
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
168
+ cond_mask = ~(
169
+ flat_labels == 1000
170
+ ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device)
171
+ cond_token_embeddings = flat_input[cond_mask]
172
+ if self.use_top_k_for_routing_contrastive:
173
+ cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask]
174
+ else:
175
+ top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0]
176
+ cond_cluster_assignments = top1_expert_indices[cond_mask]
177
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(
178
+ cond_token_embeddings,
179
+ cond_cluster_assignments,
180
+ use_top_k=self.use_top_k_for_routing_contrastive,
181
+ )
182
+ routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam
183
+ loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss
184
+
185
+ return final_output, loss
186
+
187
+ def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False):
188
+ cluster_centers = self.cluster_centers
189
+ num_clusters = cluster_centers.size(0)
190
+ device = cluster_centers.device
191
+ cluster_means = []
192
+ valid_clusters = []
193
+ for cluster_id in range(num_clusters):
194
+ mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id
195
+ if mask.sum() > 0:
196
+ cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True))
197
+ valid_clusters.append(cluster_id)
198
+ if len(valid_clusters) < 2:
199
+ return torch.tensor(0.0, device=device)
200
+ cluster_means = torch.cat(cluster_means, dim=0)
201
+ valid_centers = cluster_centers[valid_clusters]
202
+ centers_norm = F.normalize(valid_centers, p=2, dim=1)
203
+ means_norm = F.normalize(cluster_means, p=2, dim=1)
204
+ sim_matrix = centers_norm @ means_norm.T
205
+ logits = sim_matrix / self.routing_contrastive_temperature
206
+ labels = torch.arange(sim_matrix.size(0), device=device)
207
+ return F.cross_entropy(logits, labels)
208
+
209
+ def _init_weights(self):
210
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
211
+
212
+
213
+ class DiTBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ hidden_size,
217
+ num_heads,
218
+ head_dim=None,
219
+ mlp_ratio=4.0,
220
+ use_swiglu=False,
221
+ MoE_config=None,
222
+ use_moe=False,
223
+ **block_kwargs,
224
+ ):
225
+ super().__init__()
226
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
228
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
230
+ self.use_moe = use_moe
231
+ if use_moe:
232
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
233
+ else:
234
+ if not use_swiglu:
235
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
236
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
237
+ else:
238
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
239
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
240
+
241
+ def forward(self, x, c, label):
242
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
243
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
244
+ if self.use_moe:
245
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
246
+ if aux_loss is not None:
247
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
248
+ return x + gate_mlp.unsqueeze(1) * x_mlp
249
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
250
+
251
+
252
+ class DiT(nn.Module):
253
+ def __init__(
254
+ self,
255
+ input_size=32,
256
+ patch_size=2,
257
+ in_channels=4,
258
+ hidden_size=1152,
259
+ depth=28,
260
+ num_heads=16,
261
+ mlp_ratio=4.0,
262
+ qk_norm=False,
263
+ class_dropout_prob=0.1,
264
+ num_classes=1000,
265
+ learn_sigma=True,
266
+ use_swiglu=False,
267
+ MoE_config=None,
268
+ head_dim=None,
269
+ ):
270
+ super().__init__()
271
+ self.learn_sigma = learn_sigma
272
+ self.in_channels = in_channels
273
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
274
+ self.patch_size = patch_size
275
+ self.num_heads = num_heads
276
+ self.MoE_config = MoE_config
277
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
278
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
279
+ self.t_embedder = TimestepEmbedder(hidden_size)
280
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
281
+ num_patches = self.x_embedder.num_patches
282
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
283
+ self.blocks = nn.ModuleList(
284
+ [
285
+ DiTBlock(
286
+ hidden_size,
287
+ num_heads,
288
+ head_dim=head_dim,
289
+ mlp_ratio=mlp_ratio,
290
+ qk_norm=qk_norm,
291
+ use_swiglu=use_swiglu,
292
+ MoE_config=MoE_config,
293
+ use_moe=use_moe_flag[i],
294
+ )
295
+ for i in range(depth)
296
+ ]
297
+ )
298
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
299
+ self.init_MoeMLP = MoE_config.init_MoeMLP
300
+ self.initialize_weights()
301
+
302
+ def initialize_weights(self):
303
+ def _basic_init(module):
304
+ if isinstance(module, nn.Linear):
305
+ torch.nn.init.xavier_uniform_(module.weight)
306
+ if module.bias is not None:
307
+ nn.init.constant_(module.bias, 0)
308
+
309
+ self.apply(_basic_init)
310
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
311
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
312
+ w = self.x_embedder.proj.weight.data
313
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
314
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
315
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
316
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
317
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
318
+ for block in self.blocks:
319
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
320
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
321
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
322
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
323
+ nn.init.constant_(self.final_layer.linear.weight, 0)
324
+ nn.init.constant_(self.final_layer.linear.bias, 0)
325
+
326
+ def init_moe_mlp(module, std=0.006):
327
+ nn.init.normal_(module.up_proj.weight, std=std)
328
+ nn.init.normal_(module.down_proj.weight, std=std)
329
+
330
+ if self.init_MoeMLP:
331
+ for block in self.blocks:
332
+ if hasattr(block.mlp, "experts"):
333
+ for expert in block.mlp.experts:
334
+ init_moe_mlp(expert)
335
+
336
+ def unpatchify(self, x):
337
+ c = self.out_channels
338
+ p = self.x_embedder.patch_size[0]
339
+ h = w = int(x.shape[1] ** 0.5)
340
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
341
+ x = torch.einsum("nhwpqc->nchpwq", x)
342
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
343
+
344
+ def forward(self, x, timestep, context, **kwargs):
345
+ y = context
346
+ if len(x.shape) != 4:
347
+ x = x.squeeze(2)
348
+ x = self.x_embedder(x) + self.pos_embed
349
+ t = self.t_embedder(timestep)
350
+ y, labels = self.y_embedder(y, self.training)
351
+ c = t + y
352
+ for block in self.blocks:
353
+ x = block(x, c, labels)
354
+ x = self.final_layer(x, c)
355
+ return self.unpatchify(x)
ProMoE-L-256/transformer/backbone_tcdit.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_promoe_common import (
8
+ Attention,
9
+ FinalLayer,
10
+ LabelEmbedder,
11
+ Mlp,
12
+ MoeMLP_DiffMoE as MoeMLP,
13
+ PatchEmbed,
14
+ TimestepEmbedder,
15
+ get_2d_sincos_pos_embed,
16
+ modulate,
17
+ )
18
+
19
+
20
+ class MoEGate(nn.Module):
21
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
22
+ super().__init__()
23
+ self.top_k = num_experts_per_tok
24
+ self.n_routed_experts = num_experts
25
+ self.scoring_func = "softmax"
26
+ self.alpha = aux_loss_alpha
27
+ self.seq_aux = False
28
+ self.norm_topk_prob = False
29
+ self.gating_dim = embed_dim
30
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
31
+ self.reset_parameters()
32
+
33
+ def reset_parameters(self):
34
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
35
+
36
+ def forward(self, hidden_states):
37
+ bsz, seq_len, h = hidden_states.shape
38
+ hidden_states = hidden_states.view(-1, h)
39
+ logits = F.linear(hidden_states, self.weight, None)
40
+ if self.scoring_func != "softmax":
41
+ raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}")
42
+ scores = logits.softmax(dim=-1)
43
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
44
+ if self.top_k > 1 and self.norm_topk_prob:
45
+ topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
46
+
47
+ if self.training and self.alpha > 0.0:
48
+ scores_for_aux = scores
49
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
50
+ if self.seq_aux:
51
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
52
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
53
+ ce.scatter_add_(
54
+ 1,
55
+ topk_idx_for_aux_loss,
56
+ torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device),
57
+ ).div_(seq_len * self.top_k / self.n_routed_experts)
58
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
59
+ else:
60
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
61
+ ce = mask_ce.float().mean(0)
62
+ pi = scores_for_aux.mean(0)
63
+ fi = ce * self.n_routed_experts
64
+ aux_loss = (pi * fi).sum() * self.alpha
65
+ else:
66
+ aux_loss = None
67
+ return topk_idx, topk_weight, aux_loss
68
+
69
+
70
+ class AddAuxiliaryLoss(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, x, loss):
73
+ ctx.dtype = loss.dtype
74
+ ctx.required_aux_loss = loss.requires_grad
75
+ return x
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output):
79
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
80
+ return grad_output, grad_loss
81
+
82
+
83
+ class SparseMoEBlock(nn.Module):
84
+ def __init__(
85
+ self,
86
+ experts,
87
+ hidden_dim,
88
+ mlp_ratio=4,
89
+ num_experts=16,
90
+ num_experts_per_tok=2,
91
+ pretraining_tp=2,
92
+ n_shared_experts=2,
93
+ ):
94
+ super().__init__()
95
+ self.top_k = num_experts_per_tok
96
+ self.experts = nn.ModuleList(experts)
97
+ self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
98
+ self.n_shared_experts = n_shared_experts
99
+ if self.n_shared_experts > 0:
100
+ intermediate_size = hidden_dim * self.n_shared_experts
101
+ self.shared_experts = MoeMLP(
102
+ hidden_size=hidden_dim,
103
+ intermediate_size=intermediate_size,
104
+ pretraining_tp=pretraining_tp,
105
+ )
106
+
107
+ def forward(self, hidden_states):
108
+ identity = hidden_states
109
+ orig_shape = hidden_states.shape
110
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
111
+
112
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
113
+ flat_topk_idx = topk_idx.view(-1)
114
+ if self.training:
115
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
116
+ y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
117
+ for i, expert in enumerate(self.experts):
118
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float()
119
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
120
+ y = y.view(*orig_shape)
121
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
122
+ else:
123
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
124
+ if self.n_shared_experts > 0:
125
+ y = y + self.shared_experts(identity)
126
+ return y
127
+
128
+ @torch.no_grad()
129
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
130
+ expert_cache = torch.zeros_like(x)
131
+ idxs = flat_expert_indices.argsort()
132
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
133
+ token_idxs = idxs // self.top_k
134
+ for i, end_idx in enumerate(tokens_per_expert):
135
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
136
+ if start_idx == end_idx:
137
+ continue
138
+ expert = self.experts[i]
139
+ exp_token_idx = token_idxs[start_idx:end_idx]
140
+ expert_tokens = x[exp_token_idx]
141
+ expert_out = expert(expert_tokens)
142
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
143
+ expert_cache = expert_cache.to(expert_out.dtype)
144
+ expert_cache.scatter_reduce_(
145
+ 0,
146
+ exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
147
+ expert_out,
148
+ reduce="sum",
149
+ )
150
+ return expert_cache
151
+
152
+
153
+ class DiTBlock(nn.Module):
154
+ def __init__(
155
+ self,
156
+ hidden_size,
157
+ num_heads,
158
+ mlp_ratio=4,
159
+ pretraining_tp=2,
160
+ use_swiglu=False,
161
+ MoE_config=None,
162
+ use_moe=True,
163
+ **block_kwargs,
164
+ ):
165
+ super().__init__()
166
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
167
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
168
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
170
+ self.use_moe = use_moe
171
+ if use_moe:
172
+ if not use_swiglu:
173
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
174
+ experts = [
175
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
176
+ for _ in range(MoE_config.num_experts)
177
+ ]
178
+ else:
179
+ experts = [
180
+ MoeMLP(
181
+ hidden_size=hidden_size,
182
+ intermediate_size=mlp_hidden_dim,
183
+ pretraining_tp=pretraining_tp,
184
+ )
185
+ for _ in range(MoE_config.num_experts)
186
+ ]
187
+ self.mlp = SparseMoEBlock(
188
+ experts=experts,
189
+ hidden_dim=hidden_size,
190
+ num_experts=MoE_config.num_experts,
191
+ num_experts_per_tok=MoE_config.capacity,
192
+ n_shared_experts=MoE_config.n_shared_experts,
193
+ )
194
+ else:
195
+ if not use_swiglu:
196
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
197
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
198
+ else:
199
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
200
+
201
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
202
+
203
+ def forward(self, x, c):
204
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
205
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
206
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
207
+ return x
208
+
209
+
210
+ class DiT(nn.Module):
211
+ def __init__(
212
+ self,
213
+ input_size=32,
214
+ patch_size=2,
215
+ in_channels=4,
216
+ hidden_size=1152,
217
+ depth=28,
218
+ num_heads=16,
219
+ mlp_ratio=4,
220
+ qk_norm=False,
221
+ class_dropout_prob=0.1,
222
+ num_classes=1000,
223
+ pretraining_tp=1,
224
+ learn_sigma=True,
225
+ use_swiglu=False,
226
+ MoE_config=None,
227
+ ):
228
+ super().__init__()
229
+ self.learn_sigma = learn_sigma
230
+ self.in_channels = in_channels
231
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
232
+ self.patch_size = patch_size
233
+ self.num_heads = num_heads
234
+ self.MoE_config = MoE_config
235
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
236
+
237
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
238
+ self.t_embedder = TimestepEmbedder(hidden_size)
239
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
240
+ num_patches = self.x_embedder.num_patches
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
242
+
243
+ self.blocks = nn.ModuleList(
244
+ [
245
+ DiTBlock(
246
+ hidden_size,
247
+ num_heads,
248
+ mlp_ratio=mlp_ratio,
249
+ qk_norm=qk_norm,
250
+ use_swiglu=use_swiglu,
251
+ pretraining_tp=pretraining_tp,
252
+ MoE_config=MoE_config,
253
+ use_moe=use_moe_flag[i],
254
+ )
255
+ for i in range(depth)
256
+ ]
257
+ )
258
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
259
+ self.initialize_weights()
260
+
261
+ def initialize_weights(self):
262
+ def _basic_init(module):
263
+ if isinstance(module, nn.Linear):
264
+ torch.nn.init.xavier_uniform_(module.weight)
265
+ if module.bias is not None:
266
+ nn.init.constant_(module.bias, 0)
267
+
268
+ self.apply(_basic_init)
269
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
270
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
271
+ w = self.x_embedder.proj.weight.data
272
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
273
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
274
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
275
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
276
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
277
+ for block in self.blocks:
278
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
279
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
280
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
281
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
282
+ nn.init.constant_(self.final_layer.linear.weight, 0)
283
+ nn.init.constant_(self.final_layer.linear.bias, 0)
284
+
285
+ def unpatchify(self, x):
286
+ c = self.out_channels
287
+ p = self.x_embedder.patch_size[0]
288
+ h = w = int(x.shape[1] ** 0.5)
289
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
290
+ x = torch.einsum("nhwpqc->nchpwq", x)
291
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
292
+
293
+ def forward(self, x, t, context, **kwargs):
294
+ y = context
295
+ if len(x.shape) != 4:
296
+ x = x.squeeze(2)
297
+ x = self.x_embedder(x) + self.pos_embed
298
+ t = self.t_embedder(t)
299
+ y = self.y_embedder(y, self.training)
300
+ c = t + y
301
+ for block in self.blocks:
302
+ x = block(x, c)
303
+ x = self.final_layer(x, c)
304
+ return self.unpatchify(x)
ProMoE-L-256/transformer/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoETransformer2DModel",
3
+ "architecture": "promoe_tc",
4
+ "model_config": {
5
+ "MoE_config": {
6
+ "init_MoeMLP": false,
7
+ "interleave": true,
8
+ "moe_intermediate_size": 2048,
9
+ "num_routed_experts": 12,
10
+ "shared_expert_intermediate_size": 2048,
11
+ "top_k": 1,
12
+ "use_shared_expert": true,
13
+ "use_uncond_expert": true
14
+ },
15
+ "depth": 24,
16
+ "hidden_size": 1024,
17
+ "input_size": 32,
18
+ "num_classes": 1000,
19
+ "num_heads": 16,
20
+ "patch_size": 2
21
+ }
22
+ }
ProMoE-L-256/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d6f52a00ecfdb55d68bd525851a28d802f57e218b4d4dd0de8e5136e3c16c75
3
+ size 4250844688
ProMoE-L-256/transformer/modeling_promoe_common.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from dataclasses import dataclass
4
+ from itertools import repeat
5
+ from typing import Any, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
16
+ return tuple(x)
17
+ return tuple(repeat(x, n))
18
+
19
+ return parse
20
+
21
+
22
+ to_2tuple = _ntuple(2)
23
+
24
+
25
+ class AttrDict(dict):
26
+ def __getattr__(self, item):
27
+ try:
28
+ return self[item]
29
+ except KeyError as error:
30
+ raise AttributeError(item) from error
31
+
32
+ def __setattr__(self, key, value):
33
+ self[key] = value
34
+
35
+ @staticmethod
36
+ def from_data(data: Any) -> Any:
37
+ if isinstance(data, dict):
38
+ return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()})
39
+ if isinstance(data, list):
40
+ return [AttrDict.from_data(v) for v in data]
41
+ return data
42
+
43
+
44
+ class PatchEmbed(nn.Module):
45
+ def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True):
46
+ super().__init__()
47
+ self.img_size = to_2tuple(input_size)
48
+ self.patch_size = to_2tuple(patch_size)
49
+ self.grid_size = (
50
+ self.img_size[0] // self.patch_size[0],
51
+ self.img_size[1] // self.patch_size[1],
52
+ )
53
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
54
+ self.proj = nn.Conv2d(
55
+ in_channels,
56
+ embed_dim,
57
+ kernel_size=self.patch_size,
58
+ stride=self.patch_size,
59
+ bias=bias,
60
+ )
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ hidden_states = self.proj(hidden_states)
64
+ return hidden_states.flatten(2).transpose(1, 2)
65
+
66
+
67
+ class Mlp(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_features,
71
+ hidden_features=None,
72
+ out_features=None,
73
+ act_layer=nn.GELU,
74
+ norm_layer=None,
75
+ bias=True,
76
+ drop=0.0,
77
+ ):
78
+ super().__init__()
79
+ out_features = out_features or in_features
80
+ hidden_features = hidden_features or in_features
81
+ bias = to_2tuple(bias)
82
+ drop_probs = to_2tuple(drop)
83
+
84
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
85
+ self.act = act_layer()
86
+ self.drop1 = nn.Dropout(drop_probs[0])
87
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
88
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
89
+ self.drop2 = nn.Dropout(drop_probs[1])
90
+
91
+ def forward(self, x):
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop1(x)
95
+ x = self.norm(x)
96
+ x = self.fc2(x)
97
+ x = self.drop2(x)
98
+ return x
99
+
100
+
101
+ class MoeMLP(nn.Module):
102
+ def __init__(self, hidden_size, intermediate_size):
103
+ super().__init__()
104
+ self.hidden_size = hidden_size
105
+ self.intermediate_size = intermediate_size
106
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
107
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
108
+ self.act_fn = nn.GELU(approximate="tanh")
109
+
110
+ def forward(self, x):
111
+ return self.down_proj(self.act_fn(self.up_proj(x)))
112
+
113
+
114
+ class MoeMLP_DiffMoE(nn.Module):
115
+ def __init__(self, hidden_size, intermediate_size, pretraining_tp=2):
116
+ super().__init__()
117
+ self.hidden_size = hidden_size
118
+ self.intermediate_size = intermediate_size
119
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
120
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
121
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
122
+ self.act_fn = nn.SiLU()
123
+ self.pretraining_tp = pretraining_tp
124
+
125
+ def forward(self, x):
126
+ if self.pretraining_tp > 1:
127
+ split_size = self.intermediate_size // self.pretraining_tp
128
+ gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0)
129
+ up_proj_slices = self.up_proj.weight.split(split_size, dim=0)
130
+ down_proj_slices = self.down_proj.weight.split(split_size, dim=1)
131
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
132
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
133
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1)
134
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
135
+ return sum(down_proj)
136
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
137
+
138
+
139
+ class Attention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ dim: int,
143
+ num_heads: int = 8,
144
+ qkv_bias: bool = False,
145
+ qk_norm: bool = False,
146
+ attn_drop: float = 0.0,
147
+ proj_drop: float = 0.0,
148
+ head_dim=None,
149
+ norm_layer: nn.Module = nn.LayerNorm,
150
+ ):
151
+ super().__init__()
152
+ self.num_heads = num_heads
153
+ if head_dim is None:
154
+ if dim % num_heads != 0:
155
+ raise ValueError("dim must be divisible by num_heads")
156
+ self.head_dim = dim // num_heads
157
+ else:
158
+ self.head_dim = head_dim
159
+ self.scale = self.head_dim**-0.5
160
+ self.fused_attn = True
161
+ self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias)
162
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
163
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
164
+ self.attn_drop = nn.Dropout(attn_drop)
165
+ self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
166
+ self.proj_drop = nn.Dropout(proj_drop)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ batch_size, seq_len, _ = x.shape
170
+ qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
184
+ attn = self.attn_drop(attn)
185
+ x = attn @ v
186
+
187
+ x = x.transpose(1, 2).reshape(batch_size, seq_len, -1)
188
+ x = self.proj(x)
189
+ return self.proj_drop(x)
190
+
191
+
192
+ def modulate(x, shift, scale):
193
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
194
+
195
+
196
+ class TimestepEmbedder(nn.Module):
197
+ def __init__(self, hidden_size, frequency_embedding_size=256):
198
+ super().__init__()
199
+ self.mlp = nn.Sequential(
200
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
201
+ nn.SiLU(),
202
+ nn.Linear(hidden_size, hidden_size, bias=True),
203
+ )
204
+ self.frequency_embedding_size = frequency_embedding_size
205
+
206
+ @staticmethod
207
+ def timestep_embedding(t, dim, max_period=10000):
208
+ half = dim // 2
209
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
210
+ device=t.device
211
+ )
212
+ args = t[:, None].float() * freqs[None]
213
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
214
+ if dim % 2:
215
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
216
+ return embedding
217
+
218
+ def forward(self, t):
219
+ t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size)
220
+ weight_dtype = self.mlp[0].weight.dtype
221
+ return self.mlp(t_freq.to(dtype=weight_dtype))
222
+
223
+
224
+ class LabelEmbedder(nn.Module):
225
+ def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False):
226
+ super().__init__()
227
+ use_cfg_embedding = dropout_prob > 0
228
+ self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size)
229
+ self.num_classes = num_classes
230
+ self.dropout_prob = dropout_prob
231
+ self.return_labels = return_labels
232
+
233
+ def token_drop(self, labels, force_drop_ids=None):
234
+ if force_drop_ids is None:
235
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
236
+ else:
237
+ drop_ids = force_drop_ids == 1
238
+ return torch.where(drop_ids, self.num_classes, labels)
239
+
240
+ def forward(self, labels, train, force_drop_ids=None):
241
+ if (train and self.dropout_prob > 0) or (force_drop_ids is not None):
242
+ labels = self.token_drop(labels, force_drop_ids)
243
+ embeddings = self.embedding_table(labels)
244
+ if self.return_labels:
245
+ return embeddings, labels
246
+ return embeddings
247
+
248
+
249
+ class FinalLayer(nn.Module):
250
+ def __init__(self, hidden_size, patch_size, out_channels):
251
+ super().__init__()
252
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
253
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
254
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
255
+
256
+ def forward(self, x, c):
257
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
258
+ x = modulate(self.norm_final(x), shift, scale)
259
+ return self.linear(x)
260
+
261
+
262
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
263
+ grid_h = np.arange(grid_size, dtype=np.float32)
264
+ grid_w = np.arange(grid_size, dtype=np.float32)
265
+ grid = np.meshgrid(grid_w, grid_h)
266
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
267
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
268
+ if cls_token and extra_tokens > 0:
269
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
270
+ return pos_embed
271
+
272
+
273
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
274
+ if embed_dim % 2 != 0:
275
+ raise ValueError("embed_dim must be even")
276
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
277
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
278
+ return np.concatenate([emb_h, emb_w], axis=1)
279
+
280
+
281
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
282
+ if embed_dim % 2 != 0:
283
+ raise ValueError("embed_dim must be even")
284
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
285
+ omega /= embed_dim / 2.0
286
+ omega = 1.0 / 10000**omega
287
+ pos = pos.reshape(-1)
288
+ out = np.einsum("m,d->md", pos, omega)
289
+ emb_sin = np.sin(out)
290
+ emb_cos = np.cos(out)
291
+ return np.concatenate([emb_sin, emb_cos], axis=1)
ProMoE-L-256/transformer/transformer_promoe.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ try:
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.utils import BaseOutput
11
+ except Exception: # pragma: no cover
12
+ class BaseOutput(dict):
13
+ def __post_init__(self):
14
+ self.update(self.__dict__)
15
+
16
+ class _Config(dict):
17
+ def __getattr__(self, key):
18
+ try:
19
+ return self[key]
20
+ except KeyError as error:
21
+ raise AttributeError(key) from error
22
+
23
+ class ConfigMixin:
24
+ config_name = "config.json"
25
+
26
+ class ModelMixin(nn.Module):
27
+ pass
28
+
29
+ def register_to_config(init):
30
+ def wrapper(self, *args, **kwargs):
31
+ import inspect
32
+
33
+ signature = inspect.signature(init)
34
+ bound = signature.bind(self, *args, **kwargs)
35
+ bound.apply_defaults()
36
+ self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
37
+ init(self, *args, **kwargs)
38
+
39
+ return wrapper
40
+
41
+ from .backbone_diffmoe import DiT as DiffMoEBackbone
42
+ from .backbone_dit import DiT as DiTBackbone
43
+ from .backbone_ecdit import DiT as ECDiTBackbone
44
+ from .backbone_promoe_ec import DiT as ProMoEECBackbone
45
+ from .backbone_promoe_tc import DiT as ProMoETCBackbone
46
+ from .backbone_tcdit import DiT as TCDiTBackbone
47
+ from .modeling_promoe_common import AttrDict
48
+
49
+
50
+ @dataclass
51
+ class ProMoETransformer2DModelOutput(BaseOutput):
52
+ sample: torch.FloatTensor
53
+ loss_strategy: Optional[str] = None
54
+ layer_idx_list: Optional[Tuple[int, ...]] = None
55
+ ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None
56
+ pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None
57
+ capacity_pred_loss_weight: Optional[float] = None
58
+
59
+
60
+ _BACKBONES = {
61
+ "dit": DiTBackbone,
62
+ "tcdit": TCDiTBackbone,
63
+ "ecdit": ECDiTBackbone,
64
+ "diffmoe": DiffMoEBackbone,
65
+ "promoe_tc": ProMoETCBackbone,
66
+ "promoe_ec": ProMoEECBackbone,
67
+ }
68
+
69
+
70
+ class ProMoETransformer2DModel(ModelMixin, ConfigMixin):
71
+ config_name = "config.json"
72
+
73
+ @register_to_config
74
+ def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None):
75
+ super().__init__()
76
+ if architecture not in _BACKBONES:
77
+ raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}")
78
+ model_config = model_config or {}
79
+ self.architecture = architecture
80
+ self.model_config = model_config
81
+ self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config))
82
+ self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4))
83
+ self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4))
84
+
85
+ def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]:
86
+ prepared = {}
87
+ for key, value in model_config.items():
88
+ prepared[key] = AttrDict.from_data(value)
89
+ return prepared
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ timestep: Union[torch.Tensor, float, int],
95
+ class_labels: Optional[torch.LongTensor] = None,
96
+ context: Optional[torch.LongTensor] = None,
97
+ return_dict: bool = True,
98
+ **kwargs,
99
+ ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]:
100
+ labels = class_labels if class_labels is not None else context
101
+ if labels is None:
102
+ raise ValueError("Either `class_labels` or `context` must be provided.")
103
+
104
+ if not torch.is_tensor(timestep):
105
+ timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype)
106
+ timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten()
107
+ if timestep.numel() == 1:
108
+ timestep = timestep.repeat(labels.shape[0])
109
+
110
+ sample = self.backbone(hidden_states, timestep, labels, **kwargs)
111
+ if isinstance(sample, tuple):
112
+ if len(sample) == 6 and sample[1] == "Capacity_Pred":
113
+ output = ProMoETransformer2DModelOutput(
114
+ sample=sample[0],
115
+ loss_strategy=sample[1],
116
+ layer_idx_list=tuple(sample[2]),
117
+ ones_list=tuple(sample[3]),
118
+ pred_c_list=tuple(sample[4]),
119
+ capacity_pred_loss_weight=float(sample[5]),
120
+ )
121
+ else:
122
+ output = ProMoETransformer2DModelOutput(sample=sample[0])
123
+ else:
124
+ output = ProMoETransformer2DModelOutput(sample=sample)
125
+
126
+ if not return_dict:
127
+ if output.loss_strategy is None:
128
+ return (output.sample,)
129
+ return (
130
+ output.sample,
131
+ output.loss_strategy,
132
+ output.layer_idx_list,
133
+ output.ones_list,
134
+ output.pred_c_list,
135
+ output.capacity_pred_loss_weight,
136
+ )
137
+ return output
ProMoE-L-256/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.4.2",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 256,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
ProMoE-L-256/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
ProMoE-XL-256/model_index.json ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "ProMoEPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "id2label": {
8
+ "0": "tench, Tinca tinca",
9
+ "1": "goldfish, Carassius auratus",
10
+ "10": "brambling, Fringilla montifringilla",
11
+ "100": "black swan, Cygnus atratus",
12
+ "101": "tusker",
13
+ "102": "echidna, spiny anteater, anteater",
14
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
15
+ "104": "wallaby, brush kangaroo",
16
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
17
+ "106": "wombat",
18
+ "107": "jellyfish",
19
+ "108": "sea anemone, anemone",
20
+ "109": "brain coral",
21
+ "11": "goldfinch, Carduelis carduelis",
22
+ "110": "flatworm, platyhelminth",
23
+ "111": "nematode, nematode worm, roundworm",
24
+ "112": "conch",
25
+ "113": "snail",
26
+ "114": "slug",
27
+ "115": "sea slug, nudibranch",
28
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
29
+ "117": "chambered nautilus, pearly nautilus, nautilus",
30
+ "118": "Dungeness crab, Cancer magister",
31
+ "119": "rock crab, Cancer irroratus",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "120": "fiddler crab",
34
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
35
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
36
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
37
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
38
+ "125": "hermit crab",
39
+ "126": "isopod",
40
+ "127": "white stork, Ciconia ciconia",
41
+ "128": "black stork, Ciconia nigra",
42
+ "129": "spoonbill",
43
+ "13": "junco, snowbird",
44
+ "130": "flamingo",
45
+ "131": "little blue heron, Egretta caerulea",
46
+ "132": "American egret, great white heron, Egretta albus",
47
+ "133": "bittern",
48
+ "134": "crane",
49
+ "135": "limpkin, Aramus pictus",
50
+ "136": "European gallinule, Porphyrio porphyrio",
51
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
52
+ "138": "bustard",
53
+ "139": "ruddy turnstone, Arenaria interpres",
54
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
55
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
56
+ "141": "redshank, Tringa totanus",
57
+ "142": "dowitcher",
58
+ "143": "oystercatcher, oyster catcher",
59
+ "144": "pelican",
60
+ "145": "king penguin, Aptenodytes patagonica",
61
+ "146": "albatross, mollymawk",
62
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
63
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
64
+ "149": "dugong, Dugong dugon",
65
+ "15": "robin, American robin, Turdus migratorius",
66
+ "150": "sea lion",
67
+ "151": "Chihuahua",
68
+ "152": "Japanese spaniel",
69
+ "153": "Maltese dog, Maltese terrier, Maltese",
70
+ "154": "Pekinese, Pekingese, Peke",
71
+ "155": "Shih-Tzu",
72
+ "156": "Blenheim spaniel",
73
+ "157": "papillon",
74
+ "158": "toy terrier",
75
+ "159": "Rhodesian ridgeback",
76
+ "16": "bulbul",
77
+ "160": "Afghan hound, Afghan",
78
+ "161": "basset, basset hound",
79
+ "162": "beagle",
80
+ "163": "bloodhound, sleuthhound",
81
+ "164": "bluetick",
82
+ "165": "black-and-tan coonhound",
83
+ "166": "Walker hound, Walker foxhound",
84
+ "167": "English foxhound",
85
+ "168": "redbone",
86
+ "169": "borzoi, Russian wolfhound",
87
+ "17": "jay",
88
+ "170": "Irish wolfhound",
89
+ "171": "Italian greyhound",
90
+ "172": "whippet",
91
+ "173": "Ibizan hound, Ibizan Podenco",
92
+ "174": "Norwegian elkhound, elkhound",
93
+ "175": "otterhound, otter hound",
94
+ "176": "Saluki, gazelle hound",
95
+ "177": "Scottish deerhound, deerhound",
96
+ "178": "Weimaraner",
97
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
98
+ "18": "magpie",
99
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
100
+ "181": "Bedlington terrier",
101
+ "182": "Border terrier",
102
+ "183": "Kerry blue terrier",
103
+ "184": "Irish terrier",
104
+ "185": "Norfolk terrier",
105
+ "186": "Norwich terrier",
106
+ "187": "Yorkshire terrier",
107
+ "188": "wire-haired fox terrier",
108
+ "189": "Lakeland terrier",
109
+ "19": "chickadee",
110
+ "190": "Sealyham terrier, Sealyham",
111
+ "191": "Airedale, Airedale terrier",
112
+ "192": "cairn, cairn terrier",
113
+ "193": "Australian terrier",
114
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
115
+ "195": "Boston bull, Boston terrier",
116
+ "196": "miniature schnauzer",
117
+ "197": "giant schnauzer",
118
+ "198": "standard schnauzer",
119
+ "199": "Scotch terrier, Scottish terrier, Scottie",
120
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
121
+ "20": "water ouzel, dipper",
122
+ "200": "Tibetan terrier, chrysanthemum dog",
123
+ "201": "silky terrier, Sydney silky",
124
+ "202": "soft-coated wheaten terrier",
125
+ "203": "West Highland white terrier",
126
+ "204": "Lhasa, Lhasa apso",
127
+ "205": "flat-coated retriever",
128
+ "206": "curly-coated retriever",
129
+ "207": "golden retriever",
130
+ "208": "Labrador retriever",
131
+ "209": "Chesapeake Bay retriever",
132
+ "21": "kite",
133
+ "210": "German short-haired pointer",
134
+ "211": "vizsla, Hungarian pointer",
135
+ "212": "English setter",
136
+ "213": "Irish setter, red setter",
137
+ "214": "Gordon setter",
138
+ "215": "Brittany spaniel",
139
+ "216": "clumber, clumber spaniel",
140
+ "217": "English springer, English springer spaniel",
141
+ "218": "Welsh springer spaniel",
142
+ "219": "cocker spaniel, English cocker spaniel, cocker",
143
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
144
+ "220": "Sussex spaniel",
145
+ "221": "Irish water spaniel",
146
+ "222": "kuvasz",
147
+ "223": "schipperke",
148
+ "224": "groenendael",
149
+ "225": "malinois",
150
+ "226": "briard",
151
+ "227": "kelpie",
152
+ "228": "komondor",
153
+ "229": "Old English sheepdog, bobtail",
154
+ "23": "vulture",
155
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
156
+ "231": "collie",
157
+ "232": "Border collie",
158
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
159
+ "234": "Rottweiler",
160
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
161
+ "236": "Doberman, Doberman pinscher",
162
+ "237": "miniature pinscher",
163
+ "238": "Greater Swiss Mountain dog",
164
+ "239": "Bernese mountain dog",
165
+ "24": "great grey owl, great gray owl, Strix nebulosa",
166
+ "240": "Appenzeller",
167
+ "241": "EntleBucher",
168
+ "242": "boxer",
169
+ "243": "bull mastiff",
170
+ "244": "Tibetan mastiff",
171
+ "245": "French bulldog",
172
+ "246": "Great Dane",
173
+ "247": "Saint Bernard, St Bernard",
174
+ "248": "Eskimo dog, husky",
175
+ "249": "malamute, malemute, Alaskan malamute",
176
+ "25": "European fire salamander, Salamandra salamandra",
177
+ "250": "Siberian husky",
178
+ "251": "dalmatian, coach dog, carriage dog",
179
+ "252": "affenpinscher, monkey pinscher, monkey dog",
180
+ "253": "basenji",
181
+ "254": "pug, pug-dog",
182
+ "255": "Leonberg",
183
+ "256": "Newfoundland, Newfoundland dog",
184
+ "257": "Great Pyrenees",
185
+ "258": "Samoyed, Samoyede",
186
+ "259": "Pomeranian",
187
+ "26": "common newt, Triturus vulgaris",
188
+ "260": "chow, chow chow",
189
+ "261": "keeshond",
190
+ "262": "Brabancon griffon",
191
+ "263": "Pembroke, Pembroke Welsh corgi",
192
+ "264": "Cardigan, Cardigan Welsh corgi",
193
+ "265": "toy poodle",
194
+ "266": "miniature poodle",
195
+ "267": "standard poodle",
196
+ "268": "Mexican hairless",
197
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
198
+ "27": "eft",
199
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
200
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
201
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
202
+ "273": "dingo, warrigal, warragal, Canis dingo",
203
+ "274": "dhole, Cuon alpinus",
204
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
205
+ "276": "hyena, hyaena",
206
+ "277": "red fox, Vulpes vulpes",
207
+ "278": "kit fox, Vulpes macrotis",
208
+ "279": "Arctic fox, white fox, Alopex lagopus",
209
+ "28": "spotted salamander, Ambystoma maculatum",
210
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
211
+ "281": "tabby, tabby cat",
212
+ "282": "tiger cat",
213
+ "283": "Persian cat",
214
+ "284": "Siamese cat, Siamese",
215
+ "285": "Egyptian cat",
216
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
217
+ "287": "lynx, catamount",
218
+ "288": "leopard, Panthera pardus",
219
+ "289": "snow leopard, ounce, Panthera uncia",
220
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
221
+ "290": "jaguar, panther, Panthera onca, Felis onca",
222
+ "291": "lion, king of beasts, Panthera leo",
223
+ "292": "tiger, Panthera tigris",
224
+ "293": "cheetah, chetah, Acinonyx jubatus",
225
+ "294": "brown bear, bruin, Ursus arctos",
226
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
227
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
228
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
229
+ "298": "mongoose",
230
+ "299": "meerkat, mierkat",
231
+ "3": "tiger shark, Galeocerdo cuvieri",
232
+ "30": "bullfrog, Rana catesbeiana",
233
+ "300": "tiger beetle",
234
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
235
+ "302": "ground beetle, carabid beetle",
236
+ "303": "long-horned beetle, longicorn, longicorn beetle",
237
+ "304": "leaf beetle, chrysomelid",
238
+ "305": "dung beetle",
239
+ "306": "rhinoceros beetle",
240
+ "307": "weevil",
241
+ "308": "fly",
242
+ "309": "bee",
243
+ "31": "tree frog, tree-frog",
244
+ "310": "ant, emmet, pismire",
245
+ "311": "grasshopper, hopper",
246
+ "312": "cricket",
247
+ "313": "walking stick, walkingstick, stick insect",
248
+ "314": "cockroach, roach",
249
+ "315": "mantis, mantid",
250
+ "316": "cicada, cicala",
251
+ "317": "leafhopper",
252
+ "318": "lacewing, lacewing fly",
253
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
254
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
255
+ "320": "damselfly",
256
+ "321": "admiral",
257
+ "322": "ringlet, ringlet butterfly",
258
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
259
+ "324": "cabbage butterfly",
260
+ "325": "sulphur butterfly, sulfur butterfly",
261
+ "326": "lycaenid, lycaenid butterfly",
262
+ "327": "starfish, sea star",
263
+ "328": "sea urchin",
264
+ "329": "sea cucumber, holothurian",
265
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
266
+ "330": "wood rabbit, cottontail, cottontail rabbit",
267
+ "331": "hare",
268
+ "332": "Angora, Angora rabbit",
269
+ "333": "hamster",
270
+ "334": "porcupine, hedgehog",
271
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
272
+ "336": "marmot",
273
+ "337": "beaver",
274
+ "338": "guinea pig, Cavia cobaya",
275
+ "339": "sorrel",
276
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
277
+ "340": "zebra",
278
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
279
+ "342": "wild boar, boar, Sus scrofa",
280
+ "343": "warthog",
281
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
282
+ "345": "ox",
283
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
284
+ "347": "bison",
285
+ "348": "ram, tup",
286
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
287
+ "35": "mud turtle",
288
+ "350": "ibex, Capra ibex",
289
+ "351": "hartebeest",
290
+ "352": "impala, Aepyceros melampus",
291
+ "353": "gazelle",
292
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
293
+ "355": "llama",
294
+ "356": "weasel",
295
+ "357": "mink",
296
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
297
+ "359": "black-footed ferret, ferret, Mustela nigripes",
298
+ "36": "terrapin",
299
+ "360": "otter",
300
+ "361": "skunk, polecat, wood pussy",
301
+ "362": "badger",
302
+ "363": "armadillo",
303
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
304
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
305
+ "366": "gorilla, Gorilla gorilla",
306
+ "367": "chimpanzee, chimp, Pan troglodytes",
307
+ "368": "gibbon, Hylobates lar",
308
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
309
+ "37": "box turtle, box tortoise",
310
+ "370": "guenon, guenon monkey",
311
+ "371": "patas, hussar monkey, Erythrocebus patas",
312
+ "372": "baboon",
313
+ "373": "macaque",
314
+ "374": "langur",
315
+ "375": "colobus, colobus monkey",
316
+ "376": "proboscis monkey, Nasalis larvatus",
317
+ "377": "marmoset",
318
+ "378": "capuchin, ringtail, Cebus capucinus",
319
+ "379": "howler monkey, howler",
320
+ "38": "banded gecko",
321
+ "380": "titi, titi monkey",
322
+ "381": "spider monkey, Ateles geoffroyi",
323
+ "382": "squirrel monkey, Saimiri sciureus",
324
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
325
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
326
+ "385": "Indian elephant, Elephas maximus",
327
+ "386": "African elephant, Loxodonta africana",
328
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
329
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
330
+ "389": "barracouta, snoek",
331
+ "39": "common iguana, iguana, Iguana iguana",
332
+ "390": "eel",
333
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
334
+ "392": "rock beauty, Holocanthus tricolor",
335
+ "393": "anemone fish",
336
+ "394": "sturgeon",
337
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
338
+ "396": "lionfish",
339
+ "397": "puffer, pufferfish, blowfish, globefish",
340
+ "398": "abacus",
341
+ "399": "abaya",
342
+ "4": "hammerhead, hammerhead shark",
343
+ "40": "American chameleon, anole, Anolis carolinensis",
344
+ "400": "academic gown, academic robe, judge robe",
345
+ "401": "accordion, piano accordion, squeeze box",
346
+ "402": "acoustic guitar",
347
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
348
+ "404": "airliner",
349
+ "405": "airship, dirigible",
350
+ "406": "altar",
351
+ "407": "ambulance",
352
+ "408": "amphibian, amphibious vehicle",
353
+ "409": "analog clock",
354
+ "41": "whiptail, whiptail lizard",
355
+ "410": "apiary, bee house",
356
+ "411": "apron",
357
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
358
+ "413": "assault rifle, assault gun",
359
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
360
+ "415": "bakery, bakeshop, bakehouse",
361
+ "416": "balance beam, beam",
362
+ "417": "balloon",
363
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
364
+ "419": "Band Aid",
365
+ "42": "agama",
366
+ "420": "banjo",
367
+ "421": "bannister, banister, balustrade, balusters, handrail",
368
+ "422": "barbell",
369
+ "423": "barber chair",
370
+ "424": "barbershop",
371
+ "425": "barn",
372
+ "426": "barometer",
373
+ "427": "barrel, cask",
374
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
375
+ "429": "baseball",
376
+ "43": "frilled lizard, Chlamydosaurus kingi",
377
+ "430": "basketball",
378
+ "431": "bassinet",
379
+ "432": "bassoon",
380
+ "433": "bathing cap, swimming cap",
381
+ "434": "bath towel",
382
+ "435": "bathtub, bathing tub, bath, tub",
383
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
384
+ "437": "beacon, lighthouse, beacon light, pharos",
385
+ "438": "beaker",
386
+ "439": "bearskin, busby, shako",
387
+ "44": "alligator lizard",
388
+ "440": "beer bottle",
389
+ "441": "beer glass",
390
+ "442": "bell cote, bell cot",
391
+ "443": "bib",
392
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
393
+ "445": "bikini, two-piece",
394
+ "446": "binder, ring-binder",
395
+ "447": "binoculars, field glasses, opera glasses",
396
+ "448": "birdhouse",
397
+ "449": "boathouse",
398
+ "45": "Gila monster, Heloderma suspectum",
399
+ "450": "bobsled, bobsleigh, bob",
400
+ "451": "bolo tie, bolo, bola tie, bola",
401
+ "452": "bonnet, poke bonnet",
402
+ "453": "bookcase",
403
+ "454": "bookshop, bookstore, bookstall",
404
+ "455": "bottlecap",
405
+ "456": "bow",
406
+ "457": "bow tie, bow-tie, bowtie",
407
+ "458": "brass, memorial tablet, plaque",
408
+ "459": "brassiere, bra, bandeau",
409
+ "46": "green lizard, Lacerta viridis",
410
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
411
+ "461": "breastplate, aegis, egis",
412
+ "462": "broom",
413
+ "463": "bucket, pail",
414
+ "464": "buckle",
415
+ "465": "bulletproof vest",
416
+ "466": "bullet train, bullet",
417
+ "467": "butcher shop, meat market",
418
+ "468": "cab, hack, taxi, taxicab",
419
+ "469": "caldron, cauldron",
420
+ "47": "African chameleon, Chamaeleo chamaeleon",
421
+ "470": "candle, taper, wax light",
422
+ "471": "cannon",
423
+ "472": "canoe",
424
+ "473": "can opener, tin opener",
425
+ "474": "cardigan",
426
+ "475": "car mirror",
427
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
428
+ "477": "carpenters kit, tool kit",
429
+ "478": "carton",
430
+ "479": "car wheel",
431
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
432
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
433
+ "481": "cassette",
434
+ "482": "cassette player",
435
+ "483": "castle",
436
+ "484": "catamaran",
437
+ "485": "CD player",
438
+ "486": "cello, violoncello",
439
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
440
+ "488": "chain",
441
+ "489": "chainlink fence",
442
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
443
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
444
+ "491": "chain saw, chainsaw",
445
+ "492": "chest",
446
+ "493": "chiffonier, commode",
447
+ "494": "chime, bell, gong",
448
+ "495": "china cabinet, china closet",
449
+ "496": "Christmas stocking",
450
+ "497": "church, church building",
451
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
452
+ "499": "cleaver, meat cleaver, chopper",
453
+ "5": "electric ray, crampfish, numbfish, torpedo",
454
+ "50": "American alligator, Alligator mississipiensis",
455
+ "500": "cliff dwelling",
456
+ "501": "cloak",
457
+ "502": "clog, geta, patten, sabot",
458
+ "503": "cocktail shaker",
459
+ "504": "coffee mug",
460
+ "505": "coffeepot",
461
+ "506": "coil, spiral, volute, whorl, helix",
462
+ "507": "combination lock",
463
+ "508": "computer keyboard, keypad",
464
+ "509": "confectionery, confectionary, candy store",
465
+ "51": "triceratops",
466
+ "510": "container ship, containership, container vessel",
467
+ "511": "convertible",
468
+ "512": "corkscrew, bottle screw",
469
+ "513": "cornet, horn, trumpet, trump",
470
+ "514": "cowboy boot",
471
+ "515": "cowboy hat, ten-gallon hat",
472
+ "516": "cradle",
473
+ "517": "crane",
474
+ "518": "crash helmet",
475
+ "519": "crate",
476
+ "52": "thunder snake, worm snake, Carphophis amoenus",
477
+ "520": "crib, cot",
478
+ "521": "Crock Pot",
479
+ "522": "croquet ball",
480
+ "523": "crutch",
481
+ "524": "cuirass",
482
+ "525": "dam, dike, dyke",
483
+ "526": "desk",
484
+ "527": "desktop computer",
485
+ "528": "dial telephone, dial phone",
486
+ "529": "diaper, nappy, napkin",
487
+ "53": "ringneck snake, ring-necked snake, ring snake",
488
+ "530": "digital clock",
489
+ "531": "digital watch",
490
+ "532": "dining table, board",
491
+ "533": "dishrag, dishcloth",
492
+ "534": "dishwasher, dish washer, dishwashing machine",
493
+ "535": "disk brake, disc brake",
494
+ "536": "dock, dockage, docking facility",
495
+ "537": "dogsled, dog sled, dog sleigh",
496
+ "538": "dome",
497
+ "539": "doormat, welcome mat",
498
+ "54": "hognose snake, puff adder, sand viper",
499
+ "540": "drilling platform, offshore rig",
500
+ "541": "drum, membranophone, tympan",
501
+ "542": "drumstick",
502
+ "543": "dumbbell",
503
+ "544": "Dutch oven",
504
+ "545": "electric fan, blower",
505
+ "546": "electric guitar",
506
+ "547": "electric locomotive",
507
+ "548": "entertainment center",
508
+ "549": "envelope",
509
+ "55": "green snake, grass snake",
510
+ "550": "espresso maker",
511
+ "551": "face powder",
512
+ "552": "feather boa, boa",
513
+ "553": "file, file cabinet, filing cabinet",
514
+ "554": "fireboat",
515
+ "555": "fire engine, fire truck",
516
+ "556": "fire screen, fireguard",
517
+ "557": "flagpole, flagstaff",
518
+ "558": "flute, transverse flute",
519
+ "559": "folding chair",
520
+ "56": "king snake, kingsnake",
521
+ "560": "football helmet",
522
+ "561": "forklift",
523
+ "562": "fountain",
524
+ "563": "fountain pen",
525
+ "564": "four-poster",
526
+ "565": "freight car",
527
+ "566": "French horn, horn",
528
+ "567": "frying pan, frypan, skillet",
529
+ "568": "fur coat",
530
+ "569": "garbage truck, dustcart",
531
+ "57": "garter snake, grass snake",
532
+ "570": "gasmask, respirator, gas helmet",
533
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
534
+ "572": "goblet",
535
+ "573": "go-kart",
536
+ "574": "golf ball",
537
+ "575": "golfcart, golf cart",
538
+ "576": "gondola",
539
+ "577": "gong, tam-tam",
540
+ "578": "gown",
541
+ "579": "grand piano, grand",
542
+ "58": "water snake",
543
+ "580": "greenhouse, nursery, glasshouse",
544
+ "581": "grille, radiator grille",
545
+ "582": "grocery store, grocery, food market, market",
546
+ "583": "guillotine",
547
+ "584": "hair slide",
548
+ "585": "hair spray",
549
+ "586": "half track",
550
+ "587": "hammer",
551
+ "588": "hamper",
552
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
553
+ "59": "vine snake",
554
+ "590": "hand-held computer, hand-held microcomputer",
555
+ "591": "handkerchief, hankie, hanky, hankey",
556
+ "592": "hard disc, hard disk, fixed disk",
557
+ "593": "harmonica, mouth organ, harp, mouth harp",
558
+ "594": "harp",
559
+ "595": "harvester, reaper",
560
+ "596": "hatchet",
561
+ "597": "holster",
562
+ "598": "home theater, home theatre",
563
+ "599": "honeycomb",
564
+ "6": "stingray",
565
+ "60": "night snake, Hypsiglena torquata",
566
+ "600": "hook, claw",
567
+ "601": "hoopskirt, crinoline",
568
+ "602": "horizontal bar, high bar",
569
+ "603": "horse cart, horse-cart",
570
+ "604": "hourglass",
571
+ "605": "iPod",
572
+ "606": "iron, smoothing iron",
573
+ "607": "jack-o-lantern",
574
+ "608": "jean, blue jean, denim",
575
+ "609": "jeep, landrover",
576
+ "61": "boa constrictor, Constrictor constrictor",
577
+ "610": "jersey, T-shirt, tee shirt",
578
+ "611": "jigsaw puzzle",
579
+ "612": "jinrikisha, ricksha, rickshaw",
580
+ "613": "joystick",
581
+ "614": "kimono",
582
+ "615": "knee pad",
583
+ "616": "knot",
584
+ "617": "lab coat, laboratory coat",
585
+ "618": "ladle",
586
+ "619": "lampshade, lamp shade",
587
+ "62": "rock python, rock snake, Python sebae",
588
+ "620": "laptop, laptop computer",
589
+ "621": "lawn mower, mower",
590
+ "622": "lens cap, lens cover",
591
+ "623": "letter opener, paper knife, paperknife",
592
+ "624": "library",
593
+ "625": "lifeboat",
594
+ "626": "lighter, light, igniter, ignitor",
595
+ "627": "limousine, limo",
596
+ "628": "liner, ocean liner",
597
+ "629": "lipstick, lip rouge",
598
+ "63": "Indian cobra, Naja naja",
599
+ "630": "Loafer",
600
+ "631": "lotion",
601
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
602
+ "633": "loupe, jewelers loupe",
603
+ "634": "lumbermill, sawmill",
604
+ "635": "magnetic compass",
605
+ "636": "mailbag, postbag",
606
+ "637": "mailbox, letter box",
607
+ "638": "maillot",
608
+ "639": "maillot, tank suit",
609
+ "64": "green mamba",
610
+ "640": "manhole cover",
611
+ "641": "maraca",
612
+ "642": "marimba, xylophone",
613
+ "643": "mask",
614
+ "644": "matchstick",
615
+ "645": "maypole",
616
+ "646": "maze, labyrinth",
617
+ "647": "measuring cup",
618
+ "648": "medicine chest, medicine cabinet",
619
+ "649": "megalith, megalithic structure",
620
+ "65": "sea snake",
621
+ "650": "microphone, mike",
622
+ "651": "microwave, microwave oven",
623
+ "652": "military uniform",
624
+ "653": "milk can",
625
+ "654": "minibus",
626
+ "655": "miniskirt, mini",
627
+ "656": "minivan",
628
+ "657": "missile",
629
+ "658": "mitten",
630
+ "659": "mixing bowl",
631
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
632
+ "660": "mobile home, manufactured home",
633
+ "661": "Model T",
634
+ "662": "modem",
635
+ "663": "monastery",
636
+ "664": "monitor",
637
+ "665": "moped",
638
+ "666": "mortar",
639
+ "667": "mortarboard",
640
+ "668": "mosque",
641
+ "669": "mosquito net",
642
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
643
+ "670": "motor scooter, scooter",
644
+ "671": "mountain bike, all-terrain bike, off-roader",
645
+ "672": "mountain tent",
646
+ "673": "mouse, computer mouse",
647
+ "674": "mousetrap",
648
+ "675": "moving van",
649
+ "676": "muzzle",
650
+ "677": "nail",
651
+ "678": "neck brace",
652
+ "679": "necklace",
653
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
654
+ "680": "nipple",
655
+ "681": "notebook, notebook computer",
656
+ "682": "obelisk",
657
+ "683": "oboe, hautboy, hautbois",
658
+ "684": "ocarina, sweet potato",
659
+ "685": "odometer, hodometer, mileometer, milometer",
660
+ "686": "oil filter",
661
+ "687": "organ, pipe organ",
662
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
663
+ "689": "overskirt",
664
+ "69": "trilobite",
665
+ "690": "oxcart",
666
+ "691": "oxygen mask",
667
+ "692": "packet",
668
+ "693": "paddle, boat paddle",
669
+ "694": "paddlewheel, paddle wheel",
670
+ "695": "padlock",
671
+ "696": "paintbrush",
672
+ "697": "pajama, pyjama, pjs, jammies",
673
+ "698": "palace",
674
+ "699": "panpipe, pandean pipe, syrinx",
675
+ "7": "cock",
676
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
677
+ "700": "paper towel",
678
+ "701": "parachute, chute",
679
+ "702": "parallel bars, bars",
680
+ "703": "park bench",
681
+ "704": "parking meter",
682
+ "705": "passenger car, coach, carriage",
683
+ "706": "patio, terrace",
684
+ "707": "pay-phone, pay-station",
685
+ "708": "pedestal, plinth, footstall",
686
+ "709": "pencil box, pencil case",
687
+ "71": "scorpion",
688
+ "710": "pencil sharpener",
689
+ "711": "perfume, essence",
690
+ "712": "Petri dish",
691
+ "713": "photocopier",
692
+ "714": "pick, plectrum, plectron",
693
+ "715": "pickelhaube",
694
+ "716": "picket fence, paling",
695
+ "717": "pickup, pickup truck",
696
+ "718": "pier",
697
+ "719": "piggy bank, penny bank",
698
+ "72": "black and gold garden spider, Argiope aurantia",
699
+ "720": "pill bottle",
700
+ "721": "pillow",
701
+ "722": "ping-pong ball",
702
+ "723": "pinwheel",
703
+ "724": "pirate, pirate ship",
704
+ "725": "pitcher, ewer",
705
+ "726": "plane, carpenters plane, woodworking plane",
706
+ "727": "planetarium",
707
+ "728": "plastic bag",
708
+ "729": "plate rack",
709
+ "73": "barn spider, Araneus cavaticus",
710
+ "730": "plow, plough",
711
+ "731": "plunger, plumbers helper",
712
+ "732": "Polaroid camera, Polaroid Land camera",
713
+ "733": "pole",
714
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
715
+ "735": "poncho",
716
+ "736": "pool table, billiard table, snooker table",
717
+ "737": "pop bottle, soda bottle",
718
+ "738": "pot, flowerpot",
719
+ "739": "potters wheel",
720
+ "74": "garden spider, Aranea diademata",
721
+ "740": "power drill",
722
+ "741": "prayer rug, prayer mat",
723
+ "742": "printer",
724
+ "743": "prison, prison house",
725
+ "744": "projectile, missile",
726
+ "745": "projector",
727
+ "746": "puck, hockey puck",
728
+ "747": "punching bag, punch bag, punching ball, punchball",
729
+ "748": "purse",
730
+ "749": "quill, quill pen",
731
+ "75": "black widow, Latrodectus mactans",
732
+ "750": "quilt, comforter, comfort, puff",
733
+ "751": "racer, race car, racing car",
734
+ "752": "racket, racquet",
735
+ "753": "radiator",
736
+ "754": "radio, wireless",
737
+ "755": "radio telescope, radio reflector",
738
+ "756": "rain barrel",
739
+ "757": "recreational vehicle, RV, R.V.",
740
+ "758": "reel",
741
+ "759": "reflex camera",
742
+ "76": "tarantula",
743
+ "760": "refrigerator, icebox",
744
+ "761": "remote control, remote",
745
+ "762": "restaurant, eating house, eating place, eatery",
746
+ "763": "revolver, six-gun, six-shooter",
747
+ "764": "rifle",
748
+ "765": "rocking chair, rocker",
749
+ "766": "rotisserie",
750
+ "767": "rubber eraser, rubber, pencil eraser",
751
+ "768": "rugby ball",
752
+ "769": "rule, ruler",
753
+ "77": "wolf spider, hunting spider",
754
+ "770": "running shoe",
755
+ "771": "safe",
756
+ "772": "safety pin",
757
+ "773": "saltshaker, salt shaker",
758
+ "774": "sandal",
759
+ "775": "sarong",
760
+ "776": "sax, saxophone",
761
+ "777": "scabbard",
762
+ "778": "scale, weighing machine",
763
+ "779": "school bus",
764
+ "78": "tick",
765
+ "780": "schooner",
766
+ "781": "scoreboard",
767
+ "782": "screen, CRT screen",
768
+ "783": "screw",
769
+ "784": "screwdriver",
770
+ "785": "seat belt, seatbelt",
771
+ "786": "sewing machine",
772
+ "787": "shield, buckler",
773
+ "788": "shoe shop, shoe-shop, shoe store",
774
+ "789": "shoji",
775
+ "79": "centipede",
776
+ "790": "shopping basket",
777
+ "791": "shopping cart",
778
+ "792": "shovel",
779
+ "793": "shower cap",
780
+ "794": "shower curtain",
781
+ "795": "ski",
782
+ "796": "ski mask",
783
+ "797": "sleeping bag",
784
+ "798": "slide rule, slipstick",
785
+ "799": "sliding door",
786
+ "8": "hen",
787
+ "80": "black grouse",
788
+ "800": "slot, one-armed bandit",
789
+ "801": "snorkel",
790
+ "802": "snowmobile",
791
+ "803": "snowplow, snowplough",
792
+ "804": "soap dispenser",
793
+ "805": "soccer ball",
794
+ "806": "sock",
795
+ "807": "solar dish, solar collector, solar furnace",
796
+ "808": "sombrero",
797
+ "809": "soup bowl",
798
+ "81": "ptarmigan",
799
+ "810": "space bar",
800
+ "811": "space heater",
801
+ "812": "space shuttle",
802
+ "813": "spatula",
803
+ "814": "speedboat",
804
+ "815": "spider web, spiders web",
805
+ "816": "spindle",
806
+ "817": "sports car, sport car",
807
+ "818": "spotlight, spot",
808
+ "819": "stage",
809
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
810
+ "820": "steam locomotive",
811
+ "821": "steel arch bridge",
812
+ "822": "steel drum",
813
+ "823": "stethoscope",
814
+ "824": "stole",
815
+ "825": "stone wall",
816
+ "826": "stopwatch, stop watch",
817
+ "827": "stove",
818
+ "828": "strainer",
819
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
820
+ "83": "prairie chicken, prairie grouse, prairie fowl",
821
+ "830": "stretcher",
822
+ "831": "studio couch, day bed",
823
+ "832": "stupa, tope",
824
+ "833": "submarine, pigboat, sub, U-boat",
825
+ "834": "suit, suit of clothes",
826
+ "835": "sundial",
827
+ "836": "sunglass",
828
+ "837": "sunglasses, dark glasses, shades",
829
+ "838": "sunscreen, sunblock, sun blocker",
830
+ "839": "suspension bridge",
831
+ "84": "peacock",
832
+ "840": "swab, swob, mop",
833
+ "841": "sweatshirt",
834
+ "842": "swimming trunks, bathing trunks",
835
+ "843": "swing",
836
+ "844": "switch, electric switch, electrical switch",
837
+ "845": "syringe",
838
+ "846": "table lamp",
839
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
840
+ "848": "tape player",
841
+ "849": "teapot",
842
+ "85": "quail",
843
+ "850": "teddy, teddy bear",
844
+ "851": "television, television system",
845
+ "852": "tennis ball",
846
+ "853": "thatch, thatched roof",
847
+ "854": "theater curtain, theatre curtain",
848
+ "855": "thimble",
849
+ "856": "thresher, thrasher, threshing machine",
850
+ "857": "throne",
851
+ "858": "tile roof",
852
+ "859": "toaster",
853
+ "86": "partridge",
854
+ "860": "tobacco shop, tobacconist shop, tobacconist",
855
+ "861": "toilet seat",
856
+ "862": "torch",
857
+ "863": "totem pole",
858
+ "864": "tow truck, tow car, wrecker",
859
+ "865": "toyshop",
860
+ "866": "tractor",
861
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
862
+ "868": "tray",
863
+ "869": "trench coat",
864
+ "87": "African grey, African gray, Psittacus erithacus",
865
+ "870": "tricycle, trike, velocipede",
866
+ "871": "trimaran",
867
+ "872": "tripod",
868
+ "873": "triumphal arch",
869
+ "874": "trolleybus, trolley coach, trackless trolley",
870
+ "875": "trombone",
871
+ "876": "tub, vat",
872
+ "877": "turnstile",
873
+ "878": "typewriter keyboard",
874
+ "879": "umbrella",
875
+ "88": "macaw",
876
+ "880": "unicycle, monocycle",
877
+ "881": "upright, upright piano",
878
+ "882": "vacuum, vacuum cleaner",
879
+ "883": "vase",
880
+ "884": "vault",
881
+ "885": "velvet",
882
+ "886": "vending machine",
883
+ "887": "vestment",
884
+ "888": "viaduct",
885
+ "889": "violin, fiddle",
886
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
887
+ "890": "volleyball",
888
+ "891": "waffle iron",
889
+ "892": "wall clock",
890
+ "893": "wallet, billfold, notecase, pocketbook",
891
+ "894": "wardrobe, closet, press",
892
+ "895": "warplane, military plane",
893
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
894
+ "897": "washer, automatic washer, washing machine",
895
+ "898": "water bottle",
896
+ "899": "water jug",
897
+ "9": "ostrich, Struthio camelus",
898
+ "90": "lorikeet",
899
+ "900": "water tower",
900
+ "901": "whiskey jug",
901
+ "902": "whistle",
902
+ "903": "wig",
903
+ "904": "window screen",
904
+ "905": "window shade",
905
+ "906": "Windsor tie",
906
+ "907": "wine bottle",
907
+ "908": "wing",
908
+ "909": "wok",
909
+ "91": "coucal",
910
+ "910": "wooden spoon",
911
+ "911": "wool, woolen, woollen",
912
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
913
+ "913": "wreck",
914
+ "914": "yawl",
915
+ "915": "yurt",
916
+ "916": "web site, website, internet site, site",
917
+ "917": "comic book",
918
+ "918": "crossword puzzle, crossword",
919
+ "919": "street sign",
920
+ "92": "bee eater",
921
+ "920": "traffic light, traffic signal, stoplight",
922
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
923
+ "922": "menu",
924
+ "923": "plate",
925
+ "924": "guacamole",
926
+ "925": "consomme",
927
+ "926": "hot pot, hotpot",
928
+ "927": "trifle",
929
+ "928": "ice cream, icecream",
930
+ "929": "ice lolly, lolly, lollipop, popsicle",
931
+ "93": "hornbill",
932
+ "930": "French loaf",
933
+ "931": "bagel, beigel",
934
+ "932": "pretzel",
935
+ "933": "cheeseburger",
936
+ "934": "hotdog, hot dog, red hot",
937
+ "935": "mashed potato",
938
+ "936": "head cabbage",
939
+ "937": "broccoli",
940
+ "938": "cauliflower",
941
+ "939": "zucchini, courgette",
942
+ "94": "hummingbird",
943
+ "940": "spaghetti squash",
944
+ "941": "acorn squash",
945
+ "942": "butternut squash",
946
+ "943": "cucumber, cuke",
947
+ "944": "artichoke, globe artichoke",
948
+ "945": "bell pepper",
949
+ "946": "cardoon",
950
+ "947": "mushroom",
951
+ "948": "Granny Smith",
952
+ "949": "strawberry",
953
+ "95": "jacamar",
954
+ "950": "orange",
955
+ "951": "lemon",
956
+ "952": "fig",
957
+ "953": "pineapple, ananas",
958
+ "954": "banana",
959
+ "955": "jackfruit, jak, jack",
960
+ "956": "custard apple",
961
+ "957": "pomegranate",
962
+ "958": "hay",
963
+ "959": "carbonara",
964
+ "96": "toucan",
965
+ "960": "chocolate sauce, chocolate syrup",
966
+ "961": "dough",
967
+ "962": "meat loaf, meatloaf",
968
+ "963": "pizza, pizza pie",
969
+ "964": "potpie",
970
+ "965": "burrito",
971
+ "966": "red wine",
972
+ "967": "espresso",
973
+ "968": "cup",
974
+ "969": "eggnog",
975
+ "97": "drake",
976
+ "970": "alp",
977
+ "971": "bubble",
978
+ "972": "cliff, drop, drop-off",
979
+ "973": "coral reef",
980
+ "974": "geyser",
981
+ "975": "lakeside, lakeshore",
982
+ "976": "promontory, headland, head, foreland",
983
+ "977": "sandbar, sand bar",
984
+ "978": "seashore, coast, seacoast, sea-coast",
985
+ "979": "valley, vale",
986
+ "98": "red-breasted merganser, Mergus serrator",
987
+ "980": "volcano",
988
+ "981": "ballplayer, baseball player",
989
+ "982": "groom, bridegroom",
990
+ "983": "scuba diver",
991
+ "984": "rapeseed",
992
+ "985": "daisy",
993
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
994
+ "987": "corn",
995
+ "988": "acorn",
996
+ "989": "hip, rose hip, rosehip",
997
+ "99": "goose",
998
+ "990": "buckeye, horse chestnut, conker",
999
+ "991": "coral fungus",
1000
+ "992": "agaric",
1001
+ "993": "gyromitra",
1002
+ "994": "stinkhorn, carrion fungus",
1003
+ "995": "earthstar",
1004
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1005
+ "997": "bolete",
1006
+ "998": "ear, spike, capitulum",
1007
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1008
+ },
1009
+ "scheduler": [
1010
+ "scheduling_flow_match_promoe",
1011
+ "ProMoEFlowMatchScheduler"
1012
+ ],
1013
+ "transformer": [
1014
+ "transformer_promoe",
1015
+ "ProMoETransformer2DModel"
1016
+ ],
1017
+ "vae": [
1018
+ "diffusers",
1019
+ "AutoencoderKL"
1020
+ ]
1021
+ }
ProMoE-XL-256/pipeline.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: ProMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+
16
+ try:
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ except Exception: # pragma: no cover
19
+ class DiffusionPipeline:
20
+ def __init__(self):
21
+ self._execution_device = torch.device("cpu")
22
+
23
+ def register_modules(self, **kwargs):
24
+ for key, value in kwargs.items():
25
+ setattr(self, key, value)
26
+
27
+ def to(self, device):
28
+ self._execution_device = torch.device(device)
29
+ for module in (getattr(self, "transformer", None), getattr(self, "vae", None)):
30
+ if module is not None and hasattr(module, "to"):
31
+ module.to(device)
32
+ return self
33
+
34
+ def progress_bar(self, iterable):
35
+ return iterable
36
+
37
+ def maybe_free_model_hooks(self):
38
+ return None
39
+
40
+ @dataclass
41
+ class ProMoEPipelineOutput:
42
+ images: Union[List[Image.Image], np.ndarray, torch.Tensor]
43
+
44
+ class ProMoEPipeline(DiffusionPipeline):
45
+ r"""
46
+ Pipeline for class-conditional image generation with ProMoE.
47
+
48
+ Parameters:
49
+ transformer ([`ProMoETransformer2DModel`]):
50
+ Class-conditional ProMoE transformer for flow-matching in latent space.
51
+ scheduler ([`ProMoEFlowMatchScheduler`]):
52
+ Flow-matching scheduler used during denoising.
53
+ vae ([`AutoencoderKL`], *optional*):
54
+ Variational autoencoder used to decode latents to pixels.
55
+ id2label (`dict[int, str]`, *optional*):
56
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
57
+ """
58
+
59
+ model_cpu_offload_seq = "transformer->vae"
60
+ _optional_components = ["vae"]
61
+
62
+ def __init__(
63
+ self,
64
+ transformer,
65
+ scheduler,
66
+ vae=None,
67
+ id2label: Optional[Dict[Union[int, str], str]] = None,
68
+ ):
69
+ super().__init__()
70
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
71
+ self._id2label = self._normalize_id2label(id2label)
72
+ self.labels = self._build_label2id(self._id2label)
73
+ self._labels_loaded_from_model_index = bool(self._id2label)
74
+
75
+ def _ensure_labels_loaded(self) -> None:
76
+ if self._labels_loaded_from_model_index:
77
+ return
78
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
79
+ if loaded:
80
+ self._id2label = loaded
81
+ self.labels = self._build_label2id(self._id2label)
82
+ self._labels_loaded_from_model_index = True
83
+
84
+ @staticmethod
85
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
86
+ if not id2label:
87
+ return {}
88
+ return {int(key): value for key, value in id2label.items()}
89
+
90
+ @staticmethod
91
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
92
+ if not variant_path:
93
+ return {}
94
+ variant_dir = Path(variant_path).resolve()
95
+ model_index_path = variant_dir / "model_index.json"
96
+ if not model_index_path.exists():
97
+ return {}
98
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
99
+ id2label = raw.get("id2label")
100
+ if not isinstance(id2label, dict):
101
+ return {}
102
+ return {int(key): value for key, value in id2label.items()}
103
+
104
+ @staticmethod
105
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
106
+ label2id: Dict[str, int] = {}
107
+ for class_id, value in id2label.items():
108
+ for synonym in value.split(","):
109
+ synonym = synonym.strip()
110
+ if synonym:
111
+ label2id[synonym] = int(class_id)
112
+ return dict(sorted(label2id.items()))
113
+
114
+ @property
115
+ def id2label(self) -> Dict[int, str]:
116
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
117
+ self._ensure_labels_loaded()
118
+ return self._id2label
119
+
120
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
121
+ r"""
122
+ Map ImageNet label strings to class ids.
123
+
124
+ Args:
125
+ label (`str` or `list[str]`):
126
+ One or more English label strings. Each string must match a synonym in `id2label`.
127
+ """
128
+ self._ensure_labels_loaded()
129
+ label2id = self.labels
130
+ if not label2id:
131
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
132
+
133
+ if isinstance(label, str):
134
+ label = [label]
135
+
136
+ missing = [item for item in label if item not in label2id]
137
+ if missing:
138
+ preview = ", ".join(list(label2id.keys())[:8])
139
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
140
+ return [label2id[item] for item in label]
141
+
142
+ def _get_vae_spatial_downsample(self) -> int:
143
+ if self.vae is None:
144
+ return 8
145
+ block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0])
146
+ return 2 ** (len(block_out_channels) - 1)
147
+
148
+ def _normalize_class_labels(
149
+ self,
150
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
151
+ device: torch.device,
152
+ ) -> torch.LongTensor:
153
+ if torch.is_tensor(class_labels):
154
+ return class_labels.to(device=device, dtype=torch.long).reshape(-1)
155
+
156
+ if isinstance(class_labels, int):
157
+ class_label_ids = [class_labels]
158
+ elif isinstance(class_labels, str):
159
+ class_label_ids = self.get_label_ids(class_labels)
160
+ elif class_labels and isinstance(class_labels[0], str):
161
+ class_label_ids = self.get_label_ids(class_labels)
162
+ else:
163
+ class_label_ids = list(class_labels)
164
+
165
+ return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1)
166
+
167
+ def _prepare_latents(
168
+ self,
169
+ batch_size: int,
170
+ latent_height: int,
171
+ latent_width: int,
172
+ dtype: torch.dtype,
173
+ device: torch.device,
174
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
175
+ ) -> torch.Tensor:
176
+ shape = (batch_size, self.transformer.in_channels, latent_height, latent_width)
177
+ if isinstance(generator, list):
178
+ latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator]
179
+ return torch.cat(latents, dim=0)
180
+ return torch.randn(shape, generator=generator, device=device, dtype=dtype)
181
+
182
+ def _decode_latents(self, latents: torch.Tensor, output_type: str):
183
+ if output_type == "latent":
184
+ return latents
185
+ if self.vae is not None:
186
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
187
+ decode_dtype = next(self.vae.parameters()).dtype
188
+ latents = (latents / scaling_factor).to(dtype=decode_dtype)
189
+ image = self.vae.decode(latents, return_dict=False)[0]
190
+ else:
191
+ image = latents
192
+
193
+ image = (image / 2 + 0.5).clamp(0, 1)
194
+ if output_type == "pt":
195
+ return image
196
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
197
+ if output_type == "np":
198
+ return image
199
+ pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image]
200
+ return pil_images
201
+
202
+ @torch.no_grad()
203
+ def __call__(
204
+ self,
205
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
206
+ height: int = 256,
207
+ width: int = 256,
208
+ num_inference_steps: int = 50,
209
+ guidance_scale: float = 1.0,
210
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
211
+ output_type: str = "pil",
212
+ return_dict: bool = True,
213
+ ) -> Union[ProMoEPipelineOutput, Tuple]:
214
+ r"""
215
+ Generate class-conditional images with ProMoE.
216
+
217
+ Args:
218
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
219
+ ImageNet class indices or human-readable English label strings.
220
+ """
221
+ device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu")
222
+ model_dtype = next(self.transformer.parameters()).dtype
223
+ class_labels = self._normalize_class_labels(class_labels, device)
224
+ batch_size = class_labels.shape[0]
225
+
226
+ vae_scale = self._get_vae_spatial_downsample()
227
+ latent_height = height // vae_scale
228
+ latent_width = width // vae_scale
229
+ latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator)
230
+
231
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
232
+ null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000))
233
+
234
+ for t in self.progress_bar(self.scheduler.timesteps):
235
+ if guidance_scale > 1.0:
236
+ latent_input = torch.cat([latents, latents], dim=0)
237
+ labels = torch.cat([class_labels, null_labels], dim=0)
238
+ else:
239
+ latent_input = latents
240
+ labels = class_labels
241
+ timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype)
242
+ model_output = self.transformer(
243
+ hidden_states=latent_input,
244
+ timestep=timestep,
245
+ class_labels=labels,
246
+ return_dict=True,
247
+ ).sample
248
+ if model_output.shape[1] != latents.shape[1]:
249
+ model_output = model_output.chunk(2, dim=1)[0]
250
+ if guidance_scale > 1.0:
251
+ model_output_cond, model_output_uncond = model_output.chunk(2)
252
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
253
+ latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample
254
+
255
+ images = self._decode_latents(latents, output_type)
256
+ self.maybe_free_model_hooks()
257
+ if not return_dict:
258
+ return (images,)
259
+ return ProMoEPipelineOutput(images=images)
ProMoE-XL-256/scheduler/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "num_train_timesteps": 1000,
4
+ "shift": 1.0
5
+ }
ProMoE-XL-256/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoEFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
ProMoE-XL-256/scheduler/scheduling_flow_match_promoe.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from types import SimpleNamespace
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ try:
8
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
9
+ except Exception: # pragma: no cover
10
+ FlowMatchEulerDiscreteScheduler = None
11
+
12
+
13
+ @dataclass
14
+ class ProMoEFlowMatchSchedulerOutput:
15
+ prev_sample: torch.FloatTensor
16
+
17
+
18
+ if FlowMatchEulerDiscreteScheduler is not None:
19
+
20
+ class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler):
21
+ pass
22
+
23
+ else:
24
+
25
+ class ProMoEFlowMatchScheduler:
26
+ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
27
+ self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False)
28
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32)
29
+
30
+ def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None):
31
+ self.timesteps = torch.linspace(
32
+ self.config.num_train_timesteps - 1,
33
+ 0,
34
+ num_inference_steps,
35
+ dtype=torch.float32,
36
+ device=device,
37
+ )
38
+
39
+ def step(self, model_output, timestep, sample, generator=None):
40
+ del generator
41
+ dt = 1.0 / max(len(self.timesteps), 1)
42
+ prev_sample = sample - dt * model_output
43
+ return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample)
ProMoE-XL-256/transformer/backbone_diffmoe.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .modeling_promoe_common import (
7
+ Attention,
8
+ FinalLayer,
9
+ LabelEmbedder,
10
+ Mlp,
11
+ MoeMLP_DiffMoE as MoeMLP,
12
+ PatchEmbed,
13
+ TimestepEmbedder,
14
+ get_2d_sincos_pos_embed,
15
+ modulate,
16
+ )
17
+
18
+
19
+ class SparseMoEBlock(nn.Module):
20
+ def __init__(
21
+ self,
22
+ experts,
23
+ hidden_dim,
24
+ num_experts,
25
+ n_shared_experts=0,
26
+ capacity=2,
27
+ mlp_ratio=4.0,
28
+ use_diff_expert=False,
29
+ ):
30
+ super().__init__()
31
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
32
+ nn.init.normal_(self.gate_weight, std=0.006)
33
+ self.experts = nn.ModuleList(experts)
34
+ self.capacity = capacity
35
+ self.num_experts = num_experts
36
+ self.n_shared_experts = n_shared_experts
37
+ self.use_diff_expert = use_diff_expert
38
+ if use_diff_expert:
39
+ self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio))
40
+
41
+ self.capacity_predictor = nn.Sequential(
42
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_dim, self.num_experts, bias=True),
45
+ )
46
+
47
+ if self.n_shared_experts > 0:
48
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2)
49
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
50
+ self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
51
+
52
+ self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts))
53
+ self.register_buffer("ema_decay", torch.tensor([0.95]))
54
+
55
+ def forward(self, x):
56
+ if self.training:
57
+ return self.forward_train(x)
58
+ return self.forward_eval(x)
59
+
60
+ def update_threshold(self, capacity_pred):
61
+ if not self.training:
62
+ return
63
+ capacity_pred = torch.sigmoid(capacity_pred)
64
+ seq_len = capacity_pred.size(0)
65
+ topk = int((seq_len / self.num_experts) * self.capacity)
66
+ threshold = self.expert_threshold
67
+ ema_decay = self.ema_decay
68
+ for i in range(self.num_experts):
69
+ scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True)
70
+ quantile = scores[-1].detach()
71
+ threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile
72
+ if dist.is_available() and dist.is_initialized():
73
+ dist.all_reduce(threshold, op=dist.ReduceOp.SUM)
74
+ threshold /= dist.get_world_size()
75
+ self.expert_threshold = threshold
76
+
77
+ def forward_train(self, x):
78
+ bsz, seq_len, hidden_dim = x.shape
79
+ identity = x
80
+ x = x.view(-1, hidden_dim)
81
+ total_tokens = x.shape[0]
82
+ capacity_pred = self.capacity_predictor(x.detach())
83
+ k = int((total_tokens / self.num_experts) * self.capacity)
84
+ logits = F.linear(x, self.gate_weight, None)
85
+ scores = logits.softmax(dim=-1).permute(1, 0)
86
+ gating, index = torch.topk(scores, k=k, dim=-1, sorted=False)
87
+ mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device)
88
+ mask.scatter_(1, index, 1.0)
89
+ expert_inputs = x[index]
90
+ expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)])
91
+ gated_outputs = gating.unsqueeze(-1) * expert_outputs
92
+
93
+ y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device)
94
+ offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens
95
+ flat_index = (index + offset.long()).view(-1)
96
+ y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim))
97
+ y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False)
98
+
99
+ self.update_threshold(capacity_pred)
100
+ x_out = y.view(bsz, seq_len, hidden_dim)
101
+ ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts)
102
+ capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts)
103
+ if self.n_shared_experts > 0:
104
+ x_out = x_out + self.shared_experts(identity)
105
+ if self.use_diff_expert:
106
+ x_out = x_out - self.diff_expert(identity)
107
+ return x_out, ones, capacity_pred
108
+
109
+ def forward_eval(self, x):
110
+ bsz, seq_len, hidden_dim = x.shape
111
+ identity = x
112
+ x = x.view(-1, hidden_dim)
113
+ total_tokens = x.shape[0]
114
+ capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach()))
115
+ threshold = self.expert_threshold
116
+ logits = F.linear(x, self.gate_weight, None)
117
+ scores = logits.softmax(dim=-1).permute(-1, -2)
118
+ y = torch.zeros_like(x, dtype=x.dtype)
119
+ for i, expert in enumerate(self.experts):
120
+ k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum()
121
+ gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False)
122
+ y[index, :] += gating.unsqueeze(-1) * expert(x[index, :])
123
+ x_out = y.view(bsz, seq_len, hidden_dim)
124
+ if self.n_shared_experts > 0:
125
+ x_out = x_out + self.shared_experts(identity)
126
+ return x_out, None, None
127
+
128
+
129
+ class DiTBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_size,
133
+ num_heads,
134
+ head_dim=None,
135
+ mlp_ratio=4.0,
136
+ use_swiglu=False,
137
+ MoE_config=None,
138
+ use_moe=False,
139
+ qk_norm=False,
140
+ **block_kwargs,
141
+ ):
142
+ super().__init__()
143
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
144
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs)
145
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
146
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
147
+ self.use_moe = use_moe
148
+ if use_moe:
149
+ if not use_swiglu:
150
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
151
+ experts = [
152
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
153
+ for _ in range(MoE_config.num_experts)
154
+ ]
155
+ else:
156
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
157
+ self.mlp = SparseMoEBlock(
158
+ experts=experts,
159
+ hidden_dim=hidden_size,
160
+ num_experts=MoE_config.num_experts,
161
+ capacity=MoE_config.capacity,
162
+ n_shared_experts=MoE_config.n_shared_experts,
163
+ mlp_ratio=4.0,
164
+ )
165
+ else:
166
+ if not use_swiglu:
167
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
169
+ else:
170
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
171
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
172
+
173
+ def forward(self, x, c):
174
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
175
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
176
+ if self.use_moe:
177
+ x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
178
+ x = x + gate_mlp.unsqueeze(1) * x_mlp
179
+ return x, ones, pred_c
180
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+ return x, None, None
182
+
183
+
184
+ class DiT(nn.Module):
185
+ def __init__(
186
+ self,
187
+ input_size=32,
188
+ patch_size=2,
189
+ in_channels=4,
190
+ hidden_size=1152,
191
+ depth=28,
192
+ num_heads=16,
193
+ mlp_ratio=4.0,
194
+ qk_norm=False,
195
+ class_dropout_prob=0.1,
196
+ num_classes=1000,
197
+ learn_sigma=True,
198
+ use_swiglu=False,
199
+ MoE_config=None,
200
+ head_dim=None,
201
+ CapacityPred_loss_weight=0.01,
202
+ ):
203
+ super().__init__()
204
+ self.learn_sigma = learn_sigma
205
+ self.in_channels = in_channels
206
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
207
+ self.patch_size = patch_size
208
+ self.num_heads = num_heads
209
+ self.MoE_config = MoE_config
210
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
211
+ self.CapacityPred_loss_weight = CapacityPred_loss_weight
212
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
213
+ self.t_embedder = TimestepEmbedder(hidden_size)
214
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
215
+ num_patches = self.x_embedder.num_patches
216
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
217
+ self.blocks = nn.ModuleList(
218
+ [
219
+ DiTBlock(
220
+ hidden_size,
221
+ num_heads,
222
+ head_dim=head_dim,
223
+ mlp_ratio=mlp_ratio,
224
+ qk_norm=qk_norm,
225
+ use_swiglu=use_swiglu,
226
+ MoE_config=MoE_config,
227
+ use_moe=use_moe_flag[i],
228
+ )
229
+ for i in range(depth)
230
+ ]
231
+ )
232
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
233
+ self.init_MoeMLP = MoE_config.init_MoeMLP
234
+ self.initialize_weights()
235
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
236
+ if self.capacity_schedule:
237
+ self.training_iters = -1
238
+
239
+ def initialize_weights(self):
240
+ def _basic_init(module):
241
+ if isinstance(module, nn.Linear):
242
+ torch.nn.init.xavier_uniform_(module.weight)
243
+ if module.bias is not None:
244
+ nn.init.constant_(module.bias, 0)
245
+
246
+ self.apply(_basic_init)
247
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
248
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
249
+ w = self.x_embedder.proj.weight.data
250
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
251
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
252
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
253
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
254
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
255
+ for block in self.blocks:
256
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
258
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
259
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
260
+ nn.init.constant_(self.final_layer.linear.weight, 0)
261
+ nn.init.constant_(self.final_layer.linear.bias, 0)
262
+
263
+ def unpatchify(self, x):
264
+ c = self.out_channels
265
+ p = self.x_embedder.patch_size[0]
266
+ h = w = int(x.shape[1] ** 0.5)
267
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
268
+ x = torch.einsum("nhwpqc->nchpwq", x)
269
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
270
+
271
+ def forward(self, x, t, context, **kwargs):
272
+ y = context
273
+ if len(x.shape) != 4:
274
+ x = x.squeeze(2)
275
+
276
+ if self.training and self.capacity_schedule:
277
+ num_experts = self.MoE_config.num_experts
278
+ capacity = self.MoE_config.capacity
279
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
280
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
281
+ if self.training_iters <= stage_i:
282
+ capacity = num_experts
283
+ elif self.training_iters <= stage_ii:
284
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
285
+ for block in self.blocks:
286
+ if hasattr(block.mlp, "capacity"):
287
+ block.mlp.capacity = capacity
288
+
289
+ x = self.x_embedder(x) + self.pos_embed
290
+ t = self.t_embedder(t)
291
+ y = self.y_embedder(y, self.training)
292
+ c = t + y
293
+ ones_list, pred_c_list, layer_idx_list = [], [], []
294
+ for layer_idx, block in enumerate(self.blocks):
295
+ x, ones, pred_c = block(x, c)
296
+ if ones is not None:
297
+ ones_list.append(ones)
298
+ pred_c_list.append(pred_c)
299
+ layer_idx_list.append(layer_idx)
300
+ x = self.final_layer(x, c)
301
+ x = self.unpatchify(x)
302
+ return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight
ProMoE-XL-256/transformer/backbone_dit.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .modeling_promoe_common import (
5
+ Attention,
6
+ FinalLayer,
7
+ LabelEmbedder,
8
+ Mlp,
9
+ PatchEmbed,
10
+ TimestepEmbedder,
11
+ get_2d_sincos_pos_embed,
12
+ modulate,
13
+ )
14
+
15
+
16
+ class DiTBlock(nn.Module):
17
+ def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs):
18
+ super().__init__()
19
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
20
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
21
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
22
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
23
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
24
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
25
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
26
+
27
+ def forward(self, x, c):
28
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
29
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
30
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
31
+ return x
32
+
33
+
34
+ class DiT(nn.Module):
35
+ def __init__(
36
+ self,
37
+ input_size=32,
38
+ patch_size=2,
39
+ in_channels=4,
40
+ hidden_size=1152,
41
+ depth=28,
42
+ num_heads=16,
43
+ mlp_ratio=4.0,
44
+ qk_norm=False,
45
+ class_dropout_prob=0.1,
46
+ num_classes=1000,
47
+ learn_sigma=True,
48
+ head_dim=None,
49
+ use_swiglu=False,
50
+ ):
51
+ super().__init__()
52
+ self.learn_sigma = learn_sigma
53
+ self.in_channels = in_channels
54
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
55
+ self.patch_size = patch_size
56
+ self.num_heads = num_heads
57
+
58
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
59
+ self.t_embedder = TimestepEmbedder(hidden_size)
60
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
61
+ num_patches = self.x_embedder.num_patches
62
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
63
+
64
+ self.blocks = nn.ModuleList(
65
+ [
66
+ DiTBlock(
67
+ hidden_size,
68
+ num_heads,
69
+ head_dim=head_dim,
70
+ mlp_ratio=mlp_ratio,
71
+ qk_norm=qk_norm,
72
+ use_swiglu=use_swiglu,
73
+ )
74
+ for _ in range(depth)
75
+ ]
76
+ )
77
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
78
+ self.initialize_weights()
79
+
80
+ def initialize_weights(self):
81
+ def _basic_init(module):
82
+ if isinstance(module, nn.Linear):
83
+ torch.nn.init.xavier_uniform_(module.weight)
84
+ if module.bias is not None:
85
+ nn.init.constant_(module.bias, 0)
86
+
87
+ self.apply(_basic_init)
88
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
89
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
90
+ w = self.x_embedder.proj.weight.data
91
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
92
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
93
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
94
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
95
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
96
+ for block in self.blocks:
97
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
98
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
99
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
100
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
101
+ nn.init.constant_(self.final_layer.linear.weight, 0)
102
+ nn.init.constant_(self.final_layer.linear.bias, 0)
103
+
104
+ def unpatchify(self, x):
105
+ c = self.out_channels
106
+ p = self.x_embedder.patch_size[0]
107
+ h = w = int(x.shape[1] ** 0.5)
108
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
109
+ x = torch.einsum("nhwpqc->nchpwq", x)
110
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
111
+
112
+ def forward(self, x, t, context, **kwargs):
113
+ y = context
114
+ if len(x.shape) != 4:
115
+ x = x.squeeze(2)
116
+ x = self.x_embedder(x) + self.pos_embed
117
+ t = self.t_embedder(t)
118
+ y = self.y_embedder(y, self.training)
119
+ c = t + y
120
+ for block in self.blocks:
121
+ x = block(x, c)
122
+ x = self.final_layer(x, c)
123
+ return self.unpatchify(x)
ProMoE-XL-256/transformer/backbone_ecdit.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP_DiffMoE as MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class SparseMoEBlock(nn.Module):
19
+ def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2):
20
+ super().__init__()
21
+ self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim)))
22
+ nn.init.normal_(self.gate_weight, std=0.006)
23
+ self.experts = nn.ModuleList(experts)
24
+ self.capacity = capacity
25
+ self.num_experts = num_experts
26
+ self.n_shared_experts = n_shared_experts
27
+ if self.n_shared_experts > 0:
28
+ intermediate_size = hidden_dim * self.n_shared_experts
29
+ self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2)
30
+
31
+ def forward(self, x):
32
+ identity = x
33
+ batch_size, seq_len, _ = x.shape
34
+ logits = F.linear(x, self.gate_weight, None)
35
+ affinity = logits.softmax(dim=-1)
36
+ affinity = torch.einsum("b s e -> b e s", affinity)
37
+ k = int((seq_len / self.num_experts) * self.capacity)
38
+ gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False)
39
+ dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype)
40
+ x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x)
41
+ x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)]
42
+ x_e = torch.stack(x_e, dim=1)
43
+ x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e)
44
+ if self.n_shared_experts > 0:
45
+ x_out = x_out + self.shared_experts(identity)
46
+ return x_out
47
+
48
+
49
+ class DiTBlock(nn.Module):
50
+ def __init__(
51
+ self,
52
+ hidden_size,
53
+ num_heads,
54
+ head_dim=None,
55
+ mlp_ratio=4.0,
56
+ use_swiglu=False,
57
+ MoE_config=None,
58
+ use_moe=False,
59
+ **block_kwargs,
60
+ ):
61
+ super().__init__()
62
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
63
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
64
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
65
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
66
+ if use_moe:
67
+ if not use_swiglu:
68
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
69
+ experts = [
70
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
71
+ for _ in range(MoE_config.num_experts)
72
+ ]
73
+ else:
74
+ experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)]
75
+ self.mlp = SparseMoEBlock(
76
+ experts=experts,
77
+ hidden_dim=hidden_size,
78
+ num_experts=MoE_config.num_experts,
79
+ capacity=MoE_config.capacity,
80
+ n_shared_experts=MoE_config.n_shared_experts,
81
+ )
82
+ else:
83
+ if not use_swiglu:
84
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
85
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
86
+ else:
87
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
88
+
89
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
90
+
91
+ def forward(self, x, c):
92
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
93
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
94
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
95
+ return x
96
+
97
+
98
+ class DiT(nn.Module):
99
+ def __init__(
100
+ self,
101
+ input_size=32,
102
+ patch_size=2,
103
+ in_channels=4,
104
+ hidden_size=1152,
105
+ depth=28,
106
+ num_heads=16,
107
+ mlp_ratio=4.0,
108
+ qk_norm=False,
109
+ class_dropout_prob=0.1,
110
+ num_classes=1000,
111
+ learn_sigma=True,
112
+ use_swiglu=False,
113
+ MoE_config=None,
114
+ head_dim=None,
115
+ ):
116
+ super().__init__()
117
+ self.learn_sigma = learn_sigma
118
+ self.in_channels = in_channels
119
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
120
+ self.patch_size = patch_size
121
+ self.num_heads = num_heads
122
+ self.MoE_config = MoE_config
123
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
124
+
125
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
126
+ self.t_embedder = TimestepEmbedder(hidden_size)
127
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
128
+ num_patches = self.x_embedder.num_patches
129
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
130
+ self.blocks = nn.ModuleList(
131
+ [
132
+ DiTBlock(
133
+ hidden_size,
134
+ num_heads,
135
+ head_dim=head_dim,
136
+ mlp_ratio=mlp_ratio,
137
+ qk_norm=qk_norm,
138
+ use_swiglu=use_swiglu,
139
+ MoE_config=MoE_config,
140
+ use_moe=use_moe_flag[i],
141
+ )
142
+ for i in range(depth)
143
+ ]
144
+ )
145
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
146
+ self.init_MoeMLP = MoE_config.init_MoeMLP
147
+ self.initialize_weights()
148
+ self.capacity_schedule = MoE_config.get("capacity_schedule", None)
149
+ if self.capacity_schedule:
150
+ self.training_iters = -1
151
+
152
+ def initialize_weights(self):
153
+ def _basic_init(module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.xavier_uniform_(module.weight)
156
+ if module.bias is not None:
157
+ nn.init.constant_(module.bias, 0)
158
+
159
+ self.apply(_basic_init)
160
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
161
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
162
+ w = self.x_embedder.proj.weight.data
163
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
164
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
165
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
166
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
167
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
168
+ for block in self.blocks:
169
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
172
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
173
+ nn.init.constant_(self.final_layer.linear.weight, 0)
174
+ nn.init.constant_(self.final_layer.linear.bias, 0)
175
+
176
+ def init_moe_mlp(module, std=0.006):
177
+ nn.init.normal_(module.gate_proj.weight, std=std)
178
+ nn.init.normal_(module.up_proj.weight, std=std)
179
+ nn.init.normal_(module.down_proj.weight, std=std)
180
+
181
+ if self.init_MoeMLP:
182
+ for block in self.blocks:
183
+ if hasattr(block.mlp, "experts"):
184
+ for expert in block.mlp.experts:
185
+ if hasattr(expert, "gate_proj"):
186
+ init_moe_mlp(expert)
187
+
188
+ def unpatchify(self, x):
189
+ c = self.out_channels
190
+ p = self.x_embedder.patch_size[0]
191
+ h = w = int(x.shape[1] ** 0.5)
192
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
193
+ x = torch.einsum("nhwpqc->nchpwq", x)
194
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
195
+
196
+ def forward(self, x, t, context, **kwargs):
197
+ y = context
198
+ if len(x.shape) != 4:
199
+ x = x.squeeze(2)
200
+ if self.training and self.capacity_schedule:
201
+ num_experts = self.MoE_config.num_experts
202
+ capacity = self.MoE_config.capacity
203
+ stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters
204
+ stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters
205
+ if self.training_iters <= stage_i:
206
+ capacity = num_experts
207
+ elif self.training_iters <= stage_ii:
208
+ capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i)
209
+ for block in self.blocks:
210
+ if hasattr(block.mlp, "capacity"):
211
+ block.mlp.capacity = capacity
212
+
213
+ x = self.x_embedder(x) + self.pos_embed
214
+ t = self.t_embedder(t)
215
+ y = self.y_embedder(y, self.training)
216
+ c = t + y
217
+ for block in self.blocks:
218
+ x = block(x, c)
219
+ x = self.final_layer(x, c)
220
+ return self.unpatchify(x)
ProMoE-XL-256/transformer/backbone_promoe_ec.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=1,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.hidden_size = hidden_size
55
+ self.top_k = top_k
56
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
57
+ self.use_shared_expert = use_shared_expert
58
+ self.use_uncond_expert = use_uncond_expert
59
+ self.router_weight_mode = router_weight_mode
60
+ self.routing_contrastive_lam = routing_contrastive_lam
61
+ self.routing_contrastive_temperature = routing_contrastive_temperature
62
+ self.experts = nn.ModuleList(
63
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
64
+ )
65
+ if use_shared_expert:
66
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
67
+ self._init_weights()
68
+
69
+ def compute_router(self, cond_hidden_states):
70
+ b_cond, seq_len, _ = cond_hidden_states.shape
71
+ num_cond_experts = self.num_routed_experts
72
+ input_norm = F.normalize(cond_hidden_states, p=2, dim=-1)
73
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1)
74
+ cos_sim = input_norm @ cluster_norm.T
75
+ cos_sim_expert_view = cos_sim.transpose(1, 2)
76
+ if self.router_weight_mode == "softmax":
77
+ cond_weights = F.softmax(cos_sim_expert_view, dim=-1)
78
+ elif self.router_weight_mode == "sigmoid":
79
+ cond_weights = torch.sigmoid(cos_sim_expert_view)
80
+ elif self.router_weight_mode == "identity":
81
+ cond_weights = cos_sim_expert_view
82
+ else:
83
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
84
+ k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len))
85
+ router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False)
86
+ dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype)
87
+ expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states)
88
+ return dispatch_mask, router_weights, expert_inputs
89
+
90
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
91
+ identity = hidden_states
92
+ batch_size, _, hidden_dim = hidden_states.shape
93
+ final_output = torch.zeros_like(hidden_states)
94
+ loss = None
95
+ cond_batch_mask = (
96
+ labels.view(-1) != 1000
97
+ ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device)
98
+ uncond_batch_mask = ~cond_batch_mask
99
+ cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts
100
+
101
+ if cond_batch_mask.any():
102
+ cond_hidden_states = hidden_states[cond_batch_mask]
103
+ dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states)
104
+ num_cond_experts = len(cond_experts)
105
+ expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1)
106
+ cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype)
107
+ final_output[cond_batch_mask] = cond_output
108
+ if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1:
109
+ expert_token_means = expert_inputs.mean(dim=2)
110
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means)
111
+ loss = routing_contrastive_loss * self.routing_contrastive_lam
112
+ else:
113
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
114
+ for expert in cond_experts:
115
+ final_output = final_output + expert(dummy_input).sum() * 0
116
+
117
+ if self.use_uncond_expert:
118
+ if uncond_batch_mask.any():
119
+ uncond_hidden_states = hidden_states[uncond_batch_mask]
120
+ final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype)
121
+ else:
122
+ dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
123
+ final_output = final_output + self.experts[-1](dummy_input).sum() * 0
124
+
125
+ if self.use_shared_expert:
126
+ final_output += self.shared_expert(identity).to(hidden_states.dtype)
127
+ return final_output, loss
128
+
129
+ def compute_routing_contrastive_loss(self, expert_token_means):
130
+ batch_size, num_cond_experts, _ = expert_token_means.shape
131
+ if num_cond_experts < 2:
132
+ return torch.tensor(0.0, device=expert_token_means.device)
133
+ centers_norm = F.normalize(self.cluster_centers, p=2, dim=1)
134
+ means_norm = F.normalize(expert_token_means, p=2, dim=2)
135
+ sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm)
136
+ logits = sim_matrix / self.routing_contrastive_temperature
137
+ labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1)
138
+ return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1))
139
+
140
+ def _init_weights(self):
141
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
142
+
143
+
144
+ class DiTBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ hidden_size,
148
+ num_heads,
149
+ head_dim=None,
150
+ mlp_ratio=4.0,
151
+ use_swiglu=False,
152
+ MoE_config=None,
153
+ use_moe=False,
154
+ **block_kwargs,
155
+ ):
156
+ super().__init__()
157
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
158
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
159
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
160
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
161
+ self.use_moe = use_moe
162
+ if use_moe:
163
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
164
+ else:
165
+ if not use_swiglu:
166
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
167
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
168
+ else:
169
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
170
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
171
+
172
+ def forward(self, x, c, label):
173
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
174
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
175
+ if self.use_moe:
176
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
177
+ if aux_loss is not None:
178
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
179
+ return x + gate_mlp.unsqueeze(1) * x_mlp
180
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+
182
+
183
+ class DiT(nn.Module):
184
+ def __init__(
185
+ self,
186
+ input_size=32,
187
+ patch_size=2,
188
+ in_channels=4,
189
+ hidden_size=1152,
190
+ depth=28,
191
+ num_heads=16,
192
+ mlp_ratio=4.0,
193
+ qk_norm=False,
194
+ class_dropout_prob=0.1,
195
+ num_classes=1000,
196
+ learn_sigma=True,
197
+ use_swiglu=False,
198
+ MoE_config=None,
199
+ head_dim=None,
200
+ ):
201
+ super().__init__()
202
+ self.learn_sigma = learn_sigma
203
+ self.in_channels = in_channels
204
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
205
+ self.patch_size = patch_size
206
+ self.num_heads = num_heads
207
+ self.MoE_config = MoE_config
208
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
209
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
210
+ self.t_embedder = TimestepEmbedder(hidden_size)
211
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
212
+ num_patches = self.x_embedder.num_patches
213
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
214
+ self.blocks = nn.ModuleList(
215
+ [
216
+ DiTBlock(
217
+ hidden_size,
218
+ num_heads,
219
+ head_dim=head_dim,
220
+ mlp_ratio=mlp_ratio,
221
+ qk_norm=qk_norm,
222
+ use_swiglu=use_swiglu,
223
+ MoE_config=MoE_config,
224
+ use_moe=use_moe_flag[i],
225
+ )
226
+ for i in range(depth)
227
+ ]
228
+ )
229
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
230
+ self.init_MoeMLP = MoE_config.init_MoeMLP
231
+ self.initialize_weights()
232
+
233
+ def initialize_weights(self):
234
+ def _basic_init(module):
235
+ if isinstance(module, nn.Linear):
236
+ torch.nn.init.xavier_uniform_(module.weight)
237
+ if module.bias is not None:
238
+ nn.init.constant_(module.bias, 0)
239
+
240
+ self.apply(_basic_init)
241
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
242
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
243
+ w = self.x_embedder.proj.weight.data
244
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
245
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
246
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
247
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
248
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
249
+ for block in self.blocks:
250
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
251
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
252
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
253
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
254
+ nn.init.constant_(self.final_layer.linear.weight, 0)
255
+ nn.init.constant_(self.final_layer.linear.bias, 0)
256
+
257
+ def init_moe_mlp(module, std=0.006):
258
+ nn.init.normal_(module.up_proj.weight, std=std)
259
+ nn.init.normal_(module.down_proj.weight, std=std)
260
+
261
+ if self.init_MoeMLP:
262
+ for block in self.blocks:
263
+ if hasattr(block.mlp, "experts"):
264
+ for expert in block.mlp.experts:
265
+ init_moe_mlp(expert)
266
+
267
+ def unpatchify(self, x):
268
+ c = self.out_channels
269
+ p = self.x_embedder.patch_size[0]
270
+ h = w = int(x.shape[1] ** 0.5)
271
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
272
+ x = torch.einsum("nhwpqc->nchpwq", x)
273
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
274
+
275
+ def forward(self, x, timestep, context, **kwargs):
276
+ y = context
277
+ if len(x.shape) != 4:
278
+ x = x.squeeze(2)
279
+ x = self.x_embedder(x) + self.pos_embed
280
+ t = self.t_embedder(timestep)
281
+ y, labels = self.y_embedder(y, self.training)
282
+ c = t + y
283
+ for block in self.blocks:
284
+ x = block(x, c, labels)
285
+ x = self.final_layer(x, c)
286
+ return self.unpatchify(x)
ProMoE-XL-256/transformer/backbone_promoe_tc.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .modeling_promoe_common import (
6
+ Attention,
7
+ FinalLayer,
8
+ LabelEmbedder,
9
+ Mlp,
10
+ MoeMLP,
11
+ PatchEmbed,
12
+ TimestepEmbedder,
13
+ get_2d_sincos_pos_embed,
14
+ modulate,
15
+ )
16
+
17
+
18
+ class AddAuxiliaryLoss(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, loss):
21
+ ctx.dtype = loss.dtype
22
+ ctx.required_aux_loss = loss.requires_grad
23
+ return x
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
28
+ return grad_output, grad_loss
29
+
30
+
31
+ class SparseMoeBlock(nn.Module):
32
+ def __init__(
33
+ self,
34
+ num_routed_experts,
35
+ hidden_size,
36
+ moe_intermediate_size,
37
+ shared_expert_intermediate_size,
38
+ top_k=2,
39
+ load_balance_loss_coef=0,
40
+ norm_topk_prob=False,
41
+ seq_aux=False,
42
+ use_shared_expert=True,
43
+ use_uncond_expert=True,
44
+ router_weight_mode="softmax",
45
+ routing_contrastive_lam=0,
46
+ use_top_k_for_routing_contrastive=False,
47
+ routing_contrastive_temperature=0.1,
48
+ **kwargs,
49
+ ):
50
+ super().__init__()
51
+ del norm_topk_prob
52
+ self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts
53
+ self.num_routed_experts = num_routed_experts
54
+ self.seq_aux = seq_aux
55
+ self.hidden_size = hidden_size
56
+ self.top_k = top_k
57
+ self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size))
58
+ self.alpha = load_balance_loss_coef
59
+ self.use_shared_expert = use_shared_expert
60
+ self.use_uncond_expert = use_uncond_expert
61
+ self.router_weight_mode = router_weight_mode
62
+ self.routing_contrastive_lam = routing_contrastive_lam
63
+ self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive
64
+ self.routing_contrastive_temperature = routing_contrastive_temperature
65
+ self.experts = nn.ModuleList(
66
+ [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)]
67
+ )
68
+ if use_shared_expert:
69
+ self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size)
70
+ self._init_weights()
71
+
72
+ def compute_router(self, hidden_states, labels):
73
+ batch_size, seq_len, _ = hidden_states.shape
74
+ device = hidden_states.device
75
+ flat_input = hidden_states.view(-1, self.hidden_size)
76
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
77
+ if self.use_uncond_expert and flat_labels is not None:
78
+ uncond_mask = flat_labels == 1000
79
+ cond_mask = ~uncond_mask
80
+ else:
81
+ uncond_mask = None
82
+ cond_mask = torch.ones_like(flat_labels, dtype=torch.bool)
83
+
84
+ router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype)
85
+ expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long)
86
+
87
+ if uncond_mask is not None and uncond_mask.any():
88
+ uncond_positions = torch.where(uncond_mask)[0]
89
+ router_weights[uncond_positions, 0] = 1.0
90
+ expert_indices[uncond_positions] = self.num_experts - 1
91
+
92
+ cond_weights = None
93
+ topk_idx = None
94
+ if cond_mask.any():
95
+ cond_positions = torch.where(cond_mask)[0]
96
+ cond_input = flat_input[cond_positions]
97
+ input_norm = F.normalize(cond_input, p=2, dim=1)
98
+ cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1)
99
+ cos_sim = input_norm @ cluster_norm.T
100
+ if self.router_weight_mode == "softmax":
101
+ cond_weights = F.softmax(cos_sim, dim=1)
102
+ elif self.router_weight_mode == "sigmoid":
103
+ cond_weights = torch.sigmoid(cos_sim)
104
+ elif self.router_weight_mode == "identity":
105
+ cond_weights = cos_sim
106
+ else:
107
+ raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}")
108
+ topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1)
109
+ router_weights[cond_positions] = topk_scores.to(router_weights.dtype)
110
+ expert_indices[cond_positions] = topk_idx
111
+
112
+ router_weights = router_weights.view(batch_size, seq_len, self.top_k)
113
+ expert_indices = expert_indices.view(batch_size, seq_len, self.top_k)
114
+
115
+ load_balance_loss = None
116
+ if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None:
117
+ cond_batch_size = (labels != 1000).sum()
118
+ scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights
119
+ topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1)
120
+ if self.seq_aux:
121
+ scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1)
122
+ ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device)
123
+ ce.scatter_add_(
124
+ 1,
125
+ topk_idx_for_aux_loss,
126
+ torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device),
127
+ ).div_(seq_len * self.top_k / self.num_routed_experts)
128
+ load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
129
+ else:
130
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts)
131
+ ce = mask_ce.float().mean(0)
132
+ pi = scores_for_aux.mean(0)
133
+ fi = ce * self.num_routed_experts
134
+ load_balance_loss = (pi * fi).sum() * self.alpha
135
+ return router_weights, expert_indices, load_balance_loss
136
+
137
+ def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor):
138
+ router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels)
139
+ batch_size, seq_len, hidden_dim = hidden_states.shape
140
+ flat_input = hidden_states.view(-1, hidden_dim)
141
+ flat_weights = router_weights.view(-1, self.top_k)
142
+ flat_indices = expert_indices.view(-1, self.top_k)
143
+ total_tokens = batch_size * seq_len
144
+ final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
145
+
146
+ for expert_id in range(self.num_experts):
147
+ expert_mask = (flat_indices == expert_id).any(dim=1)
148
+ token_ids = torch.where(expert_mask)[0]
149
+ if token_ids.numel() > 0:
150
+ expert_input = flat_input[token_ids]
151
+ expert_weight_mask = flat_indices[token_ids] == expert_id
152
+ expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype)
153
+ combined_weights = expert_weights.sum(dim=1)
154
+ expert_output = self.experts[expert_id](expert_input)
155
+ weighted_output = expert_output * combined_weights.unsqueeze(1)
156
+ final_output.index_add_(0, token_ids, weighted_output)
157
+ else:
158
+ dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
159
+ final_output[0] += self.experts[expert_id](dummy_input)[0] * 0
160
+
161
+ final_output = final_output.view(batch_size, seq_len, hidden_dim)
162
+ if self.use_shared_expert:
163
+ final_output += self.shared_expert(hidden_states)
164
+
165
+ loss = load_balance_loss
166
+ if self.training and self.routing_contrastive_lam > 0:
167
+ flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1)
168
+ cond_mask = ~(
169
+ flat_labels == 1000
170
+ ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device)
171
+ cond_token_embeddings = flat_input[cond_mask]
172
+ if self.use_top_k_for_routing_contrastive:
173
+ cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask]
174
+ else:
175
+ top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0]
176
+ cond_cluster_assignments = top1_expert_indices[cond_mask]
177
+ routing_contrastive_loss = self.compute_routing_contrastive_loss(
178
+ cond_token_embeddings,
179
+ cond_cluster_assignments,
180
+ use_top_k=self.use_top_k_for_routing_contrastive,
181
+ )
182
+ routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam
183
+ loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss
184
+
185
+ return final_output, loss
186
+
187
+ def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False):
188
+ cluster_centers = self.cluster_centers
189
+ num_clusters = cluster_centers.size(0)
190
+ device = cluster_centers.device
191
+ cluster_means = []
192
+ valid_clusters = []
193
+ for cluster_id in range(num_clusters):
194
+ mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id
195
+ if mask.sum() > 0:
196
+ cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True))
197
+ valid_clusters.append(cluster_id)
198
+ if len(valid_clusters) < 2:
199
+ return torch.tensor(0.0, device=device)
200
+ cluster_means = torch.cat(cluster_means, dim=0)
201
+ valid_centers = cluster_centers[valid_clusters]
202
+ centers_norm = F.normalize(valid_centers, p=2, dim=1)
203
+ means_norm = F.normalize(cluster_means, p=2, dim=1)
204
+ sim_matrix = centers_norm @ means_norm.T
205
+ logits = sim_matrix / self.routing_contrastive_temperature
206
+ labels = torch.arange(sim_matrix.size(0), device=device)
207
+ return F.cross_entropy(logits, labels)
208
+
209
+ def _init_weights(self):
210
+ nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02)
211
+
212
+
213
+ class DiTBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ hidden_size,
217
+ num_heads,
218
+ head_dim=None,
219
+ mlp_ratio=4.0,
220
+ use_swiglu=False,
221
+ MoE_config=None,
222
+ use_moe=False,
223
+ **block_kwargs,
224
+ ):
225
+ super().__init__()
226
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
228
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
230
+ self.use_moe = use_moe
231
+ if use_moe:
232
+ self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config)
233
+ else:
234
+ if not use_swiglu:
235
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
236
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
237
+ else:
238
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
239
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
240
+
241
+ def forward(self, x, c, label):
242
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
243
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
244
+ if self.use_moe:
245
+ x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label)
246
+ if aux_loss is not None:
247
+ x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss)
248
+ return x + gate_mlp.unsqueeze(1) * x_mlp
249
+ return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
250
+
251
+
252
+ class DiT(nn.Module):
253
+ def __init__(
254
+ self,
255
+ input_size=32,
256
+ patch_size=2,
257
+ in_channels=4,
258
+ hidden_size=1152,
259
+ depth=28,
260
+ num_heads=16,
261
+ mlp_ratio=4.0,
262
+ qk_norm=False,
263
+ class_dropout_prob=0.1,
264
+ num_classes=1000,
265
+ learn_sigma=True,
266
+ use_swiglu=False,
267
+ MoE_config=None,
268
+ head_dim=None,
269
+ ):
270
+ super().__init__()
271
+ self.learn_sigma = learn_sigma
272
+ self.in_channels = in_channels
273
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
274
+ self.patch_size = patch_size
275
+ self.num_heads = num_heads
276
+ self.MoE_config = MoE_config
277
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
278
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
279
+ self.t_embedder = TimestepEmbedder(hidden_size)
280
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True)
281
+ num_patches = self.x_embedder.num_patches
282
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
283
+ self.blocks = nn.ModuleList(
284
+ [
285
+ DiTBlock(
286
+ hidden_size,
287
+ num_heads,
288
+ head_dim=head_dim,
289
+ mlp_ratio=mlp_ratio,
290
+ qk_norm=qk_norm,
291
+ use_swiglu=use_swiglu,
292
+ MoE_config=MoE_config,
293
+ use_moe=use_moe_flag[i],
294
+ )
295
+ for i in range(depth)
296
+ ]
297
+ )
298
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
299
+ self.init_MoeMLP = MoE_config.init_MoeMLP
300
+ self.initialize_weights()
301
+
302
+ def initialize_weights(self):
303
+ def _basic_init(module):
304
+ if isinstance(module, nn.Linear):
305
+ torch.nn.init.xavier_uniform_(module.weight)
306
+ if module.bias is not None:
307
+ nn.init.constant_(module.bias, 0)
308
+
309
+ self.apply(_basic_init)
310
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
311
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
312
+ w = self.x_embedder.proj.weight.data
313
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
314
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
315
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
316
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
317
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
318
+ for block in self.blocks:
319
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
320
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
321
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
322
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
323
+ nn.init.constant_(self.final_layer.linear.weight, 0)
324
+ nn.init.constant_(self.final_layer.linear.bias, 0)
325
+
326
+ def init_moe_mlp(module, std=0.006):
327
+ nn.init.normal_(module.up_proj.weight, std=std)
328
+ nn.init.normal_(module.down_proj.weight, std=std)
329
+
330
+ if self.init_MoeMLP:
331
+ for block in self.blocks:
332
+ if hasattr(block.mlp, "experts"):
333
+ for expert in block.mlp.experts:
334
+ init_moe_mlp(expert)
335
+
336
+ def unpatchify(self, x):
337
+ c = self.out_channels
338
+ p = self.x_embedder.patch_size[0]
339
+ h = w = int(x.shape[1] ** 0.5)
340
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
341
+ x = torch.einsum("nhwpqc->nchpwq", x)
342
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
343
+
344
+ def forward(self, x, timestep, context, **kwargs):
345
+ y = context
346
+ if len(x.shape) != 4:
347
+ x = x.squeeze(2)
348
+ x = self.x_embedder(x) + self.pos_embed
349
+ t = self.t_embedder(timestep)
350
+ y, labels = self.y_embedder(y, self.training)
351
+ c = t + y
352
+ for block in self.blocks:
353
+ x = block(x, c, labels)
354
+ x = self.final_layer(x, c)
355
+ return self.unpatchify(x)
ProMoE-XL-256/transformer/backbone_tcdit.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_promoe_common import (
8
+ Attention,
9
+ FinalLayer,
10
+ LabelEmbedder,
11
+ Mlp,
12
+ MoeMLP_DiffMoE as MoeMLP,
13
+ PatchEmbed,
14
+ TimestepEmbedder,
15
+ get_2d_sincos_pos_embed,
16
+ modulate,
17
+ )
18
+
19
+
20
+ class MoEGate(nn.Module):
21
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
22
+ super().__init__()
23
+ self.top_k = num_experts_per_tok
24
+ self.n_routed_experts = num_experts
25
+ self.scoring_func = "softmax"
26
+ self.alpha = aux_loss_alpha
27
+ self.seq_aux = False
28
+ self.norm_topk_prob = False
29
+ self.gating_dim = embed_dim
30
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
31
+ self.reset_parameters()
32
+
33
+ def reset_parameters(self):
34
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
35
+
36
+ def forward(self, hidden_states):
37
+ bsz, seq_len, h = hidden_states.shape
38
+ hidden_states = hidden_states.view(-1, h)
39
+ logits = F.linear(hidden_states, self.weight, None)
40
+ if self.scoring_func != "softmax":
41
+ raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}")
42
+ scores = logits.softmax(dim=-1)
43
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
44
+ if self.top_k > 1 and self.norm_topk_prob:
45
+ topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
46
+
47
+ if self.training and self.alpha > 0.0:
48
+ scores_for_aux = scores
49
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
50
+ if self.seq_aux:
51
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
52
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
53
+ ce.scatter_add_(
54
+ 1,
55
+ topk_idx_for_aux_loss,
56
+ torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device),
57
+ ).div_(seq_len * self.top_k / self.n_routed_experts)
58
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
59
+ else:
60
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
61
+ ce = mask_ce.float().mean(0)
62
+ pi = scores_for_aux.mean(0)
63
+ fi = ce * self.n_routed_experts
64
+ aux_loss = (pi * fi).sum() * self.alpha
65
+ else:
66
+ aux_loss = None
67
+ return topk_idx, topk_weight, aux_loss
68
+
69
+
70
+ class AddAuxiliaryLoss(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, x, loss):
73
+ ctx.dtype = loss.dtype
74
+ ctx.required_aux_loss = loss.requires_grad
75
+ return x
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output):
79
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None
80
+ return grad_output, grad_loss
81
+
82
+
83
+ class SparseMoEBlock(nn.Module):
84
+ def __init__(
85
+ self,
86
+ experts,
87
+ hidden_dim,
88
+ mlp_ratio=4,
89
+ num_experts=16,
90
+ num_experts_per_tok=2,
91
+ pretraining_tp=2,
92
+ n_shared_experts=2,
93
+ ):
94
+ super().__init__()
95
+ self.top_k = num_experts_per_tok
96
+ self.experts = nn.ModuleList(experts)
97
+ self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
98
+ self.n_shared_experts = n_shared_experts
99
+ if self.n_shared_experts > 0:
100
+ intermediate_size = hidden_dim * self.n_shared_experts
101
+ self.shared_experts = MoeMLP(
102
+ hidden_size=hidden_dim,
103
+ intermediate_size=intermediate_size,
104
+ pretraining_tp=pretraining_tp,
105
+ )
106
+
107
+ def forward(self, hidden_states):
108
+ identity = hidden_states
109
+ orig_shape = hidden_states.shape
110
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
111
+
112
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
113
+ flat_topk_idx = topk_idx.view(-1)
114
+ if self.training:
115
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
116
+ y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
117
+ for i, expert in enumerate(self.experts):
118
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float()
119
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
120
+ y = y.view(*orig_shape)
121
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
122
+ else:
123
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
124
+ if self.n_shared_experts > 0:
125
+ y = y + self.shared_experts(identity)
126
+ return y
127
+
128
+ @torch.no_grad()
129
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
130
+ expert_cache = torch.zeros_like(x)
131
+ idxs = flat_expert_indices.argsort()
132
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
133
+ token_idxs = idxs // self.top_k
134
+ for i, end_idx in enumerate(tokens_per_expert):
135
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
136
+ if start_idx == end_idx:
137
+ continue
138
+ expert = self.experts[i]
139
+ exp_token_idx = token_idxs[start_idx:end_idx]
140
+ expert_tokens = x[exp_token_idx]
141
+ expert_out = expert(expert_tokens)
142
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
143
+ expert_cache = expert_cache.to(expert_out.dtype)
144
+ expert_cache.scatter_reduce_(
145
+ 0,
146
+ exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
147
+ expert_out,
148
+ reduce="sum",
149
+ )
150
+ return expert_cache
151
+
152
+
153
+ class DiTBlock(nn.Module):
154
+ def __init__(
155
+ self,
156
+ hidden_size,
157
+ num_heads,
158
+ mlp_ratio=4,
159
+ pretraining_tp=2,
160
+ use_swiglu=False,
161
+ MoE_config=None,
162
+ use_moe=True,
163
+ **block_kwargs,
164
+ ):
165
+ super().__init__()
166
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
167
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
168
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
170
+ self.use_moe = use_moe
171
+ if use_moe:
172
+ if not use_swiglu:
173
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
174
+ experts = [
175
+ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
176
+ for _ in range(MoE_config.num_experts)
177
+ ]
178
+ else:
179
+ experts = [
180
+ MoeMLP(
181
+ hidden_size=hidden_size,
182
+ intermediate_size=mlp_hidden_dim,
183
+ pretraining_tp=pretraining_tp,
184
+ )
185
+ for _ in range(MoE_config.num_experts)
186
+ ]
187
+ self.mlp = SparseMoEBlock(
188
+ experts=experts,
189
+ hidden_dim=hidden_size,
190
+ num_experts=MoE_config.num_experts,
191
+ num_experts_per_tok=MoE_config.capacity,
192
+ n_shared_experts=MoE_config.n_shared_experts,
193
+ )
194
+ else:
195
+ if not use_swiglu:
196
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
197
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
198
+ else:
199
+ self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim)
200
+
201
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
202
+
203
+ def forward(self, x, c):
204
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
205
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
206
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
207
+ return x
208
+
209
+
210
+ class DiT(nn.Module):
211
+ def __init__(
212
+ self,
213
+ input_size=32,
214
+ patch_size=2,
215
+ in_channels=4,
216
+ hidden_size=1152,
217
+ depth=28,
218
+ num_heads=16,
219
+ mlp_ratio=4,
220
+ qk_norm=False,
221
+ class_dropout_prob=0.1,
222
+ num_classes=1000,
223
+ pretraining_tp=1,
224
+ learn_sigma=True,
225
+ use_swiglu=False,
226
+ MoE_config=None,
227
+ ):
228
+ super().__init__()
229
+ self.learn_sigma = learn_sigma
230
+ self.in_channels = in_channels
231
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
232
+ self.patch_size = patch_size
233
+ self.num_heads = num_heads
234
+ self.MoE_config = MoE_config
235
+ use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth
236
+
237
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
238
+ self.t_embedder = TimestepEmbedder(hidden_size)
239
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
240
+ num_patches = self.x_embedder.num_patches
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
242
+
243
+ self.blocks = nn.ModuleList(
244
+ [
245
+ DiTBlock(
246
+ hidden_size,
247
+ num_heads,
248
+ mlp_ratio=mlp_ratio,
249
+ qk_norm=qk_norm,
250
+ use_swiglu=use_swiglu,
251
+ pretraining_tp=pretraining_tp,
252
+ MoE_config=MoE_config,
253
+ use_moe=use_moe_flag[i],
254
+ )
255
+ for i in range(depth)
256
+ ]
257
+ )
258
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
259
+ self.initialize_weights()
260
+
261
+ def initialize_weights(self):
262
+ def _basic_init(module):
263
+ if isinstance(module, nn.Linear):
264
+ torch.nn.init.xavier_uniform_(module.weight)
265
+ if module.bias is not None:
266
+ nn.init.constant_(module.bias, 0)
267
+
268
+ self.apply(_basic_init)
269
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
270
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
271
+ w = self.x_embedder.proj.weight.data
272
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
273
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
274
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
275
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
276
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
277
+ for block in self.blocks:
278
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
279
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
280
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
281
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
282
+ nn.init.constant_(self.final_layer.linear.weight, 0)
283
+ nn.init.constant_(self.final_layer.linear.bias, 0)
284
+
285
+ def unpatchify(self, x):
286
+ c = self.out_channels
287
+ p = self.x_embedder.patch_size[0]
288
+ h = w = int(x.shape[1] ** 0.5)
289
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
290
+ x = torch.einsum("nhwpqc->nchpwq", x)
291
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
292
+
293
+ def forward(self, x, t, context, **kwargs):
294
+ y = context
295
+ if len(x.shape) != 4:
296
+ x = x.squeeze(2)
297
+ x = self.x_embedder(x) + self.pos_embed
298
+ t = self.t_embedder(t)
299
+ y = self.y_embedder(y, self.training)
300
+ c = t + y
301
+ for block in self.blocks:
302
+ x = block(x, c)
303
+ x = self.final_layer(x, c)
304
+ return self.unpatchify(x)
ProMoE-XL-256/transformer/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ProMoETransformer2DModel",
3
+ "architecture": "promoe_tc",
4
+ "model_config": {
5
+ "MoE_config": {
6
+ "init_MoeMLP": false,
7
+ "interleave": true,
8
+ "moe_intermediate_size": 2304,
9
+ "num_routed_experts": 12,
10
+ "shared_expert_intermediate_size": 2304,
11
+ "top_k": 1,
12
+ "use_shared_expert": true,
13
+ "use_uncond_expert": true
14
+ },
15
+ "depth": 28,
16
+ "hidden_size": 1152,
17
+ "input_size": 32,
18
+ "num_classes": 1000,
19
+ "num_heads": 16,
20
+ "patch_size": 2
21
+ }
22
+ }
ProMoE-XL-256/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c02abf475a881cdd445984c14642463ec2679aa8f28967ac20761da09f105580
3
+ size 6271058696
ProMoE-XL-256/transformer/modeling_promoe_common.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from dataclasses import dataclass
4
+ from itertools import repeat
5
+ from typing import Any, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
16
+ return tuple(x)
17
+ return tuple(repeat(x, n))
18
+
19
+ return parse
20
+
21
+
22
+ to_2tuple = _ntuple(2)
23
+
24
+
25
+ class AttrDict(dict):
26
+ def __getattr__(self, item):
27
+ try:
28
+ return self[item]
29
+ except KeyError as error:
30
+ raise AttributeError(item) from error
31
+
32
+ def __setattr__(self, key, value):
33
+ self[key] = value
34
+
35
+ @staticmethod
36
+ def from_data(data: Any) -> Any:
37
+ if isinstance(data, dict):
38
+ return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()})
39
+ if isinstance(data, list):
40
+ return [AttrDict.from_data(v) for v in data]
41
+ return data
42
+
43
+
44
+ class PatchEmbed(nn.Module):
45
+ def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True):
46
+ super().__init__()
47
+ self.img_size = to_2tuple(input_size)
48
+ self.patch_size = to_2tuple(patch_size)
49
+ self.grid_size = (
50
+ self.img_size[0] // self.patch_size[0],
51
+ self.img_size[1] // self.patch_size[1],
52
+ )
53
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
54
+ self.proj = nn.Conv2d(
55
+ in_channels,
56
+ embed_dim,
57
+ kernel_size=self.patch_size,
58
+ stride=self.patch_size,
59
+ bias=bias,
60
+ )
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ hidden_states = self.proj(hidden_states)
64
+ return hidden_states.flatten(2).transpose(1, 2)
65
+
66
+
67
+ class Mlp(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_features,
71
+ hidden_features=None,
72
+ out_features=None,
73
+ act_layer=nn.GELU,
74
+ norm_layer=None,
75
+ bias=True,
76
+ drop=0.0,
77
+ ):
78
+ super().__init__()
79
+ out_features = out_features or in_features
80
+ hidden_features = hidden_features or in_features
81
+ bias = to_2tuple(bias)
82
+ drop_probs = to_2tuple(drop)
83
+
84
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
85
+ self.act = act_layer()
86
+ self.drop1 = nn.Dropout(drop_probs[0])
87
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
88
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
89
+ self.drop2 = nn.Dropout(drop_probs[1])
90
+
91
+ def forward(self, x):
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop1(x)
95
+ x = self.norm(x)
96
+ x = self.fc2(x)
97
+ x = self.drop2(x)
98
+ return x
99
+
100
+
101
+ class MoeMLP(nn.Module):
102
+ def __init__(self, hidden_size, intermediate_size):
103
+ super().__init__()
104
+ self.hidden_size = hidden_size
105
+ self.intermediate_size = intermediate_size
106
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
107
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
108
+ self.act_fn = nn.GELU(approximate="tanh")
109
+
110
+ def forward(self, x):
111
+ return self.down_proj(self.act_fn(self.up_proj(x)))
112
+
113
+
114
+ class MoeMLP_DiffMoE(nn.Module):
115
+ def __init__(self, hidden_size, intermediate_size, pretraining_tp=2):
116
+ super().__init__()
117
+ self.hidden_size = hidden_size
118
+ self.intermediate_size = intermediate_size
119
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
120
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
121
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
122
+ self.act_fn = nn.SiLU()
123
+ self.pretraining_tp = pretraining_tp
124
+
125
+ def forward(self, x):
126
+ if self.pretraining_tp > 1:
127
+ split_size = self.intermediate_size // self.pretraining_tp
128
+ gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0)
129
+ up_proj_slices = self.up_proj.weight.split(split_size, dim=0)
130
+ down_proj_slices = self.down_proj.weight.split(split_size, dim=1)
131
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
132
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
133
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1)
134
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
135
+ return sum(down_proj)
136
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
137
+
138
+
139
+ class Attention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ dim: int,
143
+ num_heads: int = 8,
144
+ qkv_bias: bool = False,
145
+ qk_norm: bool = False,
146
+ attn_drop: float = 0.0,
147
+ proj_drop: float = 0.0,
148
+ head_dim=None,
149
+ norm_layer: nn.Module = nn.LayerNorm,
150
+ ):
151
+ super().__init__()
152
+ self.num_heads = num_heads
153
+ if head_dim is None:
154
+ if dim % num_heads != 0:
155
+ raise ValueError("dim must be divisible by num_heads")
156
+ self.head_dim = dim // num_heads
157
+ else:
158
+ self.head_dim = head_dim
159
+ self.scale = self.head_dim**-0.5
160
+ self.fused_attn = True
161
+ self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias)
162
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
163
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
164
+ self.attn_drop = nn.Dropout(attn_drop)
165
+ self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
166
+ self.proj_drop = nn.Dropout(proj_drop)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ batch_size, seq_len, _ = x.shape
170
+ qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
184
+ attn = self.attn_drop(attn)
185
+ x = attn @ v
186
+
187
+ x = x.transpose(1, 2).reshape(batch_size, seq_len, -1)
188
+ x = self.proj(x)
189
+ return self.proj_drop(x)
190
+
191
+
192
+ def modulate(x, shift, scale):
193
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
194
+
195
+
196
+ class TimestepEmbedder(nn.Module):
197
+ def __init__(self, hidden_size, frequency_embedding_size=256):
198
+ super().__init__()
199
+ self.mlp = nn.Sequential(
200
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
201
+ nn.SiLU(),
202
+ nn.Linear(hidden_size, hidden_size, bias=True),
203
+ )
204
+ self.frequency_embedding_size = frequency_embedding_size
205
+
206
+ @staticmethod
207
+ def timestep_embedding(t, dim, max_period=10000):
208
+ half = dim // 2
209
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
210
+ device=t.device
211
+ )
212
+ args = t[:, None].float() * freqs[None]
213
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
214
+ if dim % 2:
215
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
216
+ return embedding
217
+
218
+ def forward(self, t):
219
+ t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size)
220
+ weight_dtype = self.mlp[0].weight.dtype
221
+ return self.mlp(t_freq.to(dtype=weight_dtype))
222
+
223
+
224
+ class LabelEmbedder(nn.Module):
225
+ def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False):
226
+ super().__init__()
227
+ use_cfg_embedding = dropout_prob > 0
228
+ self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size)
229
+ self.num_classes = num_classes
230
+ self.dropout_prob = dropout_prob
231
+ self.return_labels = return_labels
232
+
233
+ def token_drop(self, labels, force_drop_ids=None):
234
+ if force_drop_ids is None:
235
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
236
+ else:
237
+ drop_ids = force_drop_ids == 1
238
+ return torch.where(drop_ids, self.num_classes, labels)
239
+
240
+ def forward(self, labels, train, force_drop_ids=None):
241
+ if (train and self.dropout_prob > 0) or (force_drop_ids is not None):
242
+ labels = self.token_drop(labels, force_drop_ids)
243
+ embeddings = self.embedding_table(labels)
244
+ if self.return_labels:
245
+ return embeddings, labels
246
+ return embeddings
247
+
248
+
249
+ class FinalLayer(nn.Module):
250
+ def __init__(self, hidden_size, patch_size, out_channels):
251
+ super().__init__()
252
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
253
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
254
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
255
+
256
+ def forward(self, x, c):
257
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
258
+ x = modulate(self.norm_final(x), shift, scale)
259
+ return self.linear(x)
260
+
261
+
262
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
263
+ grid_h = np.arange(grid_size, dtype=np.float32)
264
+ grid_w = np.arange(grid_size, dtype=np.float32)
265
+ grid = np.meshgrid(grid_w, grid_h)
266
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
267
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
268
+ if cls_token and extra_tokens > 0:
269
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
270
+ return pos_embed
271
+
272
+
273
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
274
+ if embed_dim % 2 != 0:
275
+ raise ValueError("embed_dim must be even")
276
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
277
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
278
+ return np.concatenate([emb_h, emb_w], axis=1)
279
+
280
+
281
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
282
+ if embed_dim % 2 != 0:
283
+ raise ValueError("embed_dim must be even")
284
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
285
+ omega /= embed_dim / 2.0
286
+ omega = 1.0 / 10000**omega
287
+ pos = pos.reshape(-1)
288
+ out = np.einsum("m,d->md", pos, omega)
289
+ emb_sin = np.sin(out)
290
+ emb_cos = np.cos(out)
291
+ return np.concatenate([emb_sin, emb_cos], axis=1)
ProMoE-XL-256/transformer/transformer_promoe.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ try:
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.utils import BaseOutput
11
+ except Exception: # pragma: no cover
12
+ class BaseOutput(dict):
13
+ def __post_init__(self):
14
+ self.update(self.__dict__)
15
+
16
+ class _Config(dict):
17
+ def __getattr__(self, key):
18
+ try:
19
+ return self[key]
20
+ except KeyError as error:
21
+ raise AttributeError(key) from error
22
+
23
+ class ConfigMixin:
24
+ config_name = "config.json"
25
+
26
+ class ModelMixin(nn.Module):
27
+ pass
28
+
29
+ def register_to_config(init):
30
+ def wrapper(self, *args, **kwargs):
31
+ import inspect
32
+
33
+ signature = inspect.signature(init)
34
+ bound = signature.bind(self, *args, **kwargs)
35
+ bound.apply_defaults()
36
+ self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
37
+ init(self, *args, **kwargs)
38
+
39
+ return wrapper
40
+
41
+ from .backbone_diffmoe import DiT as DiffMoEBackbone
42
+ from .backbone_dit import DiT as DiTBackbone
43
+ from .backbone_ecdit import DiT as ECDiTBackbone
44
+ from .backbone_promoe_ec import DiT as ProMoEECBackbone
45
+ from .backbone_promoe_tc import DiT as ProMoETCBackbone
46
+ from .backbone_tcdit import DiT as TCDiTBackbone
47
+ from .modeling_promoe_common import AttrDict
48
+
49
+
50
+ @dataclass
51
+ class ProMoETransformer2DModelOutput(BaseOutput):
52
+ sample: torch.FloatTensor
53
+ loss_strategy: Optional[str] = None
54
+ layer_idx_list: Optional[Tuple[int, ...]] = None
55
+ ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None
56
+ pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None
57
+ capacity_pred_loss_weight: Optional[float] = None
58
+
59
+
60
+ _BACKBONES = {
61
+ "dit": DiTBackbone,
62
+ "tcdit": TCDiTBackbone,
63
+ "ecdit": ECDiTBackbone,
64
+ "diffmoe": DiffMoEBackbone,
65
+ "promoe_tc": ProMoETCBackbone,
66
+ "promoe_ec": ProMoEECBackbone,
67
+ }
68
+
69
+
70
+ class ProMoETransformer2DModel(ModelMixin, ConfigMixin):
71
+ config_name = "config.json"
72
+
73
+ @register_to_config
74
+ def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None):
75
+ super().__init__()
76
+ if architecture not in _BACKBONES:
77
+ raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}")
78
+ model_config = model_config or {}
79
+ self.architecture = architecture
80
+ self.model_config = model_config
81
+ self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config))
82
+ self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4))
83
+ self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4))
84
+
85
+ def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]:
86
+ prepared = {}
87
+ for key, value in model_config.items():
88
+ prepared[key] = AttrDict.from_data(value)
89
+ return prepared
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ timestep: Union[torch.Tensor, float, int],
95
+ class_labels: Optional[torch.LongTensor] = None,
96
+ context: Optional[torch.LongTensor] = None,
97
+ return_dict: bool = True,
98
+ **kwargs,
99
+ ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]:
100
+ labels = class_labels if class_labels is not None else context
101
+ if labels is None:
102
+ raise ValueError("Either `class_labels` or `context` must be provided.")
103
+
104
+ if not torch.is_tensor(timestep):
105
+ timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype)
106
+ timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten()
107
+ if timestep.numel() == 1:
108
+ timestep = timestep.repeat(labels.shape[0])
109
+
110
+ sample = self.backbone(hidden_states, timestep, labels, **kwargs)
111
+ if isinstance(sample, tuple):
112
+ if len(sample) == 6 and sample[1] == "Capacity_Pred":
113
+ output = ProMoETransformer2DModelOutput(
114
+ sample=sample[0],
115
+ loss_strategy=sample[1],
116
+ layer_idx_list=tuple(sample[2]),
117
+ ones_list=tuple(sample[3]),
118
+ pred_c_list=tuple(sample[4]),
119
+ capacity_pred_loss_weight=float(sample[5]),
120
+ )
121
+ else:
122
+ output = ProMoETransformer2DModelOutput(sample=sample[0])
123
+ else:
124
+ output = ProMoETransformer2DModelOutput(sample=sample)
125
+
126
+ if not return_dict:
127
+ if output.loss_strategy is None:
128
+ return (output.sample,)
129
+ return (
130
+ output.sample,
131
+ output.loss_strategy,
132
+ output.layer_idx_list,
133
+ output.ones_list,
134
+ output.pred_c_list,
135
+ output.capacity_pred_loss_weight,
136
+ )
137
+ return output