BiliSakura commited on
Commit
098ef8f
·
verified ·
1 Parent(s): 4968e7f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ PixelFlow-256/demo.png filter=lfs diff=lfs merge=lfs -text
PixelFlow-256/demo.png CHANGED

Git LFS Details

  • SHA256: 729a0166881da84ff71d6006df90284e4592b6330684fc81238ef70c49bf67b3
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
PixelFlow-256/model_index.json CHANGED
@@ -8,5 +8,1007 @@
8
  "transformer": [
9
  "transformer_pixelflow",
10
  "PixelFlowTransformer2DModel"
11
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
 
8
  "transformer": [
9
  "transformer_pixelflow",
10
  "PixelFlowTransformer2DModel"
11
+ ],
12
+ "id2label": {
13
+ "0": "tench, Tinca tinca",
14
+ "1": "goldfish, Carassius auratus",
15
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
16
+ "3": "tiger shark, Galeocerdo cuvieri",
17
+ "4": "hammerhead, hammerhead shark",
18
+ "5": "electric ray, crampfish, numbfish, torpedo",
19
+ "6": "stingray",
20
+ "7": "cock",
21
+ "8": "hen",
22
+ "9": "ostrich, Struthio camelus",
23
+ "10": "brambling, Fringilla montifringilla",
24
+ "11": "goldfinch, Carduelis carduelis",
25
+ "12": "house finch, linnet, Carpodacus mexicanus",
26
+ "13": "junco, snowbird",
27
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
28
+ "15": "robin, American robin, Turdus migratorius",
29
+ "16": "bulbul",
30
+ "17": "jay",
31
+ "18": "magpie",
32
+ "19": "chickadee",
33
+ "20": "water ouzel, dipper",
34
+ "21": "kite",
35
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
36
+ "23": "vulture",
37
+ "24": "great grey owl, great gray owl, Strix nebulosa",
38
+ "25": "European fire salamander, Salamandra salamandra",
39
+ "26": "common newt, Triturus vulgaris",
40
+ "27": "eft",
41
+ "28": "spotted salamander, Ambystoma maculatum",
42
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
43
+ "30": "bullfrog, Rana catesbeiana",
44
+ "31": "tree frog, tree-frog",
45
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
46
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
47
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
48
+ "35": "mud turtle",
49
+ "36": "terrapin",
50
+ "37": "box turtle, box tortoise",
51
+ "38": "banded gecko",
52
+ "39": "common iguana, iguana, Iguana iguana",
53
+ "40": "American chameleon, anole, Anolis carolinensis",
54
+ "41": "whiptail, whiptail lizard",
55
+ "42": "agama",
56
+ "43": "frilled lizard, Chlamydosaurus kingi",
57
+ "44": "alligator lizard",
58
+ "45": "Gila monster, Heloderma suspectum",
59
+ "46": "green lizard, Lacerta viridis",
60
+ "47": "African chameleon, Chamaeleo chamaeleon",
61
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
62
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
63
+ "50": "American alligator, Alligator mississipiensis",
64
+ "51": "triceratops",
65
+ "52": "thunder snake, worm snake, Carphophis amoenus",
66
+ "53": "ringneck snake, ring-necked snake, ring snake",
67
+ "54": "hognose snake, puff adder, sand viper",
68
+ "55": "green snake, grass snake",
69
+ "56": "king snake, kingsnake",
70
+ "57": "garter snake, grass snake",
71
+ "58": "water snake",
72
+ "59": "vine snake",
73
+ "60": "night snake, Hypsiglena torquata",
74
+ "61": "boa constrictor, Constrictor constrictor",
75
+ "62": "rock python, rock snake, Python sebae",
76
+ "63": "Indian cobra, Naja naja",
77
+ "64": "green mamba",
78
+ "65": "sea snake",
79
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
80
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
81
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
82
+ "69": "trilobite",
83
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
84
+ "71": "scorpion",
85
+ "72": "black and gold garden spider, Argiope aurantia",
86
+ "73": "barn spider, Araneus cavaticus",
87
+ "74": "garden spider, Aranea diademata",
88
+ "75": "black widow, Latrodectus mactans",
89
+ "76": "tarantula",
90
+ "77": "wolf spider, hunting spider",
91
+ "78": "tick",
92
+ "79": "centipede",
93
+ "80": "black grouse",
94
+ "81": "ptarmigan",
95
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
96
+ "83": "prairie chicken, prairie grouse, prairie fowl",
97
+ "84": "peacock",
98
+ "85": "quail",
99
+ "86": "partridge",
100
+ "87": "African grey, African gray, Psittacus erithacus",
101
+ "88": "macaw",
102
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
103
+ "90": "lorikeet",
104
+ "91": "coucal",
105
+ "92": "bee eater",
106
+ "93": "hornbill",
107
+ "94": "hummingbird",
108
+ "95": "jacamar",
109
+ "96": "toucan",
110
+ "97": "drake",
111
+ "98": "red-breasted merganser, Mergus serrator",
112
+ "99": "goose",
113
+ "100": "black swan, Cygnus atratus",
114
+ "101": "tusker",
115
+ "102": "echidna, spiny anteater, anteater",
116
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
117
+ "104": "wallaby, brush kangaroo",
118
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
119
+ "106": "wombat",
120
+ "107": "jellyfish",
121
+ "108": "sea anemone, anemone",
122
+ "109": "brain coral",
123
+ "110": "flatworm, platyhelminth",
124
+ "111": "nematode, nematode worm, roundworm",
125
+ "112": "conch",
126
+ "113": "snail",
127
+ "114": "slug",
128
+ "115": "sea slug, nudibranch",
129
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
130
+ "117": "chambered nautilus, pearly nautilus, nautilus",
131
+ "118": "Dungeness crab, Cancer magister",
132
+ "119": "rock crab, Cancer irroratus",
133
+ "120": "fiddler crab",
134
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
135
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
136
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
137
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
138
+ "125": "hermit crab",
139
+ "126": "isopod",
140
+ "127": "white stork, Ciconia ciconia",
141
+ "128": "black stork, Ciconia nigra",
142
+ "129": "spoonbill",
143
+ "130": "flamingo",
144
+ "131": "little blue heron, Egretta caerulea",
145
+ "132": "American egret, great white heron, Egretta albus",
146
+ "133": "bittern",
147
+ "134": "crane",
148
+ "135": "limpkin, Aramus pictus",
149
+ "136": "European gallinule, Porphyrio porphyrio",
150
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
151
+ "138": "bustard",
152
+ "139": "ruddy turnstone, Arenaria interpres",
153
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
154
+ "141": "redshank, Tringa totanus",
155
+ "142": "dowitcher",
156
+ "143": "oystercatcher, oyster catcher",
157
+ "144": "pelican",
158
+ "145": "king penguin, Aptenodytes patagonica",
159
+ "146": "albatross, mollymawk",
160
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
161
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
162
+ "149": "dugong, Dugong dugon",
163
+ "150": "sea lion",
164
+ "151": "Chihuahua",
165
+ "152": "Japanese spaniel",
166
+ "153": "Maltese dog, Maltese terrier, Maltese",
167
+ "154": "Pekinese, Pekingese, Peke",
168
+ "155": "Shih-Tzu",
169
+ "156": "Blenheim spaniel",
170
+ "157": "papillon",
171
+ "158": "toy terrier",
172
+ "159": "Rhodesian ridgeback",
173
+ "160": "Afghan hound, Afghan",
174
+ "161": "basset, basset hound",
175
+ "162": "beagle",
176
+ "163": "bloodhound, sleuthhound",
177
+ "164": "bluetick",
178
+ "165": "black-and-tan coonhound",
179
+ "166": "Walker hound, Walker foxhound",
180
+ "167": "English foxhound",
181
+ "168": "redbone",
182
+ "169": "borzoi, Russian wolfhound",
183
+ "170": "Irish wolfhound",
184
+ "171": "Italian greyhound",
185
+ "172": "whippet",
186
+ "173": "Ibizan hound, Ibizan Podenco",
187
+ "174": "Norwegian elkhound, elkhound",
188
+ "175": "otterhound, otter hound",
189
+ "176": "Saluki, gazelle hound",
190
+ "177": "Scottish deerhound, deerhound",
191
+ "178": "Weimaraner",
192
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
193
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
194
+ "181": "Bedlington terrier",
195
+ "182": "Border terrier",
196
+ "183": "Kerry blue terrier",
197
+ "184": "Irish terrier",
198
+ "185": "Norfolk terrier",
199
+ "186": "Norwich terrier",
200
+ "187": "Yorkshire terrier",
201
+ "188": "wire-haired fox terrier",
202
+ "189": "Lakeland terrier",
203
+ "190": "Sealyham terrier, Sealyham",
204
+ "191": "Airedale, Airedale terrier",
205
+ "192": "cairn, cairn terrier",
206
+ "193": "Australian terrier",
207
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
208
+ "195": "Boston bull, Boston terrier",
209
+ "196": "miniature schnauzer",
210
+ "197": "giant schnauzer",
211
+ "198": "standard schnauzer",
212
+ "199": "Scotch terrier, Scottish terrier, Scottie",
213
+ "200": "Tibetan terrier, chrysanthemum dog",
214
+ "201": "silky terrier, Sydney silky",
215
+ "202": "soft-coated wheaten terrier",
216
+ "203": "West Highland white terrier",
217
+ "204": "Lhasa, Lhasa apso",
218
+ "205": "flat-coated retriever",
219
+ "206": "curly-coated retriever",
220
+ "207": "golden retriever",
221
+ "208": "Labrador retriever",
222
+ "209": "Chesapeake Bay retriever",
223
+ "210": "German short-haired pointer",
224
+ "211": "vizsla, Hungarian pointer",
225
+ "212": "English setter",
226
+ "213": "Irish setter, red setter",
227
+ "214": "Gordon setter",
228
+ "215": "Brittany spaniel",
229
+ "216": "clumber, clumber spaniel",
230
+ "217": "English springer, English springer spaniel",
231
+ "218": "Welsh springer spaniel",
232
+ "219": "cocker spaniel, English cocker spaniel, cocker",
233
+ "220": "Sussex spaniel",
234
+ "221": "Irish water spaniel",
235
+ "222": "kuvasz",
236
+ "223": "schipperke",
237
+ "224": "groenendael",
238
+ "225": "malinois",
239
+ "226": "briard",
240
+ "227": "kelpie",
241
+ "228": "komondor",
242
+ "229": "Old English sheepdog, bobtail",
243
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
244
+ "231": "collie",
245
+ "232": "Border collie",
246
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
247
+ "234": "Rottweiler",
248
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
249
+ "236": "Doberman, Doberman pinscher",
250
+ "237": "miniature pinscher",
251
+ "238": "Greater Swiss Mountain dog",
252
+ "239": "Bernese mountain dog",
253
+ "240": "Appenzeller",
254
+ "241": "EntleBucher",
255
+ "242": "boxer",
256
+ "243": "bull mastiff",
257
+ "244": "Tibetan mastiff",
258
+ "245": "French bulldog",
259
+ "246": "Great Dane",
260
+ "247": "Saint Bernard, St Bernard",
261
+ "248": "Eskimo dog, husky",
262
+ "249": "malamute, malemute, Alaskan malamute",
263
+ "250": "Siberian husky",
264
+ "251": "dalmatian, coach dog, carriage dog",
265
+ "252": "affenpinscher, monkey pinscher, monkey dog",
266
+ "253": "basenji",
267
+ "254": "pug, pug-dog",
268
+ "255": "Leonberg",
269
+ "256": "Newfoundland, Newfoundland dog",
270
+ "257": "Great Pyrenees",
271
+ "258": "Samoyed, Samoyede",
272
+ "259": "Pomeranian",
273
+ "260": "chow, chow chow",
274
+ "261": "keeshond",
275
+ "262": "Brabancon griffon",
276
+ "263": "Pembroke, Pembroke Welsh corgi",
277
+ "264": "Cardigan, Cardigan Welsh corgi",
278
+ "265": "toy poodle",
279
+ "266": "miniature poodle",
280
+ "267": "standard poodle",
281
+ "268": "Mexican hairless",
282
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
283
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
284
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
285
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
286
+ "273": "dingo, warrigal, warragal, Canis dingo",
287
+ "274": "dhole, Cuon alpinus",
288
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
289
+ "276": "hyena, hyaena",
290
+ "277": "red fox, Vulpes vulpes",
291
+ "278": "kit fox, Vulpes macrotis",
292
+ "279": "Arctic fox, white fox, Alopex lagopus",
293
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
294
+ "281": "tabby, tabby cat",
295
+ "282": "tiger cat",
296
+ "283": "Persian cat",
297
+ "284": "Siamese cat, Siamese",
298
+ "285": "Egyptian cat",
299
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
300
+ "287": "lynx, catamount",
301
+ "288": "leopard, Panthera pardus",
302
+ "289": "snow leopard, ounce, Panthera uncia",
303
+ "290": "jaguar, panther, Panthera onca, Felis onca",
304
+ "291": "lion, king of beasts, Panthera leo",
305
+ "292": "tiger, Panthera tigris",
306
+ "293": "cheetah, chetah, Acinonyx jubatus",
307
+ "294": "brown bear, bruin, Ursus arctos",
308
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
309
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
310
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
311
+ "298": "mongoose",
312
+ "299": "meerkat, mierkat",
313
+ "300": "tiger beetle",
314
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
315
+ "302": "ground beetle, carabid beetle",
316
+ "303": "long-horned beetle, longicorn, longicorn beetle",
317
+ "304": "leaf beetle, chrysomelid",
318
+ "305": "dung beetle",
319
+ "306": "rhinoceros beetle",
320
+ "307": "weevil",
321
+ "308": "fly",
322
+ "309": "bee",
323
+ "310": "ant, emmet, pismire",
324
+ "311": "grasshopper, hopper",
325
+ "312": "cricket",
326
+ "313": "walking stick, walkingstick, stick insect",
327
+ "314": "cockroach, roach",
328
+ "315": "mantis, mantid",
329
+ "316": "cicada, cicala",
330
+ "317": "leafhopper",
331
+ "318": "lacewing, lacewing fly",
332
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
333
+ "320": "damselfly",
334
+ "321": "admiral",
335
+ "322": "ringlet, ringlet butterfly",
336
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
337
+ "324": "cabbage butterfly",
338
+ "325": "sulphur butterfly, sulfur butterfly",
339
+ "326": "lycaenid, lycaenid butterfly",
340
+ "327": "starfish, sea star",
341
+ "328": "sea urchin",
342
+ "329": "sea cucumber, holothurian",
343
+ "330": "wood rabbit, cottontail, cottontail rabbit",
344
+ "331": "hare",
345
+ "332": "Angora, Angora rabbit",
346
+ "333": "hamster",
347
+ "334": "porcupine, hedgehog",
348
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
349
+ "336": "marmot",
350
+ "337": "beaver",
351
+ "338": "guinea pig, Cavia cobaya",
352
+ "339": "sorrel",
353
+ "340": "zebra",
354
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
355
+ "342": "wild boar, boar, Sus scrofa",
356
+ "343": "warthog",
357
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
358
+ "345": "ox",
359
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
360
+ "347": "bison",
361
+ "348": "ram, tup",
362
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
363
+ "350": "ibex, Capra ibex",
364
+ "351": "hartebeest",
365
+ "352": "impala, Aepyceros melampus",
366
+ "353": "gazelle",
367
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
368
+ "355": "llama",
369
+ "356": "weasel",
370
+ "357": "mink",
371
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
372
+ "359": "black-footed ferret, ferret, Mustela nigripes",
373
+ "360": "otter",
374
+ "361": "skunk, polecat, wood pussy",
375
+ "362": "badger",
376
+ "363": "armadillo",
377
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
378
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
379
+ "366": "gorilla, Gorilla gorilla",
380
+ "367": "chimpanzee, chimp, Pan troglodytes",
381
+ "368": "gibbon, Hylobates lar",
382
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
383
+ "370": "guenon, guenon monkey",
384
+ "371": "patas, hussar monkey, Erythrocebus patas",
385
+ "372": "baboon",
386
+ "373": "macaque",
387
+ "374": "langur",
388
+ "375": "colobus, colobus monkey",
389
+ "376": "proboscis monkey, Nasalis larvatus",
390
+ "377": "marmoset",
391
+ "378": "capuchin, ringtail, Cebus capucinus",
392
+ "379": "howler monkey, howler",
393
+ "380": "titi, titi monkey",
394
+ "381": "spider monkey, Ateles geoffroyi",
395
+ "382": "squirrel monkey, Saimiri sciureus",
396
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
397
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
398
+ "385": "Indian elephant, Elephas maximus",
399
+ "386": "African elephant, Loxodonta africana",
400
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
401
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
402
+ "389": "barracouta, snoek",
403
+ "390": "eel",
404
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
405
+ "392": "rock beauty, Holocanthus tricolor",
406
+ "393": "anemone fish",
407
+ "394": "sturgeon",
408
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
409
+ "396": "lionfish",
410
+ "397": "puffer, pufferfish, blowfish, globefish",
411
+ "398": "abacus",
412
+ "399": "abaya",
413
+ "400": "academic gown, academic robe, judge robe",
414
+ "401": "accordion, piano accordion, squeeze box",
415
+ "402": "acoustic guitar",
416
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
417
+ "404": "airliner",
418
+ "405": "airship, dirigible",
419
+ "406": "altar",
420
+ "407": "ambulance",
421
+ "408": "amphibian, amphibious vehicle",
422
+ "409": "analog clock",
423
+ "410": "apiary, bee house",
424
+ "411": "apron",
425
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
426
+ "413": "assault rifle, assault gun",
427
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
428
+ "415": "bakery, bakeshop, bakehouse",
429
+ "416": "balance beam, beam",
430
+ "417": "balloon",
431
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
432
+ "419": "Band Aid",
433
+ "420": "banjo",
434
+ "421": "bannister, banister, balustrade, balusters, handrail",
435
+ "422": "barbell",
436
+ "423": "barber chair",
437
+ "424": "barbershop",
438
+ "425": "barn",
439
+ "426": "barometer",
440
+ "427": "barrel, cask",
441
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
442
+ "429": "baseball",
443
+ "430": "basketball",
444
+ "431": "bassinet",
445
+ "432": "bassoon",
446
+ "433": "bathing cap, swimming cap",
447
+ "434": "bath towel",
448
+ "435": "bathtub, bathing tub, bath, tub",
449
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
450
+ "437": "beacon, lighthouse, beacon light, pharos",
451
+ "438": "beaker",
452
+ "439": "bearskin, busby, shako",
453
+ "440": "beer bottle",
454
+ "441": "beer glass",
455
+ "442": "bell cote, bell cot",
456
+ "443": "bib",
457
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
458
+ "445": "bikini, two-piece",
459
+ "446": "binder, ring-binder",
460
+ "447": "binoculars, field glasses, opera glasses",
461
+ "448": "birdhouse",
462
+ "449": "boathouse",
463
+ "450": "bobsled, bobsleigh, bob",
464
+ "451": "bolo tie, bolo, bola tie, bola",
465
+ "452": "bonnet, poke bonnet",
466
+ "453": "bookcase",
467
+ "454": "bookshop, bookstore, bookstall",
468
+ "455": "bottlecap",
469
+ "456": "bow",
470
+ "457": "bow tie, bow-tie, bowtie",
471
+ "458": "brass, memorial tablet, plaque",
472
+ "459": "brassiere, bra, bandeau",
473
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
474
+ "461": "breastplate, aegis, egis",
475
+ "462": "broom",
476
+ "463": "bucket, pail",
477
+ "464": "buckle",
478
+ "465": "bulletproof vest",
479
+ "466": "bullet train, bullet",
480
+ "467": "butcher shop, meat market",
481
+ "468": "cab, hack, taxi, taxicab",
482
+ "469": "caldron, cauldron",
483
+ "470": "candle, taper, wax light",
484
+ "471": "cannon",
485
+ "472": "canoe",
486
+ "473": "can opener, tin opener",
487
+ "474": "cardigan",
488
+ "475": "car mirror",
489
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
490
+ "477": "carpenters kit, tool kit",
491
+ "478": "carton",
492
+ "479": "car wheel",
493
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
494
+ "481": "cassette",
495
+ "482": "cassette player",
496
+ "483": "castle",
497
+ "484": "catamaran",
498
+ "485": "CD player",
499
+ "486": "cello, violoncello",
500
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
501
+ "488": "chain",
502
+ "489": "chainlink fence",
503
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
504
+ "491": "chain saw, chainsaw",
505
+ "492": "chest",
506
+ "493": "chiffonier, commode",
507
+ "494": "chime, bell, gong",
508
+ "495": "china cabinet, china closet",
509
+ "496": "Christmas stocking",
510
+ "497": "church, church building",
511
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
512
+ "499": "cleaver, meat cleaver, chopper",
513
+ "500": "cliff dwelling",
514
+ "501": "cloak",
515
+ "502": "clog, geta, patten, sabot",
516
+ "503": "cocktail shaker",
517
+ "504": "coffee mug",
518
+ "505": "coffeepot",
519
+ "506": "coil, spiral, volute, whorl, helix",
520
+ "507": "combination lock",
521
+ "508": "computer keyboard, keypad",
522
+ "509": "confectionery, confectionary, candy store",
523
+ "510": "container ship, containership, container vessel",
524
+ "511": "convertible",
525
+ "512": "corkscrew, bottle screw",
526
+ "513": "cornet, horn, trumpet, trump",
527
+ "514": "cowboy boot",
528
+ "515": "cowboy hat, ten-gallon hat",
529
+ "516": "cradle",
530
+ "517": "crane",
531
+ "518": "crash helmet",
532
+ "519": "crate",
533
+ "520": "crib, cot",
534
+ "521": "Crock Pot",
535
+ "522": "croquet ball",
536
+ "523": "crutch",
537
+ "524": "cuirass",
538
+ "525": "dam, dike, dyke",
539
+ "526": "desk",
540
+ "527": "desktop computer",
541
+ "528": "dial telephone, dial phone",
542
+ "529": "diaper, nappy, napkin",
543
+ "530": "digital clock",
544
+ "531": "digital watch",
545
+ "532": "dining table, board",
546
+ "533": "dishrag, dishcloth",
547
+ "534": "dishwasher, dish washer, dishwashing machine",
548
+ "535": "disk brake, disc brake",
549
+ "536": "dock, dockage, docking facility",
550
+ "537": "dogsled, dog sled, dog sleigh",
551
+ "538": "dome",
552
+ "539": "doormat, welcome mat",
553
+ "540": "drilling platform, offshore rig",
554
+ "541": "drum, membranophone, tympan",
555
+ "542": "drumstick",
556
+ "543": "dumbbell",
557
+ "544": "Dutch oven",
558
+ "545": "electric fan, blower",
559
+ "546": "electric guitar",
560
+ "547": "electric locomotive",
561
+ "548": "entertainment center",
562
+ "549": "envelope",
563
+ "550": "espresso maker",
564
+ "551": "face powder",
565
+ "552": "feather boa, boa",
566
+ "553": "file, file cabinet, filing cabinet",
567
+ "554": "fireboat",
568
+ "555": "fire engine, fire truck",
569
+ "556": "fire screen, fireguard",
570
+ "557": "flagpole, flagstaff",
571
+ "558": "flute, transverse flute",
572
+ "559": "folding chair",
573
+ "560": "football helmet",
574
+ "561": "forklift",
575
+ "562": "fountain",
576
+ "563": "fountain pen",
577
+ "564": "four-poster",
578
+ "565": "freight car",
579
+ "566": "French horn, horn",
580
+ "567": "frying pan, frypan, skillet",
581
+ "568": "fur coat",
582
+ "569": "garbage truck, dustcart",
583
+ "570": "gasmask, respirator, gas helmet",
584
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
585
+ "572": "goblet",
586
+ "573": "go-kart",
587
+ "574": "golf ball",
588
+ "575": "golfcart, golf cart",
589
+ "576": "gondola",
590
+ "577": "gong, tam-tam",
591
+ "578": "gown",
592
+ "579": "grand piano, grand",
593
+ "580": "greenhouse, nursery, glasshouse",
594
+ "581": "grille, radiator grille",
595
+ "582": "grocery store, grocery, food market, market",
596
+ "583": "guillotine",
597
+ "584": "hair slide",
598
+ "585": "hair spray",
599
+ "586": "half track",
600
+ "587": "hammer",
601
+ "588": "hamper",
602
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
603
+ "590": "hand-held computer, hand-held microcomputer",
604
+ "591": "handkerchief, hankie, hanky, hankey",
605
+ "592": "hard disc, hard disk, fixed disk",
606
+ "593": "harmonica, mouth organ, harp, mouth harp",
607
+ "594": "harp",
608
+ "595": "harvester, reaper",
609
+ "596": "hatchet",
610
+ "597": "holster",
611
+ "598": "home theater, home theatre",
612
+ "599": "honeycomb",
613
+ "600": "hook, claw",
614
+ "601": "hoopskirt, crinoline",
615
+ "602": "horizontal bar, high bar",
616
+ "603": "horse cart, horse-cart",
617
+ "604": "hourglass",
618
+ "605": "iPod",
619
+ "606": "iron, smoothing iron",
620
+ "607": "jack-o-lantern",
621
+ "608": "jean, blue jean, denim",
622
+ "609": "jeep, landrover",
623
+ "610": "jersey, T-shirt, tee shirt",
624
+ "611": "jigsaw puzzle",
625
+ "612": "jinrikisha, ricksha, rickshaw",
626
+ "613": "joystick",
627
+ "614": "kimono",
628
+ "615": "knee pad",
629
+ "616": "knot",
630
+ "617": "lab coat, laboratory coat",
631
+ "618": "ladle",
632
+ "619": "lampshade, lamp shade",
633
+ "620": "laptop, laptop computer",
634
+ "621": "lawn mower, mower",
635
+ "622": "lens cap, lens cover",
636
+ "623": "letter opener, paper knife, paperknife",
637
+ "624": "library",
638
+ "625": "lifeboat",
639
+ "626": "lighter, light, igniter, ignitor",
640
+ "627": "limousine, limo",
641
+ "628": "liner, ocean liner",
642
+ "629": "lipstick, lip rouge",
643
+ "630": "Loafer",
644
+ "631": "lotion",
645
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
646
+ "633": "loupe, jewelers loupe",
647
+ "634": "lumbermill, sawmill",
648
+ "635": "magnetic compass",
649
+ "636": "mailbag, postbag",
650
+ "637": "mailbox, letter box",
651
+ "638": "maillot",
652
+ "639": "maillot, tank suit",
653
+ "640": "manhole cover",
654
+ "641": "maraca",
655
+ "642": "marimba, xylophone",
656
+ "643": "mask",
657
+ "644": "matchstick",
658
+ "645": "maypole",
659
+ "646": "maze, labyrinth",
660
+ "647": "measuring cup",
661
+ "648": "medicine chest, medicine cabinet",
662
+ "649": "megalith, megalithic structure",
663
+ "650": "microphone, mike",
664
+ "651": "microwave, microwave oven",
665
+ "652": "military uniform",
666
+ "653": "milk can",
667
+ "654": "minibus",
668
+ "655": "miniskirt, mini",
669
+ "656": "minivan",
670
+ "657": "missile",
671
+ "658": "mitten",
672
+ "659": "mixing bowl",
673
+ "660": "mobile home, manufactured home",
674
+ "661": "Model T",
675
+ "662": "modem",
676
+ "663": "monastery",
677
+ "664": "monitor",
678
+ "665": "moped",
679
+ "666": "mortar",
680
+ "667": "mortarboard",
681
+ "668": "mosque",
682
+ "669": "mosquito net",
683
+ "670": "motor scooter, scooter",
684
+ "671": "mountain bike, all-terrain bike, off-roader",
685
+ "672": "mountain tent",
686
+ "673": "mouse, computer mouse",
687
+ "674": "mousetrap",
688
+ "675": "moving van",
689
+ "676": "muzzle",
690
+ "677": "nail",
691
+ "678": "neck brace",
692
+ "679": "necklace",
693
+ "680": "nipple",
694
+ "681": "notebook, notebook computer",
695
+ "682": "obelisk",
696
+ "683": "oboe, hautboy, hautbois",
697
+ "684": "ocarina, sweet potato",
698
+ "685": "odometer, hodometer, mileometer, milometer",
699
+ "686": "oil filter",
700
+ "687": "organ, pipe organ",
701
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
702
+ "689": "overskirt",
703
+ "690": "oxcart",
704
+ "691": "oxygen mask",
705
+ "692": "packet",
706
+ "693": "paddle, boat paddle",
707
+ "694": "paddlewheel, paddle wheel",
708
+ "695": "padlock",
709
+ "696": "paintbrush",
710
+ "697": "pajama, pyjama, pjs, jammies",
711
+ "698": "palace",
712
+ "699": "panpipe, pandean pipe, syrinx",
713
+ "700": "paper towel",
714
+ "701": "parachute, chute",
715
+ "702": "parallel bars, bars",
716
+ "703": "park bench",
717
+ "704": "parking meter",
718
+ "705": "passenger car, coach, carriage",
719
+ "706": "patio, terrace",
720
+ "707": "pay-phone, pay-station",
721
+ "708": "pedestal, plinth, footstall",
722
+ "709": "pencil box, pencil case",
723
+ "710": "pencil sharpener",
724
+ "711": "perfume, essence",
725
+ "712": "Petri dish",
726
+ "713": "photocopier",
727
+ "714": "pick, plectrum, plectron",
728
+ "715": "pickelhaube",
729
+ "716": "picket fence, paling",
730
+ "717": "pickup, pickup truck",
731
+ "718": "pier",
732
+ "719": "piggy bank, penny bank",
733
+ "720": "pill bottle",
734
+ "721": "pillow",
735
+ "722": "ping-pong ball",
736
+ "723": "pinwheel",
737
+ "724": "pirate, pirate ship",
738
+ "725": "pitcher, ewer",
739
+ "726": "plane, carpenters plane, woodworking plane",
740
+ "727": "planetarium",
741
+ "728": "plastic bag",
742
+ "729": "plate rack",
743
+ "730": "plow, plough",
744
+ "731": "plunger, plumbers helper",
745
+ "732": "Polaroid camera, Polaroid Land camera",
746
+ "733": "pole",
747
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
748
+ "735": "poncho",
749
+ "736": "pool table, billiard table, snooker table",
750
+ "737": "pop bottle, soda bottle",
751
+ "738": "pot, flowerpot",
752
+ "739": "potters wheel",
753
+ "740": "power drill",
754
+ "741": "prayer rug, prayer mat",
755
+ "742": "printer",
756
+ "743": "prison, prison house",
757
+ "744": "projectile, missile",
758
+ "745": "projector",
759
+ "746": "puck, hockey puck",
760
+ "747": "punching bag, punch bag, punching ball, punchball",
761
+ "748": "purse",
762
+ "749": "quill, quill pen",
763
+ "750": "quilt, comforter, comfort, puff",
764
+ "751": "racer, race car, racing car",
765
+ "752": "racket, racquet",
766
+ "753": "radiator",
767
+ "754": "radio, wireless",
768
+ "755": "radio telescope, radio reflector",
769
+ "756": "rain barrel",
770
+ "757": "recreational vehicle, RV, R.V.",
771
+ "758": "reel",
772
+ "759": "reflex camera",
773
+ "760": "refrigerator, icebox",
774
+ "761": "remote control, remote",
775
+ "762": "restaurant, eating house, eating place, eatery",
776
+ "763": "revolver, six-gun, six-shooter",
777
+ "764": "rifle",
778
+ "765": "rocking chair, rocker",
779
+ "766": "rotisserie",
780
+ "767": "rubber eraser, rubber, pencil eraser",
781
+ "768": "rugby ball",
782
+ "769": "rule, ruler",
783
+ "770": "running shoe",
784
+ "771": "safe",
785
+ "772": "safety pin",
786
+ "773": "saltshaker, salt shaker",
787
+ "774": "sandal",
788
+ "775": "sarong",
789
+ "776": "sax, saxophone",
790
+ "777": "scabbard",
791
+ "778": "scale, weighing machine",
792
+ "779": "school bus",
793
+ "780": "schooner",
794
+ "781": "scoreboard",
795
+ "782": "screen, CRT screen",
796
+ "783": "screw",
797
+ "784": "screwdriver",
798
+ "785": "seat belt, seatbelt",
799
+ "786": "sewing machine",
800
+ "787": "shield, buckler",
801
+ "788": "shoe shop, shoe-shop, shoe store",
802
+ "789": "shoji",
803
+ "790": "shopping basket",
804
+ "791": "shopping cart",
805
+ "792": "shovel",
806
+ "793": "shower cap",
807
+ "794": "shower curtain",
808
+ "795": "ski",
809
+ "796": "ski mask",
810
+ "797": "sleeping bag",
811
+ "798": "slide rule, slipstick",
812
+ "799": "sliding door",
813
+ "800": "slot, one-armed bandit",
814
+ "801": "snorkel",
815
+ "802": "snowmobile",
816
+ "803": "snowplow, snowplough",
817
+ "804": "soap dispenser",
818
+ "805": "soccer ball",
819
+ "806": "sock",
820
+ "807": "solar dish, solar collector, solar furnace",
821
+ "808": "sombrero",
822
+ "809": "soup bowl",
823
+ "810": "space bar",
824
+ "811": "space heater",
825
+ "812": "space shuttle",
826
+ "813": "spatula",
827
+ "814": "speedboat",
828
+ "815": "spider web, spiders web",
829
+ "816": "spindle",
830
+ "817": "sports car, sport car",
831
+ "818": "spotlight, spot",
832
+ "819": "stage",
833
+ "820": "steam locomotive",
834
+ "821": "steel arch bridge",
835
+ "822": "steel drum",
836
+ "823": "stethoscope",
837
+ "824": "stole",
838
+ "825": "stone wall",
839
+ "826": "stopwatch, stop watch",
840
+ "827": "stove",
841
+ "828": "strainer",
842
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
843
+ "830": "stretcher",
844
+ "831": "studio couch, day bed",
845
+ "832": "stupa, tope",
846
+ "833": "submarine, pigboat, sub, U-boat",
847
+ "834": "suit, suit of clothes",
848
+ "835": "sundial",
849
+ "836": "sunglass",
850
+ "837": "sunglasses, dark glasses, shades",
851
+ "838": "sunscreen, sunblock, sun blocker",
852
+ "839": "suspension bridge",
853
+ "840": "swab, swob, mop",
854
+ "841": "sweatshirt",
855
+ "842": "swimming trunks, bathing trunks",
856
+ "843": "swing",
857
+ "844": "switch, electric switch, electrical switch",
858
+ "845": "syringe",
859
+ "846": "table lamp",
860
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
861
+ "848": "tape player",
862
+ "849": "teapot",
863
+ "850": "teddy, teddy bear",
864
+ "851": "television, television system",
865
+ "852": "tennis ball",
866
+ "853": "thatch, thatched roof",
867
+ "854": "theater curtain, theatre curtain",
868
+ "855": "thimble",
869
+ "856": "thresher, thrasher, threshing machine",
870
+ "857": "throne",
871
+ "858": "tile roof",
872
+ "859": "toaster",
873
+ "860": "tobacco shop, tobacconist shop, tobacconist",
874
+ "861": "toilet seat",
875
+ "862": "torch",
876
+ "863": "totem pole",
877
+ "864": "tow truck, tow car, wrecker",
878
+ "865": "toyshop",
879
+ "866": "tractor",
880
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
881
+ "868": "tray",
882
+ "869": "trench coat",
883
+ "870": "tricycle, trike, velocipede",
884
+ "871": "trimaran",
885
+ "872": "tripod",
886
+ "873": "triumphal arch",
887
+ "874": "trolleybus, trolley coach, trackless trolley",
888
+ "875": "trombone",
889
+ "876": "tub, vat",
890
+ "877": "turnstile",
891
+ "878": "typewriter keyboard",
892
+ "879": "umbrella",
893
+ "880": "unicycle, monocycle",
894
+ "881": "upright, upright piano",
895
+ "882": "vacuum, vacuum cleaner",
896
+ "883": "vase",
897
+ "884": "vault",
898
+ "885": "velvet",
899
+ "886": "vending machine",
900
+ "887": "vestment",
901
+ "888": "viaduct",
902
+ "889": "violin, fiddle",
903
+ "890": "volleyball",
904
+ "891": "waffle iron",
905
+ "892": "wall clock",
906
+ "893": "wallet, billfold, notecase, pocketbook",
907
+ "894": "wardrobe, closet, press",
908
+ "895": "warplane, military plane",
909
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
910
+ "897": "washer, automatic washer, washing machine",
911
+ "898": "water bottle",
912
+ "899": "water jug",
913
+ "900": "water tower",
914
+ "901": "whiskey jug",
915
+ "902": "whistle",
916
+ "903": "wig",
917
+ "904": "window screen",
918
+ "905": "window shade",
919
+ "906": "Windsor tie",
920
+ "907": "wine bottle",
921
+ "908": "wing",
922
+ "909": "wok",
923
+ "910": "wooden spoon",
924
+ "911": "wool, woolen, woollen",
925
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
926
+ "913": "wreck",
927
+ "914": "yawl",
928
+ "915": "yurt",
929
+ "916": "web site, website, internet site, site",
930
+ "917": "comic book",
931
+ "918": "crossword puzzle, crossword",
932
+ "919": "street sign",
933
+ "920": "traffic light, traffic signal, stoplight",
934
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
935
+ "922": "menu",
936
+ "923": "plate",
937
+ "924": "guacamole",
938
+ "925": "consomme",
939
+ "926": "hot pot, hotpot",
940
+ "927": "trifle",
941
+ "928": "ice cream, icecream",
942
+ "929": "ice lolly, lolly, lollipop, popsicle",
943
+ "930": "French loaf",
944
+ "931": "bagel, beigel",
945
+ "932": "pretzel",
946
+ "933": "cheeseburger",
947
+ "934": "hotdog, hot dog, red hot",
948
+ "935": "mashed potato",
949
+ "936": "head cabbage",
950
+ "937": "broccoli",
951
+ "938": "cauliflower",
952
+ "939": "zucchini, courgette",
953
+ "940": "spaghetti squash",
954
+ "941": "acorn squash",
955
+ "942": "butternut squash",
956
+ "943": "cucumber, cuke",
957
+ "944": "artichoke, globe artichoke",
958
+ "945": "bell pepper",
959
+ "946": "cardoon",
960
+ "947": "mushroom",
961
+ "948": "Granny Smith",
962
+ "949": "strawberry",
963
+ "950": "orange",
964
+ "951": "lemon",
965
+ "952": "fig",
966
+ "953": "pineapple, ananas",
967
+ "954": "banana",
968
+ "955": "jackfruit, jak, jack",
969
+ "956": "custard apple",
970
+ "957": "pomegranate",
971
+ "958": "hay",
972
+ "959": "carbonara",
973
+ "960": "chocolate sauce, chocolate syrup",
974
+ "961": "dough",
975
+ "962": "meat loaf, meatloaf",
976
+ "963": "pizza, pizza pie",
977
+ "964": "potpie",
978
+ "965": "burrito",
979
+ "966": "red wine",
980
+ "967": "espresso",
981
+ "968": "cup",
982
+ "969": "eggnog",
983
+ "970": "alp",
984
+ "971": "bubble",
985
+ "972": "cliff, drop, drop-off",
986
+ "973": "coral reef",
987
+ "974": "geyser",
988
+ "975": "lakeside, lakeshore",
989
+ "976": "promontory, headland, head, foreland",
990
+ "977": "sandbar, sand bar",
991
+ "978": "seashore, coast, seacoast, sea-coast",
992
+ "979": "valley, vale",
993
+ "980": "volcano",
994
+ "981": "ballplayer, baseball player",
995
+ "982": "groom, bridegroom",
996
+ "983": "scuba diver",
997
+ "984": "rapeseed",
998
+ "985": "daisy",
999
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1000
+ "987": "corn",
1001
+ "988": "acorn",
1002
+ "989": "hip, rose hip, rosehip",
1003
+ "990": "buckeye, horse chestnut, conker",
1004
+ "991": "coral fungus",
1005
+ "992": "agaric",
1006
+ "993": "gyromitra",
1007
+ "994": "stinkhorn, carrion fungus",
1008
+ "995": "earthstar",
1009
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1010
+ "997": "bolete",
1011
+ "998": "ear, spike, capitulum",
1012
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1013
+ }
1014
  }
PixelFlow-256/pipeline.py CHANGED
@@ -1,16 +1,23 @@
1
- """Hub custom pipeline: PixelFlowPipeline.
2
-
3
- Load with native Hugging Face diffusers and `trust_remote_code=True`.
4
- """
5
-
6
- from __future__ import annotations
 
 
 
 
 
 
 
7
 
8
  import importlib
 
9
  import math
10
  import sys
11
- from dataclasses import dataclass
12
  from pathlib import Path
13
- from typing import List, Optional, Tuple, Union
14
 
15
  import numpy as np
16
  import torch
@@ -19,25 +26,83 @@ from einops import rearrange
19
 
20
  from diffusers.image_processor import VaeImageProcessor
21
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
22
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
- from diffusers.utils import BaseOutput
 
24
  from diffusers.utils.torch_utils import randn_tensor
25
 
26
 
27
- @dataclass
28
- class PixelFlowPipelineOutput(BaseOutput):
29
- images: Union[torch.Tensor, List, np.ndarray]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  class PixelFlowPipeline(DiffusionPipeline):
33
- """Pipeline for PixelFlow pixel-space flow generation (class-conditional or text-to-image)."""
 
 
 
 
 
 
 
 
 
 
34
 
35
- model_cpu_offload_seq = "text_encoder->transformer"
36
- _optional_components = ["text_encoder", "tokenizer"]
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @classmethod
39
  def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
40
  """Load a self-contained variant folder locally or from the Hub."""
 
 
 
41
  repo_root = Path(__file__).resolve().parent
42
 
43
  if pretrained_model_name_or_path in (None, "", "."):
@@ -62,109 +127,78 @@ class PixelFlowPipeline(DiffusionPipeline):
62
  if subfolder:
63
  variant = variant / subfolder
64
 
 
 
65
  model_kwargs = dict(kwargs)
66
- inserted: List[str] = []
67
-
68
- def _load_component(folder: str, module_name: str, class_name: str):
69
- comp_dir = variant / folder
70
- module_path = comp_dir / f"{module_name}.py"
71
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
72
- if not module_path.exists() or not has_weights:
73
- return None
74
-
75
- comp_path = str(comp_dir)
76
- if comp_path not in sys.path:
77
- sys.path.insert(0, comp_path)
78
- inserted.append(comp_path)
79
-
80
- module = importlib.import_module(module_name)
81
- component_cls = getattr(module, class_name)
82
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
83
-
84
- def _load_text_components():
85
- text_encoder = None
86
- tokenizer = None
87
- te_dir = variant / "text_encoder"
88
- tok_dir = variant / "tokenizer"
89
- if te_dir.exists() and (te_dir / "config.json").exists():
90
- from transformers import T5EncoderModel, T5Tokenizer
91
-
92
- text_encoder = T5EncoderModel.from_pretrained(str(te_dir), **model_kwargs)
93
- tokenizer = T5Tokenizer.from_pretrained(str(tok_dir))
94
- return text_encoder, tokenizer
95
 
96
  try:
97
- transformer = _load_component("transformer", "transformer_pixelflow", "PixelFlowTransformer2DModel")
98
- scheduler = _load_component("scheduler", "scheduling_pixelflow", "PixelFlowScheduler")
99
- text_encoder, tokenizer = _load_text_components()
100
-
101
- if scheduler is None:
102
- sched_dir = variant / "scheduler"
103
- if (sched_dir / "scheduling_pixelflow.py").exists():
104
- sched_path = str(sched_dir)
105
- if sched_path not in sys.path:
106
- sys.path.insert(0, sched_path)
107
- inserted.append(sched_path)
108
- scheduler = importlib.import_module("scheduling_pixelflow").PixelFlowScheduler()
109
-
110
- if transformer is None:
111
  raise ValueError(f"No loadable transformer found under {variant}")
112
 
113
- id2label = None
114
- id2label_cn = None
115
- labels_dir = variant.parent / "labels"
116
- if labels_dir.is_dir():
117
- labels_path = str(labels_dir)
118
- if labels_path not in sys.path:
119
- sys.path.insert(0, labels_path)
120
- inserted.append(labels_path)
121
- from imagenet_labels import load_id2label
122
-
123
- id2label = load_id2label(labels_dir, lang="en")
124
- id2label_cn = load_id2label(labels_dir, lang="cn")
125
-
126
- return cls(
127
- transformer=transformer,
128
- scheduler=scheduler,
129
- text_encoder=text_encoder,
130
- tokenizer=tokenizer,
131
- id2label=id2label,
132
- id2label_cn=id2label_cn,
133
- )
134
  finally:
135
  for comp_path in inserted:
136
  if comp_path in sys.path:
137
  sys.path.remove(comp_path)
138
 
139
- def __init__(
140
- self,
141
- transformer,
142
- scheduler,
143
- text_encoder=None,
144
- tokenizer=None,
145
- max_token_length: int = 512,
146
- id2label: Optional[dict[int, str]] = None,
147
- id2label_cn: Optional[dict[int, str]] = None,
148
- ):
149
- super().__init__()
150
- self.register_modules(
151
- transformer=transformer,
152
- scheduler=scheduler,
153
- text_encoder=text_encoder,
154
- tokenizer=tokenizer,
155
- )
156
- self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
157
- self.class_cond = transformer.config.num_classes > 0
158
- self.max_token_length = max_token_length
159
 
160
- self._id2label = id2label or {}
161
- self._id2label_cn = id2label_cn or {}
162
- self.labels = self._build_label2id(self._id2label)
163
- self.labels_cn = self._build_label2id(self._id2label_cn)
 
 
 
 
 
 
 
 
164
 
165
  @staticmethod
166
- def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
167
- label2id: dict[str, int] = {}
168
  for class_id, value in id2label.items():
169
  for synonym in value.split(","):
170
  synonym = synonym.strip()
@@ -173,37 +207,23 @@ class PixelFlowPipeline(DiffusionPipeline):
173
  return dict(sorted(label2id.items()))
174
 
175
  @property
176
- def id2label(self) -> dict[int, str]:
177
- """ImageNet class id to English label string (comma-separated synonyms)."""
 
178
  return self._id2label
179
 
180
- @property
181
- def id2label_cn(self) -> dict[int, str]:
182
- """ImageNet class id to Chinese label string (comma-separated synonyms)."""
183
- return self._id2label_cn
184
-
185
- def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
186
  r"""
187
  Map ImageNet label strings to class ids.
188
 
189
  Args:
190
  label (`str` or `list[str]`):
191
- One or more label strings. Each string must match a synonym in `id2label` (English)
192
- or `id2label_cn` (Chinese).
193
- lang (`str`, *optional*, defaults to `"en"`):
194
- `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
195
-
196
- Returns:
197
- `list[int]`: Class ids for [`~PixelFlowPipeline.__call__`].
198
  """
199
- if lang not in ("en", "cn"):
200
- raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
201
-
202
- label2id = self.labels if lang == "en" else self.labels_cn
203
  if not label2id:
204
- raise ValueError(
205
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
206
- )
207
 
208
  if isinstance(label, str):
209
  label = [label]
@@ -211,279 +231,246 @@ class PixelFlowPipeline(DiffusionPipeline):
211
  missing = [item for item in label if item not in label2id]
212
  if missing:
213
  preview = ", ".join(list(label2id.keys())[:8])
214
- raise ValueError(
215
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
216
- )
217
  return [label2id[item] for item in label]
218
 
219
  def _normalize_class_labels(
220
  self,
221
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]],
222
- ) -> Optional[Union[int, List[int], torch.Tensor]]:
223
- if class_labels is None:
224
- return None
225
-
226
- if isinstance(class_labels, str):
227
- return self.get_label_ids(class_labels)[0]
228
-
229
- if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
230
- if all(label in self.labels for label in class_labels):
231
- return self.get_label_ids(class_labels, lang="en")
232
- if all(label in self.labels_cn for label in class_labels):
233
- return self.get_label_ids(class_labels, lang="cn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  raise ValueError(
235
- "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
236
- "or Chinese synonyms from `pipe.labels_cn`."
237
  )
 
238
 
239
- return class_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- def sample_block_noise(self, bs, ch, height, width, eps=1e-6):
 
 
 
 
 
 
 
242
  gamma = self.scheduler.gamma
243
  dist = torch.distributions.multivariate_normal.MultivariateNormal(
244
  torch.zeros(4),
245
  torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4),
246
  )
247
- block_number = bs * ch * (height // 2) * (width // 2)
248
  noise = torch.stack([dist.sample() for _ in range(block_number)])
249
- noise = rearrange(
250
  noise,
251
  "(b c h w) (p q) -> b c (h p) (w q)",
252
- b=bs,
253
- c=ch,
254
  h=height // 2,
255
  w=width // 2,
256
  p=2,
257
  q=2,
258
  )
259
- return noise
260
-
261
- def _stage_guidance_scale(self, stage_idx: int) -> float:
262
- if not self.class_cond:
263
- return self._guidance_scale_value
264
- scale_dict = {0: 0, 1: 1 / 6, 2: 2 / 3, 3: 1}
265
- return (self._guidance_scale_value - 1) * scale_dict[stage_idx] + 1
266
 
267
- @property
268
- def do_classifier_free_guidance(self) -> bool:
269
- return self._guidance_scale_value > 0
270
-
271
- @torch.no_grad()
272
- def encode_prompt(
273
  self,
274
- prompt: Union[str, List[str]],
 
 
 
275
  device: torch.device,
276
- num_images_per_prompt: int = 1,
277
- do_classifier_free_guidance: bool = True,
278
- negative_prompt: Union[str, List[str]] = "",
279
- max_length: Optional[int] = None,
280
- ) -> Tuple[torch.Tensor, torch.Tensor]:
281
- if self.text_encoder is None or self.tokenizer is None:
282
- raise ValueError("Text-to-image generation requires `text_encoder` and `tokenizer`.")
283
-
284
- if isinstance(prompt, str):
285
- prompt = [prompt]
286
- batch_size = len(prompt)
287
- max_length = max_length or self.max_token_length
288
-
289
- text_inputs = self.tokenizer(
290
- prompt,
291
- padding="max_length",
292
- max_length=max_length,
293
- truncation=True,
294
- add_special_tokens=True,
295
- return_tensors="pt",
296
  )
297
- text_input_ids = text_inputs.input_ids.to(device)
298
- prompt_attention_mask = text_inputs.attention_mask.to(device)
299
- prompt_embeds = self.text_encoder(
300
- text_input_ids,
301
- attention_mask=prompt_attention_mask,
302
- )[0]
303
-
304
- dtype = self.text_encoder.dtype
305
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
306
- bs_embed, seq_len, _ = prompt_embeds.shape
307
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
308
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
309
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
310
-
311
- if do_classifier_free_guidance:
312
- if isinstance(negative_prompt, str):
313
- uncond_tokens = [negative_prompt] * batch_size
314
- elif isinstance(negative_prompt, list):
315
- if len(negative_prompt) != batch_size:
316
- raise ValueError(
317
- f"Negative prompt list length ({len(negative_prompt)}) must match prompt batch ({batch_size})."
318
- )
319
- uncond_tokens = negative_prompt
320
- else:
321
- raise ValueError("Negative prompt must be a string or list of strings.")
322
-
323
- uncond_inputs = self.tokenizer(
324
- uncond_tokens,
325
- padding="max_length",
326
- max_length=prompt_embeds.shape[1],
327
- truncation=True,
328
- return_attention_mask=True,
329
- add_special_tokens=True,
330
- return_tensors="pt",
331
- )
332
- negative_input_ids = uncond_inputs.input_ids.to(device)
333
- negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
334
- negative_prompt_embeds = self.text_encoder(
335
- negative_input_ids,
336
- attention_mask=negative_prompt_attention_mask,
337
- )[0]
338
 
339
- seq_len_neg = negative_prompt_embeds.shape[1]
340
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
341
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
342
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
343
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
344
-
345
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
346
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
347
 
348
- return prompt_embeds, prompt_attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
- @torch.no_grad()
 
351
  def __call__(
352
  self,
353
- prompt: Optional[Union[str, List[str]]] = None,
354
- class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]] = None,
355
  height: Optional[int] = None,
356
  width: Optional[int] = None,
357
  num_inference_steps: Union[int, List[int]] = 10,
358
  guidance_scale: float = 4.0,
359
  shift: float = 1.0,
360
- negative_prompt: Union[str, List[str]] = "",
361
- num_images_per_prompt: int = 1,
362
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
363
  output_type: str = "pil",
364
  return_dict: bool = True,
365
- ) -> Union[PixelFlowPipelineOutput, Tuple]:
366
- if height is None:
367
- height = int(self.transformer.config.sample_size)
368
- if width is None:
369
- width = int(self.transformer.config.sample_size)
370
-
371
- device = self._execution_device
372
- self._guidance_scale_value = guidance_scale
373
 
374
- if isinstance(num_inference_steps, int):
375
- num_inference_steps = [num_inference_steps] * self.scheduler.num_stages
376
-
377
- prompt_attention_mask = None
378
- if self.class_cond:
379
- if class_labels is None:
380
- raise ValueError("`class_labels` are required for class-conditional PixelFlow checkpoints.")
381
- class_labels = self._normalize_class_labels(class_labels)
382
- if isinstance(class_labels, int):
383
- class_labels = [class_labels]
384
- if not torch.is_tensor(class_labels):
385
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
386
- else:
387
- class_labels = class_labels.to(device=device, dtype=torch.long)
388
-
389
- batch_size = class_labels.shape[0]
390
- prompt_embeds = class_labels
391
- negative_prompt_embeds = torch.full_like(prompt_embeds, self.transformer.config.num_classes)
392
- if self.do_classifier_free_guidance:
393
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
394
- else:
395
- if prompt is None:
396
- raise ValueError("`prompt` is required for text-to-image PixelFlow checkpoints.")
397
- if isinstance(prompt, str):
398
- prompt = [prompt]
399
- batch_size = len(prompt)
400
- prompt_embeds, prompt_attention_mask = self.encode_prompt(
401
- prompt,
402
- device,
403
- num_images_per_prompt=num_images_per_prompt,
404
- do_classifier_free_guidance=self.do_classifier_free_guidance and guidance_scale > 1.0,
405
- negative_prompt=negative_prompt,
406
- )
407
 
408
- init_factor = 2 ** (self.scheduler.num_stages - 1)
409
- height, width = height // init_factor, width // init_factor
410
- latents = randn_tensor(
411
- (batch_size * num_images_per_prompt, 3, height, width),
412
- generator=generator,
413
- device=device,
414
- dtype=torch.float32,
415
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- for stage_idx in range(self.scheduler.num_stages):
418
- self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
419
- timesteps = self.scheduler.Timesteps
420
-
421
- if stage_idx > 0:
422
- height, width = height * 2, width * 2
423
- latents = F.interpolate(latents, size=(height, width), mode="nearest")
424
- original_start_t = self.scheduler.original_start_t[stage_idx]
425
- gamma = self.scheduler.gamma
426
- alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
427
- beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
428
-
429
- noise = self.sample_block_noise(*latents.shape)
430
- noise = noise.to(device=device, dtype=latents.dtype)
431
- latents = alpha * latents + beta * noise
432
-
433
- size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
434
- pos_embed = get_2d_rotary_pos_embed(
435
- embed_dim=self.transformer.attention_head_dim,
436
- crops_coords=((0, 0), (latents.shape[-1] // self.transformer.patch_size, latents.shape[-1] // self.transformer.patch_size)),
437
- grid_size=(latents.shape[-1] // self.transformer.patch_size, latents.shape[-1] // self.transformer.patch_size),
438
- device=device,
439
- output_type="pt",
440
- )
441
- rope_pos = torch.stack(pos_embed, -1)
442
-
443
- autocast_enabled = device.type == "cuda"
444
- autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
445
- for timestep in timesteps:
446
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
447
- timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
448
- with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
449
- if self.class_cond:
450
- noise_pred = self.transformer(
451
- latent_model_input,
452
- timestep=timestep_batch,
453
- class_labels=prompt_embeds,
454
- latent_size=size_tensor,
455
- pos_embed=rope_pos,
456
- ).sample
457
- else:
458
  noise_pred = self.transformer(
459
  latent_model_input,
460
- encoder_hidden_states=prompt_embeds,
461
- encoder_attention_mask=prompt_attention_mask,
462
  timestep=timestep_batch,
 
463
  latent_size=size_tensor,
464
  pos_embed=rope_pos,
465
  ).sample
466
 
467
- if self.do_classifier_free_guidance:
468
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
469
- noise_pred = noise_pred_uncond + self._stage_guidance_scale(stage_idx) * (
470
- noise_pred_text - noise_pred_uncond
471
- )
472
 
473
- latents = self.scheduler.step(model_output=noise_pred, sample=latents).prev_sample
474
-
475
- image = (latents / 2 + 0.5).clamp(0, 1)
476
-
477
- if output_type == "pt":
478
- pass
479
- elif output_type in ("pil", "np"):
480
- image = self.image_processor.postprocess(image, output_type=output_type)
481
- else:
482
- raise ValueError(f"Unsupported output_type: {output_type}")
483
 
 
484
  self.maybe_free_model_hooks()
485
 
486
  if not return_dict:
487
  return (image,)
 
488
 
489
- return PixelFlowPipelineOutput(images=image)
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
  import importlib
16
+ import json
17
  import math
18
  import sys
 
19
  from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
  import numpy as np
23
  import torch
 
26
 
27
  from diffusers.image_processor import VaeImageProcessor
28
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
+ from diffusers.schedulers import KarrasDiffusionSchedulers
31
+ from diffusers.utils import replace_example_docstring
32
  from diffusers.utils.torch_utils import randn_tensor
33
 
34
 
35
+ DEFAULT_NATIVE_RESOLUTION = 256
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> from pathlib import Path
41
+ >>> import torch
42
+ >>> from diffusers import DiffusionPipeline
43
+
44
+ >>> model_dir = Path("./PixelFlow-256").resolve()
45
+ >>> pipe = DiffusionPipeline.from_pretrained(
46
+ ... str(model_dir),
47
+ ... local_files_only=True,
48
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
49
+ ... trust_remote_code=True,
50
+ ... torch_dtype=torch.bfloat16,
51
+ ... )
52
+ >>> pipe = pipe.to("cuda")
53
+
54
+ >>> print(pipe.id2label[207])
55
+ >>> print(pipe.get_label_ids("golden retriever"))
56
+
57
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
58
+ >>> image = pipe(
59
+ ... class_labels="golden retriever",
60
+ ... height=256,
61
+ ... width=256,
62
+ ... num_inference_steps=[10, 10, 10, 10],
63
+ ... guidance_scale=4.0,
64
+ ... generator=generator,
65
+ ... ).images[0]
66
+ >>> image.save("demo.png")
67
+ ```
68
+ """
69
+
70
 
71
 
72
  class PixelFlowPipeline(DiffusionPipeline):
73
+ r"""
74
+ Pipeline for class-conditional PixelFlow pixel-space cascade generation.
75
+
76
+ Parameters:
77
+ transformer ([`PixelFlowTransformer2DModel`]):
78
+ Class-conditional PixelFlow transformer operating in pixel space.
79
+ scheduler ([`PixelFlowScheduler`] or [`KarrasDiffusionSchedulers`]):
80
+ Multi-stage flow scheduler used by PixelFlow cascade denoising.
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
83
+ """
84
 
85
+ model_cpu_offload_seq = "transformer"
86
+
87
+ def __init__(
88
+ self,
89
+ transformer: Any,
90
+ scheduler: Any,
91
+ id2label: Optional[Dict[Union[int, str], str]] = None,
92
+ ):
93
+ super().__init__()
94
+ self.register_modules(transformer=transformer, scheduler=scheduler)
95
+ self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
96
+ self._id2label = self._normalize_id2label(id2label)
97
+ self.labels = self._build_label2id(self._id2label)
98
+ self._labels_loaded_from_model_index = bool(self._id2label)
99
 
100
  @classmethod
101
  def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
102
  """Load a self-contained variant folder locally or from the Hub."""
103
+ import importlib
104
+ import sys
105
+
106
  repo_root = Path(__file__).resolve().parent
107
 
108
  if pretrained_model_name_or_path in (None, "", "."):
 
127
  if subfolder:
128
  variant = variant / subfolder
129
 
130
+ id2label_override = kwargs.pop("id2label", None)
131
+ kwargs.pop("trust_remote_code", None)
132
  model_kwargs = dict(kwargs)
133
+ scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
134
+ inserted = []
135
+
136
+ def _ensure_path(path: str) -> None:
137
+ if path not in sys.path:
138
+ sys.path.insert(0, path)
139
+ inserted.append(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  try:
142
+ transformer_dir = variant / "transformer"
143
+ if not (transformer_dir / "transformer_pixelflow.py").exists() or not (transformer_dir / "config.json").exists():
 
 
 
 
 
 
 
 
 
 
 
 
144
  raise ValueError(f"No loadable transformer found under {variant}")
145
 
146
+ _ensure_path(str(transformer_dir))
147
+ transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
148
+ transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
149
+
150
+ scheduler_dir = variant / "scheduler"
151
+ if not (scheduler_dir / "scheduler_config.json").exists():
152
+ raise FileNotFoundError(f"Expected scheduler config in {scheduler_dir}")
153
+
154
+ _ensure_path(str(scheduler_dir))
155
+ scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
156
+ try:
157
+ scheduler = scheduler_cls.from_pretrained(str(scheduler_dir), **scheduler_kwargs)
158
+ except Exception:
159
+ scheduler = scheduler_cls(**scheduler_kwargs)
160
+
161
+ id2label = id2label_override or cls._read_id2label_from_model_index(str(variant))
162
+ pipe = cls(transformer=transformer, scheduler=scheduler, id2label=id2label)
163
+ if hasattr(pipe, "register_to_config"):
164
+ pipe.register_to_config(_name_or_path=str(variant))
165
+ return pipe
 
166
  finally:
167
  for comp_path in inserted:
168
  if comp_path in sys.path:
169
  sys.path.remove(comp_path)
170
 
171
+ @staticmethod
172
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
173
+ if not id2label:
174
+ return {}
175
+ return {int(key): value for key, value in id2label.items()}
176
+
177
+ def _ensure_labels_loaded(self) -> None:
178
+ if self._labels_loaded_from_model_index:
179
+ return
180
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
181
+ if loaded:
182
+ self._id2label = loaded
183
+ self.labels = self._build_label2id(self._id2label)
184
+ self._labels_loaded_from_model_index = True
 
 
 
 
 
 
185
 
186
+ @staticmethod
187
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
188
+ if not variant_path:
189
+ return {}
190
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
191
+ if not model_index_path.exists():
192
+ return {}
193
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
194
+ id2label = raw.get("id2label")
195
+ if not isinstance(id2label, dict):
196
+ return {}
197
+ return {int(key): value for key, value in id2label.items()}
198
 
199
  @staticmethod
200
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
201
+ label2id: Dict[str, int] = {}
202
  for class_id, value in id2label.items():
203
  for synonym in value.split(","):
204
  synonym = synonym.strip()
 
207
  return dict(sorted(label2id.items()))
208
 
209
  @property
210
+ def id2label(self) -> Dict[int, str]:
211
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
212
+ self._ensure_labels_loaded()
213
  return self._id2label
214
 
215
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
 
 
 
 
 
216
  r"""
217
  Map ImageNet label strings to class ids.
218
 
219
  Args:
220
  label (`str` or `list[str]`):
221
+ One or more English label strings. Each string must match a synonym in `id2label`.
 
 
 
 
 
 
222
  """
223
+ self._ensure_labels_loaded()
224
+ label2id = self.labels
 
 
225
  if not label2id:
226
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
227
 
228
  if isinstance(label, str):
229
  label = [label]
 
231
  missing = [item for item in label if item not in label2id]
232
  if missing:
233
  preview = ", ".join(list(label2id.keys())[:8])
234
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
235
  return [label2id[item] for item in label]
236
 
237
  def _normalize_class_labels(
238
  self,
239
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
240
+ ) -> torch.LongTensor:
241
+ if torch.is_tensor(class_labels):
242
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
243
+
244
+ if isinstance(class_labels, int):
245
+ class_label_ids = [class_labels]
246
+ elif isinstance(class_labels, str):
247
+ class_label_ids = self.get_label_ids(class_labels)
248
+ elif class_labels and isinstance(class_labels[0], str):
249
+ class_label_ids = self.get_label_ids(class_labels)
250
+ else:
251
+ class_label_ids = list(class_labels)
252
+
253
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
254
+
255
+ def check_inputs(
256
+ self,
257
+ height: int,
258
+ width: int,
259
+ num_inference_steps: Union[int, List[int]],
260
+ output_type: str,
261
+ ) -> None:
262
+ if output_type not in {"pil", "np", "pt", "latent"}:
263
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
264
+
265
+ stage_steps = self._normalize_stage_steps(num_inference_steps)
266
+ if any(steps < 1 for steps in stage_steps):
267
+ raise ValueError("Each stage in num_inference_steps must be >= 1.")
268
+
269
+ if height <= 0 or width <= 0:
270
+ raise ValueError("height and width must be positive integers.")
271
+
272
+ def _normalize_stage_steps(self, num_inference_steps: Union[int, List[int]]) -> List[int]:
273
+ if isinstance(num_inference_steps, int):
274
+ return [num_inference_steps] * self.scheduler.num_stages
275
+ if len(num_inference_steps) != self.scheduler.num_stages:
276
  raise ValueError(
277
+ f"num_inference_steps must have length {self.scheduler.num_stages} "
278
+ f"(one value per stage), got {len(num_inference_steps)}."
279
  )
280
+ return list(num_inference_steps)
281
 
282
+ def prepare_latents(
283
+ self,
284
+ batch_size: int,
285
+ height: int,
286
+ width: int,
287
+ device: torch.device,
288
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
289
+ ) -> Tuple[torch.Tensor, int, int]:
290
+ init_factor = 2 ** (self.scheduler.num_stages - 1)
291
+ coarse_height = height // init_factor
292
+ coarse_width = width // init_factor
293
+ latents = randn_tensor(
294
+ (batch_size, 3, coarse_height, coarse_width),
295
+ generator=generator,
296
+ device=device,
297
+ dtype=torch.float32,
298
+ )
299
+ return latents, coarse_height, coarse_width
300
 
301
+ def _sample_block_noise(
302
+ self,
303
+ batch_size: int,
304
+ channels: int,
305
+ height: int,
306
+ width: int,
307
+ eps: float = 1e-6,
308
+ ) -> torch.Tensor:
309
  gamma = self.scheduler.gamma
310
  dist = torch.distributions.multivariate_normal.MultivariateNormal(
311
  torch.zeros(4),
312
  torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4),
313
  )
314
+ block_number = batch_size * channels * (height // 2) * (width // 2)
315
  noise = torch.stack([dist.sample() for _ in range(block_number)])
316
+ return rearrange(
317
  noise,
318
  "(b c h w) (p q) -> b c (h p) (w q)",
319
+ b=batch_size,
320
+ c=channels,
321
  h=height // 2,
322
  w=width // 2,
323
  p=2,
324
  q=2,
325
  )
 
 
 
 
 
 
 
326
 
327
+ def _upsample_latents_for_stage(
 
 
 
 
 
328
  self,
329
+ latents: torch.Tensor,
330
+ stage_idx: int,
331
+ height: int,
332
+ width: int,
333
  device: torch.device,
334
+ ) -> torch.Tensor:
335
+ latents = F.interpolate(latents, size=(height, width), mode="nearest")
336
+ original_start_t = self.scheduler.original_start_t[stage_idx]
337
+ gamma = self.scheduler.gamma
338
+ alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
339
+ beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
340
+
341
+ noise = self._sample_block_noise(*latents.shape)
342
+ noise = noise.to(device=device, dtype=latents.dtype)
343
+ return alpha * latents + beta * noise
344
+
345
+ def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
346
+ grid_size = latents.shape[-1] // self.transformer.patch_size
347
+ pos_embed = get_2d_rotary_pos_embed(
348
+ embed_dim=self.transformer.attention_head_dim,
349
+ crops_coords=((0, 0), (grid_size, grid_size)),
350
+ grid_size=(grid_size, grid_size),
351
+ device=device,
352
+ output_type="pt",
 
353
  )
354
+ return torch.stack(pos_embed, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ def _stage_guidance_scale(self, stage_idx: int, guidance_scale: float) -> float:
357
+ scale_dict = {0: 0, 1: 1 / 6, 2: 2 / 3, 3: 1}
358
+ return (guidance_scale - 1) * scale_dict[stage_idx] + 1
 
 
 
 
 
359
 
360
+ def _encode_class_condition(
361
+ self,
362
+ class_labels_tensor: torch.LongTensor,
363
+ guidance_scale: float,
364
+ ) -> torch.LongTensor:
365
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
366
+ if guidance_scale > 0:
367
+ return torch.cat([null_labels, class_labels_tensor], dim=0)
368
+ return class_labels_tensor
369
+
370
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
371
+ image = (latents / 2 + 0.5).clamp(0, 1)
372
+ if output_type == "latent":
373
+ return latents
374
+ if output_type == "pt":
375
+ return image
376
+ if output_type in {"pil", "np"}:
377
+ return self.image_processor.postprocess(image, output_type=output_type)
378
+ raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
379
 
380
+ @torch.inference_mode()
381
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
382
  def __call__(
383
  self,
384
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
 
385
  height: Optional[int] = None,
386
  width: Optional[int] = None,
387
  num_inference_steps: Union[int, List[int]] = 10,
388
  guidance_scale: float = 4.0,
389
  shift: float = 1.0,
 
 
390
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
391
  output_type: str = "pil",
392
  return_dict: bool = True,
393
+ ) -> Union[ImagePipelineOutput, Tuple]:
394
+ r"""
395
+ Generate class-conditional images with PixelFlow.
 
 
 
 
 
396
 
397
+ Examples:
398
+ <!-- this section is replaced by replace_example_docstring -->
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ Args:
401
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
402
+ ImageNet class indices or human-readable English label strings.
403
+ height (`int`, *optional*):
404
+ Output image height in pixels. Defaults to the transformer's native resolution.
405
+ width (`int`, *optional*):
406
+ Output image width in pixels. Defaults to the transformer's native resolution.
407
+ num_inference_steps (`int` or `list[int]`, defaults to `10`):
408
+ Number of denoising steps per cascade stage.
409
+ guidance_scale (`float`, defaults to `4.0`):
410
+ Classifier-free guidance scale. Guidance is stage-weighted for PixelFlow cascades.
411
+ shift (`float`, defaults to `1.0`):
412
+ Noise shift applied by the scheduler when building stage timesteps.
413
+ generator (`torch.Generator`, *optional*):
414
+ RNG for reproducibility.
415
+ output_type (`str`, defaults to `"pil"`):
416
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
417
+ return_dict (`bool`, defaults to `True`):
418
+ Return [`ImagePipelineOutput`] if True.
419
+ """
420
+ default_size = int(getattr(self.transformer.config, "sample_size", DEFAULT_NATIVE_RESOLUTION))
421
+ height = int(height or default_size)
422
+ width = int(width or default_size)
423
+ self.check_inputs(height, width, num_inference_steps, output_type)
424
 
425
+ device = self._execution_device
426
+ do_classifier_free_guidance = guidance_scale > 0
427
+ stage_steps = self._normalize_stage_steps(num_inference_steps)
428
+ class_labels_tensor = self._normalize_class_labels(class_labels)
429
+ batch_size = class_labels_tensor.numel()
430
+ conditioning = self._encode_class_condition(class_labels_tensor, guidance_scale)
431
+
432
+ latents, height, width = self.prepare_latents(batch_size, height, width, device, generator)
433
+ size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
434
+
435
+ autocast_enabled = device.type == "cuda"
436
+ autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
437
+
438
+ with self.progress_bar(total=sum(stage_steps)) as progress_bar:
439
+ for stage_idx in range(self.scheduler.num_stages):
440
+ self.scheduler.set_timesteps(stage_steps[stage_idx], stage_idx, device=device, shift=shift)
441
+ timesteps = self.scheduler.Timesteps
442
+
443
+ if stage_idx > 0:
444
+ height, width = height * 2, width * 2
445
+ latents = self._upsample_latents_for_stage(latents, stage_idx, height, width, device)
446
+ size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
447
+
448
+ rope_pos = self._prepare_rope_pos_embed(latents, device)
449
+
450
+ for timestep in timesteps:
451
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
452
+ timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
453
+ with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
 
 
 
 
 
 
 
 
 
 
 
 
454
  noise_pred = self.transformer(
455
  latent_model_input,
 
 
456
  timestep=timestep_batch,
457
+ class_labels=conditioning,
458
  latent_size=size_tensor,
459
  pos_embed=rope_pos,
460
  ).sample
461
 
462
+ if do_classifier_free_guidance:
463
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
464
+ stage_scale = self._stage_guidance_scale(stage_idx, guidance_scale)
465
+ noise_pred = noise_pred_uncond + stage_scale * (noise_pred_text - noise_pred_uncond)
 
466
 
467
+ latents = self.scheduler.step(model_output=noise_pred, sample=latents).prev_sample
468
+ progress_bar.update()
 
 
 
 
 
 
 
 
469
 
470
+ image = self.decode_latents(latents, output_type=output_type)
471
  self.maybe_free_model_hooks()
472
 
473
  if not return_dict:
474
  return (image,)
475
+ return ImagePipelineOutput(images=image)
476
 
 
PixelFlow-256/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "PixelFlowScheduler",
3
  "_diffusers_version": "0.36.0",
4
- "gamma": -0.3333333333333333,
5
  "num_stages": 4,
6
- "num_train_timesteps": 1000
7
  }
 
1
  {
2
  "_class_name": "PixelFlowScheduler",
3
  "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
  "num_stages": 4,
6
+ "gamma": -0.3333333333333333
7
  }
PixelFlow-256/scheduler/scheduling_pixelflow.py CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  from dataclasses import dataclass
3
  from typing import Optional, Tuple, Union
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import math
16
  from dataclasses import dataclass
17
  from typing import Optional, Tuple, Union
PixelFlow-256/transformer/transformer_pixelflow.py CHANGED
@@ -1,14 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  from typing import Optional, Tuple, Union
3
 
4
  import torch
 
 
5
 
6
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
7
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
  from diffusers.models.modeling_utils import ModelMixin
9
  from diffusers.utils import BaseOutput
10
 
11
- from modeling_pixelflow import PixelFlowModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  @dataclass
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
  from dataclasses import dataclass
17
  from typing import Optional, Tuple, Union
18
 
19
  import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
 
23
  from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.embeddings import LabelEmbedding, TimestepEmbedding, Timesteps
25
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
26
  from diffusers.models.modeling_utils import ModelMixin
27
  from diffusers.utils import BaseOutput
28
 
29
+ try:
30
+ from flash_attn import flash_attn_varlen_func
31
+ except ImportError:
32
+ warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
33
+ flash_attn_varlen_func = None
34
+
35
+
36
+ def apply_rotary_emb(
37
+ x: torch.Tensor,
38
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ cos, sin = freqs_cis.unbind(-1)
41
+ cos = cos[None, None]
42
+ sin = sin[None, None]
43
+ cos, sin = cos.to(x.device), sin.to(x.device)
44
+
45
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
46
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
47
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
48
+
49
+ return out
50
+
51
+
52
+ class PatchEmbed(nn.Module):
53
+ def __init__(self, patch_size, in_channels, embed_dim, bias=True):
54
+ super().__init__()
55
+ self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
56
+
57
+ def forward_unfold(self, x):
58
+ out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
59
+ if self.proj.bias is not None:
60
+ out_unfold += self.proj.bias.to(out_unfold.dtype)
61
+ return out_unfold
62
+
63
+ def forward(self, x):
64
+ if self.training:
65
+ return self.forward_unfold(x)
66
+ out = self.proj(x)
67
+ out = out.flatten(2).transpose(1, 2)
68
+ return out
69
+
70
+
71
+ class AdaLayerNorm(nn.Module):
72
+ def __init__(self, embedding_dim):
73
+ super().__init__()
74
+ self.embedding_dim = embedding_dim
75
+ self.silu = nn.SiLU()
76
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
77
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
78
+
79
+ def forward(self, x, timestep, seqlen_list=None):
80
+ input_dtype = x.dtype
81
+ emb = self.linear(self.silu(timestep))
82
+
83
+ if seqlen_list is not None:
84
+ emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
85
+ else:
86
+ emb = emb.unsqueeze(1)
87
+
88
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
89
+ x = self.norm(x).float() * (1 + scale_msa) + shift_msa
90
+ return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
91
+
92
+
93
+ class FeedForward(nn.Module):
94
+ def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
95
+ super().__init__()
96
+ inner_dim = int(dim * mult) if inner_dim is None else inner_dim
97
+ dim_out = dim_out if dim_out is not None else dim
98
+ self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
99
+ self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
100
+
101
+ def forward(self, hidden_states):
102
+ hidden_states = self.fc1(hidden_states)
103
+ hidden_states = F.gelu(hidden_states, approximate="tanh")
104
+ hidden_states = self.fc2(hidden_states)
105
+ return hidden_states
106
+
107
+
108
+ class RMSNorm(nn.Module):
109
+ def __init__(self, dim: int, eps=1e-6):
110
+ super().__init__()
111
+ self.weight = nn.Parameter(torch.ones(dim))
112
+ self.eps = eps
113
+
114
+ def forward(self, x):
115
+ output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
116
+ return (self.weight * output).to(x.dtype)
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
121
+ super().__init__()
122
+ self.q_dim = q_dim
123
+ self.kv_dim = kv_dim if kv_dim is not None else q_dim
124
+ self.inner_dim = head_dim * heads
125
+ self.dropout = dropout
126
+ self.head_dim = head_dim
127
+ self.num_heads = heads
128
+
129
+ self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
130
+ self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
131
+ self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
132
+ self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
133
+ self.q_norm = RMSNorm(self.inner_dim)
134
+ self.k_norm = RMSNorm(self.inner_dim)
135
+
136
+ def prepare_attention_mask(self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3):
137
+ head_size = self.num_heads
138
+ if attention_mask is None:
139
+ return attention_mask
140
+
141
+ current_length: int = attention_mask.shape[-1]
142
+ if current_length != target_length:
143
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
144
+
145
+ if out_dim == 3:
146
+ if attention_mask.shape[0] < batch_size * head_size:
147
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
148
+ elif out_dim == 4:
149
+ attention_mask = attention_mask.unsqueeze(1)
150
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
151
+
152
+ return attention_mask
153
+
154
+ def forward(
155
+ self,
156
+ inputs_q,
157
+ inputs_kv,
158
+ attention_mask=None,
159
+ cross_attention=False,
160
+ rope_pos_embed=None,
161
+ cu_seqlens_q=None,
162
+ cu_seqlens_k=None,
163
+ max_seqlen_q=None,
164
+ max_seqlen_k=None,
165
+ ):
166
+ inputs_kv = inputs_q if inputs_kv is None else inputs_kv
167
+
168
+ query_states = self.q_proj(inputs_q)
169
+ key_states = self.k_proj(inputs_kv)
170
+ value_states = self.v_proj(inputs_kv)
171
+
172
+ query_states = self.q_norm(query_states)
173
+ key_states = self.k_norm(key_states)
174
+
175
+ if max_seqlen_q is None:
176
+ assert not self.training, "PixelFlow needs sequence packing for training"
177
+
178
+ bsz, q_len, _ = inputs_q.shape
179
+ _, kv_len, _ = inputs_kv.shape
180
+
181
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
+ key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
183
+ value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
184
+
185
+ query_states = apply_rotary_emb(query_states, rope_pos_embed)
186
+ if not cross_attention:
187
+ key_states = apply_rotary_emb(key_states, rope_pos_embed)
188
+
189
+ if attention_mask is not None:
190
+ attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
191
+ attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
192
+
193
+ attn_output = F.scaled_dot_product_attention(
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ attn_mask=attention_mask,
198
+ dropout_p=self.dropout if self.training else 0.0,
199
+ is_causal=False,
200
+ )
201
+
202
+ attn_output = attn_output.transpose(1, 2).contiguous()
203
+ attn_output = attn_output.view(bsz, q_len, self.inner_dim)
204
+ attn_output = self.o_proj(attn_output)
205
+ return attn_output
206
+
207
+ query_states = query_states.view(-1, self.num_heads, self.head_dim)
208
+ key_states = key_states.view(-1, self.num_heads, self.head_dim)
209
+ value_states = value_states.view(-1, self.num_heads, self.head_dim)
210
+
211
+ query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
212
+ if not cross_attention:
213
+ key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
214
+
215
+ attn_output = flash_attn_varlen_func(
216
+ query_states,
217
+ key_states,
218
+ value_states,
219
+ cu_seqlens_q=cu_seqlens_q,
220
+ cu_seqlens_k=cu_seqlens_k,
221
+ max_seqlen_q=max_seqlen_q,
222
+ max_seqlen_k=max_seqlen_k,
223
+ )
224
+
225
+ attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
226
+ attn_output = self.o_proj(attn_output)
227
+ return attn_output
228
+
229
+
230
+ class TransformerBlock(nn.Module):
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ num_attention_heads,
235
+ attention_head_dim,
236
+ dropout=0.0,
237
+ cross_attention_dim=None,
238
+ attention_bias=False,
239
+ ):
240
+ super().__init__()
241
+ self.norm1 = AdaLayerNorm(dim)
242
+ self.attn1 = Attention(
243
+ q_dim=dim,
244
+ kv_dim=None,
245
+ heads=num_attention_heads,
246
+ head_dim=attention_head_dim,
247
+ dropout=dropout,
248
+ bias=attention_bias,
249
+ )
250
+
251
+ if cross_attention_dim is not None:
252
+ self.norm2 = RMSNorm(dim, eps=1e-6)
253
+ self.attn2 = Attention(
254
+ q_dim=dim,
255
+ kv_dim=cross_attention_dim,
256
+ heads=num_attention_heads,
257
+ head_dim=attention_head_dim,
258
+ dropout=dropout,
259
+ bias=attention_bias,
260
+ )
261
+ else:
262
+ self.attn2 = None
263
+
264
+ self.norm3 = RMSNorm(dim, eps=1e-6)
265
+ self.mlp = FeedForward(dim)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ encoder_hidden_states=None,
271
+ encoder_attention_mask=None,
272
+ timestep=None,
273
+ rope_pos_embed=None,
274
+ cu_seqlens_q=None,
275
+ cu_seqlens_k=None,
276
+ seqlen_list_q=None,
277
+ seqlen_list_k=None,
278
+ ):
279
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
280
+
281
+ attn_output = self.attn1(
282
+ inputs_q=norm_hidden_states,
283
+ inputs_kv=None,
284
+ attention_mask=None,
285
+ cross_attention=False,
286
+ rope_pos_embed=rope_pos_embed,
287
+ cu_seqlens_q=cu_seqlens_q,
288
+ cu_seqlens_k=cu_seqlens_q,
289
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
290
+ max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
291
+ )
292
+
293
+ attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
294
+ hidden_states = attn_output + hidden_states
295
+
296
+ if self.attn2 is not None:
297
+ norm_hidden_states = self.norm2(hidden_states)
298
+ attn_output = self.attn2(
299
+ inputs_q=norm_hidden_states,
300
+ inputs_kv=encoder_hidden_states,
301
+ attention_mask=encoder_attention_mask,
302
+ cross_attention=True,
303
+ rope_pos_embed=rope_pos_embed,
304
+ cu_seqlens_q=cu_seqlens_q,
305
+ cu_seqlens_k=cu_seqlens_k,
306
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
307
+ max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
308
+ )
309
+ hidden_states = hidden_states + attn_output
310
+
311
+ norm_hidden_states = self.norm3(hidden_states)
312
+ norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
313
+ ff_output = self.mlp(norm_hidden_states)
314
+ ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
315
+ hidden_states = ff_output + hidden_states
316
+
317
+ return hidden_states
318
+
319
+
320
+ class PixelFlowModel(nn.Module):
321
+ def __init__(
322
+ self,
323
+ in_channels,
324
+ out_channels,
325
+ num_attention_heads,
326
+ attention_head_dim,
327
+ depth,
328
+ patch_size,
329
+ dropout=0.0,
330
+ cross_attention_dim=None,
331
+ attention_bias=True,
332
+ num_classes=0,
333
+ init_weights=True,
334
+ ):
335
+ super().__init__()
336
+ self.patch_size = patch_size
337
+ self.attention_head_dim = attention_head_dim
338
+ self.num_classes = num_classes
339
+ self.out_channels = out_channels
340
+
341
+ embed_dim = num_attention_heads * attention_head_dim
342
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
343
+
344
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
345
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
346
+ self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
347
+ if self.num_classes > 0:
348
+ self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
349
+
350
+ self.transformer_blocks = nn.ModuleList(
351
+ [
352
+ TransformerBlock(
353
+ embed_dim,
354
+ num_attention_heads,
355
+ attention_head_dim,
356
+ dropout,
357
+ cross_attention_dim,
358
+ attention_bias,
359
+ )
360
+ for _ in range(depth)
361
+ ]
362
+ )
363
+
364
+ self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
365
+ self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
366
+ self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
367
+
368
+ if init_weights:
369
+ self.initialize_from_scratch()
370
+
371
+ def initialize_from_scratch(self):
372
+ def _basic_init(module):
373
+ if isinstance(module, nn.Linear):
374
+ torch.nn.init.xavier_uniform_(module.weight)
375
+ if module.bias is not None:
376
+ nn.init.constant_(module.bias, 0)
377
+
378
+ self.apply(_basic_init)
379
+
380
+ w = self.patch_embed.proj.weight.data
381
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
382
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
383
+
384
+ nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
385
+ nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
386
+ nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
387
+ nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
388
+
389
+ if self.num_classes > 0:
390
+ nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
391
+
392
+ for block in self.transformer_blocks:
393
+ nn.init.constant_(block.norm1.linear.weight, 0)
394
+ nn.init.constant_(block.norm1.linear.bias, 0)
395
+
396
+ nn.init.constant_(self.proj_out_1.weight, 0)
397
+ nn.init.constant_(self.proj_out_1.bias, 0)
398
+ nn.init.constant_(self.proj_out_2.weight, 0)
399
+ nn.init.constant_(self.proj_out_2.bias, 0)
400
+
401
+ def forward(
402
+ self,
403
+ hidden_states,
404
+ encoder_hidden_states=None,
405
+ class_labels=None,
406
+ timestep=None,
407
+ latent_size=None,
408
+ encoder_attention_mask=None,
409
+ pos_embed=None,
410
+ cu_seqlens_q=None,
411
+ cu_seqlens_k=None,
412
+ seqlen_list_q=None,
413
+ seqlen_list_k=None,
414
+ ):
415
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
416
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
417
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
418
+
419
+ orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
420
+ hidden_states = hidden_states.to(torch.float32)
421
+ hidden_states = self.patch_embed(hidden_states)
422
+
423
+ timesteps_proj = self.time_proj(timestep)
424
+ conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
425
+
426
+ if self.num_classes > 0:
427
+ class_embed = self.class_embedder(class_labels)
428
+ conditioning += class_embed
429
+
430
+ latent_size_proj = self.time_proj(latent_size)
431
+ latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
432
+ conditioning += latent_size_embed
433
+
434
+ for block in self.transformer_blocks:
435
+ hidden_states = block(
436
+ hidden_states,
437
+ encoder_hidden_states=encoder_hidden_states,
438
+ encoder_attention_mask=encoder_attention_mask,
439
+ timestep=conditioning,
440
+ rope_pos_embed=pos_embed,
441
+ cu_seqlens_q=cu_seqlens_q,
442
+ cu_seqlens_k=cu_seqlens_k,
443
+ seqlen_list_q=seqlen_list_q,
444
+ seqlen_list_k=seqlen_list_k,
445
+ )
446
+
447
+ shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
448
+ if seqlen_list_q is None:
449
+ shift = shift.unsqueeze(1)
450
+ scale = scale.unsqueeze(1)
451
+ else:
452
+ shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
453
+ scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
454
+
455
+ hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
456
+ hidden_states = self.proj_out_2(hidden_states)
457
+ if self.training:
458
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
459
+ hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
460
+ return hidden_states
461
+
462
+ height, width = orig_height // self.patch_size, orig_width // self.patch_size
463
+ hidden_states = hidden_states.reshape(shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels))
464
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
465
+ output = hidden_states.reshape(shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size))
466
+
467
+ return output
468
 
469
 
470
  @dataclass
PixelFlow-T2I/__pycache__/pipeline.cpython-312.pyc CHANGED
Binary files a/PixelFlow-T2I/__pycache__/pipeline.cpython-312.pyc and b/PixelFlow-T2I/__pycache__/pipeline.cpython-312.pyc differ
 
PixelFlow-T2I/model_index.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "PixelFlowPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
  "scheduling_pixelflow",
 
1
  {
2
+ "_class_name": "PixelFlowT2IPipeline",
3
  "_diffusers_version": "0.36.0",
4
  "scheduler": [
5
  "scheduling_pixelflow",
PixelFlow-T2I/pipeline.py CHANGED
@@ -1,16 +1,23 @@
1
- """Hub custom pipeline: PixelFlowPipeline.
2
-
3
- Load with native Hugging Face diffusers and `trust_remote_code=True`.
4
- """
5
-
6
- from __future__ import annotations
 
 
 
 
 
 
 
7
 
8
  import importlib
 
9
  import math
10
  import sys
11
- from dataclasses import dataclass
12
  from pathlib import Path
13
- from typing import List, Optional, Tuple, Union
14
 
15
  import numpy as np
16
  import torch
@@ -19,25 +26,91 @@ from einops import rearrange
19
 
20
  from diffusers.image_processor import VaeImageProcessor
21
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
22
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
- from diffusers.utils import BaseOutput
 
24
  from diffusers.utils.torch_utils import randn_tensor
25
 
26
 
27
- @dataclass
28
- class PixelFlowPipelineOutput(BaseOutput):
29
- images: Union[torch.Tensor, List, np.ndarray]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
31
 
32
- class PixelFlowPipeline(DiffusionPipeline):
33
- """Pipeline for PixelFlow pixel-space flow generation (class-conditional or text-to-image)."""
 
 
 
 
 
 
 
 
34
 
35
  model_cpu_offload_seq = "text_encoder->transformer"
36
  _optional_components = ["text_encoder", "tokenizer"]
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @classmethod
39
  def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
40
  """Load a self-contained variant folder locally or from the Hub."""
 
 
 
 
 
41
  repo_root = Path(__file__).resolve().parent
42
 
43
  if pretrained_model_name_or_path in (None, "", "."):
@@ -63,129 +136,187 @@ class PixelFlowPipeline(DiffusionPipeline):
63
  variant = variant / subfolder
64
 
65
  model_kwargs = dict(kwargs)
66
- inserted: List[str] = []
67
-
68
- def _load_component(folder: str, module_name: str, class_name: str):
69
- comp_dir = variant / folder
70
- module_path = comp_dir / f"{module_name}.py"
71
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
72
- if not module_path.exists() or not has_weights:
73
- return None
74
-
75
- comp_path = str(comp_dir)
76
- if comp_path not in sys.path:
77
- sys.path.insert(0, comp_path)
78
- inserted.append(comp_path)
79
-
80
- module = importlib.import_module(module_name)
81
- component_cls = getattr(module, class_name)
82
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
83
-
84
- def _load_text_components():
85
- text_encoder = None
86
- tokenizer = None
87
- te_dir = variant / "text_encoder"
88
- tok_dir = variant / "tokenizer"
89
- if te_dir.exists() and (te_dir / "config.json").exists():
90
- from transformers import T5EncoderModel, T5Tokenizer
91
-
92
- text_encoder = T5EncoderModel.from_pretrained(str(te_dir), **model_kwargs)
93
- tokenizer = T5Tokenizer.from_pretrained(str(tok_dir))
94
- return text_encoder, tokenizer
95
 
96
- def _load_text_encoder_name() -> str:
97
- metadata_path = variant / "conversion_metadata.json"
98
- if metadata_path.exists():
99
- import json
100
-
101
- metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
102
- if metadata.get("text_encoder"):
103
- return metadata["text_encoder"]
104
- return "google/flan-t5-xl"
105
 
106
  try:
107
- transformer = _load_component("transformer", "transformer_pixelflow", "PixelFlowTransformer2DModel")
108
- scheduler = _load_component("scheduler", "scheduling_pixelflow", "PixelFlowScheduler")
109
- text_encoder, tokenizer = _load_text_components()
110
-
111
- if scheduler is None:
112
- sched_dir = variant / "scheduler"
113
- if (sched_dir / "scheduling_pixelflow.py").exists():
114
- sched_path = str(sched_dir)
115
- if sched_path not in sys.path:
116
- sys.path.insert(0, sched_path)
117
- inserted.append(sched_path)
118
- scheduler = importlib.import_module("scheduling_pixelflow").PixelFlowScheduler()
119
-
120
- if transformer is None:
121
  raise ValueError(f"No loadable transformer found under {variant}")
122
 
123
- if (
124
- text_encoder is None
125
- and tokenizer is None
126
- and transformer.config.num_classes == 0
127
- and transformer.config.cross_attention_dim is not None
128
- ):
129
- from transformers import T5EncoderModel, T5Tokenizer
 
 
 
 
 
 
 
130
 
131
- text_encoder_name = _load_text_encoder_name()
 
 
 
 
 
 
 
 
 
132
  text_encoder = T5EncoderModel.from_pretrained(text_encoder_name, **model_kwargs)
133
  tokenizer = T5Tokenizer.from_pretrained(text_encoder_name)
134
 
135
- return cls(
136
- transformer=transformer,
137
- scheduler=scheduler,
138
- text_encoder=text_encoder,
139
- tokenizer=tokenizer,
140
- )
141
  finally:
142
  for comp_path in inserted:
143
  if comp_path in sys.path:
144
  sys.path.remove(comp_path)
145
 
146
- def __init__(self, transformer, scheduler, text_encoder=None, tokenizer=None, max_token_length: int = 512):
147
- super().__init__()
148
- self.register_modules(
149
- transformer=transformer,
150
- scheduler=scheduler,
151
- text_encoder=text_encoder,
152
- tokenizer=tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
- self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
155
- self.class_cond = transformer.config.num_classes > 0
156
- self.max_token_length = max_token_length
157
 
158
- def sample_block_noise(self, bs, ch, height, width, eps=1e-6):
 
 
 
 
 
 
 
159
  gamma = self.scheduler.gamma
160
  dist = torch.distributions.multivariate_normal.MultivariateNormal(
161
  torch.zeros(4),
162
  torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4),
163
  )
164
- block_number = bs * ch * (height // 2) * (width // 2)
165
  noise = torch.stack([dist.sample() for _ in range(block_number)])
166
- noise = rearrange(
167
  noise,
168
  "(b c h w) (p q) -> b c (h p) (w q)",
169
- b=bs,
170
- c=ch,
171
  h=height // 2,
172
  w=width // 2,
173
  p=2,
174
  q=2,
175
  )
176
- return noise
177
 
178
- def _stage_guidance_scale(self, stage_idx: int) -> float:
179
- if not self.class_cond:
180
- return self._guidance_scale_value
181
- scale_dict = {0: 0, 1: 1 / 6, 2: 2 / 3, 3: 1}
182
- return (self._guidance_scale_value - 1) * scale_dict[stage_idx] + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- @property
185
- def do_classifier_free_guidance(self) -> bool:
186
- return self._guidance_scale_value > 0
 
 
 
 
 
 
187
 
188
- @torch.no_grad()
189
  def encode_prompt(
190
  self,
191
  prompt: Union[str, List[str]],
@@ -195,6 +326,23 @@ class PixelFlowPipeline(DiffusionPipeline):
195
  negative_prompt: Union[str, List[str]] = "",
196
  max_length: Optional[int] = None,
197
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  if self.text_encoder is None or self.tokenizer is None:
199
  raise ValueError("Text-to-image generation requires `text_encoder` and `tokenizer`.")
200
 
@@ -257,18 +405,20 @@ class PixelFlowPipeline(DiffusionPipeline):
257
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
258
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
259
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
260
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
 
 
261
 
262
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
263
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
264
 
265
  return prompt_embeds, prompt_attention_mask
266
 
267
- @torch.no_grad()
 
268
  def __call__(
269
  self,
270
- prompt: Optional[Union[str, List[str]]] = None,
271
- class_labels: Optional[Union[int, List[int], torch.Tensor]] = None,
272
  height: Optional[int] = None,
273
  width: Optional[int] = None,
274
  num_inference_steps: Union[int, List[int]] = 10,
@@ -279,98 +429,91 @@ class PixelFlowPipeline(DiffusionPipeline):
279
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
280
  output_type: str = "pil",
281
  return_dict: bool = True,
282
- ) -> Union[PixelFlowPipelineOutput, Tuple]:
283
- if height is None:
284
- height = int(self.transformer.config.sample_size)
285
- if width is None:
286
- width = int(self.transformer.config.sample_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- device = self._execution_device
289
- self._guidance_scale_value = guidance_scale
290
 
291
- if isinstance(num_inference_steps, int):
292
- num_inference_steps = [num_inference_steps] * self.scheduler.num_stages
293
-
294
- prompt_attention_mask = None
295
- if self.class_cond:
296
- if class_labels is None:
297
- raise ValueError("`class_labels` are required for class-conditional PixelFlow checkpoints.")
298
- if isinstance(class_labels, int):
299
- class_labels = [class_labels]
300
- if not torch.is_tensor(class_labels):
301
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
302
- else:
303
- class_labels = class_labels.to(device=device, dtype=torch.long)
304
 
305
- batch_size = class_labels.shape[0]
306
- prompt_embeds = class_labels
307
- negative_prompt_embeds = torch.full_like(prompt_embeds, self.transformer.config.num_classes)
308
- if self.do_classifier_free_guidance:
309
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
310
- else:
311
- if prompt is None:
312
- raise ValueError("`prompt` is required for text-to-image PixelFlow checkpoints.")
313
- if isinstance(prompt, str):
314
- prompt = [prompt]
315
- batch_size = len(prompt)
316
- prompt_embeds, prompt_attention_mask = self.encode_prompt(
317
- prompt,
318
- device,
319
- num_images_per_prompt=num_images_per_prompt,
320
- do_classifier_free_guidance=self.do_classifier_free_guidance and guidance_scale > 1.0,
321
- negative_prompt=negative_prompt,
322
- )
323
 
324
- init_factor = 2 ** (self.scheduler.num_stages - 1)
325
- height, width = height // init_factor, width // init_factor
326
- latents = randn_tensor(
327
- (batch_size * num_images_per_prompt, 3, height, width),
328
- generator=generator,
329
- device=device,
330
- dtype=torch.float32,
331
- )
332
 
333
- for stage_idx in range(self.scheduler.num_stages):
334
- self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
335
- timesteps = self.scheduler.Timesteps
336
-
337
- if stage_idx > 0:
338
- height, width = height * 2, width * 2
339
- latents = F.interpolate(latents, size=(height, width), mode="nearest")
340
- original_start_t = self.scheduler.original_start_t[stage_idx]
341
- gamma = self.scheduler.gamma
342
- alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
343
- beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
344
-
345
- noise = self.sample_block_noise(*latents.shape)
346
- noise = noise.to(device=device, dtype=latents.dtype)
347
- latents = alpha * latents + beta * noise
348
-
349
- size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
350
- pos_embed = get_2d_rotary_pos_embed(
351
- embed_dim=self.transformer.attention_head_dim,
352
- crops_coords=((0, 0), (latents.shape[-1] // self.transformer.patch_size, latents.shape[-1] // self.transformer.patch_size)),
353
- grid_size=(latents.shape[-1] // self.transformer.patch_size, latents.shape[-1] // self.transformer.patch_size),
354
- device=device,
355
- output_type="pt",
356
- )
357
- rope_pos = torch.stack(pos_embed, -1)
358
-
359
- autocast_enabled = device.type == "cuda"
360
- autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
361
- for timestep in timesteps:
362
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
363
- timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
364
- with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
365
- if self.class_cond:
366
- noise_pred = self.transformer(
367
- latent_model_input,
368
- timestep=timestep_batch,
369
- class_labels=prompt_embeds,
370
- latent_size=size_tensor,
371
- pos_embed=rope_pos,
372
- ).sample
373
- else:
374
  noise_pred = self.transformer(
375
  latent_model_input,
376
  encoder_hidden_states=prompt_embeds,
@@ -380,26 +523,17 @@ class PixelFlowPipeline(DiffusionPipeline):
380
  pos_embed=rope_pos,
381
  ).sample
382
 
383
- if self.do_classifier_free_guidance:
384
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
385
- noise_pred = noise_pred_uncond + self._stage_guidance_scale(stage_idx) * (
386
- noise_pred_text - noise_pred_uncond
387
- )
388
 
389
- latents = self.scheduler.step(model_output=noise_pred, sample=latents).prev_sample
390
-
391
- image = (latents / 2 + 0.5).clamp(0, 1)
392
-
393
- if output_type == "pt":
394
- pass
395
- elif output_type in ("pil", "np"):
396
- image = self.image_processor.postprocess(image, output_type=output_type)
397
- else:
398
- raise ValueError(f"Unsupported output_type: {output_type}")
399
 
 
400
  self.maybe_free_model_hooks()
401
 
402
  if not return_dict:
403
  return (image,)
 
404
 
405
- return PixelFlowPipelineOutput(images=image)
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
  import importlib
16
+ import json
17
  import math
18
  import sys
 
19
  from pathlib import Path
20
+ from typing import Any, List, Optional, Tuple, Union
21
 
22
  import numpy as np
23
  import torch
 
26
 
27
  from diffusers.image_processor import VaeImageProcessor
28
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
+ from diffusers.schedulers import KarrasDiffusionSchedulers
31
+ from diffusers.utils import replace_example_docstring
32
  from diffusers.utils.torch_utils import randn_tensor
33
 
34
 
35
+ DEFAULT_NATIVE_RESOLUTION = 1024
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> from pathlib import Path
41
+ >>> import torch
42
+ >>> from diffusers import DiffusionPipeline
43
+
44
+ >>> model_dir = Path("./PixelFlow-T2I").resolve()
45
+ >>> pipe = DiffusionPipeline.from_pretrained(
46
+ ... str(model_dir),
47
+ ... local_files_only=True,
48
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
49
+ ... trust_remote_code=True,
50
+ ... torch_dtype=torch.bfloat16,
51
+ ... )
52
+ >>> pipe = pipe.to("cuda")
53
+
54
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
55
+ >>> image = pipe(
56
+ ... prompt="A golden retriever playing in a sunny garden",
57
+ ... height=1024,
58
+ ... width=1024,
59
+ ... num_inference_steps=[10, 10, 10, 10],
60
+ ... guidance_scale=4.0,
61
+ ... generator=generator,
62
+ ... ).images[0]
63
+ >>> image.save("demo.png")
64
+ ```
65
+ """
66
+
67
+
68
 
69
+ class PixelFlowT2IPipeline(DiffusionPipeline):
70
+ r"""
71
+ Pipeline for text-to-image PixelFlow pixel-space cascade generation.
72
 
73
+ Parameters:
74
+ transformer ([`PixelFlowTransformer2DModel`]):
75
+ Text-conditioned PixelFlow transformer operating in pixel space.
76
+ scheduler ([`PixelFlowScheduler`] or [`KarrasDiffusionSchedulers`]):
77
+ Multi-stage flow scheduler used by PixelFlow cascade denoising.
78
+ text_encoder ([`T5EncoderModel`], *optional*):
79
+ Text encoder used to embed prompts.
80
+ tokenizer ([`T5Tokenizer`], *optional*):
81
+ Tokenizer paired with the text encoder.
82
+ """
83
 
84
  model_cpu_offload_seq = "text_encoder->transformer"
85
  _optional_components = ["text_encoder", "tokenizer"]
86
 
87
+ def __init__(
88
+ self,
89
+ transformer: Any,
90
+ scheduler: Any,
91
+ text_encoder=None,
92
+ tokenizer=None,
93
+ max_token_length: int = 512,
94
+ ):
95
+ super().__init__()
96
+ self.register_modules(
97
+ transformer=transformer,
98
+ scheduler=scheduler,
99
+ text_encoder=text_encoder,
100
+ tokenizer=tokenizer,
101
+ )
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
103
+ self.max_token_length = max_token_length
104
+ self.set_progress_bar_config(disable=False)
105
+
106
  @classmethod
107
  def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
108
  """Load a self-contained variant folder locally or from the Hub."""
109
+ import importlib
110
+ import sys
111
+
112
+ from transformers import T5EncoderModel, T5Tokenizer
113
+
114
  repo_root = Path(__file__).resolve().parent
115
 
116
  if pretrained_model_name_or_path in (None, "", "."):
 
136
  variant = variant / subfolder
137
 
138
  model_kwargs = dict(kwargs)
139
+ model_kwargs.pop("trust_remote_code", None)
140
+ scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
141
+ inserted = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def _ensure_path(path: str) -> None:
144
+ if path not in sys.path:
145
+ sys.path.insert(0, path)
146
+ inserted.append(path)
 
 
 
 
 
147
 
148
  try:
149
+ transformer_dir = variant / "transformer"
150
+ if not (transformer_dir / "transformer_pixelflow.py").exists() or not (transformer_dir / "config.json").exists():
 
 
 
 
 
 
 
 
 
 
 
 
151
  raise ValueError(f"No loadable transformer found under {variant}")
152
 
153
+ _ensure_path(str(transformer_dir))
154
+ transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
155
+ transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
156
+
157
+ scheduler_dir = variant / "scheduler"
158
+ if not (scheduler_dir / "scheduler_config.json").exists():
159
+ raise FileNotFoundError(f"Expected scheduler config in {scheduler_dir}")
160
+
161
+ _ensure_path(str(scheduler_dir))
162
+ scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
163
+ try:
164
+ scheduler = scheduler_cls.from_pretrained(str(scheduler_dir), **scheduler_kwargs)
165
+ except Exception:
166
+ scheduler = scheduler_cls(**scheduler_kwargs)
167
 
168
+ text_encoder = None
169
+ tokenizer = None
170
+ text_encoder_dir = variant / "text_encoder"
171
+ tokenizer_dir = variant / "tokenizer"
172
+ if text_encoder_dir.exists() and (text_encoder_dir / "config.json").exists():
173
+ text_encoder = T5EncoderModel.from_pretrained(str(text_encoder_dir), **model_kwargs)
174
+ tokenizer = T5Tokenizer.from_pretrained(str(tokenizer_dir if tokenizer_dir.exists() else text_encoder_dir))
175
+
176
+ if text_encoder is None or tokenizer is None:
177
+ text_encoder_name = cls._read_text_encoder_name(variant)
178
  text_encoder = T5EncoderModel.from_pretrained(text_encoder_name, **model_kwargs)
179
  tokenizer = T5Tokenizer.from_pretrained(text_encoder_name)
180
 
181
+ pipe = cls(transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
182
+ if hasattr(pipe, "register_to_config"):
183
+ pipe.register_to_config(_name_or_path=str(variant))
184
+ return pipe
 
 
185
  finally:
186
  for comp_path in inserted:
187
  if comp_path in sys.path:
188
  sys.path.remove(comp_path)
189
 
190
+ @staticmethod
191
+ def _read_text_encoder_name(variant_path: Path) -> str:
192
+ metadata_path = variant_path / "conversion_metadata.json"
193
+ if metadata_path.exists():
194
+ metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
195
+ if metadata.get("text_encoder"):
196
+ return metadata["text_encoder"]
197
+ return "google/flan-t5-xl"
198
+
199
+ def check_inputs(
200
+ self,
201
+ prompt: Union[str, List[str]],
202
+ height: int,
203
+ width: int,
204
+ num_inference_steps: Union[int, List[int]],
205
+ output_type: str,
206
+ negative_prompt: Optional[Union[str, List[str]]],
207
+ ) -> None:
208
+ if not isinstance(prompt, str) and not (isinstance(prompt, list) and all(isinstance(p, str) for p in prompt)):
209
+ raise TypeError("`prompt` must be a string or list of strings.")
210
+
211
+ if negative_prompt is not None and not isinstance(negative_prompt, str):
212
+ if not (isinstance(negative_prompt, list) and all(isinstance(p, str) for p in negative_prompt)):
213
+ raise TypeError("`negative_prompt` must be a string or list of strings.")
214
+
215
+ if output_type not in {"pil", "np", "pt", "latent"}:
216
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
217
+
218
+ stage_steps = self._normalize_stage_steps(num_inference_steps)
219
+ if any(steps < 1 for steps in stage_steps):
220
+ raise ValueError("Each stage in num_inference_steps must be >= 1.")
221
+
222
+ if height <= 0 or width <= 0:
223
+ raise ValueError("height and width must be positive integers.")
224
+
225
+ def _normalize_stage_steps(self, num_inference_steps: Union[int, List[int]]) -> List[int]:
226
+ if isinstance(num_inference_steps, int):
227
+ return [num_inference_steps] * self.scheduler.num_stages
228
+ if len(num_inference_steps) != self.scheduler.num_stages:
229
+ raise ValueError(
230
+ f"num_inference_steps must have length {self.scheduler.num_stages} "
231
+ f"(one value per stage), got {len(num_inference_steps)}."
232
+ )
233
+ return list(num_inference_steps)
234
+
235
+ def prepare_latents(
236
+ self,
237
+ batch_size: int,
238
+ height: int,
239
+ width: int,
240
+ device: torch.device,
241
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
242
+ ) -> Tuple[torch.Tensor, int, int]:
243
+ init_factor = 2 ** (self.scheduler.num_stages - 1)
244
+ coarse_height = height // init_factor
245
+ coarse_width = width // init_factor
246
+ latents = randn_tensor(
247
+ (batch_size, 3, coarse_height, coarse_width),
248
+ generator=generator,
249
+ device=device,
250
+ dtype=torch.float32,
251
  )
252
+ return latents, coarse_height, coarse_width
 
 
253
 
254
+ def _sample_block_noise(
255
+ self,
256
+ batch_size: int,
257
+ channels: int,
258
+ height: int,
259
+ width: int,
260
+ eps: float = 1e-6,
261
+ ) -> torch.Tensor:
262
  gamma = self.scheduler.gamma
263
  dist = torch.distributions.multivariate_normal.MultivariateNormal(
264
  torch.zeros(4),
265
  torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4),
266
  )
267
+ block_number = batch_size * channels * (height // 2) * (width // 2)
268
  noise = torch.stack([dist.sample() for _ in range(block_number)])
269
+ return rearrange(
270
  noise,
271
  "(b c h w) (p q) -> b c (h p) (w q)",
272
+ b=batch_size,
273
+ c=channels,
274
  h=height // 2,
275
  w=width // 2,
276
  p=2,
277
  q=2,
278
  )
 
279
 
280
+ def _upsample_latents_for_stage(
281
+ self,
282
+ latents: torch.Tensor,
283
+ stage_idx: int,
284
+ height: int,
285
+ width: int,
286
+ device: torch.device,
287
+ ) -> torch.Tensor:
288
+ latents = F.interpolate(latents, size=(height, width), mode="nearest")
289
+ original_start_t = self.scheduler.original_start_t[stage_idx]
290
+ gamma = self.scheduler.gamma
291
+ alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
292
+ beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
293
+
294
+ noise = self._sample_block_noise(*latents.shape)
295
+ noise = noise.to(device=device, dtype=latents.dtype)
296
+ return alpha * latents + beta * noise
297
+
298
+ def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
299
+ grid_size = latents.shape[-1] // self.transformer.patch_size
300
+ pos_embed = get_2d_rotary_pos_embed(
301
+ embed_dim=self.transformer.attention_head_dim,
302
+ crops_coords=((0, 0), (grid_size, grid_size)),
303
+ grid_size=(grid_size, grid_size),
304
+ device=device,
305
+ output_type="pt",
306
+ )
307
+ return torch.stack(pos_embed, -1)
308
 
309
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
310
+ image = (latents / 2 + 0.5).clamp(0, 1)
311
+ if output_type == "latent":
312
+ return latents
313
+ if output_type == "pt":
314
+ return image
315
+ if output_type in {"pil", "np"}:
316
+ return self.image_processor.postprocess(image, output_type=output_type)
317
+ raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
318
 
319
+ @torch.inference_mode()
320
  def encode_prompt(
321
  self,
322
  prompt: Union[str, List[str]],
 
326
  negative_prompt: Union[str, List[str]] = "",
327
  max_length: Optional[int] = None,
328
  ) -> Tuple[torch.Tensor, torch.Tensor]:
329
+ r"""
330
+ Encode text prompts into hidden states for the PixelFlow transformer.
331
+
332
+ Args:
333
+ prompt (`str` or `list[str]`):
334
+ Prompt(s) to encode.
335
+ device (`torch.device`):
336
+ Target device for encoded tensors.
337
+ num_images_per_prompt (`int`, defaults to `1`):
338
+ Number of images to generate per prompt.
339
+ do_classifier_free_guidance (`bool`, defaults to `True`):
340
+ Whether to concatenate unconditional prompt embeddings for CFG.
341
+ negative_prompt (`str` or `list[str]`, defaults to `""`):
342
+ Negative prompt(s) used for classifier-free guidance.
343
+ max_length (`int`, *optional*):
344
+ Maximum token length. Defaults to `self.max_token_length`.
345
+ """
346
  if self.text_encoder is None or self.tokenizer is None:
347
  raise ValueError("Text-to-image generation requires `text_encoder` and `tokenizer`.")
348
 
 
405
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
406
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
407
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
408
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(
409
+ num_images_per_prompt, 1
410
+ )
411
 
412
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
413
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
414
 
415
  return prompt_embeds, prompt_attention_mask
416
 
417
+ @torch.inference_mode()
418
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
419
  def __call__(
420
  self,
421
+ prompt: Union[str, List[str]],
 
422
  height: Optional[int] = None,
423
  width: Optional[int] = None,
424
  num_inference_steps: Union[int, List[int]] = 10,
 
429
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
430
  output_type: str = "pil",
431
  return_dict: bool = True,
432
+ ) -> Union[ImagePipelineOutput, Tuple]:
433
+ r"""
434
+ Generate text-to-image samples with PixelFlow.
435
+
436
+ Examples:
437
+ <!-- this section is replaced by replace_example_docstring -->
438
+
439
+ Args:
440
+ prompt (`str` or `list[str]`):
441
+ Text prompt(s) describing the desired image.
442
+ height (`int`, *optional*):
443
+ Output image height in pixels. Defaults to the transformer's native resolution.
444
+ width (`int`, *optional*):
445
+ Output image width in pixels. Defaults to the transformer's native resolution.
446
+ num_inference_steps (`int` or `list[int]`, defaults to `10`):
447
+ Number of denoising steps per cascade stage.
448
+ guidance_scale (`float`, defaults to `4.0`):
449
+ Classifier-free guidance scale.
450
+ shift (`float`, defaults to `1.0`):
451
+ Noise shift applied by the scheduler when building stage timesteps.
452
+ negative_prompt (`str` or `list[str]`, defaults to `""`):
453
+ Negative prompt(s) for classifier-free guidance.
454
+ num_images_per_prompt (`int`, defaults to `1`):
455
+ Number of images to generate for each prompt.
456
+ generator (`torch.Generator`, *optional*):
457
+ RNG for reproducibility.
458
+ output_type (`str`, defaults to `"pil"`):
459
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
460
+ return_dict (`bool`, defaults to `True`):
461
+ Return [`ImagePipelineOutput`] if True.
462
+ """
463
+ if isinstance(prompt, str):
464
+ prompt_list = [prompt]
465
+ else:
466
+ prompt_list = prompt
467
+
468
+ default_size = int(getattr(self.transformer.config, "sample_size", DEFAULT_NATIVE_RESOLUTION))
469
+ height = int(height or default_size)
470
+ width = int(width or default_size)
471
+ self.check_inputs(prompt_list, height, width, num_inference_steps, output_type, negative_prompt)
472
+
473
+ device = self.transformer.device
474
+ text_encoder_device = self.text_encoder.device if self.text_encoder is not None else device
475
+ do_classifier_free_guidance = guidance_scale > 1.0
476
+ stage_steps = self._normalize_stage_steps(num_inference_steps)
477
+ batch_size = len(prompt_list)
478
+
479
+ prompt_embeds, prompt_attention_mask = self.encode_prompt(
480
+ prompt_list,
481
+ text_encoder_device,
482
+ num_images_per_prompt=num_images_per_prompt,
483
+ do_classifier_free_guidance=do_classifier_free_guidance,
484
+ negative_prompt=negative_prompt,
485
+ )
486
+ prompt_embeds = prompt_embeds.to(device)
487
+ prompt_attention_mask = prompt_attention_mask.to(device)
488
+
489
+ latents, height, width = self.prepare_latents(
490
+ batch_size * num_images_per_prompt,
491
+ height,
492
+ width,
493
+ device,
494
+ generator,
495
+ )
496
+ size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
497
 
498
+ autocast_enabled = device.type == "cuda"
499
+ autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
500
 
501
+ with self.progress_bar(total=sum(stage_steps)) as progress_bar:
502
+ for stage_idx in range(self.scheduler.num_stages):
503
+ self.scheduler.set_timesteps(stage_steps[stage_idx], stage_idx, device=device, shift=shift)
504
+ timesteps = self.scheduler.Timesteps
 
 
 
 
 
 
 
 
 
505
 
506
+ if stage_idx > 0:
507
+ height, width = height * 2, width * 2
508
+ latents = self._upsample_latents_for_stage(latents, stage_idx, height, width, device)
509
+ size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
+ rope_pos = self._prepare_rope_pos_embed(latents, device)
 
 
 
 
 
 
 
512
 
513
+ for timestep in timesteps:
514
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
515
+ timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
516
+ with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  noise_pred = self.transformer(
518
  latent_model_input,
519
  encoder_hidden_states=prompt_embeds,
 
523
  pos_embed=rope_pos,
524
  ).sample
525
 
526
+ if do_classifier_free_guidance:
527
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
528
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
529
 
530
+ latents = self.scheduler.step(model_output=noise_pred, sample=latents).prev_sample
531
+ progress_bar.update()
 
 
 
 
 
 
 
 
532
 
533
+ image = self.decode_latents(latents, output_type=output_type)
534
  self.maybe_free_model_hooks()
535
 
536
  if not return_dict:
537
  return (image,)
538
+ return ImagePipelineOutput(images=image)
539
 
 
PixelFlow-T2I/scheduler/scheduler_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "PixelFlowScheduler",
3
  "_diffusers_version": "0.36.0",
4
- "gamma": -0.3333333333333333,
5
  "num_stages": 4,
6
- "num_train_timesteps": 1000
7
  }
 
1
  {
2
  "_class_name": "PixelFlowScheduler",
3
  "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
  "num_stages": 4,
6
+ "gamma": -0.3333333333333333
7
  }
PixelFlow-T2I/scheduler/scheduling_pixelflow.py CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  from dataclasses import dataclass
3
  from typing import Optional, Tuple, Union
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import math
16
  from dataclasses import dataclass
17
  from typing import Optional, Tuple, Union
PixelFlow-T2I/transformer/transformer_pixelflow.py CHANGED
@@ -1,14 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  from typing import Optional, Tuple, Union
3
 
4
  import torch
 
 
5
 
6
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
7
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
  from diffusers.models.modeling_utils import ModelMixin
9
  from diffusers.utils import BaseOutput
10
 
11
- from modeling_pixelflow import PixelFlowModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  @dataclass
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
  from dataclasses import dataclass
17
  from typing import Optional, Tuple, Union
18
 
19
  import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
 
23
  from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.embeddings import LabelEmbedding, TimestepEmbedding, Timesteps
25
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
26
  from diffusers.models.modeling_utils import ModelMixin
27
  from diffusers.utils import BaseOutput
28
 
29
+ try:
30
+ from flash_attn import flash_attn_varlen_func
31
+ except ImportError:
32
+ warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
33
+ flash_attn_varlen_func = None
34
+
35
+
36
+ def apply_rotary_emb(
37
+ x: torch.Tensor,
38
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ cos, sin = freqs_cis.unbind(-1)
41
+ cos = cos[None, None]
42
+ sin = sin[None, None]
43
+ cos, sin = cos.to(x.device), sin.to(x.device)
44
+
45
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
46
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
47
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
48
+
49
+ return out
50
+
51
+
52
+ class PatchEmbed(nn.Module):
53
+ def __init__(self, patch_size, in_channels, embed_dim, bias=True):
54
+ super().__init__()
55
+ self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
56
+
57
+ def forward_unfold(self, x):
58
+ out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
59
+ if self.proj.bias is not None:
60
+ out_unfold += self.proj.bias.to(out_unfold.dtype)
61
+ return out_unfold
62
+
63
+ def forward(self, x):
64
+ if self.training:
65
+ return self.forward_unfold(x)
66
+ out = self.proj(x)
67
+ out = out.flatten(2).transpose(1, 2)
68
+ return out
69
+
70
+
71
+ class AdaLayerNorm(nn.Module):
72
+ def __init__(self, embedding_dim):
73
+ super().__init__()
74
+ self.embedding_dim = embedding_dim
75
+ self.silu = nn.SiLU()
76
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
77
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
78
+
79
+ def forward(self, x, timestep, seqlen_list=None):
80
+ input_dtype = x.dtype
81
+ emb = self.linear(self.silu(timestep))
82
+
83
+ if seqlen_list is not None:
84
+ emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
85
+ else:
86
+ emb = emb.unsqueeze(1)
87
+
88
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
89
+ x = self.norm(x).float() * (1 + scale_msa) + shift_msa
90
+ return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
91
+
92
+
93
+ class FeedForward(nn.Module):
94
+ def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
95
+ super().__init__()
96
+ inner_dim = int(dim * mult) if inner_dim is None else inner_dim
97
+ dim_out = dim_out if dim_out is not None else dim
98
+ self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
99
+ self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
100
+
101
+ def forward(self, hidden_states):
102
+ hidden_states = self.fc1(hidden_states)
103
+ hidden_states = F.gelu(hidden_states, approximate="tanh")
104
+ hidden_states = self.fc2(hidden_states)
105
+ return hidden_states
106
+
107
+
108
+ class RMSNorm(nn.Module):
109
+ def __init__(self, dim: int, eps=1e-6):
110
+ super().__init__()
111
+ self.weight = nn.Parameter(torch.ones(dim))
112
+ self.eps = eps
113
+
114
+ def forward(self, x):
115
+ output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
116
+ return (self.weight * output).to(x.dtype)
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
121
+ super().__init__()
122
+ self.q_dim = q_dim
123
+ self.kv_dim = kv_dim if kv_dim is not None else q_dim
124
+ self.inner_dim = head_dim * heads
125
+ self.dropout = dropout
126
+ self.head_dim = head_dim
127
+ self.num_heads = heads
128
+
129
+ self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
130
+ self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
131
+ self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
132
+ self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
133
+ self.q_norm = RMSNorm(self.inner_dim)
134
+ self.k_norm = RMSNorm(self.inner_dim)
135
+
136
+ def prepare_attention_mask(self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3):
137
+ head_size = self.num_heads
138
+ if attention_mask is None:
139
+ return attention_mask
140
+
141
+ current_length: int = attention_mask.shape[-1]
142
+ if current_length != target_length:
143
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
144
+
145
+ if out_dim == 3:
146
+ if attention_mask.shape[0] < batch_size * head_size:
147
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
148
+ elif out_dim == 4:
149
+ attention_mask = attention_mask.unsqueeze(1)
150
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
151
+
152
+ return attention_mask
153
+
154
+ def forward(
155
+ self,
156
+ inputs_q,
157
+ inputs_kv,
158
+ attention_mask=None,
159
+ cross_attention=False,
160
+ rope_pos_embed=None,
161
+ cu_seqlens_q=None,
162
+ cu_seqlens_k=None,
163
+ max_seqlen_q=None,
164
+ max_seqlen_k=None,
165
+ ):
166
+ inputs_kv = inputs_q if inputs_kv is None else inputs_kv
167
+
168
+ query_states = self.q_proj(inputs_q)
169
+ key_states = self.k_proj(inputs_kv)
170
+ value_states = self.v_proj(inputs_kv)
171
+
172
+ query_states = self.q_norm(query_states)
173
+ key_states = self.k_norm(key_states)
174
+
175
+ if max_seqlen_q is None:
176
+ assert not self.training, "PixelFlow needs sequence packing for training"
177
+
178
+ bsz, q_len, _ = inputs_q.shape
179
+ _, kv_len, _ = inputs_kv.shape
180
+
181
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
+ key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
183
+ value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
184
+
185
+ query_states = apply_rotary_emb(query_states, rope_pos_embed)
186
+ if not cross_attention:
187
+ key_states = apply_rotary_emb(key_states, rope_pos_embed)
188
+
189
+ if attention_mask is not None:
190
+ attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
191
+ attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
192
+
193
+ attn_output = F.scaled_dot_product_attention(
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ attn_mask=attention_mask,
198
+ dropout_p=self.dropout if self.training else 0.0,
199
+ is_causal=False,
200
+ )
201
+
202
+ attn_output = attn_output.transpose(1, 2).contiguous()
203
+ attn_output = attn_output.view(bsz, q_len, self.inner_dim)
204
+ attn_output = self.o_proj(attn_output)
205
+ return attn_output
206
+
207
+ query_states = query_states.view(-1, self.num_heads, self.head_dim)
208
+ key_states = key_states.view(-1, self.num_heads, self.head_dim)
209
+ value_states = value_states.view(-1, self.num_heads, self.head_dim)
210
+
211
+ query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
212
+ if not cross_attention:
213
+ key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
214
+
215
+ attn_output = flash_attn_varlen_func(
216
+ query_states,
217
+ key_states,
218
+ value_states,
219
+ cu_seqlens_q=cu_seqlens_q,
220
+ cu_seqlens_k=cu_seqlens_k,
221
+ max_seqlen_q=max_seqlen_q,
222
+ max_seqlen_k=max_seqlen_k,
223
+ )
224
+
225
+ attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
226
+ attn_output = self.o_proj(attn_output)
227
+ return attn_output
228
+
229
+
230
+ class TransformerBlock(nn.Module):
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ num_attention_heads,
235
+ attention_head_dim,
236
+ dropout=0.0,
237
+ cross_attention_dim=None,
238
+ attention_bias=False,
239
+ ):
240
+ super().__init__()
241
+ self.norm1 = AdaLayerNorm(dim)
242
+ self.attn1 = Attention(
243
+ q_dim=dim,
244
+ kv_dim=None,
245
+ heads=num_attention_heads,
246
+ head_dim=attention_head_dim,
247
+ dropout=dropout,
248
+ bias=attention_bias,
249
+ )
250
+
251
+ if cross_attention_dim is not None:
252
+ self.norm2 = RMSNorm(dim, eps=1e-6)
253
+ self.attn2 = Attention(
254
+ q_dim=dim,
255
+ kv_dim=cross_attention_dim,
256
+ heads=num_attention_heads,
257
+ head_dim=attention_head_dim,
258
+ dropout=dropout,
259
+ bias=attention_bias,
260
+ )
261
+ else:
262
+ self.attn2 = None
263
+
264
+ self.norm3 = RMSNorm(dim, eps=1e-6)
265
+ self.mlp = FeedForward(dim)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ encoder_hidden_states=None,
271
+ encoder_attention_mask=None,
272
+ timestep=None,
273
+ rope_pos_embed=None,
274
+ cu_seqlens_q=None,
275
+ cu_seqlens_k=None,
276
+ seqlen_list_q=None,
277
+ seqlen_list_k=None,
278
+ ):
279
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
280
+
281
+ attn_output = self.attn1(
282
+ inputs_q=norm_hidden_states,
283
+ inputs_kv=None,
284
+ attention_mask=None,
285
+ cross_attention=False,
286
+ rope_pos_embed=rope_pos_embed,
287
+ cu_seqlens_q=cu_seqlens_q,
288
+ cu_seqlens_k=cu_seqlens_q,
289
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
290
+ max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
291
+ )
292
+
293
+ attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
294
+ hidden_states = attn_output + hidden_states
295
+
296
+ if self.attn2 is not None:
297
+ norm_hidden_states = self.norm2(hidden_states)
298
+ attn_output = self.attn2(
299
+ inputs_q=norm_hidden_states,
300
+ inputs_kv=encoder_hidden_states,
301
+ attention_mask=encoder_attention_mask,
302
+ cross_attention=True,
303
+ rope_pos_embed=rope_pos_embed,
304
+ cu_seqlens_q=cu_seqlens_q,
305
+ cu_seqlens_k=cu_seqlens_k,
306
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
307
+ max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
308
+ )
309
+ hidden_states = hidden_states + attn_output
310
+
311
+ norm_hidden_states = self.norm3(hidden_states)
312
+ norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
313
+ ff_output = self.mlp(norm_hidden_states)
314
+ ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
315
+ hidden_states = ff_output + hidden_states
316
+
317
+ return hidden_states
318
+
319
+
320
+ class PixelFlowModel(nn.Module):
321
+ def __init__(
322
+ self,
323
+ in_channels,
324
+ out_channels,
325
+ num_attention_heads,
326
+ attention_head_dim,
327
+ depth,
328
+ patch_size,
329
+ dropout=0.0,
330
+ cross_attention_dim=None,
331
+ attention_bias=True,
332
+ num_classes=0,
333
+ init_weights=True,
334
+ ):
335
+ super().__init__()
336
+ self.patch_size = patch_size
337
+ self.attention_head_dim = attention_head_dim
338
+ self.num_classes = num_classes
339
+ self.out_channels = out_channels
340
+
341
+ embed_dim = num_attention_heads * attention_head_dim
342
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
343
+
344
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
345
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
346
+ self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
347
+ if self.num_classes > 0:
348
+ self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
349
+
350
+ self.transformer_blocks = nn.ModuleList(
351
+ [
352
+ TransformerBlock(
353
+ embed_dim,
354
+ num_attention_heads,
355
+ attention_head_dim,
356
+ dropout,
357
+ cross_attention_dim,
358
+ attention_bias,
359
+ )
360
+ for _ in range(depth)
361
+ ]
362
+ )
363
+
364
+ self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
365
+ self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
366
+ self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
367
+
368
+ if init_weights:
369
+ self.initialize_from_scratch()
370
+
371
+ def initialize_from_scratch(self):
372
+ def _basic_init(module):
373
+ if isinstance(module, nn.Linear):
374
+ torch.nn.init.xavier_uniform_(module.weight)
375
+ if module.bias is not None:
376
+ nn.init.constant_(module.bias, 0)
377
+
378
+ self.apply(_basic_init)
379
+
380
+ w = self.patch_embed.proj.weight.data
381
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
382
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
383
+
384
+ nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
385
+ nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
386
+ nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
387
+ nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
388
+
389
+ if self.num_classes > 0:
390
+ nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
391
+
392
+ for block in self.transformer_blocks:
393
+ nn.init.constant_(block.norm1.linear.weight, 0)
394
+ nn.init.constant_(block.norm1.linear.bias, 0)
395
+
396
+ nn.init.constant_(self.proj_out_1.weight, 0)
397
+ nn.init.constant_(self.proj_out_1.bias, 0)
398
+ nn.init.constant_(self.proj_out_2.weight, 0)
399
+ nn.init.constant_(self.proj_out_2.bias, 0)
400
+
401
+ def forward(
402
+ self,
403
+ hidden_states,
404
+ encoder_hidden_states=None,
405
+ class_labels=None,
406
+ timestep=None,
407
+ latent_size=None,
408
+ encoder_attention_mask=None,
409
+ pos_embed=None,
410
+ cu_seqlens_q=None,
411
+ cu_seqlens_k=None,
412
+ seqlen_list_q=None,
413
+ seqlen_list_k=None,
414
+ ):
415
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
416
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
417
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
418
+
419
+ orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
420
+ hidden_states = hidden_states.to(torch.float32)
421
+ hidden_states = self.patch_embed(hidden_states)
422
+
423
+ timesteps_proj = self.time_proj(timestep)
424
+ conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
425
+
426
+ if self.num_classes > 0:
427
+ class_embed = self.class_embedder(class_labels)
428
+ conditioning += class_embed
429
+
430
+ latent_size_proj = self.time_proj(latent_size)
431
+ latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
432
+ conditioning += latent_size_embed
433
+
434
+ for block in self.transformer_blocks:
435
+ hidden_states = block(
436
+ hidden_states,
437
+ encoder_hidden_states=encoder_hidden_states,
438
+ encoder_attention_mask=encoder_attention_mask,
439
+ timestep=conditioning,
440
+ rope_pos_embed=pos_embed,
441
+ cu_seqlens_q=cu_seqlens_q,
442
+ cu_seqlens_k=cu_seqlens_k,
443
+ seqlen_list_q=seqlen_list_q,
444
+ seqlen_list_k=seqlen_list_k,
445
+ )
446
+
447
+ shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
448
+ if seqlen_list_q is None:
449
+ shift = shift.unsqueeze(1)
450
+ scale = scale.unsqueeze(1)
451
+ else:
452
+ shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
453
+ scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
454
+
455
+ hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
456
+ hidden_states = self.proj_out_2(hidden_states)
457
+ if self.training:
458
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
459
+ hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
460
+ return hidden_states
461
+
462
+ height, width = orig_height // self.patch_size, orig_width // self.patch_size
463
+ hidden_states = hidden_states.reshape(shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels))
464
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
465
+ output = hidden_states.reshape(shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size))
466
+
467
+ return output
468
 
469
 
470
  @dataclass
README.md CHANGED
@@ -1,99 +1,147 @@
1
- ---
2
- license: mit
3
- library_name: diffusers
4
- pipeline_tag: text-to-image
5
- tags:
6
- - diffusers
7
- - pixelflow
8
- - image-generation
9
- - class-conditional
10
- - flow-matching
11
- widget:
12
- - output:
13
- url: PixelFlow-256/demo.png
14
- language:
15
- - en
16
- ---
17
-
18
  # BiliSakura/PixelFlow-diffusers
19
 
20
- Self-contained PixelFlow checkpoints for Hugging Face diffusers. Each subfolder ships its own `pipeline.py`, component modules, and weights.
21
 
22
  ## Available checkpoints
23
 
24
- | Subfolder | Task | Resolution | Params |
25
- | --- | --- | ---: | ---: |
26
- | [`PixelFlow-256/`](PixelFlow-256/) | class-to-image | 256×256 | 677M |
27
- | [`PixelFlow-T2I/`](PixelFlow-T2I/) | text-to-image | 1024×1024 | 882M |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- ## ImageNet class labels
30
 
31
- For class-conditional [`PixelFlow-256/`](PixelFlow-256/), ImageNet-1k labels live in shared [`labels/`](labels/) at the repo root:
32
 
33
- | File | Direction | Value format |
34
- | --- | --- | --- |
35
- | `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
36
- | `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` |
37
 
38
- After `PixelFlowPipeline.from_pretrained(...)`, the pipeline exposes:
39
 
40
- - `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence
41
- - `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id)
42
- - `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")`
43
  - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
44
 
45
  ## Demo
46
 
47
- ![PixelFlow-256 demo](PixelFlow-256/demo.png)
48
 
49
- ## Load from a local clone
 
 
50
 
51
- ```python
52
- import sys
53
- from pathlib import Path
54
 
55
- repo = Path("BiliSakura/PixelFlow-diffusers").resolve()
56
- variant = "PixelFlow-256"
 
57
 
58
- sys.path.insert(0, str(repo / variant))
59
- from pipeline import PixelFlowPipeline
60
 
61
- pipe = PixelFlowPipeline.from_pretrained(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  pipe.to("cuda")
63
 
64
- images = pipe(
65
- class_labels=207,
 
 
 
 
 
 
66
  num_inference_steps=[10, 10, 10, 10],
67
  guidance_scale=4.0,
68
- ).images
69
-
70
- # Human-readable ImageNet labels (English or Chinese)
71
- print(pipe.id2label[207]) # "golden retriever"
72
- print(pipe.id2label_cn[207]) # "金毛猎犬"
73
- pipe.get_label_ids("golden retriever") # [207]
74
- pipe.get_label_ids("金毛猎犬", lang="cn") # [207]
75
- images = pipe(class_labels="golden retriever", num_inference_steps=[10, 10, 10, 10]).images
76
  ```
77
 
78
  ### Text-to-image (`PixelFlow-T2I`)
79
 
80
- Uses [`google/flan-t5-xl`](https://huggingface.co/google/flan-t5-xl) as the text encoder (loaded from Hugging Face at runtime, not bundled in the repo).
81
 
82
  ```python
83
- variant = "PixelFlow-T2I"
84
- sys.path.insert(0, str(repo / variant))
85
- from pipeline import PixelFlowPipeline
86
-
87
- pipe = PixelFlowPipeline.from_pretrained(".")
 
 
 
 
 
 
 
88
  pipe.to("cuda")
89
 
90
- images = pipe(
 
91
  prompt="A golden retriever playing in a sunny garden",
 
 
92
  num_inference_steps=[10, 10, 10, 10],
93
  guidance_scale=4.0,
94
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  ```
96
 
 
 
97
  ## Conversion
98
 
99
  ```bash
@@ -107,4 +155,17 @@ python scripts/convert_pixelflow_to_diffusers.py \
107
  --config models/raw/PixelFlow/t2i/config.yaml \
108
  --output models/BiliSakura/PixelFlow-diffusers/PixelFlow-T2I \
109
  --skip-text-encoder
110
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # BiliSakura/PixelFlow-diffusers
2
 
3
+ Self-contained PixelFlow checkpoints for Hugging Face diffusers. Each variant folder ships its own `pipeline.py`, component modules, and weights.
4
 
5
  ## Available checkpoints
6
 
7
+ | Subfolder | Pipeline | Task | Resolution | Params |
8
+ | --- | --- | --- | ---: | ---: |
9
+ | [`PixelFlow-256/`](PixelFlow-256/) | `PixelFlowPipeline` | class-to-image | 256×256 | 677M |
10
+ | [`PixelFlow-T2I/`](PixelFlow-T2I/) | `PixelFlowT2IPipeline` | text-to-image | 1024×1024 | 882M |
11
+
12
+ ## Repo layout
13
+
14
+ ```text
15
+ BiliSakura/PixelFlow-diffusers/
16
+ ├── README.md
17
+ ├── PixelFlow-256/
18
+ │ ├── pipeline.py
19
+ │ ├── model_index.json
20
+ │ ├── scheduler/scheduler_config.json
21
+ │ └── transformer/
22
+ └── PixelFlow-T2I/
23
+ ├── pipeline.py
24
+ ├── model_index.json
25
+ ├── scheduler/scheduler_config.json
26
+ ├── text_encoder/
27
+ ├── tokenizer/
28
+ └── transformer/
29
+ ```
30
 
31
+ Each variant is self-contained. The `scheduler/` folder contains `scheduler_config.json` and `scheduling_pixelflow.py` with [`PixelFlowScheduler`](PixelFlow-256/scheduler/scheduling_pixelflow.py).
32
 
33
+ No shared helper modules at inference time; only PyPI `diffusers` plus the local variant directory.
34
 
35
+ ## ImageNet class labels
 
 
 
36
 
37
+ For class-conditional [`PixelFlow-256/`](PixelFlow-256/), `id2label` is embedded in `PixelFlow-256/model_index.json` (DiT-style).
38
 
39
+ - `pipe.id2label` — inspect id → English label correspondence
40
+ - `pipe.labels` — reverse map (English synonym → id)
41
+ - `pipe.get_label_ids("golden retriever")`
42
  - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
43
 
44
  ## Demo
45
 
46
+ Class-to-image:
47
 
48
+ ```bash
49
+ python demo_inference_c2i.py
50
+ ```
51
 
52
+ Text-to-image:
 
 
53
 
54
+ ```bash
55
+ python demo_inference_t2i.py
56
+ ```
57
 
58
+ ## Load from a local clone
 
59
 
60
+ ### Class-to-image (`PixelFlow-256`)
61
+
62
+ ```python
63
+ from pathlib import Path
64
+ import torch
65
+ from diffusers import DiffusionPipeline
66
+
67
+ model_dir = Path("./PixelFlow-256").resolve()
68
+ pipe = DiffusionPipeline.from_pretrained(
69
+ str(model_dir),
70
+ local_files_only=True,
71
+ custom_pipeline=str(model_dir / "pipeline.py"),
72
+ trust_remote_code=True,
73
+ torch_dtype=torch.bfloat16,
74
+ )
75
  pipe.to("cuda")
76
 
77
+ print(pipe.id2label[207])
78
+ print(pipe.get_label_ids("golden retriever"))
79
+
80
+ generator = torch.Generator(device="cuda").manual_seed(42)
81
+ image = pipe(
82
+ class_labels="golden retriever",
83
+ height=256,
84
+ width=256,
85
  num_inference_steps=[10, 10, 10, 10],
86
  guidance_scale=4.0,
87
+ generator=generator,
88
+ ).images[0]
89
+ image.save("demo.png")
 
 
 
 
 
90
  ```
91
 
92
  ### Text-to-image (`PixelFlow-T2I`)
93
 
94
+ Uses [`google/flan-t5-xl`](https://huggingface.co/google/flan-t5-xl) when `text_encoder/` is not bundled.
95
 
96
  ```python
97
+ from pathlib import Path
98
+ import torch
99
+ from diffusers import DiffusionPipeline
100
+
101
+ model_dir = Path("./PixelFlow-T2I").resolve()
102
+ pipe = DiffusionPipeline.from_pretrained(
103
+ str(model_dir),
104
+ local_files_only=True,
105
+ custom_pipeline=str(model_dir / "pipeline.py"),
106
+ trust_remote_code=True,
107
+ torch_dtype=torch.bfloat16,
108
+ )
109
  pipe.to("cuda")
110
 
111
+ generator = torch.Generator(device="cuda").manual_seed(42)
112
+ image = pipe(
113
  prompt="A golden retriever playing in a sunny garden",
114
+ height=1024,
115
+ width=1024,
116
  num_inference_steps=[10, 10, 10, 10],
117
  guidance_scale=4.0,
118
+ generator=generator,
119
+ ).images[0]
120
+ image.save("demo.png")
121
+ ```
122
+
123
+ Load a **variant subfolder** (e.g. `./PixelFlow-256`), not the repo root.
124
+
125
+ ## Load from the Hub
126
+
127
+ ```python
128
+ import torch
129
+ from diffusers import DiffusionPipeline
130
+
131
+ pipe = DiffusionPipeline.from_pretrained(
132
+ "BiliSakura/PixelFlow-diffusers",
133
+ subfolder="PixelFlow-256",
134
+ custom_pipeline="pipeline.py",
135
+ trust_remote_code=True,
136
+ torch_dtype=torch.bfloat16,
137
+ )
138
+ pipe.to("cuda")
139
+
140
+ image = pipe(class_labels="golden retriever", num_inference_steps=[10, 10, 10, 10]).images[0]
141
  ```
142
 
143
+ Swap `subfolder="PixelFlow-T2I"` and call with `prompt=...` for text-to-image.
144
+
145
  ## Conversion
146
 
147
  ```bash
 
155
  --config models/raw/PixelFlow/t2i/config.yaml \
156
  --output models/BiliSakura/PixelFlow-diffusers/PixelFlow-T2I \
157
  --skip-text-encoder
158
+ ```
159
+
160
+ ## Citation
161
+
162
+ ```bibtex
163
+ @article{chen2025pixelflow,
164
+ title={PixelFlow: Pixel-Space Flow Matching for High-Resolution Image Synthesis},
165
+ author={Chen, Shoufa and others},
166
+ year={2025},
167
+ eprint={2504.07963},
168
+ archivePrefix={arXiv},
169
+ primaryClass={cs.CV}
170
+ }
171
+ ```
demo_inference_c2i.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a demo image with PixelFlow-256."""
3
+
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parent
10
+ MODEL_DIR = REPO_ROOT / "PixelFlow-256"
11
+ OUTPUT_PATH = REPO_ROOT / "PixelFlow-256" / "demo.png"
12
+
13
+
14
+ def main() -> None:
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ str(MODEL_DIR),
17
+ local_files_only=True,
18
+ custom_pipeline=str(MODEL_DIR / "pipeline.py"),
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16,
21
+ )
22
+ pipe.to("cuda")
23
+
24
+ print(pipe.id2label[207])
25
+ print(pipe.get_label_ids("golden retriever"))
26
+
27
+ generator = torch.Generator(device="cuda").manual_seed(42)
28
+ image = pipe(
29
+ class_labels="golden retriever",
30
+ height=256,
31
+ width=256,
32
+ num_inference_steps=[10, 10, 10, 10],
33
+ guidance_scale=4.0,
34
+ generator=generator,
35
+ ).images[0]
36
+ image.save(OUTPUT_PATH)
37
+ print(f"Saved demo image to {OUTPUT_PATH}")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
demo_inference_t2i.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a demo image with PixelFlow-T2I."""
3
+
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parent
10
+ MODEL_DIR = REPO_ROOT / "PixelFlow-T2I"
11
+ OUTPUT_PATH = REPO_ROOT / "PixelFlow-T2I" / "demo.png"
12
+
13
+
14
+ def main() -> None:
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ str(MODEL_DIR),
17
+ local_files_only=True,
18
+ custom_pipeline=str(MODEL_DIR / "pipeline.py"),
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16,
21
+ )
22
+ pipe.to("cuda")
23
+
24
+ generator = torch.Generator(device="cuda").manual_seed(42)
25
+ image = pipe(
26
+ prompt="A golden retriever playing in a sunny garden",
27
+ height=1024,
28
+ width=1024,
29
+ num_inference_steps=[10, 10, 10, 10],
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save(OUTPUT_PATH)
34
+ print(f"Saved demo image to {OUTPUT_PATH}")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()