BiliSakura commited on
Commit
c8b6e26
·
verified ·
1 Parent(s): cccd7ed

Upload folder using huggingface_hub

Browse files
Files changed (47) hide show
  1. .gitattributes +3 -0
  2. README.md +63 -0
  3. pMF-B-16/README.md +44 -0
  4. pMF-B-16/model_index.json +1017 -0
  5. pMF-B-16/pipeline.py +285 -0
  6. pMF-B-16/scheduler/scheduler_config.json +7 -0
  7. pMF-B-16/transformer/config.json +26 -0
  8. pMF-B-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  9. pMF-B-16/transformer/transformer_pmf.py +664 -0
  10. pMF-B-32/README.md +44 -0
  11. pMF-B-32/model_index.json +1017 -0
  12. pMF-B-32/pipeline.py +285 -0
  13. pMF-B-32/scheduler/scheduler_config.json +7 -0
  14. pMF-B-32/transformer/config.json +26 -0
  15. pMF-B-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  16. pMF-B-32/transformer/transformer_pmf.py +664 -0
  17. pMF-H-16/README.md +44 -0
  18. pMF-H-16/demo.png +3 -0
  19. pMF-H-16/model_index.json +1017 -0
  20. pMF-H-16/pipeline.py +285 -0
  21. pMF-H-16/scheduler/scheduler_config.json +7 -0
  22. pMF-H-16/transformer/config.json +26 -0
  23. pMF-H-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  24. pMF-H-16/transformer/transformer_pmf.py +664 -0
  25. pMF-H-32/README.md +44 -0
  26. pMF-H-32/demo.png +3 -0
  27. pMF-H-32/model_index.json +1017 -0
  28. pMF-H-32/pipeline.py +285 -0
  29. pMF-H-32/scheduler/scheduler_config.json +7 -0
  30. pMF-H-32/transformer/config.json +26 -0
  31. pMF-H-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  32. pMF-H-32/transformer/transformer_pmf.py +664 -0
  33. pMF-L-16/README.md +44 -0
  34. pMF-L-16/model_index.json +1017 -0
  35. pMF-L-16/pipeline.py +285 -0
  36. pMF-L-16/scheduler/scheduler_config.json +7 -0
  37. pMF-L-16/transformer/config.json +26 -0
  38. pMF-L-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  39. pMF-L-16/transformer/transformer_pmf.py +664 -0
  40. pMF-L-32/README.md +44 -0
  41. pMF-L-32/demo.png +3 -0
  42. pMF-L-32/model_index.json +1017 -0
  43. pMF-L-32/pipeline.py +285 -0
  44. pMF-L-32/scheduler/scheduler_config.json +7 -0
  45. pMF-L-32/transformer/config.json +26 -0
  46. pMF-L-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  47. pMF-L-32/transformer/transformer_pmf.py +664 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pMF-H-16/demo.png filter=lfs diff=lfs merge=lfs -text
37
+ pMF-H-32/demo.png filter=lfs diff=lfs merge=lfs -text
38
+ pMF-L-32/demo.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-diffusers
15
+
16
+ Native diffusers implementation of [Pixel Mean Flows (pMF)](https://arxiv.org/abs/2601.22158). Each variant folder is self-contained:
17
+
18
+ - `pipeline.py` — `PMFPipeline`
19
+ - `scheduler/scheduler_config.json` — `FlowMatchEulerDiscreteScheduler` config
20
+ - `transformer/transformer_pmf.py` — `PMFTransformer2DModel`
21
+ - `transformer/` — converted weights and config
22
+
23
+ ## Available checkpoints
24
+
25
+ | Checkpoint | Path | Resolution | Recommended CFG (ω) | CFG interval | Noise scale |
26
+ | --- | --- | --- | --- | --- | --- |
27
+ | pMF-B/16 | `./pMF-B-16` | 256×256 | 7.5 | [0.1, 0.8] | 1.0 |
28
+ | pMF-B/32 | `./pMF-B-32` | 512×512 | 6.5 | [0.1, 0.7] | 2.0 |
29
+ | pMF-L/16 | `./pMF-L-16` | 256×256 | 7.0 | [0.2, 0.7] | 1.0 |
30
+ | pMF-L/32 | `./pMF-L-32` | 512×512 | 7.5 | [0.2, 0.6] | 4.0 |
31
+ | pMF-H/16 | `./pMF-H-16` | 256×256 | 7.0 | [0.2, 0.6] | 2.0 |
32
+ | pMF-H/32 | `./pMF-H-32` | 512×512 | 5.5 | [0.1, 0.6] | 4.0 |
33
+
34
+ ## Inference
35
+
36
+ ```python
37
+ from pathlib import Path
38
+ from diffusers import DiffusionPipeline
39
+ import torch
40
+
41
+ model_dir = Path("./pMF-L-16")
42
+ pipe = DiffusionPipeline.from_pretrained(
43
+ str(model_dir),
44
+ local_files_only=True,
45
+ custom_pipeline=str(model_dir / "pipeline.py"),
46
+ trust_remote_code=True,
47
+ torch_dtype=torch.float32,
48
+ ).to("cuda")
49
+
50
+ generator = torch.Generator(device="cuda").manual_seed(42)
51
+ image = pipe(
52
+ class_labels="golden retriever",
53
+ num_inference_steps=1,
54
+ guidance_scale=7.0,
55
+ guidance_interval_min=0.2,
56
+ guidance_interval_max=0.7,
57
+ noise_scale=1.0,
58
+ generator=generator,
59
+ ).images[0]
60
+ image.save("demo.png")
61
+ ```
62
+
63
+ Load a **variant subfolder** (e.g. `./pMF-L-16`), not the repo root.
pMF-B-16/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-B-16
15
+
16
+ Self-contained Diffusers variant for **pMF-B/16** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=7.5`, interval `[0.1, 0.8]`, `noise_scale=1.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-B-16")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=7.5,
40
+ guidance_interval_min=0.1,
41
+ guidance_interval_max=0.8,
42
+ noise_scale=1.0,
43
+ ).images[0]
44
+ ```
pMF-B-16/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-B-16/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-B-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-B-16/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 128,
6
+ "depth": 16,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 768,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-B/16",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 12,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 16,
22
+ "sample_size": 256,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-B-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6554cc507ea6d53c72c01575c20ca4b5ff0b3a8d1a709e9e5e1c1da3e9a552
3
+ size 472960528
pMF-B-16/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel
pMF-B-32/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-B-32
15
+
16
+ Self-contained Diffusers variant for **pMF-B/32** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=6.5`, interval `[0.1, 0.7]`, `noise_scale=2.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-B-32")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=6.5,
40
+ guidance_interval_min=0.1,
41
+ guidance_interval_max=0.7,
42
+ noise_scale=2.0,
43
+ ).images[0]
44
+ ```
pMF-B-32/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-B-32/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-B-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-B-32/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 128,
6
+ "depth": 16,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 768,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-B/32",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 12,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 32,
22
+ "sample_size": 512,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-B-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95799f21538c20de6ec71ac20c4881126a4a58d1dc399419f79c7732b6ccce13
3
+ size 481227280
pMF-B-32/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel
pMF-H-16/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-H-16
15
+
16
+ Self-contained Diffusers variant for **pMF-H/16** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=7.0`, interval `[0.2, 0.6]`, `noise_scale=2.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-H-16")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=7.0,
40
+ guidance_interval_min=0.2,
41
+ guidance_interval_max=0.6,
42
+ noise_scale=2.0,
43
+ ).images[0]
44
+ ```
pMF-H-16/demo.png ADDED

Git LFS Details

  • SHA256: 340c0bec765db16b20c22e583ec2f4d9e72f33056298703ff442cc91d4b00365
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
pMF-H-16/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-H-16/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-H-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-H-16/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 256,
6
+ "depth": 48,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 1280,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-H/16",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 16,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 16,
22
+ "sample_size": 256,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-H-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58fe8d472810713e6bf301a432d8718fd3bebfa40f8922507a1b87cbfbb362a7
3
+ size 3822104552
pMF-H-16/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel
pMF-H-32/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-H-32
15
+
16
+ Self-contained Diffusers variant for **pMF-H/32** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=5.5`, interval `[0.1, 0.6]`, `noise_scale=4.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-H-32")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=5.5,
40
+ guidance_interval_min=0.1,
41
+ guidance_interval_max=0.6,
42
+ noise_scale=4.0,
43
+ ).images[0]
44
+ ```
pMF-H-32/demo.png ADDED

Git LFS Details

  • SHA256: a87cf6f0b8d77877449c12bd88b7830a9416b02ab229d9382ccb4e814e33d428
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB
pMF-H-32/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-H-32/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-H-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-H-32/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 256,
6
+ "depth": 48,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 1280,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-H/32",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 16,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 32,
22
+ "sample_size": 512,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-H-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f27a084b5e4ea148216425bdc168acdbd9ee7b0b862d21bfe8b277776ff314a
3
+ size 3836269544
pMF-H-32/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel
pMF-L-16/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-L-16
15
+
16
+ Self-contained Diffusers variant for **pMF-L/16** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=7.0`, interval `[0.2, 0.7]`, `noise_scale=1.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-L-16")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=7.0,
40
+ guidance_interval_min=0.2,
41
+ guidance_interval_max=0.7,
42
+ noise_scale=1.0,
43
+ ).images[0]
44
+ ```
pMF-L-16/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-L-16/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-L-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-L-16/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 128,
6
+ "depth": 32,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 1024,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-L/16",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 16,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 16,
22
+ "sample_size": 256,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-L-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72dfc4b443595b9ad8ce8aa9aa8566de269c0cafb440a0d2086800fae6b3ea20
3
+ size 1641329160
pMF-L-16/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel
pMF-L-32/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - diffusers
7
+ - pmf
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ ---
13
+
14
+ # pMF-L-32
15
+
16
+ Self-contained Diffusers variant for **pMF-L/32** (Pixel Mean Flows).
17
+
18
+ Recommended settings: `guidance_scale=7.5`, interval `[0.2, 0.6]`, `noise_scale=4.0`.
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from diffusers import DiffusionPipeline
25
+ import torch
26
+
27
+ model_dir = Path("./pMF-L-32")
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ str(model_dir),
30
+ local_files_only=True,
31
+ custom_pipeline=str(model_dir / "pipeline.py"),
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float32,
34
+ ).to("cuda")
35
+
36
+ image = pipe(
37
+ class_labels=207,
38
+ num_inference_steps=1,
39
+ guidance_scale=7.5,
40
+ guidance_interval_min=0.2,
41
+ guidance_interval_max=0.6,
42
+ noise_scale=4.0,
43
+ ).images[0]
44
+ ```
pMF-L-32/demo.png ADDED

Git LFS Details

  • SHA256: 2cb1e80baaa93a0018f39aad046fe3cf4cbfda23dc1162d48d774ff2a5618079
  • Pointer size: 131 Bytes
  • Size of remote file: 306 kB
pMF-L-32/model_index.json ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PMFPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_pmf",
13
+ "PMFTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
pMF-L-32/pipeline.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Hub custom pipeline: PMFPipeline.
16
+
17
+ Load with native Hugging Face diffusers and trust_remote_code=True.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+
32
+ DEFAULT_CFG_BY_MODEL: Dict[str, Dict[str, float]] = {
33
+ "pMF-B/16": {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8},
34
+ "pMF-B/32": {"guidance_scale": 6.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.7},
35
+ "pMF-L/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.7},
36
+ "pMF-L/32": {"guidance_scale": 7.5, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
37
+ "pMF-H/16": {"guidance_scale": 7.0, "guidance_interval_min": 0.2, "guidance_interval_max": 0.6},
38
+ "pMF-H/32": {"guidance_scale": 5.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.6},
39
+ }
40
+
41
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
42
+ "pMF-B/16": 1.0,
43
+ "pMF-B/32": 2.0,
44
+ "pMF-L/16": 1.0,
45
+ "pMF-L/32": 4.0,
46
+ "pMF-H/16": 2.0,
47
+ "pMF-H/32": 4.0,
48
+ }
49
+
50
+
51
+ def _set_pmf_timesteps(
52
+ scheduler: FlowMatchEulerDiscreteScheduler,
53
+ num_inference_steps: int,
54
+ device: torch.device,
55
+ ) -> torch.Tensor:
56
+ r"""Set linear flow sigmas from 1.0 to 0.0 for pMF sampling."""
57
+ flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
58
+ scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
59
+ return flow_sigmas
60
+
61
+
62
+ class PMFPipeline(DiffusionPipeline):
63
+ r"""
64
+ Pipeline for ImageNet class-conditional generation with Pixel Mean Flows (pMF).
65
+
66
+ Parameters:
67
+ transformer ([`PMFTransformer2DModel`]):
68
+ Class-conditioned pMF transformer that predicts mean-flow velocity.
69
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
70
+ Built-in flow-matching Euler scheduler.
71
+ id2label (`dict[int, str]`, *optional*):
72
+ ImageNet class id to English label mapping.
73
+ """
74
+
75
+ model_cpu_offload_seq = "transformer"
76
+
77
+ def __init__(
78
+ self,
79
+ transformer,
80
+ scheduler,
81
+ id2label: Optional[Dict[Union[int, str], str]] = None,
82
+ ):
83
+ super().__init__()
84
+ if scheduler is None:
85
+ scheduler = FlowMatchEulerDiscreteScheduler(
86
+ num_train_timesteps=1000,
87
+ shift=1.0,
88
+ stochastic_sampling=False,
89
+ )
90
+ self.register_modules(transformer=transformer, scheduler=scheduler)
91
+ self._id2label = self._normalize_id2label(id2label)
92
+ self.labels = self._build_label2id(self._id2label)
93
+ self._labels_loaded_from_model_index = bool(self._id2label)
94
+
95
+ def _ensure_labels_loaded(self) -> None:
96
+ if self._labels_loaded_from_model_index:
97
+ return
98
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
99
+ if loaded:
100
+ self._id2label = loaded
101
+ self.labels = self._build_label2id(self._id2label)
102
+ self._labels_loaded_from_model_index = True
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
112
+ if not variant_path:
113
+ return {}
114
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
115
+ if not model_index_path.exists():
116
+ return {}
117
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
118
+ id2label = raw.get("id2label")
119
+ if not isinstance(id2label, dict):
120
+ return {}
121
+ return {int(key): value for key, value in id2label.items()}
122
+
123
+ @staticmethod
124
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
125
+ label2id: Dict[str, int] = {}
126
+ for class_id, value in id2label.items():
127
+ for synonym in value.split(","):
128
+ synonym = synonym.strip()
129
+ if synonym:
130
+ label2id[synonym] = int(class_id)
131
+ return dict(sorted(label2id.items()))
132
+
133
+ @property
134
+ def id2label(self) -> Dict[int, str]:
135
+ self._ensure_labels_loaded()
136
+ return self._id2label
137
+
138
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
139
+ self._ensure_labels_loaded()
140
+ if not self.labels:
141
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
142
+ labels = [label] if isinstance(label, str) else label
143
+ missing = [item for item in labels if item not in self.labels]
144
+ if missing:
145
+ preview = ", ".join(list(self.labels.keys())[:8])
146
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
147
+ return [self.labels[item] for item in labels]
148
+
149
+ def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
150
+ if isinstance(class_labels, int):
151
+ return [class_labels]
152
+ if isinstance(class_labels, str):
153
+ return self.get_label_ids(class_labels)
154
+ if class_labels and isinstance(class_labels[0], str):
155
+ return self.get_label_ids(class_labels)
156
+ return list(class_labels)
157
+
158
+ def _recommended_noise_scale(self) -> float:
159
+ model_type = getattr(self.transformer.config, "model_type", None)
160
+ if model_type in RECOMMENDED_NOISE_BY_MODEL:
161
+ return RECOMMENDED_NOISE_BY_MODEL[model_type]
162
+ image_size = int(self.transformer.config.sample_size)
163
+ return {256: 1.0, 512: 2.0}.get(image_size, 1.0)
164
+
165
+ def _default_cfg(self) -> Dict[str, float]:
166
+ model_type = getattr(self.transformer.config, "model_type", None)
167
+ if model_type in DEFAULT_CFG_BY_MODEL:
168
+ return dict(DEFAULT_CFG_BY_MODEL[model_type])
169
+ return {"guidance_scale": 7.5, "guidance_interval_min": 0.1, "guidance_interval_max": 0.8}
170
+
171
+ @torch.inference_mode()
172
+ def __call__(
173
+ self,
174
+ class_labels: Union[int, str, List[Union[int, str]]],
175
+ num_inference_steps: int = 1,
176
+ guidance_scale: Optional[float] = None,
177
+ guidance_interval_min: Optional[float] = None,
178
+ guidance_interval_max: Optional[float] = None,
179
+ noise_scale: Optional[float] = None,
180
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
181
+ output_type: Optional[str] = "pil",
182
+ return_dict: bool = True,
183
+ ) -> Union[ImagePipelineOutput, Tuple]:
184
+ r"""
185
+ Generate class-conditional images with pMF.
186
+
187
+ Args:
188
+ class_labels (`int`, `str`, or `list`):
189
+ ImageNet class id(s) or label name(s).
190
+ num_inference_steps (`int`, *optional*, defaults to 1):
191
+ Number of flow steps. pMF is typically used with 1 step.
192
+ guidance_scale (`float`, *optional*):
193
+ Classifier-free guidance scale. Defaults to model-specific preset.
194
+ guidance_interval_min (`float`, *optional*):
195
+ Lower bound of the CFG interval in normalized time.
196
+ guidance_interval_max (`float`, *optional*):
197
+ Upper bound of the CFG interval in normalized time.
198
+ noise_scale (`float`, *optional*):
199
+ Initial Gaussian noise scale. Defaults to model-specific preset.
200
+ generator (`torch.Generator`, *optional*):
201
+ Random generator for reproducibility.
202
+ output_type (`str`, *optional*, defaults to `"pil"`):
203
+ Output format: `"pil"`, `"np"`, or `"pt"`.
204
+ return_dict (`bool`, *optional*, defaults to `True`):
205
+ Whether to return an [`~pipelines.ImagePipelineOutput`].
206
+
207
+ Returns:
208
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
209
+ Generated images.
210
+ """
211
+ if num_inference_steps < 1:
212
+ raise ValueError("num_inference_steps must be >= 1.")
213
+ if output_type not in {"pil", "np", "pt"}:
214
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
215
+
216
+ defaults = self._default_cfg()
217
+ if guidance_scale is None:
218
+ guidance_scale = defaults["guidance_scale"]
219
+ if guidance_interval_min is None:
220
+ guidance_interval_min = defaults["guidance_interval_min"]
221
+ if guidance_interval_max is None:
222
+ guidance_interval_max = defaults["guidance_interval_max"]
223
+ if noise_scale is None:
224
+ noise_scale = self._recommended_noise_scale()
225
+
226
+ class_label_ids = self._normalize_class_labels(class_labels)
227
+ batch_size = len(class_label_ids)
228
+ image_size = int(self.transformer.config.sample_size)
229
+ channels = int(self.transformer.config.in_channels)
230
+ null_class_val = int(
231
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
232
+ )
233
+
234
+ latents = randn_tensor(
235
+ shape=(batch_size, channels, image_size, image_size),
236
+ generator=generator,
237
+ device=self._execution_device,
238
+ dtype=self.transformer.dtype,
239
+ ) * noise_scale
240
+
241
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
242
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
243
+
244
+ device = latents.device
245
+ dtype = latents.dtype
246
+ omega = torch.full((batch_size,), guidance_scale, device=device, dtype=dtype)
247
+ t_min = torch.full((batch_size,), guidance_interval_min, device=device, dtype=dtype)
248
+ t_max = torch.full((batch_size,), guidance_interval_max, device=device, dtype=dtype)
249
+
250
+ flow_sigmas = _set_pmf_timesteps(self.scheduler, num_inference_steps, device)
251
+
252
+ for step_index in self.progress_bar(range(num_inference_steps)):
253
+ t = flow_sigmas[step_index]
254
+ t_next = flow_sigmas[step_index + 1]
255
+ h = (t - t_next).expand(batch_size).to(device=device, dtype=dtype)
256
+ t_batch = t.expand(batch_size).to(device=device, dtype=dtype)
257
+
258
+ output = self.transformer(
259
+ sample=latents,
260
+ timestep=t_batch,
261
+ class_labels=class_labels_t,
262
+ h=h,
263
+ omega=omega,
264
+ guidance_interval_min=t_min,
265
+ guidance_interval_max=t_max,
266
+ return_dict=True,
267
+ )
268
+ latents = self.scheduler.step(output.u, self.scheduler.timesteps[step_index], latents).prev_sample
269
+
270
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
271
+ if output_type == "pt":
272
+ images = images_pt
273
+ elif output_type == "np":
274
+ images = images_pt.permute(0, 2, 3, 1).numpy()
275
+ else:
276
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
277
+
278
+ self.maybe_free_model_hooks()
279
+
280
+ if not return_dict:
281
+ return (images,)
282
+ return ImagePipelineOutput(images=images)
283
+
284
+
285
+ PMFPipelineOutput = ImagePipelineOutput
pMF-L-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
pMF-L-32/transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PMFTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "aux_head_depth": 8,
5
+ "bottleneck_dim": 128,
6
+ "depth": 32,
7
+ "embedding_init_constant": 1.0,
8
+ "eval_mode": true,
9
+ "hidden_size": 1024,
10
+ "in_channels": 3,
11
+ "mlp_ratio": 2.6666666666666665,
12
+ "model_type": "pMF-L/32",
13
+ "norm_eps": 1e-06,
14
+ "num_attention_heads": 16,
15
+ "num_cfg_tokens": 4,
16
+ "num_class_embeds": null,
17
+ "num_class_tokens": 8,
18
+ "num_classes": 1000,
19
+ "num_interval_tokens": 2,
20
+ "num_time_tokens": 4,
21
+ "patch_size": 32,
22
+ "sample_size": 512,
23
+ "t_clip_min": 0.05,
24
+ "token_init_constant": 1.0,
25
+ "weight_init_constant": 0.32
26
+ }
pMF-L-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf058e7683ada392c730b15e0ff98f5082039d50744c6a2616615c917aac2fc9
3
+ size 1651955208
pMF-L-32/transformer/transformer_pmf.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from math import sqrt
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.normalization import RMSNorm
14
+ from diffusers.utils import BaseOutput
15
+
16
+
17
+ PMF_PRESET_CONFIGS: Dict[str, Dict[str, object]] = {
18
+ "pMF-B/16": {
19
+ "sample_size": 256,
20
+ "patch_size": 16,
21
+ "hidden_size": 768,
22
+ "depth": 16,
23
+ "num_attention_heads": 12,
24
+ "bottleneck_dim": 128,
25
+ "aux_head_depth": 8,
26
+ },
27
+ "pMF-B/32": {
28
+ "sample_size": 512,
29
+ "patch_size": 32,
30
+ "hidden_size": 768,
31
+ "depth": 16,
32
+ "num_attention_heads": 12,
33
+ "bottleneck_dim": 128,
34
+ "aux_head_depth": 8,
35
+ },
36
+ "pMF-L/16": {
37
+ "sample_size": 256,
38
+ "patch_size": 16,
39
+ "hidden_size": 1024,
40
+ "depth": 32,
41
+ "num_attention_heads": 16,
42
+ "bottleneck_dim": 128,
43
+ "aux_head_depth": 8,
44
+ },
45
+ "pMF-L/32": {
46
+ "sample_size": 512,
47
+ "patch_size": 32,
48
+ "hidden_size": 1024,
49
+ "depth": 32,
50
+ "num_attention_heads": 16,
51
+ "bottleneck_dim": 128,
52
+ "aux_head_depth": 8,
53
+ },
54
+ "pMF-H/16": {
55
+ "sample_size": 256,
56
+ "patch_size": 16,
57
+ "hidden_size": 1280,
58
+ "depth": 48,
59
+ "num_attention_heads": 16,
60
+ "bottleneck_dim": 256,
61
+ "aux_head_depth": 8,
62
+ },
63
+ "pMF-H/32": {
64
+ "sample_size": 512,
65
+ "patch_size": 32,
66
+ "hidden_size": 1280,
67
+ "depth": 48,
68
+ "num_attention_heads": 16,
69
+ "bottleneck_dim": 256,
70
+ "aux_head_depth": 8,
71
+ },
72
+ }
73
+
74
+ RECOMMENDED_NOISE_BY_MODEL: Dict[str, float] = {
75
+ "pMF-B/16": 1.0,
76
+ "pMF-B/32": 2.0,
77
+ "pMF-L/16": 1.0,
78
+ "pMF-L/32": 4.0,
79
+ "pMF-H/16": 2.0,
80
+ "pMF-H/32": 4.0,
81
+ }
82
+
83
+ # Legacy torch repo keys (pmfDiT_*)
84
+ LEGACY_MODEL_ALIASES: Dict[str, str] = {
85
+ "pmfDiT_B_16": "pMF-B/16",
86
+ "pmfDiT_B_32": "pMF-B/32",
87
+ "pmfDiT_L_16": "pMF-L/16",
88
+ "pmfDiT_L_32": "pMF-L/32",
89
+ "pmfDiT_H_16": "pMF-H/16",
90
+ "pmfDiT_H_32": "pMF-H/32",
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class PMFTransformer2DOutput(BaseOutput):
96
+ u: torch.Tensor
97
+ v: Optional[torch.Tensor] = None
98
+
99
+
100
+ def remap_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Map wrapper/backbone keys from legacy checkpoints to native PMFTransformer2DModel keys."""
102
+ remapped: Dict[str, torch.Tensor] = {}
103
+ for key, value in state_dict.items():
104
+ new_key = key
105
+ for prefix in ("transformer.", "net."):
106
+ if new_key.startswith(prefix):
107
+ new_key = new_key[len(prefix) :]
108
+ break
109
+ # Official PyTorch checkpoints use TorchLinear/TorchEmbedding wrappers.
110
+ new_key = new_key.replace("._flax_linear", "").replace("._flax_embedding", "")
111
+ if new_key == "rope_freqs":
112
+ continue
113
+ remapped[new_key] = value
114
+ return remapped
115
+
116
+
117
+ def config_from_legacy(config: Dict[str, object]) -> Dict[str, object]:
118
+ """Build native config kwargs from a legacy config.json dict."""
119
+ model_type = config.get("model_type") or config.get("model_name") or config.get("model_str")
120
+ if model_type in LEGACY_MODEL_ALIASES:
121
+ model_type = LEGACY_MODEL_ALIASES[model_type]
122
+ if model_type not in PMF_PRESET_CONFIGS:
123
+ raise ValueError(f"Unknown pMF preset '{model_type}'. Known: {list(PMF_PRESET_CONFIGS)}")
124
+
125
+ preset = dict(PMF_PRESET_CONFIGS[model_type])
126
+ preset["num_classes"] = int(config.get("num_class_embeds") or config.get("num_classes") or 1000)
127
+ preset["model_type"] = model_type
128
+ if config.get("sample_size") is not None:
129
+ preset["sample_size"] = int(config["sample_size"])
130
+ if config.get("eval_mode") is not None:
131
+ preset["eval_mode"] = bool(config["eval_mode"])
132
+ return preset
133
+
134
+
135
+ def _scaled_linear(
136
+ in_features: int,
137
+ out_features: int,
138
+ *,
139
+ bias: bool = True,
140
+ weight_init: str = "scaled_variance",
141
+ init_constant: float = 1.0,
142
+ bias_init: str = "zeros",
143
+ ) -> nn.Linear:
144
+ layer = nn.Linear(in_features, out_features, bias=bias)
145
+ if weight_init == "scaled_variance":
146
+ std = init_constant / sqrt(in_features)
147
+ nn.init.normal_(layer.weight, std=std)
148
+ elif weight_init == "zeros":
149
+ nn.init.zeros_(layer.weight)
150
+ else:
151
+ raise ValueError(f"Invalid weight_init: {weight_init}")
152
+
153
+ if bias:
154
+ if bias_init == "zeros":
155
+ nn.init.zeros_(layer.bias)
156
+ else:
157
+ raise ValueError(f"Invalid bias_init: {bias_init}")
158
+ return layer
159
+
160
+
161
+ class PMFTimestepEmbedder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ hidden_size: int,
165
+ frequency_embedding_size: int = 256,
166
+ init_constant: float = 1.0,
167
+ ):
168
+ super().__init__()
169
+ init_kwargs = dict(
170
+ out_features=hidden_size,
171
+ bias=True,
172
+ weight_init="scaled_variance",
173
+ init_constant=init_constant,
174
+ bias_init="zeros",
175
+ )
176
+ self.mlp = nn.Sequential(
177
+ _scaled_linear(frequency_embedding_size, **init_kwargs),
178
+ nn.SiLU(),
179
+ _scaled_linear(hidden_size, **init_kwargs),
180
+ )
181
+ self.frequency_embedding_size = frequency_embedding_size
182
+
183
+ @staticmethod
184
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
185
+ half = dim // 2
186
+ freqs = torch.exp(
187
+ -math.log(max_period)
188
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
189
+ / half
190
+ )
191
+ args = t[:, None].float() * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
198
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
199
+ return self.mlp(t_freq)
200
+
201
+
202
+ class PMFLabelEmbedder(nn.Module):
203
+ def __init__(self, num_classes: int, hidden_size: int, init_constant: float = 1.0):
204
+ super().__init__()
205
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
206
+ nn.init.normal_(self.embedding_table.weight, std=init_constant / sqrt(hidden_size))
207
+
208
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
209
+ return self.embedding_table(labels)
210
+
211
+
212
+ class PMFBottleneckPatchEmbedder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ patch_size: int,
217
+ pca_channels: int,
218
+ in_channels: int,
219
+ hidden_size: int,
220
+ bias: bool = True,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = (patch_size, patch_size)
224
+ self.num_patches = (input_size // patch_size) ** 2
225
+ self.proj1 = nn.Conv2d(
226
+ in_channels,
227
+ pca_channels,
228
+ kernel_size=patch_size,
229
+ stride=patch_size,
230
+ bias=bias,
231
+ )
232
+ self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=bias)
233
+
234
+ kh = kw = patch_size
235
+ fan_in = kh * kw * in_channels
236
+ fan_out = pca_channels
237
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
238
+ nn.init.uniform_(self.proj1.weight, -limit, limit)
239
+ fan_in = pca_channels
240
+ fan_out = hidden_size
241
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
242
+ nn.init.uniform_(self.proj2.weight, -limit, limit)
243
+ if bias:
244
+ nn.init.zeros_(self.proj1.bias)
245
+ nn.init.zeros_(self.proj2.bias)
246
+
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ x = self.proj2(self.proj1(x))
249
+ return x.flatten(2).transpose(1, 2)
250
+
251
+
252
+ def precompute_rope_freqs(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
253
+ dim = dim // 2
254
+ grid_size = int(seq_len**0.5)
255
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
256
+ positions = torch.arange(grid_size, dtype=torch.float32)
257
+ freqs_h = torch.einsum("i,j->ij", positions, freqs)
258
+ freqs_w = torch.einsum("i,j->ij", positions, freqs)
259
+ freqs_2d = torch.cat(
260
+ [
261
+ torch.tile(freqs_h[:, None, :], (1, grid_size, 1)),
262
+ torch.tile(freqs_w[None, :, :], (grid_size, 1, 1)),
263
+ ],
264
+ dim=-1,
265
+ )
266
+ real = torch.cos(freqs_2d).reshape(seq_len, dim)
267
+ imag = torch.sin(freqs_2d).reshape(seq_len, dim)
268
+ return torch.complex(real, imag)
269
+
270
+
271
+ def apply_rotary_pos_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
272
+ x_float = x.to(torch.float32)
273
+ x_complex = torch.view_as_complex(x_float.reshape(*x_float.shape[:-1], -1, 2).contiguous())
274
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
275
+ token_count = freqs_cis.shape[1]
276
+ x_rotated = x_complex.clone()
277
+ x_rotated[:, -token_count:, :] = x_complex[:, -token_count:, :] * freqs_cis
278
+ x_out = torch.view_as_real(x_rotated).flatten(-2)
279
+ return x_out.to(x.dtype)
280
+
281
+
282
+ class PMFAttention(nn.Module):
283
+ def __init__(
284
+ self,
285
+ hidden_size: int,
286
+ num_heads: int,
287
+ weight_init_constant: float = 0.32,
288
+ eps: float = 1e-6,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = hidden_size // num_heads
293
+ init_kwargs = dict(
294
+ bias=False,
295
+ weight_init="scaled_variance",
296
+ init_constant=weight_init_constant,
297
+ )
298
+ self.q_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
299
+ self.k_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
300
+ self.v_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
301
+ self.out_proj = _scaled_linear(hidden_size, hidden_size, **init_kwargs)
302
+ self.q_norm = RMSNorm(self.head_dim, eps=eps)
303
+ self.k_norm = RMSNorm(self.head_dim, eps=eps)
304
+
305
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
306
+ batch_size, seq_len, channels = x.shape
307
+ q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
308
+ k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
309
+ v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+ q = apply_rotary_pos_emb(q, rope_freqs)
314
+ k = apply_rotary_pos_emb(k, rope_freqs)
315
+
316
+ query = q / math.sqrt(self.head_dim)
317
+ attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, k)
318
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
319
+ attn = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
320
+ attn = attn.reshape(batch_size, seq_len, channels)
321
+ return self.out_proj(attn)
322
+
323
+
324
+ class PMFSwiGLUMlp(nn.Module):
325
+ def __init__(self, dim: int, hidden_dim: int, weight_init_constant: float = 0.32):
326
+ super().__init__()
327
+ init_kwargs = dict(bias=False, weight_init="scaled_variance", init_constant=weight_init_constant)
328
+ self.w1 = _scaled_linear(dim, hidden_dim, **init_kwargs)
329
+ self.w3 = _scaled_linear(dim, hidden_dim, **init_kwargs)
330
+ self.w2 = _scaled_linear(hidden_dim, dim, **init_kwargs)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
334
+
335
+
336
+ class PMFTransformerBlock(nn.Module):
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 8 / 3,
342
+ weight_init_constant: float = 0.32,
343
+ eps: float = 1e-6,
344
+ ):
345
+ super().__init__()
346
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
347
+ self.attn = PMFAttention(hidden_size, num_heads, weight_init_constant=weight_init_constant, eps=eps)
348
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ if hidden_size > 1024:
351
+ mlp_hidden_dim = (mlp_hidden_dim + 7) // 8 * 8
352
+ self.mlp = PMFSwiGLUMlp(hidden_size, mlp_hidden_dim, weight_init_constant=weight_init_constant)
353
+ self.attn_scale = nn.Parameter(torch.zeros(hidden_size))
354
+ self.mlp_scale = nn.Parameter(torch.zeros(hidden_size))
355
+
356
+ def forward(self, x: torch.Tensor, rope_freqs: torch.Tensor) -> torch.Tensor:
357
+ x = x + self.attn(self.norm1(x), rope_freqs) * self.attn_scale
358
+ x = x + self.mlp(self.norm2(x)) * self.mlp_scale
359
+ return x
360
+
361
+
362
+ class PMFFinalLayer(nn.Module):
363
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, eps: float = 1e-6):
364
+ super().__init__()
365
+ self.norm = RMSNorm(hidden_size, eps=eps)
366
+ self.linear = _scaled_linear(
367
+ hidden_size,
368
+ patch_size * patch_size * out_channels,
369
+ bias=True,
370
+ weight_init="zeros",
371
+ bias_init="zeros",
372
+ )
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ return self.linear(self.norm(x))
376
+
377
+
378
+ class PMFTransformer2DModel(ModelMixin, ConfigMixin):
379
+ """Native diffusers implementation of the pMF DiT backbone."""
380
+
381
+ _supports_gradient_checkpointing = True
382
+ _skip_layerwise_casting_patterns = ["pos_embed", "rope_freqs"]
383
+
384
+ @register_to_config
385
+ def __init__(
386
+ self,
387
+ sample_size: int = 256,
388
+ patch_size: int = 16,
389
+ in_channels: int = 3,
390
+ hidden_size: int = 768,
391
+ depth: int = 16,
392
+ num_attention_heads: int = 12,
393
+ mlp_ratio: float = 8 / 3,
394
+ num_classes: int = 1000,
395
+ bottleneck_dim: int = 128,
396
+ aux_head_depth: int = 8,
397
+ num_class_tokens: int = 8,
398
+ num_time_tokens: int = 4,
399
+ num_cfg_tokens: int = 4,
400
+ num_interval_tokens: int = 2,
401
+ token_init_constant: float = 1.0,
402
+ embedding_init_constant: float = 1.0,
403
+ weight_init_constant: float = 0.32,
404
+ eval_mode: bool = True,
405
+ model_type: str | None = None,
406
+ num_class_embeds: int | None = None,
407
+ t_clip_min: float = 0.05,
408
+ norm_eps: float = 1e-6,
409
+ ):
410
+ super().__init__()
411
+ if num_class_embeds is not None:
412
+ num_classes = int(num_class_embeds)
413
+ if model_type in LEGACY_MODEL_ALIASES:
414
+ model_type = LEGACY_MODEL_ALIASES[model_type]
415
+ if model_type in PMF_PRESET_CONFIGS:
416
+ preset = PMF_PRESET_CONFIGS[model_type]
417
+ sample_size = int(preset["sample_size"])
418
+ patch_size = int(preset["patch_size"])
419
+ hidden_size = int(preset["hidden_size"])
420
+ depth = int(preset["depth"])
421
+ num_attention_heads = int(preset["num_attention_heads"])
422
+ bottleneck_dim = int(preset["bottleneck_dim"])
423
+ aux_head_depth = int(preset["aux_head_depth"])
424
+
425
+ self.sample_size = sample_size
426
+ self.patch_size = patch_size
427
+ self.in_channels = in_channels
428
+ self.out_channels = in_channels
429
+ self.hidden_size = hidden_size
430
+ self.depth = depth
431
+ self.num_attention_heads = num_attention_heads
432
+ self.aux_head_depth = aux_head_depth
433
+ self.num_class_tokens = num_class_tokens
434
+ self.num_time_tokens = num_time_tokens
435
+ self.num_cfg_tokens = num_cfg_tokens
436
+ self.num_interval_tokens = num_interval_tokens
437
+ self.prefix_tokens = (
438
+ num_class_tokens + num_cfg_tokens + 2 * num_interval_tokens + num_time_tokens
439
+ )
440
+ self.t_clip_min = t_clip_min
441
+ self.eval_mode = eval_mode
442
+ self.gradient_checkpointing = False
443
+
444
+ self.x_embedder = PMFBottleneckPatchEmbedder(
445
+ sample_size,
446
+ patch_size,
447
+ bottleneck_dim,
448
+ in_channels,
449
+ hidden_size,
450
+ bias=True,
451
+ )
452
+ embed_kwargs = dict(hidden_size=hidden_size, init_constant=embedding_init_constant)
453
+ self.h_embedder = PMFTimestepEmbedder(**embed_kwargs)
454
+ self.omega_embedder = PMFTimestepEmbedder(**embed_kwargs)
455
+ self.cfg_t_start_embedder = PMFTimestepEmbedder(**embed_kwargs)
456
+ self.cfg_t_end_embedder = PMFTimestepEmbedder(**embed_kwargs)
457
+ self.y_embedder = PMFLabelEmbedder(num_classes, hidden_size, init_constant=embedding_init_constant)
458
+
459
+ token_std = token_init_constant / math.sqrt(hidden_size)
460
+ self.time_tokens = nn.Parameter(torch.randn(1, num_time_tokens, hidden_size) * token_std)
461
+ self.class_tokens = nn.Parameter(torch.randn(1, num_class_tokens, hidden_size) * token_std)
462
+ self.omega_tokens = nn.Parameter(torch.randn(1, num_cfg_tokens, hidden_size) * token_std)
463
+ self.t_min_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
464
+ self.t_max_tokens = nn.Parameter(torch.randn(1, num_interval_tokens, hidden_size) * token_std)
465
+
466
+ total_tokens = self.x_embedder.num_patches + self.prefix_tokens
467
+ self.pos_embed = nn.Parameter(torch.randn(1, total_tokens, hidden_size) * 0.02)
468
+
469
+ head_dim = hidden_size // num_attention_heads
470
+ self.register_buffer(
471
+ "rope_freqs",
472
+ precompute_rope_freqs(head_dim, self.x_embedder.num_patches),
473
+ persistent=False,
474
+ )
475
+
476
+ shared_depth = depth - aux_head_depth
477
+ block_kwargs = dict(
478
+ hidden_size=hidden_size,
479
+ num_heads=num_attention_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ weight_init_constant=weight_init_constant,
482
+ eps=norm_eps,
483
+ )
484
+ self.shared_blocks = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(shared_depth)])
485
+ self.u_heads = nn.ModuleList([PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth)])
486
+ self.v_heads = nn.ModuleList(
487
+ [PMFTransformerBlock(**block_kwargs) for _ in range(aux_head_depth if not eval_mode else 0)]
488
+ )
489
+ self.u_final_layer = PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
490
+ self.v_final_layer = (
491
+ PMFFinalLayer(hidden_size, patch_size, in_channels, eps=norm_eps)
492
+ if not eval_mode
493
+ else None
494
+ )
495
+
496
+ def _build_sequence(
497
+ self,
498
+ sample: torch.Tensor,
499
+ h: torch.Tensor,
500
+ omega: torch.Tensor,
501
+ t_min: torch.Tensor,
502
+ t_max: torch.Tensor,
503
+ class_labels: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ x_embed = self.x_embedder(sample)
506
+ h_embed = self.h_embedder(h)
507
+ omega_embed = self.omega_embedder(1 - 1 / omega)
508
+ t_min_embed = self.cfg_t_start_embedder(t_min)
509
+ t_max_embed = self.cfg_t_end_embedder(t_max)
510
+ y_embed = self.y_embedder(class_labels)
511
+
512
+ time_tokens = self.time_tokens + h_embed.unsqueeze(1)
513
+ omega_tokens = self.omega_tokens + omega_embed.unsqueeze(1)
514
+ t_min_tokens = self.t_min_tokens + t_min_embed.unsqueeze(1)
515
+ t_max_tokens = self.t_max_tokens + t_max_embed.unsqueeze(1)
516
+ class_tokens = self.class_tokens + y_embed.unsqueeze(1)
517
+
518
+ seq = torch.cat(
519
+ [class_tokens, omega_tokens, t_min_tokens, t_max_tokens, time_tokens, x_embed],
520
+ dim=1,
521
+ )
522
+ return seq + self.pos_embed
523
+
524
+ def _unpatchify(self, tokens: torch.Tensor) -> torch.Tensor:
525
+ batch_size = tokens.shape[0]
526
+ patch = self.patch_size
527
+ grid = int(tokens.shape[1] ** 0.5)
528
+ channels = self.out_channels
529
+ x = tokens.reshape(batch_size, grid, grid, patch, patch, channels)
530
+ x = torch.einsum("nhwpqc->nchpwq", x)
531
+ return x.reshape(batch_size, channels, grid * patch, grid * patch)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ class_labels: torch.Tensor,
538
+ h: Optional[torch.Tensor] = None,
539
+ omega: Optional[torch.Tensor] = None,
540
+ guidance_interval_min: Optional[torch.Tensor] = None,
541
+ guidance_interval_max: Optional[torch.Tensor] = None,
542
+ return_dict: bool = True,
543
+ ) -> PMFTransformer2DOutput | Tuple[torch.Tensor, Optional[torch.Tensor]]:
544
+ batch_size = sample.shape[0]
545
+ timestep = self._expand_batch(timestep, batch_size, sample.device, sample.dtype)
546
+ h = self._expand_batch(h if h is not None else timestep, batch_size, sample.device, sample.dtype)
547
+ omega = self._expand_batch(
548
+ omega if omega is not None else torch.ones(batch_size, device=sample.device),
549
+ batch_size,
550
+ sample.device,
551
+ sample.dtype,
552
+ )
553
+ guidance_interval_min = self._expand_batch(
554
+ guidance_interval_min
555
+ if guidance_interval_min is not None
556
+ else torch.zeros(batch_size, device=sample.device),
557
+ batch_size,
558
+ sample.device,
559
+ sample.dtype,
560
+ )
561
+ guidance_interval_max = self._expand_batch(
562
+ guidance_interval_max
563
+ if guidance_interval_max is not None
564
+ else torch.ones(batch_size, device=sample.device),
565
+ batch_size,
566
+ sample.device,
567
+ sample.dtype,
568
+ )
569
+
570
+ seq = self._build_sequence(sample, h, omega, guidance_interval_min, guidance_interval_max, class_labels)
571
+ rope_freqs = self.rope_freqs.to(device=sample.device)
572
+
573
+ for block in self.shared_blocks:
574
+ if self.training and self.gradient_checkpointing:
575
+ seq = torch.utils.checkpoint.checkpoint(block, seq, rope_freqs, use_reentrant=False)
576
+ else:
577
+ seq = block(seq, rope_freqs)
578
+
579
+ u_seq = v_seq = seq
580
+ for block in self.u_heads:
581
+ if self.training and self.gradient_checkpointing:
582
+ u_seq = torch.utils.checkpoint.checkpoint(block, u_seq, rope_freqs, use_reentrant=False)
583
+ else:
584
+ u_seq = block(u_seq, rope_freqs)
585
+
586
+ for block in self.v_heads:
587
+ if self.training and self.gradient_checkpointing:
588
+ v_seq = torch.utils.checkpoint.checkpoint(block, v_seq, rope_freqs, use_reentrant=False)
589
+ else:
590
+ v_seq = block(v_seq, rope_freqs)
591
+
592
+ u_tokens = u_seq[:, self.prefix_tokens :]
593
+ u_pred = self._unpatchify(self.u_final_layer(u_tokens))
594
+ t = timestep.reshape(batch_size, 1, 1, 1)
595
+ u = (sample - u_pred) / torch.clamp(t, min=self.t_clip_min)
596
+
597
+ v = None
598
+ if self.v_final_layer is not None:
599
+ v_tokens = v_seq[:, self.prefix_tokens :]
600
+ v_pred = self._unpatchify(self.v_final_layer(v_tokens))
601
+ v = (sample - v_pred) / torch.clamp(t, min=self.t_clip_min)
602
+
603
+ if not return_dict:
604
+ return (u, v)
605
+ return PMFTransformer2DOutput(u=u, v=v)
606
+
607
+ @staticmethod
608
+ def _expand_batch(
609
+ value: torch.Tensor,
610
+ batch_size: int,
611
+ device: torch.device,
612
+ dtype: torch.dtype,
613
+ ) -> torch.Tensor:
614
+ value = torch.as_tensor(value, device=device, dtype=dtype)
615
+ if value.ndim == 0:
616
+ value = value.reshape(1)
617
+ if value.shape[0] == 1 and batch_size > 1:
618
+ value = value.expand(batch_size)
619
+ return value.reshape(batch_size)
620
+
621
+ @classmethod
622
+ def from_pmf_checkpoint(
623
+ cls,
624
+ checkpoint_path: str,
625
+ model_type: str | None = None,
626
+ map_location: str = "cpu",
627
+ strict: bool = False,
628
+ ) -> Tuple["PMFTransformer2DModel", Dict[str, object]]:
629
+ checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
630
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
631
+ state_dict = checkpoint["state_dict"]
632
+ else:
633
+ state_dict = checkpoint
634
+
635
+ if model_type is None:
636
+ for key in ("model_type", "model_str", "model"):
637
+ if isinstance(checkpoint, dict) and key in checkpoint:
638
+ model_type = checkpoint[key]
639
+ break
640
+ if model_type in LEGACY_MODEL_ALIASES:
641
+ model_type = LEGACY_MODEL_ALIASES[model_type]
642
+ if model_type is None:
643
+ raise ValueError("model_type is required when it cannot be inferred from the checkpoint.")
644
+
645
+ config = dict(PMF_PRESET_CONFIGS[model_type])
646
+ config["model_type"] = model_type
647
+ config["eval_mode"] = True
648
+ model = cls(**config)
649
+ model.load_state_dict(remap_legacy_state_dict(state_dict), strict=strict)
650
+ metadata = {"checkpoint_path": checkpoint_path, "model_type": model_type}
651
+ return model, metadata
652
+
653
+ def to_pmf_checkpoint(self, prefix: str = "net.") -> Dict[str, torch.Tensor]:
654
+ state_dict: Dict[str, torch.Tensor] = {}
655
+ for key, value in self.state_dict().items():
656
+ state_dict[f"{prefix}{key}"] = value.detach().cpu()
657
+ return state_dict
658
+
659
+ @property
660
+ def net(self):
661
+ return self
662
+
663
+
664
+ PMFDiffusersModel = PMFTransformer2DModel