benzweijia commited on
Commit
9294bc7
·
verified ·
1 Parent(s): b1d39cd

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc +0 -0
  3. adv_grpo/__pycache__/discriminator.cpython-310.pyc +0 -0
  4. adv_grpo/__pycache__/ema.cpython-310.pyc +0 -0
  5. adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc +0 -0
  6. adv_grpo/__pycache__/inflated_layers.cpython-310.pyc +0 -0
  7. adv_grpo/__pycache__/inflated_lib.cpython-310.pyc +0 -0
  8. adv_grpo/__pycache__/ocr.cpython-310.pyc +0 -0
  9. adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc +0 -0
  10. adv_grpo/__pycache__/pick_score_training.cpython-310.pyc +0 -0
  11. adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc +0 -0
  12. adv_grpo/__pycache__/prompts.cpython-310.pyc +0 -0
  13. adv_grpo/__pycache__/rewards.cpython-310.pyc +0 -0
  14. adv_grpo/__pycache__/stat_tracking.cpython-310.pyc +0 -0
  15. adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc +0 -0
  16. adv_grpo/aesthetic_scorer.py +53 -0
  17. adv_grpo/assets/activities.txt +3 -0
  18. adv_grpo/assets/activities_v0.txt +3 -0
  19. adv_grpo/assets/flow_grpo_fast.png +3 -0
  20. adv_grpo/assets/imagenet_classes.txt +1000 -0
  21. adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth +3 -0
  22. adv_grpo/assets/simple_animals.txt +45 -0
  23. adv_grpo/assets/simple_ocr_animals.txt +5 -0
  24. adv_grpo/assets/simple_ocr_animals_digit1.txt +45 -0
  25. adv_grpo/assets/simple_ocr_animals_digit3.txt +45 -0
  26. adv_grpo/assets/simple_ocr_animals_digit5.txt +50 -0
  27. adv_grpo/assets/test.jpg +0 -0
  28. adv_grpo/clip_scorer.py +97 -0
  29. adv_grpo/conv_gradfix.py +345 -0
  30. adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc +0 -0
  31. adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc +0 -0
  32. adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc +0 -0
  33. adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py +255 -0
  34. adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py +187 -0
  35. adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py +198 -0
  36. adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py +1081 -0
  37. adv_grpo/diffusers_patch/sd3_sde_with_logprob.py +139 -0
  38. adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py +144 -0
  39. adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py +144 -0
  40. adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py +373 -0
  41. adv_grpo/diffusers_patch/wan_prompt_embedding.py +97 -0
  42. adv_grpo/ema.py +88 -0
  43. adv_grpo/imagereward_scorer.py +40 -0
  44. adv_grpo/inflated_layers.py +305 -0
  45. adv_grpo/inflated_lib.py +346 -0
  46. adv_grpo/ocr.py +138 -0
  47. adv_grpo/pick_score_training.py +385 -0
  48. adv_grpo/pickscore_scorer.py +70 -0
  49. adv_grpo/pickscore_scorer_constractive.py +89 -0
  50. adv_grpo/pickscore_scorer_patch.py +78 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ adv_grpo/assets/flow_grpo_fast.png filter=lfs diff=lfs merge=lfs -text
adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc ADDED
Binary file (9.28 kB). View file
 
adv_grpo/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
adv_grpo/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc ADDED
Binary file (625 Bytes). View file
 
adv_grpo/__pycache__/inflated_layers.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
adv_grpo/__pycache__/inflated_lib.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
adv_grpo/__pycache__/ocr.cpython-310.pyc ADDED
Binary file (4.41 kB). View file
 
adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
adv_grpo/__pycache__/pick_score_training.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc ADDED
Binary file (2.44 kB). View file
 
adv_grpo/__pycache__/prompts.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
adv_grpo/__pycache__/rewards.cpython-310.pyc ADDED
Binary file (29.7 kB). View file
 
adv_grpo/__pycache__/stat_tracking.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
adv_grpo/aesthetic_scorer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py
2
+
3
+ from importlib import resources
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from transformers import CLIPModel, CLIPProcessor
8
+ from PIL import Image
9
+
10
+ ASSETS_PATH = resources.files("adv_grpo.assets")
11
+
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.layers = nn.Sequential(
17
+ nn.Linear(768, 1024),
18
+ nn.Dropout(0.2),
19
+ nn.Linear(1024, 128),
20
+ nn.Dropout(0.2),
21
+ nn.Linear(128, 64),
22
+ nn.Dropout(0.1),
23
+ nn.Linear(64, 16),
24
+ nn.Linear(16, 1),
25
+ )
26
+
27
+ @torch.no_grad()
28
+ def forward(self, embed):
29
+ return self.layers(embed)
30
+
31
+
32
+ class AestheticScorer(torch.nn.Module):
33
+ def __init__(self, dtype):
34
+ super().__init__()
35
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
36
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
37
+ self.mlp = MLP()
38
+ state_dict = torch.load(
39
+ ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")
40
+ )
41
+ self.mlp.load_state_dict(state_dict)
42
+ self.dtype = dtype
43
+ self.eval()
44
+
45
+ @torch.no_grad()
46
+ def __call__(self, images):
47
+ device = next(self.parameters()).device
48
+ inputs = self.processor(images=images, return_tensors="pt")
49
+ inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
50
+ embed = self.clip.get_image_features(**inputs)
51
+ # normalize embedding
52
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
53
+ return self.mlp(embed).squeeze(1)
adv_grpo/assets/activities.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ washing the dishes
2
+ riding a bike
3
+ playing chess
adv_grpo/assets/activities_v0.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ washing the dishes
2
+ riding a bike
3
+ playing chess
adv_grpo/assets/flow_grpo_fast.png ADDED

Git LFS Details

  • SHA256: 35709d674818e29d39728e036479e51b8e015bcc7caf4dc54eb3e4f41cc05ab1
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
adv_grpo/assets/imagenet_classes.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tench, Tinca tinca
2
+ goldfish, Carassius auratus
3
+ great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
4
+ tiger shark, Galeocerdo cuvieri
5
+ hammerhead, hammerhead shark
6
+ electric ray, crampfish, numbfish, torpedo
7
+ stingray
8
+ cock
9
+ hen
10
+ ostrich, Struthio camelus
11
+ brambling, Fringilla montifringilla
12
+ goldfinch, Carduelis carduelis
13
+ house finch, linnet, Carpodacus mexicanus
14
+ junco, snowbird
15
+ indigo bunting, indigo finch, indigo bird, Passerina cyanea
16
+ robin, American robin, Turdus migratorius
17
+ bulbul
18
+ jay
19
+ magpie
20
+ chickadee
21
+ water ouzel, dipper
22
+ kite
23
+ bald eagle, American eagle, Haliaeetus leucocephalus
24
+ vulture
25
+ great grey owl, great gray owl, Strix nebulosa
26
+ European fire salamander, Salamandra salamandra
27
+ common newt, Triturus vulgaris
28
+ eft
29
+ spotted salamander, Ambystoma maculatum
30
+ axolotl, mud puppy, Ambystoma mexicanum
31
+ bullfrog, Rana catesbeiana
32
+ tree frog, tree-frog
33
+ tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
34
+ loggerhead, loggerhead turtle, Caretta caretta
35
+ leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
36
+ mud turtle
37
+ terrapin
38
+ box turtle, box tortoise
39
+ banded gecko
40
+ common iguana, iguana, Iguana iguana
41
+ American chameleon, anole, Anolis carolinensis
42
+ whiptail, whiptail lizard
43
+ agama
44
+ frilled lizard, Chlamydosaurus kingi
45
+ alligator lizard
46
+ Gila monster, Heloderma suspectum
47
+ green lizard, Lacerta viridis
48
+ African chameleon, Chamaeleo chamaeleon
49
+ Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
50
+ African crocodile, Nile crocodile, Crocodylus niloticus
51
+ American alligator, Alligator mississipiensis
52
+ triceratops
53
+ thunder snake, worm snake, Carphophis amoenus
54
+ ringneck snake, ring-necked snake, ring snake
55
+ hognose snake, puff adder, sand viper
56
+ green snake, grass snake
57
+ king snake, kingsnake
58
+ garter snake, grass snake
59
+ water snake
60
+ vine snake
61
+ night snake, Hypsiglena torquata
62
+ boa constrictor, Constrictor constrictor
63
+ rock python, rock snake, Python sebae
64
+ Indian cobra, Naja naja
65
+ green mamba
66
+ sea snake
67
+ horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
68
+ diamondback, diamondback rattlesnake, Crotalus adamanteus
69
+ sidewinder, horned rattlesnake, Crotalus cerastes
70
+ trilobite
71
+ harvestman, daddy longlegs, Phalangium opilio
72
+ scorpion
73
+ black and gold garden spider, Argiope aurantia
74
+ barn spider, Araneus cavaticus
75
+ garden spider, Aranea diademata
76
+ black widow, Latrodectus mactans
77
+ tarantula
78
+ wolf spider, hunting spider
79
+ tick
80
+ centipede
81
+ black grouse
82
+ ptarmigan
83
+ ruffed grouse, partridge, Bonasa umbellus
84
+ prairie chicken, prairie grouse, prairie fowl
85
+ peacock
86
+ quail
87
+ partridge
88
+ African grey, African gray, Psittacus erithacus
89
+ macaw
90
+ sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
91
+ lorikeet
92
+ coucal
93
+ bee eater
94
+ hornbill
95
+ hummingbird
96
+ jacamar
97
+ toucan
98
+ drake
99
+ red-breasted merganser, Mergus serrator
100
+ goose
101
+ black swan, Cygnus atratus
102
+ tusker
103
+ echidna, spiny anteater, anteater
104
+ platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
105
+ wallaby, brush kangaroo
106
+ koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
107
+ wombat
108
+ jellyfish
109
+ sea anemone, anemone
110
+ brain coral
111
+ flatworm, platyhelminth
112
+ nematode, nematode worm, roundworm
113
+ conch
114
+ snail
115
+ slug
116
+ sea slug, nudibranch
117
+ chiton, coat-of-mail shell, sea cradle, polyplacophore
118
+ chambered nautilus, pearly nautilus, nautilus
119
+ Dungeness crab, Cancer magister
120
+ rock crab, Cancer irroratus
121
+ fiddler crab
122
+ king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
123
+ American lobster, Northern lobster, Maine lobster, Homarus americanus
124
+ spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
125
+ crayfish, crawfish, crawdad, crawdaddy
126
+ hermit crab
127
+ isopod
128
+ white stork, Ciconia ciconia
129
+ black stork, Ciconia nigra
130
+ spoonbill
131
+ flamingo
132
+ little blue heron, Egretta caerulea
133
+ American egret, great white heron, Egretta albus
134
+ bittern
135
+ crane
136
+ limpkin, Aramus pictus
137
+ European gallinule, Porphyrio porphyrio
138
+ American coot, marsh hen, mud hen, water hen, Fulica americana
139
+ bustard
140
+ ruddy turnstone, Arenaria interpres
141
+ red-backed sandpiper, dunlin, Erolia alpina
142
+ redshank, Tringa totanus
143
+ dowitcher
144
+ oystercatcher, oyster catcher
145
+ pelican
146
+ king penguin, Aptenodytes patagonica
147
+ albatross, mollymawk
148
+ grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
149
+ killer whale, killer, orca, grampus, sea wolf, Orcinus orca
150
+ dugong, Dugong dugon
151
+ sea lion
152
+ Chihuahua
153
+ Japanese spaniel
154
+ Maltese dog, Maltese terrier, Maltese
155
+ Pekinese, Pekingese, Peke
156
+ Shih-Tzu
157
+ Blenheim spaniel
158
+ papillon
159
+ toy terrier
160
+ Rhodesian ridgeback
161
+ Afghan hound, Afghan
162
+ basset, basset hound
163
+ beagle
164
+ bloodhound, sleuthhound
165
+ bluetick
166
+ black-and-tan coonhound
167
+ Walker hound, Walker foxhound
168
+ English foxhound
169
+ redbone
170
+ borzoi, Russian wolfhound
171
+ Irish wolfhound
172
+ Italian greyhound
173
+ whippet
174
+ Ibizan hound, Ibizan Podenco
175
+ Norwegian elkhound, elkhound
176
+ otterhound, otter hound
177
+ Saluki, gazelle hound
178
+ Scottish deerhound, deerhound
179
+ Weimaraner
180
+ Staffordshire bullterrier, Staffordshire bull terrier
181
+ American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
182
+ Bedlington terrier
183
+ Border terrier
184
+ Kerry blue terrier
185
+ Irish terrier
186
+ Norfolk terrier
187
+ Norwich terrier
188
+ Yorkshire terrier
189
+ wire-haired fox terrier
190
+ Lakeland terrier
191
+ Sealyham terrier, Sealyham
192
+ Airedale, Airedale terrier
193
+ cairn, cairn terrier
194
+ Australian terrier
195
+ Dandie Dinmont, Dandie Dinmont terrier
196
+ Boston bull, Boston terrier
197
+ miniature schnauzer
198
+ giant schnauzer
199
+ standard schnauzer
200
+ Scotch terrier, Scottish terrier, Scottie
201
+ Tibetan terrier, chrysanthemum dog
202
+ silky terrier, Sydney silky
203
+ soft-coated wheaten terrier
204
+ West Highland white terrier
205
+ Lhasa, Lhasa apso
206
+ flat-coated retriever
207
+ curly-coated retriever
208
+ golden retriever
209
+ Labrador retriever
210
+ Chesapeake Bay retriever
211
+ German short-haired pointer
212
+ vizsla, Hungarian pointer
213
+ English setter
214
+ Irish setter, red setter
215
+ Gordon setter
216
+ Brittany spaniel
217
+ clumber, clumber spaniel
218
+ English springer, English springer spaniel
219
+ Welsh springer spaniel
220
+ cocker spaniel, English cocker spaniel, cocker
221
+ Sussex spaniel
222
+ Irish water spaniel
223
+ kuvasz
224
+ schipperke
225
+ groenendael
226
+ malinois
227
+ briard
228
+ kelpie
229
+ komondor
230
+ Old English sheepdog, bobtail
231
+ Shetland sheepdog, Shetland sheep dog, Shetland
232
+ collie
233
+ Border collie
234
+ Bouvier des Flandres, Bouviers des Flandres
235
+ Rottweiler
236
+ German shepherd, German shepherd dog, German police dog, alsatian
237
+ Doberman, Doberman pinscher
238
+ miniature pinscher
239
+ Greater Swiss Mountain dog
240
+ Bernese mountain dog
241
+ Appenzeller
242
+ EntleBucher
243
+ boxer
244
+ bull mastiff
245
+ Tibetan mastiff
246
+ French bulldog
247
+ Great Dane
248
+ Saint Bernard, St Bernard
249
+ Eskimo dog, husky
250
+ malamute, malemute, Alaskan malamute
251
+ Siberian husky
252
+ dalmatian, coach dog, carriage dog
253
+ affenpinscher, monkey pinscher, monkey dog
254
+ basenji
255
+ pug, pug-dog
256
+ Leonberg
257
+ Newfoundland, Newfoundland dog
258
+ Great Pyrenees
259
+ Samoyed, Samoyede
260
+ Pomeranian
261
+ chow, chow chow
262
+ keeshond
263
+ Brabancon griffon
264
+ Pembroke, Pembroke Welsh corgi
265
+ Cardigan, Cardigan Welsh corgi
266
+ toy poodle
267
+ miniature poodle
268
+ standard poodle
269
+ Mexican hairless
270
+ timber wolf, grey wolf, gray wolf, Canis lupus
271
+ white wolf, Arctic wolf, Canis lupus tundrarum
272
+ red wolf, maned wolf, Canis rufus, Canis niger
273
+ coyote, prairie wolf, brush wolf, Canis latrans
274
+ dingo, warrigal, warragal, Canis dingo
275
+ dhole, Cuon alpinus
276
+ African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
277
+ hyena, hyaena
278
+ red fox, Vulpes vulpes
279
+ kit fox, Vulpes macrotis
280
+ Arctic fox, white fox, Alopex lagopus
281
+ grey fox, gray fox, Urocyon cinereoargenteus
282
+ tabby, tabby cat
283
+ tiger cat
284
+ Persian cat
285
+ Siamese cat, Siamese
286
+ Egyptian cat
287
+ cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
288
+ lynx, catamount
289
+ leopard, Panthera pardus
290
+ snow leopard, ounce, Panthera uncia
291
+ jaguar, panther, Panthera onca, Felis onca
292
+ lion, king of beasts, Panthera leo
293
+ tiger, Panthera tigris
294
+ cheetah, chetah, Acinonyx jubatus
295
+ brown bear, bruin, Ursus arctos
296
+ American black bear, black bear, Ursus americanus, Euarctos americanus
297
+ ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
298
+ sloth bear, Melursus ursinus, Ursus ursinus
299
+ mongoose
300
+ meerkat, mierkat
301
+ tiger beetle
302
+ ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
303
+ ground beetle, carabid beetle
304
+ long-horned beetle, longicorn, longicorn beetle
305
+ leaf beetle, chrysomelid
306
+ dung beetle
307
+ rhinoceros beetle
308
+ weevil
309
+ fly
310
+ bee
311
+ ant, emmet, pismire
312
+ grasshopper, hopper
313
+ cricket
314
+ walking stick, walkingstick, stick insect
315
+ cockroach, roach
316
+ mantis, mantid
317
+ cicada, cicala
318
+ leafhopper
319
+ lacewing, lacewing fly
320
+ dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
321
+ damselfly
322
+ admiral
323
+ ringlet, ringlet butterfly
324
+ monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
325
+ cabbage butterfly
326
+ sulphur butterfly, sulfur butterfly
327
+ lycaenid, lycaenid butterfly
328
+ starfish, sea star
329
+ sea urchin
330
+ sea cucumber, holothurian
331
+ wood rabbit, cottontail, cottontail rabbit
332
+ hare
333
+ Angora, Angora rabbit
334
+ hamster
335
+ porcupine, hedgehog
336
+ fox squirrel, eastern fox squirrel, Sciurus niger
337
+ marmot
338
+ beaver
339
+ guinea pig, Cavia cobaya
340
+ sorrel
341
+ zebra
342
+ hog, pig, grunter, squealer, Sus scrofa
343
+ wild boar, boar, Sus scrofa
344
+ warthog
345
+ hippopotamus, hippo, river horse, Hippopotamus amphibius
346
+ ox
347
+ water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348
+ bison
349
+ ram, tup
350
+ bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
351
+ ibex, Capra ibex
352
+ hartebeest
353
+ impala, Aepyceros melampus
354
+ gazelle
355
+ Arabian camel, dromedary, Camelus dromedarius
356
+ llama
357
+ weasel
358
+ mink
359
+ polecat, fitch, foulmart, foumart, Mustela putorius
360
+ black-footed ferret, ferret, Mustela nigripes
361
+ otter
362
+ skunk, polecat, wood pussy
363
+ badger
364
+ armadillo
365
+ three-toed sloth, ai, Bradypus tridactylus
366
+ orangutan, orang, orangutang, Pongo pygmaeus
367
+ gorilla, Gorilla gorilla
368
+ chimpanzee, chimp, Pan troglodytes
369
+ gibbon, Hylobates lar
370
+ siamang, Hylobates syndactylus, Symphalangus syndactylus
371
+ guenon, guenon monkey
372
+ patas, hussar monkey, Erythrocebus patas
373
+ baboon
374
+ macaque
375
+ langur
376
+ colobus, colobus monkey
377
+ proboscis monkey, Nasalis larvatus
378
+ marmoset
379
+ capuchin, ringtail, Cebus capucinus
380
+ howler monkey, howler
381
+ titi, titi monkey
382
+ spider monkey, Ateles geoffroyi
383
+ squirrel monkey, Saimiri sciureus
384
+ Madagascar cat, ring-tailed lemur, Lemur catta
385
+ indri, indris, Indri indri, Indri brevicaudatus
386
+ Indian elephant, Elephas maximus
387
+ African elephant, Loxodonta africana
388
+ lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
389
+ giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
390
+ barracouta, snoek
391
+ eel
392
+ coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
393
+ rock beauty, Holocanthus tricolor
394
+ anemone fish
395
+ sturgeon
396
+ gar, garfish, garpike, billfish, Lepisosteus osseus
397
+ lionfish
398
+ puffer, pufferfish, blowfish, globefish
399
+ abacus
400
+ abaya
401
+ academic gown, academic robe, judge's robe
402
+ accordion, piano accordion, squeeze box
403
+ acoustic guitar
404
+ aircraft carrier, carrier, flattop, attack aircraft carrier
405
+ airliner
406
+ airship, dirigible
407
+ altar
408
+ ambulance
409
+ amphibian, amphibious vehicle
410
+ analog clock
411
+ apiary, bee house
412
+ apron
413
+ ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
414
+ assault rifle, assault gun
415
+ backpack, back pack, knapsack, packsack, rucksack, haversack
416
+ bakery, bakeshop, bakehouse
417
+ balance beam, beam
418
+ balloon
419
+ ballpoint, ballpoint pen, ballpen, Biro
420
+ Band Aid
421
+ banjo
422
+ bannister, banister, balustrade, balusters, handrail
423
+ barbell
424
+ barber chair
425
+ barbershop
426
+ barn
427
+ barometer
428
+ barrel, cask
429
+ barrow, garden cart, lawn cart, wheelbarrow
430
+ baseball
431
+ basketball
432
+ bassinet
433
+ bassoon
434
+ bathing cap, swimming cap
435
+ bath towel
436
+ bathtub, bathing tub, bath, tub
437
+ beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
438
+ beacon, lighthouse, beacon light, pharos
439
+ beaker
440
+ bearskin, busby, shako
441
+ beer bottle
442
+ beer glass
443
+ bell cote, bell cot
444
+ bib
445
+ bicycle-built-for-two, tandem bicycle, tandem
446
+ bikini, two-piece
447
+ binder, ring-binder
448
+ binoculars, field glasses, opera glasses
449
+ birdhouse
450
+ boathouse
451
+ bobsled, bobsleigh, bob
452
+ bolo tie, bolo, bola tie, bola
453
+ bonnet, poke bonnet
454
+ bookcase
455
+ bookshop, bookstore, bookstall
456
+ bottlecap
457
+ bow
458
+ bow tie, bow-tie, bowtie
459
+ brass, memorial tablet, plaque
460
+ brassiere, bra, bandeau
461
+ breakwater, groin, groyne, mole, bulwark, seawall, jetty
462
+ breastplate, aegis, egis
463
+ broom
464
+ bucket, pail
465
+ buckle
466
+ bulletproof vest
467
+ bullet train, bullet
468
+ butcher shop, meat market
469
+ cab, hack, taxi, taxicab
470
+ caldron, cauldron
471
+ candle, taper, wax light
472
+ cannon
473
+ canoe
474
+ can opener, tin opener
475
+ cardigan
476
+ car mirror
477
+ carousel, carrousel, merry-go-round, roundabout, whirligig
478
+ carpenter's kit, tool kit
479
+ carton
480
+ car wheel
481
+ cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
482
+ cassette
483
+ cassette player
484
+ castle
485
+ catamaran
486
+ CD player
487
+ cello, violoncello
488
+ cellular telephone, cellular phone, cellphone, cell, mobile phone
489
+ chain
490
+ chainlink fence
491
+ chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
492
+ chain saw, chainsaw
493
+ chest
494
+ chiffonier, commode
495
+ chime, bell, gong
496
+ china cabinet, china closet
497
+ Christmas stocking
498
+ church, church building
499
+ cinema, movie theater, movie theatre, movie house, picture palace
500
+ cleaver, meat cleaver, chopper
501
+ cliff dwelling
502
+ cloak
503
+ clog, geta, patten, sabot
504
+ cocktail shaker
505
+ coffee mug
506
+ coffeepot
507
+ coil, spiral, volute, whorl, helix
508
+ combination lock
509
+ computer keyboard, keypad
510
+ confectionery, confectionary, candy store
511
+ container ship, containership, container vessel
512
+ convertible
513
+ corkscrew, bottle screw
514
+ cornet, horn, trumpet, trump
515
+ cowboy boot
516
+ cowboy hat, ten-gallon hat
517
+ cradle
518
+ crane
519
+ crash helmet
520
+ crate
521
+ crib, cot
522
+ Crock Pot
523
+ croquet ball
524
+ crutch
525
+ cuirass
526
+ dam, dike, dyke
527
+ desk
528
+ desktop computer
529
+ dial telephone, dial phone
530
+ diaper, nappy, napkin
531
+ digital clock
532
+ digital watch
533
+ dining table, board
534
+ dishrag, dishcloth
535
+ dishwasher, dish washer, dishwashing machine
536
+ disk brake, disc brake
537
+ dock, dockage, docking facility
538
+ dogsled, dog sled, dog sleigh
539
+ dome
540
+ doormat, welcome mat
541
+ drilling platform, offshore rig
542
+ drum, membranophone, tympan
543
+ drumstick
544
+ dumbbell
545
+ Dutch oven
546
+ electric fan, blower
547
+ electric guitar
548
+ electric locomotive
549
+ entertainment center
550
+ envelope
551
+ espresso maker
552
+ face powder
553
+ feather boa, boa
554
+ file, file cabinet, filing cabinet
555
+ fireboat
556
+ fire engine, fire truck
557
+ fire screen, fireguard
558
+ flagpole, flagstaff
559
+ flute, transverse flute
560
+ folding chair
561
+ football helmet
562
+ forklift
563
+ fountain
564
+ fountain pen
565
+ four-poster
566
+ freight car
567
+ French horn, horn
568
+ frying pan, frypan, skillet
569
+ fur coat
570
+ garbage truck, dustcart
571
+ gasmask, respirator, gas helmet
572
+ gas pump, gasoline pump, petrol pump, island dispenser
573
+ goblet
574
+ go-kart
575
+ golf ball
576
+ golfcart, golf cart
577
+ gondola
578
+ gong, tam-tam
579
+ gown
580
+ grand piano, grand
581
+ greenhouse, nursery, glasshouse
582
+ grille, radiator grille
583
+ grocery store, grocery, food market, market
584
+ guillotine
585
+ hair slide
586
+ hair spray
587
+ half track
588
+ hammer
589
+ hamper
590
+ hand blower, blow dryer, blow drier, hair dryer, hair drier
591
+ hand-held computer, hand-held microcomputer
592
+ handkerchief, hankie, hanky, hankey
593
+ hard disc, hard disk, fixed disk
594
+ harmonica, mouth organ, harp, mouth harp
595
+ harp
596
+ harvester, reaper
597
+ hatchet
598
+ holster
599
+ home theater, home theatre
600
+ honeycomb
601
+ hook, claw
602
+ hoopskirt, crinoline
603
+ horizontal bar, high bar
604
+ horse cart, horse-cart
605
+ hourglass
606
+ iPod
607
+ iron, smoothing iron
608
+ jack-o'-lantern
609
+ jean, blue jean, denim
610
+ jeep, landrover
611
+ jersey, T-shirt, tee shirt
612
+ jigsaw puzzle
613
+ jinrikisha, ricksha, rickshaw
614
+ joystick
615
+ kimono
616
+ knee pad
617
+ knot
618
+ lab coat, laboratory coat
619
+ ladle
620
+ lampshade, lamp shade
621
+ laptop, laptop computer
622
+ lawn mower, mower
623
+ lens cap, lens cover
624
+ letter opener, paper knife, paperknife
625
+ library
626
+ lifeboat
627
+ lighter, light, igniter, ignitor
628
+ limousine, limo
629
+ liner, ocean liner
630
+ lipstick, lip rouge
631
+ Loafer
632
+ lotion
633
+ loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
634
+ loupe, jeweler's loupe
635
+ lumbermill, sawmill
636
+ magnetic compass
637
+ mailbag, postbag
638
+ mailbox, letter box
639
+ maillot
640
+ maillot, tank suit
641
+ manhole cover
642
+ maraca
643
+ marimba, xylophone
644
+ mask
645
+ matchstick
646
+ maypole
647
+ maze, labyrinth
648
+ measuring cup
649
+ medicine chest, medicine cabinet
650
+ megalith, megalithic structure
651
+ microphone, mike
652
+ microwave, microwave oven
653
+ military uniform
654
+ milk can
655
+ minibus
656
+ miniskirt, mini
657
+ minivan
658
+ missile
659
+ mitten
660
+ mixing bowl
661
+ mobile home, manufactured home
662
+ Model T
663
+ modem
664
+ monastery
665
+ monitor
666
+ moped
667
+ mortar
668
+ mortarboard
669
+ mosque
670
+ mosquito net
671
+ motor scooter, scooter
672
+ mountain bike, all-terrain bike, off-roader
673
+ mountain tent
674
+ mouse, computer mouse
675
+ mousetrap
676
+ moving van
677
+ muzzle
678
+ nail
679
+ neck brace
680
+ necklace
681
+ nipple
682
+ notebook, notebook computer
683
+ obelisk
684
+ oboe, hautboy, hautbois
685
+ ocarina, sweet potato
686
+ odometer, hodometer, mileometer, milometer
687
+ oil filter
688
+ organ, pipe organ
689
+ oscilloscope, scope, cathode-ray oscilloscope, CRO
690
+ overskirt
691
+ oxcart
692
+ oxygen mask
693
+ packet
694
+ paddle, boat paddle
695
+ paddlewheel, paddle wheel
696
+ padlock
697
+ paintbrush
698
+ pajama, pyjama, pj's, jammies
699
+ palace
700
+ panpipe, pandean pipe, syrinx
701
+ paper towel
702
+ parachute, chute
703
+ parallel bars, bars
704
+ park bench
705
+ parking meter
706
+ passenger car, coach, carriage
707
+ patio, terrace
708
+ pay-phone, pay-station
709
+ pedestal, plinth, footstall
710
+ pencil box, pencil case
711
+ pencil sharpener
712
+ perfume, essence
713
+ Petri dish
714
+ photocopier
715
+ pick, plectrum, plectron
716
+ pickelhaube
717
+ picket fence, paling
718
+ pickup, pickup truck
719
+ pier
720
+ piggy bank, penny bank
721
+ pill bottle
722
+ pillow
723
+ ping-pong ball
724
+ pinwheel
725
+ pirate, pirate ship
726
+ pitcher, ewer
727
+ plane, carpenter's plane, woodworking plane
728
+ planetarium
729
+ plastic bag
730
+ plate rack
731
+ plow, plough
732
+ plunger, plumber's helper
733
+ Polaroid camera, Polaroid Land camera
734
+ pole
735
+ police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
736
+ poncho
737
+ pool table, billiard table, snooker table
738
+ pop bottle, soda bottle
739
+ pot, flowerpot
740
+ potter's wheel
741
+ power drill
742
+ prayer rug, prayer mat
743
+ printer
744
+ prison, prison house
745
+ projectile, missile
746
+ projector
747
+ puck, hockey puck
748
+ punching bag, punch bag, punching ball, punchball
749
+ purse
750
+ quill, quill pen
751
+ quilt, comforter, comfort, puff
752
+ racer, race car, racing car
753
+ racket, racquet
754
+ radiator
755
+ radio, wireless
756
+ radio telescope, radio reflector
757
+ rain barrel
758
+ recreational vehicle, RV, R.V.
759
+ reel
760
+ reflex camera
761
+ refrigerator, icebox
762
+ remote control, remote
763
+ restaurant, eating house, eating place, eatery
764
+ revolver, six-gun, six-shooter
765
+ rifle
766
+ rocking chair, rocker
767
+ rotisserie
768
+ rubber eraser, rubber, pencil eraser
769
+ rugby ball
770
+ rule, ruler
771
+ running shoe
772
+ safe
773
+ safety pin
774
+ saltshaker, salt shaker
775
+ sandal
776
+ sarong
777
+ sax, saxophone
778
+ scabbard
779
+ scale, weighing machine
780
+ school bus
781
+ schooner
782
+ scoreboard
783
+ screen, CRT screen
784
+ screw
785
+ screwdriver
786
+ seat belt, seatbelt
787
+ sewing machine
788
+ shield, buckler
789
+ shoe shop, shoe-shop, shoe store
790
+ shoji
791
+ shopping basket
792
+ shopping cart
793
+ shovel
794
+ shower cap
795
+ shower curtain
796
+ ski
797
+ ski mask
798
+ sleeping bag
799
+ slide rule, slipstick
800
+ sliding door
801
+ slot, one-armed bandit
802
+ snorkel
803
+ snowmobile
804
+ snowplow, snowplough
805
+ soap dispenser
806
+ soccer ball
807
+ sock
808
+ solar dish, solar collector, solar furnace
809
+ sombrero
810
+ soup bowl
811
+ space bar
812
+ space heater
813
+ space shuttle
814
+ spatula
815
+ speedboat
816
+ spider web, spider's web
817
+ spindle
818
+ sports car, sport car
819
+ spotlight, spot
820
+ stage
821
+ steam locomotive
822
+ steel arch bridge
823
+ steel drum
824
+ stethoscope
825
+ stole
826
+ stone wall
827
+ stopwatch, stop watch
828
+ stove
829
+ strainer
830
+ streetcar, tram, tramcar, trolley, trolley car
831
+ stretcher
832
+ studio couch, day bed
833
+ stupa, tope
834
+ submarine, pigboat, sub, U-boat
835
+ suit, suit of clothes
836
+ sundial
837
+ sunglass
838
+ sunglasses, dark glasses, shades
839
+ sunscreen, sunblock, sun blocker
840
+ suspension bridge
841
+ swab, swob, mop
842
+ sweatshirt
843
+ swimming trunks, bathing trunks
844
+ swing
845
+ switch, electric switch, electrical switch
846
+ syringe
847
+ table lamp
848
+ tank, army tank, armored combat vehicle, armoured combat vehicle
849
+ tape player
850
+ teapot
851
+ teddy, teddy bear
852
+ television, television system
853
+ tennis ball
854
+ thatch, thatched roof
855
+ theater curtain, theatre curtain
856
+ thimble
857
+ thresher, thrasher, threshing machine
858
+ throne
859
+ tile roof
860
+ toaster
861
+ tobacco shop, tobacconist shop, tobacconist
862
+ toilet seat
863
+ torch
864
+ totem pole
865
+ tow truck, tow car, wrecker
866
+ toyshop
867
+ tractor
868
+ trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
869
+ tray
870
+ trench coat
871
+ tricycle, trike, velocipede
872
+ trimaran
873
+ tripod
874
+ triumphal arch
875
+ trolleybus, trolley coach, trackless trolley
876
+ trombone
877
+ tub, vat
878
+ turnstile
879
+ typewriter keyboard
880
+ umbrella
881
+ unicycle, monocycle
882
+ upright, upright piano
883
+ vacuum, vacuum cleaner
884
+ vase
885
+ vault
886
+ velvet
887
+ vending machine
888
+ vestment
889
+ viaduct
890
+ violin, fiddle
891
+ volleyball
892
+ waffle iron
893
+ wall clock
894
+ wallet, billfold, notecase, pocketbook
895
+ wardrobe, closet, press
896
+ warplane, military plane
897
+ washbasin, handbasin, washbowl, lavabo, wash-hand basin
898
+ washer, automatic washer, washing machine
899
+ water bottle
900
+ water jug
901
+ water tower
902
+ whiskey jug
903
+ whistle
904
+ wig
905
+ window screen
906
+ window shade
907
+ Windsor tie
908
+ wine bottle
909
+ wing
910
+ wok
911
+ wooden spoon
912
+ wool, woolen, woollen
913
+ worm fence, snake fence, snake-rail fence, Virginia fence
914
+ wreck
915
+ yawl
916
+ yurt
917
+ web site, website, internet site, site
918
+ comic book
919
+ crossword puzzle, crossword
920
+ street sign
921
+ traffic light, traffic signal, stoplight
922
+ book jacket, dust cover, dust jacket, dust wrapper
923
+ menu
924
+ plate
925
+ guacamole
926
+ consomme
927
+ hot pot, hotpot
928
+ trifle
929
+ ice cream, icecream
930
+ ice lolly, lolly, lollipop, popsicle
931
+ French loaf
932
+ bagel, beigel
933
+ pretzel
934
+ cheeseburger
935
+ hotdog, hot dog, red hot
936
+ mashed potato
937
+ head cabbage
938
+ broccoli
939
+ cauliflower
940
+ zucchini, courgette
941
+ spaghetti squash
942
+ acorn squash
943
+ butternut squash
944
+ cucumber, cuke
945
+ artichoke, globe artichoke
946
+ bell pepper
947
+ cardoon
948
+ mushroom
949
+ Granny Smith
950
+ strawberry
951
+ orange
952
+ lemon
953
+ fig
954
+ pineapple, ananas
955
+ banana
956
+ jackfruit, jak, jack
957
+ custard apple
958
+ pomegranate
959
+ hay
960
+ carbonara
961
+ chocolate sauce, chocolate syrup
962
+ dough
963
+ meat loaf, meatloaf
964
+ pizza, pizza pie
965
+ potpie
966
+ burrito
967
+ red wine
968
+ espresso
969
+ cup
970
+ eggnog
971
+ alp
972
+ bubble
973
+ cliff, drop, drop-off
974
+ coral reef
975
+ geyser
976
+ lakeside, lakeshore
977
+ promontory, headland, head, foreland
978
+ sandbar, sand bar
979
+ seashore, coast, seacoast, sea-coast
980
+ valley, vale
981
+ volcano
982
+ ballplayer, baseball player
983
+ groom, bridegroom
984
+ scuba diver
985
+ rapeseed
986
+ daisy
987
+ yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
988
+ corn
989
+ acorn
990
+ hip, rose hip, rosehip
991
+ buckeye, horse chestnut, conker
992
+ coral fungus
993
+ agaric
994
+ gyromitra
995
+ stinkhorn, carrion fungus
996
+ earthstar
997
+ hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
998
+ bolete
999
+ ear, spike, capitulum
1000
+ toilet tissue, toilet paper, bathroom tissue
adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0
3
+ size 3714759
adv_grpo/assets/simple_animals.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cat
2
+ dog
3
+ horse
4
+ monkey
5
+ rabbit
6
+ zebra
7
+ spider
8
+ bird
9
+ sheep
10
+ deer
11
+ cow
12
+ goat
13
+ lion
14
+ tiger
15
+ bear
16
+ raccoon
17
+ fox
18
+ wolf
19
+ lizard
20
+ beetle
21
+ ant
22
+ butterfly
23
+ fish
24
+ shark
25
+ whale
26
+ dolphin
27
+ squirrel
28
+ mouse
29
+ rat
30
+ snake
31
+ turtle
32
+ frog
33
+ chicken
34
+ duck
35
+ goose
36
+ bee
37
+ pig
38
+ turkey
39
+ fly
40
+ llama
41
+ camel
42
+ bat
43
+ gorilla
44
+ hedgehog
45
+ kangaroo
adv_grpo/assets/simple_ocr_animals.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cat
2
+ dog
3
+ horse
4
+ monkey
5
+ rabbit
adv_grpo/assets/simple_ocr_animals_digit1.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A cat holding a sign that says '0'
2
+ A dog holding a sign that says '0'
3
+ A horse holding a sign that says '0'
4
+ A monkey holding a sign that says '0'
5
+ A rabbit holding a sign that says '0'
6
+ A cat holding a sign that says '1'
7
+ A dog holding a sign that says '1'
8
+ A horse holding a sign that says '1'
9
+ A monkey holding a sign that says '1'
10
+ A rabbit holding a sign that says '1'
11
+ A cat holding a sign that says '2'
12
+ A dog holding a sign that says '2'
13
+ A horse holding a sign that says '2'
14
+ A monkey holding a sign that says '2'
15
+ A rabbit holding a sign that says '2'
16
+ A cat holding a sign that says '3'
17
+ A dog holding a sign that says '3'
18
+ A horse holding a sign that says '3'
19
+ A monkey holding a sign that says '3'
20
+ A rabbit holding a sign that says '3'
21
+ A cat holding a sign that says '4'
22
+ A dog holding a sign that says '4'
23
+ A horse holding a sign that says '4'
24
+ A monkey holding a sign that says '4'
25
+ A rabbit holding a sign that says '4'
26
+ A cat holding a sign that says '5'
27
+ A dog holding a sign that says '5'
28
+ A horse holding a sign that says '5'
29
+ A monkey holding a sign that says '5'
30
+ A rabbit holding a sign that says '5'
31
+ A cat holding a sign that says '6'
32
+ A dog holding a sign that says '6'
33
+ A horse holding a sign that says '6'
34
+ A monkey holding a sign that says '6'
35
+ A rabbit holding a sign that says '6'
36
+ A cat holding a sign that says '7'
37
+ A dog holding a sign that says '7'
38
+ A horse holding a sign that says '7'
39
+ A monkey holding a sign that says '7'
40
+ A rabbit holding a sign that says '7'
41
+ A cat holding a sign that says '8'
42
+ A dog holding a sign that says '8'
43
+ A horse holding a sign that says '8'
44
+ A monkey holding a sign that says '8'
45
+ A rabbit holding a sign that says '8'
adv_grpo/assets/simple_ocr_animals_digit3.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A cat holding a sign that says '123'
2
+ A dog holding a sign that says '234'
3
+ A horse holding a sign that says '345'
4
+ A monkey holding a sign that says '456'
5
+ A rabbit holding a sign that says '567'
6
+ A cat holding a sign that says '678'
7
+ A dog holding a sign that says '789'
8
+ A horse holding a sign that says '123'
9
+ A monkey holding a sign that says '234'
10
+ A rabbit holding a sign that says '345'
11
+ A cat holding a sign that says '456'
12
+ A dog holding a sign that says '567'
13
+ A horse holding a sign that says '678'
14
+ A monkey holding a sign that says '789'
15
+ A rabbit holding a sign that says '123'
16
+ A cat holding a sign that says '234'
17
+ A dog holding a sign that says '345'
18
+ A horse holding a sign that says '456'
19
+ A monkey holding a sign that says '567'
20
+ A rabbit holding a sign that says '678'
21
+ A cat holding a sign that says '789'
22
+ A dog holding a sign that says '123'
23
+ A horse holding a sign that says '234'
24
+ A monkey holding a sign that says '345'
25
+ A rabbit holding a sign that says '456'
26
+ A cat holding a sign that says '567'
27
+ A dog holding a sign that says '678'
28
+ A horse holding a sign that says '789'
29
+ A monkey holding a sign that says '123'
30
+ A rabbit holding a sign that says '234'
31
+ A cat holding a sign that says '345'
32
+ A dog holding a sign that says '456'
33
+ A horse holding a sign that says '567'
34
+ A monkey holding a sign that says '678'
35
+ A rabbit holding a sign that says '789'
36
+ A cat holding a sign that says '123'
37
+ A dog holding a sign that says '234'
38
+ A horse holding a sign that says '345'
39
+ A monkey holding a sign that says '456'
40
+ A rabbit holding a sign that says '567'
41
+ A cat holding a sign that says '678'
42
+ A dog holding a sign that says '789'
43
+ A horse holding a sign that says '123'
44
+ A monkey holding a sign that says '234'
45
+ A rabbit holding a sign that says '345'
adv_grpo/assets/simple_ocr_animals_digit5.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A cat holding a sign that says '12345'
2
+ A dog holding a sign that says '23456'
3
+ A horse holding a sign that says '34567'
4
+ A monkey holding a sign that says '45678'
5
+ A rabbit holding a sign that says '56789'
6
+ A cat holding a sign that says '54321'
7
+ A dog holding a sign that says '65432'
8
+ A horse holding a sign that says '76543'
9
+ A monkey holding a sign that says '87654'
10
+ A rabbit holding a sign that says '98765'
11
+ A cat holding a sign that says '12345'
12
+ A dog holding a sign that says '23456'
13
+ A horse holding a sign that says '34567'
14
+ A monkey holding a sign that says '45678'
15
+ A rabbit holding a sign that says '56789'
16
+ A cat holding a sign that says '54321'
17
+ A dog holding a sign that says '65432'
18
+ A horse holding a sign that says '76543'
19
+ A monkey holding a sign that says '87654'
20
+ A rabbit holding a sign that says '98765'
21
+ A cat holding a sign that says '12345'
22
+ A dog holding a sign that says '23456'
23
+ A horse holding a sign that says '34567'
24
+ A monkey holding a sign that says '45678'
25
+ A rabbit holding a sign that says '56789'
26
+ A cat holding a sign that says '54321'
27
+ A dog holding a sign that says '65432'
28
+ A horse holding a sign that says '76543'
29
+ A monkey holding a sign that says '87654'
30
+ A rabbit holding a sign that says '98765'
31
+ A cat holding a sign that says '12345'
32
+ A dog holding a sign that says '23456'
33
+ A horse holding a sign that says '34567'
34
+ A monkey holding a sign that says '45678'
35
+ A rabbit holding a sign that says '56789'
36
+ A cat holding a sign that says '54321'
37
+ A dog holding a sign that says '65432'
38
+ A horse holding a sign that says '76543'
39
+ A monkey holding a sign that says '87654'
40
+ A rabbit holding a sign that says '98765'
41
+ A cat holding a sign that says '12345'
42
+ A dog holding a sign that says '23456'
43
+ A horse holding a sign that says '34567'
44
+ A monkey holding a sign that says '45678'
45
+ A rabbit holding a sign that says '56789'
46
+ A cat holding a sign that says '54321'
47
+ A dog holding a sign that says '65432'
48
+ A horse holding a sign that says '76543'
49
+ A monkey holding a sign that says '87654'
50
+ A rabbit holding a sign that says '98765'
adv_grpo/assets/test.jpg ADDED
adv_grpo/clip_scorer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/RE-N-Y/imscore/blob/main/src/imscore/preference/model.py
2
+
3
+ from importlib import resources
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ import torchvision.transforms as T
8
+ from transformers import AutoImageProcessor,CLIPProcessor, CLIPModel
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ def get_size(size):
13
+ if isinstance(size, int):
14
+ return (size, size)
15
+ elif "height" in size and "width" in size:
16
+ return (size["height"], size["width"])
17
+ elif "shortest_edge" in size:
18
+ return size["shortest_edge"]
19
+ else:
20
+ raise ValueError(f"Invalid size: {size}")
21
+
22
+ def get_image_transform(processor:AutoImageProcessor):
23
+ config = processor.to_dict()
24
+ resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
25
+ crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
26
+ normalise = T.Normalize(mean=processor.image_mean, std=processor.image_std) if config.get("do_normalize") else nn.Identity()
27
+
28
+ return T.Compose([resize, crop, normalise])
29
+
30
+ class ClipScorer(torch.nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ # self.device="cuda"
34
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
35
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
36
+ self.tform = get_image_transform(self.processor.image_processor)
37
+ self.eval()
38
+
39
+ def _process(self, pixels):
40
+ dtype = pixels.dtype
41
+ pixels = self.tform(pixels)
42
+ pixels = pixels.to(dtype=dtype)
43
+
44
+ return pixels
45
+
46
+ @torch.no_grad()
47
+ def __call__(self, pixels, prompts, return_img_embedding=False):
48
+ device = next(self.parameters()).device
49
+ texts = self.processor(text=prompts, padding='max_length', truncation=True, return_tensors="pt").to(device)
50
+ pixels = self._process(pixels).to(device)
51
+ outputs = self.model(pixel_values=pixels, **texts)
52
+ if return_img_embedding:
53
+ return outputs.logits_per_image.diagonal()/30, outputs.image_embeds
54
+ return outputs.logits_per_image.diagonal()/30
55
+
56
+ @torch.no_grad()
57
+ def image_similarity(self, pixels, ref_pixels):
58
+ device = next(self.parameters()).device
59
+ pixels = self._process(pixels).to(device)
60
+ ref_pixels = self._process(ref_pixels).to(device)
61
+
62
+ pixel_embeds = self.model.get_image_features(pixel_values=pixels)
63
+ ref_embeds = self.model.get_image_features(pixel_values=ref_pixels)
64
+
65
+ pixel_embeds = pixel_embeds / pixel_embeds.norm(p=2, dim=-1, keepdim=True)
66
+ ref_embeds = ref_embeds / ref_embeds.norm(p=2, dim=-1, keepdim=True)
67
+
68
+ sim = pixel_embeds @ ref_embeds.T
69
+ # sim = torch.diagonal(sim, 0)
70
+ sim = sim.squeeze(-1)
71
+ return sim
72
+
73
+
74
+ def main():
75
+ # scorer = ClipScorer(
76
+ # device='cuda'
77
+ # )
78
+ scorer = ClipScorer(
79
+ )
80
+
81
+ images=[
82
+ "assets/test.jpg",
83
+ "assets/test.jpg"
84
+ ]
85
+ pil_images = [Image.open(img) for img in images]
86
+ prompts=[
87
+ 'an image of cat',
88
+ 'not an image of cat'
89
+ ]
90
+ images = [np.array(img) for img in pil_images]
91
+ images = np.array(images)
92
+ images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
93
+ images = torch.tensor(images, dtype=torch.uint8)/255.0
94
+ print(scorer(images, prompts))
95
+
96
+ if __name__ == "__main__":
97
+ main()
adv_grpo/conv_gradfix.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom replacement for `torch.nn.functional.convNd` and `torch.nn.functional.conv_transposeNd`
3
+ that supports arbitrarily high order gradients with zero performance penalty.
4
+ Modified from https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py
5
+ """
6
+
7
+ import contextlib
8
+ import warnings
9
+ from typing import Optional
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+ from torch.nn import Conv2d, Conv3d
14
+
15
+ # pylint: disable=redefined-builtin
16
+ # pylint: disable=arguments-differ
17
+ # pylint: disable=protected-access
18
+
19
+ # ----------------------------------------------------------------------------
20
+
21
+ enabled = False # Enable the custom op by setting this to true.
22
+ weight_gradients_disabled = (
23
+ False # Forcefully disable computation of gradients with respect to the weights.
24
+ )
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def no_weight_gradients():
29
+ global weight_gradients_disabled
30
+ old = weight_gradients_disabled
31
+ weight_gradients_disabled = True
32
+ yield
33
+ weight_gradients_disabled = old
34
+
35
+
36
+ # ----------------------------------------------------------------------------
37
+ class GradFixConv2d(Conv2d):
38
+ def __init__(self, *args, use_gradfix: bool = False, **kwargs):
39
+ self.use_gradfix = use_gradfix
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
43
+ conv_fn = F.conv2d if not self.use_gradfix else convNd
44
+ if self.padding_mode != "zeros":
45
+ return conv_fn(
46
+ F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
47
+ weight,
48
+ bias,
49
+ self.stride,
50
+ (0, 0),
51
+ self.dilation,
52
+ self.groups,
53
+ )
54
+ return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
55
+
56
+ def forward(
57
+ self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
58
+ ) -> Tensor:
59
+ weight = self.weight if weight is None else weight
60
+ bias = self.bias if bias is None else bias
61
+ return self._conv_forward(input, weight, bias)
62
+
63
+
64
+ class GradFixConv3d(Conv3d):
65
+ def __init__(self, *args, use_gradfix: bool = False, **kwargs):
66
+ self.use_gradfix = use_gradfix
67
+ super().__init__(*args, **kwargs)
68
+
69
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
70
+ conv_fn = F.conv3d if not self.use_gradfix else convNd
71
+ if self.padding_mode != "zeros":
72
+ return conv_fn(
73
+ F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
74
+ weight,
75
+ bias,
76
+ self.stride,
77
+ (0, 0, 0),
78
+ self.dilation,
79
+ self.groups,
80
+ )
81
+ return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
82
+
83
+ def forward(
84
+ self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
85
+ ) -> Tensor:
86
+ weight = self.weight if weight is None else weight
87
+ bias = self.bias if bias is None else bias
88
+ return self._conv_forward(input, weight, bias)
89
+
90
+
91
+ # ----------------------------------------------------------------------------
92
+
93
+
94
+ def convNd(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
95
+ N = weight.ndim - 2
96
+ if _should_use_custom_op(input):
97
+ return _conv_gradfix(
98
+ transpose=False,
99
+ weight_shape=weight.shape,
100
+ stride=stride,
101
+ padding=padding,
102
+ output_padding=0,
103
+ dilation=dilation,
104
+ groups=groups,
105
+ ).apply(input, weight, bias)
106
+ return getattr(torch.nn.functional, f"conv{N}d")(
107
+ input=input,
108
+ weight=weight,
109
+ bias=bias,
110
+ stride=stride,
111
+ padding=padding,
112
+ dilation=dilation,
113
+ groups=groups,
114
+ )
115
+
116
+
117
+ def conv_transposeNd(
118
+ input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1
119
+ ):
120
+ N = weight.ndim - 2
121
+ if _should_use_custom_op(input):
122
+ return _conv_gradfix(
123
+ transpose=True,
124
+ weight_shape=weight.shape,
125
+ stride=stride,
126
+ padding=padding,
127
+ output_padding=output_padding,
128
+ groups=groups,
129
+ dilation=dilation,
130
+ ).apply(input, weight, bias)
131
+ return getattr(torch.nn.functional, f"conv_transpose{N}d")(
132
+ input=input,
133
+ weight=weight,
134
+ bias=bias,
135
+ stride=stride,
136
+ padding=padding,
137
+ output_padding=output_padding,
138
+ groups=groups,
139
+ dilation=dilation,
140
+ )
141
+
142
+
143
+ # ----------------------------------------------------------------------------
144
+
145
+
146
+ def _should_use_custom_op(input):
147
+ assert isinstance(input, torch.Tensor)
148
+ if (not enabled) or (not torch.backends.cudnn.enabled):
149
+ return False
150
+ if input.device.type != "cuda":
151
+ return False
152
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]):
153
+ return True
154
+ if torch.__version__.startswith("2"):
155
+ return True
156
+ warnings.warn(
157
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. "
158
+ f"Falling back to torch.nn.functional.conv2d()."
159
+ )
160
+ return False
161
+
162
+
163
+ def _tuple_of_ints(xs, ndim):
164
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
165
+ assert len(xs) == ndim
166
+ assert all(isinstance(x, int) for x in xs)
167
+ return xs
168
+
169
+
170
+ # ----------------------------------------------------------------------------
171
+
172
+ _conv_gradfix_cache = dict()
173
+
174
+
175
+ def _conv_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
176
+ ndim = len(weight_shape) - 2
177
+ # Parse arguments.
178
+ weight_shape = tuple(weight_shape)
179
+ stride = _tuple_of_ints(stride, ndim)
180
+ padding = _tuple_of_ints(padding, ndim)
181
+ output_padding = _tuple_of_ints(output_padding, ndim)
182
+ dilation = _tuple_of_ints(dilation, ndim)
183
+
184
+ # Lookup from cache.
185
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
186
+ if key in _conv_gradfix_cache:
187
+ return _conv_gradfix_cache[key]
188
+
189
+ # Validate arguments.
190
+ assert groups >= 1
191
+ assert all(stride[i] >= 1 for i in range(ndim))
192
+ assert all(padding[i] >= 0 for i in range(ndim))
193
+ assert all(dilation[i] >= 0 for i in range(ndim))
194
+ if not transpose:
195
+ assert all(output_padding[i] == 0 for i in range(ndim))
196
+ else: # transpose
197
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
198
+
199
+ # Helpers.
200
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
201
+
202
+ def calc_output_padding(input_shape, output_shape):
203
+ if transpose:
204
+ return [
205
+ 0,
206
+ ] * ndim
207
+ return [
208
+ input_shape[i + 2]
209
+ - (output_shape[i + 2] - 1) * stride[i]
210
+ - (1 - 2 * padding[i])
211
+ - dilation[i] * (weight_shape[i + 2] - 1)
212
+ for i in range(ndim)
213
+ ]
214
+
215
+ # Forward & backward.
216
+ class ConvNd(torch.autograd.Function):
217
+ @staticmethod
218
+ def forward(ctx, input, weight, bias):
219
+ """
220
+ input size: [B, C, ...]
221
+ weight size:
222
+ -> Conv: [C_out, C_in // groups, ...]
223
+ -> Transpose: [C_in, C_out // groups, ...]
224
+ """
225
+ assert weight.shape == weight_shape
226
+ ctx.save_for_backward(input, weight)
227
+
228
+ # General case => cuDNN.
229
+ if transpose:
230
+ return getattr(torch.nn.functional, f"conv_transpose{ndim}d")(
231
+ input=input,
232
+ weight=weight.to(input.dtype),
233
+ bias=bias,
234
+ output_padding=output_padding,
235
+ **common_kwargs,
236
+ )
237
+ return getattr(torch.nn.functional, f"conv{ndim}d")(
238
+ input=input, weight=weight.to(input.dtype), bias=bias, **common_kwargs
239
+ )
240
+
241
+ @staticmethod
242
+ def backward(ctx, grad_output):
243
+ input, weight = ctx.saved_tensors
244
+ grad_input = None
245
+ grad_weight = None
246
+ grad_bias = None
247
+
248
+ if ctx.needs_input_grad[0]: # Input
249
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
250
+ op = _conv_gradfix(
251
+ transpose=(not transpose),
252
+ weight_shape=weight_shape,
253
+ output_padding=p,
254
+ **common_kwargs,
255
+ )
256
+ grad_input = op.apply(grad_output, weight, None)
257
+ assert grad_input.shape == input.shape
258
+
259
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled: # Weight
260
+ grad_weight = ConvNdGradWeight.apply(grad_output, input)
261
+ assert grad_weight.shape == weight_shape
262
+
263
+ if ctx.needs_input_grad[2]: # Bias
264
+ grad_bias = grad_output.transpose(0, 1).flatten(1).sum(1)
265
+
266
+ return grad_input, grad_weight, grad_bias
267
+
268
+ # Gradient with respect to the weights.
269
+ class ConvNdGradWeight(torch.autograd.Function):
270
+ @staticmethod
271
+ def forward(ctx, grad_output, input):
272
+ flags = [
273
+ torch.backends.cudnn.benchmark,
274
+ torch.backends.cudnn.deterministic,
275
+ torch.backends.cudnn.allow_tf32,
276
+ ]
277
+ if torch.__version__.startswith("1"):
278
+ op = torch._C._jit_get_operation(
279
+ "aten::cudnn_convolution_backward_weight"
280
+ if not transpose
281
+ else "aten::cudnn_convolution_transpose_backward_weight"
282
+ )
283
+ grad_weight = op(
284
+ weight_shape,
285
+ grad_output,
286
+ input.to(grad_output.dtype),
287
+ padding,
288
+ stride,
289
+ dilation,
290
+ groups,
291
+ *flags,
292
+ )
293
+ elif torch.__version__.startswith("2"):
294
+ # https://github.com/pytorch/pytorch/issues/74437
295
+ op, _ = torch._C._jit_get_operation("aten::convolution_backward")
296
+ dummy_weight = torch.tensor(
297
+ 0.0, dtype=grad_output.dtype, device=input.device
298
+ ).expand(weight_shape)
299
+ grad_weight = op(
300
+ grad_output,
301
+ input.to(grad_output.dtype),
302
+ dummy_weight,
303
+ None,
304
+ stride,
305
+ padding,
306
+ dilation,
307
+ transpose,
308
+ (0,) * ndim,
309
+ groups,
310
+ [False, True, False],
311
+ )[1]
312
+ else:
313
+ raise NotImplementedError
314
+ assert grad_weight.shape == weight_shape
315
+ ctx.save_for_backward(grad_output, input)
316
+ return grad_weight
317
+
318
+ @staticmethod
319
+ def backward(ctx, grad2_grad_weight):
320
+ grad_output, input = ctx.saved_tensors
321
+ grad2_grad_output = None
322
+ grad2_input = None
323
+
324
+ if ctx.needs_input_grad[0]: # Grad of Weight
325
+ grad2_grad_output = ConvNd.apply(input, grad2_grad_weight, None)
326
+ assert grad2_grad_output.shape == grad_output.shape
327
+
328
+ if ctx.needs_input_grad[1]: # Input
329
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
330
+ op = _conv_gradfix(
331
+ transpose=(not transpose),
332
+ weight_shape=weight_shape,
333
+ output_padding=p,
334
+ **common_kwargs,
335
+ )
336
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
337
+ assert grad2_input.shape == input.shape
338
+
339
+ return grad2_grad_output, grad2_input
340
+
341
+ _conv_gradfix_cache[key] = ConvNd
342
+ return ConvNd
343
+
344
+
345
+ # ----------------------------------------------------------------------------
adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc ADDED
Binary file (3.54 kB). View file
 
adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
2
+
3
+ from typing import Any, Dict, List, Optional, Union, Callable
4
+ import torch
5
+ import numpy as np
6
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
7
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
8
+ from diffusers.utils import logging
9
+ from .sd3_sde_with_logprob import sde_step_with_logprob
10
+
11
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12
+
13
+ PREFERRED_KONTEXT_RESOLUTIONS = [
14
+ (672, 1568),
15
+ (688, 1504),
16
+ (720, 1456),
17
+ (752, 1392),
18
+ (800, 1328),
19
+ (832, 1248),
20
+ (880, 1184),
21
+ (944, 1104),
22
+ (1024, 1024),
23
+ (1104, 944),
24
+ (1184, 880),
25
+ (1248, 832),
26
+ (1328, 800),
27
+ (1392, 752),
28
+ (1456, 720),
29
+ (1504, 688),
30
+ (1568, 672),
31
+ ]
32
+
33
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
34
+ def retrieve_latents(
35
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
36
+ ):
37
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
38
+ return encoder_output.latent_dist.sample(generator)
39
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
40
+ return encoder_output.latent_dist.mode()
41
+ elif hasattr(encoder_output, "latents"):
42
+ return encoder_output.latents
43
+ else:
44
+ raise AttributeError("Could not access latents of provided encoder_output")
45
+
46
+ def calculate_shift(
47
+ image_seq_len,
48
+ base_seq_len: int = 256,
49
+ max_seq_len: int = 4096,
50
+ base_shift: float = 0.5,
51
+ max_shift: float = 1.15,
52
+ ):
53
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
54
+ b = base_shift - m * base_seq_len
55
+ mu = image_seq_len * m + b
56
+ return mu
57
+
58
+ @torch.no_grad()
59
+ def pipeline_with_logprob(
60
+ self,
61
+ image: Optional[PipelineImageInput] = None,
62
+ prompt: Union[str, List[str]] = None,
63
+ prompt_2: Optional[Union[str, List[str]]] = None,
64
+ negative_prompt: Union[str, List[str]] = None,
65
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
66
+ height: Optional[int] = None,
67
+ width: Optional[int] = None,
68
+ num_inference_steps: int = 28,
69
+ sigmas: Optional[List[float]] = None,
70
+ guidance_scale: float = 3.5,
71
+ num_images_per_prompt: Optional[int] = 1,
72
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
73
+ latents: Optional[torch.FloatTensor] = None,
74
+ prompt_embeds: Optional[torch.FloatTensor] = None,
75
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
77
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
78
+ output_type: Optional[str] = "pil",
79
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
81
+ max_sequence_length: int = 512,
82
+ max_area: int = 1024**2,
83
+ _auto_resize: bool = True,
84
+ noise_level: float = 0.7,
85
+ ):
86
+ height = height or self.default_sample_size * self.vae_scale_factor
87
+ width = width or self.default_sample_size * self.vae_scale_factor
88
+
89
+ original_height, original_width = height, width
90
+ aspect_ratio = width / height
91
+ width = round((max_area * aspect_ratio) ** 0.5)
92
+ height = round((max_area / aspect_ratio) ** 0.5)
93
+
94
+ multiple_of = self.vae_scale_factor * 2
95
+ width = width // multiple_of * multiple_of
96
+ height = height // multiple_of * multiple_of
97
+
98
+ if height != original_height or width != original_width:
99
+ logger.warning(
100
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
101
+ )
102
+
103
+ # 1. Check inputs. Raise error if not correct
104
+ self.check_inputs(
105
+ prompt,
106
+ prompt_2,
107
+ height,
108
+ width,
109
+ negative_prompt=negative_prompt,
110
+ negative_prompt_2=negative_prompt_2,
111
+ prompt_embeds=prompt_embeds,
112
+ negative_prompt_embeds=negative_prompt_embeds,
113
+ pooled_prompt_embeds=pooled_prompt_embeds,
114
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
115
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
116
+ max_sequence_length=max_sequence_length,
117
+ )
118
+
119
+ self._guidance_scale = guidance_scale
120
+ self._joint_attention_kwargs = joint_attention_kwargs
121
+ self._current_timestep = None
122
+ self._interrupt = False
123
+
124
+ # 2. Define call parameters
125
+ if prompt is not None and isinstance(prompt, str):
126
+ batch_size = 1
127
+ elif prompt is not None and isinstance(prompt, list):
128
+ batch_size = len(prompt)
129
+ else:
130
+ batch_size = prompt_embeds.shape[0]
131
+
132
+ device = self._execution_device
133
+
134
+ lora_scale = (
135
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
136
+ )
137
+ (
138
+ prompt_embeds,
139
+ pooled_prompt_embeds,
140
+ text_ids,
141
+ ) = self.encode_prompt(
142
+ prompt=prompt,
143
+ prompt_2=prompt_2,
144
+ prompt_embeds=prompt_embeds,
145
+ pooled_prompt_embeds=pooled_prompt_embeds,
146
+ device=device,
147
+ num_images_per_prompt=num_images_per_prompt,
148
+ max_sequence_length=max_sequence_length,
149
+ lora_scale=lora_scale,
150
+ )
151
+
152
+ # 3. Preprocess image
153
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
154
+ image = self.image_processor.resize(image, height, width)
155
+ image = self.image_processor.preprocess(image, height, width)
156
+ # 4. Prepare latent variables
157
+ num_channels_latents = self.transformer.config.in_channels // 4
158
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
159
+ image.float(),
160
+ batch_size * num_images_per_prompt,
161
+ num_channels_latents,
162
+ height,
163
+ width,
164
+ prompt_embeds.dtype,
165
+ device,
166
+ generator,
167
+ latents,
168
+ )
169
+ if image_ids is not None:
170
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
171
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
172
+ image_seq_len = latents.shape[1]
173
+ mu = calculate_shift(
174
+ image_seq_len,
175
+ self.scheduler.config.get("base_image_seq_len", 256),
176
+ self.scheduler.config.get("max_image_seq_len", 4096),
177
+ self.scheduler.config.get("base_shift", 0.5),
178
+ self.scheduler.config.get("max_shift", 1.15),
179
+ )
180
+ timesteps, num_inference_steps = retrieve_timesteps(
181
+ self.scheduler,
182
+ num_inference_steps,
183
+ device,
184
+ sigmas=sigmas,
185
+ mu=mu,
186
+ )
187
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
188
+ self._num_timesteps = len(timesteps)
189
+
190
+ # handle guidance
191
+ if self.transformer.config.guidance_embeds:
192
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
193
+ guidance = guidance.expand(latents.shape[0])
194
+ else:
195
+ guidance = None
196
+
197
+ # 6. Prepare image embeddings
198
+ all_latents = [latents]
199
+ all_log_probs = []
200
+
201
+ # 7. Denoising loop
202
+ self.scheduler.set_begin_index(0)
203
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
204
+ for i, t in enumerate(timesteps):
205
+ if self.interrupt:
206
+ continue
207
+ self._current_timestep = t
208
+ latent_model_input = latents
209
+ if image_latents is not None:
210
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
211
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
212
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
213
+ noise_pred = self.transformer(
214
+ hidden_states=latent_model_input,
215
+ timestep=timestep / 1000,
216
+ guidance=guidance,
217
+ pooled_projections=pooled_prompt_embeds,
218
+ encoder_hidden_states=prompt_embeds,
219
+ txt_ids=text_ids,
220
+ img_ids=latent_ids,
221
+ joint_attention_kwargs=self.joint_attention_kwargs,
222
+ return_dict=False,
223
+ )[0]
224
+ if noise_pred.isnan().any():
225
+ breakpoint()
226
+ print("log_prob is nan")
227
+ noise_pred = noise_pred[:, : latents.size(1)]
228
+ latents_dtype = latents.dtype
229
+
230
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
231
+ self.scheduler,
232
+ noise_pred.float(),
233
+ t.unsqueeze(0).repeat(latents.shape[0]),
234
+ latents.float(),
235
+ noise_level=noise_level,
236
+ )
237
+
238
+ if latents.dtype != latents_dtype:
239
+ latents = latents.to(latents_dtype)
240
+ all_latents.append(latents)
241
+ all_log_probs.append(log_prob)
242
+ # call the callback, if provided
243
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
244
+ progress_bar.update()
245
+
246
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
247
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
248
+ latents = latents.to(dtype=self.vae.dtype)
249
+ image = self.vae.decode(latents, return_dict=False)[0]
250
+ image = self.image_processor.postprocess(image, output_type=output_type)
251
+
252
+ # Offload all models
253
+ self.maybe_free_model_hooks()
254
+
255
+ return image, all_latents, latent_ids, text_ids, all_log_probs, image_latents
adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py
2
+
3
+ from typing import Any, Dict, List, Optional, Union, Callable
4
+ import torch
5
+ import numpy as np
6
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
7
+ from .sd3_sde_with_logprob import sde_step_with_logprob
8
+
9
+ def calculate_shift(
10
+ image_seq_len,
11
+ base_seq_len: int = 256,
12
+ max_seq_len: int = 4096,
13
+ base_shift: float = 0.5,
14
+ max_shift: float = 1.15,
15
+ ):
16
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
17
+ b = base_shift - m * base_seq_len
18
+ mu = image_seq_len * m + b
19
+ return mu
20
+
21
+ @torch.no_grad()
22
+ def pipeline_with_logprob(
23
+ self,
24
+ prompt: Union[str, List[str]] = None,
25
+ prompt_2: Optional[Union[str, List[str]]] = None,
26
+ negative_prompt: Union[str, List[str]] = None,
27
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
28
+ height: Optional[int] = None,
29
+ width: Optional[int] = None,
30
+ num_inference_steps: int = 28,
31
+ sigmas: Optional[List[float]] = None,
32
+ guidance_scale: float = 3.5,
33
+ num_images_per_prompt: Optional[int] = 1,
34
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35
+ latents: Optional[torch.FloatTensor] = None,
36
+ prompt_embeds: Optional[torch.FloatTensor] = None,
37
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
39
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
40
+ output_type: Optional[str] = "pil",
41
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
42
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
43
+ max_sequence_length: int = 512,
44
+ noise_level: float = 0.7,
45
+ ):
46
+ height = height or self.default_sample_size * self.vae_scale_factor
47
+ width = width or self.default_sample_size * self.vae_scale_factor
48
+
49
+ # 1. Check inputs. Raise error if not correct
50
+ self.check_inputs(
51
+ prompt,
52
+ prompt_2,
53
+ height,
54
+ width,
55
+ negative_prompt=negative_prompt,
56
+ negative_prompt_2=negative_prompt_2,
57
+ prompt_embeds=prompt_embeds,
58
+ negative_prompt_embeds=negative_prompt_embeds,
59
+ pooled_prompt_embeds=pooled_prompt_embeds,
60
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
61
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
62
+ max_sequence_length=max_sequence_length,
63
+ )
64
+
65
+ self._guidance_scale = guidance_scale
66
+ self._joint_attention_kwargs = joint_attention_kwargs
67
+ self._current_timestep = None
68
+ self._interrupt = False
69
+
70
+ # 2. Define call parameters
71
+ if prompt is not None and isinstance(prompt, str):
72
+ batch_size = 1
73
+ elif prompt is not None and isinstance(prompt, list):
74
+ batch_size = len(prompt)
75
+ else:
76
+ batch_size = prompt_embeds.shape[0]
77
+
78
+ device = self._execution_device
79
+
80
+ lora_scale = (
81
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
82
+ )
83
+ (
84
+ prompt_embeds,
85
+ pooled_prompt_embeds,
86
+ text_ids,
87
+ ) = self.encode_prompt(
88
+ prompt=prompt,
89
+ prompt_2=prompt_2,
90
+ prompt_embeds=prompt_embeds,
91
+ pooled_prompt_embeds=pooled_prompt_embeds,
92
+ device=device,
93
+ num_images_per_prompt=num_images_per_prompt,
94
+ max_sequence_length=max_sequence_length,
95
+ lora_scale=lora_scale,
96
+ )
97
+
98
+ # 4. Prepare latent variables
99
+ num_channels_latents = self.transformer.config.in_channels // 4
100
+ latents, latent_image_ids = self.prepare_latents(
101
+ batch_size * num_images_per_prompt,
102
+ num_channels_latents,
103
+ height,
104
+ width,
105
+ prompt_embeds.dtype,
106
+ device,
107
+ generator,
108
+ latents,
109
+ )
110
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
111
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
112
+ sigmas = None
113
+ image_seq_len = latents.shape[1]
114
+ mu = calculate_shift(
115
+ image_seq_len,
116
+ self.scheduler.config.get("base_image_seq_len", 256),
117
+ self.scheduler.config.get("max_image_seq_len", 4096),
118
+ self.scheduler.config.get("base_shift", 0.5),
119
+ self.scheduler.config.get("max_shift", 1.15),
120
+ )
121
+ timesteps, num_inference_steps = retrieve_timesteps(
122
+ self.scheduler,
123
+ num_inference_steps,
124
+ device,
125
+ sigmas=sigmas,
126
+ mu=mu,
127
+ )
128
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
129
+ self._num_timesteps = len(timesteps)
130
+
131
+ # handle guidance
132
+ if self.transformer.config.guidance_embeds:
133
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
134
+ guidance = guidance.expand(latents.shape[0])
135
+ else:
136
+ guidance = None
137
+
138
+ # 6. Prepare image embeddings
139
+ all_latents = [latents]
140
+ all_log_probs = []
141
+
142
+ # 7. Denoising loop
143
+ self.scheduler.set_begin_index(0)
144
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
145
+ for i, t in enumerate(timesteps):
146
+ if self.interrupt:
147
+ continue
148
+ self._current_timestep = t
149
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
150
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
151
+ noise_pred = self.transformer(
152
+ hidden_states=latents,
153
+ timestep=timestep / 1000,
154
+ guidance=guidance,
155
+ pooled_projections=pooled_prompt_embeds,
156
+ encoder_hidden_states=prompt_embeds,
157
+ txt_ids=text_ids,
158
+ img_ids=latent_image_ids,
159
+ joint_attention_kwargs=self.joint_attention_kwargs,
160
+ return_dict=False,
161
+ )[0]
162
+ latents_dtype = latents.dtype
163
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
164
+ self.scheduler,
165
+ noise_pred.float(),
166
+ t.unsqueeze(0).repeat(latents.shape[0]),
167
+ latents.float(),
168
+ noise_level=noise_level,
169
+ )
170
+ if latents.dtype != latents_dtype:
171
+ latents = latents.to(latents_dtype)
172
+ all_latents.append(latents)
173
+ all_log_probs.append(log_prob)
174
+ # call the callback, if provided
175
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
176
+ progress_bar.update()
177
+
178
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
179
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
180
+ latents = latents.to(dtype=self.vae.dtype)
181
+ image = self.vae.decode(latents, return_dict=False)[0]
182
+ image = self.image_processor.postprocess(image, output_type=output_type)
183
+
184
+ # Offload all models
185
+ self.maybe_free_model_hooks()
186
+
187
+ return image, all_latents, latent_image_ids, text_ids, all_log_probs
adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
2
+ # with the following modifications:
3
+ # - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
4
+ # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
5
+ from typing import Any, Dict, List, Optional, Union
6
+ import torch
7
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
8
+ from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
9
+
10
+ @torch.no_grad()
11
+ def pipeline_with_logprob(
12
+ self,
13
+ prompt: Union[str, List[str]] = None,
14
+ prompt_2: Optional[Union[str, List[str]]] = None,
15
+ prompt_3: Optional[Union[str, List[str]]] = None,
16
+ height: Optional[int] = None,
17
+ width: Optional[int] = None,
18
+ num_inference_steps: int = 28,
19
+ sigmas: Optional[List[float]] = None,
20
+ guidance_scale: float = 7.0,
21
+ negative_prompt: Optional[Union[str, List[str]]] = None,
22
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
23
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
31
+ output_type: Optional[str] = "pil",
32
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
33
+ clip_skip: Optional[int] = None,
34
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
35
+ max_sequence_length: int = 256,
36
+ skip_layer_guidance_scale: float = 2.8,
37
+ noise_level: float = 0.7,
38
+ ):
39
+ height = height or self.default_sample_size * self.vae_scale_factor
40
+ width = width or self.default_sample_size * self.vae_scale_factor
41
+
42
+ # 1. Check inputs. Raise error if not correct
43
+ self.check_inputs(
44
+ prompt,
45
+ prompt_2,
46
+ prompt_3,
47
+ height,
48
+ width,
49
+ negative_prompt=negative_prompt,
50
+ negative_prompt_2=negative_prompt_2,
51
+ negative_prompt_3=negative_prompt_3,
52
+ prompt_embeds=prompt_embeds,
53
+ negative_prompt_embeds=negative_prompt_embeds,
54
+ pooled_prompt_embeds=pooled_prompt_embeds,
55
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
56
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
57
+ max_sequence_length=max_sequence_length,
58
+ )
59
+
60
+ self._guidance_scale = guidance_scale
61
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
62
+ self._clip_skip = clip_skip
63
+ self._joint_attention_kwargs = joint_attention_kwargs
64
+ self._interrupt = False
65
+
66
+ # 2. Define call parameters
67
+ if prompt is not None and isinstance(prompt, str):
68
+ batch_size = 1
69
+ elif prompt is not None and isinstance(prompt, list):
70
+ batch_size = len(prompt)
71
+ else:
72
+ batch_size = prompt_embeds.shape[0]
73
+
74
+ device = self._execution_device
75
+
76
+ lora_scale = (
77
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
78
+ )
79
+ (
80
+ prompt_embeds,
81
+ negative_prompt_embeds,
82
+ pooled_prompt_embeds,
83
+ negative_pooled_prompt_embeds,
84
+ ) = self.encode_prompt(
85
+ prompt=prompt,
86
+ prompt_2=prompt_2,
87
+ prompt_3=prompt_3,
88
+ negative_prompt=negative_prompt,
89
+ negative_prompt_2=negative_prompt_2,
90
+ negative_prompt_3=negative_prompt_3,
91
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
92
+ prompt_embeds=prompt_embeds,
93
+ negative_prompt_embeds=negative_prompt_embeds,
94
+ pooled_prompt_embeds=pooled_prompt_embeds,
95
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
96
+ device=device,
97
+ clip_skip=self.clip_skip,
98
+ num_images_per_prompt=num_images_per_prompt,
99
+ max_sequence_length=max_sequence_length,
100
+ lora_scale=lora_scale,
101
+ )
102
+ if self.do_classifier_free_guidance:
103
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
104
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
105
+
106
+ # 4. Prepare latent variables
107
+ num_channels_latents = self.transformer.config.in_channels
108
+ # latents = self.prepare_latents(
109
+ # batch_size * num_images_per_prompt,
110
+ # num_channels_latents,
111
+ # height,
112
+ # width,
113
+ # prompt_embeds.dtype,
114
+ # device,
115
+ # generator,
116
+ # latents,
117
+ # ).float()
118
+ latents = self.prepare_latents(
119
+ batch_size * num_images_per_prompt,
120
+ num_channels_latents,
121
+ height,
122
+ width,
123
+ prompt_embeds.dtype,
124
+ device,
125
+ generator,
126
+ latents,
127
+ )
128
+
129
+ # 5. Prepare timesteps
130
+ scheduler_kwargs = {}
131
+ timesteps, num_inference_steps = retrieve_timesteps(
132
+ self.scheduler,
133
+ num_inference_steps,
134
+ device,
135
+ sigmas=sigmas,
136
+ **scheduler_kwargs,
137
+ )
138
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
139
+ self._num_timesteps = len(timesteps)
140
+
141
+ # 6. Prepare image embeddings
142
+ all_latents = [latents]
143
+ all_log_probs = []
144
+ # impor ptbd;
145
+
146
+ # 7. Denoising loop
147
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
148
+ for i, t in enumerate(timesteps):
149
+ if self.interrupt:
150
+ continue
151
+
152
+ # expand the latents if we are doing classifier free guidance
153
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
154
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
155
+ timestep = t.expand(latent_model_input.shape[0])
156
+ # import pdb; pdb.set_trace()
157
+ noise_pred = self.transformer(
158
+ hidden_states=latent_model_input,
159
+ timestep=timestep,
160
+ encoder_hidden_states=prompt_embeds,
161
+ pooled_projections=pooled_prompt_embeds,
162
+ joint_attention_kwargs=self.joint_attention_kwargs,
163
+ return_dict=False,
164
+ )[0]
165
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
166
+ # perform guidance
167
+ if self.do_classifier_free_guidance:
168
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
169
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
170
+
171
+ latents_dtype = latents.dtype
172
+
173
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
174
+ self.scheduler,
175
+ noise_pred.float(),
176
+ t.unsqueeze(0),
177
+ latents.float(),
178
+ noise_level=noise_level,
179
+ )
180
+
181
+ all_latents.append(latents)
182
+ all_log_probs.append(log_prob)
183
+ if latents.dtype != latents_dtype:
184
+ latents = latents.to(latents_dtype)
185
+
186
+ # call the callback, if provided
187
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
188
+ progress_bar.update()
189
+
190
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
191
+ latents = latents.to(dtype=self.vae.dtype)
192
+ image = self.vae.decode(latents, return_dict=False)[0]
193
+ image = self.image_processor.postprocess(image, output_type=output_type)
194
+
195
+ # Offload all models
196
+ self.maybe_free_model_hooks()
197
+
198
+ return image, all_latents, all_log_probs
adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
2
+ # with the following modifications:
3
+ # - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
4
+ # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
5
+ from typing import Any, Dict, List, Optional, Union
6
+ import torch
7
+ import random
8
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
9
+ from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+
13
+
14
+
15
+ @torch.no_grad()
16
+ def pipeline_with_logprob(
17
+ self,
18
+ prompt: Union[str, List[str]] = None,
19
+ prompt_2: Optional[Union[str, List[str]]] = None,
20
+ prompt_3: Optional[Union[str, List[str]]] = None,
21
+ height: Optional[int] = None,
22
+ width: Optional[int] = None,
23
+ num_inference_steps: int = 28,
24
+ mini_num_image_per_prompt: int = 1,
25
+ sigmas: Optional[List[float]] = None,
26
+ guidance_scale: float = 7.0,
27
+ negative_prompt: Optional[Union[str, List[str]]] = None,
28
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
29
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
30
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
31
+ latents: Optional[torch.FloatTensor] = None,
32
+ prompt_embeds: Optional[torch.FloatTensor] = None,
33
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
34
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
35
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
36
+ output_type: Optional[str] = "pil",
37
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
38
+ clip_skip: Optional[int] = None,
39
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
40
+ max_sequence_length: int = 256,
41
+ skip_layer_guidance_scale: float = 2.8,
42
+ noise_level: float = 0.7,
43
+ train_num_steps: int = 1,
44
+ process_index: int = 0,
45
+ sample_num_steps: int = 10,
46
+ random_timestep: Optional[int] = None,
47
+ ):
48
+ height = height or self.default_sample_size * self.vae_scale_factor
49
+ width = width or self.default_sample_size * self.vae_scale_factor
50
+
51
+ # 1. Check inputs. Raise error if not correct
52
+ self.check_inputs(
53
+ prompt,
54
+ prompt_2,
55
+ prompt_3,
56
+ height,
57
+ width,
58
+ negative_prompt=negative_prompt,
59
+ negative_prompt_2=negative_prompt_2,
60
+ negative_prompt_3=negative_prompt_3,
61
+ prompt_embeds=prompt_embeds,
62
+ negative_prompt_embeds=negative_prompt_embeds,
63
+ pooled_prompt_embeds=pooled_prompt_embeds,
64
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
65
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
66
+ max_sequence_length=max_sequence_length,
67
+ )
68
+
69
+ self._guidance_scale = guidance_scale
70
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
71
+ self._clip_skip = clip_skip
72
+ self._joint_attention_kwargs = joint_attention_kwargs
73
+ self._interrupt = False
74
+
75
+ # 2. Define call parameters
76
+ if prompt is not None and isinstance(prompt, str):
77
+ batch_size = 1
78
+ elif prompt is not None and isinstance(prompt, list):
79
+ batch_size = len(prompt)
80
+ else:
81
+ batch_size = prompt_embeds.shape[0]
82
+
83
+ device = self._execution_device
84
+
85
+ lora_scale = (
86
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
87
+ )
88
+ (
89
+ prompt_embeds,
90
+ negative_prompt_embeds,
91
+ pooled_prompt_embeds,
92
+ negative_pooled_prompt_embeds,
93
+ ) = self.encode_prompt(
94
+ prompt=prompt,
95
+ prompt_2=prompt_2,
96
+ prompt_3=prompt_3,
97
+ negative_prompt=negative_prompt,
98
+ negative_prompt_2=negative_prompt_2,
99
+ negative_prompt_3=negative_prompt_3,
100
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
101
+ prompt_embeds=prompt_embeds,
102
+ negative_prompt_embeds=negative_prompt_embeds,
103
+ pooled_prompt_embeds=pooled_prompt_embeds,
104
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
105
+ device=device,
106
+ clip_skip=self.clip_skip,
107
+ max_sequence_length=max_sequence_length,
108
+ lora_scale=lora_scale,
109
+ )
110
+ # import pdb; pdb.set_trace()
111
+
112
+ # 4. Prepare latent variables
113
+ num_channels_latents = self.transformer.config.in_channels
114
+ latents = self.prepare_latents(
115
+ batch_size,
116
+ num_channels_latents,
117
+ height,
118
+ width,
119
+ prompt_embeds.dtype,
120
+ device,
121
+ generator,
122
+ latents,
123
+ ).float()
124
+ # import pdb; pdb.set_trace()
125
+ # latents = latents.to(prompt_embeds.dtype)
126
+
127
+ # 5. Prepare timesteps
128
+ scheduler_kwargs = {}
129
+ timesteps, num_inference_steps = retrieve_timesteps(
130
+ self.scheduler,
131
+ num_inference_steps,
132
+ device,
133
+ sigmas=sigmas,
134
+ **scheduler_kwargs,
135
+ )
136
+ # timesteps = timesteps.to(prompt_embeds.dtype)
137
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
138
+ self._num_timesteps = len(timesteps)
139
+
140
+ random.seed(process_index)
141
+ if random_timestep is None:
142
+ random_timestep = random.randint(0, sample_num_steps//2)
143
+
144
+
145
+ # 6. Prepare image embeddings
146
+ all_latents = []
147
+ all_log_probs = []
148
+ all_timesteps = []
149
+
150
+ if self.do_classifier_free_guidance:
151
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
152
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
153
+ # 7. Denoising loop
154
+ # import pdb; pdb.set_trace()
155
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
156
+ # import pdb; pdb.set_trace()
157
+ for i, t in enumerate(timesteps):
158
+ if i < random_timestep:
159
+ cur_noise_level = 0
160
+ elif i == random_timestep:
161
+ cur_noise_level= noise_level
162
+ # 将latents repeat mini_num_image_per_prompt次
163
+ latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
164
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
165
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
166
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
167
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
168
+ if self.do_classifier_free_guidance:
169
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
170
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
171
+ all_latents.append(latents)
172
+ elif i > random_timestep and i < random_timestep + train_num_steps:
173
+ cur_noise_level = noise_level
174
+ else:
175
+ cur_noise_level= 0
176
+ # expand the latents if we are doing classifier free guidance
177
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
178
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
179
+ timestep = t.expand(latent_model_input.shape[0])
180
+ # import pdb; pdb.set_trace()
181
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
182
+ noise_pred = self.transformer(
183
+ hidden_states=latent_model_input,
184
+ timestep=timestep,
185
+ encoder_hidden_states=tem_prompt_embeds,
186
+ pooled_projections=tem_pooled_prompt_embeds,
187
+ joint_attention_kwargs=self.joint_attention_kwargs,
188
+ return_dict=False,
189
+ )[0]
190
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
191
+ # perform guidance
192
+ if self.do_classifier_free_guidance:
193
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
194
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
195
+
196
+ latents_dtype = latents.dtype
197
+
198
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
199
+ self.scheduler,
200
+ noise_pred.float(),
201
+ t.unsqueeze(0),
202
+ latents.float(),
203
+ noise_level=cur_noise_level,
204
+ )
205
+
206
+ # if latents.dtype != latents_dtype:
207
+ # latents = latents.to(latents_dtype)
208
+
209
+ if i >= random_timestep and i < random_timestep + train_num_steps:
210
+ all_latents.append(latents)
211
+ all_log_probs.append(log_prob)
212
+ all_timesteps.append(t.repeat(len(latents)))
213
+ # import pdb; pdb.set_trace()
214
+
215
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
216
+ progress_bar.update()
217
+
218
+
219
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
220
+ latents = latents.to(dtype=self.vae.dtype)
221
+ image = self.vae.decode(latents, return_dict=False)[0]
222
+ reconstructd_image = self.image_processor.postprocess(image, output_type="pil")
223
+ # reconstructd_image[0].save("0.png")
224
+ # import pdb; pdb.set_trace()
225
+ image = self.image_processor.postprocess(image, output_type=output_type)
226
+
227
+ # Offload all models
228
+ self.maybe_free_model_hooks()
229
+ return image, all_latents, all_log_probs, all_timesteps
230
+
231
+
232
+
233
+ @torch.no_grad()
234
+ def pipeline_with_logprob_new(
235
+ self,
236
+ prompt: Union[str, List[str]] = None,
237
+ prompt_2: Optional[Union[str, List[str]]] = None,
238
+ prompt_3: Optional[Union[str, List[str]]] = None,
239
+ height: Optional[int] = None,
240
+ width: Optional[int] = None,
241
+ num_inference_steps: int = 28,
242
+ mini_num_image_per_prompt: int = 1,
243
+ sigmas: Optional[List[float]] = None,
244
+ guidance_scale: float = 7.0,
245
+ negative_prompt: Optional[Union[str, List[str]]] = None,
246
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
247
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
248
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
249
+ latents: Optional[torch.FloatTensor] = None,
250
+ prompt_embeds: Optional[torch.FloatTensor] = None,
251
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
252
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
254
+ output_type: Optional[str] = "pil",
255
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
256
+ clip_skip: Optional[int] = None,
257
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
258
+ max_sequence_length: int = 256,
259
+ skip_layer_guidance_scale: float = 2.8,
260
+ noise_level: float = 0.7,
261
+ train_num_steps: int = 1,
262
+ process_index: int = 0,
263
+ sample_num_steps: int = 10,
264
+ random_timestep: Optional[int] = None,
265
+ ):
266
+ height = height or self.default_sample_size * self.vae_scale_factor
267
+ width = width or self.default_sample_size * self.vae_scale_factor
268
+ # import pdb; pdb.set_trace()
269
+
270
+ # 1. Check inputs. Raise error if not correct
271
+ self.check_inputs(
272
+ prompt,
273
+ prompt_2,
274
+ prompt_3,
275
+ height,
276
+ width,
277
+ negative_prompt=negative_prompt,
278
+ negative_prompt_2=negative_prompt_2,
279
+ negative_prompt_3=negative_prompt_3,
280
+ prompt_embeds=prompt_embeds,
281
+ negative_prompt_embeds=negative_prompt_embeds,
282
+ pooled_prompt_embeds=pooled_prompt_embeds,
283
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
284
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
285
+ max_sequence_length=max_sequence_length,
286
+ )
287
+ # import pdb; pdb.set_trace()
288
+
289
+ self._guidance_scale = guidance_scale
290
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
291
+ self._clip_skip = clip_skip
292
+ self._joint_attention_kwargs = joint_attention_kwargs
293
+ self._interrupt = False
294
+
295
+ # 2. Define call parameters
296
+ if prompt is not None and isinstance(prompt, str):
297
+ batch_size = 1
298
+ elif prompt is not None and isinstance(prompt, list):
299
+ batch_size = len(prompt)
300
+ else:
301
+ batch_size = prompt_embeds.shape[0]
302
+
303
+ device = self._execution_device
304
+
305
+ lora_scale = (
306
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
307
+ )
308
+ # import pdb; pdb.set_trace()
309
+ (
310
+ prompt_embeds,
311
+ negative_prompt_embeds,
312
+ pooled_prompt_embeds,
313
+ negative_pooled_prompt_embeds,
314
+ ) = self.encode_prompt(
315
+ prompt=prompt,
316
+ prompt_2=prompt_2,
317
+ prompt_3=prompt_3,
318
+ negative_prompt=negative_prompt,
319
+ negative_prompt_2=negative_prompt_2,
320
+ negative_prompt_3=negative_prompt_3,
321
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
322
+ prompt_embeds=prompt_embeds,
323
+ negative_prompt_embeds=negative_prompt_embeds,
324
+ pooled_prompt_embeds=pooled_prompt_embeds,
325
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
326
+ device=device,
327
+ clip_skip=self.clip_skip,
328
+ max_sequence_length=max_sequence_length,
329
+ lora_scale=lora_scale,
330
+ )
331
+ # import pdb; pdb.set_trace()
332
+
333
+ # 4. Prepare latent variables
334
+ num_channels_latents = self.transformer.config.in_channels
335
+ latents = self.prepare_latents(
336
+ batch_size,
337
+ num_channels_latents,
338
+ height,
339
+ width,
340
+ prompt_embeds.dtype,
341
+ device,
342
+ generator,
343
+ latents,
344
+ )
345
+ # import pdb; pdb.set_trace()
346
+ # latents = latents.to(prompt_embeds.dtype)
347
+
348
+ # 5. Prepare timesteps
349
+ scheduler_kwargs = {}
350
+ timesteps, num_inference_steps = retrieve_timesteps(
351
+ self.scheduler,
352
+ num_inference_steps,
353
+ device,
354
+ sigmas=sigmas,
355
+ **scheduler_kwargs,
356
+ )
357
+ # timesteps = timesteps.to(prompt_embeds.dtype)
358
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
359
+ self._num_timesteps = len(timesteps)
360
+
361
+ random.seed(process_index)
362
+ if random_timestep is None:
363
+ random_timestep = random.randint(0, sample_num_steps//2)
364
+
365
+
366
+ # 6. Prepare image embeddings
367
+ all_latents = []
368
+ all_log_probs = []
369
+ all_timesteps = []
370
+ # import pdb; pdb.set_trace()
371
+
372
+ if self.do_classifier_free_guidance:
373
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
374
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
375
+ # 7. Denoising loop
376
+ # import pdb; pdb.set_trace()
377
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
378
+ # import pdb; pdb.set_trace()
379
+ for i, t in enumerate(timesteps):
380
+ if i < random_timestep:
381
+ cur_noise_level = 0
382
+ elif i == random_timestep:
383
+ cur_noise_level= noise_level
384
+ # 将latents repeat mini_num_image_per_prompt次
385
+ latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
386
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
387
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
388
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
389
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
390
+ if self.do_classifier_free_guidance:
391
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
392
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
393
+ all_latents.append(latents)
394
+ elif i > random_timestep and i < random_timestep + train_num_steps:
395
+ cur_noise_level = noise_level
396
+ else:
397
+ cur_noise_level= 0
398
+ # expand the latents if we are doing classifier free guidance
399
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
400
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
401
+ timestep = t.expand(latent_model_input.shape[0])
402
+ # import pdb; pdb.set_trace()
403
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
404
+ noise_pred = self.transformer(
405
+ hidden_states=latent_model_input,
406
+ timestep=timestep,
407
+ encoder_hidden_states=tem_prompt_embeds,
408
+ pooled_projections=tem_pooled_prompt_embeds,
409
+ joint_attention_kwargs=self.joint_attention_kwargs,
410
+ return_dict=False,
411
+ )[0]
412
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
413
+ # perform guidance
414
+ if self.do_classifier_free_guidance:
415
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
416
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
417
+
418
+ latents_dtype = latents.dtype
419
+
420
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
421
+ self.scheduler,
422
+ noise_pred.float(),
423
+ t.unsqueeze(0),
424
+ latents.float(),
425
+ noise_level=cur_noise_level,
426
+ )
427
+
428
+ if latents.dtype != latents_dtype:
429
+ latents = latents.to(latents_dtype)
430
+
431
+ if i >= random_timestep and i < random_timestep + train_num_steps:
432
+ all_latents.append(latents)
433
+ all_log_probs.append(log_prob)
434
+ all_timesteps.append(t.repeat(len(latents)))
435
+ # import pdb; pdb.set_trace()
436
+
437
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
438
+ progress_bar.update()
439
+
440
+
441
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
442
+ latents = latents.to(dtype=self.vae.dtype)
443
+ image = self.vae.decode(latents, return_dict=False)[0]
444
+ image = self.image_processor.postprocess(image, output_type=output_type)
445
+
446
+ # Offload all models
447
+ self.maybe_free_model_hooks()
448
+ return image, all_latents, all_log_probs, all_timesteps
449
+
450
+
451
+
452
+
453
+ @torch.no_grad()
454
+ def pipeline_with_logprob_random(
455
+ self,
456
+ prompt: Union[str, List[str]] = None,
457
+ prompt_2: Optional[Union[str, List[str]]] = None,
458
+ prompt_3: Optional[Union[str, List[str]]] = None,
459
+ height: Optional[int] = None,
460
+ width: Optional[int] = None,
461
+ num_inference_steps: int = 28,
462
+ mini_num_image_per_prompt: int = 1,
463
+ sigmas: Optional[List[float]] = None,
464
+ guidance_scale: float = 7.0,
465
+ negative_prompt: Optional[Union[str, List[str]]] = None,
466
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
467
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
468
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
469
+ latents: Optional[torch.FloatTensor] = None,
470
+ prompt_embeds: Optional[torch.FloatTensor] = None,
471
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
472
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
473
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
474
+ output_type: Optional[str] = "pil",
475
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
476
+ clip_skip: Optional[int] = None,
477
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
478
+ max_sequence_length: int = 256,
479
+ skip_layer_guidance_scale: float = 2.8,
480
+ noise_level: float = 0.7,
481
+ train_num_steps: int = 1,
482
+ process_index: int = 0,
483
+ sample_num_steps: int = 10,
484
+ random_timestep: Optional[int] = None,
485
+ ):
486
+ height = height or self.default_sample_size * self.vae_scale_factor
487
+ width = width or self.default_sample_size * self.vae_scale_factor
488
+ # import pdb; pdb.set_trace()
489
+
490
+ # 1. Check inputs. Raise error if not correct
491
+ self.check_inputs(
492
+ prompt,
493
+ prompt_2,
494
+ prompt_3,
495
+ height,
496
+ width,
497
+ negative_prompt=negative_prompt,
498
+ negative_prompt_2=negative_prompt_2,
499
+ negative_prompt_3=negative_prompt_3,
500
+ prompt_embeds=prompt_embeds,
501
+ negative_prompt_embeds=negative_prompt_embeds,
502
+ pooled_prompt_embeds=pooled_prompt_embeds,
503
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
504
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
505
+ max_sequence_length=max_sequence_length,
506
+ )
507
+ # import pdb; pdb.set_trace()
508
+
509
+ self._guidance_scale = guidance_scale
510
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
511
+ self._clip_skip = clip_skip
512
+ self._joint_attention_kwargs = joint_attention_kwargs
513
+ self._interrupt = False
514
+
515
+ # 2. Define call parameters
516
+ if prompt is not None and isinstance(prompt, str):
517
+ batch_size = 1
518
+ elif prompt is not None and isinstance(prompt, list):
519
+ batch_size = len(prompt)
520
+ else:
521
+ batch_size = prompt_embeds.shape[0]
522
+
523
+ device = self._execution_device
524
+
525
+ lora_scale = (
526
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
527
+ )
528
+ # import pdb; pdb.set_trace()
529
+ (
530
+ prompt_embeds,
531
+ negative_prompt_embeds,
532
+ pooled_prompt_embeds,
533
+ negative_pooled_prompt_embeds,
534
+ ) = self.encode_prompt(
535
+ prompt=prompt,
536
+ prompt_2=prompt_2,
537
+ prompt_3=prompt_3,
538
+ negative_prompt=negative_prompt,
539
+ negative_prompt_2=negative_prompt_2,
540
+ negative_prompt_3=negative_prompt_3,
541
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
542
+ prompt_embeds=prompt_embeds,
543
+ negative_prompt_embeds=negative_prompt_embeds,
544
+ pooled_prompt_embeds=pooled_prompt_embeds,
545
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
546
+ device=device,
547
+ clip_skip=self.clip_skip,
548
+ max_sequence_length=max_sequence_length,
549
+ lora_scale=lora_scale,
550
+ )
551
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
552
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
553
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
554
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
555
+ # import pdb; pdb.set_trace()
556
+
557
+ # 4. Prepare latent variables
558
+ num_channels_latents = self.transformer.config.in_channels
559
+ latents = self.prepare_latents(
560
+ prompt_embeds.shape[0],
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ prompt_embeds.dtype,
565
+ device,
566
+ generator,
567
+ latents,
568
+ )
569
+ # import pdb; pdb.set_trace()
570
+ # latents = latents.to(prompt_embeds.dtype)
571
+
572
+ # 5. Prepare timesteps
573
+ scheduler_kwargs = {}
574
+ timesteps, num_inference_steps = retrieve_timesteps(
575
+ self.scheduler,
576
+ num_inference_steps,
577
+ device,
578
+ sigmas=sigmas,
579
+ **scheduler_kwargs,
580
+ )
581
+ # timesteps = timesteps.to(prompt_embeds.dtype)
582
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
583
+ self._num_timesteps = len(timesteps)
584
+
585
+ random.seed(process_index)
586
+ if random_timestep is None:
587
+ random_timestep = random.randint(0, sample_num_steps//2)
588
+
589
+
590
+ # 6. Prepare image embeddings
591
+ all_latents = []
592
+ all_log_probs = []
593
+ all_timesteps = []
594
+ if self.do_classifier_free_guidance:
595
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
596
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
597
+
598
+ if self.do_classifier_free_guidance:
599
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
600
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
601
+ # 7. Denoising loop
602
+ # import pdb; pdb.set_trace()
603
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
604
+ # import pdb; pdb.set_trace()
605
+ for i, t in enumerate(timesteps):
606
+ if i < random_timestep:
607
+ cur_noise_level = 0
608
+ elif i == random_timestep:
609
+ cur_noise_level= noise_level
610
+ # 将latents repeat mini_num_image_per_prompt次
611
+ # latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
612
+ # prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
613
+ # pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
614
+ # negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
615
+ # negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
616
+ # if self.do_classifier_free_guidance:
617
+ # tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
618
+ # tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
619
+ all_latents.append(latents)
620
+ elif i > random_timestep and i < random_timestep + train_num_steps:
621
+ cur_noise_level = noise_level
622
+ else:
623
+ cur_noise_level= 0
624
+ # expand the latents if we are doing classifier free guidance
625
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
626
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
627
+ timestep = t.expand(latent_model_input.shape[0])
628
+ # import pdb; pdb.set_trace()
629
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
630
+ noise_pred = self.transformer(
631
+ hidden_states=latent_model_input,
632
+ timestep=timestep,
633
+ encoder_hidden_states=tem_prompt_embeds,
634
+ pooled_projections=tem_pooled_prompt_embeds,
635
+ joint_attention_kwargs=self.joint_attention_kwargs,
636
+ return_dict=False,
637
+ )[0]
638
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
639
+ # perform guidance
640
+ if self.do_classifier_free_guidance:
641
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
642
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
643
+
644
+ latents_dtype = latents.dtype
645
+
646
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
647
+ self.scheduler,
648
+ noise_pred.float(),
649
+ t.unsqueeze(0),
650
+ latents.float(),
651
+ noise_level=cur_noise_level,
652
+ )
653
+
654
+ if latents.dtype != latents_dtype:
655
+ latents = latents.to(latents_dtype)
656
+
657
+ if i >= random_timestep and i < random_timestep + train_num_steps:
658
+ all_latents.append(latents)
659
+ all_log_probs.append(log_prob)
660
+ all_timesteps.append(t.repeat(len(latents)))
661
+ # import pdb; pdb.set_trace()
662
+
663
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
664
+ progress_bar.update()
665
+
666
+
667
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
668
+ latents = latents.to(dtype=self.vae.dtype)
669
+ image = self.vae.decode(latents, return_dict=False)[0]
670
+ image = self.image_processor.postprocess(image, output_type=output_type)
671
+
672
+ # Offload all models
673
+ self.maybe_free_model_hooks()
674
+ return image, all_latents, all_log_probs, all_timesteps
675
+
676
+
677
+
678
+ def move_scheduler_to_device(scheduler, device="cuda"):
679
+ for attr_name in dir(scheduler):
680
+ attr = getattr(scheduler, attr_name)
681
+ if isinstance(attr, torch.Tensor):
682
+ setattr(scheduler, attr_name, attr.to(device))
683
+ return scheduler
684
+
685
+
686
+ def image_to_latent(pipe, images: Union[Image.Image, List[Image.Image]], device="cuda"):
687
+ # 统一转 list
688
+ if isinstance(images, Image.Image):
689
+ images = [images]
690
+
691
+ preprocess = transforms.Compose([
692
+ transforms.Resize((512, 512)),
693
+ transforms.ToTensor(), # 转 [0,1]
694
+ transforms.Normalize([0.5], [0.5]) # 映射到 [-1,1]
695
+ ])
696
+
697
+ # 批量处理
698
+ img_tensors = [preprocess(img) for img in images] # list of [3,512,512]
699
+ img_tensor = torch.stack(img_tensors, dim=0).to(device, dtype=torch.float32) # [B,3,512,512]
700
+ # import pdb; pdb.set_trace()
701
+
702
+ # 过 VAE 编码
703
+ latent = pipe.vae.encode(img_tensor).latent_dist.sample()
704
+ latent = latent * pipe.vae.config.scaling_factor
705
+ return latent.to(torch.bfloat16) # [B,4,64,64] (假设512输入,缩小8倍)
706
+
707
+
708
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
709
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
710
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
711
+ timesteps = timesteps.to(device)
712
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
713
+
714
+ sigma = sigmas[step_indices].flatten()
715
+ while len(sigma.shape) < n_dim:
716
+ sigma = sigma.unsqueeze(-1)
717
+ return sigma
718
+
719
+
720
+
721
+ @torch.no_grad()
722
+ def flux_to_sd3_denoise(
723
+ self,
724
+ prompt: Union[str, List[str]] = None,
725
+ prompt_2: Optional[Union[str, List[str]]] = None,
726
+ prompt_3: Optional[Union[str, List[str]]] = None,
727
+ flux_images=None,
728
+ device="cuda",
729
+ output_type: Optional[str] = "pil",
730
+ num_inference_steps: int = 20,
731
+ guidance_scale: float = 7.0,
732
+ negative_prompt: Optional[Union[str, List[str]]] = None,
733
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
734
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
735
+ prompt_embeds: Optional[torch.FloatTensor] = None,
736
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
737
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
738
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
739
+ max_sequence_length: int = 256,
740
+ noise_level: float = 0.7,
741
+ random_timestep: Optional[int] = None,
742
+ noise_timestep_ratio: float = 0.4,
743
+ clip_skip: Optional[int] = None,
744
+ ):
745
+ """
746
+ 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
747
+ 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
748
+ """
749
+ # 1. 转 latent
750
+ flux_latent = image_to_latent(self, flux_images, device)
751
+ self._guidance_scale = guidance_scale
752
+ self._clip_skip = clip_skip
753
+
754
+ # 2. 准备 scheduler
755
+ noise_scheduler = self.scheduler
756
+ noise_scheduler.set_timesteps(num_inference_steps)
757
+ timesteps = noise_scheduler.timesteps.to(device)
758
+
759
+ # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
760
+ target_idx = torch.tensor([noise_timestep_ratio], device=device)
761
+ t = timesteps[target_idx].to(device)
762
+
763
+ noise = torch.randn_like(flux_latent)
764
+ sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
765
+ latents = (1.0 - sigmas) * flux_latent + sigmas * noise
766
+ num_channels_latents = self.transformer.config.in_channels
767
+ if prompt is not None and isinstance(prompt, str):
768
+ batch_size = 1
769
+ elif prompt is not None and isinstance(prompt, list):
770
+ batch_size = len(prompt)
771
+ else:
772
+ batch_size = prompt_embeds.shape[0]
773
+
774
+ # latents = self.prepare_latents(
775
+ # batch_size,
776
+ # num_channels_latents,
777
+ # 512,
778
+ # 512,
779
+ # prompt_embeds.dtype,
780
+ # device,
781
+ # None,
782
+ # None,
783
+ # )
784
+
785
+
786
+
787
+ # import pdb; pdb.set_trace()
788
+
789
+ # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
790
+ # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
791
+
792
+ # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
793
+ # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
794
+
795
+ # 保存到本地
796
+ # noisy_image.save("noisy_image.png")
797
+ # import pdb; pdb.set_trace()
798
+
799
+ # 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
800
+ # lora_scale = (
801
+ # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
802
+ # )
803
+ lora_scale = None
804
+ (
805
+ prompt_embeds,
806
+ negative_prompt_embeds,
807
+ pooled_prompt_embeds,
808
+ negative_pooled_prompt_embeds,
809
+ ) = self.encode_prompt(
810
+ prompt=prompt,
811
+ prompt_2=prompt_2,
812
+ prompt_3=prompt_3,
813
+ negative_prompt=negative_prompt,
814
+ negative_prompt_2=negative_prompt_2,
815
+ negative_prompt_3=negative_prompt_3,
816
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
817
+ prompt_embeds=prompt_embeds,
818
+ negative_prompt_embeds=negative_prompt_embeds,
819
+ pooled_prompt_embeds=pooled_prompt_embeds,
820
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
821
+ device=device,
822
+ clip_skip=self.clip_skip,
823
+ max_sequence_length=max_sequence_length,
824
+ lora_scale=lora_scale,
825
+ )
826
+ # import pdb; pdb.set_trace()
827
+
828
+
829
+ prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
830
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
831
+ negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
832
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
833
+
834
+
835
+ if self.do_classifier_free_guidance:
836
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
837
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
838
+ else:
839
+ tem_prompt_embeds = prompt_embeds
840
+ tem_pooled_prompt_embeds = pooled_prompt_embeds
841
+
842
+ # 5. 从当前 t 开始去噪
843
+ noise_scheduler.set_timesteps(num_inference_steps)
844
+ timesteps = noise_scheduler.timesteps.to(device)
845
+ start_idx = (timesteps >= t[0]).nonzero()[-1].item()
846
+ timesteps = timesteps[start_idx:]
847
+
848
+ all_latents, all_log_probs, all_timesteps = [], [], []
849
+ noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
850
+
851
+ for index, t_cur in enumerate(timesteps):
852
+ # import pdb; pdb.set_trace()
853
+ if index==0:
854
+ all_latents.append(latents)
855
+
856
+ if index<2:
857
+ cur_noise_level = noise_level
858
+ else:
859
+ cur_noise_level = 0.0
860
+
861
+ latent_model_input = (
862
+ torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
863
+ )
864
+ t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
865
+
866
+ latents_dtype = latents.dtype
867
+ model_pred = self.transformer(
868
+ hidden_states=latent_model_input,
869
+ timestep=t_input,
870
+ encoder_hidden_states=tem_prompt_embeds,
871
+ pooled_projections=tem_pooled_prompt_embeds,
872
+ return_dict=False,
873
+ )[0]
874
+
875
+ if self.do_classifier_free_guidance:
876
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
877
+ model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
878
+ # import pdb; pdb.set_trace()
879
+
880
+
881
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
882
+ noise_scheduler,
883
+ model_pred.float(),
884
+ t_cur.repeat(len(latents)),
885
+ latents.float(),
886
+ noise_level=noise_level,
887
+ )
888
+ if latents.dtype != latents_dtype:
889
+ latents = latents.to(latents_dtype)
890
+
891
+ if index>=0 and index<2:
892
+ # if index<2:
893
+ # print(model_pred)
894
+ all_latents.append(latents)
895
+ all_log_probs.append(log_prob)
896
+ all_timesteps.append(t_cur.repeat(len(latents)))
897
+ # import pdb; pdb.set_trace()
898
+
899
+ # 6. 最终解码
900
+ denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
901
+ denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
902
+
903
+ image = self.vae.decode(denoised_latent, return_dict=False)[0]
904
+ # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
905
+ image = self.image_processor.postprocess(image, output_type=output_type)
906
+
907
+ return image, all_latents, all_log_probs, all_timesteps
908
+
909
+
910
+
911
+
912
+
913
+ @torch.no_grad()
914
+ def flux_to_sd3_denoise_random(
915
+ self,
916
+ prompt: Union[str, List[str]] = None,
917
+ prompt_2: Optional[Union[str, List[str]]] = None,
918
+ prompt_3: Optional[Union[str, List[str]]] = None,
919
+ flux_images=None,
920
+ device="cuda",
921
+ output_type: Optional[str] = "pil",
922
+ num_inference_steps: int = 20,
923
+ guidance_scale: float = 7.0,
924
+ negative_prompt: Optional[Union[str, List[str]]] = None,
925
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
926
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
927
+ prompt_embeds: Optional[torch.FloatTensor] = None,
928
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
929
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
930
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
931
+ max_sequence_length: int = 256,
932
+ noise_level: float = 0.7,
933
+ random_timestep: Optional[int] = None,
934
+ noise_timestep_ratio: float = 0.4,
935
+ clip_skip: Optional[int] = None,
936
+ ):
937
+ """
938
+ 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
939
+ 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
940
+ """
941
+ # 1. 转 latent
942
+ flux_latent = image_to_latent(self, flux_images, device)
943
+ self._guidance_scale = guidance_scale
944
+ self._clip_skip = clip_skip
945
+
946
+ # 2. 准备 scheduler
947
+ noise_scheduler = self.scheduler
948
+ noise_scheduler.set_timesteps(num_inference_steps)
949
+ timesteps = noise_scheduler.timesteps.to(device)
950
+
951
+ # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
952
+ # t = timesteps[target_idx].to(device)
953
+
954
+ # noise = torch.randn_like(flux_latent)
955
+ # sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
956
+ # latents = (1.0 - sigmas) * flux_latent + sigmas * noise
957
+
958
+ target_idx = torch.tensor([random.randint(5, 10)], device=device)
959
+ t = timesteps[target_idx].to(device)
960
+ # 生成标准高斯噪声
961
+ noise = torch.randn_like(flux_latent)
962
+ # 获取对应的 sigma
963
+ sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
964
+ # 给 latent 加噪
965
+ latents = (1.0 - sigmas) * flux_latent + sigmas * noise
966
+
967
+ # import pdb; pdb.set_trace()
968
+
969
+ # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
970
+ # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
971
+
972
+ # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
973
+ # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
974
+
975
+ # 保存到本地
976
+ # noisy_image.save("noisy_image.png")
977
+ # import pdb; pdb.set_trace()
978
+
979
+ # 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
980
+ # lora_scale = (
981
+ # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
982
+ # )
983
+ lora_scale = None
984
+ (
985
+ prompt_embeds,
986
+ negative_prompt_embeds,
987
+ pooled_prompt_embeds,
988
+ negative_pooled_prompt_embeds,
989
+ ) = self.encode_prompt(
990
+ prompt=prompt,
991
+ prompt_2=prompt_2,
992
+ prompt_3=prompt_3,
993
+ negative_prompt=negative_prompt,
994
+ negative_prompt_2=negative_prompt_2,
995
+ negative_prompt_3=negative_prompt_3,
996
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
997
+ prompt_embeds=prompt_embeds,
998
+ negative_prompt_embeds=negative_prompt_embeds,
999
+ pooled_prompt_embeds=pooled_prompt_embeds,
1000
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1001
+ device=device,
1002
+ clip_skip=self.clip_skip,
1003
+ max_sequence_length=max_sequence_length,
1004
+ lora_scale=lora_scale,
1005
+ )
1006
+ # import pdb; pdb.set_trace()
1007
+
1008
+
1009
+ prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
1010
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
1011
+ negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
1012
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
1013
+
1014
+
1015
+ if self.do_classifier_free_guidance:
1016
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1017
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1018
+ else:
1019
+ tem_prompt_embeds = prompt_embeds
1020
+ tem_pooled_prompt_embeds = pooled_prompt_embeds
1021
+
1022
+ # 5. 从当前 t 开始去噪
1023
+ noise_scheduler.set_timesteps(num_inference_steps)
1024
+ timesteps = noise_scheduler.timesteps.to(device)
1025
+ start_idx = (timesteps >= t[0]).nonzero()[-1].item()
1026
+ timesteps = timesteps[start_idx:]
1027
+
1028
+ all_latents, all_log_probs, all_timesteps = [], [], []
1029
+ noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
1030
+
1031
+ for index, t_cur in enumerate(timesteps):
1032
+ if index==0:
1033
+ all_latents.append(latents)
1034
+
1035
+ latent_model_input = (
1036
+ torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1037
+ )
1038
+ t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
1039
+
1040
+ latents_dtype = latents.dtype
1041
+ model_pred = self.transformer(
1042
+ hidden_states=latent_model_input,
1043
+ timestep=t_input,
1044
+ encoder_hidden_states=tem_prompt_embeds,
1045
+ pooled_projections=tem_pooled_prompt_embeds,
1046
+ return_dict=False,
1047
+ )[0]
1048
+
1049
+ if self.do_classifier_free_guidance:
1050
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
1051
+ model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1052
+ # import pdb; pdb.set_trace()
1053
+
1054
+
1055
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
1056
+ noise_scheduler,
1057
+ model_pred.float(),
1058
+ t_cur.repeat(len(latents)),
1059
+ latents.float(),
1060
+ noise_level=noise_level,
1061
+ )
1062
+ if latents.dtype != latents_dtype:
1063
+ latents = latents.to(latents_dtype)
1064
+
1065
+ # if index>=2 and index<4:
1066
+ if index<2:
1067
+ # print(model_pred)
1068
+ all_latents.append(latents)
1069
+ all_log_probs.append(log_prob)
1070
+ all_timesteps.append(t_cur.repeat(len(latents)))
1071
+ # import pdb; pdb.set_trace()
1072
+
1073
+ # 6. 最终解码
1074
+ denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1075
+ denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
1076
+
1077
+ image = self.vae.decode(denoised_latent, return_dict=False)[0]
1078
+ # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
1079
+ image = self.image_processor.postprocess(image, output_type=output_type)
1080
+
1081
+ return image, all_latents, all_log_probs, all_timesteps
adv_grpo/diffusers_patch/sd3_sde_with_logprob.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
2
+ # We adapt it from flow to flow matching.
3
+
4
+ import math
5
+ from typing import Optional, Union
6
+ import torch
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
10
+
11
+
12
+
13
+ def sde_step_with_logprob(
14
+ self: FlowMatchEulerDiscreteScheduler,
15
+ model_output: torch.FloatTensor,
16
+ timestep: Union[float, torch.FloatTensor],
17
+ sample: torch.FloatTensor,
18
+ noise_level: float = 0.7,
19
+ prev_sample: Optional[torch.FloatTensor] = None,
20
+ generator: Optional[torch.Generator] = None,
21
+ ):
22
+ """
23
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
24
+ process from the learned model outputs (most often the predicted velocity).
25
+
26
+ Args:
27
+ model_output (`torch.FloatTensor`):
28
+ The direct output from learned flow model.
29
+ timestep (`float`):
30
+ The current discrete timestep in the diffusion chain.
31
+ sample (`torch.FloatTensor`):
32
+ A current instance of a sample created by the diffusion process.
33
+ generator (`torch.Generator`, *optional*):
34
+ A random number generator.
35
+ """
36
+ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
37
+ model_output=model_output.float()
38
+ sample=sample.float()
39
+ if prev_sample is not None:
40
+ prev_sample=prev_sample.float()
41
+
42
+ step_index = [self.index_for_timestep(t) for t in timestep]
43
+ prev_step_index = [step+1 for step in step_index]
44
+ sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
45
+ sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
46
+ sigma_max = self.sigmas[1].item()
47
+ dt = sigma_prev - sigma
48
+
49
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
50
+ # import pdb; pdb.set_trace()
51
+
52
+ # our sde
53
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
54
+
55
+ if prev_sample is None:
56
+ variance_noise = randn_tensor(
57
+ model_output.shape,
58
+ generator=generator,
59
+ device=model_output.device,
60
+ dtype=model_output.dtype,
61
+ )
62
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
63
+
64
+ log_prob = (
65
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
66
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
67
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
68
+ )
69
+
70
+ # mean along all but batch dimension
71
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
72
+
73
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
74
+
75
+
76
+
77
+ def sde_step_with_logprob_new(
78
+ self: FlowMatchEulerDiscreteScheduler,
79
+ model_output: torch.FloatTensor,
80
+ timestep: Union[float, torch.FloatTensor],
81
+ sample: torch.FloatTensor,
82
+ noise_level: float = 0.7,
83
+ prev_sample: Optional[torch.FloatTensor] = None,
84
+ generator: Optional[torch.Generator] = None,
85
+ ):
86
+ """
87
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
88
+ process from the learned model outputs (most often the predicted velocity).
89
+
90
+ Args:
91
+ model_output (`torch.FloatTensor`):
92
+ The direct output from learned flow model.
93
+ timestep (`float`):
94
+ The current discrete timestep in the diffusion chain.
95
+ sample (`torch.FloatTensor`):
96
+ A current instance of a sample created by the diffusion process.
97
+ generator (`torch.Generator`, *optional*):
98
+ A random number generator.
99
+ """
100
+ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
101
+ model_output=model_output.float()
102
+ sample=sample.float()
103
+ if prev_sample is not None:
104
+ prev_sample=prev_sample.float()
105
+
106
+ step_index = [self.index_for_timestep(t) for t in timestep]
107
+ prev_step_index = [step+1 for step in step_index]
108
+ sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
109
+ sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
110
+ sigma_max = self.sigmas[1].item()
111
+ dt = sigma_prev - sigma
112
+
113
+ # Flow-SDE
114
+ #std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt)
115
+ # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
116
+
117
+ # Flow-CPS
118
+ std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
119
+ pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
120
+ noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
121
+ prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
122
+ # import pdb; pdb.set_trace()
123
+
124
+ if prev_sample is None:
125
+ variance_noise = randn_tensor(
126
+ model_output.shape,
127
+ generator=generator,
128
+ device=model_output.device,
129
+ dtype=model_output.dtype,
130
+ )
131
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
132
+
133
+ # remove all constants
134
+ log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
135
+
136
+ # mean along all but batch dimension
137
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
138
+
139
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import torch
17
+
18
+
19
+ def _encode_prompt_with_t5(
20
+ text_encoder,
21
+ tokenizer,
22
+ max_sequence_length=512,
23
+ prompt=None,
24
+ num_images_per_prompt=1,
25
+ device=None,
26
+ text_input_ids=None,
27
+ ):
28
+ prompt = [prompt] if isinstance(prompt, str) else prompt
29
+ batch_size = len(prompt)
30
+
31
+ if tokenizer is not None:
32
+ text_inputs = tokenizer(
33
+ prompt,
34
+ padding="max_length",
35
+ max_length=max_sequence_length,
36
+ truncation=True,
37
+ return_length=False,
38
+ return_overflowing_tokens=False,
39
+ return_tensors="pt",
40
+ )
41
+ text_input_ids = text_inputs.input_ids
42
+ else:
43
+ if text_input_ids is None:
44
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
45
+
46
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
47
+
48
+ if hasattr(text_encoder, "module"):
49
+ dtype = text_encoder.module.dtype
50
+ else:
51
+ dtype = text_encoder.dtype
52
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
53
+
54
+ _, seq_len, _ = prompt_embeds.shape
55
+
56
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
57
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
58
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
59
+
60
+ return prompt_embeds
61
+
62
+
63
+ def _encode_prompt_with_clip(
64
+ text_encoder,
65
+ tokenizer,
66
+ prompt: str,
67
+ device=None,
68
+ text_input_ids=None,
69
+ num_images_per_prompt: int = 1,
70
+ ):
71
+ prompt = [prompt] if isinstance(prompt, str) else prompt
72
+ batch_size = len(prompt)
73
+
74
+ if tokenizer is not None:
75
+ text_inputs = tokenizer(
76
+ prompt,
77
+ padding="max_length",
78
+ max_length=77,
79
+ truncation=True,
80
+ return_overflowing_tokens=False,
81
+ return_length=False,
82
+ return_tensors="pt",
83
+ )
84
+
85
+ text_input_ids = text_inputs.input_ids
86
+ else:
87
+ if text_input_ids is None:
88
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
89
+
90
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
91
+
92
+ if hasattr(text_encoder, "module"):
93
+ dtype = text_encoder.module.dtype
94
+ else:
95
+ dtype = text_encoder.dtype
96
+ # Use pooled output of CLIPTextModel
97
+ prompt_embeds = prompt_embeds.pooler_output
98
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
99
+
100
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
101
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
102
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
103
+
104
+ return prompt_embeds
105
+
106
+
107
+ def encode_prompt(
108
+ text_encoders,
109
+ tokenizers,
110
+ prompt: str,
111
+ max_sequence_length,
112
+ device=None,
113
+ num_images_per_prompt: int = 1,
114
+ text_input_ids_list=None,
115
+ ):
116
+ prompt = [prompt] if isinstance(prompt, str) else prompt
117
+
118
+ if hasattr(text_encoders[0], "module"):
119
+ dtype = text_encoders[0].module.dtype
120
+ else:
121
+ dtype = text_encoders[0].dtype
122
+
123
+ pooled_prompt_embeds = _encode_prompt_with_clip(
124
+ text_encoder=text_encoders[0],
125
+ tokenizer=tokenizers[0],
126
+ prompt=prompt,
127
+ device=device if device is not None else text_encoders[0].device,
128
+ num_images_per_prompt=num_images_per_prompt,
129
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
130
+ )
131
+
132
+ prompt_embeds = _encode_prompt_with_t5(
133
+ text_encoder=text_encoders[1],
134
+ tokenizer=tokenizers[1],
135
+ max_sequence_length=max_sequence_length,
136
+ prompt=prompt,
137
+ num_images_per_prompt=num_images_per_prompt,
138
+ device=device if device is not None else text_encoders[1].device,
139
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
140
+ )
141
+
142
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
143
+
144
+ return prompt_embeds, pooled_prompt_embeds, text_ids
adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import torch
17
+
18
+
19
+ def _encode_prompt_with_t5(
20
+ text_encoder,
21
+ tokenizer,
22
+ max_sequence_length,
23
+ prompt=None,
24
+ num_images_per_prompt=1,
25
+ device=None,
26
+ text_input_ids=None,
27
+ ):
28
+ prompt = [prompt] if isinstance(prompt, str) else prompt
29
+ batch_size = len(prompt)
30
+
31
+ if tokenizer is not None:
32
+ text_inputs = tokenizer(
33
+ prompt,
34
+ padding="max_length",
35
+ max_length=max_sequence_length,
36
+ truncation=True,
37
+ add_special_tokens=True,
38
+ return_tensors="pt",
39
+ )
40
+ text_input_ids = text_inputs.input_ids
41
+ else:
42
+ if text_input_ids is None:
43
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
44
+
45
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
46
+
47
+ dtype = text_encoder.dtype
48
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
49
+
50
+ _, seq_len, _ = prompt_embeds.shape
51
+
52
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
53
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
54
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
55
+
56
+ return prompt_embeds
57
+
58
+
59
+ def _encode_prompt_with_clip(
60
+ text_encoder,
61
+ tokenizer,
62
+ prompt: str,
63
+ device=None,
64
+ text_input_ids=None,
65
+ num_images_per_prompt: int = 1,
66
+ ):
67
+ prompt = [prompt] if isinstance(prompt, str) else prompt
68
+ batch_size = len(prompt)
69
+
70
+ if tokenizer is not None:
71
+ text_inputs = tokenizer(
72
+ prompt,
73
+ padding="max_length",
74
+ max_length=77,
75
+ truncation=True,
76
+ return_tensors="pt",
77
+ )
78
+
79
+ text_input_ids = text_inputs.input_ids
80
+ else:
81
+ if text_input_ids is None:
82
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
83
+
84
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
85
+
86
+ pooled_prompt_embeds = prompt_embeds[0]
87
+ prompt_embeds = prompt_embeds.hidden_states[-2]
88
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
89
+
90
+ _, seq_len, _ = prompt_embeds.shape
91
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
92
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
93
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
94
+
95
+ return prompt_embeds, pooled_prompt_embeds
96
+
97
+
98
+ def encode_prompt(
99
+ text_encoders,
100
+ tokenizers,
101
+ prompt: str,
102
+ max_sequence_length,
103
+ device=None,
104
+ num_images_per_prompt: int = 1,
105
+ text_input_ids_list=None,
106
+ ):
107
+ prompt = [prompt] if isinstance(prompt, str) else prompt
108
+
109
+ clip_tokenizers = tokenizers[:2]
110
+ clip_text_encoders = text_encoders[:2]
111
+
112
+ clip_prompt_embeds_list = []
113
+ clip_pooled_prompt_embeds_list = []
114
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
115
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
116
+ text_encoder=text_encoder,
117
+ tokenizer=tokenizer,
118
+ prompt=prompt,
119
+ device=device if device is not None else text_encoder.device,
120
+ num_images_per_prompt=num_images_per_prompt,
121
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
122
+ )
123
+ clip_prompt_embeds_list.append(prompt_embeds)
124
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
125
+
126
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
127
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
128
+
129
+ t5_prompt_embed = _encode_prompt_with_t5(
130
+ text_encoders[-1],
131
+ tokenizers[-1],
132
+ max_sequence_length,
133
+ prompt=prompt,
134
+ num_images_per_prompt=num_images_per_prompt,
135
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
136
+ device=device if device is not None else text_encoders[-1].device,
137
+ )
138
+
139
+ clip_prompt_embeds = torch.nn.functional.pad(
140
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
141
+ )
142
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
143
+
144
+ return prompt_embeds, pooled_prompt_embeds
adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
2
+ import torch
3
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
4
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+ import math
7
+ import numpy as np
8
+ # import logger
9
+
10
+ def sde_step_with_logprob(
11
+ self: UniPCMultistepScheduler,
12
+ model_output: torch.FloatTensor,
13
+ timestep: Union[float, torch.FloatTensor],
14
+ sample: torch.FloatTensor,
15
+ prev_sample: Optional[torch.FloatTensor] = None,
16
+ generator: Optional[torch.Generator] = None,
17
+ determistic: bool = False,
18
+ return_pixel_log_prob: bool = False,
19
+ return_dt_and_std_dev_t: bool = False
20
+ ):
21
+ """
22
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
23
+ process from the learned model outputs (most often the predicted velocity).
24
+
25
+ Args:
26
+ model_output (`torch.FloatTensor`):
27
+ The direct output from learned flow model.
28
+ timestep (`float`):
29
+ The current discrete timestep in the diffusion chain.
30
+ sample (`torch.FloatTensor`):
31
+ A current instance of a sample created by the diffusion process.
32
+ generator (`torch.Generator`, *optional*):
33
+ A random number generator.
34
+ """
35
+ # prev_sample_mean, we must convert all variable to fp32
36
+ model_output=model_output.float()
37
+ sample=sample.float()
38
+ if prev_sample is not None:
39
+ prev_sample=prev_sample.float()
40
+
41
+ step_index = [self.index_for_timestep(t) for t in timestep]
42
+ prev_step_index = [step+1 for step in step_index]
43
+
44
+ self.sigmas = self.sigmas.to(sample.device)
45
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1, 1)
46
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1, 1)
47
+ sigma_max = self.sigmas[1].item()
48
+ sigma_min = self.sigmas[-1].item()
49
+ dt = sigma_prev - sigma
50
+
51
+ std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma
52
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
53
+
54
+ if prev_sample is not None and generator is not None:
55
+ raise ValueError(
56
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
57
+ " `prev_sample` stays `None`."
58
+ )
59
+
60
+ if prev_sample is None:
61
+ variance_noise = randn_tensor(
62
+ model_output.shape,
63
+ generator=generator,
64
+ device=model_output.device,
65
+ dtype=model_output.dtype,
66
+ )
67
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
68
+
69
+ # No noise is added during evaluation
70
+ if determistic:
71
+ prev_sample = sample + dt * model_output
72
+
73
+ log_prob = (
74
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
75
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
76
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
77
+ )
78
+
79
+ # mean along all but batch dimension
80
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
81
+
82
+ if return_dt_and_std_dev_t:
83
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt)
84
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
85
+
86
+ def wan_pipeline_with_logprob(
87
+ self,
88
+ prompt: Union[str, List[str]] = None,
89
+ negative_prompt: Union[str, List[str]] = None,
90
+ height: int = 480,
91
+ width: int = 832,
92
+ num_frames: int = 81,
93
+ num_inference_steps: int = 50,
94
+ guidance_scale: float = 5.0,
95
+ num_videos_per_prompt: Optional[int] = 1,
96
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
97
+ latents: Optional[torch.Tensor] = None,
98
+ prompt_embeds: Optional[torch.Tensor] = None,
99
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
100
+ output_type: Optional[str] = "np",
101
+ return_dict: bool = True,
102
+ attention_kwargs: Optional[Dict[str, Any]] = None,
103
+ callback_on_step_end: Optional[
104
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
105
+ ] = None,
106
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
107
+ max_sequence_length: int = 512,
108
+ determistic: bool = False,
109
+ kl_reward: float = 0.0,
110
+ return_pixel_log_prob: bool = False,
111
+ ):
112
+ r"""
113
+ The call function to the pipeline for generation.
114
+
115
+ Args:
116
+ prompt (`str` or `List[str]`, *optional*):
117
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
118
+ instead.
119
+ height (`int`, defaults to `480`):
120
+ The height in pixels of the generated image.
121
+ width (`int`, defaults to `832`):
122
+ The width in pixels of the generated image.
123
+ num_frames (`int`, defaults to `81`):
124
+ The number of frames in the generated video.
125
+ num_inference_steps (`int`, defaults to `50`):
126
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
127
+ expense of slower inference.
128
+ guidance_scale (`float`, defaults to `5.0`):
129
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
130
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
131
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
132
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
133
+ usually at the expense of lower image quality.
134
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
135
+ The number of images to generate per prompt.
136
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
137
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
138
+ generation deterministic.
139
+ latents (`torch.Tensor`, *optional*):
140
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
141
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
142
+ tensor is generated by sampling using the supplied random `generator`.
143
+ prompt_embeds (`torch.Tensor`, *optional*):
144
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
145
+ provided, text embeddings are generated from the `prompt` input argument.
146
+ output_type (`str`, *optional*, defaults to `"pil"`):
147
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
148
+ return_dict (`bool`, *optional*, defaults to `True`):
149
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
150
+ attention_kwargs (`dict`, *optional*):
151
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
152
+ `self.processor` in
153
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
154
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
155
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
156
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
157
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
158
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
159
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
160
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
161
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
162
+ `._callback_tensor_inputs` attribute of your pipeline class.
163
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
164
+ The dtype to use for the torch.amp.autocast.
165
+
166
+ Examples:
167
+
168
+ Returns:
169
+ [`~WanPipelineOutput`] or `tuple`:
170
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
171
+ the first element is a list with the generated images and the second element is a list of `bool`s
172
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
173
+ """
174
+
175
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
176
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
177
+
178
+ # 1. Check inputs. Raise error if not correct
179
+ self.check_inputs(
180
+ prompt,
181
+ negative_prompt,
182
+ height,
183
+ width,
184
+ prompt_embeds,
185
+ negative_prompt_embeds,
186
+ callback_on_step_end_tensor_inputs,
187
+ )
188
+
189
+ if num_frames % self.vae_scale_factor_temporal != 1:
190
+ print(
191
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
192
+ )
193
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
194
+ num_frames = max(num_frames, 1)
195
+
196
+ self._guidance_scale = guidance_scale
197
+ self._attention_kwargs = attention_kwargs
198
+ self._current_timestep = None
199
+ self._interrupt = False
200
+
201
+ device = self._execution_device
202
+
203
+ # 2. Define call parameters
204
+ if prompt is not None and isinstance(prompt, str):
205
+ batch_size = 1
206
+ elif prompt is not None and isinstance(prompt, list):
207
+ batch_size = len(prompt)
208
+ else:
209
+ batch_size = prompt_embeds.shape[0]
210
+
211
+ # 3. Encode input prompt
212
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
213
+ prompt=prompt,
214
+ negative_prompt=negative_prompt,
215
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
216
+ num_videos_per_prompt=num_videos_per_prompt,
217
+ prompt_embeds=prompt_embeds,
218
+ negative_prompt_embeds=negative_prompt_embeds,
219
+ max_sequence_length=max_sequence_length,
220
+ device=device,
221
+ )
222
+
223
+ transformer_dtype = self.transformer.dtype
224
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
225
+ if negative_prompt_embeds is not None:
226
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
227
+
228
+ # 4. Prepare timesteps
229
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
230
+ timesteps = self.scheduler.timesteps
231
+
232
+ # 5. Prepare latent variables
233
+ num_channels_latents = self.transformer.config.in_channels
234
+ latents = self.prepare_latents(
235
+ batch_size * num_videos_per_prompt,
236
+ num_channels_latents,
237
+ height,
238
+ width,
239
+ num_frames,
240
+ torch.float32,
241
+ device,
242
+ generator,
243
+ latents,
244
+ )
245
+
246
+ all_latents = [latents]
247
+ all_log_probs = []
248
+ all_kl = []
249
+
250
+ # 6. Denoising loop
251
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
252
+ self._num_timesteps = len(timesteps)
253
+ # print(timesteps)
254
+
255
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
256
+ for i, t in enumerate(timesteps):
257
+ if self.interrupt:
258
+ continue
259
+
260
+ latents_ori = latents.clone()
261
+ self._current_timestep = t
262
+ latent_model_input = latents.to(transformer_dtype)
263
+ timestep = t.expand(latents.shape[0])
264
+
265
+ noise_pred = self.transformer(
266
+ hidden_states=latent_model_input,
267
+ timestep=timestep,
268
+ encoder_hidden_states=prompt_embeds,
269
+ attention_kwargs=attention_kwargs,
270
+ return_dict=False,
271
+ )[0]
272
+ noise_pred = noise_pred.to(prompt_embeds.dtype)
273
+
274
+ if self.do_classifier_free_guidance:
275
+ noise_uncond = self.transformer(
276
+ hidden_states=latent_model_input,
277
+ timestep=timestep,
278
+ encoder_hidden_states=negative_prompt_embeds,
279
+ attention_kwargs=attention_kwargs,
280
+ return_dict=False,
281
+ )[0]
282
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
283
+
284
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
285
+ self.scheduler,
286
+ noise_pred.float(),
287
+ t.unsqueeze(0),
288
+ latents.float(),
289
+ determistic=determistic,
290
+ return_pixel_log_prob=return_pixel_log_prob
291
+ )
292
+ prev_latents = latents.clone()
293
+
294
+ all_latents.append(latents)
295
+ all_log_probs.append(log_prob)
296
+
297
+ # compute the previous noisy sample x_t -> x_t-1
298
+ # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
299
+
300
+ if callback_on_step_end is not None:
301
+ callback_kwargs = {}
302
+ for k in callback_on_step_end_tensor_inputs:
303
+ callback_kwargs[k] = locals()[k]
304
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
305
+
306
+ latents = callback_outputs.pop("latents", latents)
307
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
308
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
309
+
310
+ # use kl_reward & is sampling process
311
+ if kl_reward>0 and not determistic:
312
+ latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
313
+ with self.transformer.disable_adapter():
314
+ noise_pred = self.transformer(
315
+ hidden_states=latent_model_input,
316
+ timestep=timestep,
317
+ encoder_hidden_states=prompt_embeds,
318
+ attention_kwargs=attention_kwargs,
319
+ return_dict=False,
320
+ )[0]
321
+ noise_pred = noise_pred.to(prompt_embeds.dtype)
322
+ # perform guidance
323
+ if self.do_classifier_free_guidance:
324
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
325
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
326
+
327
+ _, ref_log_prob, ref_prev_latents_mean, ref_std_dev_t = sde_step_with_logprob(
328
+ self.scheduler,
329
+ noise_pred.float(),
330
+ t.unsqueeze(0),
331
+ latents_ori.float(),
332
+ prev_sample=prev_latents.float(),
333
+ determistic=determistic,
334
+ )
335
+ assert std_dev_t == ref_std_dev_t
336
+ kl = (prev_latents_mean - ref_prev_latents_mean)**2 / (2 * std_dev_t**2)
337
+ kl = kl.mean(dim=tuple(range(1, kl.ndim)))
338
+ all_kl.append(kl)
339
+ else:
340
+ # no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
341
+ all_kl.append(torch.zeros(len(latents), device=latents.device))
342
+
343
+ # call the callback, if provided
344
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
345
+ progress_bar.update()
346
+
347
+ # if XLA_AVAILABLE:
348
+ # xm.mark_step()
349
+
350
+ self._current_timestep = None
351
+
352
+ if not output_type == "latent":
353
+ latents = latents.to(self.vae.dtype)
354
+ latents_mean = (
355
+ torch.tensor(self.vae.config.latents_mean)
356
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
357
+ .to(latents.device, latents.dtype)
358
+ )
359
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
360
+ latents.device, latents.dtype
361
+ )
362
+ latents = latents / latents_std + latents_mean
363
+ video = self.vae.decode(latents, return_dict=False)[0]
364
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
365
+ else:
366
+ video = latents
367
+
368
+ self.maybe_free_model_hooks()
369
+
370
+ if not return_dict:
371
+ return (video, all_latents, all_log_probs, all_kl)
372
+
373
+ return WanPipelineOutput(frames=video), all_latents, all_log_probs, all_kl
adv_grpo/diffusers_patch/wan_prompt_embedding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ def _get_t5_prompt_embeds(
5
+ text_encoder,
6
+ tokenizer,
7
+ prompt: Union[str, List[str]] = None,
8
+ max_sequence_length: int = 226,
9
+ num_videos_per_prompt: int = 1,
10
+ device: Optional[torch.device] = None,
11
+ dtype: Optional[torch.dtype] = None,
12
+ ):
13
+
14
+ prompt = [prompt] if isinstance(prompt, str) else prompt
15
+ batch_size = len(prompt)
16
+
17
+ text_inputs = tokenizer(
18
+ prompt,
19
+ padding="max_length",
20
+ max_length=max_sequence_length,
21
+ truncation=True,
22
+ add_special_tokens=True,
23
+ return_attention_mask=True,
24
+ return_tensors="pt",
25
+ )
26
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
27
+ seq_lens = mask.gt(0).sum(dim=1).long()
28
+
29
+ prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
30
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
31
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
32
+ prompt_embeds = torch.stack(
33
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
34
+ )
35
+
36
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
37
+ _, seq_len, _ = prompt_embeds.shape
38
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
39
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
40
+
41
+ return prompt_embeds
42
+
43
+ def encode_prompt(
44
+ text_encoder,
45
+ tokenizer,
46
+ prompt: Union[str, List[str]],
47
+ max_sequence_length: int = 226,
48
+ num_videos_per_prompt: int = 1,
49
+ device: Optional[torch.device] = None,
50
+ dtype: Optional[torch.dtype] = None,
51
+ ):
52
+ r"""
53
+ Encodes the prompt into text encoder hidden states.
54
+
55
+ Args:
56
+ prompt (`str` or `List[str]`, *optional*):
57
+ prompt to be encoded
58
+ negative_prompt (`str` or `List[str]`, *optional*):
59
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
60
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
61
+ less than `1`).
62
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
63
+ Whether to use classifier free guidance or not.
64
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
65
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
66
+ prompt_embeds (`torch.Tensor`, *optional*):
67
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
68
+ provided, text embeddings will be generated from `prompt` input argument.
69
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
70
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
71
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
72
+ argument.
73
+ device: (`torch.device`, *optional*):
74
+ torch device
75
+ dtype: (`torch.dtype`, *optional*):
76
+ torch dtype
77
+ """
78
+ device = text_encoder[0].device
79
+ dtype = text_encoder[0].dtype
80
+
81
+ prompt = [prompt] if isinstance(prompt, str) else prompt
82
+ if prompt is not None:
83
+ batch_size = len(prompt)
84
+ else:
85
+ batch_size = prompt_embeds.shape[0]
86
+
87
+ prompt_embeds = _get_t5_prompt_embeds(
88
+ text_encoder=text_encoder[0],
89
+ tokenizer=tokenizer[0],
90
+ prompt=prompt,
91
+ max_sequence_length=max_sequence_length,
92
+ num_videos_per_prompt=num_videos_per_prompt,
93
+ device=device,
94
+ dtype=dtype,
95
+ )
96
+
97
+ return prompt_embeds
adv_grpo/ema.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from another repo, but I can't remember exactly which one.
2
+
3
+ from collections.abc import Iterable
4
+
5
+ import torch
6
+
7
+
8
+ class EMAModuleWrapper:
9
+ def __init__(
10
+ self,
11
+ parameters: Iterable[torch.nn.Parameter],
12
+ decay: float = 0.9999,
13
+ update_step_interval: int = 1,
14
+ device: torch.device | None = None,
15
+ ):
16
+ parameters = list(parameters)
17
+ self.ema_parameters = [p.clone().detach().to(device) for p in parameters]
18
+
19
+ self.temp_stored_parameters = None
20
+
21
+ self.decay = decay
22
+ self.update_step_interval = update_step_interval
23
+ self.device = device
24
+
25
+ # TODO: add an automatic decay calculation based on this formula:
26
+ # The impact of the last n steps can be calculated as:
27
+ # impact = 1-(decay^n)
28
+ # The number of steps needed to reach a specific impact is:
29
+ # n = log_decay(1-impact)
30
+ # The decay needed to reach a specific impact after n steps is:
31
+ # decay = (1-impact)^(1/n)
32
+
33
+ def get_current_decay(self, optimization_step) -> float:
34
+ return min(
35
+ (1 + optimization_step) / (10 + optimization_step),
36
+ self.decay
37
+ )
38
+
39
+ @torch.no_grad()
40
+ def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step):
41
+ parameters = list(parameters)
42
+
43
+ one_minus_decay = 1 - self.get_current_decay(optimization_step)
44
+
45
+ if (optimization_step + 1) % self.update_step_interval == 0:
46
+ for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
47
+ if parameter.requires_grad:
48
+ if ema_parameter.device == parameter.device:
49
+ ema_parameter.add_(one_minus_decay * (parameter - ema_parameter))
50
+ else:
51
+ # in place calculations to save memory
52
+ parameter_copy = parameter.detach().to(ema_parameter.device)
53
+ parameter_copy.sub_(ema_parameter)
54
+ parameter_copy.mul_(one_minus_decay)
55
+ ema_parameter.add_(parameter_copy)
56
+ del parameter_copy
57
+
58
+ def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None:
59
+ self.device = device
60
+ self.ema_parameters = [
61
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
62
+ for p in self.ema_parameters
63
+ ]
64
+
65
+ def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True) -> None:
66
+ if store_temp:
67
+ self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters]
68
+
69
+ parameters = list(parameters)
70
+ for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
71
+ parameter.data.copy_(ema_parameter.to(parameter.device).data)
72
+
73
+ def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
74
+ for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True):
75
+ parameter.data.copy_(temp_parameter.data)
76
+
77
+ self.temp_stored_parameters = None
78
+
79
+ def load_state_dict(self, state_dict: dict) -> None:
80
+ self.decay = self.decay if self.decay else state_dict.get("decay", self.decay)
81
+ self.ema_parameters = state_dict.get("ema_parameters")
82
+ self.to(self.device)
83
+
84
+ def state_dict(self) -> dict:
85
+ return {
86
+ "decay": self.decay,
87
+ "ema_parameters": self.ema_parameters,
88
+ }
adv_grpo/imagereward_scorer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModel
2
+ from PIL import Image
3
+ import torch
4
+ import ImageReward as RM
5
+
6
+ class ImageRewardScorer(torch.nn.Module):
7
+ def __init__(self, device="cuda", dtype=torch.float32):
8
+ super().__init__()
9
+ self.model_path = "ImageReward-v1.0"
10
+ self.device = device
11
+ self.dtype = dtype
12
+ self.model = RM.load(self.model_path, device=device).eval().to(dtype=dtype)
13
+ self.model.requires_grad_(False)
14
+
15
+ @torch.no_grad()
16
+ def __call__(self, prompts, images):
17
+ rewards = []
18
+ for prompt,image in zip(prompts, images):
19
+ _, reward = self.model.inference_rank(prompt, [image])
20
+ rewards.append(reward)
21
+ return rewards
22
+
23
+ # Usage example
24
+ def main():
25
+ scorer = ImageRewardScorer(
26
+ device="cuda",
27
+ dtype=torch.float32
28
+ )
29
+
30
+ images=[
31
+ "astronaut.jpg",
32
+ ]
33
+ pil_images = [Image.open(img) for img in images]
34
+ prompts=[
35
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
36
+ ]
37
+ print(scorer(prompts, pil_images))
38
+
39
+ if __name__ == "__main__":
40
+ main()
adv_grpo/inflated_layers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Literal
3
+ from einops import rearrange
4
+ from torch import Tensor
5
+ from torch.nn import ConvTranspose2d, ConvTranspose3d
6
+
7
+ from flow_grpo.inflated_lib import (
8
+ MemoryState,
9
+ extend_head,
10
+ inflate_bias,
11
+ inflate_distribution_bias,
12
+ inflate_distribution_weight,
13
+ inflate_weight,
14
+ modify_state_dict,
15
+ )
16
+ from flow_grpo.conv_gradfix import GradFixConv2d, GradFixConv3d
17
+
18
+ VERBOSE = False
19
+
20
+ _inflation_mode_t = (Literal["none", "flatten", "partial_flatten", "pad", "tile"],)
21
+ _direction_t = Literal["", "out", "in"]
22
+
23
+
24
+ class InflatedCausalConv3d(GradFixConv3d):
25
+ """
26
+ Note:
27
+ To align the behavior of pretrained 2D models,
28
+ if you compose a video clip from a single image by:
29
+ - duplicating: set shape_norm = True
30
+ - padding zeros: set shape_norm = False
31
+ to avoid gaps in the beginning of training process.
32
+ """
33
+
34
+ def __init__(
35
+ self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
36
+ ):
37
+ self.shape_norm = shape_norm
38
+ self.inflation_mode = inflation_mode
39
+ self.padding_bank = None
40
+ super().__init__(*args, **kwargs)
41
+ self.temporal_padding = self.padding[0]
42
+ self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
43
+
44
+ def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
45
+ bank_size = self.stride[0] - self.kernel_size[0]
46
+ padding_bank = (
47
+ input[:, :, bank_size:].detach()
48
+ if (bank_size != 0 and memory_state != MemoryState.DISABLED)
49
+ else None
50
+ )
51
+ if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
52
+ input = extend_head(input, memory=self.padding_bank)
53
+ else:
54
+ input = extend_head(input, times=self.temporal_padding * 2)
55
+ if memory_state != MemoryState.DISABLED and not self.training:
56
+ self.padding_bank = padding_bank
57
+ return super().forward(input)
58
+
59
+ def _load_from_state_dict(
60
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
61
+ ):
62
+ if self.inflation_mode == "none":
63
+ super()._load_from_state_dict(
64
+ state_dict,
65
+ prefix,
66
+ local_metadata,
67
+ strict,
68
+ missing_keys,
69
+ unexpected_keys,
70
+ error_msgs,
71
+ )
72
+ else:
73
+ # NOTE: need to switch off strict
74
+ super()._load_from_state_dict(
75
+ modify_state_dict(
76
+ self,
77
+ state_dict,
78
+ prefix,
79
+ verbose=VERBOSE,
80
+ inflate_weight_fn=partial(inflate_weight, position="tail"),
81
+ inflate_bias_fn=partial(inflate_bias, position="tail"),
82
+ ),
83
+ prefix,
84
+ local_metadata,
85
+ False,
86
+ missing_keys,
87
+ unexpected_keys,
88
+ error_msgs,
89
+ )
90
+
91
+
92
+ class InflatedDistributionCausalConv3d(GradFixConv3d):
93
+ """
94
+ Note:
95
+ Direction:
96
+ - out: this layer generates mean/std of some distribution;
97
+ - in: this layer takes tensors sampled from output of `out` layer as input.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ *args,
103
+ direction: _direction_t,
104
+ inflation_mode: _inflation_mode_t,
105
+ shape_norm: bool = True,
106
+ **kwargs,
107
+ ):
108
+ self.shape_norm = shape_norm
109
+ self.inflation_mode = inflation_mode
110
+ self.direction = direction
111
+ self.padding_bank = None
112
+ super().__init__(*args, **kwargs)
113
+ self.temporal_padding = self.padding[0]
114
+ self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
115
+
116
+ def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
117
+ bank_size = self.stride[0] - self.kernel_size[0]
118
+ padding_bank = (
119
+ input[:, :, bank_size:].detach()
120
+ if (bank_size != 0 and memory_state != MemoryState.DISABLED)
121
+ else None
122
+ )
123
+ if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
124
+ input = extend_head(input, memory=self.padding_bank)
125
+ else:
126
+ input = extend_head(input, times=self.temporal_padding * 2)
127
+ if memory_state != MemoryState.DISABLED and not self.training:
128
+ self.padding_bank = padding_bank
129
+ return super().forward(input)
130
+
131
+ def _load_from_state_dict(
132
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
133
+ ):
134
+ if self.inflation_mode == "none":
135
+ super()._load_from_state_dict(
136
+ state_dict,
137
+ prefix,
138
+ local_metadata,
139
+ strict,
140
+ missing_keys,
141
+ unexpected_keys,
142
+ error_msgs,
143
+ )
144
+ else:
145
+ super()._load_from_state_dict(
146
+ modify_state_dict(
147
+ self,
148
+ state_dict,
149
+ prefix,
150
+ verbose=VERBOSE,
151
+ inflate_weight_fn=partial(
152
+ inflate_distribution_weight, direction=self.direction, position="tail"
153
+ ),
154
+ inflate_bias_fn=partial(
155
+ inflate_distribution_bias, direction=self.direction, position="tail"
156
+ ),
157
+ ),
158
+ prefix,
159
+ local_metadata,
160
+ False,
161
+ missing_keys,
162
+ unexpected_keys,
163
+ error_msgs,
164
+ )
165
+
166
+
167
+ class InflatedConvTranspose3d(ConvTranspose3d):
168
+ # Note: It's not a causal one.
169
+ def __init__(
170
+ self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
171
+ ):
172
+ self.shape_norm = shape_norm
173
+ self.inflation_mode = inflation_mode
174
+ super().__init__(*args, **kwargs)
175
+
176
+ def _load_from_state_dict(
177
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
178
+ ):
179
+ if self.inflation_mode == "none":
180
+ super()._load_from_state_dict(
181
+ state_dict,
182
+ prefix,
183
+ local_metadata,
184
+ strict,
185
+ missing_keys,
186
+ unexpected_keys,
187
+ error_msgs,
188
+ )
189
+ else:
190
+ # NOTE: need to switch off strict
191
+ super()._load_from_state_dict(
192
+ modify_state_dict(
193
+ self,
194
+ state_dict,
195
+ prefix,
196
+ verbose=VERBOSE,
197
+ inflate_weight_fn=partial(inflate_weight, position="center"),
198
+ inflate_bias_fn=partial(inflate_bias, position="center"),
199
+ ),
200
+ prefix,
201
+ local_metadata,
202
+ False,
203
+ missing_keys,
204
+ unexpected_keys,
205
+ error_msgs,
206
+ )
207
+
208
+
209
+ class FlattenedConvTranspose3d(ConvTranspose2d):
210
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
211
+ output = rearrange(input, "b c f h w -> (b f) c h w")
212
+ output = super().forward(output)
213
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
214
+ return output
215
+
216
+
217
+ class FlattenedConv3d(GradFixConv2d):
218
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
219
+ output = rearrange(input, "b c f h w -> (b f) c h w")
220
+ output = super().forward(output)
221
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
222
+ return output
223
+
224
+
225
+ def init_causal_conv3d(
226
+ *args,
227
+ inflation_mode: _inflation_mode_t,
228
+ direction: _direction_t = "",
229
+ partial_switch: bool = False,
230
+ **kwargs,
231
+ ):
232
+ """
233
+ Initialize a Causal-3D convolution layer.
234
+ Parameters:
235
+ inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have.
236
+ - none: No inflation will be conducted.
237
+ The loading logic of state dict will fall back to default.
238
+ - flatten: It will produce a `fake` 3D layer,
239
+ which simply squeeze the axis of batch size and depth together,
240
+ and then conduct 2D convolution.
241
+ - partial_flatten:
242
+ - layers with `partial_switch` on: using `none` mode.
243
+ - layers with `partial_switch` off: using `flatten` mode.
244
+ - pad / tile: Refer to the definition of `InflatedCausalConv3d`.
245
+ direction:
246
+ - empty string: Ordinary causal convolution layer.
247
+ - out / in: Refer to the definition of `InflatedDistributionCausalConv3d`.
248
+ partial_switch: Only works when `inflation_mode` is `partial_flatten`.
249
+ """
250
+ stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
251
+ padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
252
+ if "flatten" in inflation_mode:
253
+ if (
254
+ (
255
+ (not stride)
256
+ or isinstance(stride, int)
257
+ or (isinstance(stride, list or tuple) and len(stride) < 3)
258
+ ) # if the config of stride can be used for 2D conv
259
+ and (
260
+ (not padding)
261
+ or isinstance(padding, int)
262
+ or (isinstance(padding, list or tuple) and len(padding) < 3)
263
+ ) # if the config of padding can be used for 2D conv
264
+ and (("partial" not in inflation_mode) or (not partial_switch))
265
+ # if it's fully-flatten mode, or with `partial_switch` off
266
+ ):
267
+ return FlattenedConv3d(*args, **kwargs)
268
+ else:
269
+ return InflatedCausalConv3d(*args, inflation_mode="none", **kwargs)
270
+ # Force-override
271
+ else:
272
+ if direction:
273
+ return InflatedDistributionCausalConv3d(
274
+ *args, direction=direction, inflation_mode=inflation_mode, **kwargs
275
+ )
276
+ else:
277
+ return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)
278
+
279
+
280
+ def init_transposed_conv3d(
281
+ *args, inflation_mode: _inflation_mode_t, partial_switch: bool = False, **kwargs
282
+ ):
283
+ stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
284
+ padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
285
+ if "flatten" in inflation_mode:
286
+ if (
287
+ (
288
+ (not stride)
289
+ or isinstance(stride, int)
290
+ or (isinstance(stride, list or tuple) and len(stride) < 3)
291
+ )
292
+ and (
293
+ (not padding)
294
+ or isinstance(padding, int)
295
+ or (isinstance(padding, list or tuple) and len(padding) < 3)
296
+ )
297
+ or (("partial" in inflation_mode) and not partial_switch)
298
+ ):
299
+ return FlattenedConvTranspose3d(*args, **kwargs)
300
+ else:
301
+ return InflatedConvTranspose3d(
302
+ *args, inflation_mode="none", **kwargs
303
+ ) # Force-override
304
+ else:
305
+ return InflatedConvTranspose3d(*args, inflation_mode=inflation_mode, **kwargs)
adv_grpo/inflated_lib.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+ from typing import Optional
4
+ import numpy as np
5
+ import torch
6
+ from diffusers.models.attention_processor import SpatialNorm
7
+ from diffusers.models.normalization import RMSNorm
8
+ from einops import rearrange
9
+ from torch import Tensor, nn
10
+
11
+ # from common.logger import get_logger
12
+
13
+ # logger = get_logger(__name__)
14
+
15
+
16
+ class MemoryState(Enum):
17
+ """
18
+ State[Disabled]: No memory bank will be enabled.
19
+ State[Initializing]: The model is handling the first clip,
20
+ need to reset / initialize the memory bank.
21
+ State[Active]: There has been some data in the memory bank.
22
+ """
23
+
24
+ DISABLED = 0
25
+ INITIALIZING = 1
26
+ ACTIVE = 2
27
+
28
+
29
+ def norm_wrapper(
30
+ norm_layer: nn.Module,
31
+ x: torch.Tensor,
32
+ y: Optional[torch.Tensor] = None,
33
+ keep_causal: bool = False,
34
+ ) -> torch.Tensor:
35
+ if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)):
36
+ if x.ndim == 4:
37
+ x = rearrange(x, "b c h w -> b h w c")
38
+ x = norm_layer(x)
39
+ x = rearrange(x, "b h w c -> b c h w")
40
+ return x
41
+ if x.ndim == 5:
42
+ x = rearrange(x, "b c t h w -> b t h w c")
43
+ x = norm_layer(x)
44
+ x = rearrange(x, "b t h w c -> b c t h w")
45
+ return x
46
+ if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
47
+ if x.ndim <= 4 or (not keep_causal and not isinstance(norm_layer, nn.BatchNorm2d)):
48
+ return norm_layer(x)
49
+ if x.ndim == 5:
50
+ t = x.size(2)
51
+ x = rearrange(x, "b c t h w -> (b t) c h w")
52
+ x = norm_layer(x)
53
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
54
+ return x
55
+ if isinstance(norm_layer, SpatialNorm):
56
+ t = -1
57
+ if x.ndim == 5:
58
+ t = x.size(2)
59
+ x = rearrange(x, "b c t h w -> (b t) c h w")
60
+ if y.ndim == 5:
61
+ y = rearrange(y, "b c t h w -> (b t) c h w")
62
+ if x.ndim != 4 or y.ndim != 4:
63
+ raise NotImplementedError
64
+ x = norm_layer(x, y)
65
+ if t != -1:
66
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
67
+ return x
68
+ raise NotImplementedError
69
+
70
+
71
+ def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
72
+ """
73
+ Remove duplicated first frame features in the up-sampling process.
74
+ """
75
+ if times == 0:
76
+ return tensor
77
+ return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
78
+
79
+
80
+ def extend_head(
81
+ tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None
82
+ ) -> Tensor:
83
+ """
84
+ When memory is None:
85
+ - Duplicate first frame features in the down-sampling process.
86
+ When memory is not None:
87
+ - Concatenate memory features with the input features to keep temporal consistency.
88
+ """
89
+ if times == 0:
90
+ return tensor
91
+ if memory is not None:
92
+ return torch.cat((memory.to(tensor), tensor), dim=2)
93
+ else:
94
+ tile_repeat = np.ones(tensor.ndim).astype(int)
95
+ tile_repeat[2] = times
96
+ return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2)
97
+
98
+
99
+ def fill_weight_in_depth(weight: torch.Tensor, source: torch.Tensor, position: str):
100
+ """
101
+ Inflate a 2D convolution weight matrix to a 3D one by padding zeros in the channel of depth.
102
+ Parameters:
103
+ weight: The weight parameters of 3D conv kernel to be initialized.
104
+ source: The weight parameters of 2D conv kernel to be inflated.
105
+ position: Where to insert the 2D weights, can be chosen from
106
+ - tail: Pad zeros in the front of the 2D kernel. Used for casual inflation.
107
+ - center: Pad zeros around the 2D kernel. Used for normal inflation.
108
+ """
109
+ assert position in ["tail", "center"], "Unsupported fill-in position for weight inflation."
110
+ depth = weight.size(2)
111
+ weight.fill_(0.0)
112
+ if position == "center":
113
+ if depth % 2 == 1:
114
+ weight[:, :, depth // 2].copy_(source.squeeze(2))
115
+ else:
116
+ weight[:, :, depth // 2].copy_(source.squeeze(2) / 2.0)
117
+ weight[:, :, depth // 2 - 1].copy_(source.squeeze(2) / 2.0)
118
+ else:
119
+ if depth % 2 == 1:
120
+ weight[:, :, -1].copy_(source.squeeze(2))
121
+ else:
122
+ weight[:, :, -1].copy_(source.squeeze(2) / 2.0)
123
+ weight[:, :, -2].copy_(source.squeeze(2) / 2.0)
124
+ return weight
125
+
126
+
127
+ def inflate_weight(
128
+ weight_2d: torch.Tensor,
129
+ weight_3d: torch.Tensor,
130
+ shape_norm: bool,
131
+ name: str,
132
+ inflation_mode: str,
133
+ position: str,
134
+ verbose: bool = True,
135
+ ):
136
+ """
137
+ Inflate a 2D convolution weight matrix to a 3D one.
138
+ Parameters:
139
+ weight_2d: The weight matrix of 2D conv to be inflated.
140
+ weight_3d: The weight matrix of 3D conv to be initialized.
141
+ inflation_mode: the mode of inflation
142
+ - pad: pad zeros around 2D kernel.
143
+ - tile: tile 2D kernel along the depth axis.
144
+
145
+ shape_norm: Whether to scale the parameters of 2D kernel so that the untrained
146
+ inflated model behaves exactly the same as the original 2D model
147
+ in the reconstruction of image and video. recommend to switch it on.
148
+
149
+ name: The name of inflated module. Only be used in logging.
150
+ position: Refer to the doc of `fill_weight_in_depth`.
151
+ Only works when `inflation_mode` is `pad`.
152
+ verbose: Whether to log information about inflation.
153
+ """
154
+ assert inflation_mode in ["pad", "tile"]
155
+ depth = weight_3d.size(2)
156
+ tgt_out, tgt_in = weight_3d.size()[:2]
157
+ src_out, src_in = weight_2d.size()[:2]
158
+ assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
159
+ out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
160
+ depth_factor = 1 if inflation_mode == "pad" else depth
161
+ factor = (depth_factor * math.sqrt(out_fan) * math.sqrt(in_fan)) if shape_norm else 1
162
+ with torch.no_grad():
163
+ channel_inflation = weight_2d.unsqueeze(2).repeat(out_fan, in_fan, 1, 1, 1) / factor
164
+ if inflation_mode == "tile":
165
+ weight_3d.copy_(channel_inflation.repeat(1, 1, depth, 1, 1))
166
+ else:
167
+ weight_3d = fill_weight_in_depth(weight_3d, channel_inflation, position)
168
+ if verbose:
169
+ print(
170
+ f"*** {name}weight {weight_2d.size()} is inflated to {weight_3d.size()} ***"
171
+ )
172
+ return weight_3d
173
+
174
+
175
+ def inflate_bias(
176
+ bias_2d: torch.Tensor,
177
+ bias_3d: torch.Tensor,
178
+ shape_norm: bool,
179
+ name: str,
180
+ inflation_mode: str,
181
+ position: str,
182
+ verbose: bool = True,
183
+ ):
184
+ """
185
+ Inflate a 2D convolution bias tensor to a 3D one
186
+ Parameters:
187
+ bias_2d: The bias tensor of 2D conv to be inflated.
188
+ bias_3d: The bias tensor of 3D conv to be initialized.
189
+ shape_norm: Refer to `inflate_weight` function.
190
+ name: The name of inflated module. Only be used in logging.
191
+ inflation_mode: Placeholder to align `inflate_weight`.
192
+ position: Placeholder to align `inflate_weight`.
193
+ verbose: Whether to log information about inflation.
194
+ """
195
+ tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
196
+ assert tgt_ch % src_ch == 0
197
+ fan = tgt_ch // src_ch
198
+ factor = math.sqrt(fan) if shape_norm else 1
199
+ with torch.no_grad():
200
+ bias_3d.copy_(bias_2d.repeat(fan) / factor)
201
+ if (tgt_ch != src_ch) and verbose:
202
+ print(f"*** {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***")
203
+ return bias_3d
204
+
205
+
206
+ def inflate_distribution_weight(
207
+ weight_2d: torch.Tensor,
208
+ weight_3d: torch.Tensor,
209
+ shape_norm: bool,
210
+ name: str,
211
+ direction: str,
212
+ inflation_mode: str,
213
+ position: str,
214
+ verbose: bool = True,
215
+ ):
216
+ """
217
+ Inflate a 2D convolution weight matrix to a 3D one.
218
+ Note: Different from `inflate_weight`,
219
+ it's designed for `quant_conv` or `post_quant_conv` layers.
220
+ i.e., a convolution layer used to produce `mean` and `std` of some distribution,
221
+ or its subsequent layer.
222
+ Parameters: Refer to `inflate_weight`.
223
+ direction:
224
+ - out: this layer generates `mean` and `std` of some distribution.
225
+ - in: this layer takes tensors sampled from output of `out` layer as input.
226
+ """
227
+ assert inflation_mode in ["pad", "tile"]
228
+ depth = weight_3d.size(2)
229
+ tgt_out, tgt_in = weight_3d.size()[:2]
230
+ src_out, src_in = weight_2d.size()[:2]
231
+ assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
232
+ out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
233
+ depth_factor = 1 if inflation_mode == "pad" else depth
234
+ if direction == "out":
235
+ factor = (depth_factor * math.sqrt(in_fan)) if shape_norm else 1
236
+ with torch.no_grad():
237
+ in_inflation = weight_2d.unsqueeze(2).repeat(1, in_fan, 1, 1, 1) / factor
238
+ # [src_out, src_in, k_h, k_w] -> [src_out, tgt_in, 1, k_h, k_w]
239
+ out_mean_weight, out_std_weight = torch.chunk(in_inflation, 2, dim=0)
240
+ mean_slice = slice(src_out // 2)
241
+ std_slice = slice(tgt_out // 2, tgt_out // 2 + src_out // 2)
242
+ if inflation_mode == "tile":
243
+ weight_3d[mean_slice] = out_mean_weight
244
+ weight_3d[std_slice] = out_std_weight
245
+ # Other part will be randomly initialized.
246
+ else:
247
+ weight_3d[mean_slice] = fill_weight_in_depth(
248
+ weight_3d[mean_slice], out_mean_weight, position
249
+ )
250
+ weight_3d[std_slice] = fill_weight_in_depth(
251
+ weight_3d[std_slice], out_std_weight, position
252
+ )
253
+ # Other part will be randomly initialized.
254
+ elif direction == "in":
255
+ factor = (depth_factor * math.sqrt(out_fan)) if shape_norm else 1
256
+ with torch.no_grad():
257
+ out_inflation = weight_2d.unsqueeze(2).repeat(out_fan, 1, 1, 1, 1) / factor
258
+ # [src_out, src_in, k_h, k_w] -> [tgt_out, src_in, 1, k_h, k_w]
259
+ if inflation_mode == "tile":
260
+ weight_3d[:, :src_in] = out_inflation
261
+ else:
262
+ weight_3d[:, :src_in] = fill_weight_in_depth(
263
+ weight_3d[:, :src_in], out_inflation, position
264
+ )
265
+ weight_3d[:, src_in:].fill_(0.0)
266
+ else:
267
+ raise NotImplementedError
268
+ if verbose:
269
+ print(
270
+ f"*** [Distribution] {name}weight {weight_2d.size()} "
271
+ f"is inflated to {weight_3d.size()} ***"
272
+ )
273
+ return weight_3d
274
+
275
+
276
+ def inflate_distribution_bias(
277
+ bias_2d: torch.Tensor,
278
+ bias_3d: torch.Tensor,
279
+ shape_norm: bool,
280
+ name: str,
281
+ direction: str,
282
+ inflation_mode: str,
283
+ position: str,
284
+ verbose: bool = True,
285
+ ):
286
+ """
287
+ The combination of `inflate_distribution_weight` and `inflate_bias`.
288
+ """
289
+ tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
290
+ assert tgt_ch % src_ch == 0
291
+ if direction == "out":
292
+ with torch.no_grad():
293
+ out_mean_bias, out_std_bias = torch.chunk(bias_2d, 2, dim=0)
294
+ bias_3d[: src_ch // 2] = out_mean_bias
295
+ bias_3d[tgt_ch // 2 : tgt_ch // 2 + src_ch // 2] = out_std_bias
296
+ elif direction == "in":
297
+ with torch.no_grad():
298
+ bias_3d[:src_ch] = bias_2d
299
+ bias_3d[src_ch:].fill_(0.0)
300
+ else:
301
+ raise NotImplementedError
302
+ if verbose:
303
+ print(
304
+ f"*** [Distribution] {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***"
305
+ )
306
+ return bias_3d
307
+
308
+
309
+ def modify_state_dict(
310
+ layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn, verbose=False
311
+ ):
312
+ """
313
+ the main function to inflated 2D parameters to 3D.
314
+ """
315
+ weight_name = prefix + "weight"
316
+ bias_name = prefix + "bias"
317
+ if weight_name in state_dict:
318
+ weight_2d = state_dict[weight_name]
319
+ if (
320
+ weight_2d.dim() == 4
321
+ ): # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
322
+ weight_3d = inflate_weight_fn(
323
+ weight_2d=weight_2d,
324
+ weight_3d=layer.weight,
325
+ shape_norm=layer.shape_norm,
326
+ name=prefix,
327
+ verbose=verbose,
328
+ inflation_mode=layer.inflation_mode,
329
+ )
330
+ state_dict[weight_name] = weight_3d
331
+ else:
332
+ return state_dict
333
+ # It's a 3d state dict, should not do inflation on both bias and weight.
334
+ if bias_name in state_dict:
335
+ bias_2d = state_dict[bias_name]
336
+ if bias_2d.dim() == 1: # Assuming the 2D biases are 1D tensors (out_channels,)
337
+ bias_3d = inflate_bias_fn(
338
+ bias_2d=bias_2d,
339
+ bias_3d=layer.bias,
340
+ shape_norm=layer.shape_norm,
341
+ name=prefix,
342
+ verbose=verbose,
343
+ inflation_mode=layer.inflation_mode,
344
+ )
345
+ state_dict[bias_name] = bias_3d
346
+ return state_dict
adv_grpo/ocr.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddleocr import PaddleOCR
2
+ import torch
3
+ import numpy as np
4
+ from Levenshtein import distance
5
+ from typing import List, Union, Tuple
6
+ from PIL import Image
7
+
8
+ class OcrScorer:
9
+ def __init__(self, use_gpu: bool = False):
10
+ """
11
+ OCR reward calculator
12
+ :param use_gpu: Whether to use GPU acceleration for PaddleOCR
13
+ """
14
+ self.ocr = PaddleOCR(
15
+ use_angle_cls=False,
16
+ lang="en",
17
+ use_gpu=use_gpu,
18
+ show_log=False # Disable unnecessary log output
19
+ )
20
+
21
+ @torch.no_grad()
22
+ def __call__(self,
23
+ images: Union[List[Image.Image], List[np.ndarray]],
24
+ prompts: List[str]) -> torch.Tensor:
25
+ """
26
+ Calculate OCR reward
27
+ :param images: List of input images (PIL or numpy format)
28
+ :param prompts: Corresponding target text list
29
+ :return: Reward tensor (CPU)
30
+ """
31
+ # import pdb; pdb.set_trace()
32
+ prompts = [prompt.split('"')[1] for prompt in prompts]
33
+ rewards = []
34
+ # Ensure input lengths are consistent
35
+ assert len(images) == len(prompts), "Images and prompts must have the same length"
36
+ for img, prompt in zip(images, prompts):
37
+ # Convert image format
38
+ if isinstance(img, Image.Image):
39
+ img = np.array(img)
40
+
41
+ try:
42
+ # OCR recognition
43
+ result = self.ocr.ocr(img, cls=False)
44
+ # Extract recognized text (handle possible multi-line results)
45
+ recognized_text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
46
+
47
+ recognized_text = recognized_text.replace(' ', '').lower()
48
+ prompt = prompt.replace(' ', '').lower()
49
+ if prompt in recognized_text:
50
+ dist = 0
51
+ else:
52
+ dist = distance(recognized_text, prompt)
53
+ # import pdb; pdb.set_trace()
54
+ # Recognized many unrelated characters, only add one character penalty
55
+ if dist > len(prompt):
56
+ dist = len(prompt)
57
+
58
+ except Exception as e:
59
+ # Error handling (e.g., OCR parsing failure)
60
+ print(f"OCR processing failed: {str(e)}")
61
+ dist = len(prompt) # Maximum penalty
62
+ reward = 1-dist/(len(prompt))
63
+ rewards.append(reward)
64
+
65
+ return rewards
66
+
67
+ class OcrScorer_video_or_image:
68
+ def __init__(self, use_gpu: bool = False):
69
+ """
70
+ OCR reward calculator
71
+ :param use_gpu: Whether to use GPU acceleration for PaddleOCR
72
+ """
73
+ self.ocr = PaddleOCR(
74
+ use_angle_cls=False,
75
+ lang="en",
76
+ use_gpu=use_gpu,
77
+ show_log=False # Disable unnecessary log output
78
+ )
79
+ self.frame_interval = 4
80
+
81
+ @torch.no_grad()
82
+ def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> Tuple[List[float], torch.Tensor]:
83
+ """
84
+ :param images: List of images or videos (each video as np.ndarray of shape [F, H, W, C])
85
+ :param prompts: List of prompts containing target text
86
+ :return: (List of OCR rewards, Tensor of attention regions)
87
+ """
88
+ prompts = [prompt.split('"')[1] for prompt in prompts]
89
+ assert len(images) == len(prompts), "Mismatch between images and prompts."
90
+
91
+ rewards = []
92
+ for img, prompt in zip(images, prompts):
93
+ prompt = prompt.replace(' ', '').lower()
94
+ frame_rewards = []
95
+
96
+ # Handle video: shape (F, H, W, C)
97
+ if isinstance(img, np.ndarray) and img.ndim == 4:
98
+ sampled_frames = img[::self.frame_interval]
99
+ else:
100
+ sampled_frames = [img]
101
+
102
+ for frame in sampled_frames:
103
+ region = None
104
+ if isinstance(frame, Image.Image):
105
+ frame = np.array(frame)
106
+ try:
107
+ result = self.ocr.ocr(frame, cls=False)
108
+ text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
109
+ text = text.replace(' ', '').lower()
110
+
111
+ dist = distance(text, prompt)
112
+ dist = min(dist, len(prompt))
113
+
114
+ except Exception as e:
115
+ print(f"OCR failed on frame: {e}")
116
+ dist = len(prompt)
117
+
118
+ reward = 1 - dist / len(prompt)
119
+ if reward > 0:
120
+ frame_rewards.append(reward)
121
+
122
+ if frame_rewards:
123
+ rewards.append(sum(frame_rewards) / len(frame_rewards))
124
+ else:
125
+ rewards.append(0.0)
126
+
127
+ return rewards
128
+
129
+ if __name__ == "__main__":
130
+ example_image_path = "media_images_eval_images_499_ef42de47b8ec98892954.jpg"
131
+ example_image = Image.open(example_image_path)
132
+ example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky'
133
+ # Instantiate scorer
134
+ scorer = OcrScorer(use_gpu=False)
135
+
136
+ # Call scorer and print result
137
+ reward = scorer([example_image], [example_prompt])
138
+ print(f"OCR Reward: {reward}")
adv_grpo/pick_score_training.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+ from torch.utils.data import DataLoader
5
+
6
+ # ====== 使用你找到的 CLIPCriterion ======
7
+ from dataclasses import dataclass
8
+ from torch.nn.modules.loss import _Loss
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import os
11
+ import json
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch.utils.data import DataLoader, DistributedSampler
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+
17
+
18
+
19
+ def evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device, max_eval=100):
20
+ """
21
+ 简单评估:取前 max_eval 对 Qwen vs SD3 pair,算平均分
22
+ """
23
+ model.eval()
24
+ if hasattr(model, "module"): # DDP 情况
25
+ model = model.module
26
+
27
+ with open(json_file, "r") as f:
28
+ prompt2img = json.load(f)
29
+
30
+ prompts = list(prompt2img.keys())[:max_eval]
31
+
32
+ qwen_scores, sd3_scores = [], []
33
+
34
+ for prompt in prompts:
35
+ filename = prompt2img[prompt]
36
+ qwen_img_path = os.path.join(qwen_dir, filename)
37
+ sd3_img_path = os.path.join(sd3_dir, filename)
38
+
39
+ if not (os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path)):
40
+ continue
41
+
42
+ qwen_img = Image.open(qwen_img_path).convert("RGB")
43
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
44
+
45
+ # 文本 & 图像输入
46
+ text_inputs = processor.tokenizer(
47
+ prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77
48
+ ).to(device)
49
+ qwen_inputs = processor(images=qwen_img, return_tensors="pt").to(device)
50
+ sd3_inputs = processor(images=sd3_img, return_tensors="pt").to(device)
51
+
52
+ with torch.no_grad():
53
+ text_features = model.get_text_features(**text_inputs)
54
+ qwen_features = model.get_image_features(**qwen_inputs)
55
+ sd3_features = model.get_image_features(**sd3_inputs)
56
+
57
+ # 归一化
58
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
59
+ qwen_features = qwen_features / qwen_features.norm(dim=-1, keepdim=True)
60
+ sd3_features = sd3_features / sd3_features.norm(dim=-1, keepdim=True)
61
+
62
+ # 相似度分数
63
+ logit_scale = model.logit_scale.exp()
64
+ qwen_score = (logit_scale * (text_features @ qwen_features.T)).item()
65
+ sd3_score = (logit_scale * (text_features @ sd3_features.T)).item()
66
+
67
+ qwen_scores.append(qwen_score)
68
+ sd3_scores.append(sd3_score)
69
+
70
+ model.train()
71
+ if len(qwen_scores) > 0:
72
+ print(f"[Eval] Qwen avg={sum(qwen_scores)/len(qwen_scores):.4f} "
73
+ f"| SD3 avg={sum(sd3_scores)/len(sd3_scores):.4f}")
74
+
75
+
76
+ @dataclass
77
+ class CLIPCriterionConfig:
78
+ _target_: str = "trainer.criterions.clip_criterion.CLIPCriterion"
79
+ is_distributed: bool = False # 本地先关掉
80
+ label_0_column_name: str = "label_0"
81
+ label_1_column_name: str = "label_1"
82
+ input_ids_column_name: str = "input_ids"
83
+ pixels_0_column_name: str = "pixels_0"
84
+ pixels_1_column_name: str = "pixels_1"
85
+ num_examples_per_prompt_column_name: str = "num_examples_per_prompt"
86
+ in_batch_negatives: bool = False
87
+
88
+
89
+ class CLIPCriterion(_Loss):
90
+ def __init__(self, cfg: CLIPCriterionConfig):
91
+ super().__init__()
92
+ self.cfg = cfg
93
+
94
+ @staticmethod
95
+ def get_features(model, input_ids, pixels_0_values, pixels_1_values):
96
+ # import pdb; pdb.set_trace()
97
+ # if hasattr(model, "module"):
98
+ # model = model.module
99
+ all_pixel_values = torch.cat([pixels_0_values, pixels_1_values], dim=0)
100
+ # text_features, all_image_features = model(text_inputs=input_ids, image_inputs=all_pixel_values)
101
+ text_features = model.get_text_features(input_ids=input_ids)
102
+ all_image_features = model.get_image_features(pixel_values=all_pixel_values)
103
+ all_image_features = all_image_features / all_image_features.norm(dim=-1, keepdim=True)
104
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
105
+ image_0_features, image_1_features = all_image_features.chunk(2, dim=0)
106
+ return image_0_features, image_1_features, text_features
107
+
108
+ @staticmethod
109
+ def gather_features(features):
110
+ all_features = torch.cat(torch.distributed.nn.all_gather(features), dim=0)
111
+ return all_features
112
+
113
+ # def safe_sync(self, msg):
114
+ # torch.cuda.synchronize()
115
+ # print(f"[Rank {dist.get_rank()}] OK at {msg}")
116
+
117
+ def calc_loss(
118
+ self,
119
+ text_features,
120
+ image_0_features,
121
+ image_1_features,
122
+ logit_scale,
123
+ label_0,
124
+ label_1,
125
+ num_examples_per_prompt,
126
+ *args,
127
+ **kwargs
128
+ ):
129
+ # self.safe_sync("start")
130
+
131
+ device = image_0_features.device
132
+
133
+ # gather features
134
+ if self.cfg.is_distributed:
135
+ image_0_features = self.gather_features(image_0_features)
136
+ image_1_features = self.gather_features(image_1_features)
137
+ text_features = self.gather_features(text_features)
138
+ label_0 = self.gather_features(label_0)
139
+ label_1 = self.gather_features(label_1)
140
+ num_examples_per_prompt = self.gather_features(num_examples_per_prompt)
141
+
142
+ # calc logits # TODO use local loss as open-clip does
143
+ all_image_features = torch.cat([image_0_features, image_1_features], dim=0) # (2 * batch_size, dim)
144
+ logits_per_image = logit_scale * all_image_features @ text_features.T
145
+ image_0_logits, image_1_logits = logits_per_image.chunk(2, dim=0)
146
+ text_logits = logit_scale * text_features @ all_image_features.T
147
+
148
+ if self.cfg.in_batch_negatives:
149
+ # get labels
150
+ num_images = all_image_features.shape[0]
151
+ image_labels = torch.arange(num_images, device=device, dtype=torch.long)
152
+ image_0_labels, image_1_labels = image_labels.chunk(2, dim=0)
153
+ num_texts = text_features.shape[0]
154
+ text_labels = torch.arange(num_texts, device=device, dtype=torch.long)
155
+
156
+ # image loss - we want to increase the logits of the preferred image to the text
157
+ image_0_loss = torch.nn.functional.cross_entropy(image_0_logits, text_labels, reduction="none")
158
+ image_1_loss = torch.nn.functional.cross_entropy(image_1_logits, text_labels, reduction="none")
159
+ # if we have a tie, we will increase both images equally, and average so the image loss of each example is
160
+ # proportional
161
+ image_loss = label_0 * image_0_loss + label_1 * image_1_loss
162
+
163
+ # text loss - we want to increase the logits of the text to the preferred image
164
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, image_0_labels, reduction="none")
165
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, image_1_labels, reduction="none")
166
+
167
+ else:
168
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
169
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
170
+
171
+ text_0_logits = text_0_logits[index, index]
172
+ text_1_logits = text_1_logits[index, index]
173
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
174
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
175
+ text_1_labels = text_0_labels + 1
176
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
177
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
178
+
179
+ # if we have a tie we want the logits of for each image to be equal
180
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
181
+ # we want the ideal loss to be 0, currently, if there is a tie, it is 0.5 * log(0.5) + 0.5 * log(0.5)
182
+ # so we add log(0.5) to the loss
183
+ is_tie = (label_0 == label_1).float()
184
+ is_tie *= torch.log(torch.tensor(0.5, device=device))
185
+ text_loss += is_tie
186
+
187
+ # we average the image and text loss
188
+ if self.cfg.in_batch_negatives:
189
+ loss = (image_loss + text_loss) / 2
190
+ else:
191
+ loss = text_loss
192
+ # import pdb; pdb.set_trace()
193
+
194
+ # some prompts have lots of interactions, we want weight them accordingly
195
+ # absolute_example_weight = 1 / num_examples_per_prompt
196
+ # denominator = absolute_example_weight.sum()
197
+ # weight_per_example = absolute_example_weight / denominator
198
+ # loss *= weight_per_example
199
+ loss = loss.mean()
200
+ # import pdb; pdb.set_trace()
201
+
202
+ # loss = loss.sum()
203
+ return loss
204
+
205
+ def forward(self, model, batch):
206
+ # import pdb; pdb.set_trace()
207
+ image_0_features, image_1_features, text_features = self.get_features(
208
+ model,
209
+ batch[self.cfg.input_ids_column_name],
210
+ batch[self.cfg.pixels_0_column_name],
211
+ batch[self.cfg.pixels_1_column_name]
212
+ )
213
+ # print("text_features:", text_features.shape)
214
+
215
+ loss = self.calc_loss(
216
+ text_features,
217
+ image_0_features,
218
+ image_1_features,
219
+ model.logit_scale.exp(),
220
+ batch[self.cfg.label_0_column_name],
221
+ batch[self.cfg.label_1_column_name],
222
+ batch[self.cfg.num_examples_per_prompt_column_name],
223
+ )
224
+ return loss
225
+
226
+
227
+ # ====== 数据准备 ======
228
+ class QwenSD3JsonDataset(Dataset):
229
+ def __init__(self, processor, json_file, qwen_dir, sd3_dir):
230
+ """
231
+ json_file: prompt2img.json {prompt: filename}
232
+ qwen_dir: 存放Qwen图像的文件夹
233
+ sd3_dir: 存放SD3图像的文件夹
234
+ """
235
+ self.processor = processor
236
+
237
+ with open(json_file, "r") as f:
238
+ self.prompt2img = json.load(f)
239
+
240
+ self.prompts = list(self.prompt2img.keys())
241
+ self.qwen_dir = qwen_dir
242
+ self.sd3_dir = sd3_dir
243
+
244
+ def __len__(self):
245
+ return len(self.prompts)
246
+
247
+ def __getitem__(self, idx):
248
+ prompt = self.prompts[idx]
249
+ filename = self.prompt2img[prompt]
250
+
251
+ qwen_img_path = os.path.join(self.qwen_dir, filename)
252
+ sd3_img_path = os.path.join(self.sd3_dir, filename)
253
+
254
+ if os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path):
255
+ qwen_img = Image.open(qwen_img_path).convert("RGB")
256
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
257
+ else:
258
+ qwen_img = Image.open(sd3_img_path).convert("RGB")
259
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
260
+
261
+ # 文本token
262
+ text_inputs = self.processor.tokenizer(
263
+ prompt,
264
+ padding="max_length",
265
+ truncation=True,
266
+ max_length=77,
267
+ return_tensors="pt"
268
+ )
269
+ input_ids = text_inputs["input_ids"].squeeze(0)
270
+
271
+ # 图像预处理
272
+ pixels_0 = self.processor(images=qwen_img, return_tensors="pt")["pixel_values"].squeeze(0)
273
+ pixels_1 = self.processor(images=sd3_img, return_tensors="pt")["pixel_values"].squeeze(0)
274
+
275
+ return {
276
+ "input_ids": input_ids,
277
+ "pixels_0": pixels_0, # 正样本 (Qwen)
278
+ "pixels_1": pixels_1, # 负样本 (SD3)
279
+ "label_0": torch.tensor(1.0),
280
+ "label_1": torch.tensor(0.0),
281
+ "num_examples_per_prompt": torch.tensor(1.0)
282
+ }
283
+
284
+
285
+ # ====== 训练 loop ======
286
+ # def finetune_pickscore(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6, device="cuda"):
287
+ # processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
288
+ # model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
289
+
290
+ # dataset = QwenSD3JsonDataset(processor,json_file, qwen_dir, sd3_dir)
291
+ # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
292
+
293
+ # criterion = CLIPCriterion(CLIPCriterionConfig())
294
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
295
+ # # import pdb; pdb.set_trace()
296
+
297
+ # model.train()
298
+ # for epoch in range(epochs):
299
+ # total_loss = 0.0
300
+ # for batch in dataloader:
301
+ # batch = {k: v.to(device) for k, v in batch.items()}
302
+ # loss = criterion(model, batch)
303
+
304
+ # optimizer.zero_grad()
305
+ # loss.backward()
306
+ # optimizer.step()
307
+
308
+ # total_loss += loss.item()
309
+ # print(f"Epoch {epoch} | Loss {total_loss/len(dataloader):.4f}")
310
+
311
+ # model.save_pretrained("pickscore_qwen_finetuned")
312
+ # return model
313
+
314
+ def finetune_pickscore_distributed(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6):
315
+ # 1. 初始化分布式
316
+ dist.init_process_group(backend="nccl")
317
+ local_rank = int(os.environ["LOCAL_RANK"])
318
+ torch.cuda.set_device(local_rank)
319
+ device = torch.device("cuda", local_rank)
320
+
321
+ # 2. 准备数据
322
+ processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
323
+ dataset = QwenSD3JsonDataset(processor, json_file, qwen_dir, sd3_dir)
324
+ sampler = DistributedSampler(dataset)
325
+ dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
326
+
327
+ # 3. 模型 + DDP
328
+ model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
329
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
330
+
331
+ criterion = CLIPCriterion(CLIPCriterionConfig())
332
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
333
+
334
+ # 4. 训练
335
+ model.train()
336
+ if dist.get_rank() == 0:
337
+ evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
338
+ for epoch in range(epochs):
339
+ sampler.set_epoch(epoch) # 保证每个 epoch shuffle 一样
340
+ total_loss = 0.0
341
+
342
+ for step, batch in enumerate(dataloader):
343
+ batch = {k: v.to(device) for k, v in batch.items()}
344
+ loss = criterion(model.module, batch)
345
+
346
+ optimizer.zero_grad()
347
+ loss.backward()
348
+ optimizer.step()
349
+
350
+ # 累积loss(先local)
351
+ total_loss += loss.item()
352
+
353
+ # 每隔一定步打印一次(rank=0)
354
+ if step % 50 == 0: # 你可以改成10、100
355
+ # all_reduce 把所有 GPU 的 loss 平均
356
+ avg_loss = torch.tensor(loss.item(), device=device)
357
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
358
+ if dist.get_rank() == 0:
359
+ print(f"[Epoch {epoch} | Step {step}/{len(dataloader)}] "
360
+ f"local_loss={loss.item():.4f} | avg_loss={avg_loss.item():.4f}")
361
+
362
+ # 每个 epoch 打印 epoch 平均 loss
363
+ epoch_loss = torch.tensor(total_loss / len(dataloader), device=device)
364
+ dist.all_reduce(epoch_loss, op=dist.ReduceOp.AVG)
365
+ if dist.get_rank() == 0:
366
+ print(f"===> Epoch {epoch} done | avg_epoch_loss={epoch_loss.item():.4f}")
367
+ evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
368
+
369
+ # 5. 保存模型(只在 rank=0)
370
+ if dist.get_rank() == 0:
371
+ model.module.save_pretrained("pickscore_qwen_finetuned")
372
+
373
+ dist.destroy_process_group()
374
+
375
+
376
+ # ====== 用法示例 ======
377
+ if __name__ == "__main__":
378
+ finetune_pickscore_distributed(
379
+ json_file="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images/prompt2img.json",
380
+ qwen_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images",
381
+ sd3_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images",
382
+ epochs=2,
383
+ batch_size=4,
384
+ lr=1e-6,
385
+ )
adv_grpo/pickscore_scorer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ from PIL import Image
3
+ import torch
4
+
5
+ class PickScoreScorer(torch.nn.Module):
6
+ def __init__(self, device="cuda", dtype=torch.float32):
7
+ super().__init__()
8
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
9
+ model_path = "yuvalkirstain/PickScore_v1"
10
+ self.device = device
11
+ self.dtype = dtype
12
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
13
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
14
+ self.model = self.model.to(dtype=dtype)
15
+
16
+ @torch.no_grad()
17
+ def __call__(self, prompt, images):
18
+ # Preprocess images
19
+ if hasattr(self.model, "module"):
20
+ self.model = self.model.module
21
+ image_inputs = self.processor(
22
+ images=images,
23
+ padding=True,
24
+ truncation=True,
25
+ max_length=77,
26
+ return_tensors="pt",
27
+ )
28
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
29
+ # Preprocess text
30
+ text_inputs = self.processor(
31
+ text=prompt,
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=77,
35
+ return_tensors="pt",
36
+ )
37
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
38
+
39
+ # Get embeddings
40
+ image_embs = self.model.get_image_features(**image_inputs)
41
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
42
+
43
+ text_embs = self.model.get_text_features(**text_inputs)
44
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
45
+
46
+ # Calculate scores
47
+ logit_scale = self.model.logit_scale.exp()
48
+ scores = logit_scale * (text_embs @ image_embs.T)
49
+ scores = scores.diag()
50
+ # norm to 0-1
51
+ scores = scores/26
52
+ return scores
53
+
54
+ # Usage example
55
+ def main():
56
+ scorer = PickScoreScorer(
57
+ device="cuda",
58
+ dtype=torch.float32
59
+ )
60
+ images=[
61
+ "nasa.jpg",
62
+ ]
63
+ pil_images = [Image.open(img) for img in images]
64
+ prompts=[
65
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
66
+ ]
67
+ print(scorer(prompts, pil_images))
68
+
69
+ if __name__ == "__main__":
70
+ main()
adv_grpo/pickscore_scorer_constractive.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ from PIL import Image
3
+ import torch
4
+
5
+ class PickScoreScorerConstractive(torch.nn.Module):
6
+ def __init__(self, device="cuda", dtype=torch.float32):
7
+ super().__init__()
8
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
9
+ model_path = "yuvalkirstain/PickScore_v1"
10
+ self.device = device
11
+ self.dtype = dtype
12
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
13
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
14
+ self.model = self.model.to(dtype=dtype)
15
+
16
+ @torch.no_grad()
17
+ def __call__(self, prompt, images, ref_images):
18
+ # Preprocess images
19
+ image_inputs = self.processor(
20
+ images=images,
21
+ padding=True,
22
+ truncation=True,
23
+ max_length=77,
24
+ return_tensors="pt",
25
+ )
26
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
27
+
28
+ ref_image_inputs = self.processor(
29
+ images=ref_images,
30
+ padding=True,
31
+ truncation=True,
32
+ max_length=77,
33
+ return_tensors="pt",
34
+ )
35
+ ref_image_inputs = {k: v.to(device=self.device) for k, v in ref_image_inputs.items()}
36
+
37
+
38
+
39
+ # Preprocess text
40
+ text_inputs = self.processor(
41
+ text=prompt,
42
+ padding=True,
43
+ truncation=True,
44
+ max_length=77,
45
+ return_tensors="pt",
46
+ )
47
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
48
+
49
+ # Get embeddings
50
+ image_embs = self.model.get_image_features(**image_inputs)
51
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
52
+
53
+ ref_image_embs = self.model.get_image_features(**ref_image_inputs)
54
+ ref_image_embs = ref_image_embs / ref_image_embs.norm(p=2, dim=-1, keepdim=True)
55
+
56
+ text_embs = self.model.get_text_features(**text_inputs)
57
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
58
+
59
+ # Calculate scores
60
+ logit_scale = self.model.logit_scale.exp()
61
+ scores = logit_scale * (text_embs @ image_embs.T)
62
+ scores = scores.diag()
63
+ # norm to 0-1
64
+ scores = scores/26
65
+
66
+ ref_scores = logit_scale * (text_embs @ ref_image_embs.T)
67
+ ref_scores = ref_scores.diag()
68
+ ref_scores = ref_scores/26
69
+
70
+
71
+ return scores, ref_scores, image_embs, ref_image_embs
72
+
73
+ # Usage example
74
+ def main():
75
+ scorer = PickScoreScorer(
76
+ device="cuda",
77
+ dtype=torch.float32
78
+ )
79
+ images=[
80
+ "nasa.jpg",
81
+ ]
82
+ pil_images = [Image.open(img) for img in images]
83
+ prompts=[
84
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
85
+ ]
86
+ print(scorer(prompts, pil_images))
87
+
88
+ if __name__ == "__main__":
89
+ main()
adv_grpo/pickscore_scorer_patch.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ from PIL import Image
3
+ import torch
4
+
5
+ class PickScoreScorer(torch.nn.Module):
6
+ def __init__(self, device="cuda", dtype=torch.float32):
7
+ super().__init__()
8
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
9
+ model_path = "yuvalkirstain/PickScore_v1"
10
+ self.device = device
11
+ self.dtype = dtype
12
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
13
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
14
+ self.model = self.model.to(dtype=dtype)
15
+
16
+ @torch.no_grad()
17
+ def __call__(self, prompt, images):
18
+ # Preprocess images
19
+ if hasattr(self.model, "module"):
20
+ self.model = self.model.module
21
+ image_inputs = self.processor(
22
+ images=images,
23
+ padding=True,
24
+ truncation=True,
25
+ max_length=77,
26
+ return_tensors="pt",
27
+ )
28
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
29
+ # Preprocess text
30
+ text_inputs = self.processor(
31
+ text=prompt,
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=77,
35
+ return_tensors="pt",
36
+ )
37
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
38
+
39
+ # Get embeddings
40
+ # image_embs = self.model.get_image_features(**image_inputs)
41
+ import pdb; pdb.set_trace()
42
+ image_embs = self.model.vision_model(image_inputs["pixel_values"],output_hidden_states=True)
43
+ image_embs = image_embs.last_hidden_state
44
+
45
+ image_embs = self.model.visual_projection(image_embs) # [B, N, 1024]
46
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
47
+
48
+ text_embs = self.model.get_text_features(**text_inputs)
49
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
50
+
51
+ # Calculate scores
52
+ logit_scale = self.model.logit_scale.exp()
53
+ # scores = logit_scale * (text_embs @ image_embs.T)
54
+ patch_scores = torch.einsum("bd,bnd->bn", text_embs, image_embs) # [B, N]
55
+ scores = logit_scale * patch_scores.mean(dim=1) # 取所有 patch 的平均
56
+ # scores = scores.diag()
57
+ # norm to 0-1
58
+ scores = scores/26
59
+ # import pdb; pdb.set_trace()
60
+ return scores
61
+
62
+ # Usage example
63
+ def main():
64
+ scorer = PickScoreScorer(
65
+ device="cuda",
66
+ dtype=torch.float32
67
+ )
68
+ images=[
69
+ "nasa.jpg",
70
+ ]
71
+ pil_images = [Image.open(img) for img in images]
72
+ prompts=[
73
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
74
+ ]
75
+ print(scorer(prompts, pil_images))
76
+
77
+ if __name__ == "__main__":
78
+ main()