diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..c181295ffbb693822919fab3b78e7fa27fcb6405 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+adv_grpo/assets/flow_grpo_fast.png filter=lfs diff=lfs merge=lfs -text
diff --git a/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc b/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bd14ec63f59adaf6c2a71cc495f56561d661c24
Binary files /dev/null and b/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/discriminator.cpython-310.pyc b/adv_grpo/__pycache__/discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..057eed8804426cead5d24583ac1478edb69eecf7
Binary files /dev/null and b/adv_grpo/__pycache__/discriminator.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/ema.cpython-310.pyc b/adv_grpo/__pycache__/ema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a67a9a31ba71037e9c61adbb849f49e3a7f3c75a
Binary files /dev/null and b/adv_grpo/__pycache__/ema.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..113135015d4be509379b6f852fa5824bfe4788c3
Binary files /dev/null and b/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc b/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3af95ea15e99987014cd96f50a7fc1cf3470770
Binary files /dev/null and b/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc b/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68a7a54674316b88d1c1c1bdd19efa849f255a3a
Binary files /dev/null and b/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/ocr.cpython-310.pyc b/adv_grpo/__pycache__/ocr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6a42af366cef7b37766039eac170c00c23088b5
Binary files /dev/null and b/adv_grpo/__pycache__/ocr.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfb4ede079452fc65682fb9e675bb5f4f600e60f
Binary files /dev/null and b/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc b/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63a42d7a8696445e14f5bcd4b110194f7b330a24
Binary files /dev/null and b/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc b/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d754ee400786b3fb7965c84727681f534026259f
Binary files /dev/null and b/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/prompts.cpython-310.pyc b/adv_grpo/__pycache__/prompts.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3797ed3b5ae1faf3ec861941368269f13477898e
Binary files /dev/null and b/adv_grpo/__pycache__/prompts.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/rewards.cpython-310.pyc b/adv_grpo/__pycache__/rewards.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..401e17e811fcb9a6c0373c6240848832adef7e02
Binary files /dev/null and b/adv_grpo/__pycache__/rewards.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc b/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3846c5741ce33be3af8421c5c31327110fa5d77
Binary files /dev/null and b/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc differ
diff --git a/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b05c6efd0e877da5d4cca2b63ff9221cce32817
Binary files /dev/null and b/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc differ
diff --git a/adv_grpo/aesthetic_scorer.py b/adv_grpo/aesthetic_scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff364c45848fa4a8e9e121b6cae755b90056b64e
--- /dev/null
+++ b/adv_grpo/aesthetic_scorer.py
@@ -0,0 +1,53 @@
+# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py
+
+from importlib import resources
+import torch
+import torch.nn as nn
+import numpy as np
+from transformers import CLIPModel, CLIPProcessor
+from PIL import Image
+
+ASSETS_PATH = resources.files("adv_grpo.assets")
+
+
+class MLP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(768, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1),
+ )
+
+ @torch.no_grad()
+ def forward(self, embed):
+ return self.layers(embed)
+
+
+class AestheticScorer(torch.nn.Module):
+ def __init__(self, dtype):
+ super().__init__()
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ self.mlp = MLP()
+ state_dict = torch.load(
+ ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")
+ )
+ self.mlp.load_state_dict(state_dict)
+ self.dtype = dtype
+ self.eval()
+
+ @torch.no_grad()
+ def __call__(self, images):
+ device = next(self.parameters()).device
+ inputs = self.processor(images=images, return_tensors="pt")
+ inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
+ embed = self.clip.get_image_features(**inputs)
+ # normalize embedding
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
+ return self.mlp(embed).squeeze(1)
diff --git a/adv_grpo/assets/activities.txt b/adv_grpo/assets/activities.txt
new file mode 100644
index 0000000000000000000000000000000000000000..abea0458a5836b50ec85da2f732ff3a7d63b8c3a
--- /dev/null
+++ b/adv_grpo/assets/activities.txt
@@ -0,0 +1,3 @@
+washing the dishes
+riding a bike
+playing chess
\ No newline at end of file
diff --git a/adv_grpo/assets/activities_v0.txt b/adv_grpo/assets/activities_v0.txt
new file mode 100644
index 0000000000000000000000000000000000000000..abea0458a5836b50ec85da2f732ff3a7d63b8c3a
--- /dev/null
+++ b/adv_grpo/assets/activities_v0.txt
@@ -0,0 +1,3 @@
+washing the dishes
+riding a bike
+playing chess
\ No newline at end of file
diff --git a/adv_grpo/assets/flow_grpo_fast.png b/adv_grpo/assets/flow_grpo_fast.png
new file mode 100644
index 0000000000000000000000000000000000000000..0de32b018bd203d0f7bb735017796cc25989ab0e
--- /dev/null
+++ b/adv_grpo/assets/flow_grpo_fast.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35709d674818e29d39728e036479e51b8e015bcc7caf4dc54eb3e4f41cc05ab1
+size 222022
diff --git a/adv_grpo/assets/imagenet_classes.txt b/adv_grpo/assets/imagenet_classes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..722c984560c36a6014113deb8106a6cf50050ad5
--- /dev/null
+++ b/adv_grpo/assets/imagenet_classes.txt
@@ -0,0 +1,1000 @@
+tench, Tinca tinca
+goldfish, Carassius auratus
+great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
+tiger shark, Galeocerdo cuvieri
+hammerhead, hammerhead shark
+electric ray, crampfish, numbfish, torpedo
+stingray
+cock
+hen
+ostrich, Struthio camelus
+brambling, Fringilla montifringilla
+goldfinch, Carduelis carduelis
+house finch, linnet, Carpodacus mexicanus
+junco, snowbird
+indigo bunting, indigo finch, indigo bird, Passerina cyanea
+robin, American robin, Turdus migratorius
+bulbul
+jay
+magpie
+chickadee
+water ouzel, dipper
+kite
+bald eagle, American eagle, Haliaeetus leucocephalus
+vulture
+great grey owl, great gray owl, Strix nebulosa
+European fire salamander, Salamandra salamandra
+common newt, Triturus vulgaris
+eft
+spotted salamander, Ambystoma maculatum
+axolotl, mud puppy, Ambystoma mexicanum
+bullfrog, Rana catesbeiana
+tree frog, tree-frog
+tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
+loggerhead, loggerhead turtle, Caretta caretta
+leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
+mud turtle
+terrapin
+box turtle, box tortoise
+banded gecko
+common iguana, iguana, Iguana iguana
+American chameleon, anole, Anolis carolinensis
+whiptail, whiptail lizard
+agama
+frilled lizard, Chlamydosaurus kingi
+alligator lizard
+Gila monster, Heloderma suspectum
+green lizard, Lacerta viridis
+African chameleon, Chamaeleo chamaeleon
+Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
+African crocodile, Nile crocodile, Crocodylus niloticus
+American alligator, Alligator mississipiensis
+triceratops
+thunder snake, worm snake, Carphophis amoenus
+ringneck snake, ring-necked snake, ring snake
+hognose snake, puff adder, sand viper
+green snake, grass snake
+king snake, kingsnake
+garter snake, grass snake
+water snake
+vine snake
+night snake, Hypsiglena torquata
+boa constrictor, Constrictor constrictor
+rock python, rock snake, Python sebae
+Indian cobra, Naja naja
+green mamba
+sea snake
+horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
+diamondback, diamondback rattlesnake, Crotalus adamanteus
+sidewinder, horned rattlesnake, Crotalus cerastes
+trilobite
+harvestman, daddy longlegs, Phalangium opilio
+scorpion
+black and gold garden spider, Argiope aurantia
+barn spider, Araneus cavaticus
+garden spider, Aranea diademata
+black widow, Latrodectus mactans
+tarantula
+wolf spider, hunting spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse, partridge, Bonasa umbellus
+prairie chicken, prairie grouse, prairie fowl
+peacock
+quail
+partridge
+African grey, African gray, Psittacus erithacus
+macaw
+sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser, Mergus serrator
+goose
+black swan, Cygnus atratus
+tusker
+echidna, spiny anteater, anteater
+platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
+wallaby, brush kangaroo
+koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
+wombat
+jellyfish
+sea anemone, anemone
+brain coral
+flatworm, platyhelminth
+nematode, nematode worm, roundworm
+conch
+snail
+slug
+sea slug, nudibranch
+chiton, coat-of-mail shell, sea cradle, polyplacophore
+chambered nautilus, pearly nautilus, nautilus
+Dungeness crab, Cancer magister
+rock crab, Cancer irroratus
+fiddler crab
+king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
+American lobster, Northern lobster, Maine lobster, Homarus americanus
+spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
+crayfish, crawfish, crawdad, crawdaddy
+hermit crab
+isopod
+white stork, Ciconia ciconia
+black stork, Ciconia nigra
+spoonbill
+flamingo
+little blue heron, Egretta caerulea
+American egret, great white heron, Egretta albus
+bittern
+crane
+limpkin, Aramus pictus
+European gallinule, Porphyrio porphyrio
+American coot, marsh hen, mud hen, water hen, Fulica americana
+bustard
+ruddy turnstone, Arenaria interpres
+red-backed sandpiper, dunlin, Erolia alpina
+redshank, Tringa totanus
+dowitcher
+oystercatcher, oyster catcher
+pelican
+king penguin, Aptenodytes patagonica
+albatross, mollymawk
+grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
+killer whale, killer, orca, grampus, sea wolf, Orcinus orca
+dugong, Dugong dugon
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog, Maltese terrier, Maltese
+Pekinese, Pekingese, Peke
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound, Afghan
+basset, basset hound
+beagle
+bloodhound, sleuthhound
+bluetick
+black-and-tan coonhound
+Walker hound, Walker foxhound
+English foxhound
+redbone
+borzoi, Russian wolfhound
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound, Ibizan Podenco
+Norwegian elkhound, elkhound
+otterhound, otter hound
+Saluki, gazelle hound
+Scottish deerhound, deerhound
+Weimaraner
+Staffordshire bullterrier, Staffordshire bull terrier
+American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier, Sealyham
+Airedale, Airedale terrier
+cairn, cairn terrier
+Australian terrier
+Dandie Dinmont, Dandie Dinmont terrier
+Boston bull, Boston terrier
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier, Scottish terrier, Scottie
+Tibetan terrier, chrysanthemum dog
+silky terrier, Sydney silky
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa, Lhasa apso
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla, Hungarian pointer
+English setter
+Irish setter, red setter
+Gordon setter
+Brittany spaniel
+clumber, clumber spaniel
+English springer, English springer spaniel
+Welsh springer spaniel
+cocker spaniel, English cocker spaniel, cocker
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog, bobtail
+Shetland sheepdog, Shetland sheep dog, Shetland
+collie
+Border collie
+Bouvier des Flandres, Bouviers des Flandres
+Rottweiler
+German shepherd, German shepherd dog, German police dog, alsatian
+Doberman, Doberman pinscher
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard, St Bernard
+Eskimo dog, husky
+malamute, malemute, Alaskan malamute
+Siberian husky
+dalmatian, coach dog, carriage dog
+affenpinscher, monkey pinscher, monkey dog
+basenji
+pug, pug-dog
+Leonberg
+Newfoundland, Newfoundland dog
+Great Pyrenees
+Samoyed, Samoyede
+Pomeranian
+chow, chow chow
+keeshond
+Brabancon griffon
+Pembroke, Pembroke Welsh corgi
+Cardigan, Cardigan Welsh corgi
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf, grey wolf, gray wolf, Canis lupus
+white wolf, Arctic wolf, Canis lupus tundrarum
+red wolf, maned wolf, Canis rufus, Canis niger
+coyote, prairie wolf, brush wolf, Canis latrans
+dingo, warrigal, warragal, Canis dingo
+dhole, Cuon alpinus
+African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
+hyena, hyaena
+red fox, Vulpes vulpes
+kit fox, Vulpes macrotis
+Arctic fox, white fox, Alopex lagopus
+grey fox, gray fox, Urocyon cinereoargenteus
+tabby, tabby cat
+tiger cat
+Persian cat
+Siamese cat, Siamese
+Egyptian cat
+cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
+lynx, catamount
+leopard, Panthera pardus
+snow leopard, ounce, Panthera uncia
+jaguar, panther, Panthera onca, Felis onca
+lion, king of beasts, Panthera leo
+tiger, Panthera tigris
+cheetah, chetah, Acinonyx jubatus
+brown bear, bruin, Ursus arctos
+American black bear, black bear, Ursus americanus, Euarctos americanus
+ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
+sloth bear, Melursus ursinus, Ursus ursinus
+mongoose
+meerkat, mierkat
+tiger beetle
+ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
+ground beetle, carabid beetle
+long-horned beetle, longicorn, longicorn beetle
+leaf beetle, chrysomelid
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant, emmet, pismire
+grasshopper, hopper
+cricket
+walking stick, walkingstick, stick insect
+cockroach, roach
+mantis, mantid
+cicada, cicala
+leafhopper
+lacewing, lacewing fly
+dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
+damselfly
+admiral
+ringlet, ringlet butterfly
+monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
+cabbage butterfly
+sulphur butterfly, sulfur butterfly
+lycaenid, lycaenid butterfly
+starfish, sea star
+sea urchin
+sea cucumber, holothurian
+wood rabbit, cottontail, cottontail rabbit
+hare
+Angora, Angora rabbit
+hamster
+porcupine, hedgehog
+fox squirrel, eastern fox squirrel, Sciurus niger
+marmot
+beaver
+guinea pig, Cavia cobaya
+sorrel
+zebra
+hog, pig, grunter, squealer, Sus scrofa
+wild boar, boar, Sus scrofa
+warthog
+hippopotamus, hippo, river horse, Hippopotamus amphibius
+ox
+water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
+bison
+ram, tup
+bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
+ibex, Capra ibex
+hartebeest
+impala, Aepyceros melampus
+gazelle
+Arabian camel, dromedary, Camelus dromedarius
+llama
+weasel
+mink
+polecat, fitch, foulmart, foumart, Mustela putorius
+black-footed ferret, ferret, Mustela nigripes
+otter
+skunk, polecat, wood pussy
+badger
+armadillo
+three-toed sloth, ai, Bradypus tridactylus
+orangutan, orang, orangutang, Pongo pygmaeus
+gorilla, Gorilla gorilla
+chimpanzee, chimp, Pan troglodytes
+gibbon, Hylobates lar
+siamang, Hylobates syndactylus, Symphalangus syndactylus
+guenon, guenon monkey
+patas, hussar monkey, Erythrocebus patas
+baboon
+macaque
+langur
+colobus, colobus monkey
+proboscis monkey, Nasalis larvatus
+marmoset
+capuchin, ringtail, Cebus capucinus
+howler monkey, howler
+titi, titi monkey
+spider monkey, Ateles geoffroyi
+squirrel monkey, Saimiri sciureus
+Madagascar cat, ring-tailed lemur, Lemur catta
+indri, indris, Indri indri, Indri brevicaudatus
+Indian elephant, Elephas maximus
+African elephant, Loxodonta africana
+lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
+giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
+barracouta, snoek
+eel
+coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
+rock beauty, Holocanthus tricolor
+anemone fish
+sturgeon
+gar, garfish, garpike, billfish, Lepisosteus osseus
+lionfish
+puffer, pufferfish, blowfish, globefish
+abacus
+abaya
+academic gown, academic robe, judge's robe
+accordion, piano accordion, squeeze box
+acoustic guitar
+aircraft carrier, carrier, flattop, attack aircraft carrier
+airliner
+airship, dirigible
+altar
+ambulance
+amphibian, amphibious vehicle
+analog clock
+apiary, bee house
+apron
+ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
+assault rifle, assault gun
+backpack, back pack, knapsack, packsack, rucksack, haversack
+bakery, bakeshop, bakehouse
+balance beam, beam
+balloon
+ballpoint, ballpoint pen, ballpen, Biro
+Band Aid
+banjo
+bannister, banister, balustrade, balusters, handrail
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel, cask
+barrow, garden cart, lawn cart, wheelbarrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap, swimming cap
+bath towel
+bathtub, bathing tub, bath, tub
+beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
+beacon, lighthouse, beacon light, pharos
+beaker
+bearskin, busby, shako
+beer bottle
+beer glass
+bell cote, bell cot
+bib
+bicycle-built-for-two, tandem bicycle, tandem
+bikini, two-piece
+binder, ring-binder
+binoculars, field glasses, opera glasses
+birdhouse
+boathouse
+bobsled, bobsleigh, bob
+bolo tie, bolo, bola tie, bola
+bonnet, poke bonnet
+bookcase
+bookshop, bookstore, bookstall
+bottlecap
+bow
+bow tie, bow-tie, bowtie
+brass, memorial tablet, plaque
+brassiere, bra, bandeau
+breakwater, groin, groyne, mole, bulwark, seawall, jetty
+breastplate, aegis, egis
+broom
+bucket, pail
+buckle
+bulletproof vest
+bullet train, bullet
+butcher shop, meat market
+cab, hack, taxi, taxicab
+caldron, cauldron
+candle, taper, wax light
+cannon
+canoe
+can opener, tin opener
+cardigan
+car mirror
+carousel, carrousel, merry-go-round, roundabout, whirligig
+carpenter's kit, tool kit
+carton
+car wheel
+cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello, violoncello
+cellular telephone, cellular phone, cellphone, cell, mobile phone
+chain
+chainlink fence
+chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
+chain saw, chainsaw
+chest
+chiffonier, commode
+chime, bell, gong
+china cabinet, china closet
+Christmas stocking
+church, church building
+cinema, movie theater, movie theatre, movie house, picture palace
+cleaver, meat cleaver, chopper
+cliff dwelling
+cloak
+clog, geta, patten, sabot
+cocktail shaker
+coffee mug
+coffeepot
+coil, spiral, volute, whorl, helix
+combination lock
+computer keyboard, keypad
+confectionery, confectionary, candy store
+container ship, containership, container vessel
+convertible
+corkscrew, bottle screw
+cornet, horn, trumpet, trump
+cowboy boot
+cowboy hat, ten-gallon hat
+cradle
+crane
+crash helmet
+crate
+crib, cot
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam, dike, dyke
+desk
+desktop computer
+dial telephone, dial phone
+diaper, nappy, napkin
+digital clock
+digital watch
+dining table, board
+dishrag, dishcloth
+dishwasher, dish washer, dishwashing machine
+disk brake, disc brake
+dock, dockage, docking facility
+dogsled, dog sled, dog sleigh
+dome
+doormat, welcome mat
+drilling platform, offshore rig
+drum, membranophone, tympan
+drumstick
+dumbbell
+Dutch oven
+electric fan, blower
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa, boa
+file, file cabinet, filing cabinet
+fireboat
+fire engine, fire truck
+fire screen, fireguard
+flagpole, flagstaff
+flute, transverse flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn, horn
+frying pan, frypan, skillet
+fur coat
+garbage truck, dustcart
+gasmask, respirator, gas helmet
+gas pump, gasoline pump, petrol pump, island dispenser
+goblet
+go-kart
+golf ball
+golfcart, golf cart
+gondola
+gong, tam-tam
+gown
+grand piano, grand
+greenhouse, nursery, glasshouse
+grille, radiator grille
+grocery store, grocery, food market, market
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower, blow dryer, blow drier, hair dryer, hair drier
+hand-held computer, hand-held microcomputer
+handkerchief, hankie, hanky, hankey
+hard disc, hard disk, fixed disk
+harmonica, mouth organ, harp, mouth harp
+harp
+harvester, reaper
+hatchet
+holster
+home theater, home theatre
+honeycomb
+hook, claw
+hoopskirt, crinoline
+horizontal bar, high bar
+horse cart, horse-cart
+hourglass
+iPod
+iron, smoothing iron
+jack-o'-lantern
+jean, blue jean, denim
+jeep, landrover
+jersey, T-shirt, tee shirt
+jigsaw puzzle
+jinrikisha, ricksha, rickshaw
+joystick
+kimono
+knee pad
+knot
+lab coat, laboratory coat
+ladle
+lampshade, lamp shade
+laptop, laptop computer
+lawn mower, mower
+lens cap, lens cover
+letter opener, paper knife, paperknife
+library
+lifeboat
+lighter, light, igniter, ignitor
+limousine, limo
+liner, ocean liner
+lipstick, lip rouge
+Loafer
+lotion
+loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
+loupe, jeweler's loupe
+lumbermill, sawmill
+magnetic compass
+mailbag, postbag
+mailbox, letter box
+maillot
+maillot, tank suit
+manhole cover
+maraca
+marimba, xylophone
+mask
+matchstick
+maypole
+maze, labyrinth
+measuring cup
+medicine chest, medicine cabinet
+megalith, megalithic structure
+microphone, mike
+microwave, microwave oven
+military uniform
+milk can
+minibus
+miniskirt, mini
+minivan
+missile
+mitten
+mixing bowl
+mobile home, manufactured home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter, scooter
+mountain bike, all-terrain bike, off-roader
+mountain tent
+mouse, computer mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook, notebook computer
+obelisk
+oboe, hautboy, hautbois
+ocarina, sweet potato
+odometer, hodometer, mileometer, milometer
+oil filter
+organ, pipe organ
+oscilloscope, scope, cathode-ray oscilloscope, CRO
+overskirt
+oxcart
+oxygen mask
+packet
+paddle, boat paddle
+paddlewheel, paddle wheel
+padlock
+paintbrush
+pajama, pyjama, pj's, jammies
+palace
+panpipe, pandean pipe, syrinx
+paper towel
+parachute, chute
+parallel bars, bars
+park bench
+parking meter
+passenger car, coach, carriage
+patio, terrace
+pay-phone, pay-station
+pedestal, plinth, footstall
+pencil box, pencil case
+pencil sharpener
+perfume, essence
+Petri dish
+photocopier
+pick, plectrum, plectron
+pickelhaube
+picket fence, paling
+pickup, pickup truck
+pier
+piggy bank, penny bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate, pirate ship
+pitcher, ewer
+plane, carpenter's plane, woodworking plane
+planetarium
+plastic bag
+plate rack
+plow, plough
+plunger, plumber's helper
+Polaroid camera, Polaroid Land camera
+pole
+police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
+poncho
+pool table, billiard table, snooker table
+pop bottle, soda bottle
+pot, flowerpot
+potter's wheel
+power drill
+prayer rug, prayer mat
+printer
+prison, prison house
+projectile, missile
+projector
+puck, hockey puck
+punching bag, punch bag, punching ball, punchball
+purse
+quill, quill pen
+quilt, comforter, comfort, puff
+racer, race car, racing car
+racket, racquet
+radiator
+radio, wireless
+radio telescope, radio reflector
+rain barrel
+recreational vehicle, RV, R.V.
+reel
+reflex camera
+refrigerator, icebox
+remote control, remote
+restaurant, eating house, eating place, eatery
+revolver, six-gun, six-shooter
+rifle
+rocking chair, rocker
+rotisserie
+rubber eraser, rubber, pencil eraser
+rugby ball
+rule, ruler
+running shoe
+safe
+safety pin
+saltshaker, salt shaker
+sandal
+sarong
+sax, saxophone
+scabbard
+scale, weighing machine
+school bus
+schooner
+scoreboard
+screen, CRT screen
+screw
+screwdriver
+seat belt, seatbelt
+sewing machine
+shield, buckler
+shoe shop, shoe-shop, shoe store
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule, slipstick
+sliding door
+slot, one-armed bandit
+snorkel
+snowmobile
+snowplow, snowplough
+soap dispenser
+soccer ball
+sock
+solar dish, solar collector, solar furnace
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web, spider's web
+spindle
+sports car, sport car
+spotlight, spot
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch, stop watch
+stove
+strainer
+streetcar, tram, tramcar, trolley, trolley car
+stretcher
+studio couch, day bed
+stupa, tope
+submarine, pigboat, sub, U-boat
+suit, suit of clothes
+sundial
+sunglass
+sunglasses, dark glasses, shades
+sunscreen, sunblock, sun blocker
+suspension bridge
+swab, swob, mop
+sweatshirt
+swimming trunks, bathing trunks
+swing
+switch, electric switch, electrical switch
+syringe
+table lamp
+tank, army tank, armored combat vehicle, armoured combat vehicle
+tape player
+teapot
+teddy, teddy bear
+television, television system
+tennis ball
+thatch, thatched roof
+theater curtain, theatre curtain
+thimble
+thresher, thrasher, threshing machine
+throne
+tile roof
+toaster
+tobacco shop, tobacconist shop, tobacconist
+toilet seat
+torch
+totem pole
+tow truck, tow car, wrecker
+toyshop
+tractor
+trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
+tray
+trench coat
+tricycle, trike, velocipede
+trimaran
+tripod
+triumphal arch
+trolleybus, trolley coach, trackless trolley
+trombone
+tub, vat
+turnstile
+typewriter keyboard
+umbrella
+unicycle, monocycle
+upright, upright piano
+vacuum, vacuum cleaner
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin, fiddle
+volleyball
+waffle iron
+wall clock
+wallet, billfold, notecase, pocketbook
+wardrobe, closet, press
+warplane, military plane
+washbasin, handbasin, washbowl, lavabo, wash-hand basin
+washer, automatic washer, washing machine
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool, woolen, woollen
+worm fence, snake fence, snake-rail fence, Virginia fence
+wreck
+yawl
+yurt
+web site, website, internet site, site
+comic book
+crossword puzzle, crossword
+street sign
+traffic light, traffic signal, stoplight
+book jacket, dust cover, dust jacket, dust wrapper
+menu
+plate
+guacamole
+consomme
+hot pot, hotpot
+trifle
+ice cream, icecream
+ice lolly, lolly, lollipop, popsicle
+French loaf
+bagel, beigel
+pretzel
+cheeseburger
+hotdog, hot dog, red hot
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini, courgette
+spaghetti squash
+acorn squash
+butternut squash
+cucumber, cuke
+artichoke, globe artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple, ananas
+banana
+jackfruit, jak, jack
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce, chocolate syrup
+dough
+meat loaf, meatloaf
+pizza, pizza pie
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff, drop, drop-off
+coral reef
+geyser
+lakeside, lakeshore
+promontory, headland, head, foreland
+sandbar, sand bar
+seashore, coast, seacoast, sea-coast
+valley, vale
+volcano
+ballplayer, baseball player
+groom, bridegroom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
+corn
+acorn
+hip, rose hip, rosehip
+buckeye, horse chestnut, conker
+coral fungus
+agaric
+gyromitra
+stinkhorn, carrion fungus
+earthstar
+hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
+bolete
+ear, spike, capitulum
+toilet tissue, toilet paper, bathroom tissue
\ No newline at end of file
diff --git a/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth b/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth
new file mode 100644
index 0000000000000000000000000000000000000000..aae5780851125baf1a30834c3a715d3866858a4d
--- /dev/null
+++ b/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0
+size 3714759
diff --git a/adv_grpo/assets/simple_animals.txt b/adv_grpo/assets/simple_animals.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bc9e1176a2eb831541d98dcd810b674a36602a78
--- /dev/null
+++ b/adv_grpo/assets/simple_animals.txt
@@ -0,0 +1,45 @@
+cat
+dog
+horse
+monkey
+rabbit
+zebra
+spider
+bird
+sheep
+deer
+cow
+goat
+lion
+tiger
+bear
+raccoon
+fox
+wolf
+lizard
+beetle
+ant
+butterfly
+fish
+shark
+whale
+dolphin
+squirrel
+mouse
+rat
+snake
+turtle
+frog
+chicken
+duck
+goose
+bee
+pig
+turkey
+fly
+llama
+camel
+bat
+gorilla
+hedgehog
+kangaroo
diff --git a/adv_grpo/assets/simple_ocr_animals.txt b/adv_grpo/assets/simple_ocr_animals.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fd39766d1798ae1968138f5bcf3dac66b6705d8b
--- /dev/null
+++ b/adv_grpo/assets/simple_ocr_animals.txt
@@ -0,0 +1,5 @@
+cat
+dog
+horse
+monkey
+rabbit
\ No newline at end of file
diff --git a/adv_grpo/assets/simple_ocr_animals_digit1.txt b/adv_grpo/assets/simple_ocr_animals_digit1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..017892c8d121fb6fb70aa8ce9f5f010e30f60052
--- /dev/null
+++ b/adv_grpo/assets/simple_ocr_animals_digit1.txt
@@ -0,0 +1,45 @@
+A cat holding a sign that says '0'
+A dog holding a sign that says '0'
+A horse holding a sign that says '0'
+A monkey holding a sign that says '0'
+A rabbit holding a sign that says '0'
+A cat holding a sign that says '1'
+A dog holding a sign that says '1'
+A horse holding a sign that says '1'
+A monkey holding a sign that says '1'
+A rabbit holding a sign that says '1'
+A cat holding a sign that says '2'
+A dog holding a sign that says '2'
+A horse holding a sign that says '2'
+A monkey holding a sign that says '2'
+A rabbit holding a sign that says '2'
+A cat holding a sign that says '3'
+A dog holding a sign that says '3'
+A horse holding a sign that says '3'
+A monkey holding a sign that says '3'
+A rabbit holding a sign that says '3'
+A cat holding a sign that says '4'
+A dog holding a sign that says '4'
+A horse holding a sign that says '4'
+A monkey holding a sign that says '4'
+A rabbit holding a sign that says '4'
+A cat holding a sign that says '5'
+A dog holding a sign that says '5'
+A horse holding a sign that says '5'
+A monkey holding a sign that says '5'
+A rabbit holding a sign that says '5'
+A cat holding a sign that says '6'
+A dog holding a sign that says '6'
+A horse holding a sign that says '6'
+A monkey holding a sign that says '6'
+A rabbit holding a sign that says '6'
+A cat holding a sign that says '7'
+A dog holding a sign that says '7'
+A horse holding a sign that says '7'
+A monkey holding a sign that says '7'
+A rabbit holding a sign that says '7'
+A cat holding a sign that says '8'
+A dog holding a sign that says '8'
+A horse holding a sign that says '8'
+A monkey holding a sign that says '8'
+A rabbit holding a sign that says '8'
\ No newline at end of file
diff --git a/adv_grpo/assets/simple_ocr_animals_digit3.txt b/adv_grpo/assets/simple_ocr_animals_digit3.txt
new file mode 100644
index 0000000000000000000000000000000000000000..382ca9ce7faca644b367493d92754251896735f0
--- /dev/null
+++ b/adv_grpo/assets/simple_ocr_animals_digit3.txt
@@ -0,0 +1,45 @@
+A cat holding a sign that says '123'
+A dog holding a sign that says '234'
+A horse holding a sign that says '345'
+A monkey holding a sign that says '456'
+A rabbit holding a sign that says '567'
+A cat holding a sign that says '678'
+A dog holding a sign that says '789'
+A horse holding a sign that says '123'
+A monkey holding a sign that says '234'
+A rabbit holding a sign that says '345'
+A cat holding a sign that says '456'
+A dog holding a sign that says '567'
+A horse holding a sign that says '678'
+A monkey holding a sign that says '789'
+A rabbit holding a sign that says '123'
+A cat holding a sign that says '234'
+A dog holding a sign that says '345'
+A horse holding a sign that says '456'
+A monkey holding a sign that says '567'
+A rabbit holding a sign that says '678'
+A cat holding a sign that says '789'
+A dog holding a sign that says '123'
+A horse holding a sign that says '234'
+A monkey holding a sign that says '345'
+A rabbit holding a sign that says '456'
+A cat holding a sign that says '567'
+A dog holding a sign that says '678'
+A horse holding a sign that says '789'
+A monkey holding a sign that says '123'
+A rabbit holding a sign that says '234'
+A cat holding a sign that says '345'
+A dog holding a sign that says '456'
+A horse holding a sign that says '567'
+A monkey holding a sign that says '678'
+A rabbit holding a sign that says '789'
+A cat holding a sign that says '123'
+A dog holding a sign that says '234'
+A horse holding a sign that says '345'
+A monkey holding a sign that says '456'
+A rabbit holding a sign that says '567'
+A cat holding a sign that says '678'
+A dog holding a sign that says '789'
+A horse holding a sign that says '123'
+A monkey holding a sign that says '234'
+A rabbit holding a sign that says '345'
\ No newline at end of file
diff --git a/adv_grpo/assets/simple_ocr_animals_digit5.txt b/adv_grpo/assets/simple_ocr_animals_digit5.txt
new file mode 100644
index 0000000000000000000000000000000000000000..caa91b383e3d1ac2a2bd02e1c2dd27d42373ff7d
--- /dev/null
+++ b/adv_grpo/assets/simple_ocr_animals_digit5.txt
@@ -0,0 +1,50 @@
+A cat holding a sign that says '12345'
+A dog holding a sign that says '23456'
+A horse holding a sign that says '34567'
+A monkey holding a sign that says '45678'
+A rabbit holding a sign that says '56789'
+A cat holding a sign that says '54321'
+A dog holding a sign that says '65432'
+A horse holding a sign that says '76543'
+A monkey holding a sign that says '87654'
+A rabbit holding a sign that says '98765'
+A cat holding a sign that says '12345'
+A dog holding a sign that says '23456'
+A horse holding a sign that says '34567'
+A monkey holding a sign that says '45678'
+A rabbit holding a sign that says '56789'
+A cat holding a sign that says '54321'
+A dog holding a sign that says '65432'
+A horse holding a sign that says '76543'
+A monkey holding a sign that says '87654'
+A rabbit holding a sign that says '98765'
+A cat holding a sign that says '12345'
+A dog holding a sign that says '23456'
+A horse holding a sign that says '34567'
+A monkey holding a sign that says '45678'
+A rabbit holding a sign that says '56789'
+A cat holding a sign that says '54321'
+A dog holding a sign that says '65432'
+A horse holding a sign that says '76543'
+A monkey holding a sign that says '87654'
+A rabbit holding a sign that says '98765'
+A cat holding a sign that says '12345'
+A dog holding a sign that says '23456'
+A horse holding a sign that says '34567'
+A monkey holding a sign that says '45678'
+A rabbit holding a sign that says '56789'
+A cat holding a sign that says '54321'
+A dog holding a sign that says '65432'
+A horse holding a sign that says '76543'
+A monkey holding a sign that says '87654'
+A rabbit holding a sign that says '98765'
+A cat holding a sign that says '12345'
+A dog holding a sign that says '23456'
+A horse holding a sign that says '34567'
+A monkey holding a sign that says '45678'
+A rabbit holding a sign that says '56789'
+A cat holding a sign that says '54321'
+A dog holding a sign that says '65432'
+A horse holding a sign that says '76543'
+A monkey holding a sign that says '87654'
+A rabbit holding a sign that says '98765'
\ No newline at end of file
diff --git a/adv_grpo/assets/test.jpg b/adv_grpo/assets/test.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5fa2a71e3786313e40d76128136eef4afab05ba3
Binary files /dev/null and b/adv_grpo/assets/test.jpg differ
diff --git a/adv_grpo/clip_scorer.py b/adv_grpo/clip_scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eea6458d476d7726e177f27b6af2e0154cd0239
--- /dev/null
+++ b/adv_grpo/clip_scorer.py
@@ -0,0 +1,97 @@
+# Based on https://github.com/RE-N-Y/imscore/blob/main/src/imscore/preference/model.py
+
+from importlib import resources
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.transforms as T
+from transformers import AutoImageProcessor,CLIPProcessor, CLIPModel
+import numpy as np
+from PIL import Image
+
+def get_size(size):
+ if isinstance(size, int):
+ return (size, size)
+ elif "height" in size and "width" in size:
+ return (size["height"], size["width"])
+ elif "shortest_edge" in size:
+ return size["shortest_edge"]
+ else:
+ raise ValueError(f"Invalid size: {size}")
+
+def get_image_transform(processor:AutoImageProcessor):
+ config = processor.to_dict()
+ resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
+ crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
+ normalise = T.Normalize(mean=processor.image_mean, std=processor.image_std) if config.get("do_normalize") else nn.Identity()
+
+ return T.Compose([resize, crop, normalise])
+
+class ClipScorer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ # self.device="cuda"
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ self.tform = get_image_transform(self.processor.image_processor)
+ self.eval()
+
+ def _process(self, pixels):
+ dtype = pixels.dtype
+ pixels = self.tform(pixels)
+ pixels = pixels.to(dtype=dtype)
+
+ return pixels
+
+ @torch.no_grad()
+ def __call__(self, pixels, prompts, return_img_embedding=False):
+ device = next(self.parameters()).device
+ texts = self.processor(text=prompts, padding='max_length', truncation=True, return_tensors="pt").to(device)
+ pixels = self._process(pixels).to(device)
+ outputs = self.model(pixel_values=pixels, **texts)
+ if return_img_embedding:
+ return outputs.logits_per_image.diagonal()/30, outputs.image_embeds
+ return outputs.logits_per_image.diagonal()/30
+
+ @torch.no_grad()
+ def image_similarity(self, pixels, ref_pixels):
+ device = next(self.parameters()).device
+ pixels = self._process(pixels).to(device)
+ ref_pixels = self._process(ref_pixels).to(device)
+
+ pixel_embeds = self.model.get_image_features(pixel_values=pixels)
+ ref_embeds = self.model.get_image_features(pixel_values=ref_pixels)
+
+ pixel_embeds = pixel_embeds / pixel_embeds.norm(p=2, dim=-1, keepdim=True)
+ ref_embeds = ref_embeds / ref_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ sim = pixel_embeds @ ref_embeds.T
+ # sim = torch.diagonal(sim, 0)
+ sim = sim.squeeze(-1)
+ return sim
+
+
+def main():
+ # scorer = ClipScorer(
+ # device='cuda'
+ # )
+ scorer = ClipScorer(
+ )
+
+ images=[
+ "assets/test.jpg",
+ "assets/test.jpg"
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'an image of cat',
+ 'not an image of cat'
+ ]
+ images = [np.array(img) for img in pil_images]
+ images = np.array(images)
+ images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
+ images = torch.tensor(images, dtype=torch.uint8)/255.0
+ print(scorer(images, prompts))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/conv_gradfix.py b/adv_grpo/conv_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..8989fdb3aa251b24fa2df457306dc202d913d723
--- /dev/null
+++ b/adv_grpo/conv_gradfix.py
@@ -0,0 +1,345 @@
+"""
+Custom replacement for `torch.nn.functional.convNd` and `torch.nn.functional.conv_transposeNd`
+that supports arbitrarily high order gradients with zero performance penalty.
+Modified from https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py
+"""
+
+import contextlib
+import warnings
+from typing import Optional
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Conv2d, Conv3d
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+# ----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = (
+ False # Forcefully disable computation of gradients with respect to the weights.
+)
+
+
+@contextlib.contextmanager
+def no_weight_gradients():
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+
+# ----------------------------------------------------------------------------
+class GradFixConv2d(Conv2d):
+ def __init__(self, *args, use_gradfix: bool = False, **kwargs):
+ self.use_gradfix = use_gradfix
+ super().__init__(*args, **kwargs)
+
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
+ conv_fn = F.conv2d if not self.use_gradfix else convNd
+ if self.padding_mode != "zeros":
+ return conv_fn(
+ F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
+ weight,
+ bias,
+ self.stride,
+ (0, 0),
+ self.dilation,
+ self.groups,
+ )
+ return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+
+ def forward(
+ self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
+ ) -> Tensor:
+ weight = self.weight if weight is None else weight
+ bias = self.bias if bias is None else bias
+ return self._conv_forward(input, weight, bias)
+
+
+class GradFixConv3d(Conv3d):
+ def __init__(self, *args, use_gradfix: bool = False, **kwargs):
+ self.use_gradfix = use_gradfix
+ super().__init__(*args, **kwargs)
+
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
+ conv_fn = F.conv3d if not self.use_gradfix else convNd
+ if self.padding_mode != "zeros":
+ return conv_fn(
+ F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
+ weight,
+ bias,
+ self.stride,
+ (0, 0, 0),
+ self.dilation,
+ self.groups,
+ )
+ return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+
+ def forward(
+ self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
+ ) -> Tensor:
+ weight = self.weight if weight is None else weight
+ bias = self.bias if bias is None else bias
+ return self._conv_forward(input, weight, bias)
+
+
+# ----------------------------------------------------------------------------
+
+
+def convNd(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ N = weight.ndim - 2
+ if _should_use_custom_op(input):
+ return _conv_gradfix(
+ transpose=False,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=0,
+ dilation=dilation,
+ groups=groups,
+ ).apply(input, weight, bias)
+ return getattr(torch.nn.functional, f"conv{N}d")(
+ input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
+
+def conv_transposeNd(
+ input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1
+):
+ N = weight.ndim - 2
+ if _should_use_custom_op(input):
+ return _conv_gradfix(
+ transpose=True,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ).apply(input, weight, bias)
+ return getattr(torch.nn.functional, f"conv_transpose{N}d")(
+ input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
+
+
+# ----------------------------------------------------------------------------
+
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != "cuda":
+ return False
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]):
+ return True
+ if torch.__version__.startswith("2"):
+ return True
+ warnings.warn(
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. "
+ f"Falling back to torch.nn.functional.conv2d()."
+ )
+ return False
+
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+
+# ----------------------------------------------------------------------------
+
+_conv_gradfix_cache = dict()
+
+
+def _conv_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ ndim = len(weight_shape) - 2
+ # Parse arguments.
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv_gradfix_cache:
+ return _conv_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [
+ 0,
+ ] * ndim
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class ConvNd(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ """
+ input size: [B, C, ...]
+ weight size:
+ -> Conv: [C_out, C_in // groups, ...]
+ -> Transpose: [C_in, C_out // groups, ...]
+ """
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(input, weight)
+
+ # General case => cuDNN.
+ if transpose:
+ return getattr(torch.nn.functional, f"conv_transpose{ndim}d")(
+ input=input,
+ weight=weight.to(input.dtype),
+ bias=bias,
+ output_padding=output_padding,
+ **common_kwargs,
+ )
+ return getattr(torch.nn.functional, f"conv{ndim}d")(
+ input=input, weight=weight.to(input.dtype), bias=bias, **common_kwargs
+ )
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]: # Input
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ op = _conv_gradfix(
+ transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs,
+ )
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input.shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled: # Weight
+ grad_weight = ConvNdGradWeight.apply(grad_output, input)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]: # Bias
+ grad_bias = grad_output.transpose(0, 1).flatten(1).sum(1)
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class ConvNdGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input):
+ flags = [
+ torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic,
+ torch.backends.cudnn.allow_tf32,
+ ]
+ if torch.__version__.startswith("1"):
+ op = torch._C._jit_get_operation(
+ "aten::cudnn_convolution_backward_weight"
+ if not transpose
+ else "aten::cudnn_convolution_transpose_backward_weight"
+ )
+ grad_weight = op(
+ weight_shape,
+ grad_output,
+ input.to(grad_output.dtype),
+ padding,
+ stride,
+ dilation,
+ groups,
+ *flags,
+ )
+ elif torch.__version__.startswith("2"):
+ # https://github.com/pytorch/pytorch/issues/74437
+ op, _ = torch._C._jit_get_operation("aten::convolution_backward")
+ dummy_weight = torch.tensor(
+ 0.0, dtype=grad_output.dtype, device=input.device
+ ).expand(weight_shape)
+ grad_weight = op(
+ grad_output,
+ input.to(grad_output.dtype),
+ dummy_weight,
+ None,
+ stride,
+ padding,
+ dilation,
+ transpose,
+ (0,) * ndim,
+ groups,
+ [False, True, False],
+ )[1]
+ else:
+ raise NotImplementedError
+ assert grad_weight.shape == weight_shape
+ ctx.save_for_backward(grad_output, input)
+ return grad_weight
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]: # Grad of Weight
+ grad2_grad_output = ConvNd.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output.shape
+
+ if ctx.needs_input_grad[1]: # Input
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ op = _conv_gradfix(
+ transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs,
+ )
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input.shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv_gradfix_cache[key] = ConvNd
+ return ConvNd
+
+
+# ----------------------------------------------------------------------------
diff --git a/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abfe709ca3b831dee3e248d3f3b1e6dc27737966
Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc differ
diff --git a/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9321923338fe113f51150fe42198a8b84bb851d9
Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc differ
diff --git a/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fee662b9067583e744cdbceda8c7b329ee4de16
Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc differ
diff --git a/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f1d57175510c73d118f2becdb024d03139ebf9
--- /dev/null
+++ b/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py
@@ -0,0 +1,255 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
+
+from typing import Any, Dict, List, Optional, Union, Callable
+import torch
+import numpy as np
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.utils import logging
+from .sd3_sde_with_logprob import sde_step_with_logprob
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+@torch.no_grad()
+def pipeline_with_logprob(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ noise_level: float = 0.7,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, height, width)
+ image = self.image_processor.preprocess(image, height, width)
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
+ image.float(),
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Prepare image embeddings
+ all_latents = [latents]
+ all_log_probs = []
+
+ # 7. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+ self._current_timestep = t
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ if noise_pred.isnan().any():
+ breakpoint()
+ print("log_prob is nan")
+ noise_pred = noise_pred[:, : latents.size(1)]
+ latents_dtype = latents.dtype
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0).repeat(latents.shape[0]),
+ latents.float(),
+ noise_level=noise_level,
+ )
+
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return image, all_latents, latent_ids, text_ids, all_log_probs, image_latents
diff --git a/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py
new file mode 100644
index 0000000000000000000000000000000000000000..530ac1b73d206a542dd8c7875e6d027fb631eaff
--- /dev/null
+++ b/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py
@@ -0,0 +1,187 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py
+
+from typing import Any, Dict, List, Optional, Union, Callable
+import torch
+import numpy as np
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+from .sd3_sde_with_logprob import sde_step_with_logprob
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+@torch.no_grad()
+def pipeline_with_logprob(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ noise_level: float = 0.7,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Prepare image embeddings
+ all_latents = [latents]
+ all_log_probs = []
+
+ # 7. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ latents_dtype = latents.dtype
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0).repeat(latents.shape[0]),
+ latents.float(),
+ noise_level=noise_level,
+ )
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return image, all_latents, latent_image_ids, text_ids, all_log_probs
diff --git a/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a28e04b835302dbdba57c24d9aad8670e1613f6
--- /dev/null
+++ b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py
@@ -0,0 +1,198 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+# with the following modifications:
+# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
+# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
+from typing import Any, Dict, List, Optional, Union
+import torch
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
+
+@torch.no_grad()
+def pipeline_with_logprob(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ skip_layer_guidance_scale: float = 2.8,
+ noise_level: float = 0.7,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ # latents = self.prepare_latents(
+ # batch_size * num_images_per_prompt,
+ # num_channels_latents,
+ # height,
+ # width,
+ # prompt_embeds.dtype,
+ # device,
+ # generator,
+ # latents,
+ # ).float()
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ scheduler_kwargs = {}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare image embeddings
+ all_latents = [latents]
+ all_log_probs = []
+ # impor ptbd;
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ # import pdb; pdb.set_trace()
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents_dtype = latents.dtype
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents.float(),
+ noise_level=noise_level,
+ )
+
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return image, all_latents, all_log_probs
diff --git a/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca98c131dd7bf27eeb74b0343c8c7da5d0c155bb
--- /dev/null
+++ b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py
@@ -0,0 +1,1081 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+# with the following modifications:
+# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
+# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
+from typing import Any, Dict, List, Optional, Union
+import torch
+import random
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
+from PIL import Image
+from torchvision import transforms
+
+
+
+@torch.no_grad()
+def pipeline_with_logprob(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ mini_num_image_per_prompt: int = 1,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ skip_layer_guidance_scale: float = 2.8,
+ noise_level: float = 0.7,
+ train_num_steps: int = 1,
+ process_index: int = 0,
+ sample_num_steps: int = 10,
+ random_timestep: Optional[int] = None,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ # import pdb; pdb.set_trace()
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ ).float()
+ # import pdb; pdb.set_trace()
+ # latents = latents.to(prompt_embeds.dtype)
+
+ # 5. Prepare timesteps
+ scheduler_kwargs = {}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ # timesteps = timesteps.to(prompt_embeds.dtype)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ random.seed(process_index)
+ if random_timestep is None:
+ random_timestep = random.randint(0, sample_num_steps//2)
+
+
+ # 6. Prepare image embeddings
+ all_latents = []
+ all_log_probs = []
+ all_timesteps = []
+
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ # 7. Denoising loop
+ # import pdb; pdb.set_trace()
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # import pdb; pdb.set_trace()
+ for i, t in enumerate(timesteps):
+ if i < random_timestep:
+ cur_noise_level = 0
+ elif i == random_timestep:
+ cur_noise_level= noise_level
+ # 将latents repeat mini_num_image_per_prompt次
+ latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ all_latents.append(latents)
+ elif i > random_timestep and i < random_timestep + train_num_steps:
+ cur_noise_level = noise_level
+ else:
+ cur_noise_level= 0
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ # import pdb; pdb.set_trace()
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=tem_prompt_embeds,
+ pooled_projections=tem_pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents_dtype = latents.dtype
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents.float(),
+ noise_level=cur_noise_level,
+ )
+
+ # if latents.dtype != latents_dtype:
+ # latents = latents.to(latents_dtype)
+
+ if i >= random_timestep and i < random_timestep + train_num_steps:
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ all_timesteps.append(t.repeat(len(latents)))
+ # import pdb; pdb.set_trace()
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ reconstructd_image = self.image_processor.postprocess(image, output_type="pil")
+ # reconstructd_image[0].save("0.png")
+ # import pdb; pdb.set_trace()
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ return image, all_latents, all_log_probs, all_timesteps
+
+
+
+@torch.no_grad()
+def pipeline_with_logprob_new(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ mini_num_image_per_prompt: int = 1,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ skip_layer_guidance_scale: float = 2.8,
+ noise_level: float = 0.7,
+ train_num_steps: int = 1,
+ process_index: int = 0,
+ sample_num_steps: int = 10,
+ random_timestep: Optional[int] = None,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ # import pdb; pdb.set_trace()
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+ # import pdb; pdb.set_trace()
+
+ self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ # import pdb; pdb.set_trace()
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ # import pdb; pdb.set_trace()
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # import pdb; pdb.set_trace()
+ # latents = latents.to(prompt_embeds.dtype)
+
+ # 5. Prepare timesteps
+ scheduler_kwargs = {}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ # timesteps = timesteps.to(prompt_embeds.dtype)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ random.seed(process_index)
+ if random_timestep is None:
+ random_timestep = random.randint(0, sample_num_steps//2)
+
+
+ # 6. Prepare image embeddings
+ all_latents = []
+ all_log_probs = []
+ all_timesteps = []
+ # import pdb; pdb.set_trace()
+
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ # 7. Denoising loop
+ # import pdb; pdb.set_trace()
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # import pdb; pdb.set_trace()
+ for i, t in enumerate(timesteps):
+ if i < random_timestep:
+ cur_noise_level = 0
+ elif i == random_timestep:
+ cur_noise_level= noise_level
+ # 将latents repeat mini_num_image_per_prompt次
+ latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ all_latents.append(latents)
+ elif i > random_timestep and i < random_timestep + train_num_steps:
+ cur_noise_level = noise_level
+ else:
+ cur_noise_level= 0
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ # import pdb; pdb.set_trace()
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=tem_prompt_embeds,
+ pooled_projections=tem_pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents_dtype = latents.dtype
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents.float(),
+ noise_level=cur_noise_level,
+ )
+
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+
+ if i >= random_timestep and i < random_timestep + train_num_steps:
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ all_timesteps.append(t.repeat(len(latents)))
+ # import pdb; pdb.set_trace()
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ return image, all_latents, all_log_probs, all_timesteps
+
+
+
+
+@torch.no_grad()
+def pipeline_with_logprob_random(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ mini_num_image_per_prompt: int = 1,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ skip_layer_guidance_scale: float = 2.8,
+ noise_level: float = 0.7,
+ train_num_steps: int = 1,
+ process_index: int = 0,
+ sample_num_steps: int = 10,
+ random_timestep: Optional[int] = None,
+):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ # import pdb; pdb.set_trace()
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+ # import pdb; pdb.set_trace()
+
+ self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ # import pdb; pdb.set_trace()
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ # import pdb; pdb.set_trace()
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ prompt_embeds.shape[0],
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # import pdb; pdb.set_trace()
+ # latents = latents.to(prompt_embeds.dtype)
+
+ # 5. Prepare timesteps
+ scheduler_kwargs = {}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ # timesteps = timesteps.to(prompt_embeds.dtype)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ random.seed(process_index)
+ if random_timestep is None:
+ random_timestep = random.randint(0, sample_num_steps//2)
+
+
+ # 6. Prepare image embeddings
+ all_latents = []
+ all_log_probs = []
+ all_timesteps = []
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ # 7. Denoising loop
+ # import pdb; pdb.set_trace()
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # import pdb; pdb.set_trace()
+ for i, t in enumerate(timesteps):
+ if i < random_timestep:
+ cur_noise_level = 0
+ elif i == random_timestep:
+ cur_noise_level= noise_level
+ # 将latents repeat mini_num_image_per_prompt次
+ # latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
+ # prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ # pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ # negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
+ # negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
+ # if self.do_classifier_free_guidance:
+ # tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ # tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ all_latents.append(latents)
+ elif i > random_timestep and i < random_timestep + train_num_steps:
+ cur_noise_level = noise_level
+ else:
+ cur_noise_level= 0
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ # import pdb; pdb.set_trace()
+ # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=tem_prompt_embeds,
+ pooled_projections=tem_pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ # noise_pred = noise_pred.to(prompt_embeds.dtype)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents_dtype = latents.dtype
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents.float(),
+ noise_level=cur_noise_level,
+ )
+
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+
+ if i >= random_timestep and i < random_timestep + train_num_steps:
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ all_timesteps.append(t.repeat(len(latents)))
+ # import pdb; pdb.set_trace()
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ latents = latents.to(dtype=self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ return image, all_latents, all_log_probs, all_timesteps
+
+
+
+def move_scheduler_to_device(scheduler, device="cuda"):
+ for attr_name in dir(scheduler):
+ attr = getattr(scheduler, attr_name)
+ if isinstance(attr, torch.Tensor):
+ setattr(scheduler, attr_name, attr.to(device))
+ return scheduler
+
+
+def image_to_latent(pipe, images: Union[Image.Image, List[Image.Image]], device="cuda"):
+ # 统一转 list
+ if isinstance(images, Image.Image):
+ images = [images]
+
+ preprocess = transforms.Compose([
+ transforms.Resize((512, 512)),
+ transforms.ToTensor(), # 转 [0,1]
+ transforms.Normalize([0.5], [0.5]) # 映射到 [-1,1]
+ ])
+
+ # 批量处理
+ img_tensors = [preprocess(img) for img in images] # list of [3,512,512]
+ img_tensor = torch.stack(img_tensors, dim=0).to(device, dtype=torch.float32) # [B,3,512,512]
+ # import pdb; pdb.set_trace()
+
+ # 过 VAE 编码
+ latent = pipe.vae.encode(img_tensor).latent_dist.sample()
+ latent = latent * pipe.vae.config.scaling_factor
+ return latent.to(torch.bfloat16) # [B,4,64,64] (假设512输入,缩小8倍)
+
+
+def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+
+
+@torch.no_grad()
+def flux_to_sd3_denoise(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ flux_images=None,
+ device="cuda",
+ output_type: Optional[str] = "pil",
+ num_inference_steps: int = 20,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 256,
+ noise_level: float = 0.7,
+ random_timestep: Optional[int] = None,
+ noise_timestep_ratio: float = 0.4,
+ clip_skip: Optional[int] = None,
+):
+ """
+ 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
+ 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
+ """
+ # 1. 转 latent
+ flux_latent = image_to_latent(self, flux_images, device)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+
+ # 2. 准备 scheduler
+ noise_scheduler = self.scheduler
+ noise_scheduler.set_timesteps(num_inference_steps)
+ timesteps = noise_scheduler.timesteps.to(device)
+
+ # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
+ target_idx = torch.tensor([noise_timestep_ratio], device=device)
+ t = timesteps[target_idx].to(device)
+
+ noise = torch.randn_like(flux_latent)
+ sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
+ latents = (1.0 - sigmas) * flux_latent + sigmas * noise
+ num_channels_latents = self.transformer.config.in_channels
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # latents = self.prepare_latents(
+ # batch_size,
+ # num_channels_latents,
+ # 512,
+ # 512,
+ # prompt_embeds.dtype,
+ # device,
+ # None,
+ # None,
+ # )
+
+
+
+ # import pdb; pdb.set_trace()
+
+ # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
+
+ # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
+ # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
+
+ # 保存到本地
+ # noisy_image.save("noisy_image.png")
+ # import pdb; pdb.set_trace()
+
+ # 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
+ # lora_scale = (
+ # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ # )
+ lora_scale = None
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ # import pdb; pdb.set_trace()
+
+
+ prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
+
+
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ else:
+ tem_prompt_embeds = prompt_embeds
+ tem_pooled_prompt_embeds = pooled_prompt_embeds
+
+ # 5. 从当前 t 开始去噪
+ noise_scheduler.set_timesteps(num_inference_steps)
+ timesteps = noise_scheduler.timesteps.to(device)
+ start_idx = (timesteps >= t[0]).nonzero()[-1].item()
+ timesteps = timesteps[start_idx:]
+
+ all_latents, all_log_probs, all_timesteps = [], [], []
+ noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
+
+ for index, t_cur in enumerate(timesteps):
+ # import pdb; pdb.set_trace()
+ if index==0:
+ all_latents.append(latents)
+
+ if index<2:
+ cur_noise_level = noise_level
+ else:
+ cur_noise_level = 0.0
+
+ latent_model_input = (
+ torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ )
+ t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
+
+ latents_dtype = latents.dtype
+ model_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=t_input,
+ encoder_hidden_states=tem_prompt_embeds,
+ pooled_projections=tem_pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
+ model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # import pdb; pdb.set_trace()
+
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ noise_scheduler,
+ model_pred.float(),
+ t_cur.repeat(len(latents)),
+ latents.float(),
+ noise_level=noise_level,
+ )
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+
+ if index>=0 and index<2:
+ # if index<2:
+ # print(model_pred)
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ all_timesteps.append(t_cur.repeat(len(latents)))
+ # import pdb; pdb.set_trace()
+
+ # 6. 最终解码
+ denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
+
+ image = self.vae.decode(denoised_latent, return_dict=False)[0]
+ # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ return image, all_latents, all_log_probs, all_timesteps
+
+
+
+
+
+@torch.no_grad()
+def flux_to_sd3_denoise_random(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ flux_images=None,
+ device="cuda",
+ output_type: Optional[str] = "pil",
+ num_inference_steps: int = 20,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 256,
+ noise_level: float = 0.7,
+ random_timestep: Optional[int] = None,
+ noise_timestep_ratio: float = 0.4,
+ clip_skip: Optional[int] = None,
+):
+ """
+ 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
+ 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
+ """
+ # 1. 转 latent
+ flux_latent = image_to_latent(self, flux_images, device)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+
+ # 2. 准备 scheduler
+ noise_scheduler = self.scheduler
+ noise_scheduler.set_timesteps(num_inference_steps)
+ timesteps = noise_scheduler.timesteps.to(device)
+
+ # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
+ # t = timesteps[target_idx].to(device)
+
+ # noise = torch.randn_like(flux_latent)
+ # sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
+ # latents = (1.0 - sigmas) * flux_latent + sigmas * noise
+
+ target_idx = torch.tensor([random.randint(5, 10)], device=device)
+ t = timesteps[target_idx].to(device)
+ # 生成标准高斯噪声
+ noise = torch.randn_like(flux_latent)
+ # 获取对应的 sigma
+ sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
+ # 给 latent 加噪
+ latents = (1.0 - sigmas) * flux_latent + sigmas * noise
+
+ # import pdb; pdb.set_trace()
+
+ # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
+
+ # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
+ # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
+
+ # 保存到本地
+ # noisy_image.save("noisy_image.png")
+ # import pdb; pdb.set_trace()
+
+ # 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
+ # lora_scale = (
+ # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ # )
+ lora_scale = None
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ # import pdb; pdb.set_trace()
+
+
+ prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
+
+
+ if self.do_classifier_free_guidance:
+ tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ else:
+ tem_prompt_embeds = prompt_embeds
+ tem_pooled_prompt_embeds = pooled_prompt_embeds
+
+ # 5. 从当前 t 开始去噪
+ noise_scheduler.set_timesteps(num_inference_steps)
+ timesteps = noise_scheduler.timesteps.to(device)
+ start_idx = (timesteps >= t[0]).nonzero()[-1].item()
+ timesteps = timesteps[start_idx:]
+
+ all_latents, all_log_probs, all_timesteps = [], [], []
+ noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
+
+ for index, t_cur in enumerate(timesteps):
+ if index==0:
+ all_latents.append(latents)
+
+ latent_model_input = (
+ torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ )
+ t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
+
+ latents_dtype = latents.dtype
+ model_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=t_input,
+ encoder_hidden_states=tem_prompt_embeds,
+ pooled_projections=tem_pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
+ model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # import pdb; pdb.set_trace()
+
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ noise_scheduler,
+ model_pred.float(),
+ t_cur.repeat(len(latents)),
+ latents.float(),
+ noise_level=noise_level,
+ )
+ if latents.dtype != latents_dtype:
+ latents = latents.to(latents_dtype)
+
+ # if index>=2 and index<4:
+ if index<2:
+ # print(model_pred)
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+ all_timesteps.append(t_cur.repeat(len(latents)))
+ # import pdb; pdb.set_trace()
+
+ # 6. 最终解码
+ denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
+
+ image = self.vae.decode(denoised_latent, return_dict=False)[0]
+ # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ return image, all_latents, all_log_probs, all_timesteps
diff --git a/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py b/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py
new file mode 100644
index 0000000000000000000000000000000000000000..47ae91bbd2160b653ddbc663390ebe912c4c46fd
--- /dev/null
+++ b/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py
@@ -0,0 +1,139 @@
+# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
+# We adapt it from flow to flow matching.
+
+import math
+from typing import Optional, Union
+import torch
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+
+
+
+def sde_step_with_logprob(
+ self: FlowMatchEulerDiscreteScheduler,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ noise_level: float = 0.7,
+ prev_sample: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+):
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
+ process from the learned model outputs (most often the predicted velocity).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned flow model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ """
+ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
+ model_output=model_output.float()
+ sample=sample.float()
+ if prev_sample is not None:
+ prev_sample=prev_sample.float()
+
+ step_index = [self.index_for_timestep(t) for t in timestep]
+ prev_step_index = [step+1 for step in step_index]
+ sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
+ sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
+ sigma_max = self.sigmas[1].item()
+ dt = sigma_prev - sigma
+
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
+ # import pdb; pdb.set_trace()
+
+ # our sde
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
+
+ if prev_sample is None:
+ variance_noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype,
+ )
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
+
+ log_prob = (
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
+ )
+
+ # mean along all but batch dimension
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
+
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
+
+
+
+def sde_step_with_logprob_new(
+ self: FlowMatchEulerDiscreteScheduler,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ noise_level: float = 0.7,
+ prev_sample: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+):
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
+ process from the learned model outputs (most often the predicted velocity).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned flow model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ """
+ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
+ model_output=model_output.float()
+ sample=sample.float()
+ if prev_sample is not None:
+ prev_sample=prev_sample.float()
+
+ step_index = [self.index_for_timestep(t) for t in timestep]
+ prev_step_index = [step+1 for step in step_index]
+ sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
+ sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
+ sigma_max = self.sigmas[1].item()
+ dt = sigma_prev - sigma
+
+ # Flow-SDE
+ #std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt)
+ # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
+
+ # Flow-CPS
+ std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
+ pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
+ noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
+ prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
+ # import pdb; pdb.set_trace()
+
+ if prev_sample is None:
+ variance_noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype,
+ )
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
+
+ # remove all constants
+ log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
+
+ # mean along all but batch dimension
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
+
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
\ No newline at end of file
diff --git a/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py b/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..1100e280c247293707be5bbe2c5f08a27c6793b9
--- /dev/null
+++ b/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import torch
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length=512,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(text_encoder, "module"):
+ dtype = text_encoder.module.dtype
+ else:
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ if hasattr(text_encoder, "module"):
+ dtype = text_encoder.module.dtype
+ else:
+ dtype = text_encoder.dtype
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if hasattr(text_encoders[0], "module"):
+ dtype = text_encoders[0].module.dtype
+ else:
+ dtype = text_encoders[0].dtype
+
+ pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoders[0],
+ tokenizer=tokenizers[0],
+ prompt=prompt,
+ device=device if device is not None else text_encoders[0].device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
+ )
+
+ prompt_embeds = _encode_prompt_with_t5(
+ text_encoder=text_encoders[1],
+ tokenizer=tokenizers[1],
+ max_sequence_length=max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[1].device,
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
+ )
+
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
diff --git a/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py b/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd908170bbc2d773109cebb799b2a4d1042335e6
--- /dev/null
+++ b/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import torch
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ clip_tokenizers = tokenizers[:2]
+ clip_text_encoders = text_encoders[:2]
+
+ clip_prompt_embeds_list = []
+ clip_pooled_prompt_embeds_list = []
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ device=device if device is not None else text_encoder.device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
+ )
+ clip_prompt_embeds_list.append(prompt_embeds)
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
+
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
+
+ t5_prompt_embed = _encode_prompt_with_t5(
+ text_encoders[-1],
+ tokenizers[-1],
+ max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
+ device=device if device is not None else text_encoders[-1].device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+
+ return prompt_embeds, pooled_prompt_embeds
diff --git a/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5b5f99a36db6a8c5317ad32d5bb44976678156b
--- /dev/null
+++ b/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py
@@ -0,0 +1,373 @@
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+import torch
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+from diffusers.utils.torch_utils import randn_tensor
+import math
+import numpy as np
+# import logger
+
+def sde_step_with_logprob(
+ self: UniPCMultistepScheduler,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ prev_sample: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+ determistic: bool = False,
+ return_pixel_log_prob: bool = False,
+ return_dt_and_std_dev_t: bool = False
+):
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
+ process from the learned model outputs (most often the predicted velocity).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned flow model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ """
+ # prev_sample_mean, we must convert all variable to fp32
+ model_output=model_output.float()
+ sample=sample.float()
+ if prev_sample is not None:
+ prev_sample=prev_sample.float()
+
+ step_index = [self.index_for_timestep(t) for t in timestep]
+ prev_step_index = [step+1 for step in step_index]
+
+ self.sigmas = self.sigmas.to(sample.device)
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1, 1)
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1, 1)
+ sigma_max = self.sigmas[1].item()
+ sigma_min = self.sigmas[-1].item()
+ dt = sigma_prev - sigma
+
+ std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
+
+ if prev_sample is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
+ " `prev_sample` stays `None`."
+ )
+
+ if prev_sample is None:
+ variance_noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype,
+ )
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
+
+ # No noise is added during evaluation
+ if determistic:
+ prev_sample = sample + dt * model_output
+
+ log_prob = (
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
+ )
+
+ # mean along all but batch dimension
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
+
+ if return_dt_and_std_dev_t:
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt)
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
+
+def wan_pipeline_with_logprob(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ determistic: bool = False,
+ kl_reward: float = 0.0,
+ return_pixel_log_prob: bool = False,
+):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ print(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ all_latents = [latents]
+ all_log_probs = []
+ all_kl = []
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ # print(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latents_ori = latents.clone()
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.to(prompt_embeds.dtype)
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents.float(),
+ determistic=determistic,
+ return_pixel_log_prob=return_pixel_log_prob
+ )
+ prev_latents = latents.clone()
+
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # use kl_reward & is sampling process
+ if kl_reward>0 and not determistic:
+ latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
+ with self.transformer.disable_adapter():
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.to(prompt_embeds.dtype)
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ _, ref_log_prob, ref_prev_latents_mean, ref_std_dev_t = sde_step_with_logprob(
+ self.scheduler,
+ noise_pred.float(),
+ t.unsqueeze(0),
+ latents_ori.float(),
+ prev_sample=prev_latents.float(),
+ determistic=determistic,
+ )
+ assert std_dev_t == ref_std_dev_t
+ kl = (prev_latents_mean - ref_prev_latents_mean)**2 / (2 * std_dev_t**2)
+ kl = kl.mean(dim=tuple(range(1, kl.ndim)))
+ all_kl.append(kl)
+ else:
+ # no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
+ all_kl.append(torch.zeros(len(latents), device=latents.device))
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # if XLA_AVAILABLE:
+ # xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video, all_latents, all_log_probs, all_kl)
+
+ return WanPipelineOutput(frames=video), all_latents, all_log_probs, all_kl
diff --git a/adv_grpo/diffusers_patch/wan_prompt_embedding.py b/adv_grpo/diffusers_patch/wan_prompt_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc73b64fea6889054ec5cf72557bfa3e26847a3d
--- /dev/null
+++ b/adv_grpo/diffusers_patch/wan_prompt_embedding.py
@@ -0,0 +1,97 @@
+import torch
+from typing import Any, Callable, Dict, List, Optional, Union
+
+def _get_t5_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 226,
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+):
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+def encode_prompt(
+ text_encoder,
+ tokenizer,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int = 226,
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = text_encoder[0].device
+ dtype = text_encoder[0].dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ prompt_embeds = _get_t5_prompt_embeds(
+ text_encoder=text_encoder[0],
+ tokenizer=tokenizer[0],
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds
\ No newline at end of file
diff --git a/adv_grpo/ema.py b/adv_grpo/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcdfc2caf0f1b4e11c85965f7747bdc4f52833cb
--- /dev/null
+++ b/adv_grpo/ema.py
@@ -0,0 +1,88 @@
+# Copied from another repo, but I can't remember exactly which one.
+
+from collections.abc import Iterable
+
+import torch
+
+
+class EMAModuleWrapper:
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float = 0.9999,
+ update_step_interval: int = 1,
+ device: torch.device | None = None,
+ ):
+ parameters = list(parameters)
+ self.ema_parameters = [p.clone().detach().to(device) for p in parameters]
+
+ self.temp_stored_parameters = None
+
+ self.decay = decay
+ self.update_step_interval = update_step_interval
+ self.device = device
+
+ # TODO: add an automatic decay calculation based on this formula:
+ # The impact of the last n steps can be calculated as:
+ # impact = 1-(decay^n)
+ # The number of steps needed to reach a specific impact is:
+ # n = log_decay(1-impact)
+ # The decay needed to reach a specific impact after n steps is:
+ # decay = (1-impact)^(1/n)
+
+ def get_current_decay(self, optimization_step) -> float:
+ return min(
+ (1 + optimization_step) / (10 + optimization_step),
+ self.decay
+ )
+
+ @torch.no_grad()
+ def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step):
+ parameters = list(parameters)
+
+ one_minus_decay = 1 - self.get_current_decay(optimization_step)
+
+ if (optimization_step + 1) % self.update_step_interval == 0:
+ for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
+ if parameter.requires_grad:
+ if ema_parameter.device == parameter.device:
+ ema_parameter.add_(one_minus_decay * (parameter - ema_parameter))
+ else:
+ # in place calculations to save memory
+ parameter_copy = parameter.detach().to(ema_parameter.device)
+ parameter_copy.sub_(ema_parameter)
+ parameter_copy.mul_(one_minus_decay)
+ ema_parameter.add_(parameter_copy)
+ del parameter_copy
+
+ def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None:
+ self.device = device
+ self.ema_parameters = [
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
+ for p in self.ema_parameters
+ ]
+
+ def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True) -> None:
+ if store_temp:
+ self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters]
+
+ parameters = list(parameters)
+ for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
+ parameter.data.copy_(ema_parameter.to(parameter.device).data)
+
+ def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True):
+ parameter.data.copy_(temp_parameter.data)
+
+ self.temp_stored_parameters = None
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ self.decay = self.decay if self.decay else state_dict.get("decay", self.decay)
+ self.ema_parameters = state_dict.get("ema_parameters")
+ self.to(self.device)
+
+ def state_dict(self) -> dict:
+ return {
+ "decay": self.decay,
+ "ema_parameters": self.ema_parameters,
+ }
diff --git a/adv_grpo/imagereward_scorer.py b/adv_grpo/imagereward_scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab192eb17ee54474c592e98a91945132365fab46
--- /dev/null
+++ b/adv_grpo/imagereward_scorer.py
@@ -0,0 +1,40 @@
+from transformers import AutoProcessor, AutoModel
+from PIL import Image
+import torch
+import ImageReward as RM
+
+class ImageRewardScorer(torch.nn.Module):
+ def __init__(self, device="cuda", dtype=torch.float32):
+ super().__init__()
+ self.model_path = "ImageReward-v1.0"
+ self.device = device
+ self.dtype = dtype
+ self.model = RM.load(self.model_path, device=device).eval().to(dtype=dtype)
+ self.model.requires_grad_(False)
+
+ @torch.no_grad()
+ def __call__(self, prompts, images):
+ rewards = []
+ for prompt,image in zip(prompts, images):
+ _, reward = self.model.inference_rank(prompt, [image])
+ rewards.append(reward)
+ return rewards
+
+# Usage example
+def main():
+ scorer = ImageRewardScorer(
+ device="cuda",
+ dtype=torch.float32
+ )
+
+ images=[
+ "astronaut.jpg",
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+ print(scorer(prompts, pil_images))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/inflated_layers.py b/adv_grpo/inflated_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4eb14e9ae7637821db6bf58bd8149aca8b0b16b
--- /dev/null
+++ b/adv_grpo/inflated_layers.py
@@ -0,0 +1,305 @@
+from functools import partial
+from typing import Literal
+from einops import rearrange
+from torch import Tensor
+from torch.nn import ConvTranspose2d, ConvTranspose3d
+
+from flow_grpo.inflated_lib import (
+ MemoryState,
+ extend_head,
+ inflate_bias,
+ inflate_distribution_bias,
+ inflate_distribution_weight,
+ inflate_weight,
+ modify_state_dict,
+)
+from flow_grpo.conv_gradfix import GradFixConv2d, GradFixConv3d
+
+VERBOSE = False
+
+_inflation_mode_t = (Literal["none", "flatten", "partial_flatten", "pad", "tile"],)
+_direction_t = Literal["", "out", "in"]
+
+
+class InflatedCausalConv3d(GradFixConv3d):
+ """
+ Note:
+ To align the behavior of pretrained 2D models,
+ if you compose a video clip from a single image by:
+ - duplicating: set shape_norm = True
+ - padding zeros: set shape_norm = False
+ to avoid gaps in the beginning of training process.
+ """
+
+ def __init__(
+ self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
+ ):
+ self.shape_norm = shape_norm
+ self.inflation_mode = inflation_mode
+ self.padding_bank = None
+ super().__init__(*args, **kwargs)
+ self.temporal_padding = self.padding[0]
+ self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
+
+ def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
+ bank_size = self.stride[0] - self.kernel_size[0]
+ padding_bank = (
+ input[:, :, bank_size:].detach()
+ if (bank_size != 0 and memory_state != MemoryState.DISABLED)
+ else None
+ )
+ if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
+ input = extend_head(input, memory=self.padding_bank)
+ else:
+ input = extend_head(input, times=self.temporal_padding * 2)
+ if memory_state != MemoryState.DISABLED and not self.training:
+ self.padding_bank = padding_bank
+ return super().forward(input)
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ if self.inflation_mode == "none":
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ else:
+ # NOTE: need to switch off strict
+ super()._load_from_state_dict(
+ modify_state_dict(
+ self,
+ state_dict,
+ prefix,
+ verbose=VERBOSE,
+ inflate_weight_fn=partial(inflate_weight, position="tail"),
+ inflate_bias_fn=partial(inflate_bias, position="tail"),
+ ),
+ prefix,
+ local_metadata,
+ False,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+
+class InflatedDistributionCausalConv3d(GradFixConv3d):
+ """
+ Note:
+ Direction:
+ - out: this layer generates mean/std of some distribution;
+ - in: this layer takes tensors sampled from output of `out` layer as input.
+ """
+
+ def __init__(
+ self,
+ *args,
+ direction: _direction_t,
+ inflation_mode: _inflation_mode_t,
+ shape_norm: bool = True,
+ **kwargs,
+ ):
+ self.shape_norm = shape_norm
+ self.inflation_mode = inflation_mode
+ self.direction = direction
+ self.padding_bank = None
+ super().__init__(*args, **kwargs)
+ self.temporal_padding = self.padding[0]
+ self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
+
+ def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
+ bank_size = self.stride[0] - self.kernel_size[0]
+ padding_bank = (
+ input[:, :, bank_size:].detach()
+ if (bank_size != 0 and memory_state != MemoryState.DISABLED)
+ else None
+ )
+ if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
+ input = extend_head(input, memory=self.padding_bank)
+ else:
+ input = extend_head(input, times=self.temporal_padding * 2)
+ if memory_state != MemoryState.DISABLED and not self.training:
+ self.padding_bank = padding_bank
+ return super().forward(input)
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ if self.inflation_mode == "none":
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ else:
+ super()._load_from_state_dict(
+ modify_state_dict(
+ self,
+ state_dict,
+ prefix,
+ verbose=VERBOSE,
+ inflate_weight_fn=partial(
+ inflate_distribution_weight, direction=self.direction, position="tail"
+ ),
+ inflate_bias_fn=partial(
+ inflate_distribution_bias, direction=self.direction, position="tail"
+ ),
+ ),
+ prefix,
+ local_metadata,
+ False,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+
+class InflatedConvTranspose3d(ConvTranspose3d):
+ # Note: It's not a causal one.
+ def __init__(
+ self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
+ ):
+ self.shape_norm = shape_norm
+ self.inflation_mode = inflation_mode
+ super().__init__(*args, **kwargs)
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ if self.inflation_mode == "none":
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ else:
+ # NOTE: need to switch off strict
+ super()._load_from_state_dict(
+ modify_state_dict(
+ self,
+ state_dict,
+ prefix,
+ verbose=VERBOSE,
+ inflate_weight_fn=partial(inflate_weight, position="center"),
+ inflate_bias_fn=partial(inflate_bias, position="center"),
+ ),
+ prefix,
+ local_metadata,
+ False,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+
+class FlattenedConvTranspose3d(ConvTranspose2d):
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
+ output = rearrange(input, "b c f h w -> (b f) c h w")
+ output = super().forward(output)
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
+ return output
+
+
+class FlattenedConv3d(GradFixConv2d):
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
+ output = rearrange(input, "b c f h w -> (b f) c h w")
+ output = super().forward(output)
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
+ return output
+
+
+def init_causal_conv3d(
+ *args,
+ inflation_mode: _inflation_mode_t,
+ direction: _direction_t = "",
+ partial_switch: bool = False,
+ **kwargs,
+):
+ """
+ Initialize a Causal-3D convolution layer.
+ Parameters:
+ inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have.
+ - none: No inflation will be conducted.
+ The loading logic of state dict will fall back to default.
+ - flatten: It will produce a `fake` 3D layer,
+ which simply squeeze the axis of batch size and depth together,
+ and then conduct 2D convolution.
+ - partial_flatten:
+ - layers with `partial_switch` on: using `none` mode.
+ - layers with `partial_switch` off: using `flatten` mode.
+ - pad / tile: Refer to the definition of `InflatedCausalConv3d`.
+ direction:
+ - empty string: Ordinary causal convolution layer.
+ - out / in: Refer to the definition of `InflatedDistributionCausalConv3d`.
+ partial_switch: Only works when `inflation_mode` is `partial_flatten`.
+ """
+ stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
+ padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
+ if "flatten" in inflation_mode:
+ if (
+ (
+ (not stride)
+ or isinstance(stride, int)
+ or (isinstance(stride, list or tuple) and len(stride) < 3)
+ ) # if the config of stride can be used for 2D conv
+ and (
+ (not padding)
+ or isinstance(padding, int)
+ or (isinstance(padding, list or tuple) and len(padding) < 3)
+ ) # if the config of padding can be used for 2D conv
+ and (("partial" not in inflation_mode) or (not partial_switch))
+ # if it's fully-flatten mode, or with `partial_switch` off
+ ):
+ return FlattenedConv3d(*args, **kwargs)
+ else:
+ return InflatedCausalConv3d(*args, inflation_mode="none", **kwargs)
+ # Force-override
+ else:
+ if direction:
+ return InflatedDistributionCausalConv3d(
+ *args, direction=direction, inflation_mode=inflation_mode, **kwargs
+ )
+ else:
+ return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)
+
+
+def init_transposed_conv3d(
+ *args, inflation_mode: _inflation_mode_t, partial_switch: bool = False, **kwargs
+):
+ stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
+ padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
+ if "flatten" in inflation_mode:
+ if (
+ (
+ (not stride)
+ or isinstance(stride, int)
+ or (isinstance(stride, list or tuple) and len(stride) < 3)
+ )
+ and (
+ (not padding)
+ or isinstance(padding, int)
+ or (isinstance(padding, list or tuple) and len(padding) < 3)
+ )
+ or (("partial" in inflation_mode) and not partial_switch)
+ ):
+ return FlattenedConvTranspose3d(*args, **kwargs)
+ else:
+ return InflatedConvTranspose3d(
+ *args, inflation_mode="none", **kwargs
+ ) # Force-override
+ else:
+ return InflatedConvTranspose3d(*args, inflation_mode=inflation_mode, **kwargs)
\ No newline at end of file
diff --git a/adv_grpo/inflated_lib.py b/adv_grpo/inflated_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..589ffb773351ade6ed806137e8b3df03b11bde27
--- /dev/null
+++ b/adv_grpo/inflated_lib.py
@@ -0,0 +1,346 @@
+import math
+from enum import Enum
+from typing import Optional
+import numpy as np
+import torch
+from diffusers.models.attention_processor import SpatialNorm
+from diffusers.models.normalization import RMSNorm
+from einops import rearrange
+from torch import Tensor, nn
+
+# from common.logger import get_logger
+
+# logger = get_logger(__name__)
+
+
+class MemoryState(Enum):
+ """
+ State[Disabled]: No memory bank will be enabled.
+ State[Initializing]: The model is handling the first clip,
+ need to reset / initialize the memory bank.
+ State[Active]: There has been some data in the memory bank.
+ """
+
+ DISABLED = 0
+ INITIALIZING = 1
+ ACTIVE = 2
+
+
+def norm_wrapper(
+ norm_layer: nn.Module,
+ x: torch.Tensor,
+ y: Optional[torch.Tensor] = None,
+ keep_causal: bool = False,
+) -> torch.Tensor:
+ if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)):
+ if x.ndim == 4:
+ x = rearrange(x, "b c h w -> b h w c")
+ x = norm_layer(x)
+ x = rearrange(x, "b h w c -> b c h w")
+ return x
+ if x.ndim == 5:
+ x = rearrange(x, "b c t h w -> b t h w c")
+ x = norm_layer(x)
+ x = rearrange(x, "b t h w c -> b c t h w")
+ return x
+ if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
+ if x.ndim <= 4 or (not keep_causal and not isinstance(norm_layer, nn.BatchNorm2d)):
+ return norm_layer(x)
+ if x.ndim == 5:
+ t = x.size(2)
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = norm_layer(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+ if isinstance(norm_layer, SpatialNorm):
+ t = -1
+ if x.ndim == 5:
+ t = x.size(2)
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ if y.ndim == 5:
+ y = rearrange(y, "b c t h w -> (b t) c h w")
+ if x.ndim != 4 or y.ndim != 4:
+ raise NotImplementedError
+ x = norm_layer(x, y)
+ if t != -1:
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+ raise NotImplementedError
+
+
+def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
+ """
+ Remove duplicated first frame features in the up-sampling process.
+ """
+ if times == 0:
+ return tensor
+ return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
+
+
+def extend_head(
+ tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None
+) -> Tensor:
+ """
+ When memory is None:
+ - Duplicate first frame features in the down-sampling process.
+ When memory is not None:
+ - Concatenate memory features with the input features to keep temporal consistency.
+ """
+ if times == 0:
+ return tensor
+ if memory is not None:
+ return torch.cat((memory.to(tensor), tensor), dim=2)
+ else:
+ tile_repeat = np.ones(tensor.ndim).astype(int)
+ tile_repeat[2] = times
+ return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2)
+
+
+def fill_weight_in_depth(weight: torch.Tensor, source: torch.Tensor, position: str):
+ """
+ Inflate a 2D convolution weight matrix to a 3D one by padding zeros in the channel of depth.
+ Parameters:
+ weight: The weight parameters of 3D conv kernel to be initialized.
+ source: The weight parameters of 2D conv kernel to be inflated.
+ position: Where to insert the 2D weights, can be chosen from
+ - tail: Pad zeros in the front of the 2D kernel. Used for casual inflation.
+ - center: Pad zeros around the 2D kernel. Used for normal inflation.
+ """
+ assert position in ["tail", "center"], "Unsupported fill-in position for weight inflation."
+ depth = weight.size(2)
+ weight.fill_(0.0)
+ if position == "center":
+ if depth % 2 == 1:
+ weight[:, :, depth // 2].copy_(source.squeeze(2))
+ else:
+ weight[:, :, depth // 2].copy_(source.squeeze(2) / 2.0)
+ weight[:, :, depth // 2 - 1].copy_(source.squeeze(2) / 2.0)
+ else:
+ if depth % 2 == 1:
+ weight[:, :, -1].copy_(source.squeeze(2))
+ else:
+ weight[:, :, -1].copy_(source.squeeze(2) / 2.0)
+ weight[:, :, -2].copy_(source.squeeze(2) / 2.0)
+ return weight
+
+
+def inflate_weight(
+ weight_2d: torch.Tensor,
+ weight_3d: torch.Tensor,
+ shape_norm: bool,
+ name: str,
+ inflation_mode: str,
+ position: str,
+ verbose: bool = True,
+):
+ """
+ Inflate a 2D convolution weight matrix to a 3D one.
+ Parameters:
+ weight_2d: The weight matrix of 2D conv to be inflated.
+ weight_3d: The weight matrix of 3D conv to be initialized.
+ inflation_mode: the mode of inflation
+ - pad: pad zeros around 2D kernel.
+ - tile: tile 2D kernel along the depth axis.
+
+ shape_norm: Whether to scale the parameters of 2D kernel so that the untrained
+ inflated model behaves exactly the same as the original 2D model
+ in the reconstruction of image and video. recommend to switch it on.
+
+ name: The name of inflated module. Only be used in logging.
+ position: Refer to the doc of `fill_weight_in_depth`.
+ Only works when `inflation_mode` is `pad`.
+ verbose: Whether to log information about inflation.
+ """
+ assert inflation_mode in ["pad", "tile"]
+ depth = weight_3d.size(2)
+ tgt_out, tgt_in = weight_3d.size()[:2]
+ src_out, src_in = weight_2d.size()[:2]
+ assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
+ out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
+ depth_factor = 1 if inflation_mode == "pad" else depth
+ factor = (depth_factor * math.sqrt(out_fan) * math.sqrt(in_fan)) if shape_norm else 1
+ with torch.no_grad():
+ channel_inflation = weight_2d.unsqueeze(2).repeat(out_fan, in_fan, 1, 1, 1) / factor
+ if inflation_mode == "tile":
+ weight_3d.copy_(channel_inflation.repeat(1, 1, depth, 1, 1))
+ else:
+ weight_3d = fill_weight_in_depth(weight_3d, channel_inflation, position)
+ if verbose:
+ print(
+ f"*** {name}weight {weight_2d.size()} is inflated to {weight_3d.size()} ***"
+ )
+ return weight_3d
+
+
+def inflate_bias(
+ bias_2d: torch.Tensor,
+ bias_3d: torch.Tensor,
+ shape_norm: bool,
+ name: str,
+ inflation_mode: str,
+ position: str,
+ verbose: bool = True,
+):
+ """
+ Inflate a 2D convolution bias tensor to a 3D one
+ Parameters:
+ bias_2d: The bias tensor of 2D conv to be inflated.
+ bias_3d: The bias tensor of 3D conv to be initialized.
+ shape_norm: Refer to `inflate_weight` function.
+ name: The name of inflated module. Only be used in logging.
+ inflation_mode: Placeholder to align `inflate_weight`.
+ position: Placeholder to align `inflate_weight`.
+ verbose: Whether to log information about inflation.
+ """
+ tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
+ assert tgt_ch % src_ch == 0
+ fan = tgt_ch // src_ch
+ factor = math.sqrt(fan) if shape_norm else 1
+ with torch.no_grad():
+ bias_3d.copy_(bias_2d.repeat(fan) / factor)
+ if (tgt_ch != src_ch) and verbose:
+ print(f"*** {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***")
+ return bias_3d
+
+
+def inflate_distribution_weight(
+ weight_2d: torch.Tensor,
+ weight_3d: torch.Tensor,
+ shape_norm: bool,
+ name: str,
+ direction: str,
+ inflation_mode: str,
+ position: str,
+ verbose: bool = True,
+):
+ """
+ Inflate a 2D convolution weight matrix to a 3D one.
+ Note: Different from `inflate_weight`,
+ it's designed for `quant_conv` or `post_quant_conv` layers.
+ i.e., a convolution layer used to produce `mean` and `std` of some distribution,
+ or its subsequent layer.
+ Parameters: Refer to `inflate_weight`.
+ direction:
+ - out: this layer generates `mean` and `std` of some distribution.
+ - in: this layer takes tensors sampled from output of `out` layer as input.
+ """
+ assert inflation_mode in ["pad", "tile"]
+ depth = weight_3d.size(2)
+ tgt_out, tgt_in = weight_3d.size()[:2]
+ src_out, src_in = weight_2d.size()[:2]
+ assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
+ out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
+ depth_factor = 1 if inflation_mode == "pad" else depth
+ if direction == "out":
+ factor = (depth_factor * math.sqrt(in_fan)) if shape_norm else 1
+ with torch.no_grad():
+ in_inflation = weight_2d.unsqueeze(2).repeat(1, in_fan, 1, 1, 1) / factor
+ # [src_out, src_in, k_h, k_w] -> [src_out, tgt_in, 1, k_h, k_w]
+ out_mean_weight, out_std_weight = torch.chunk(in_inflation, 2, dim=0)
+ mean_slice = slice(src_out // 2)
+ std_slice = slice(tgt_out // 2, tgt_out // 2 + src_out // 2)
+ if inflation_mode == "tile":
+ weight_3d[mean_slice] = out_mean_weight
+ weight_3d[std_slice] = out_std_weight
+ # Other part will be randomly initialized.
+ else:
+ weight_3d[mean_slice] = fill_weight_in_depth(
+ weight_3d[mean_slice], out_mean_weight, position
+ )
+ weight_3d[std_slice] = fill_weight_in_depth(
+ weight_3d[std_slice], out_std_weight, position
+ )
+ # Other part will be randomly initialized.
+ elif direction == "in":
+ factor = (depth_factor * math.sqrt(out_fan)) if shape_norm else 1
+ with torch.no_grad():
+ out_inflation = weight_2d.unsqueeze(2).repeat(out_fan, 1, 1, 1, 1) / factor
+ # [src_out, src_in, k_h, k_w] -> [tgt_out, src_in, 1, k_h, k_w]
+ if inflation_mode == "tile":
+ weight_3d[:, :src_in] = out_inflation
+ else:
+ weight_3d[:, :src_in] = fill_weight_in_depth(
+ weight_3d[:, :src_in], out_inflation, position
+ )
+ weight_3d[:, src_in:].fill_(0.0)
+ else:
+ raise NotImplementedError
+ if verbose:
+ print(
+ f"*** [Distribution] {name}weight {weight_2d.size()} "
+ f"is inflated to {weight_3d.size()} ***"
+ )
+ return weight_3d
+
+
+def inflate_distribution_bias(
+ bias_2d: torch.Tensor,
+ bias_3d: torch.Tensor,
+ shape_norm: bool,
+ name: str,
+ direction: str,
+ inflation_mode: str,
+ position: str,
+ verbose: bool = True,
+):
+ """
+ The combination of `inflate_distribution_weight` and `inflate_bias`.
+ """
+ tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
+ assert tgt_ch % src_ch == 0
+ if direction == "out":
+ with torch.no_grad():
+ out_mean_bias, out_std_bias = torch.chunk(bias_2d, 2, dim=0)
+ bias_3d[: src_ch // 2] = out_mean_bias
+ bias_3d[tgt_ch // 2 : tgt_ch // 2 + src_ch // 2] = out_std_bias
+ elif direction == "in":
+ with torch.no_grad():
+ bias_3d[:src_ch] = bias_2d
+ bias_3d[src_ch:].fill_(0.0)
+ else:
+ raise NotImplementedError
+ if verbose:
+ print(
+ f"*** [Distribution] {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***"
+ )
+ return bias_3d
+
+
+def modify_state_dict(
+ layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn, verbose=False
+):
+ """
+ the main function to inflated 2D parameters to 3D.
+ """
+ weight_name = prefix + "weight"
+ bias_name = prefix + "bias"
+ if weight_name in state_dict:
+ weight_2d = state_dict[weight_name]
+ if (
+ weight_2d.dim() == 4
+ ): # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
+ weight_3d = inflate_weight_fn(
+ weight_2d=weight_2d,
+ weight_3d=layer.weight,
+ shape_norm=layer.shape_norm,
+ name=prefix,
+ verbose=verbose,
+ inflation_mode=layer.inflation_mode,
+ )
+ state_dict[weight_name] = weight_3d
+ else:
+ return state_dict
+ # It's a 3d state dict, should not do inflation on both bias and weight.
+ if bias_name in state_dict:
+ bias_2d = state_dict[bias_name]
+ if bias_2d.dim() == 1: # Assuming the 2D biases are 1D tensors (out_channels,)
+ bias_3d = inflate_bias_fn(
+ bias_2d=bias_2d,
+ bias_3d=layer.bias,
+ shape_norm=layer.shape_norm,
+ name=prefix,
+ verbose=verbose,
+ inflation_mode=layer.inflation_mode,
+ )
+ state_dict[bias_name] = bias_3d
+ return state_dict
\ No newline at end of file
diff --git a/adv_grpo/ocr.py b/adv_grpo/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..a973ef2f7daa12f4e7efadf6c5a998993b30078f
--- /dev/null
+++ b/adv_grpo/ocr.py
@@ -0,0 +1,138 @@
+from paddleocr import PaddleOCR
+import torch
+import numpy as np
+from Levenshtein import distance
+from typing import List, Union, Tuple
+from PIL import Image
+
+class OcrScorer:
+ def __init__(self, use_gpu: bool = False):
+ """
+ OCR reward calculator
+ :param use_gpu: Whether to use GPU acceleration for PaddleOCR
+ """
+ self.ocr = PaddleOCR(
+ use_angle_cls=False,
+ lang="en",
+ use_gpu=use_gpu,
+ show_log=False # Disable unnecessary log output
+ )
+
+ @torch.no_grad()
+ def __call__(self,
+ images: Union[List[Image.Image], List[np.ndarray]],
+ prompts: List[str]) -> torch.Tensor:
+ """
+ Calculate OCR reward
+ :param images: List of input images (PIL or numpy format)
+ :param prompts: Corresponding target text list
+ :return: Reward tensor (CPU)
+ """
+ # import pdb; pdb.set_trace()
+ prompts = [prompt.split('"')[1] for prompt in prompts]
+ rewards = []
+ # Ensure input lengths are consistent
+ assert len(images) == len(prompts), "Images and prompts must have the same length"
+ for img, prompt in zip(images, prompts):
+ # Convert image format
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ try:
+ # OCR recognition
+ result = self.ocr.ocr(img, cls=False)
+ # Extract recognized text (handle possible multi-line results)
+ recognized_text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
+
+ recognized_text = recognized_text.replace(' ', '').lower()
+ prompt = prompt.replace(' ', '').lower()
+ if prompt in recognized_text:
+ dist = 0
+ else:
+ dist = distance(recognized_text, prompt)
+ # import pdb; pdb.set_trace()
+ # Recognized many unrelated characters, only add one character penalty
+ if dist > len(prompt):
+ dist = len(prompt)
+
+ except Exception as e:
+ # Error handling (e.g., OCR parsing failure)
+ print(f"OCR processing failed: {str(e)}")
+ dist = len(prompt) # Maximum penalty
+ reward = 1-dist/(len(prompt))
+ rewards.append(reward)
+
+ return rewards
+
+class OcrScorer_video_or_image:
+ def __init__(self, use_gpu: bool = False):
+ """
+ OCR reward calculator
+ :param use_gpu: Whether to use GPU acceleration for PaddleOCR
+ """
+ self.ocr = PaddleOCR(
+ use_angle_cls=False,
+ lang="en",
+ use_gpu=use_gpu,
+ show_log=False # Disable unnecessary log output
+ )
+ self.frame_interval = 4
+
+ @torch.no_grad()
+ def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> Tuple[List[float], torch.Tensor]:
+ """
+ :param images: List of images or videos (each video as np.ndarray of shape [F, H, W, C])
+ :param prompts: List of prompts containing target text
+ :return: (List of OCR rewards, Tensor of attention regions)
+ """
+ prompts = [prompt.split('"')[1] for prompt in prompts]
+ assert len(images) == len(prompts), "Mismatch between images and prompts."
+
+ rewards = []
+ for img, prompt in zip(images, prompts):
+ prompt = prompt.replace(' ', '').lower()
+ frame_rewards = []
+
+ # Handle video: shape (F, H, W, C)
+ if isinstance(img, np.ndarray) and img.ndim == 4:
+ sampled_frames = img[::self.frame_interval]
+ else:
+ sampled_frames = [img]
+
+ for frame in sampled_frames:
+ region = None
+ if isinstance(frame, Image.Image):
+ frame = np.array(frame)
+ try:
+ result = self.ocr.ocr(frame, cls=False)
+ text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
+ text = text.replace(' ', '').lower()
+
+ dist = distance(text, prompt)
+ dist = min(dist, len(prompt))
+
+ except Exception as e:
+ print(f"OCR failed on frame: {e}")
+ dist = len(prompt)
+
+ reward = 1 - dist / len(prompt)
+ if reward > 0:
+ frame_rewards.append(reward)
+
+ if frame_rewards:
+ rewards.append(sum(frame_rewards) / len(frame_rewards))
+ else:
+ rewards.append(0.0)
+
+ return rewards
+
+if __name__ == "__main__":
+ example_image_path = "media_images_eval_images_499_ef42de47b8ec98892954.jpg"
+ example_image = Image.open(example_image_path)
+ example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky'
+ # Instantiate scorer
+ scorer = OcrScorer(use_gpu=False)
+
+ # Call scorer and print result
+ reward = scorer([example_image], [example_prompt])
+ print(f"OCR Reward: {reward}")
\ No newline at end of file
diff --git a/adv_grpo/pick_score_training.py b/adv_grpo/pick_score_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6782c978cda6e5331c0a9e177b311154573d01
--- /dev/null
+++ b/adv_grpo/pick_score_training.py
@@ -0,0 +1,385 @@
+import torch
+from transformers import CLIPProcessor, CLIPModel
+from PIL import Image
+from torch.utils.data import DataLoader
+
+# ====== 使用你找到的 CLIPCriterion ======
+from dataclasses import dataclass
+from torch.nn.modules.loss import _Loss
+from torch.utils.data import Dataset, DataLoader
+import os
+import json
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+
+def evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device, max_eval=100):
+ """
+ 简单评估:取前 max_eval 对 Qwen vs SD3 pair,算平均分
+ """
+ model.eval()
+ if hasattr(model, "module"): # DDP 情况
+ model = model.module
+
+ with open(json_file, "r") as f:
+ prompt2img = json.load(f)
+
+ prompts = list(prompt2img.keys())[:max_eval]
+
+ qwen_scores, sd3_scores = [], []
+
+ for prompt in prompts:
+ filename = prompt2img[prompt]
+ qwen_img_path = os.path.join(qwen_dir, filename)
+ sd3_img_path = os.path.join(sd3_dir, filename)
+
+ if not (os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path)):
+ continue
+
+ qwen_img = Image.open(qwen_img_path).convert("RGB")
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
+
+ # 文本 & 图像输入
+ text_inputs = processor.tokenizer(
+ prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77
+ ).to(device)
+ qwen_inputs = processor(images=qwen_img, return_tensors="pt").to(device)
+ sd3_inputs = processor(images=sd3_img, return_tensors="pt").to(device)
+
+ with torch.no_grad():
+ text_features = model.get_text_features(**text_inputs)
+ qwen_features = model.get_image_features(**qwen_inputs)
+ sd3_features = model.get_image_features(**sd3_inputs)
+
+ # 归一化
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ qwen_features = qwen_features / qwen_features.norm(dim=-1, keepdim=True)
+ sd3_features = sd3_features / sd3_features.norm(dim=-1, keepdim=True)
+
+ # 相似度分数
+ logit_scale = model.logit_scale.exp()
+ qwen_score = (logit_scale * (text_features @ qwen_features.T)).item()
+ sd3_score = (logit_scale * (text_features @ sd3_features.T)).item()
+
+ qwen_scores.append(qwen_score)
+ sd3_scores.append(sd3_score)
+
+ model.train()
+ if len(qwen_scores) > 0:
+ print(f"[Eval] Qwen avg={sum(qwen_scores)/len(qwen_scores):.4f} "
+ f"| SD3 avg={sum(sd3_scores)/len(sd3_scores):.4f}")
+
+
+@dataclass
+class CLIPCriterionConfig:
+ _target_: str = "trainer.criterions.clip_criterion.CLIPCriterion"
+ is_distributed: bool = False # 本地先关掉
+ label_0_column_name: str = "label_0"
+ label_1_column_name: str = "label_1"
+ input_ids_column_name: str = "input_ids"
+ pixels_0_column_name: str = "pixels_0"
+ pixels_1_column_name: str = "pixels_1"
+ num_examples_per_prompt_column_name: str = "num_examples_per_prompt"
+ in_batch_negatives: bool = False
+
+
+class CLIPCriterion(_Loss):
+ def __init__(self, cfg: CLIPCriterionConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ @staticmethod
+ def get_features(model, input_ids, pixels_0_values, pixels_1_values):
+ # import pdb; pdb.set_trace()
+ # if hasattr(model, "module"):
+ # model = model.module
+ all_pixel_values = torch.cat([pixels_0_values, pixels_1_values], dim=0)
+ # text_features, all_image_features = model(text_inputs=input_ids, image_inputs=all_pixel_values)
+ text_features = model.get_text_features(input_ids=input_ids)
+ all_image_features = model.get_image_features(pixel_values=all_pixel_values)
+ all_image_features = all_image_features / all_image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ image_0_features, image_1_features = all_image_features.chunk(2, dim=0)
+ return image_0_features, image_1_features, text_features
+
+ @staticmethod
+ def gather_features(features):
+ all_features = torch.cat(torch.distributed.nn.all_gather(features), dim=0)
+ return all_features
+
+ # def safe_sync(self, msg):
+ # torch.cuda.synchronize()
+ # print(f"[Rank {dist.get_rank()}] OK at {msg}")
+
+ def calc_loss(
+ self,
+ text_features,
+ image_0_features,
+ image_1_features,
+ logit_scale,
+ label_0,
+ label_1,
+ num_examples_per_prompt,
+ *args,
+ **kwargs
+ ):
+ # self.safe_sync("start")
+
+ device = image_0_features.device
+
+ # gather features
+ if self.cfg.is_distributed:
+ image_0_features = self.gather_features(image_0_features)
+ image_1_features = self.gather_features(image_1_features)
+ text_features = self.gather_features(text_features)
+ label_0 = self.gather_features(label_0)
+ label_1 = self.gather_features(label_1)
+ num_examples_per_prompt = self.gather_features(num_examples_per_prompt)
+
+ # calc logits # TODO use local loss as open-clip does
+ all_image_features = torch.cat([image_0_features, image_1_features], dim=0) # (2 * batch_size, dim)
+ logits_per_image = logit_scale * all_image_features @ text_features.T
+ image_0_logits, image_1_logits = logits_per_image.chunk(2, dim=0)
+ text_logits = logit_scale * text_features @ all_image_features.T
+
+ if self.cfg.in_batch_negatives:
+ # get labels
+ num_images = all_image_features.shape[0]
+ image_labels = torch.arange(num_images, device=device, dtype=torch.long)
+ image_0_labels, image_1_labels = image_labels.chunk(2, dim=0)
+ num_texts = text_features.shape[0]
+ text_labels = torch.arange(num_texts, device=device, dtype=torch.long)
+
+ # image loss - we want to increase the logits of the preferred image to the text
+ image_0_loss = torch.nn.functional.cross_entropy(image_0_logits, text_labels, reduction="none")
+ image_1_loss = torch.nn.functional.cross_entropy(image_1_logits, text_labels, reduction="none")
+ # if we have a tie, we will increase both images equally, and average so the image loss of each example is
+ # proportional
+ image_loss = label_0 * image_0_loss + label_1 * image_1_loss
+
+ # text loss - we want to increase the logits of the text to the preferred image
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, image_0_labels, reduction="none")
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, image_1_labels, reduction="none")
+
+ else:
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
+
+ text_0_logits = text_0_logits[index, index]
+ text_1_logits = text_1_logits[index, index]
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
+ text_1_labels = text_0_labels + 1
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
+
+ # if we have a tie we want the logits of for each image to be equal
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
+ # we want the ideal loss to be 0, currently, if there is a tie, it is 0.5 * log(0.5) + 0.5 * log(0.5)
+ # so we add log(0.5) to the loss
+ is_tie = (label_0 == label_1).float()
+ is_tie *= torch.log(torch.tensor(0.5, device=device))
+ text_loss += is_tie
+
+ # we average the image and text loss
+ if self.cfg.in_batch_negatives:
+ loss = (image_loss + text_loss) / 2
+ else:
+ loss = text_loss
+ # import pdb; pdb.set_trace()
+
+ # some prompts have lots of interactions, we want weight them accordingly
+ # absolute_example_weight = 1 / num_examples_per_prompt
+ # denominator = absolute_example_weight.sum()
+ # weight_per_example = absolute_example_weight / denominator
+ # loss *= weight_per_example
+ loss = loss.mean()
+ # import pdb; pdb.set_trace()
+
+ # loss = loss.sum()
+ return loss
+
+ def forward(self, model, batch):
+ # import pdb; pdb.set_trace()
+ image_0_features, image_1_features, text_features = self.get_features(
+ model,
+ batch[self.cfg.input_ids_column_name],
+ batch[self.cfg.pixels_0_column_name],
+ batch[self.cfg.pixels_1_column_name]
+ )
+ # print("text_features:", text_features.shape)
+
+ loss = self.calc_loss(
+ text_features,
+ image_0_features,
+ image_1_features,
+ model.logit_scale.exp(),
+ batch[self.cfg.label_0_column_name],
+ batch[self.cfg.label_1_column_name],
+ batch[self.cfg.num_examples_per_prompt_column_name],
+ )
+ return loss
+
+
+# ====== 数据准备 ======
+class QwenSD3JsonDataset(Dataset):
+ def __init__(self, processor, json_file, qwen_dir, sd3_dir):
+ """
+ json_file: prompt2img.json {prompt: filename}
+ qwen_dir: 存放Qwen图像的文件夹
+ sd3_dir: 存放SD3图像的文件夹
+ """
+ self.processor = processor
+
+ with open(json_file, "r") as f:
+ self.prompt2img = json.load(f)
+
+ self.prompts = list(self.prompt2img.keys())
+ self.qwen_dir = qwen_dir
+ self.sd3_dir = sd3_dir
+
+ def __len__(self):
+ return len(self.prompts)
+
+ def __getitem__(self, idx):
+ prompt = self.prompts[idx]
+ filename = self.prompt2img[prompt]
+
+ qwen_img_path = os.path.join(self.qwen_dir, filename)
+ sd3_img_path = os.path.join(self.sd3_dir, filename)
+
+ if os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path):
+ qwen_img = Image.open(qwen_img_path).convert("RGB")
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
+ else:
+ qwen_img = Image.open(sd3_img_path).convert("RGB")
+ sd3_img = Image.open(sd3_img_path).convert("RGB")
+
+ # 文本token
+ text_inputs = self.processor.tokenizer(
+ prompt,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt"
+ )
+ input_ids = text_inputs["input_ids"].squeeze(0)
+
+ # 图像预处理
+ pixels_0 = self.processor(images=qwen_img, return_tensors="pt")["pixel_values"].squeeze(0)
+ pixels_1 = self.processor(images=sd3_img, return_tensors="pt")["pixel_values"].squeeze(0)
+
+ return {
+ "input_ids": input_ids,
+ "pixels_0": pixels_0, # 正样本 (Qwen)
+ "pixels_1": pixels_1, # 负样本 (SD3)
+ "label_0": torch.tensor(1.0),
+ "label_1": torch.tensor(0.0),
+ "num_examples_per_prompt": torch.tensor(1.0)
+ }
+
+
+# ====== 训练 loop ======
+# def finetune_pickscore(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6, device="cuda"):
+# processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+# model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
+
+# dataset = QwenSD3JsonDataset(processor,json_file, qwen_dir, sd3_dir)
+# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+# criterion = CLIPCriterion(CLIPCriterionConfig())
+# optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+# # import pdb; pdb.set_trace()
+
+# model.train()
+# for epoch in range(epochs):
+# total_loss = 0.0
+# for batch in dataloader:
+# batch = {k: v.to(device) for k, v in batch.items()}
+# loss = criterion(model, batch)
+
+# optimizer.zero_grad()
+# loss.backward()
+# optimizer.step()
+
+# total_loss += loss.item()
+# print(f"Epoch {epoch} | Loss {total_loss/len(dataloader):.4f}")
+
+# model.save_pretrained("pickscore_qwen_finetuned")
+# return model
+
+def finetune_pickscore_distributed(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6):
+ # 1. 初始化分布式
+ dist.init_process_group(backend="nccl")
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(local_rank)
+ device = torch.device("cuda", local_rank)
+
+ # 2. 准备数据
+ processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ dataset = QwenSD3JsonDataset(processor, json_file, qwen_dir, sd3_dir)
+ sampler = DistributedSampler(dataset)
+ dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
+
+ # 3. 模型 + DDP
+ model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
+
+ criterion = CLIPCriterion(CLIPCriterionConfig())
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+
+ # 4. 训练
+ model.train()
+ if dist.get_rank() == 0:
+ evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
+ for epoch in range(epochs):
+ sampler.set_epoch(epoch) # 保证每个 epoch shuffle 一样
+ total_loss = 0.0
+
+ for step, batch in enumerate(dataloader):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ loss = criterion(model.module, batch)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # 累积loss(先local)
+ total_loss += loss.item()
+
+ # 每隔一定步打印一次(rank=0)
+ if step % 50 == 0: # 你可以改成10、100
+ # all_reduce 把所有 GPU 的 loss 平均
+ avg_loss = torch.tensor(loss.item(), device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
+ if dist.get_rank() == 0:
+ print(f"[Epoch {epoch} | Step {step}/{len(dataloader)}] "
+ f"local_loss={loss.item():.4f} | avg_loss={avg_loss.item():.4f}")
+
+ # 每个 epoch 打印 epoch 平均 loss
+ epoch_loss = torch.tensor(total_loss / len(dataloader), device=device)
+ dist.all_reduce(epoch_loss, op=dist.ReduceOp.AVG)
+ if dist.get_rank() == 0:
+ print(f"===> Epoch {epoch} done | avg_epoch_loss={epoch_loss.item():.4f}")
+ evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
+
+ # 5. 保存模型(只在 rank=0)
+ if dist.get_rank() == 0:
+ model.module.save_pretrained("pickscore_qwen_finetuned")
+
+ dist.destroy_process_group()
+
+
+# ====== 用法示例 ======
+if __name__ == "__main__":
+ finetune_pickscore_distributed(
+ json_file="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images/prompt2img.json",
+ qwen_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images",
+ sd3_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images",
+ epochs=2,
+ batch_size=4,
+ lr=1e-6,
+ )
diff --git a/adv_grpo/pickscore_scorer.py b/adv_grpo/pickscore_scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec09229bdef3d4351f3f810da9dcbe466df82e0
--- /dev/null
+++ b/adv_grpo/pickscore_scorer.py
@@ -0,0 +1,70 @@
+from transformers import CLIPProcessor, CLIPModel
+from PIL import Image
+import torch
+
+class PickScoreScorer(torch.nn.Module):
+ def __init__(self, device="cuda", dtype=torch.float32):
+ super().__init__()
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+ model_path = "yuvalkirstain/PickScore_v1"
+ self.device = device
+ self.dtype = dtype
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
+ self.model = self.model.to(dtype=dtype)
+
+ @torch.no_grad()
+ def __call__(self, prompt, images):
+ # Preprocess images
+ if hasattr(self.model, "module"):
+ self.model = self.model.module
+ image_inputs = self.processor(
+ images=images,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
+ # Preprocess text
+ text_inputs = self.processor(
+ text=prompt,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
+
+ # Get embeddings
+ image_embs = self.model.get_image_features(**image_inputs)
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
+
+ text_embs = self.model.get_text_features(**text_inputs)
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
+
+ # Calculate scores
+ logit_scale = self.model.logit_scale.exp()
+ scores = logit_scale * (text_embs @ image_embs.T)
+ scores = scores.diag()
+ # norm to 0-1
+ scores = scores/26
+ return scores
+
+# Usage example
+def main():
+ scorer = PickScoreScorer(
+ device="cuda",
+ dtype=torch.float32
+ )
+ images=[
+ "nasa.jpg",
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+ print(scorer(prompts, pil_images))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/pickscore_scorer_constractive.py b/adv_grpo/pickscore_scorer_constractive.py
new file mode 100644
index 0000000000000000000000000000000000000000..32589bfcfb055122a9d7cc8a5ac2bdfa9bc48424
--- /dev/null
+++ b/adv_grpo/pickscore_scorer_constractive.py
@@ -0,0 +1,89 @@
+from transformers import CLIPProcessor, CLIPModel
+from PIL import Image
+import torch
+
+class PickScoreScorerConstractive(torch.nn.Module):
+ def __init__(self, device="cuda", dtype=torch.float32):
+ super().__init__()
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+ model_path = "yuvalkirstain/PickScore_v1"
+ self.device = device
+ self.dtype = dtype
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
+ self.model = self.model.to(dtype=dtype)
+
+ @torch.no_grad()
+ def __call__(self, prompt, images, ref_images):
+ # Preprocess images
+ image_inputs = self.processor(
+ images=images,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
+
+ ref_image_inputs = self.processor(
+ images=ref_images,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ ref_image_inputs = {k: v.to(device=self.device) for k, v in ref_image_inputs.items()}
+
+
+
+ # Preprocess text
+ text_inputs = self.processor(
+ text=prompt,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
+
+ # Get embeddings
+ image_embs = self.model.get_image_features(**image_inputs)
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
+
+ ref_image_embs = self.model.get_image_features(**ref_image_inputs)
+ ref_image_embs = ref_image_embs / ref_image_embs.norm(p=2, dim=-1, keepdim=True)
+
+ text_embs = self.model.get_text_features(**text_inputs)
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
+
+ # Calculate scores
+ logit_scale = self.model.logit_scale.exp()
+ scores = logit_scale * (text_embs @ image_embs.T)
+ scores = scores.diag()
+ # norm to 0-1
+ scores = scores/26
+
+ ref_scores = logit_scale * (text_embs @ ref_image_embs.T)
+ ref_scores = ref_scores.diag()
+ ref_scores = ref_scores/26
+
+
+ return scores, ref_scores, image_embs, ref_image_embs
+
+# Usage example
+def main():
+ scorer = PickScoreScorer(
+ device="cuda",
+ dtype=torch.float32
+ )
+ images=[
+ "nasa.jpg",
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+ print(scorer(prompts, pil_images))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/pickscore_scorer_patch.py b/adv_grpo/pickscore_scorer_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..62f4180b7f4703e8c03e5c6093c6b49f90bd6ef1
--- /dev/null
+++ b/adv_grpo/pickscore_scorer_patch.py
@@ -0,0 +1,78 @@
+from transformers import CLIPProcessor, CLIPModel
+from PIL import Image
+import torch
+
+class PickScoreScorer(torch.nn.Module):
+ def __init__(self, device="cuda", dtype=torch.float32):
+ super().__init__()
+ processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+ model_path = "yuvalkirstain/PickScore_v1"
+ self.device = device
+ self.dtype = dtype
+ self.processor = CLIPProcessor.from_pretrained(processor_path)
+ self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
+ self.model = self.model.to(dtype=dtype)
+
+ @torch.no_grad()
+ def __call__(self, prompt, images):
+ # Preprocess images
+ if hasattr(self.model, "module"):
+ self.model = self.model.module
+ image_inputs = self.processor(
+ images=images,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
+ # Preprocess text
+ text_inputs = self.processor(
+ text=prompt,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
+
+ # Get embeddings
+ # image_embs = self.model.get_image_features(**image_inputs)
+ import pdb; pdb.set_trace()
+ image_embs = self.model.vision_model(image_inputs["pixel_values"],output_hidden_states=True)
+ image_embs = image_embs.last_hidden_state
+
+ image_embs = self.model.visual_projection(image_embs) # [B, N, 1024]
+ image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
+
+ text_embs = self.model.get_text_features(**text_inputs)
+ text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
+
+ # Calculate scores
+ logit_scale = self.model.logit_scale.exp()
+ # scores = logit_scale * (text_embs @ image_embs.T)
+ patch_scores = torch.einsum("bd,bnd->bn", text_embs, image_embs) # [B, N]
+ scores = logit_scale * patch_scores.mean(dim=1) # 取所有 patch 的平均
+ # scores = scores.diag()
+ # norm to 0-1
+ scores = scores/26
+ # import pdb; pdb.set_trace()
+ return scores
+
+# Usage example
+def main():
+ scorer = PickScoreScorer(
+ device="cuda",
+ dtype=torch.float32
+ )
+ images=[
+ "nasa.jpg",
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+ print(scorer(prompts, pil_images))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/prompts.py b/adv_grpo/prompts.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b0d55e465ebb8f4e5f16efa43ae6e33629dabb
--- /dev/null
+++ b/adv_grpo/prompts.py
@@ -0,0 +1,80 @@
+from importlib import resources
+import os
+import functools
+import random
+# import inflect
+
+# IE = inflect.engine()
+IE=None
+ASSETS_PATH = resources.files("adv_grpo.assets")
+
+
+@functools.cache
+def _load_lines(path):
+ """
+ Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
+ `adv_grpo/assets` directory for a file named `path`.
+ """
+ if not os.path.exists(path):
+ newpath = ASSETS_PATH.joinpath(path)
+ if not os.path.exists(newpath):
+ raise FileNotFoundError(f"Could not find {path} or adv_grpo.assets/{path}")
+ path = newpath
+ with open(path, "r") as f:
+ return [line.strip() for line in f.readlines()]
+
+
+def from_file(path, low=None, high=None):
+ prompts = _load_lines(path)[low:high]
+ return random.choice(prompts), {}
+
+
+def imagenet_all():
+ return from_file("imagenet_classes.txt")
+
+
+def imagenet_animals():
+ return from_file("imagenet_classes.txt", 0, 398)
+
+
+def imagenet_dogs():
+ return from_file("imagenet_classes.txt", 151, 269)
+
+
+def simple_animals():
+ return from_file("simple_animals.txt")
+
+def general_ocr():
+ return from_file("general_ocr_train.txt")
+
+def simple_ocr_animals():
+ animals = _load_lines("simple_ocr_animals.txt")
+ # random_number = random.randint(100, 999)
+ # random_number = ''.join([str(random.randint(0, 9)) for _ in range(10)])
+ num=random.randint(1, 9)
+ random_number = ''.join([str(6) for _ in range(num)])
+ return f'A {random.choice(animals)} holding a sign that says "{random_number}"', {}
+
+def nouns_activities(nouns_file, activities_file):
+ nouns = _load_lines(nouns_file)
+ activities = _load_lines(activities_file)
+ return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
+
+
+def counting(nouns_file, low, high):
+ nouns = _load_lines(nouns_file)
+ number = IE.number_to_words(random.randint(low, high))
+ noun = random.choice(nouns)
+ plural_noun = IE.plural(noun)
+ prompt = f"{number} {plural_noun}"
+ metadata = {
+ "questions": [
+ f"How many {plural_noun} are there in this image?",
+ f"What animal is in this image?",
+ ],
+ "answers": [
+ number,
+ noun,
+ ],
+ }
+ return prompt, metadata
diff --git a/adv_grpo/qwenvl.py b/adv_grpo/qwenvl.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8534785cb9b57e67843d0988553334d688fd37f
--- /dev/null
+++ b/adv_grpo/qwenvl.py
@@ -0,0 +1,118 @@
+from PIL import Image
+import torch
+import re
+import base64
+from io import BytesIO
+from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+def pil_image_to_base64(image):
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ base64_qwen = f"data:image;base64,{encoded_image_text}"
+ return base64_qwen
+
+def extract_scores(output_text):
+ scores = []
+ for text in output_text:
+ match = re.search(r'(\d+)', text)
+ if match:
+ scores.append(float(match.group(1))/5)
+ else:
+ scores.append(0)
+ return scores
+
+class QwenVLScorer(torch.nn.Module):
+ def __init__(self, device="cuda", dtype=torch.bfloat16):
+ super().__init__()
+ self.device = device
+ self.dtype = dtype
+
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2.5-VL-7B-Instruct",
+ torch_dtype=self.dtype,
+ attn_implementation="flash_attention_2",
+ device_map=None,
+ ).to(self.device)
+ self.model.requires_grad_(False)
+ self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
+ self.task = '''
+Your role is to evaluate the aesthetic quality score of given images.
+1. Bad: Extremely blurry, underexposed with significant noise, indiscernible
+subjects, and chaotic composition.
+2. Poor: Noticeable blur, poor lighting, washed-out colors, and awkward
+composition with cut-off subjects.
+3. Fair: In focus with adequate lighting, dull colors, decent composition but
+lacks creativity.
+4. Good: Sharp, good exposure, vibrant colors, thoughtful composition with
+a clear focal point.
+5. Excellent: Exceptional clarity, perfect exposure, rich colors, masterful
+composition with emotional impact.
+
+Please first provide a detailed analysis of the evaluation process, including the criteria for judging aesthetic quality, within the tag. Then, give a final score from 1 to 5 within the tag.
+
+[Analyze the evaluation process in detail here]
+
+X
+'''
+
+ @torch.no_grad()
+ def __call__(self, prompt, images):
+ images_base64 = [pil_image_to_base64(image) for image in images]
+ messages=[]
+ for base64_qwen in images_base64:
+ messages.append([
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": base64_qwen},
+ {"type": "text", "text": self.task},
+ ],
+ },
+ ])
+
+ # Preparation for batch inference
+ texts = [
+ self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
+ for msg in messages
+ ]
+ image_inputs, video_inputs = process_vision_info(messages)
+ inputs = self.processor(
+ text=texts,
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to(self.device)
+
+ # Batch Inference
+ generated_ids = self.model.generate(**inputs, max_new_tokens=2048)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_texts = self.processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ rewards = extract_scores(output_texts)
+ return rewards
+
+# Usage example
+def main():
+ scorer = QwenVLScorer(
+ device="cuda",
+ dtype=torch.bfloat16
+ )
+ images=[
+ "nasa.jpg",
+ ]
+ pil_images = [Image.open(img) for img in images]
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+
+ print(scorer(None, pil_images))
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/rewards.py b/adv_grpo/rewards.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6bb4d39424e1b4a0d33f6375a1f5aa8141a96b1
--- /dev/null
+++ b/adv_grpo/rewards.py
@@ -0,0 +1,1126 @@
+from PIL import Image
+import io
+import numpy as np
+import torch
+from collections import defaultdict
+
+import torch
+import timm
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def jpeg_incompressibility():
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ buffers = [io.BytesIO() for _ in images]
+ for image, buffer in zip(images, buffers):
+ image.save(buffer, format="JPEG", quality=95)
+ sizes = [buffer.tell() / 1000 for buffer in buffers]
+ return np.array(sizes), {}
+
+ return _fn
+
+
+def jpeg_compressibility():
+ jpeg_fn = jpeg_incompressibility()
+
+ def _fn(images, prompts, metadata):
+ rew, meta = jpeg_fn(images, prompts, metadata)
+ return -rew/500, meta
+
+ return _fn
+
+def aesthetic_score():
+ from adv_grpo.aesthetic_scorer import AestheticScorer
+
+ scorer = AestheticScorer(dtype=torch.float32).cuda()
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8)
+ else:
+ images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
+ images = torch.tensor(images, dtype=torch.uint8)
+ scores = scorer(images)
+ return scores, {}
+
+ return _fn
+
+def clip_score():
+ from adv_grpo.clip_scorer import ClipScorer
+
+ # scorer = ClipScorer(dtype=torch.float32).cuda()
+ scorer = ClipScorer().cuda()
+
+ def _fn(images, prompts, metadata):
+ if not isinstance(images, torch.Tensor):
+ images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
+ images = torch.tensor(images, dtype=torch.uint8)/255.0
+ scores = scorer(images, prompts)
+ return scores, {}
+
+ return _fn
+
+
+def siglip_image_similarity_score(device):
+ import torch
+ import numpy as np
+ from transformers import SiglipModel
+ import torch.nn.functional as F
+
+ # 1. 加载 SigLIP 模型(推荐 so400m-p14-384)
+ scorer = SiglipModel.from_pretrained(
+ "google/siglip-so400m-patch14-384"
+ ).to(device).to(torch.bfloat16)
+ scorer.eval()
+
+ # SigLIP preprocess mean/std
+ siglip_mean = [0.5, 0.5, 0.5]
+ siglip_std = [0.5, 0.5, 0.5]
+
+ # 模型输入分辨率(自动匹配所选 SigLIP 模型)
+ image_size = scorer.config.vision_config.image_size # e.g., 224/256/384
+
+ def _preprocess(images):
+ # 转成 tensor
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+
+ # 0-255 → 0-1
+ if images.max() > 1.0:
+ images = images / 255.0
+
+ # NHWC → NCHW
+ if images.shape[-1] == 3:
+ images = images.permute(0, 3, 1, 2)
+
+ # resize to SigLIP input
+ images = F.interpolate(
+ images,
+ size=(image_size, image_size),
+ mode="bicubic",
+ align_corners=False
+ )
+
+ # normalize using SigLIP's mean/std
+ mean = torch.tensor(siglip_mean, device=device)[None, :, None, None]
+ std = torch.tensor(siglip_std, device=device)[None, :, None, None]
+ images = (images - mean) / std
+
+ return images.to(device).to(torch.bfloat16)
+
+ def _fn(images, ref_images):
+ # 2. preprocess
+ images = _preprocess(images)
+ ref_images = _preprocess(ref_images)
+
+ with torch.no_grad():
+ # SigLIP extract feature
+ out_img = scorer.vision_model(
+ pixel_values=images.to(torch.float32)
+ )
+ out_ref = scorer.vision_model(
+ pixel_values=ref_images.to(torch.float32)
+ )
+
+ emb_images = out_img.pooler_output.to(torch.bfloat16)
+ emb_ref = out_ref.pooler_output.to(torch.bfloat16)
+
+ # 3. normalize embeddings (L2)
+ emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
+ emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
+
+ # 4. cosine similarity
+ scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
+ per_img = scores.max(dim=1).values # [N]
+
+ return per_img.detach(), {"pairwise": scores.detach()}
+
+ return _fn
+
+
+
+def image_similarity_score(device):
+ import torch
+ import numpy as np
+ import timm
+ import torch.nn.functional as F
+
+ # 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14)
+ model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True)
+ # model = timm.create_model('vit_large_patch16_dinov3_qkvb.lvd1689m', pretrained=True)
+ model.eval().to(device)
+
+
+ def _preprocess(images):
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+ if images.max() > 1.0:
+ images = images / 255.0
+ if images.shape[-1] == 3: # NHWC -> NCHW
+ images = images.permute(0, 3, 1, 2)
+ # 调整到 518×518
+ # images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ # DINOv2 normalization
+ # images = (images - 0.5) / 0.5
+ images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
+ std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
+ images = (images - mean) / std
+ return images.to(device)
+
+ def _fn(images, ref_images):
+ # 2. 预处理
+ # import pdb; pdb.set_trace()
+ images = _preprocess(images)
+ ref_images = _preprocess(ref_images)
+
+ with torch.no_grad():
+ # import pdb; pdb.set_trace()
+ emb_images = model(images) # [N,D]
+ emb_ref = model(ref_images) # [M,D]
+
+ # 3. 归一化
+ emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
+ emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
+
+ # 4. 计算相似度 (余弦相似度)
+ scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
+ per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1)
+ # import pdb; pdb.set_trace()
+
+ # 返回一维分数,pairwise 放到 info 里
+ # return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref
+ return per_img.detach(), {"pairwise": scores.detach()}
+
+
+ # return scores, {}
+
+ return _fn
+
+
+
+
+def image_similarity_score_eval(device):
+ import torch
+ import numpy as np
+ import timm
+ import torch.nn.functional as F
+
+ # 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14)
+ model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True)
+ model.eval().to(device)
+
+ def _preprocess(images):
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+ if images.max() > 1.0:
+ images = images / 255.0
+ if images.shape[-1] == 3: # NHWC -> NCHW
+ images = images.permute(0, 3, 1, 2)
+ # 调整到 518×518
+ # images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ # DINOv2 normalization
+ # images = (images - 0.5) / 0.5
+ images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
+ std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
+ images = (images - mean) / std
+ return images.to(device)
+
+ def _fn(images, ref_images):
+ # 2. 预处理
+ # import pdb; pdb.set_trace()
+ images = _preprocess(images)
+ ref_images = _preprocess(ref_images)
+
+ with torch.no_grad():
+ # import pdb; pdb.set_trace()
+ emb_images = model(images) # [N,D]
+ emb_ref = model(ref_images) # [M,D]
+
+ # 3. 归一化
+ # emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True)
+ # emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True)
+
+ # 4. 计算相似度 (余弦相似度)
+ scores = torch.matmul(emb_images, emb_ref.T) # [N,M]
+ per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1)
+ # import pdb; pdb.set_trace()
+
+ # 返回一维分数,pairwise 放到 info 里
+ return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref
+ # return per_img.detach(), {"pairwise": scores.detach()}
+
+
+ # return scores, {}
+
+ return _fn
+
+
+
+def dino_cotrain_score(device):
+ def _preprocess(images):
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+ if images.max() > 1.0:
+ images = images / 255.0
+ if images.shape[-1] == 3: # NHWC -> NCHW
+ images = images.permute(0, 3, 1, 2)
+ images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ # images = (images - 0.5) / 0.5
+ mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
+ std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
+ images = (images - mean) / std
+ return images.to(device).to(torch.bfloat16)
+
+ def _fn(scorer, head, images, prompts, metadata):
+ images = _preprocess(images)
+
+ with torch.no_grad():
+ emb = scorer(images) # [N,D]
+ emb = emb / emb.norm(dim=-1, keepdim=True)
+
+ # Head 输出 reward
+ scores = head(emb).squeeze(-1) # [N]
+ # import pdb; pdb.set_trace()
+
+ return scores.detach(), {"embeddings": emb.detach()}
+
+ return _fn
+
+
+
+
+def siglip_cotrain_score(device):
+ """
+ 使用方式与原来的 dino_cotrain_score 一致:
+ reward_fn = siglip_cotrain_score(device)
+ reward_fn(scorer, head, images, prompts, metadata)
+
+ scorer: SigLIPModel (从 HF transformers 加载)
+ head: 你的 reward head
+ """
+ from torchvision import transforms
+ tiny_jitter = transforms.ColorJitter(
+ brightness=0.02,
+ contrast=0.02,
+ )
+
+ def _preprocess(images, image_size=224):
+ # 转 tensor
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+
+ # 将 uint8 0-255 → 0-1
+ if images.max() > 1.0:
+ images = images / 255.0
+
+ # NHWC → NCHW
+ if images.shape[-1] == 3:
+ images = images.permute(0, 3, 1, 2)
+
+ imgs_aug = []
+ for img in images:
+ imgs_aug.append(tiny_jitter(img)) # 做轻微亮度扰动
+ images = torch.stack(imgs_aug)
+
+ # Resize到 SigLIP 默认输入大小(通常 224 / 256 / 384)
+ images = F.interpolate(
+ images,
+ size=(image_size, image_size),
+ mode="bicubic",
+ align_corners=False
+ )
+
+ # ★★★ SigLIP 官方 mean/std(注意不同于 CLIP)★★★
+ mean = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None]
+ std = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None]
+
+ images = (images - mean) / std
+
+ return images.to(device).to(torch.bfloat16)
+
+ def _fn(scorer, head, images, prompts, metadata):
+ """
+ scorer: SigLIPModel
+ head: reward head
+ """
+
+ # 图像预处理(scorer.config.vision_config.image_size 通常 224)
+ # import pdb; pdb.set_trace()
+ images = _preprocess(images, 512)
+
+ scorer.eval()
+ with torch.no_grad():
+ # ★★★ SigLIP 特征获取 ★★★
+ # 返回 CLS token 的 embedding,类似 CLIP 的 global feature
+ vision_out = scorer.vision_model(
+ pixel_values=images.to(torch.float32)
+ )
+ emb = vision_out.pooler_output.to(torch.bfloat16) # [B, D]
+
+ # head 输出 reward
+ scores = head(emb).squeeze(-1)
+
+ return scores.detach(), {"embeddings": emb.detach()}
+
+ return _fn
+
+
+def dino_patch_cotrain_score(device, n_patches=64):
+ import torch
+ import torch.nn.functional as F
+
+ def _preprocess(images):
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+ if images.max() > 1.0:
+ images = images / 255.0
+ if images.shape[-1] == 3: # NHWC -> NCHW
+ images = images.permute(0, 3, 1, 2)
+ images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ # images = (images - 0.5) / 0.5
+ mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
+ std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
+ images = (images - mean) / std
+ return images.to(device).to(torch.bfloat16)
+
+ def _fn(scorer, head, images, prompts, metadata, cls_weight=0.7):
+ images = _preprocess(images)
+ with torch.no_grad():
+ # 提取所有特征: [B, N+1, D]
+ feats = scorer.forward_features(images)
+
+ # --- 分离 CLS 与 patch ---
+ cls_emb = feats[:, 0, :] # [B, D]
+ patch_emb = feats[:, 1:, :] # [B, N, D]
+ B, N, D = patch_emb.shape
+
+ # --- 随机采样 patch ---
+ n_select = min(n_patches, N)
+ idx = torch.randint(0, N, (B, n_select), device=device)
+ sampled_patches = torch.gather(
+ patch_emb, 1, idx.unsqueeze(-1).expand(-1, -1, D)
+ ) # [B, n_select, D]
+
+ # --- 归一化 ---
+ cls_emb = cls_emb / (cls_emb.norm(dim=-1, keepdim=True) + 1e-6)
+ sampled_patches = sampled_patches / (sampled_patches.norm(dim=-1, keepdim=True) + 1e-6)
+
+ # --- 计算分数 ---
+ cls_score = head(cls_emb).squeeze(-1) # [B]
+ patch_scores = head(sampled_patches).squeeze(-1) # [B, n_select]
+ patch_score_mean = patch_scores.mean(dim=1) # [B]
+
+ # --- 混合 reward ---
+ hybrid_score = cls_weight * cls_score + (1 - cls_weight) * patch_score_mean
+ # hybrid_score = hybrid_score.unsqueeze(-1) # [B, 1]
+
+ # import pdb; pdb.set_trace()
+
+ # --- 返回结果 ---
+ return hybrid_score.detach(), {
+ "cls_score": cls_score.detach(),
+ "patch_scores": patch_scores.detach(),
+ "patch_indices": idx.detach(),
+ "cls_weight": cls_weight,
+ }
+
+ return _fn
+
+
+def _get_layer_tokens_timm(model, imgs, layer_ids=(2, 5, 8, 11)):
+ handles, feats = [], {i: None for i in layer_ids}
+
+ def make_hook(i):
+ def hook(_module, _inp, out):
+ # 对 timm 的 ViT,block 的输出通常是 [B, N+1, D]
+ feats[i] = out
+ return hook
+
+ for i in layer_ids:
+ assert 0 <= i < len(model.blocks), f"layer id {i} out of range"
+ handles.append(model.blocks[i].register_forward_hook(make_hook(i)))
+
+ # 触发前向。timm 的 ViT 有 forward_features;没有就直接 __call__
+ with torch.no_grad():
+ if hasattr(model, "forward_features"):
+ _ = model.forward_features(imgs)
+ else:
+ _ = model(imgs)
+
+ for h in handles:
+ h.remove()
+
+ return [feats[i] for i in layer_ids] # list of [B, N+1, D]
+
+# -------- reward 工厂:分层 head + top-k 池化 + 融合 --------
+def dino_multi_cotrain_score(
+ device,
+ topk_tau=0.2, # 每层取前 tau 比例的 patch logits 做均值
+ apply_sigmoid=True, # 是否把 logit 过 sigmoid 得到 [0,1] reward
+ lambda_cls=0.5,
+ zscore=False, # 是否对 batch 维做 z-score(组内可再自行处理)
+):
+
+ def _preprocess(images):
+ if not isinstance(images, torch.Tensor):
+ images = torch.tensor(images, dtype=torch.float32)
+ if images.max() > 1.0:
+ images = images / 255.0
+ if images.shape[-1] == 3: # NHWC -> NCHW
+ images = images.permute(0, 3, 1, 2)
+ images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False)
+ images = (images - 0.5) / 0.5
+ # mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None]
+ # std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None]
+ # images = (images - mean) / std
+ return images.to(device)
+
+ @torch.no_grad()
+ def _fn(scorer, heads, fusion, images,prompts=None, metadata=None, layer_ids=(8,),temperature=0.2):
+ """
+ scorer : timm 的 ViT-DINOv2 backbone(已 .eval() 且 requires_grad=False)
+ heads : nn.ModuleList,长度 = len(layer_ids),每层一个 head
+ fusion : 融合器,将 (B,T) -> (B,)
+ images : [N,H,W,3] or [N,3,H,W],值域 [0,1]/[0,255]
+ 返回:
+ rewards: [N](float32)
+ aux: dict,包含 per_layer_scores 等,便于调试/可视化
+ """
+ from torch.nn.parallel import DistributedDataParallel as DDP
+ hmod = heads.module if isinstance(heads, DDP) else heads
+ fmod = fusion.module if isinstance(fusion, DDP) else fusion
+ x = _preprocess(images).to(dtype=next(scorer.parameters()).dtype, device=device)
+
+ # 取多层 tokens
+ tokens_list = _get_layer_tokens_timm(scorer, x, layer_ids=layer_ids) # list of [B, N+1, D]
+ B = x.size(0)
+ T = len(tokens_list)
+
+ per_layer_scores = []
+ per_layer_logits = [] # 保存每层的 top-k 前的 patch logits(可选)
+ per_layer_cls_scores = []
+
+ for t in range(T):
+ tokens = tokens_list[t] # [B, N+1, D]
+ patch = tokens[:, 1:] # [B, N, D] 忽略 CLS
+ class_patch = tokens[:, 0]
+ Bn, N, D = patch.shape
+
+ # head 支持 [B,N,D]:输出 [B,N]
+ logits_patch = hmod[t](patch).squeeze(-1) # [B, N]
+ per_layer_logits.append(logits_patch)
+
+ # 层内 top-k 池化
+ k = max(1, int(N * topk_tau))
+ pooled = logits_patch.topk(k, dim=1).values.mean(dim=1) # [B]
+ per_layer_scores.append(pooled)
+
+ cls_logit = hmod[t](class_patch).squeeze(-1) # [B]
+ per_layer_cls_scores.append(cls_logit)
+
+ per_layer_scores = torch.stack(per_layer_scores, dim=1) # [B, T]
+ per_layer_cls_scores = torch.stack(per_layer_cls_scores, dim=1)
+
+ # 融合为最终 logit
+ logit_patch = fmod(per_layer_scores) # [B]
+ logit_cls = fmod(per_layer_cls_scores)
+
+ # logits = (1.0 - float(lambda_cls)) * logit_patch + float(lambda_cls) * logit_cls # [B]
+ logits = logit_patch
+ # 标定成 reward
+ rewards = logits
+ # import pdb; pdb.set_trace()
+ if apply_sigmoid:
+ rewards = torch.sigmoid(rewards / float(temperature))
+ if zscore:
+ mu = rewards.mean(dim=0, keepdim=True)
+ sigma = rewards.std(dim=0, keepdim=True).clamp_min(1e-6)
+ rewards = (rewards - mu) / sigma
+
+ # 输出 float32,避免后续与 bf16 混用出问题
+ rewards = rewards.float()
+ # import pdb; pdb.set_trace()
+
+ aux = {
+ "per_layer_scores": per_layer_scores.float(), # [B,T]
+ "logits": logits.float(), # 未标定的融合 logit
+ # 下行可能很大(B×T×N),需要时再用;默认不返回节省带宽
+ # "per_layer_logits": [lp.float() for lp in per_layer_logits],
+ }
+ return rewards, aux
+
+ return _fn
+
+def pickscore_score(device):
+ from adv_grpo.pickscore_scorer import PickScoreScorer
+
+ scorer = PickScoreScorer(dtype=torch.float32, device=device)
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ scores = scorer(prompts, images)
+ return scores, {}
+
+ return _fn
+
+
+def pickscore_cotrain_score(device):
+
+
+ def _fn(scorer, images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ scores = scorer(prompts, images)
+ # import pdb; pdb.set_trace()
+ return scores, {}
+
+ return _fn
+
+
+def pickscore_score_patch(device):
+ from adv_grpo.pickscore_scorer_patch import PickScoreScorer
+
+
+ scorer = PickScoreScorer(dtype=torch.float32, device=device)
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ scores = scorer(prompts, images)
+ return scores, {}
+
+ return _fn
+
+
+
+
+def discriminator_score(device):
+ def _fn(scorer, images, prompts=None, metadata=None):
+ # 归一化到 [-1,1]
+ if isinstance(images, torch.Tensor):
+ if images.max() > 1.5: # 可能是 0~255
+ images = images / 255.0
+ images = (images - 0.5) * 2.0
+ else:
+ raise ValueError("images must be a torch.Tensor in [B,3,H,W]")
+
+ with torch.no_grad():
+ logits = scorer(images.to(device)) # StyleGAN: [B] / [B,1] PatchGAN: [B,1,H',W']
+
+ if logits.ndim == 1:
+ # StyleGAN D,已经是 [B]
+ scores = torch.sigmoid(logits)
+ elif logits.ndim == 2 and logits.shape[1] == 1:
+ # StyleGAN D,输出 [B,1]
+ scores = torch.sigmoid(logits.squeeze(1))
+ elif logits.ndim == 4 and logits.shape[1] == 1:
+ # PatchGAN D,输出 [B,1,H',W']
+ scores = torch.sigmoid(logits).mean(dim=[1,2,3]) # -> [B]
+ else:
+ raise ValueError(f"Unexpected logits shape: {logits.shape}")
+
+ return scores.cpu(), {}
+
+ return _fn
+
+
+
+def imagereward_score(device):
+ from adv_grpo.imagereward_scorer import ImageRewardScorer
+
+ scorer = ImageRewardScorer(dtype=torch.float32, device=device)
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ prompts = [prompt for prompt in prompts]
+ scores = scorer(prompts, images)
+ return scores, {}
+
+ return _fn
+
+def qwenvl_score(device):
+ from adv_grpo.qwenvl import QwenVLScorer
+
+ scorer = QwenVLScorer(dtype=torch.bfloat16, device=device)
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+ prompts = [prompt for prompt in prompts]
+ scores = scorer(prompts, images)
+ return scores, {}
+
+ return _fn
+
+
+def ocr_score(device):
+ from adv_grpo.ocr import OcrScorer
+
+ scorer = OcrScorer()
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ # import pdb; pdb.set_trace()
+ scores = scorer(images, prompts)
+ # change tensor to list
+ return scores, {}
+
+ return _fn
+
+def video_ocr_score(device):
+ from adv_grpo.ocr import OcrScorer_video_or_image
+
+ scorer = OcrScorer_video_or_image()
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ if images.dim() == 4 and images.shape[1] == 3:
+ images = images.permute(0, 2, 3, 1)
+ elif images.dim() == 5 and images.shape[2] == 3:
+ images = images.permute(0, 1, 3, 4, 2)
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ scores = scorer(images, prompts)
+ # change tensor to list
+ return scores, {}
+
+ return _fn
+
+def constractive_external(device, beta=0.5, top_n=2):
+ import torch
+ from PIL import Image
+ from adv_grpo.pickscore_scorer_constractive import PickScoreScorerConstractive
+
+ scorer = PickScoreScorerConstractive(dtype=torch.float32, device=device)
+
+ def _fn(images,ref_images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images = [Image.fromarray(image) for image in images]
+
+ ref_images = (ref_images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ ref_images = ref_images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ ref_images = [Image.fromarray(image) for image in ref_images]
+ # import pdb; pdb.set_trace()
+
+ # scorer
+ scores, ref_scores, img_embeds, ref_img_embeds = scorer(prompts, images, ref_images)
+
+ # external anchor
+ ref_embed = ref_img_embeds.mean(dim=0, keepdim=True) # (1,D)
+ ext_score = ref_scores.mean()
+
+ top_idx = torch.topk(scores, k=min(top_n, len(scores))).indices
+ hack_scores = scores[top_idx]
+ hack_embeds = img_embeds[top_idx] # (N,D)
+
+ if ext_score >= hack_scores.max():
+ return scores, {"raw_scores": scores, "ref_scores": ref_scores}
+
+ # 计算对比修正
+ sim_to_ext = torch.nn.functional.cosine_similarity(img_embeds, ref_embed)
+ sim_to_hack = torch.nn.functional.cosine_similarity(
+ img_embeds.unsqueeze(1), hack_embeds.unsqueeze(0), dim=-1
+ ) # (num_images, N)
+ sim_to_hack = sim_to_hack.mean(dim=1)
+
+ adjusted_scores = scores + beta * (sim_to_ext - sim_to_hack)
+
+ return adjusted_scores, {
+ "raw_scores": scores,
+ "ref_scores": ref_scores,
+ "sim_to_ext": sim_to_ext,
+ "sim_to_hack": sim_to_hack,
+ "hack_scores": hack_scores
+ }
+
+ return _fn
+
+
+def deqa_score_remote(device):
+ """Submits images to DeQA and computes a reward.
+ """
+ import requests
+ from requests.adapters import HTTPAdapter, Retry
+ from io import BytesIO
+ import pickle
+
+ batch_size = 64
+ url = "http://127.0.0.1:18086"
+ sess = requests.Session()
+ retries = Retry(
+ total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
+ )
+ sess.mount("http://", HTTPAdapter(max_retries=retries))
+
+ def _fn(images, prompts, metadata):
+ del prompts
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
+ all_scores = []
+ for image_batch in images_batched:
+ jpeg_images = []
+
+ # Compress the images using JPEG
+ for image in image_batch:
+ img = Image.fromarray(image)
+ buffer = BytesIO()
+ img.save(buffer, format="JPEG")
+ jpeg_images.append(buffer.getvalue())
+
+ # format for LLaVA server
+ data = {
+ "images": jpeg_images,
+ }
+ data_bytes = pickle.dumps(data)
+
+ # send a request to the llava server
+ response = sess.post(url, data=data_bytes, timeout=120)
+ response_data = pickle.loads(response.content)
+
+ all_scores += response_data["outputs"]
+
+ return all_scores, {}
+
+ return _fn
+
+
+
+def geneval_score(device):
+ """Submits images to GenEval and computes a reward.
+ """
+ import requests
+ from requests.adapters import HTTPAdapter, Retry
+ from io import BytesIO
+ import pickle
+
+ batch_size = 64
+ url = "http://127.0.0.1:18085"
+ sess = requests.Session()
+ retries = Retry(
+ total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
+ )
+ sess.mount("http://", HTTPAdapter(max_retries=retries))
+
+ def _fn(images, prompts, metadatas, only_strict):
+ del prompts
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
+ metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
+ all_scores = []
+ all_rewards = []
+ all_strict_rewards = []
+ all_group_strict_rewards = []
+ all_group_rewards = []
+ for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
+ jpeg_images = []
+
+ # Compress the images using JPEG
+ for image in image_batch:
+ img = Image.fromarray(image)
+ buffer = BytesIO()
+ img.save(buffer, format="JPEG")
+ jpeg_images.append(buffer.getvalue())
+
+ # format for LLaVA server
+ data = {
+ "images": jpeg_images,
+ "meta_datas": list(metadata_batched),
+ "only_strict": only_strict,
+ }
+ data_bytes = pickle.dumps(data)
+
+ # send a request to the llava server
+ response = sess.post(url, data=data_bytes, timeout=120)
+ response_data = pickle.loads(response.content)
+
+ all_scores += response_data["scores"]
+ all_rewards += response_data["rewards"]
+ all_strict_rewards += response_data["strict_rewards"]
+ all_group_strict_rewards.append(response_data["group_strict_rewards"])
+ all_group_rewards.append(response_data["group_rewards"])
+ all_group_strict_rewards_dict = defaultdict(list)
+ all_group_rewards_dict = defaultdict(list)
+ for current_dict in all_group_strict_rewards:
+ for key, value in current_dict.items():
+ all_group_strict_rewards_dict[key].extend(value)
+ all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)
+
+ for current_dict in all_group_rewards:
+ for key, value in current_dict.items():
+ all_group_rewards_dict[key].extend(value)
+ all_group_rewards_dict = dict(all_group_rewards_dict)
+
+ return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict
+
+ return _fn
+
+def unifiedreward_score_remote(device):
+ """Submits images to DeQA and computes a reward.
+ """
+ import requests
+ from requests.adapters import HTTPAdapter, Retry
+ from io import BytesIO
+ import pickle
+
+ batch_size = 64
+ url = "http://10.82.120.15:18085"
+ sess = requests.Session()
+ retries = Retry(
+ total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
+ )
+ sess.mount("http://", HTTPAdapter(max_retries=retries))
+
+ def _fn(images, prompts, metadata):
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+ images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
+ prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
+
+ all_scores = []
+ for image_batch, prompt_batch in zip(images_batched, prompts_batched):
+ jpeg_images = []
+
+ # Compress the images using JPEG
+ for image in image_batch:
+ img = Image.fromarray(image)
+ buffer = BytesIO()
+ img.save(buffer, format="JPEG")
+ jpeg_images.append(buffer.getvalue())
+
+ # format for LLaVA server
+ data = {
+ "images": jpeg_images,
+ "prompts": prompt_batch
+ }
+ data_bytes = pickle.dumps(data)
+
+ # send a request to the llava server
+ response = sess.post(url, data=data_bytes, timeout=120)
+ print("response: ", response)
+ print("response: ", response.content)
+ response_data = pickle.loads(response.content)
+
+ all_scores += response_data["outputs"]
+
+ return all_scores, {}
+
+ return _fn
+
+def unifiedreward_score_sglang(device):
+ import asyncio
+ from openai import AsyncOpenAI
+ import base64
+ from io import BytesIO
+ import re
+
+ def pil_image_to_base64(image):
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ base64_qwen = f"data:image;base64,{encoded_image_text}"
+ return base64_qwen
+
+ def _extract_scores(text_outputs):
+ scores = []
+ pattern = r"Final Score:\s*([1-5](?:\.\d+)?)"
+ for text in text_outputs:
+ match = re.search(pattern, text)
+ if match:
+ try:
+ scores.append(float(match.group(1)))
+ except ValueError:
+ scores.append(0.0)
+ else:
+ scores.append(0.0)
+ return scores
+
+ client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo")
+
+ async def evaluate_image(prompt, image):
+ question = f"\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after \'Final Score:\'.\nYour task is provided as follows:\nText Caption: [{prompt}]"
+ images_base64 = pil_image_to_base64(image)
+ response = await client.chat.completions.create(
+ model="UnifiedReward-7b-v1.5",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": images_base64},
+ },
+ {
+ "type": "text",
+ "text": question,
+ },
+ ],
+ },
+ ],
+ temperature=0,
+ )
+ return response.choices[0].message.content
+
+ async def evaluate_batch_image(images, prompts):
+ tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)]
+ results = await asyncio.gather(*tasks)
+ return results
+
+ def _fn(images, prompts, metadata):
+ # 处理Tensor类型转换
+ if isinstance(images, torch.Tensor):
+ images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
+
+ # 转换为PIL Image并调整尺寸
+ images = [Image.fromarray(image).resize((512, 512)) for image in images]
+
+ # 执行异步批量评估
+ text_outputs = asyncio.run(evaluate_batch_image(images, prompts))
+ score = _extract_scores(text_outputs)
+ score = [sc/5.0 for sc in score]
+ return score, {}
+
+ return _fn
+
+def multi_score(device, score_dict):
+ score_functions = {
+ "deqa": deqa_score_remote,
+ "ocr": ocr_score,
+ "video_ocr": video_ocr_score,
+ "imagereward": imagereward_score,
+ "pickscore": pickscore_score,
+ "qwenvl": qwenvl_score,
+ "aesthetic": aesthetic_score,
+ "jpeg_compressibility": jpeg_compressibility,
+ "unifiedreward": unifiedreward_score_sglang,
+ "geneval": geneval_score,
+ "clipscore": clip_score,
+ "image_similarity": image_similarity_score,
+ "image_similarity_eval": image_similarity_score_eval,
+ "constractive_external": constractive_external,
+ "discriminator": discriminator_score,
+ "pickscore_cotrain": pickscore_cotrain_score,
+ "pickscore_patch":pickscore_score_patch,
+ "dino_cotrain":dino_cotrain_score,
+ "dino_multi_cotrain": dino_multi_cotrain_score,
+ "dino_patch_cotrain": dino_patch_cotrain_score,
+ "siglip_cotrain": siglip_cotrain_score,
+ "siglip_image_similarity": siglip_image_similarity_score
+ }
+ score_fns={}
+ for score_name, weight in score_dict.items():
+ # import pdb; pdb.set_trace()
+ score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()
+
+ # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time.
+ def _fn(images, prompts, metadata, scorer = None, ref_images=None, only_strict=True, head=None, fusion=None, layer_ids = None, temperature=0.2):
+ total_scores = []
+ score_details = {}
+
+ for score_name, weight in score_dict.items():
+ if score_name == "geneval":
+ scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
+ score_details['accuracy'] = rewards
+ score_details['strict_accuracy'] = strict_rewards
+ for key, value in group_strict_rewards.items():
+ score_details[f'{key}_strict_accuracy'] = value
+ for key, value in group_rewards.items():
+ score_details[f'{key}_accuracy'] = value
+ elif score_name == "image_similarity":
+ scores, rewards = score_fns[score_name](images, ref_images)
+ elif score_name == "siglip_image_similarity":
+ scores, rewards = score_fns[score_name](images, ref_images)
+ elif score_name == "image_similarity_eval":
+ scores, rewards, feat, ref_feat = score_fns[score_name](images, ref_images)
+ score_details['feat'] = feat
+ score_details['ref_feat'] = ref_feat
+ elif score_name == "constractive_external":
+ scores, rewards = score_fns[score_name](images, prompts, ref_images)
+ elif score_name == "discriminator":
+ scores, rewards = score_fns[score_name](scorer, images, prompts, ref_images)
+ elif score_name == "pickscore_cotrain":
+ scores, rewards = score_fns[score_name](scorer, images, prompts, metadata)
+ elif score_name == "dino_cotrain":
+ scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
+ elif score_name == "siglip_cotrain":
+ scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
+ elif score_name == "dino_multi_cotrain":
+ scores, rewards = score_fns[score_name](scorer, head, fusion, images, prompts, metadata, layer_ids, temperature)
+ elif score_name == "dino_patch_cotrain":
+ scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
+ elif score_name == "dinov3_patch_cotrain":
+ scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata)
+ else:
+ scores, rewards = score_fns[score_name](images, prompts, metadata)
+
+ score_details[score_name] = scores
+ weighted_scores = [weight * score for score in scores]
+
+ if not total_scores:
+ total_scores = weighted_scores
+ else:
+ total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
+ # import pdb; pdb.set_trace()
+
+ score_details['avg'] = total_scores
+ return score_details, {}
+
+ return _fn
+
+def main():
+ import torchvision.transforms as transforms
+
+ image_paths = [
+ "nasa.jpg",
+ ]
+
+ transform = transforms.Compose([
+ transforms.ToTensor(), # Convert to tensor
+ ])
+
+ images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths])
+ prompts=[
+ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
+ ]
+ metadata = {} # Example metadata
+ score_dict = {
+ "unifiedreward": 1.0
+ }
+ # Initialize the multi_score function with a device and score_dict
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ scoring_fn = multi_score(device, score_dict)
+ # Get the scores
+ scores, _ = scoring_fn(images, prompts, metadata)
+ # Print the scores
+ print("Scores:", scores)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/adv_grpo/stat_tracking.py b/adv_grpo/stat_tracking.py
new file mode 100644
index 0000000000000000000000000000000000000000..98738b2953c3de5e9abb89554ad1ac551b2f6617
--- /dev/null
+++ b/adv_grpo/stat_tracking.py
@@ -0,0 +1,94 @@
+import numpy as np
+from collections import deque
+import torch
+# import warnings
+# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
+# print(f"\n⚠️ RuntimeWarning caught: {message}")
+# print("first 10 values:", rewards)
+# return warnings.default_action
+
+# warnings.showwarning = warn_with_traceback
+
+class PerPromptStatTracker:
+ def __init__(self, global_std=False):
+ self.global_std = global_std
+ self.stats = {}
+ self.history_prompts = set()
+
+ def update(self, prompts, rewards, type='grpo'):
+ prompts = np.array(prompts)
+ rewards = np.array(rewards, dtype=np.float64)
+ unique = np.unique(prompts)
+ advantages = np.empty_like(rewards)*0.0
+
+ # try:
+ # advantages = np.empty_like(rewards) * 0.0
+ # except RuntimeWarning as e:
+ # print("⚠️ RuntimeWarning:", e)
+ # print("rewards shape:", rewards.shape)
+ # print("rewards first 10:", rewards)
+ # raise
+
+ for prompt in unique:
+ prompt_rewards = rewards[prompts == prompt]
+ if prompt not in self.stats:
+ self.stats[prompt] = []
+ self.stats[prompt].extend(prompt_rewards)
+ self.history_prompts.add(hash(prompt)) # Add hash of prompt to history_prompts
+ for prompt in unique:
+ self.stats[prompt] = np.stack(self.stats[prompt])
+ prompt_rewards = rewards[prompts == prompt] # Fix: Recalculate prompt_rewards for each prompt
+ mean = np.mean(self.stats[prompt], axis=0, keepdims=True)
+ if self.global_std:
+ std = np.std(rewards, axis=0, keepdims=True) + 1e-4 # Use global std of all rewards
+ else:
+ std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4
+ if type=='grpo':
+ advantages[prompts == prompt] = (prompt_rewards - mean) / std
+ elif type=='rwr':
+ # advantages[prompts == prompt] = (prompt_rewards - mean) / std
+ advantages[prompts == prompt] = prompt_rewards
+ # advantages[prompts == prompt] = torch.softmax(torch.tensor(prompt_rewards), dim=0).numpy()
+ elif type=='sft':
+ advantages[prompts == prompt] = (torch.tensor(prompt_rewards) == torch.max(torch.tensor(prompt_rewards))).float().numpy()
+ elif type=='dpo':
+ # Get the advantages of the current prompt
+ prompt_advantages = torch.tensor(prompt_rewards)
+ # Find the indices of the maximum and minimum values
+ max_idx = torch.argmax(prompt_advantages)
+ min_idx = torch.argmin(prompt_advantages)
+ # If all rewards in a group are the same
+ if max_idx == min_idx:
+ min_idx = 0
+ max_idx = 1
+ result = torch.zeros_like(prompt_advantages).float()
+ # Set the maximum index to 1, minimum index to -1
+ result[max_idx] = 1.0
+ result[min_idx] = -1.0
+ advantages[prompts == prompt] = result.numpy()
+ # print("reward difference one group", prompt_advantages[max_idx]-prompt_advantages[min_idx])
+
+ return advantages
+
+ def get_stats(self):
+ avg_group_size = sum(len(v) for v in self.stats.values()) / len(self.stats) if self.stats else 0
+ history_prompts = len(self.history_prompts)
+ return avg_group_size, history_prompts
+
+ def clear(self):
+ self.stats = {}
+
+def main():
+ tracker = PerPromptStatTracker()
+ prompts = ['a', 'b', 'a', 'c', 'b', 'a']
+ rewards = [1, 2, 3, 4, 5, 6]
+ advantages = tracker.update(prompts, rewards)
+ print("Advantages:", advantages)
+ avg_group_size, history_prompts = tracker.get_stats()
+ print("Average Group Size:", avg_group_size)
+ print("History Prompts:", history_prompts)
+ tracker.clear()
+ print("Stats after clear:", tracker.stats)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..85a06606fe4b19e0513a861571441633d5f7ba4c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,223 @@
+import gradio as gr
+import spaces
+import torch
+from diffusers import StableDiffusion3Pipeline
+from adv_grpo.diffusers_patch.sd3_pipeline_with_logprob_fast import pipeline_with_logprob_random as pipeline_with_logprob
+from adv_grpo.diffusers_patch.train_dreambooth_lora_sd3 import encode_prompt
+from adv_grpo.ema import EMAModuleWrapper
+from peft import PeftModel
+from PIL import Image
+import numpy as np
+import os
+from ml_collections import config_flags
+from huggingface_hub import hf_hub_download
+from huggingface_hub import login
+
+login(os.environ["HF_TOKEN"])
+
+
+
+# ---------------------------------------------------------
+# GLOBAL VARIABLES
+# ---------------------------------------------------------
+
+pipeline = None
+config = None
+text_encoders = None
+tokenizers = None
+ema = None
+transformer_trainable_parameters = None
+
+def load_lora_from_subfolder():
+ repo_id = "benzweijia/Adv-GRPO"
+ subfolder = "PickScore"
+
+ local_dir = "/tmp/PickScore"
+ os.makedirs(local_dir, exist_ok=True)
+
+ for filename in ["adapter_config.json", "adapter_model.safetensors"]:
+ hf_hub_download(
+ repo_id=repo_id,
+ repo_type="model",
+ subfolder=subfolder,
+ filename=filename,
+ local_dir=local_dir,
+ force_download=False
+ )
+ # import pdb; pdb.set_trace()
+ return local_dir
+
+# -------------- Load Config ------------------------------
+def load_config():
+ """
+ """
+ import importlib.util
+
+ config_path = "config/base.py"
+ spec = importlib.util.spec_from_file_location("config", config_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module.get_config()
+
+
+# -------------- Embedding Function -----------------------
+def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders, tokenizers, prompt, max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# ---------------------------------------------------------
+# GPU MODEL INITIALIZATION
+# ---------------------------------------------------------
+@spaces.GPU
+def init_model():
+ global pipeline, config, text_encoders, tokenizers, ema, transformer_trainable_parameters
+
+ print("🔥 Loading config...")
+ config = load_config()
+
+ print("🔥 Loading SD3 base model on GPU...")
+ # import pdb; pdb.set_trace()
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-medium"
+ )
+
+ # freeze non-trainable params
+ pipeline.vae.requires_grad_(False)
+ pipeline.text_encoder.requires_grad_(False)
+ pipeline.text_encoder_2.requires_grad_(False)
+ pipeline.text_encoder_3.requires_grad_(False)
+
+ pipeline.transformer.requires_grad_(not config.use_lora)
+
+ text_encoders = [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.text_encoder_3]
+ tokenizers = [pipeline.tokenizer, pipeline.tokenizer_2, pipeline.tokenizer_3]
+
+ pipeline.safety_checker = None
+ pipeline.set_progress_bar_config(disable=True)
+
+ # move to GPU
+ pipeline.vae.to("cuda")
+ pipeline.text_encoder.to("cuda")
+ pipeline.text_encoder_2.to("cuda")
+ pipeline.text_encoder_3.to("cuda")
+ pipeline.transformer.to("cuda")
+ config.train.lora_path = "benzweijia/Adv-GRPO/PickScore"
+ config.use_lora = True
+ lora_dir = load_lora_from_subfolder()
+
+ if config.use_lora and config.train.lora_path:
+ print("🔥 Loading LoRA from:", config.train.lora_path)
+ pipeline.transformer = PeftModel.from_pretrained(
+ pipeline.transformer,
+ os.path.join(lora_dir,"PickScore")
+ )
+ pipeline.transformer.set_adapter("default")
+
+ transformer_trainable_parameters = list(
+ filter(lambda p: p.requires_grad, pipeline.transformer.parameters())
+ )
+
+ # Setup EMA
+ ema = EMAModuleWrapper(
+ transformer_trainable_parameters,
+ decay=0.9,
+ update_step_interval=8,
+ device="cuda"
+ )
+
+ print("✅ Model initialized and ready.")
+
+
+# ---------------------------------------------------------
+# INFERENCE FUNCTION
+# ---------------------------------------------------------
+@spaces.GPU
+def infer(prompt):
+ print("start infer")
+ global pipeline, config
+ print(pipeline)
+
+ if pipeline is None:
+ init_model()
+ print(pipeline)
+ print("start infer 1111")
+
+
+ prompts = [prompt]
+
+ # get prompt embedding
+ prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers,
+ max_sequence_length=128,
+ device="cuda"
+ )
+ print("start infer 2")
+
+ neg_embed, neg_pooled_embed = compute_text_embeddings(
+ [""], text_encoders, tokenizers,
+ max_sequence_length=128,
+ device="cuda"
+ )
+
+ neg_prompt_embeds = neg_embed.repeat(1, 1, 1)
+ neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1)
+ print("start infer 3")
+
+ # generation seed
+ generator = torch.Generator().manual_seed(0)
+
+ with torch.no_grad():
+ images, _, _, _ = pipeline_with_logprob(
+ pipeline,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_prompt_embeds=neg_prompt_embeds,
+ negative_pooled_prompt_embeds=neg_pooled_prompt_embeds,
+ num_inference_steps=config.sample.eval_num_steps,
+ guidance_scale=config.sample.guidance_scale,
+ output_type="pt",
+ height=config.resolution,
+ width=config.resolution,
+ noise_level=0,
+ mini_num_image_per_prompt=1,
+ process_index=0,
+ sample_num_steps=config.sample.num_steps,
+ random_timestep=0,
+ generator=generator,
+ )
+
+ print("images type:", type(images))
+ print("images len:", len(images))
+ print("first image shape:", images[0].shape)
+
+ # Convert to PIL
+ pil = Image.fromarray(
+ (images[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ )
+
+ # Fixed 512x512 for output
+ pil = pil.resize((512, 512))
+
+ return pil
+
+
+# ---------------------------------------------------------
+# GRADIO UI
+# ---------------------------------------------------------
+# init_model()
+
+demo = gr.Interface(
+ fn=infer,
+ inputs=gr.Textbox(lines=2, label="Prompt"),
+ outputs=gr.Image(type="pil"),
+ title="Adv-GRPO(PickScore)",
+ description="Enter a prompt and generate image using Adv-GRPO",
+)
+
+demo.launch()
diff --git a/config/__pycache__/base.cpython-310.pyc b/config/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afa426d488add841064574435e65cd29d880463e
Binary files /dev/null and b/config/__pycache__/base.cpython-310.pyc differ
diff --git a/config/__pycache__/grpo.cpython-310.pyc b/config/__pycache__/grpo.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ed4ee8c72a05c1616160d2d549678ad06b89990
Binary files /dev/null and b/config/__pycache__/grpo.cpython-310.pyc differ
diff --git a/config/base.py b/config/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d969e684edff4fa4403af44ab1c5f2f6f98a88fd
--- /dev/null
+++ b/config/base.py
@@ -0,0 +1,113 @@
+import ml_collections
+
+
+def get_config():
+ config = ml_collections.ConfigDict()
+
+ ###### General ######
+ # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
+ config.run_name = ""
+ # random seed for reproducibility.
+ config.seed = 42
+ # top-level logging directory for checkpoint saving.
+ config.logdir = "logs"
+ # number of epochs between saving model checkpoints.
+ config.save_freq = 20
+ # number of epochs between evaluating the model.
+ config.eval_freq = 20
+ # number of checkpoints to keep before overwriting old ones.
+ config.num_checkpoint_limit = 5
+ # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
+ config.mixed_precision = "fp16"
+ # allow tf32 on Ampere GPUs, which can speed up training.
+ config.allow_tf32 = True
+ # whether or not to use LoRA.
+ config.use_lora = True
+ config.dataset = ""
+ config.resolution = 768
+
+ ###### Pretrained Model ######
+ config.pretrained = pretrained = ml_collections.ConfigDict()
+ # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
+ pretrained.model = "runwayml/stable-diffusion-v1-5"
+ # revision of the model to load.
+ pretrained.revision = "main"
+
+ ###### Sampling ######
+ config.sample = sample = ml_collections.ConfigDict()
+ # number of sampler inference steps for collecting dataset.
+ sample.num_steps = 40
+ # number of sampler inference steps for evaluation.
+ sample.eval_num_steps = 40
+ # classifier-free guidance weight. 1.0 is no guidance.
+ sample.guidance_scale = 4.5
+ # batch size (per GPU!) to use for sampling.
+ sample.train_batch_size = 1
+ sample.num_image_per_prompt = 1
+ sample.test_batch_size = 1
+ # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
+ # batch_size * num_gpus`.
+ sample.num_batches_per_epoch = 2
+ # Whether use all samples in a batch to compute std
+ sample.global_std = True
+ # noise level
+ sample.noise_level = 0.7
+ # Whether to use the same noise for the same prompt
+ sample.same_latent = False
+
+ ###### Training ######
+ config.train = train = ml_collections.ConfigDict()
+ # batch size (per GPU!) to use for training.
+ train.batch_size = 1
+ # whether to use the 8bit Adam optimizer from bitsandbytes.
+ train.use_8bit_adam = False
+ # learning rate.
+ train.learning_rate = 3e-4
+ # Adam beta1.
+ train.adam_beta1 = 0.9
+ # Adam beta2.
+ train.adam_beta2 = 0.999
+ # Adam weight decay.
+ train.adam_weight_decay = 1e-4
+ # Adam epsilon.
+ train.adam_epsilon = 1e-8
+ # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus *
+ # gradient_accumulation_steps`.
+ train.gradient_accumulation_steps = 1
+ # maximum gradient norm for gradient clipping.
+ train.max_grad_norm = 1.0
+ # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one
+ # outer epoch's round of sampling.
+ train.num_inner_epochs = 1
+ # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
+ # sampling will be used during training.
+ train.cfg = True
+ # clip advantages to the range [-adv_clip_max, adv_clip_max].
+ train.adv_clip_max = 5
+ # the PPO clip range.
+ train.clip_range = 1e-4
+ # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
+ # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
+ train.timestep_fraction = 1.0
+ # kl ratio
+ train.beta = 0.0
+ # pretrained lora path
+ train.lora_path = None
+ # save ema model
+ train.ema = False
+
+ ###### Prompt Function ######
+ # prompt function to use. see `prompts.py` for available prompt functions.
+ config.prompt_fn = "imagenet_animals"
+ # kwargs to pass to the prompt function.
+ config.prompt_fn_kwargs = {}
+
+ ###### Reward Function ######
+ # reward function to use. see `rewards.py` for available reward functions.
+ config.reward_fn = ml_collections.ConfigDict()
+ config.save_dir = ''
+
+ ###### Per-Prompt Stat Tracking ######
+ config.per_prompt_stat_tracking = True
+
+ return config
diff --git a/config/dpo.py b/config/dpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..96b2116a4d0ecb20808b7f540e6fdf37d5e96146
--- /dev/null
+++ b/config/dpo.py
@@ -0,0 +1,109 @@
+import ml_collections
+import imp
+import os
+
+base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
+
+def compressibility():
+ config = base.get_config()
+
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.use_lora = True
+
+ config.sample.batch_size = 8
+ config.sample.num_batches_per_epoch = 4
+
+ config.train.batch_size = 4
+ config.train.gradient_accumulation_steps = 2
+
+ # prompting
+ config.prompt_fn = "general_ocr"
+
+ # rewards
+ config.reward_fn = {"jpeg_compressibility": 1}
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def geneval_sd3():
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/geneval")
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 40
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ config.sample.train_batch_size = 24
+ config.sample.num_image_per_prompt = 24
+ config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 14 # This bs is a special design, the test set has a total of 2212, to make gpu_num*bs*n as close as possible to 2212, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+
+ config.train.algorithm = 'dpo'
+ # Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO.
+ config.train.ref_update_step=10000000
+ config.train.batch_size = config.sample.train_batch_size
+ config.train.gradient_accumulation_steps = 1
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.beta = 100
+ config.sample.global_std=True
+ config.train.ema=True
+ config.save_freq = 40 # epoch
+ config.eval_freq = 40
+ config.save_dir = 'logs/geneval/sd3.5-M-dpo'
+ config.reward_fn = {
+ "geneval": 1.0,
+ }
+
+ config.prompt_fn = "geneval"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+def pickscore_sd3():
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 40
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale=4.5
+
+ config.resolution = 512
+ config.sample.train_batch_size = 24
+ config.sample.num_image_per_prompt = 24
+ config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+
+ config.train.algorithm = 'dpo'
+ # Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO.
+ config.train.ref_update_step=10000000
+
+ config.train.batch_size = config.sample.train_batch_size
+ config.train.gradient_accumulation_steps = 1
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.beta = 100
+ config.sample.global_std=True
+ config.train.ema=True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.save_dir = 'logs/pickscore/sd3.5-M-dpo'
+ config.reward_fn = {
+ "pickscore": 1.0,
+ }
+
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def get_config(name):
+ return globals()[name]()
diff --git a/config/grpo.py b/config/grpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ccc08d2332895b7faded12c4645c4a9167bb6c
--- /dev/null
+++ b/config/grpo.py
@@ -0,0 +1,434 @@
+import ml_collections
+import imp
+import os
+
+base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
+
+def compressibility():
+ config = base.get_config()
+
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.use_lora = True
+
+ config.sample.batch_size = 8
+ config.sample.num_batches_per_epoch = 4
+
+ config.train.batch_size = 4
+ config.train.gradient_accumulation_steps = 2
+
+ # prompting
+ config.prompt_fn = "general_ocr"
+
+ # rewards
+ config.reward_fn = {"jpeg_compressibility": 1}
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+
+def dino_cotrain_sd3_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.mixed_precision = "bf16"
+ config.wandb_init = True
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 16
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+ config.sample.random_timestep = 0
+
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.discriminator = "pickscore"
+ config.d_times=10
+ config.d_lr=1e-4
+ config.train.lora_path = None
+ config.tune_layer=-2
+
+ # config.use_lora = False
+ # config.train.learning_rate = 1e-5
+ # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora"
+
+ config.train_d = True
+ config.weight_path = None
+ config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json"
+ config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode"
+ config.test_external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test"
+
+ config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_new_loss_24_9_preprocess"
+ config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_new_loss_16_8_preprocess'
+ # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8'
+ config.reward_fn = {
+ "dino_cotrain":1,
+ }
+ config.eval_reward_fn = {
+ "pickscore":1,
+ "image_similarity": 1
+ }
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+
+def dino_cotrain_sd3_patch_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.mixed_precision = "bf16"
+ config.wandb_init = True
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 16
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+ config.sample.random_timestep = 0
+
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.discriminator = "pickscore"
+ config.d_times=10
+ config.d_lr=1e-4
+ config.train.lora_path = None
+ config.tune_layer=-2
+
+ # config.use_lora = False
+ # config.train.learning_rate = 1e-5
+ # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora"
+
+ config.train_d = True
+ config.weight_path = None
+ config.limit = None
+ config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json"
+ config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode"
+ config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test"
+
+ # config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_geneval.json"
+ # config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_geneval_multinode2"
+ # config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_ocr_test"
+
+ config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_patch_image_loss_73_again"
+ config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_patch_image_loss_73_again'
+ # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8'
+ config.reward_fn = {
+ "dino_patch_cotrain":1,
+ }
+ config.eval_reward_fn = {
+ "pickscore":1,
+ "image_similarity": 1
+ }
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def dino_cotrain_sd3_multi_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.mixed_precision = "bf16"
+ config.wandb_init = False
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 8
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+ config.sample.random_timestep = 0
+
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0.0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.discriminator = "pickscore"
+ config.d_times=10
+ config.d_lr=1e-4
+ config.train.lora_path = None
+ config.tune_layer=(11,)
+ config.temperature = 2
+
+ # config.use_lora = False
+ # config.train.learning_rate = 1e-5
+ # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora"
+
+ config.train_d = True
+ config.weight_path = None
+ config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json"
+ config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode"
+ config.test_external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test"
+
+ config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_multi_image_loss_11_only_patch3_tem_2"
+ config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_multi_image_loss_11_only_patch3_tem_2'
+ # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8'
+ config.reward_fn = {
+ "dino_multi_cotrain":1,
+ }
+ config.eval_reward_fn = {
+ "pickscore":1,
+ "image_similarity": 1
+ }
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+def eval_sd3_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.mixed_precision = "bf16"
+ config.wandb_init = False
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 8
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ # config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+
+ config.sample.test_batch_size = 16
+ config.sample.repeat = 1
+
+ config.sample.random_timestep = 0
+
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0.0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.discriminator = "pickscore"
+ config.d_times=10
+ config.d_lr=1e-4
+ config.tune_layer=-2
+
+ config.train.lora_path = ""
+ config.save_folder = "/mnt/bn/vgfm2/test_dit/weijia/outputs_flowgrpo_test2/sd3_dino_pickscore_test_1"
+ config.train_d = True
+ config.weight_path = None
+ config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json"
+ config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode"
+ config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test"
+
+ config.reward_fn = {
+ "dino_cotrain":1,
+ }
+ config.eval_reward_fn = {
+ "pickscore":1,
+ }
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+
+def pickscore_cotrain_sd3_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.mixed_precision = "bf16"
+ config.wandb_init = True
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 16
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+ config.sample.random_timestep = 0
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0.0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.discriminator = "pickscore"
+ config.d_times=20
+ config.d_lr=5e-6
+ config.train.lora_path = None
+ config.tune_layer=-1
+ # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora"
+
+ config.train_d = True
+ config.weight_path = None
+ config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json"
+ config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode"
+
+ config.case_name = "fast_pickscore_cotrain_lr_5e6_last1_16_8"
+ config.save_dir = 'logs/pickscore/sd3.5-M-fast_pickscore_cotrain_lr_5e6_last1_16_8'
+ # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8'
+ config.reward_fn = {
+ "pickscore_cotrain":1,
+ }
+ config.eval_reward_fn = {
+ "pickscore":1
+ }
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def pickscore_sd3_fast():
+ gpu_number=8
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/ocr")
+
+ config.mixed_precision = "bf16"
+ config.case_name = "fast_1node_16_8_multireward_11"
+ config.wandb_init = True
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 10
+ config.sample.train_num_steps = 2
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ # 这里固定为1
+ config.sample.train_batch_size = 1
+ config.sample.num_image_per_prompt = 16
+ config.sample.mini_num_image_per_prompt = 8
+ # config.sample.mini_num_image_per_prompt = 4
+ config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt))
+ # config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+ config.sample.random_timestep = None
+
+ config.train.batch_size = config.sample.mini_num_image_per_prompt
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.clip_range = 1e-5
+ config.train.beta = 0.0
+ config.sample.global_std = True
+ config.sample.noise_level = 0.8
+ config.train.ema = True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.save_dir = 'logs/pickscore_again/sd3.5-M-fast_1node_16_8_multireward_11_ocr_pickscore'
+ config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images"
+ config.reward_fn = {
+ "pickscore": 0.5,
+ "ocr": 0.5,
+ }
+
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+
+
+def get_config(name):
+ return globals()[name]()
+
diff --git a/config/sft.py b/config/sft.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fefd840b84ad156f9e611b598acc5beb543c681
--- /dev/null
+++ b/config/sft.py
@@ -0,0 +1,109 @@
+import ml_collections
+import imp
+import os
+
+base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
+
+def compressibility():
+ config = base.get_config()
+
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ config.use_lora = True
+
+ config.sample.batch_size = 8
+ config.sample.num_batches_per_epoch = 4
+
+ config.train.batch_size = 4
+ config.train.gradient_accumulation_steps = 2
+
+ # prompting
+ config.prompt_fn = "general_ocr"
+
+ # rewards
+ config.reward_fn = {"jpeg_compressibility": 1}
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def geneval_sd3():
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/geneval")
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 40
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale = 4.5
+
+ config.resolution = 512
+ config.sample.train_batch_size = 24
+ config.sample.num_image_per_prompt = 24
+ config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 14 # This bs is a special design, the test set has a total of 2212, to make gpu_num*bs*n as close as possible to 2212, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+
+ config.train.algorithm = 'sft'
+ # Change ref_update_step to a small number, e.g., 40, to switch to OnlineSFT.
+ config.train.ref_update_step=10000000
+ config.train.batch_size = config.sample.train_batch_size
+ config.train.gradient_accumulation_steps = 1
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.beta = 100
+ config.sample.global_std=True
+ config.train.ema=True
+ config.save_freq = 40 # epoch
+ config.eval_freq = 40
+ config.save_dir = 'logs/geneval/sd3.5-M-sft'
+ config.reward_fn = {
+ "geneval": 1.0,
+ }
+
+ config.prompt_fn = "geneval"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+def pickscore_sd3():
+ config = compressibility()
+ config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")
+
+ # sd3.5 medium
+ config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
+ config.sample.num_steps = 40
+ config.sample.eval_num_steps = 40
+ config.sample.guidance_scale=4.5
+
+ config.resolution = 512
+ config.sample.train_batch_size = 24
+ config.sample.num_image_per_prompt = 24
+ config.sample.num_batches_per_epoch = 1
+ config.sample.test_batch_size = 16 # # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.
+
+ config.train.algorithm = 'sft'
+ # Change ref_update_step to a small number, e.g., 40, to switch to OnlineSFT.
+ config.train.ref_update_step=10000000
+
+ config.train.batch_size = config.sample.train_batch_size
+ config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2
+ config.train.num_inner_epochs = 1
+ config.train.timestep_fraction = 0.99
+ config.train.beta = 100
+ config.sample.global_std=True
+ config.train.ema=True
+ config.save_freq = 60 # epoch
+ config.eval_freq = 60
+ config.save_dir = 'logs/pickscore/sd3.5-M-sft'
+ config.reward_fn = {
+ "pickscore": 1.0,
+ }
+
+ config.prompt_fn = "general_ocr"
+
+ config.per_prompt_stat_tracking = True
+ return config
+
+
+def get_config(name):
+ return globals()[name]()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1cab484b32586f8bb0b14e572bf5ccd09d08a694
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+torch==2.6.0
+torchvision==0.21.0
+transformers==4.54.0
+diffusers==0.33.1
+accelerate>=0.25.0
+peft>=0.6.2
+safetensors
+numpy==1.26.4
+Pillow
+gradio>=4.12.0
+spaces>=0.24.0
+tqdm
+ml_collections
+huggingface-hub
+sentencepiece