BiliSakura commited on
Commit
1071e0d
·
verified ·
1 Parent(s): 2cb3e63

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -35,4 +35,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  demo.png filter=lfs diff=lfs merge=lfs -text
37
  demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
38
- demo_images/jit_h32_final_test.png filter=lfs diff=lfs merge=lfs -text
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  demo.png filter=lfs diff=lfs merge=lfs -text
37
  demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
 
JiT-B-16/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-B-16/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-B-16/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-B-16/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
JiT-B-32/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-B-32/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-B-32/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-B-32/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
JiT-H-16/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-H-16/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-H-16/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-H-16/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
JiT-H-32/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-H-32/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-H-32/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-H-32/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
JiT-L-16/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-L-16/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-L-16/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-L-16/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
JiT-L-32/model_index.json CHANGED
@@ -5,11 +5,1013 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "scheduling_jit",
9
- "JiTScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "transformer": [
12
  "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
  }
JiT-L-32/pipeline.py CHANGED
@@ -12,8 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from __future__ import annotations
16
-
17
  import importlib
18
  import json
19
  import sys
@@ -23,6 +21,7 @@ from typing import Dict, List, Optional, Tuple, Union
23
  import torch
24
 
25
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
26
  from diffusers.utils.torch_utils import randn_tensor
27
 
28
 
@@ -39,12 +38,10 @@ class JiTPipeline(DiffusionPipeline):
39
  Parameters:
40
  transformer ([`JiTTransformer2DModel`]):
41
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
- scheduler ([`JiTScheduler`]):
43
- Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
  id2label (`dict[int, str]`, *optional*):
45
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
- id2label_cn (`dict[int, str]`, *optional*):
47
- ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
  """
49
 
50
  model_cpu_offload_seq = "transformer"
@@ -71,7 +68,7 @@ class JiTPipeline(DiffusionPipeline):
71
 
72
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
  if subfolder:
74
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
  else:
@@ -82,6 +79,7 @@ class JiTPipeline(DiffusionPipeline):
82
  if subfolder:
83
  variant = variant / subfolder
84
 
 
85
  model_kwargs = dict(kwargs)
86
  inserted: List[str] = []
87
 
@@ -103,19 +101,22 @@ class JiTPipeline(DiffusionPipeline):
103
 
104
  try:
105
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
- scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
 
 
 
107
 
108
  if transformer is None:
109
  raise ValueError(f"No loadable transformer found under {variant}")
110
 
111
  variant_path = str(variant)
112
- id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
 
113
 
114
  pipe = cls(
115
  transformer=transformer,
116
  scheduler=scheduler,
117
  id2label=id2label,
118
- id2label_cn=id2label_cn,
119
  )
120
  if variant_path and hasattr(pipe, "register_to_config"):
121
  pipe.register_to_config(_name_or_path=variant_path)
@@ -128,58 +129,31 @@ class JiTPipeline(DiffusionPipeline):
128
  def __init__(
129
  self,
130
  transformer,
131
- scheduler,
132
- id2label: Optional[Dict[int, str]] = None,
133
- id2label_cn: Optional[Dict[int, str]] = None,
134
  ):
135
  super().__init__()
 
136
  self.register_modules(transformer=transformer, scheduler=scheduler)
137
 
138
- self._id2label = id2label or {}
139
- self._id2label_cn = id2label_cn or {}
140
  self.labels = self._build_label2id(self._id2label)
141
- self.labels_cn = self._build_label2id(self._id2label_cn)
142
-
143
- def _ensure_labels_loaded(self) -> None:
144
- if self._id2label or self._id2label_cn:
145
- return
146
- loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
- if loaded_en:
148
- self._id2label = loaded_en
149
- self.labels = self._build_label2id(self._id2label)
150
- if loaded_cn:
151
- self._id2label_cn = loaded_cn
152
- self.labels_cn = self._build_label2id(self._id2label_cn)
153
 
154
  @staticmethod
155
- def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
- if not variant_path:
157
- return None
158
- variant_dir = Path(variant_path).resolve()
159
- labels_dir = variant_dir.parent / "labels"
160
- return labels_dir if labels_dir.is_dir() else None
161
 
162
  @staticmethod
163
- def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
- filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
- path = labels_dir / filename
166
- if not path.exists():
167
- raise FileNotFoundError(path)
168
- raw = json.loads(path.read_text(encoding="utf-8"))
169
- return {int(key): value for key, value in raw.items()}
170
-
171
- @classmethod
172
- def _load_labels_for_variant(
173
- cls,
174
- variant_path: Optional[str],
175
- ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
- labels_dir = cls._labels_dir_for_variant(variant_path)
177
- if labels_dir is None:
178
- return None, None
179
- try:
180
- return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
- except FileNotFoundError:
182
- return None, None
183
 
184
  @staticmethod
185
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
@@ -194,35 +168,19 @@ class JiTPipeline(DiffusionPipeline):
194
  @property
195
  def id2label(self) -> Dict[int, str]:
196
  """ImageNet class id to English label string (comma-separated synonyms)."""
197
- self._ensure_labels_loaded()
198
  return self._id2label
199
 
200
- @property
201
- def id2label_cn(self) -> Dict[int, str]:
202
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
- self._ensure_labels_loaded()
204
- return self._id2label_cn
205
-
206
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
  r"""
208
  Map ImageNet label strings to class ids.
209
 
210
  Args:
211
  label (`str` or `list[str]`):
212
- One or more label strings. Each string must match a synonym in `id2label` (English)
213
- or `id2label_cn` (Chinese).
214
- lang (`str`, *optional*, defaults to `"en"`):
215
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
  """
217
- if lang not in ("en", "cn"):
218
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
-
220
- self._ensure_labels_loaded()
221
- label2id = self.labels if lang == "en" else self.labels_cn
222
  if not label2id:
223
- raise ValueError(
224
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
- )
226
 
227
  if isinstance(label, str):
228
  label = [label]
@@ -231,7 +189,7 @@ class JiTPipeline(DiffusionPipeline):
231
  if missing:
232
  preview = ", ".join(list(label2id.keys())[:8])
233
  raise ValueError(
234
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
  )
236
  return [label2id[item] for item in label]
237
 
@@ -246,115 +204,10 @@ class JiTPipeline(DiffusionPipeline):
246
  return self.get_label_ids(class_labels)
247
 
248
  if class_labels and isinstance(class_labels[0], str):
249
- self._ensure_labels_loaded()
250
- if all(label in self.labels for label in class_labels):
251
- return self.get_label_ids(class_labels, lang="en")
252
- if all(label in self.labels_cn for label in class_labels):
253
- return self.get_label_ids(class_labels, lang="cn")
254
- raise ValueError(
255
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
- "or Chinese synonyms from `pipe.labels_cn`."
257
- )
258
 
259
  return list(class_labels)
260
 
261
- def _predict_velocity(
262
- self,
263
- z_value: torch.Tensor,
264
- t: torch.Tensor,
265
- class_labels: torch.Tensor,
266
- class_null: torch.Tensor,
267
- do_classifier_free_guidance: bool,
268
- guidance_scale: float,
269
- guidance_interval_min: float,
270
- guidance_interval_max: float,
271
- ) -> torch.Tensor:
272
- t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
- if do_classifier_free_guidance:
274
- z_in = torch.cat([z_value, z_value], dim=0)
275
- labels = torch.cat([class_labels, class_null], dim=0)
276
- else:
277
- z_in = z_value
278
- labels = class_labels
279
-
280
- t_batch = t.flatten().expand(z_in.shape[0])
281
- x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
- v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
-
284
- if not do_classifier_free_guidance:
285
- return v
286
-
287
- v_cond, v_uncond = v.chunk(2, dim=0)
288
- interval_mask = t < guidance_interval_max
289
- if guidance_interval_min != 0.0:
290
- interval_mask = interval_mask & (t > guidance_interval_min)
291
- scale = torch.where(
292
- interval_mask,
293
- torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
- torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
- )
296
- return v_uncond + scale * (v_cond - v_uncond)
297
-
298
- def _run_sampler(
299
- self,
300
- latents: torch.Tensor,
301
- class_labels: torch.Tensor,
302
- class_null: torch.Tensor,
303
- num_inference_steps: int,
304
- do_classifier_free_guidance: bool,
305
- guidance_scale: float,
306
- guidance_interval_min: float,
307
- guidance_interval_max: float,
308
- sampling_method: str,
309
- ) -> torch.Tensor:
310
- device = latents.device
311
- self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
- timesteps = self.scheduler.timesteps
313
-
314
- for i in self.progress_bar(range(num_inference_steps - 1)):
315
- t = timesteps[i]
316
- t_next = timesteps[i + 1]
317
- v = self._predict_velocity(
318
- latents,
319
- t,
320
- class_labels,
321
- class_null,
322
- do_classifier_free_guidance,
323
- guidance_scale,
324
- guidance_interval_min,
325
- guidance_interval_max,
326
- )
327
-
328
- if sampling_method == "heun":
329
- latents_euler = latents + (t_next - t) * v
330
- v_next = self._predict_velocity(
331
- latents_euler,
332
- t_next,
333
- class_labels,
334
- class_null,
335
- do_classifier_free_guidance,
336
- guidance_scale,
337
- guidance_interval_min,
338
- guidance_interval_max,
339
- )
340
- latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
- else:
342
- latents = self.scheduler.step(v, t, latents).prev_sample
343
-
344
- t = timesteps[-2]
345
- t_next = timesteps[-1]
346
- v = self._predict_velocity(
347
- latents,
348
- t,
349
- class_labels,
350
- class_null,
351
- do_classifier_free_guidance,
352
- guidance_scale,
353
- guidance_interval_min,
354
- guidance_interval_max,
355
- )
356
- return latents + (t_next - t) * v
357
-
358
  @torch.inference_mode()
359
  def __call__(
360
  self,
@@ -363,10 +216,12 @@ class JiTPipeline(DiffusionPipeline):
363
  guidance_interval_min: float = 0.1,
364
  guidance_interval_max: float = 1.0,
365
  noise_scale: Optional[float] = None,
366
- t_eps: Optional[float] = None,
367
- sampling_method: Optional[str] = None,
368
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
  num_inference_steps: int = 50,
 
 
 
370
  output_type: Optional[str] = "pil",
371
  return_dict: bool = True,
372
  ) -> Union[ImagePipelineOutput, Tuple]:
@@ -375,7 +230,7 @@ class JiTPipeline(DiffusionPipeline):
375
 
376
  Args:
377
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
- ImageNet class indices or human-readable label strings (English or Chinese).
379
  guidance_scale (`float`, *optional*):
380
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
  guidance_interval_min (`float`, defaults to `0.1`):
@@ -384,10 +239,8 @@ class JiTPipeline(DiffusionPipeline):
384
  Upper bound of the CFG interval in flow time.
385
  noise_scale (`float`, *optional*):
386
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
- t_eps (`float`, *optional*):
388
- Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
- sampling_method (`str`, *optional*):
390
- `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
  generator (`torch.Generator`, *optional*):
392
  RNG for reproducibility.
393
  num_inference_steps (`int`, defaults to `50`):
@@ -397,31 +250,34 @@ class JiTPipeline(DiffusionPipeline):
397
  return_dict (`bool`, *optional*, defaults to `True`):
398
  Return [`ImagePipelineOutput`] if True.
399
  """
400
- solver = sampling_method or self.scheduler.config.solver
401
- if solver not in {"heun", "euler"}:
402
- raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
  if num_inference_steps < 2:
404
  raise ValueError("num_inference_steps must be >= 2.")
405
 
406
- if t_eps is not None:
407
- self.scheduler.register_to_config(t_eps=t_eps)
408
-
409
  class_label_ids = self._normalize_class_labels(class_labels)
410
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
 
412
  batch_size = len(class_label_ids)
413
  image_size = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
414
  channels = int(self.transformer.config.in_channels)
415
  null_class_val = int(self.transformer.config.num_classes)
416
 
417
  if guidance_scale is None:
418
  guidance_scale = 1.0
419
  if noise_scale is None:
420
- noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
 
422
  latents = (
423
  randn_tensor(
424
- shape=(batch_size, channels, image_size, image_size),
425
  generator=generator,
426
  device=self._execution_device,
427
  dtype=self.transformer.dtype,
@@ -433,17 +289,47 @@ class JiTPipeline(DiffusionPipeline):
433
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
  class_null = torch.full_like(class_labels_t, null_class_val)
435
 
436
- latents = self._run_sampler(
437
- latents,
438
- class_labels_t,
439
- class_null,
440
- num_inference_steps,
441
- do_classifier_free_guidance,
442
- guidance_scale,
443
- guidance_interval_min,
444
- guidance_interval_max,
445
- solver,
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
  if output_type == "pt":
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import importlib
16
  import json
17
  import sys
 
21
  import torch
22
 
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
 
 
38
  Parameters:
39
  transformer ([`JiTTransformer2DModel`]):
40
  A class-conditioned `JiTTransformer2DModel` to denoise the images.
41
+ scheduler ([`KarrasDiffusionSchedulers`] or [`FlowMatchHeunDiscreteScheduler`]):
42
+ Diffusers scheduler interface for JiT generation (defaults to `FlowMatchHeunDiscreteScheduler(shift=4.0)`).
43
  id2label (`dict[int, str]`, *optional*):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
 
 
45
  """
46
 
47
  model_cpu_offload_seq = "transformer"
 
68
 
69
  hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
  if subfolder:
71
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
  cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
  variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
  else:
 
79
  if subfolder:
80
  variant = variant / subfolder
81
 
82
+ id2label_override = kwargs.pop("id2label", None)
83
  model_kwargs = dict(kwargs)
84
  inserted: List[str] = []
85
 
 
101
 
102
  try:
103
  transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
+ try:
105
+ scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
+ except Exception:
107
+ scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
 
109
  if transformer is None:
110
  raise ValueError(f"No loadable transformer found under {variant}")
111
 
112
  variant_path = str(variant)
113
+ model_index_path = variant / "model_index.json"
114
+ id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
 
116
  pipe = cls(
117
  transformer=transformer,
118
  scheduler=scheduler,
119
  id2label=id2label,
 
120
  )
121
  if variant_path and hasattr(pipe, "register_to_config"):
122
  pipe.register_to_config(_name_or_path=variant_path)
 
129
  def __init__(
130
  self,
131
  transformer,
132
+ scheduler: FlowMatchHeunDiscreteScheduler,
133
+ id2label: Optional[Dict[Union[int, str], str]] = None,
 
134
  ):
135
  super().__init__()
136
+ scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
 
139
+ self._id2label = self._normalize_id2label(id2label)
 
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
144
+ if not id2label:
145
+ return {}
146
+ return {int(key): value for key, value in id2label.items()}
 
 
147
 
148
  @staticmethod
149
+ def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
150
+ if not model_index_path.exists():
151
+ return {}
152
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
153
+ id2label = raw.get("id2label")
154
+ if not isinstance(id2label, dict):
155
+ return {}
156
+ return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @staticmethod
159
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
  """ImageNet class id to English label string (comma-separated synonyms)."""
 
171
  return self._id2label
172
 
173
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
 
174
  r"""
175
  Map ImageNet label strings to class ids.
176
 
177
  Args:
178
  label (`str` or `list[str]`):
179
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
180
  """
181
+ label2id = self.labels
 
 
 
 
182
  if not label2id:
183
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
 
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
  raise ValueError(
192
+ f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
  )
194
  return [label2id[item] for item in label]
195
 
 
204
  return self.get_label_ids(class_labels)
205
 
206
  if class_labels and isinstance(class_labels[0], str):
207
+ return self.get_label_ids(class_labels)
 
 
 
 
 
 
 
 
208
 
209
  return list(class_labels)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
 
216
  guidance_interval_min: float = 0.1,
217
  guidance_interval_max: float = 1.0,
218
  noise_scale: Optional[float] = None,
219
+ t_eps: float = 5e-2,
 
220
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
221
  num_inference_steps: int = 50,
222
+ height: Optional[int] = None,
223
+ width: Optional[int] = None,
224
+ interpolate_pos_encoding: bool = True,
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
 
230
 
231
  Args:
232
  class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
+ ImageNet class indices or human-readable English label strings.
234
  guidance_scale (`float`, *optional*):
235
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
  guidance_interval_min (`float`, defaults to `0.1`):
 
239
  Upper bound of the CFG interval in flow time.
240
  noise_scale (`float`, *optional*):
241
  Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
+ t_eps (`float`, defaults to `5e-2`):
243
+ Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
 
 
244
  generator (`torch.Generator`, *optional*):
245
  RNG for reproducibility.
246
  num_inference_steps (`int`, defaults to `50`):
 
250
  return_dict (`bool`, *optional*, defaults to `True`):
251
  Return [`ImagePipelineOutput`] if True.
252
  """
 
 
 
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
255
 
 
 
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
258
 
259
  batch_size = len(class_label_ids)
260
  image_size = int(self.transformer.config.sample_size)
261
+ patch_size = int(self.transformer.config.patch_size)
262
+ height = int(height or image_size)
263
+ width = int(width or image_size)
264
+ if height <= 0 or width <= 0:
265
+ raise ValueError("height and width must be positive integers.")
266
+ if height % patch_size != 0 or width % patch_size != 0:
267
+ raise ValueError(
268
+ f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
+ )
270
  channels = int(self.transformer.config.in_channels)
271
  null_class_val = int(self.transformer.config.num_classes)
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
  latents = (
279
  randn_tensor(
280
+ shape=(batch_size, channels, height, width),
281
  generator=generator,
282
  device=self._execution_device,
283
  dtype=self.transformer.dtype,
 
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
290
  class_null = torch.full_like(class_labels_t, null_class_val)
291
 
292
+ if do_classifier_free_guidance:
293
+ class_labels_input = torch.cat([class_labels_t, class_null], dim=0)
294
+ else:
295
+ class_labels_input = class_labels_t
296
+
297
+ self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
298
+ for t in self.progress_bar(self.scheduler.timesteps):
299
+ step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
+ sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
301
+ sigma = sigma.clamp_min(t_eps)
302
+ t_flow = (1.0 - sigma).clamp(0.0, 1.0)
303
+
304
+ if do_classifier_free_guidance:
305
+ latent_model_input = torch.cat([latents, latents], dim=0)
306
+ else:
307
+ latent_model_input = latents
308
+
309
+ timesteps = t_flow.flatten().expand(latent_model_input.shape[0])
310
+ x_pred = self.transformer(
311
+ latent_model_input,
312
+ timestep=timesteps,
313
+ class_labels=class_labels_input,
314
+ interpolate_pos_encoding=interpolate_pos_encoding,
315
+ ).sample
316
+
317
+ if do_classifier_free_guidance:
318
+ x_cond, x_uncond = x_pred.chunk(2, dim=0)
319
+ interval_mask = t_flow < guidance_interval_max
320
+ if guidance_interval_min != 0.0:
321
+ interval_mask = interval_mask & (t_flow > guidance_interval_min)
322
+ scale = torch.where(
323
+ interval_mask,
324
+ torch.tensor(guidance_scale, device=latents.device, dtype=latents.dtype),
325
+ torch.tensor(1.0, device=latents.device, dtype=latents.dtype),
326
+ )
327
+ x_pred = x_uncond + scale * (x_cond - x_uncond)
328
+
329
+ sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
+ # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
+ model_output = -(x_pred - latents) / sigma
332
+ latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
JiT-L-32/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
- "_class_name": "JiTScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
- "t_eps": 0.05,
6
- "solver": "heun"
7
  }
 
1
  {
2
+ "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
+ "shift": 4.0
 
6
  }
JiT-L-32/transformer/jit_transformer_2d.py CHANGED
@@ -68,38 +68,58 @@ class JiTRotaryEmbedding(nn.Module):
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
- if custom_freqs is not None:
72
- freqs = custom_freqs
73
- else:
74
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
-
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
-
80
- freqs = torch.einsum("..., f -> ... f", t, freqs)
81
- freqs = freqs.repeat_interleave(2, dim=-1)
82
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
-
84
- if num_cls_token > 0:
85
- freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
- cos_img = freqs_flat.cos()
87
- sin_img = freqs_flat.sin()
88
-
89
- # prepend in-context cls token
90
- _, D = cos_img.shape
91
- cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
- sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
-
94
- self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
- self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
  else:
97
- self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
- self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
-
100
- def forward(self, t):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
  seq_len = t.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
103
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
 
@@ -195,7 +215,7 @@ class JiTAttention(nn.Module):
195
  self.proj = nn.Linear(dim, dim)
196
  self.proj_drop = nn.Dropout(proj_drop)
197
 
198
- def forward(self, x, rope=None):
199
  B, N, C = x.shape
200
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -206,8 +226,8 @@ class JiTAttention(nn.Module):
206
  if rope is not None:
207
  q = q.transpose(1, 2)
208
  k = k.transpose(1, 2)
209
- q = rope(q)
210
- k = rope(k)
211
  q = q.transpose(1, 2)
212
  k = k.transpose(1, 2)
213
 
@@ -254,7 +274,7 @@ class JiTBlock(nn.Module):
254
  self.act = nn.SiLU()
255
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
 
257
- def forward(self, x, c, feat_rope=None):
258
  # Apply activation
259
  c = self.act(c)
260
 
@@ -263,7 +283,7 @@ class JiTBlock(nn.Module):
263
  # Attention block
264
  norm_x = self.norm1(x)
265
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
- attn_out = self.attn(modulated_x, rope=feat_rope)
267
  x = x + gate_msa.unsqueeze(1) * attn_out
268
 
269
  # MLP block
@@ -437,11 +457,30 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
437
  self.act_final = nn.SiLU()
438
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
  hidden_states: torch.Tensor,
443
  timestep: torch.LongTensor,
444
  class_labels: torch.LongTensor,
 
445
  return_dict: bool = True,
446
  ):
447
 
@@ -454,8 +493,19 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
454
  c = t_emb + y_emb
455
 
456
  # Patch Embed
 
457
  x = self.x_embedder(hidden_states)
458
- x = x + self.pos_embed.to(x.dtype)
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Blocks
461
  for i, block in enumerate(self.blocks):
@@ -467,15 +517,23 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
467
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
 
469
  if self.training and self.gradient_checkpointing:
 
 
 
 
 
 
 
 
 
470
  x = torch.utils.checkpoint.checkpoint(
471
- block,
472
  x,
473
  c,
474
- rope,
475
  use_reentrant=False,
476
  )
477
  else:
478
- x = block(x, c, feat_rope=rope)
479
 
480
  # Slice off in-context tokens
481
  if self.in_context_len > 0:
@@ -489,10 +547,11 @@ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
489
  x = self.linear_final(x)
490
 
491
  # Unpatchify
492
- h = w = int(x.shape[1] ** 0.5)
493
- x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
  x = torch.einsum("nhwpqc->nchpwq", x)
495
- output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
 
 
496
 
497
  if not return_dict:
498
  return (output,)
 
68
  num_cls_token=0,
69
  ):
70
  super().__init__()
71
+ self.dim = dim
72
+ self.pt_seq_len = pt_seq_len
73
+ self.theta = theta
74
+ self.num_cls_token = num_cls_token
75
+ self.custom_freqs = custom_freqs
76
  if ft_seq_len is None:
77
  ft_seq_len = pt_seq_len
78
+ self._cached_hw = None
79
+ cos, sin = self._build_freqs(ft_seq_len, ft_seq_len, device=torch.device("cpu"))
80
+ self.register_buffer("freqs_cos", cos, persistent=False)
81
+ self.register_buffer("freqs_sin", sin, persistent=False)
82
+ self._cached_hw = (ft_seq_len, ft_seq_len)
83
+
84
+ def _build_freqs(self, height, width, device):
85
+ if self.custom_freqs is not None:
86
+ freqs = self.custom_freqs.to(device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
87
  else:
88
+ freqs = 1.0 / (
89
+ self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)[: (self.dim // 2)] / self.dim)
90
+ )
91
+
92
+ t_h = torch.arange(height, device=device, dtype=torch.float32) / height * self.pt_seq_len
93
+ t_w = torch.arange(width, device=device, dtype=torch.float32) / width * self.pt_seq_len
94
+ freqs_h = torch.einsum("..., f -> ... f", t_h, freqs).repeat_interleave(2, dim=-1)
95
+ freqs_w = torch.einsum("..., f -> ... f", t_w, freqs).repeat_interleave(2, dim=-1)
96
+ freqs_2d = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
97
+ freqs_flat = freqs_2d.view(-1, freqs_2d.shape[-1])
98
+ cos_img = freqs_flat.cos()
99
+ sin_img = freqs_flat.sin()
100
+ if self.num_cls_token > 0:
101
+ _, dim_freq = cos_img.shape
102
+ cos_pad = torch.ones(self.num_cls_token, dim_freq, dtype=cos_img.dtype, device=device)
103
+ sin_pad = torch.zeros(self.num_cls_token, dim_freq, dtype=sin_img.dtype, device=device)
104
+ cos_img = torch.cat([cos_pad, cos_img], dim=0)
105
+ sin_img = torch.cat([sin_pad, sin_img], dim=0)
106
+ return cos_img, sin_img
107
+
108
+ def forward(self, t, height=None, width=None):
109
  # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
110
  seq_len = t.shape[1]
111
+ if height is None or width is None:
112
+ image_tokens = seq_len - self.num_cls_token
113
+ size = int(image_tokens**0.5)
114
+ if size * size != image_tokens:
115
+ raise ValueError(
116
+ f"Cannot infer square token grid from sequence length {seq_len} with {self.num_cls_token} class tokens."
117
+ )
118
+ height = size
119
+ width = size
120
+ if self._cached_hw != (height, width) or self.freqs_cos.device != t.device:
121
+ self.freqs_cos, self.freqs_sin = self._build_freqs(height, width, device=t.device)
122
+ self._cached_hw = (height, width)
123
  freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
124
  freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
125
 
 
215
  self.proj = nn.Linear(dim, dim)
216
  self.proj_drop = nn.Dropout(proj_drop)
217
 
218
+ def forward(self, x, rope=None, grid_height=None, grid_width=None):
219
  B, N, C = x.shape
220
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
221
  q, k, v = qkv[0], qkv[1], qkv[2]
 
226
  if rope is not None:
227
  q = q.transpose(1, 2)
228
  k = k.transpose(1, 2)
229
+ q = rope(q, height=grid_height, width=grid_width)
230
+ k = rope(k, height=grid_height, width=grid_width)
231
  q = q.transpose(1, 2)
232
  k = k.transpose(1, 2)
233
 
 
274
  self.act = nn.SiLU()
275
  self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
276
 
277
+ def forward(self, x, c, feat_rope=None, grid_height=None, grid_width=None):
278
  # Apply activation
279
  c = self.act(c)
280
 
 
283
  # Attention block
284
  norm_x = self.norm1(x)
285
  modulated_x = modulate(norm_x, shift_msa, scale_msa)
286
+ attn_out = self.attn(modulated_x, rope=feat_rope, grid_height=grid_height, grid_width=grid_width)
287
  x = x + gate_msa.unsqueeze(1) * attn_out
288
 
289
  # MLP block
 
457
  self.act_final = nn.SiLU()
458
  self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
459
 
460
+ def _get_patch_grid(self, hidden_states):
461
+ height, width = hidden_states.shape[-2:]
462
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
463
+ raise ValueError(
464
+ f"Input size {(height, width)} must be divisible by patch_size={self.patch_size}."
465
+ )
466
+ return height // self.patch_size, width // self.patch_size
467
+
468
+ def _interpolate_pos_encoding(self, tokens, grid_height, grid_width):
469
+ num_tokens = grid_height * grid_width
470
+ if self.pos_embed.shape[1] == num_tokens:
471
+ return self.pos_embed.to(device=tokens.device, dtype=tokens.dtype)
472
+ base_size = int(self.pos_embed.shape[1] ** 0.5)
473
+ pos_embed = self.pos_embed.reshape(1, base_size, base_size, self.hidden_size).permute(0, 3, 1, 2)
474
+ pos_embed = F.interpolate(pos_embed, size=(grid_height, grid_width), mode="bicubic", align_corners=False)
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_tokens, self.hidden_size)
476
+ return pos_embed.to(device=tokens.device, dtype=tokens.dtype)
477
+
478
  def forward(
479
  self,
480
  hidden_states: torch.Tensor,
481
  timestep: torch.LongTensor,
482
  class_labels: torch.LongTensor,
483
+ interpolate_pos_encoding: bool = True,
484
  return_dict: bool = True,
485
  ):
486
 
 
493
  c = t_emb + y_emb
494
 
495
  # Patch Embed
496
+ grid_height, grid_width = self._get_patch_grid(hidden_states)
497
  x = self.x_embedder(hidden_states)
498
+ if interpolate_pos_encoding:
499
+ pos_embed = self._interpolate_pos_encoding(x, grid_height, grid_width)
500
+ else:
501
+ expected_tokens = grid_height * grid_width
502
+ if self.pos_embed.shape[1] != expected_tokens:
503
+ raise ValueError(
504
+ f"pos_embed token count {self.pos_embed.shape[1]} does not match input token count {expected_tokens}. "
505
+ "Enable interpolate_pos_encoding for dynamic resolutions."
506
+ )
507
+ pos_embed = self.pos_embed.to(device=x.device, dtype=x.dtype)
508
+ x = x + pos_embed
509
 
510
  # Blocks
511
  for i, block in enumerate(self.blocks):
 
517
  rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
518
 
519
  if self.training and self.gradient_checkpointing:
520
+ def custom_forward(current_x, current_c):
521
+ return block(
522
+ current_x,
523
+ current_c,
524
+ feat_rope=rope,
525
+ grid_height=grid_height,
526
+ grid_width=grid_width,
527
+ )
528
+
529
  x = torch.utils.checkpoint.checkpoint(
530
+ custom_forward,
531
  x,
532
  c,
 
533
  use_reentrant=False,
534
  )
535
  else:
536
+ x = block(x, c, feat_rope=rope, grid_height=grid_height, grid_width=grid_width)
537
 
538
  # Slice off in-context tokens
539
  if self.in_context_len > 0:
 
547
  x = self.linear_final(x)
548
 
549
  # Unpatchify
550
+ x = x.reshape(shape=(x.shape[0], grid_height, grid_width, self.patch_size, self.patch_size, self.out_channels))
 
551
  x = torch.einsum("nhwpqc->nchpwq", x)
552
+ output = x.reshape(
553
+ shape=(x.shape[0], self.out_channels, grid_height * self.patch_size, grid_width * self.patch_size)
554
+ )
555
 
556
  if not return_dict:
557
  return (output,)
README.md CHANGED
@@ -19,17 +19,17 @@ language:
19
  Native diffusers implementation of **JiT** (Just image Transformer). Each variant folder is self-contained:
20
 
21
  - `pipeline.py` — `JiTPipeline`
22
- - `scheduler/scheduling_jit.py` — `JiTScheduler` (linear `t in [0, 1]`, Heun/Euler)
23
  - `transformer/jit_transformer_2d.py` — `JiTTransformer2DModel`
24
 
25
- Shared ImageNet-1k labels live in [`labels/`](labels/) at the repo root (not duplicated per variant).
26
 
27
  No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom code in the variant directory.
28
 
29
  ## Available checkpoints
30
 
31
  | Checkpoint | Path | Resolution | Recommended CFG |
32
- |---|---|---|---|
33
  | JiT-B/16 | `./JiT-B-16` | 256×256 | 3.0 |
34
  | JiT-L/16 | `./JiT-L-16` | 256×256 | 2.4 |
35
  | JiT-H/16 | `./JiT-H-16` | 256×256 | 2.2 |
@@ -39,42 +39,52 @@ No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom cod
39
 
40
  ## ImageNet class labels
41
 
42
- | File | Direction | Format |
43
- |---|---|---|
44
- | `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
45
- | `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` |
46
 
47
- - `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence
48
- - `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id), sorted for browsing
49
- - `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")`
50
  - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
51
 
 
 
52
  ## Inference
53
 
 
 
 
 
 
 
 
 
54
  ```python
55
- from diffusers import DiffusionPipeline
 
56
  import torch
57
 
 
58
  pipe = DiffusionPipeline.from_pretrained(
59
- "./JiT-H-32",
60
- trust_remote_code=True,
61
  )
 
62
  pipe.to("cuda")
63
- pipe.transformer.to(dtype=torch.bfloat16)
64
 
65
  # Numeric or human-readable labels
66
  print(pipe.id2label[207])
67
  print(pipe.get_label_ids("golden retriever"))
68
 
69
  generator = torch.Generator(device="cuda").manual_seed(42)
70
- images = pipe(
71
  class_labels="golden retriever",
72
  num_inference_steps=50,
73
  guidance_scale=2.3,
74
- sampling_method="heun",
75
  generator=generator,
76
- ).images
77
- images[0].save("output.png")
78
  ```
79
 
 
 
80
  Load a **variant subfolder** (e.g. `./JiT-H-32`), not the repo root.
 
19
  Native diffusers implementation of **JiT** (Just image Transformer). Each variant folder is self-contained:
20
 
21
  - `pipeline.py` — `JiTPipeline`
22
+ - `scheduler/scheduler_config.json` — `FlowMatchHeunDiscreteScheduler` config (default `shift=4.0`)
23
  - `transformer/jit_transformer_2d.py` — `JiTTransformer2DModel`
24
 
25
+ The pipeline now supports dynamic inference resolution in `__call__` with positional interpolation.
26
 
27
  No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom code in the variant directory.
28
 
29
  ## Available checkpoints
30
 
31
  | Checkpoint | Path | Resolution | Recommended CFG |
32
+ | --- | --- | --- | --- |
33
  | JiT-B/16 | `./JiT-B-16` | 256×256 | 3.0 |
34
  | JiT-L/16 | `./JiT-L-16` | 256×256 | 2.4 |
35
  | JiT-H/16 | `./JiT-H-16` | 256×256 | 2.2 |
 
39
 
40
  ## ImageNet class labels
41
 
42
+ Each variant keeps an English `id2label` map directly in its own `model_index.json` (DiT-style).
 
 
 
43
 
44
+ - `pipe.id2label` — inspect id → English label correspondence
45
+ - `pipe.labels` — reverse map (English synonym → id), sorted for browsing
46
+ - `pipe.get_label_ids("golden retriever")`
47
  - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
48
 
49
+ Chinese labels are preserved in the main source repo under `src/labels/id2label_cn.json` for reference.
50
+
51
  ## Inference
52
 
53
+ Run the bundled demo script from the repo root:
54
+
55
+ ```bash
56
+ python demo_inference.py
57
+ ```
58
+
59
+ This writes `demo.png` using `JiT-H-32` with the settings below.
60
+
61
  ```python
62
+ from pathlib import Path
63
+ from diffusers import DiffusionPipeline, FlowMatchHeunDiscreteScheduler
64
  import torch
65
 
66
+ model_dir = Path("./JiT-H-32")
67
  pipe = DiffusionPipeline.from_pretrained(
68
+ str(model_dir),
69
+ custom_pipeline=str(model_dir / "pipeline.py"),
70
  )
71
+ pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(pipe.scheduler.config, shift=4.0)
72
  pipe.to("cuda")
 
73
 
74
  # Numeric or human-readable labels
75
  print(pipe.id2label[207])
76
  print(pipe.get_label_ids("golden retriever"))
77
 
78
  generator = torch.Generator(device="cuda").manual_seed(42)
79
+ image = pipe(
80
  class_labels="golden retriever",
81
  num_inference_steps=50,
82
  guidance_scale=2.3,
 
83
  generator=generator,
84
+ ).images[0]
85
+ image.save("demo.png")
86
  ```
87
 
88
+ `height` and `width` default to the checkpoint's native resolution when omitted.
89
+
90
  Load a **variant subfolder** (e.g. `./JiT-H-32`), not the repo root.
demo.png CHANGED

Git LFS Details

  • SHA256: f5fdbd0300f895de7642229d1294aff74facd75c0bb4c4a01efa8c75b14b6fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB

Git LFS Details

  • SHA256: 406f8c46e0fc77bc2f612c16e2ffde230f41fcc6e4bfb2e11307b758b19169e7
  • Pointer size: 131 Bytes
  • Size of remote file: 539 kB
demo_inference.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a demo image with JiT-H-32."""
3
+
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, FlowMatchHeunDiscreteScheduler
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parent
10
+ MODEL_DIR = REPO_ROOT / "JiT-H-32"
11
+ OUTPUT_PATH = REPO_ROOT / "demo.png"
12
+
13
+
14
+ def main() -> None:
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ str(MODEL_DIR),
17
+ custom_pipeline=str(MODEL_DIR / "pipeline.py"),
18
+ torch_dtype=torch.bfloat16,
19
+ )
20
+ pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(pipe.scheduler.config, shift=4.0)
21
+ pipe.to("cuda")
22
+
23
+ print(pipe.id2label[207])
24
+ print(pipe.get_label_ids("golden retriever"))
25
+
26
+ generator = torch.Generator(device="cuda").manual_seed(42)
27
+ image = pipe(
28
+ class_labels="golden retriever",
29
+ num_inference_steps=50,
30
+ guidance_scale=2.3,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save(OUTPUT_PATH)
34
+ print(f"Saved demo image to {OUTPUT_PATH}")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()