diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c181295ffbb693822919fab3b78e7fa27fcb6405 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +adv_grpo/assets/flow_grpo_fast.png filter=lfs diff=lfs merge=lfs -text diff --git a/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc b/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd14ec63f59adaf6c2a71cc495f56561d661c24 Binary files /dev/null and b/adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/discriminator.cpython-310.pyc b/adv_grpo/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..057eed8804426cead5d24583ac1478edb69eecf7 Binary files /dev/null and b/adv_grpo/__pycache__/discriminator.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/ema.cpython-310.pyc b/adv_grpo/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a67a9a31ba71037e9c61adbb849f49e3a7f3c75a Binary files /dev/null and b/adv_grpo/__pycache__/ema.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..113135015d4be509379b6f852fa5824bfe4788c3 Binary files /dev/null and b/adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc b/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3af95ea15e99987014cd96f50a7fc1cf3470770 Binary files /dev/null and b/adv_grpo/__pycache__/inflated_layers.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc b/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a7a54674316b88d1c1c1bdd19efa849f255a3a Binary files /dev/null and b/adv_grpo/__pycache__/inflated_lib.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/ocr.cpython-310.pyc b/adv_grpo/__pycache__/ocr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6a42af366cef7b37766039eac170c00c23088b5 Binary files /dev/null and b/adv_grpo/__pycache__/ocr.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfb4ede079452fc65682fb9e675bb5f4f600e60f Binary files /dev/null and b/adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc b/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a42d7a8696445e14f5bcd4b110194f7b330a24 Binary files /dev/null and b/adv_grpo/__pycache__/pick_score_training.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc b/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d754ee400786b3fb7965c84727681f534026259f Binary files /dev/null and b/adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/prompts.cpython-310.pyc b/adv_grpo/__pycache__/prompts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3797ed3b5ae1faf3ec861941368269f13477898e Binary files /dev/null and b/adv_grpo/__pycache__/prompts.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/rewards.cpython-310.pyc b/adv_grpo/__pycache__/rewards.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..401e17e811fcb9a6c0373c6240848832adef7e02 Binary files /dev/null and b/adv_grpo/__pycache__/rewards.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc b/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3846c5741ce33be3af8421c5c31327110fa5d77 Binary files /dev/null and b/adv_grpo/__pycache__/stat_tracking.cpython-310.pyc differ diff --git a/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc b/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b05c6efd0e877da5d4cca2b63ff9221cce32817 Binary files /dev/null and b/adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc differ diff --git a/adv_grpo/aesthetic_scorer.py b/adv_grpo/aesthetic_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff364c45848fa4a8e9e121b6cae755b90056b64e --- /dev/null +++ b/adv_grpo/aesthetic_scorer.py @@ -0,0 +1,53 @@ +# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py + +from importlib import resources +import torch +import torch.nn as nn +import numpy as np +from transformers import CLIPModel, CLIPProcessor +from PIL import Image + +ASSETS_PATH = resources.files("adv_grpo.assets") + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + def __init__(self, dtype): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + state_dict = torch.load( + ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth") + ) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + @torch.no_grad() + def __call__(self, images): + device = next(self.parameters()).device + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + return self.mlp(embed).squeeze(1) diff --git a/adv_grpo/assets/activities.txt b/adv_grpo/assets/activities.txt new file mode 100644 index 0000000000000000000000000000000000000000..abea0458a5836b50ec85da2f732ff3a7d63b8c3a --- /dev/null +++ b/adv_grpo/assets/activities.txt @@ -0,0 +1,3 @@ +washing the dishes +riding a bike +playing chess \ No newline at end of file diff --git a/adv_grpo/assets/activities_v0.txt b/adv_grpo/assets/activities_v0.txt new file mode 100644 index 0000000000000000000000000000000000000000..abea0458a5836b50ec85da2f732ff3a7d63b8c3a --- /dev/null +++ b/adv_grpo/assets/activities_v0.txt @@ -0,0 +1,3 @@ +washing the dishes +riding a bike +playing chess \ No newline at end of file diff --git a/adv_grpo/assets/flow_grpo_fast.png b/adv_grpo/assets/flow_grpo_fast.png new file mode 100644 index 0000000000000000000000000000000000000000..0de32b018bd203d0f7bb735017796cc25989ab0e --- /dev/null +++ b/adv_grpo/assets/flow_grpo_fast.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35709d674818e29d39728e036479e51b8e015bcc7caf4dc54eb3e4f41cc05ab1 +size 222022 diff --git a/adv_grpo/assets/imagenet_classes.txt b/adv_grpo/assets/imagenet_classes.txt new file mode 100644 index 0000000000000000000000000000000000000000..722c984560c36a6014113deb8106a6cf50050ad5 --- /dev/null +++ b/adv_grpo/assets/imagenet_classes.txt @@ -0,0 +1,1000 @@ +tench, Tinca tinca +goldfish, Carassius auratus +great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +tiger shark, Galeocerdo cuvieri +hammerhead, hammerhead shark +electric ray, crampfish, numbfish, torpedo +stingray +cock +hen +ostrich, Struthio camelus +brambling, Fringilla montifringilla +goldfinch, Carduelis carduelis +house finch, linnet, Carpodacus mexicanus +junco, snowbird +indigo bunting, indigo finch, indigo bird, Passerina cyanea +robin, American robin, Turdus migratorius +bulbul +jay +magpie +chickadee +water ouzel, dipper +kite +bald eagle, American eagle, Haliaeetus leucocephalus +vulture +great grey owl, great gray owl, Strix nebulosa +European fire salamander, Salamandra salamandra +common newt, Triturus vulgaris +eft +spotted salamander, Ambystoma maculatum +axolotl, mud puppy, Ambystoma mexicanum +bullfrog, Rana catesbeiana +tree frog, tree-frog +tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +loggerhead, loggerhead turtle, Caretta caretta +leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +mud turtle +terrapin +box turtle, box tortoise +banded gecko +common iguana, iguana, Iguana iguana +American chameleon, anole, Anolis carolinensis +whiptail, whiptail lizard +agama +frilled lizard, Chlamydosaurus kingi +alligator lizard +Gila monster, Heloderma suspectum +green lizard, Lacerta viridis +African chameleon, Chamaeleo chamaeleon +Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +African crocodile, Nile crocodile, Crocodylus niloticus +American alligator, Alligator mississipiensis +triceratops +thunder snake, worm snake, Carphophis amoenus +ringneck snake, ring-necked snake, ring snake +hognose snake, puff adder, sand viper +green snake, grass snake +king snake, kingsnake +garter snake, grass snake +water snake +vine snake +night snake, Hypsiglena torquata +boa constrictor, Constrictor constrictor +rock python, rock snake, Python sebae +Indian cobra, Naja naja +green mamba +sea snake +horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +diamondback, diamondback rattlesnake, Crotalus adamanteus +sidewinder, horned rattlesnake, Crotalus cerastes +trilobite +harvestman, daddy longlegs, Phalangium opilio +scorpion +black and gold garden spider, Argiope aurantia +barn spider, Araneus cavaticus +garden spider, Aranea diademata +black widow, Latrodectus mactans +tarantula +wolf spider, hunting spider +tick +centipede +black grouse +ptarmigan +ruffed grouse, partridge, Bonasa umbellus +prairie chicken, prairie grouse, prairie fowl +peacock +quail +partridge +African grey, African gray, Psittacus erithacus +macaw +sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser, Mergus serrator +goose +black swan, Cygnus atratus +tusker +echidna, spiny anteater, anteater +platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +wallaby, brush kangaroo +koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +wombat +jellyfish +sea anemone, anemone +brain coral +flatworm, platyhelminth +nematode, nematode worm, roundworm +conch +snail +slug +sea slug, nudibranch +chiton, coat-of-mail shell, sea cradle, polyplacophore +chambered nautilus, pearly nautilus, nautilus +Dungeness crab, Cancer magister +rock crab, Cancer irroratus +fiddler crab +king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +American lobster, Northern lobster, Maine lobster, Homarus americanus +spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +crayfish, crawfish, crawdad, crawdaddy +hermit crab +isopod +white stork, Ciconia ciconia +black stork, Ciconia nigra +spoonbill +flamingo +little blue heron, Egretta caerulea +American egret, great white heron, Egretta albus +bittern +crane +limpkin, Aramus pictus +European gallinule, Porphyrio porphyrio +American coot, marsh hen, mud hen, water hen, Fulica americana +bustard +ruddy turnstone, Arenaria interpres +red-backed sandpiper, dunlin, Erolia alpina +redshank, Tringa totanus +dowitcher +oystercatcher, oyster catcher +pelican +king penguin, Aptenodytes patagonica +albatross, mollymawk +grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +killer whale, killer, orca, grampus, sea wolf, Orcinus orca +dugong, Dugong dugon +sea lion +Chihuahua +Japanese spaniel +Maltese dog, Maltese terrier, Maltese +Pekinese, Pekingese, Peke +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound, Afghan +basset, basset hound +beagle +bloodhound, sleuthhound +bluetick +black-and-tan coonhound +Walker hound, Walker foxhound +English foxhound +redbone +borzoi, Russian wolfhound +Irish wolfhound +Italian greyhound +whippet +Ibizan hound, Ibizan Podenco +Norwegian elkhound, elkhound +otterhound, otter hound +Saluki, gazelle hound +Scottish deerhound, deerhound +Weimaraner +Staffordshire bullterrier, Staffordshire bull terrier +American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier, Sealyham +Airedale, Airedale terrier +cairn, cairn terrier +Australian terrier +Dandie Dinmont, Dandie Dinmont terrier +Boston bull, Boston terrier +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier, Scottish terrier, Scottie +Tibetan terrier, chrysanthemum dog +silky terrier, Sydney silky +soft-coated wheaten terrier +West Highland white terrier +Lhasa, Lhasa apso +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla, Hungarian pointer +English setter +Irish setter, red setter +Gordon setter +Brittany spaniel +clumber, clumber spaniel +English springer, English springer spaniel +Welsh springer spaniel +cocker spaniel, English cocker spaniel, cocker +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog, bobtail +Shetland sheepdog, Shetland sheep dog, Shetland +collie +Border collie +Bouvier des Flandres, Bouviers des Flandres +Rottweiler +German shepherd, German shepherd dog, German police dog, alsatian +Doberman, Doberman pinscher +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard, St Bernard +Eskimo dog, husky +malamute, malemute, Alaskan malamute +Siberian husky +dalmatian, coach dog, carriage dog +affenpinscher, monkey pinscher, monkey dog +basenji +pug, pug-dog +Leonberg +Newfoundland, Newfoundland dog +Great Pyrenees +Samoyed, Samoyede +Pomeranian +chow, chow chow +keeshond +Brabancon griffon +Pembroke, Pembroke Welsh corgi +Cardigan, Cardigan Welsh corgi +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf, grey wolf, gray wolf, Canis lupus +white wolf, Arctic wolf, Canis lupus tundrarum +red wolf, maned wolf, Canis rufus, Canis niger +coyote, prairie wolf, brush wolf, Canis latrans +dingo, warrigal, warragal, Canis dingo +dhole, Cuon alpinus +African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +hyena, hyaena +red fox, Vulpes vulpes +kit fox, Vulpes macrotis +Arctic fox, white fox, Alopex lagopus +grey fox, gray fox, Urocyon cinereoargenteus +tabby, tabby cat +tiger cat +Persian cat +Siamese cat, Siamese +Egyptian cat +cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +lynx, catamount +leopard, Panthera pardus +snow leopard, ounce, Panthera uncia +jaguar, panther, Panthera onca, Felis onca +lion, king of beasts, Panthera leo +tiger, Panthera tigris +cheetah, chetah, Acinonyx jubatus +brown bear, bruin, Ursus arctos +American black bear, black bear, Ursus americanus, Euarctos americanus +ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +sloth bear, Melursus ursinus, Ursus ursinus +mongoose +meerkat, mierkat +tiger beetle +ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +ground beetle, carabid beetle +long-horned beetle, longicorn, longicorn beetle +leaf beetle, chrysomelid +dung beetle +rhinoceros beetle +weevil +fly +bee +ant, emmet, pismire +grasshopper, hopper +cricket +walking stick, walkingstick, stick insect +cockroach, roach +mantis, mantid +cicada, cicala +leafhopper +lacewing, lacewing fly +dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +damselfly +admiral +ringlet, ringlet butterfly +monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +cabbage butterfly +sulphur butterfly, sulfur butterfly +lycaenid, lycaenid butterfly +starfish, sea star +sea urchin +sea cucumber, holothurian +wood rabbit, cottontail, cottontail rabbit +hare +Angora, Angora rabbit +hamster +porcupine, hedgehog +fox squirrel, eastern fox squirrel, Sciurus niger +marmot +beaver +guinea pig, Cavia cobaya +sorrel +zebra +hog, pig, grunter, squealer, Sus scrofa +wild boar, boar, Sus scrofa +warthog +hippopotamus, hippo, river horse, Hippopotamus amphibius +ox +water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +bison +ram, tup +bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +ibex, Capra ibex +hartebeest +impala, Aepyceros melampus +gazelle +Arabian camel, dromedary, Camelus dromedarius +llama +weasel +mink +polecat, fitch, foulmart, foumart, Mustela putorius +black-footed ferret, ferret, Mustela nigripes +otter +skunk, polecat, wood pussy +badger +armadillo +three-toed sloth, ai, Bradypus tridactylus +orangutan, orang, orangutang, Pongo pygmaeus +gorilla, Gorilla gorilla +chimpanzee, chimp, Pan troglodytes +gibbon, Hylobates lar +siamang, Hylobates syndactylus, Symphalangus syndactylus +guenon, guenon monkey +patas, hussar monkey, Erythrocebus patas +baboon +macaque +langur +colobus, colobus monkey +proboscis monkey, Nasalis larvatus +marmoset +capuchin, ringtail, Cebus capucinus +howler monkey, howler +titi, titi monkey +spider monkey, Ateles geoffroyi +squirrel monkey, Saimiri sciureus +Madagascar cat, ring-tailed lemur, Lemur catta +indri, indris, Indri indri, Indri brevicaudatus +Indian elephant, Elephas maximus +African elephant, Loxodonta africana +lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +barracouta, snoek +eel +coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +rock beauty, Holocanthus tricolor +anemone fish +sturgeon +gar, garfish, garpike, billfish, Lepisosteus osseus +lionfish +puffer, pufferfish, blowfish, globefish +abacus +abaya +academic gown, academic robe, judge's robe +accordion, piano accordion, squeeze box +acoustic guitar +aircraft carrier, carrier, flattop, attack aircraft carrier +airliner +airship, dirigible +altar +ambulance +amphibian, amphibious vehicle +analog clock +apiary, bee house +apron +ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +assault rifle, assault gun +backpack, back pack, knapsack, packsack, rucksack, haversack +bakery, bakeshop, bakehouse +balance beam, beam +balloon +ballpoint, ballpoint pen, ballpen, Biro +Band Aid +banjo +bannister, banister, balustrade, balusters, handrail +barbell +barber chair +barbershop +barn +barometer +barrel, cask +barrow, garden cart, lawn cart, wheelbarrow +baseball +basketball +bassinet +bassoon +bathing cap, swimming cap +bath towel +bathtub, bathing tub, bath, tub +beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +beacon, lighthouse, beacon light, pharos +beaker +bearskin, busby, shako +beer bottle +beer glass +bell cote, bell cot +bib +bicycle-built-for-two, tandem bicycle, tandem +bikini, two-piece +binder, ring-binder +binoculars, field glasses, opera glasses +birdhouse +boathouse +bobsled, bobsleigh, bob +bolo tie, bolo, bola tie, bola +bonnet, poke bonnet +bookcase +bookshop, bookstore, bookstall +bottlecap +bow +bow tie, bow-tie, bowtie +brass, memorial tablet, plaque +brassiere, bra, bandeau +breakwater, groin, groyne, mole, bulwark, seawall, jetty +breastplate, aegis, egis +broom +bucket, pail +buckle +bulletproof vest +bullet train, bullet +butcher shop, meat market +cab, hack, taxi, taxicab +caldron, cauldron +candle, taper, wax light +cannon +canoe +can opener, tin opener +cardigan +car mirror +carousel, carrousel, merry-go-round, roundabout, whirligig +carpenter's kit, tool kit +carton +car wheel +cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +cassette +cassette player +castle +catamaran +CD player +cello, violoncello +cellular telephone, cellular phone, cellphone, cell, mobile phone +chain +chainlink fence +chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +chain saw, chainsaw +chest +chiffonier, commode +chime, bell, gong +china cabinet, china closet +Christmas stocking +church, church building +cinema, movie theater, movie theatre, movie house, picture palace +cleaver, meat cleaver, chopper +cliff dwelling +cloak +clog, geta, patten, sabot +cocktail shaker +coffee mug +coffeepot +coil, spiral, volute, whorl, helix +combination lock +computer keyboard, keypad +confectionery, confectionary, candy store +container ship, containership, container vessel +convertible +corkscrew, bottle screw +cornet, horn, trumpet, trump +cowboy boot +cowboy hat, ten-gallon hat +cradle +crane +crash helmet +crate +crib, cot +Crock Pot +croquet ball +crutch +cuirass +dam, dike, dyke +desk +desktop computer +dial telephone, dial phone +diaper, nappy, napkin +digital clock +digital watch +dining table, board +dishrag, dishcloth +dishwasher, dish washer, dishwashing machine +disk brake, disc brake +dock, dockage, docking facility +dogsled, dog sled, dog sleigh +dome +doormat, welcome mat +drilling platform, offshore rig +drum, membranophone, tympan +drumstick +dumbbell +Dutch oven +electric fan, blower +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa, boa +file, file cabinet, filing cabinet +fireboat +fire engine, fire truck +fire screen, fireguard +flagpole, flagstaff +flute, transverse flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn, horn +frying pan, frypan, skillet +fur coat +garbage truck, dustcart +gasmask, respirator, gas helmet +gas pump, gasoline pump, petrol pump, island dispenser +goblet +go-kart +golf ball +golfcart, golf cart +gondola +gong, tam-tam +gown +grand piano, grand +greenhouse, nursery, glasshouse +grille, radiator grille +grocery store, grocery, food market, market +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower, blow dryer, blow drier, hair dryer, hair drier +hand-held computer, hand-held microcomputer +handkerchief, hankie, hanky, hankey +hard disc, hard disk, fixed disk +harmonica, mouth organ, harp, mouth harp +harp +harvester, reaper +hatchet +holster +home theater, home theatre +honeycomb +hook, claw +hoopskirt, crinoline +horizontal bar, high bar +horse cart, horse-cart +hourglass +iPod +iron, smoothing iron +jack-o'-lantern +jean, blue jean, denim +jeep, landrover +jersey, T-shirt, tee shirt +jigsaw puzzle +jinrikisha, ricksha, rickshaw +joystick +kimono +knee pad +knot +lab coat, laboratory coat +ladle +lampshade, lamp shade +laptop, laptop computer +lawn mower, mower +lens cap, lens cover +letter opener, paper knife, paperknife +library +lifeboat +lighter, light, igniter, ignitor +limousine, limo +liner, ocean liner +lipstick, lip rouge +Loafer +lotion +loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +loupe, jeweler's loupe +lumbermill, sawmill +magnetic compass +mailbag, postbag +mailbox, letter box +maillot +maillot, tank suit +manhole cover +maraca +marimba, xylophone +mask +matchstick +maypole +maze, labyrinth +measuring cup +medicine chest, medicine cabinet +megalith, megalithic structure +microphone, mike +microwave, microwave oven +military uniform +milk can +minibus +miniskirt, mini +minivan +missile +mitten +mixing bowl +mobile home, manufactured home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter, scooter +mountain bike, all-terrain bike, off-roader +mountain tent +mouse, computer mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook, notebook computer +obelisk +oboe, hautboy, hautbois +ocarina, sweet potato +odometer, hodometer, mileometer, milometer +oil filter +organ, pipe organ +oscilloscope, scope, cathode-ray oscilloscope, CRO +overskirt +oxcart +oxygen mask +packet +paddle, boat paddle +paddlewheel, paddle wheel +padlock +paintbrush +pajama, pyjama, pj's, jammies +palace +panpipe, pandean pipe, syrinx +paper towel +parachute, chute +parallel bars, bars +park bench +parking meter +passenger car, coach, carriage +patio, terrace +pay-phone, pay-station +pedestal, plinth, footstall +pencil box, pencil case +pencil sharpener +perfume, essence +Petri dish +photocopier +pick, plectrum, plectron +pickelhaube +picket fence, paling +pickup, pickup truck +pier +piggy bank, penny bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate, pirate ship +pitcher, ewer +plane, carpenter's plane, woodworking plane +planetarium +plastic bag +plate rack +plow, plough +plunger, plumber's helper +Polaroid camera, Polaroid Land camera +pole +police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +poncho +pool table, billiard table, snooker table +pop bottle, soda bottle +pot, flowerpot +potter's wheel +power drill +prayer rug, prayer mat +printer +prison, prison house +projectile, missile +projector +puck, hockey puck +punching bag, punch bag, punching ball, punchball +purse +quill, quill pen +quilt, comforter, comfort, puff +racer, race car, racing car +racket, racquet +radiator +radio, wireless +radio telescope, radio reflector +rain barrel +recreational vehicle, RV, R.V. +reel +reflex camera +refrigerator, icebox +remote control, remote +restaurant, eating house, eating place, eatery +revolver, six-gun, six-shooter +rifle +rocking chair, rocker +rotisserie +rubber eraser, rubber, pencil eraser +rugby ball +rule, ruler +running shoe +safe +safety pin +saltshaker, salt shaker +sandal +sarong +sax, saxophone +scabbard +scale, weighing machine +school bus +schooner +scoreboard +screen, CRT screen +screw +screwdriver +seat belt, seatbelt +sewing machine +shield, buckler +shoe shop, shoe-shop, shoe store +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule, slipstick +sliding door +slot, one-armed bandit +snorkel +snowmobile +snowplow, snowplough +soap dispenser +soccer ball +sock +solar dish, solar collector, solar furnace +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web, spider's web +spindle +sports car, sport car +spotlight, spot +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch, stop watch +stove +strainer +streetcar, tram, tramcar, trolley, trolley car +stretcher +studio couch, day bed +stupa, tope +submarine, pigboat, sub, U-boat +suit, suit of clothes +sundial +sunglass +sunglasses, dark glasses, shades +sunscreen, sunblock, sun blocker +suspension bridge +swab, swob, mop +sweatshirt +swimming trunks, bathing trunks +swing +switch, electric switch, electrical switch +syringe +table lamp +tank, army tank, armored combat vehicle, armoured combat vehicle +tape player +teapot +teddy, teddy bear +television, television system +tennis ball +thatch, thatched roof +theater curtain, theatre curtain +thimble +thresher, thrasher, threshing machine +throne +tile roof +toaster +tobacco shop, tobacconist shop, tobacconist +toilet seat +torch +totem pole +tow truck, tow car, wrecker +toyshop +tractor +trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +tray +trench coat +tricycle, trike, velocipede +trimaran +tripod +triumphal arch +trolleybus, trolley coach, trackless trolley +trombone +tub, vat +turnstile +typewriter keyboard +umbrella +unicycle, monocycle +upright, upright piano +vacuum, vacuum cleaner +vase +vault +velvet +vending machine +vestment +viaduct +violin, fiddle +volleyball +waffle iron +wall clock +wallet, billfold, notecase, pocketbook +wardrobe, closet, press +warplane, military plane +washbasin, handbasin, washbowl, lavabo, wash-hand basin +washer, automatic washer, washing machine +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool, woolen, woollen +worm fence, snake fence, snake-rail fence, Virginia fence +wreck +yawl +yurt +web site, website, internet site, site +comic book +crossword puzzle, crossword +street sign +traffic light, traffic signal, stoplight +book jacket, dust cover, dust jacket, dust wrapper +menu +plate +guacamole +consomme +hot pot, hotpot +trifle +ice cream, icecream +ice lolly, lolly, lollipop, popsicle +French loaf +bagel, beigel +pretzel +cheeseburger +hotdog, hot dog, red hot +mashed potato +head cabbage +broccoli +cauliflower +zucchini, courgette +spaghetti squash +acorn squash +butternut squash +cucumber, cuke +artichoke, globe artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple, ananas +banana +jackfruit, jak, jack +custard apple +pomegranate +hay +carbonara +chocolate sauce, chocolate syrup +dough +meat loaf, meatloaf +pizza, pizza pie +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff, drop, drop-off +coral reef +geyser +lakeside, lakeshore +promontory, headland, head, foreland +sandbar, sand bar +seashore, coast, seacoast, sea-coast +valley, vale +volcano +ballplayer, baseball player +groom, bridegroom +scuba diver +rapeseed +daisy +yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +corn +acorn +hip, rose hip, rosehip +buckeye, horse chestnut, conker +coral fungus +agaric +gyromitra +stinkhorn, carrion fungus +earthstar +hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +bolete +ear, spike, capitulum +toilet tissue, toilet paper, bathroom tissue \ No newline at end of file diff --git a/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth b/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth new file mode 100644 index 0000000000000000000000000000000000000000..aae5780851125baf1a30834c3a715d3866858a4d --- /dev/null +++ b/adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0 +size 3714759 diff --git a/adv_grpo/assets/simple_animals.txt b/adv_grpo/assets/simple_animals.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc9e1176a2eb831541d98dcd810b674a36602a78 --- /dev/null +++ b/adv_grpo/assets/simple_animals.txt @@ -0,0 +1,45 @@ +cat +dog +horse +monkey +rabbit +zebra +spider +bird +sheep +deer +cow +goat +lion +tiger +bear +raccoon +fox +wolf +lizard +beetle +ant +butterfly +fish +shark +whale +dolphin +squirrel +mouse +rat +snake +turtle +frog +chicken +duck +goose +bee +pig +turkey +fly +llama +camel +bat +gorilla +hedgehog +kangaroo diff --git a/adv_grpo/assets/simple_ocr_animals.txt b/adv_grpo/assets/simple_ocr_animals.txt new file mode 100644 index 0000000000000000000000000000000000000000..fd39766d1798ae1968138f5bcf3dac66b6705d8b --- /dev/null +++ b/adv_grpo/assets/simple_ocr_animals.txt @@ -0,0 +1,5 @@ +cat +dog +horse +monkey +rabbit \ No newline at end of file diff --git a/adv_grpo/assets/simple_ocr_animals_digit1.txt b/adv_grpo/assets/simple_ocr_animals_digit1.txt new file mode 100644 index 0000000000000000000000000000000000000000..017892c8d121fb6fb70aa8ce9f5f010e30f60052 --- /dev/null +++ b/adv_grpo/assets/simple_ocr_animals_digit1.txt @@ -0,0 +1,45 @@ +A cat holding a sign that says '0' +A dog holding a sign that says '0' +A horse holding a sign that says '0' +A monkey holding a sign that says '0' +A rabbit holding a sign that says '0' +A cat holding a sign that says '1' +A dog holding a sign that says '1' +A horse holding a sign that says '1' +A monkey holding a sign that says '1' +A rabbit holding a sign that says '1' +A cat holding a sign that says '2' +A dog holding a sign that says '2' +A horse holding a sign that says '2' +A monkey holding a sign that says '2' +A rabbit holding a sign that says '2' +A cat holding a sign that says '3' +A dog holding a sign that says '3' +A horse holding a sign that says '3' +A monkey holding a sign that says '3' +A rabbit holding a sign that says '3' +A cat holding a sign that says '4' +A dog holding a sign that says '4' +A horse holding a sign that says '4' +A monkey holding a sign that says '4' +A rabbit holding a sign that says '4' +A cat holding a sign that says '5' +A dog holding a sign that says '5' +A horse holding a sign that says '5' +A monkey holding a sign that says '5' +A rabbit holding a sign that says '5' +A cat holding a sign that says '6' +A dog holding a sign that says '6' +A horse holding a sign that says '6' +A monkey holding a sign that says '6' +A rabbit holding a sign that says '6' +A cat holding a sign that says '7' +A dog holding a sign that says '7' +A horse holding a sign that says '7' +A monkey holding a sign that says '7' +A rabbit holding a sign that says '7' +A cat holding a sign that says '8' +A dog holding a sign that says '8' +A horse holding a sign that says '8' +A monkey holding a sign that says '8' +A rabbit holding a sign that says '8' \ No newline at end of file diff --git a/adv_grpo/assets/simple_ocr_animals_digit3.txt b/adv_grpo/assets/simple_ocr_animals_digit3.txt new file mode 100644 index 0000000000000000000000000000000000000000..382ca9ce7faca644b367493d92754251896735f0 --- /dev/null +++ b/adv_grpo/assets/simple_ocr_animals_digit3.txt @@ -0,0 +1,45 @@ +A cat holding a sign that says '123' +A dog holding a sign that says '234' +A horse holding a sign that says '345' +A monkey holding a sign that says '456' +A rabbit holding a sign that says '567' +A cat holding a sign that says '678' +A dog holding a sign that says '789' +A horse holding a sign that says '123' +A monkey holding a sign that says '234' +A rabbit holding a sign that says '345' +A cat holding a sign that says '456' +A dog holding a sign that says '567' +A horse holding a sign that says '678' +A monkey holding a sign that says '789' +A rabbit holding a sign that says '123' +A cat holding a sign that says '234' +A dog holding a sign that says '345' +A horse holding a sign that says '456' +A monkey holding a sign that says '567' +A rabbit holding a sign that says '678' +A cat holding a sign that says '789' +A dog holding a sign that says '123' +A horse holding a sign that says '234' +A monkey holding a sign that says '345' +A rabbit holding a sign that says '456' +A cat holding a sign that says '567' +A dog holding a sign that says '678' +A horse holding a sign that says '789' +A monkey holding a sign that says '123' +A rabbit holding a sign that says '234' +A cat holding a sign that says '345' +A dog holding a sign that says '456' +A horse holding a sign that says '567' +A monkey holding a sign that says '678' +A rabbit holding a sign that says '789' +A cat holding a sign that says '123' +A dog holding a sign that says '234' +A horse holding a sign that says '345' +A monkey holding a sign that says '456' +A rabbit holding a sign that says '567' +A cat holding a sign that says '678' +A dog holding a sign that says '789' +A horse holding a sign that says '123' +A monkey holding a sign that says '234' +A rabbit holding a sign that says '345' \ No newline at end of file diff --git a/adv_grpo/assets/simple_ocr_animals_digit5.txt b/adv_grpo/assets/simple_ocr_animals_digit5.txt new file mode 100644 index 0000000000000000000000000000000000000000..caa91b383e3d1ac2a2bd02e1c2dd27d42373ff7d --- /dev/null +++ b/adv_grpo/assets/simple_ocr_animals_digit5.txt @@ -0,0 +1,50 @@ +A cat holding a sign that says '12345' +A dog holding a sign that says '23456' +A horse holding a sign that says '34567' +A monkey holding a sign that says '45678' +A rabbit holding a sign that says '56789' +A cat holding a sign that says '54321' +A dog holding a sign that says '65432' +A horse holding a sign that says '76543' +A monkey holding a sign that says '87654' +A rabbit holding a sign that says '98765' +A cat holding a sign that says '12345' +A dog holding a sign that says '23456' +A horse holding a sign that says '34567' +A monkey holding a sign that says '45678' +A rabbit holding a sign that says '56789' +A cat holding a sign that says '54321' +A dog holding a sign that says '65432' +A horse holding a sign that says '76543' +A monkey holding a sign that says '87654' +A rabbit holding a sign that says '98765' +A cat holding a sign that says '12345' +A dog holding a sign that says '23456' +A horse holding a sign that says '34567' +A monkey holding a sign that says '45678' +A rabbit holding a sign that says '56789' +A cat holding a sign that says '54321' +A dog holding a sign that says '65432' +A horse holding a sign that says '76543' +A monkey holding a sign that says '87654' +A rabbit holding a sign that says '98765' +A cat holding a sign that says '12345' +A dog holding a sign that says '23456' +A horse holding a sign that says '34567' +A monkey holding a sign that says '45678' +A rabbit holding a sign that says '56789' +A cat holding a sign that says '54321' +A dog holding a sign that says '65432' +A horse holding a sign that says '76543' +A monkey holding a sign that says '87654' +A rabbit holding a sign that says '98765' +A cat holding a sign that says '12345' +A dog holding a sign that says '23456' +A horse holding a sign that says '34567' +A monkey holding a sign that says '45678' +A rabbit holding a sign that says '56789' +A cat holding a sign that says '54321' +A dog holding a sign that says '65432' +A horse holding a sign that says '76543' +A monkey holding a sign that says '87654' +A rabbit holding a sign that says '98765' \ No newline at end of file diff --git a/adv_grpo/assets/test.jpg b/adv_grpo/assets/test.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fa2a71e3786313e40d76128136eef4afab05ba3 Binary files /dev/null and b/adv_grpo/assets/test.jpg differ diff --git a/adv_grpo/clip_scorer.py b/adv_grpo/clip_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..0eea6458d476d7726e177f27b6af2e0154cd0239 --- /dev/null +++ b/adv_grpo/clip_scorer.py @@ -0,0 +1,97 @@ +# Based on https://github.com/RE-N-Y/imscore/blob/main/src/imscore/preference/model.py + +from importlib import resources +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as T +from transformers import AutoImageProcessor,CLIPProcessor, CLIPModel +import numpy as np +from PIL import Image + +def get_size(size): + if isinstance(size, int): + return (size, size) + elif "height" in size and "width" in size: + return (size["height"], size["width"]) + elif "shortest_edge" in size: + return size["shortest_edge"] + else: + raise ValueError(f"Invalid size: {size}") + +def get_image_transform(processor:AutoImageProcessor): + config = processor.to_dict() + resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity() + crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity() + normalise = T.Normalize(mean=processor.image_mean, std=processor.image_std) if config.get("do_normalize") else nn.Identity() + + return T.Compose([resize, crop, normalise]) + +class ClipScorer(torch.nn.Module): + def __init__(self): + super().__init__() + # self.device="cuda" + self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.tform = get_image_transform(self.processor.image_processor) + self.eval() + + def _process(self, pixels): + dtype = pixels.dtype + pixels = self.tform(pixels) + pixels = pixels.to(dtype=dtype) + + return pixels + + @torch.no_grad() + def __call__(self, pixels, prompts, return_img_embedding=False): + device = next(self.parameters()).device + texts = self.processor(text=prompts, padding='max_length', truncation=True, return_tensors="pt").to(device) + pixels = self._process(pixels).to(device) + outputs = self.model(pixel_values=pixels, **texts) + if return_img_embedding: + return outputs.logits_per_image.diagonal()/30, outputs.image_embeds + return outputs.logits_per_image.diagonal()/30 + + @torch.no_grad() + def image_similarity(self, pixels, ref_pixels): + device = next(self.parameters()).device + pixels = self._process(pixels).to(device) + ref_pixels = self._process(ref_pixels).to(device) + + pixel_embeds = self.model.get_image_features(pixel_values=pixels) + ref_embeds = self.model.get_image_features(pixel_values=ref_pixels) + + pixel_embeds = pixel_embeds / pixel_embeds.norm(p=2, dim=-1, keepdim=True) + ref_embeds = ref_embeds / ref_embeds.norm(p=2, dim=-1, keepdim=True) + + sim = pixel_embeds @ ref_embeds.T + # sim = torch.diagonal(sim, 0) + sim = sim.squeeze(-1) + return sim + + +def main(): + # scorer = ClipScorer( + # device='cuda' + # ) + scorer = ClipScorer( + ) + + images=[ + "assets/test.jpg", + "assets/test.jpg" + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'an image of cat', + 'not an image of cat' + ] + images = [np.array(img) for img in pil_images] + images = np.array(images) + images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW + images = torch.tensor(images, dtype=torch.uint8)/255.0 + print(scorer(images, prompts)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/conv_gradfix.py b/adv_grpo/conv_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..8989fdb3aa251b24fa2df457306dc202d913d723 --- /dev/null +++ b/adv_grpo/conv_gradfix.py @@ -0,0 +1,345 @@ +""" +Custom replacement for `torch.nn.functional.convNd` and `torch.nn.functional.conv_transposeNd` +that supports arbitrarily high order gradients with zero performance penalty. +Modified from https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py +""" + +import contextlib +import warnings +from typing import Optional +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Conv2d, Conv3d + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +# ---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = ( + False # Forcefully disable computation of gradients with respect to the weights. +) + + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +# ---------------------------------------------------------------------------- +class GradFixConv2d(Conv2d): + def __init__(self, *args, use_gradfix: bool = False, **kwargs): + self.use_gradfix = use_gradfix + super().__init__(*args, **kwargs) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + conv_fn = F.conv2d if not self.use_gradfix else convNd + if self.padding_mode != "zeros": + return conv_fn( + F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, + bias, + self.stride, + (0, 0), + self.dilation, + self.groups, + ) + return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + + def forward( + self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None + ) -> Tensor: + weight = self.weight if weight is None else weight + bias = self.bias if bias is None else bias + return self._conv_forward(input, weight, bias) + + +class GradFixConv3d(Conv3d): + def __init__(self, *args, use_gradfix: bool = False, **kwargs): + self.use_gradfix = use_gradfix + super().__init__(*args, **kwargs) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + conv_fn = F.conv3d if not self.use_gradfix else convNd + if self.padding_mode != "zeros": + return conv_fn( + F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, + bias, + self.stride, + (0, 0, 0), + self.dilation, + self.groups, + ) + return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + + def forward( + self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None + ) -> Tensor: + weight = self.weight if weight is None else weight + bias = self.bias if bias is None else bias + return self._conv_forward(input, weight, bias) + + +# ---------------------------------------------------------------------------- + + +def convNd(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + N = weight.ndim - 2 + if _should_use_custom_op(input): + return _conv_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups, + ).apply(input, weight, bias) + return getattr(torch.nn.functional, f"conv{N}d")( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + +def conv_transposeNd( + input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1 +): + N = weight.ndim - 2 + if _should_use_custom_op(input): + return _conv_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ).apply(input, weight, bias) + return getattr(torch.nn.functional, f"conv_transpose{N}d")( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + + +# ---------------------------------------------------------------------------- + + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != "cuda": + return False + if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]): + return True + if torch.__version__.startswith("2"): + return True + warnings.warn( + f"conv2d_gradfix not supported on PyTorch {torch.__version__}. " + f"Falling back to torch.nn.functional.conv2d()." + ) + return False + + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + + +# ---------------------------------------------------------------------------- + +_conv_gradfix_cache = dict() + + +def _conv_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + ndim = len(weight_shape) - 2 + # Parse arguments. + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv_gradfix_cache: + return _conv_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [ + 0, + ] * ndim + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class ConvNd(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + """ + input size: [B, C, ...] + weight size: + -> Conv: [C_out, C_in // groups, ...] + -> Transpose: [C_in, C_out // groups, ...] + """ + assert weight.shape == weight_shape + ctx.save_for_backward(input, weight) + + # General case => cuDNN. + if transpose: + return getattr(torch.nn.functional, f"conv_transpose{ndim}d")( + input=input, + weight=weight.to(input.dtype), + bias=bias, + output_padding=output_padding, + **common_kwargs, + ) + return getattr(torch.nn.functional, f"conv{ndim}d")( + input=input, weight=weight.to(input.dtype), bias=bias, **common_kwargs + ) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: # Input + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + op = _conv_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: # Weight + grad_weight = ConvNdGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: # Bias + grad_bias = grad_output.transpose(0, 1).flatten(1).sum(1) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class ConvNdGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + flags = [ + torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32, + ] + if torch.__version__.startswith("1"): + op = torch._C._jit_get_operation( + "aten::cudnn_convolution_backward_weight" + if not transpose + else "aten::cudnn_convolution_transpose_backward_weight" + ) + grad_weight = op( + weight_shape, + grad_output, + input.to(grad_output.dtype), + padding, + stride, + dilation, + groups, + *flags, + ) + elif torch.__version__.startswith("2"): + # https://github.com/pytorch/pytorch/issues/74437 + op, _ = torch._C._jit_get_operation("aten::convolution_backward") + dummy_weight = torch.tensor( + 0.0, dtype=grad_output.dtype, device=input.device + ).expand(weight_shape) + grad_weight = op( + grad_output, + input.to(grad_output.dtype), + dummy_weight, + None, + stride, + padding, + dilation, + transpose, + (0,) * ndim, + groups, + [False, True, False], + )[1] + else: + raise NotImplementedError + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: # Grad of Weight + grad2_grad_output = ConvNd.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: # Input + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + op = _conv_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv_gradfix_cache[key] = ConvNd + return ConvNd + + +# ---------------------------------------------------------------------------- diff --git a/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abfe709ca3b831dee3e248d3f3b1e6dc27737966 Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc differ diff --git a/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9321923338fe113f51150fe42198a8b84bb851d9 Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc differ diff --git a/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc b/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fee662b9067583e744cdbceda8c7b329ee4de16 Binary files /dev/null and b/adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc differ diff --git a/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py new file mode 100644 index 0000000000000000000000000000000000000000..52f1d57175510c73d118f2becdb024d03139ebf9 --- /dev/null +++ b/adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py @@ -0,0 +1,255 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py + +from typing import Any, Dict, List, Optional, Union, Callable +import torch +import numpy as np +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.utils import logging +from .sd3_sde_with_logprob import sde_step_with_logprob + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +@torch.no_grad() +def pipeline_with_logprob( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + noise_level: float = 0.7, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, latent_ids, image_ids = self.prepare_latents( + image.float(), + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Prepare image embeddings + all_latents = [latents] + all_log_probs = [] + + # 7. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + if noise_pred.isnan().any(): + breakpoint() + print("log_prob is nan") + noise_pred = noise_pred[:, : latents.size(1)] + latents_dtype = latents.dtype + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0).repeat(latents.shape[0]), + latents.float(), + noise_level=noise_level, + ) + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + all_latents.append(latents) + all_log_probs.append(log_prob) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + return image, all_latents, latent_ids, text_ids, all_log_probs, image_latents diff --git a/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py new file mode 100644 index 0000000000000000000000000000000000000000..530ac1b73d206a542dd8c7875e6d027fb631eaff --- /dev/null +++ b/adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py @@ -0,0 +1,187 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py + +from typing import Any, Dict, List, Optional, Union, Callable +import torch +import numpy as np +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from .sd3_sde_with_logprob import sde_step_with_logprob + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +@torch.no_grad() +def pipeline_with_logprob( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + noise_level: float = 0.7, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Prepare image embeddings + all_latents = [latents] + all_log_probs = [] + + # 7. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + latents_dtype = latents.dtype + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0).repeat(latents.shape[0]), + latents.float(), + noise_level=noise_level, + ) + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + all_latents.append(latents) + all_log_probs.append(log_prob) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + return image, all_latents, latent_image_ids, text_ids, all_log_probs diff --git a/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py new file mode 100644 index 0000000000000000000000000000000000000000..2a28e04b835302dbdba57c24d9aad8670e1613f6 --- /dev/null +++ b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py @@ -0,0 +1,198 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +# with the following modifications: +# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`. +# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. +from typing import Any, Dict, List, Optional, Union +import torch +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob + +@torch.no_grad() +def pipeline_with_logprob( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + skip_layer_guidance_scale: float = 2.8, + noise_level: float = 0.7, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + # latents = self.prepare_latents( + # batch_size * num_images_per_prompt, + # num_channels_latents, + # height, + # width, + # prompt_embeds.dtype, + # device, + # generator, + # latents, + # ).float() + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + scheduler_kwargs = {} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare image embeddings + all_latents = [latents] + all_log_probs = [] + # impor ptbd; + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + # import pdb; pdb.set_trace() + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + # noise_pred = noise_pred.to(prompt_embeds.dtype) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents_dtype = latents.dtype + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents.float(), + noise_level=noise_level, + ) + + all_latents.append(latents) + all_log_probs.append(log_prob) + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + return image, all_latents, all_log_probs diff --git a/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ca98c131dd7bf27eeb74b0343c8c7da5d0c155bb --- /dev/null +++ b/adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py @@ -0,0 +1,1081 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +# with the following modifications: +# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`. +# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. +from typing import Any, Dict, List, Optional, Union +import torch +import random +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob +from PIL import Image +from torchvision import transforms + + + +@torch.no_grad() +def pipeline_with_logprob( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + mini_num_image_per_prompt: int = 1, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + skip_layer_guidance_scale: float = 2.8, + noise_level: float = 0.7, + train_num_steps: int = 1, + process_index: int = 0, + sample_num_steps: int = 10, + random_timestep: Optional[int] = None, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + # import pdb; pdb.set_trace() + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ).float() + # import pdb; pdb.set_trace() + # latents = latents.to(prompt_embeds.dtype) + + # 5. Prepare timesteps + scheduler_kwargs = {} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + # timesteps = timesteps.to(prompt_embeds.dtype) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + random.seed(process_index) + if random_timestep is None: + random_timestep = random.randint(0, sample_num_steps//2) + + + # 6. Prepare image embeddings + all_latents = [] + all_log_probs = [] + all_timesteps = [] + + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # 7. Denoising loop + # import pdb; pdb.set_trace() + with self.progress_bar(total=num_inference_steps) as progress_bar: + # import pdb; pdb.set_trace() + for i, t in enumerate(timesteps): + if i < random_timestep: + cur_noise_level = 0 + elif i == random_timestep: + cur_noise_level= noise_level + # 将latents repeat mini_num_image_per_prompt次 + latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1) + prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + all_latents.append(latents) + elif i > random_timestep and i < random_timestep + train_num_steps: + cur_noise_level = noise_level + else: + cur_noise_level= 0 + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + # import pdb; pdb.set_trace() + # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0] + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=tem_prompt_embeds, + pooled_projections=tem_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + # noise_pred = noise_pred.to(prompt_embeds.dtype) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents_dtype = latents.dtype + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents.float(), + noise_level=cur_noise_level, + ) + + # if latents.dtype != latents_dtype: + # latents = latents.to(latents_dtype) + + if i >= random_timestep and i < random_timestep + train_num_steps: + all_latents.append(latents) + all_log_probs.append(log_prob) + all_timesteps.append(t.repeat(len(latents))) + # import pdb; pdb.set_trace() + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + reconstructd_image = self.image_processor.postprocess(image, output_type="pil") + # reconstructd_image[0].save("0.png") + # import pdb; pdb.set_trace() + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + return image, all_latents, all_log_probs, all_timesteps + + + +@torch.no_grad() +def pipeline_with_logprob_new( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + mini_num_image_per_prompt: int = 1, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + skip_layer_guidance_scale: float = 2.8, + noise_level: float = 0.7, + train_num_steps: int = 1, + process_index: int = 0, + sample_num_steps: int = 10, + random_timestep: Optional[int] = None, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + # import pdb; pdb.set_trace() + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + # import pdb; pdb.set_trace() + + self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + # import pdb; pdb.set_trace() + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + # import pdb; pdb.set_trace() + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # import pdb; pdb.set_trace() + # latents = latents.to(prompt_embeds.dtype) + + # 5. Prepare timesteps + scheduler_kwargs = {} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + # timesteps = timesteps.to(prompt_embeds.dtype) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + random.seed(process_index) + if random_timestep is None: + random_timestep = random.randint(0, sample_num_steps//2) + + + # 6. Prepare image embeddings + all_latents = [] + all_log_probs = [] + all_timesteps = [] + # import pdb; pdb.set_trace() + + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # 7. Denoising loop + # import pdb; pdb.set_trace() + with self.progress_bar(total=num_inference_steps) as progress_bar: + # import pdb; pdb.set_trace() + for i, t in enumerate(timesteps): + if i < random_timestep: + cur_noise_level = 0 + elif i == random_timestep: + cur_noise_level= noise_level + # 将latents repeat mini_num_image_per_prompt次 + latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1) + prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + all_latents.append(latents) + elif i > random_timestep and i < random_timestep + train_num_steps: + cur_noise_level = noise_level + else: + cur_noise_level= 0 + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + # import pdb; pdb.set_trace() + # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0] + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=tem_prompt_embeds, + pooled_projections=tem_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + # noise_pred = noise_pred.to(prompt_embeds.dtype) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents_dtype = latents.dtype + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents.float(), + noise_level=cur_noise_level, + ) + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if i >= random_timestep and i < random_timestep + train_num_steps: + all_latents.append(latents) + all_log_probs.append(log_prob) + all_timesteps.append(t.repeat(len(latents))) + # import pdb; pdb.set_trace() + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + return image, all_latents, all_log_probs, all_timesteps + + + + +@torch.no_grad() +def pipeline_with_logprob_random( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + mini_num_image_per_prompt: int = 1, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + skip_layer_guidance_scale: float = 2.8, + noise_level: float = 0.7, + train_num_steps: int = 1, + process_index: int = 0, + sample_num_steps: int = 10, + random_timestep: Optional[int] = None, +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + # import pdb; pdb.set_trace() + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + # import pdb; pdb.set_trace() + + self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + # import pdb; pdb.set_trace() + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + # import pdb; pdb.set_trace() + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + prompt_embeds.shape[0], + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # import pdb; pdb.set_trace() + # latents = latents.to(prompt_embeds.dtype) + + # 5. Prepare timesteps + scheduler_kwargs = {} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + # timesteps = timesteps.to(prompt_embeds.dtype) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + random.seed(process_index) + if random_timestep is None: + random_timestep = random.randint(0, sample_num_steps//2) + + + # 6. Prepare image embeddings + all_latents = [] + all_log_probs = [] + all_timesteps = [] + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # 7. Denoising loop + # import pdb; pdb.set_trace() + with self.progress_bar(total=num_inference_steps) as progress_bar: + # import pdb; pdb.set_trace() + for i, t in enumerate(timesteps): + if i < random_timestep: + cur_noise_level = 0 + elif i == random_timestep: + cur_noise_level= noise_level + # 将latents repeat mini_num_image_per_prompt次 + # latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1) + # prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + # pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + # negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1) + # negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1) + # if self.do_classifier_free_guidance: + # tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + all_latents.append(latents) + elif i > random_timestep and i < random_timestep + train_num_steps: + cur_noise_level = noise_level + else: + cur_noise_level= 0 + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + # import pdb; pdb.set_trace() + # noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0] + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=tem_prompt_embeds, + pooled_projections=tem_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + # noise_pred = noise_pred.to(prompt_embeds.dtype) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents_dtype = latents.dtype + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents.float(), + noise_level=cur_noise_level, + ) + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if i >= random_timestep and i < random_timestep + train_num_steps: + all_latents.append(latents) + all_log_probs.append(log_prob) + all_timesteps.append(t.repeat(len(latents))) + # import pdb; pdb.set_trace() + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + return image, all_latents, all_log_probs, all_timesteps + + + +def move_scheduler_to_device(scheduler, device="cuda"): + for attr_name in dir(scheduler): + attr = getattr(scheduler, attr_name) + if isinstance(attr, torch.Tensor): + setattr(scheduler, attr_name, attr.to(device)) + return scheduler + + +def image_to_latent(pipe, images: Union[Image.Image, List[Image.Image]], device="cuda"): + # 统一转 list + if isinstance(images, Image.Image): + images = [images] + + preprocess = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.ToTensor(), # 转 [0,1] + transforms.Normalize([0.5], [0.5]) # 映射到 [-1,1] + ]) + + # 批量处理 + img_tensors = [preprocess(img) for img in images] # list of [3,512,512] + img_tensor = torch.stack(img_tensors, dim=0).to(device, dtype=torch.float32) # [B,3,512,512] + # import pdb; pdb.set_trace() + + # 过 VAE 编码 + latent = pipe.vae.encode(img_tensor).latent_dist.sample() + latent = latent * pipe.vae.config.scaling_factor + return latent.to(torch.bfloat16) # [B,4,64,64] (假设512输入,缩小8倍) + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + + +@torch.no_grad() +def flux_to_sd3_denoise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + flux_images=None, + device="cuda", + output_type: Optional[str] = "pil", + num_inference_steps: int = 20, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 256, + noise_level: float = 0.7, + random_timestep: Optional[int] = None, + noise_timestep_ratio: float = 0.4, + clip_skip: Optional[int] = None, +): + """ + 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪 + 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps + """ + # 1. 转 latent + flux_latent = image_to_latent(self, flux_images, device) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + + # 2. 准备 scheduler + noise_scheduler = self.scheduler + noise_scheduler.set_timesteps(num_inference_steps) + timesteps = noise_scheduler.timesteps.to(device) + + # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device) + target_idx = torch.tensor([noise_timestep_ratio], device=device) + t = timesteps[target_idx].to(device) + + noise = torch.randn_like(flux_latent) + sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype) + latents = (1.0 - sigmas) * flux_latent + sigmas * noise + num_channels_latents = self.transformer.config.in_channels + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # latents = self.prepare_latents( + # batch_size, + # num_channels_latents, + # 512, + # 512, + # prompt_embeds.dtype, + # device, + # None, + # None, + # ) + + + + # import pdb; pdb.set_trace() + + # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype) + + # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0] + # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0] + + # 保存到本地 + # noisy_image.save("noisy_image.png") + # import pdb; pdb.set_trace() + + # 4. Encode prompts (对齐 pipeline_with_logprob 的处理) + # lora_scale = ( + # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + # ) + lora_scale = None + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + # import pdb; pdb.set_trace() + + + prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1) + + + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + tem_prompt_embeds = prompt_embeds + tem_pooled_prompt_embeds = pooled_prompt_embeds + + # 5. 从当前 t 开始去噪 + noise_scheduler.set_timesteps(num_inference_steps) + timesteps = noise_scheduler.timesteps.to(device) + start_idx = (timesteps >= t[0]).nonzero()[-1].item() + timesteps = timesteps[start_idx:] + + all_latents, all_log_probs, all_timesteps = [], [], [] + noise_scheduler = move_scheduler_to_device(noise_scheduler, device) + + for index, t_cur in enumerate(timesteps): + # import pdb; pdb.set_trace() + if index==0: + all_latents.append(latents) + + if index<2: + cur_noise_level = noise_level + else: + cur_noise_level = 0.0 + + latent_model_input = ( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + ) + t_input = t_cur.expand(latent_model_input.shape[0]).to(device) + + latents_dtype = latents.dtype + model_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_input, + encoder_hidden_states=tem_prompt_embeds, + pooled_projections=tem_pooled_prompt_embeds, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # import pdb; pdb.set_trace() + + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + noise_scheduler, + model_pred.float(), + t_cur.repeat(len(latents)), + latents.float(), + noise_level=noise_level, + ) + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if index>=0 and index<2: + # if index<2: + # print(model_pred) + all_latents.append(latents) + all_log_probs.append(log_prob) + all_timesteps.append(t_cur.repeat(len(latents))) + # import pdb; pdb.set_trace() + + # 6. 最终解码 + denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + denoised_latent = denoised_latent.to(dtype=self.vae.dtype) + + image = self.vae.decode(denoised_latent, return_dict=False)[0] + # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + return image, all_latents, all_log_probs, all_timesteps + + + + + +@torch.no_grad() +def flux_to_sd3_denoise_random( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + flux_images=None, + device="cuda", + output_type: Optional[str] = "pil", + num_inference_steps: int = 20, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 256, + noise_level: float = 0.7, + random_timestep: Optional[int] = None, + noise_timestep_ratio: float = 0.4, + clip_skip: Optional[int] = None, +): + """ + 用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪 + 输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps + """ + # 1. 转 latent + flux_latent = image_to_latent(self, flux_images, device) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + + # 2. 准备 scheduler + noise_scheduler = self.scheduler + noise_scheduler.set_timesteps(num_inference_steps) + timesteps = noise_scheduler.timesteps.to(device) + + # target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device) + # t = timesteps[target_idx].to(device) + + # noise = torch.randn_like(flux_latent) + # sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype) + # latents = (1.0 - sigmas) * flux_latent + sigmas * noise + + target_idx = torch.tensor([random.randint(5, 10)], device=device) + t = timesteps[target_idx].to(device) + # 生成标准高斯噪声 + noise = torch.randn_like(flux_latent) + # 获取对应的 sigma + sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype) + # 给 latent 加噪 + latents = (1.0 - sigmas) * flux_latent + sigmas * noise + + # import pdb; pdb.set_trace() + + # noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + # noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype) + + # noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0] + # noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0] + + # 保存到本地 + # noisy_image.save("noisy_image.png") + # import pdb; pdb.set_trace() + + # 4. Encode prompts (对齐 pipeline_with_logprob 的处理) + # lora_scale = ( + # self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + # ) + lora_scale = None + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + # import pdb; pdb.set_trace() + + + prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1) + + + if self.do_classifier_free_guidance: + tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + tem_prompt_embeds = prompt_embeds + tem_pooled_prompt_embeds = pooled_prompt_embeds + + # 5. 从当前 t 开始去噪 + noise_scheduler.set_timesteps(num_inference_steps) + timesteps = noise_scheduler.timesteps.to(device) + start_idx = (timesteps >= t[0]).nonzero()[-1].item() + timesteps = timesteps[start_idx:] + + all_latents, all_log_probs, all_timesteps = [], [], [] + noise_scheduler = move_scheduler_to_device(noise_scheduler, device) + + for index, t_cur in enumerate(timesteps): + if index==0: + all_latents.append(latents) + + latent_model_input = ( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + ) + t_input = t_cur.expand(latent_model_input.shape[0]).to(device) + + latents_dtype = latents.dtype + model_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_input, + encoder_hidden_states=tem_prompt_embeds, + pooled_projections=tem_pooled_prompt_embeds, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # import pdb; pdb.set_trace() + + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + noise_scheduler, + model_pred.float(), + t_cur.repeat(len(latents)), + latents.float(), + noise_level=noise_level, + ) + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + # if index>=2 and index<4: + if index<2: + # print(model_pred) + all_latents.append(latents) + all_log_probs.append(log_prob) + all_timesteps.append(t_cur.repeat(len(latents))) + # import pdb; pdb.set_trace() + + # 6. 最终解码 + denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + denoised_latent = denoised_latent.to(dtype=self.vae.dtype) + + image = self.vae.decode(denoised_latent, return_dict=False)[0] + # reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + return image, all_latents, all_log_probs, all_timesteps diff --git a/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py b/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py new file mode 100644 index 0000000000000000000000000000000000000000..47ae91bbd2160b653ddbc663390ebe912c4c46fd --- /dev/null +++ b/adv_grpo/diffusers_patch/sd3_sde_with_logprob.py @@ -0,0 +1,139 @@ +# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py +# We adapt it from flow to flow matching. + +import math +from typing import Optional, Union +import torch + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + + + +def sde_step_with_logprob( + self: FlowMatchEulerDiscreteScheduler, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + noise_level: float = 0.7, + prev_sample: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, +): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow + process from the learned model outputs (most often the predicted velocity). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned flow model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + """ + # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 + model_output=model_output.float() + sample=sample.float() + if prev_sample is not None: + prev_sample=prev_sample.float() + + step_index = [self.index_for_timestep(t) for t in timestep] + prev_step_index = [step+1 for step in step_index] + sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) + sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) + sigma_max = self.sigmas[1].item() + dt = sigma_prev - sigma + + std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level + # import pdb; pdb.set_trace() + + # our sde + prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise + + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2)) + - torch.log(std_dev_t * torch.sqrt(-1*dt)) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return prev_sample, log_prob, prev_sample_mean, std_dev_t + + + +def sde_step_with_logprob_new( + self: FlowMatchEulerDiscreteScheduler, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + noise_level: float = 0.7, + prev_sample: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, +): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow + process from the learned model outputs (most often the predicted velocity). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned flow model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + """ + # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 + model_output=model_output.float() + sample=sample.float() + if prev_sample is not None: + prev_sample=prev_sample.float() + + step_index = [self.index_for_timestep(t) for t in timestep] + prev_step_index = [step+1 for step in step_index] + sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) + sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) + sigma_max = self.sigmas[1].item() + dt = sigma_prev - sigma + + # Flow-SDE + #std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt) + # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt + + # Flow-CPS + std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper + pred_original_sample = sample - sigma * model_output # predicted x_0 in paper + noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper + prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2) + # import pdb; pdb.set_trace() + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # remove all constants + log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) + + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return prev_sample, log_prob, prev_sample_mean, std_dev_t \ No newline at end of file diff --git a/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py b/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..1100e280c247293707be5bbe2c5f08a27c6793b9 --- /dev/null +++ b/adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import torch + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, + ) + + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids diff --git a/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py b/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..cd908170bbc2d773109cebb799b2a4d1042335e6 --- /dev/null +++ b/adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import torch + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + clip_tokenizers = tokenizers[:2] + clip_text_encoders = text_encoders[:2] + + clip_prompt_embeds_list = [] + clip_pooled_prompt_embeds_list = [] + for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): + prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + device=device if device is not None else text_encoder.device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, + ) + clip_prompt_embeds_list.append(prompt_embeds) + clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) + + clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) + pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) + + t5_prompt_embed = _encode_prompt_with_t5( + text_encoders[-1], + tokenizers[-1], + max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, + device=device if device is not None else text_encoders[-1].device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + + return prompt_embeds, pooled_prompt_embeds diff --git a/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py b/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b5f99a36db6a8c5317ad32d5bb44976678156b --- /dev/null +++ b/adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py @@ -0,0 +1,373 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +import math +import numpy as np +# import logger + +def sde_step_with_logprob( + self: UniPCMultistepScheduler, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + prev_sample: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + determistic: bool = False, + return_pixel_log_prob: bool = False, + return_dt_and_std_dev_t: bool = False +): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow + process from the learned model outputs (most often the predicted velocity). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned flow model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + """ + # prev_sample_mean, we must convert all variable to fp32 + model_output=model_output.float() + sample=sample.float() + if prev_sample is not None: + prev_sample=prev_sample.float() + + step_index = [self.index_for_timestep(t) for t in timestep] + prev_step_index = [step+1 for step in step_index] + + self.sigmas = self.sigmas.to(sample.device) + sigma = self.sigmas[step_index].view(-1, 1, 1, 1, 1) + sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1, 1) + sigma_max = self.sigmas[1].item() + sigma_min = self.sigmas[-1].item() + dt = sigma_prev - sigma + + std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma + prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise + + # No noise is added during evaluation + if determistic: + prev_sample = sample + dt * model_output + + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2)) + - torch.log(std_dev_t * torch.sqrt(-1*dt)) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + if return_dt_and_std_dev_t: + return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt) + return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt) + +def wan_pipeline_with_logprob( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + determistic: bool = False, + kl_reward: float = 0.0, + return_pixel_log_prob: bool = False, +): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + print( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + all_latents = [latents] + all_log_probs = [] + all_kl = [] + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + # print(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latents_ori = latents.clone() + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.to(prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents.float(), + determistic=determistic, + return_pixel_log_prob=return_pixel_log_prob + ) + prev_latents = latents.clone() + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # compute the previous noisy sample x_t -> x_t-1 + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # use kl_reward & is sampling process + if kl_reward>0 and not determistic: + latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori + with self.transformer.disable_adapter(): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.to(prompt_embeds.dtype) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + _, ref_log_prob, ref_prev_latents_mean, ref_std_dev_t = sde_step_with_logprob( + self.scheduler, + noise_pred.float(), + t.unsqueeze(0), + latents_ori.float(), + prev_sample=prev_latents.float(), + determistic=determistic, + ) + assert std_dev_t == ref_std_dev_t + kl = (prev_latents_mean - ref_prev_latents_mean)**2 / (2 * std_dev_t**2) + kl = kl.mean(dim=tuple(range(1, kl.ndim))) + all_kl.append(kl) + else: + # no kl reward, we do not need to compute, just put a pre-position value, kl will be 0 + all_kl.append(torch.zeros(len(latents), device=latents.device)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # if XLA_AVAILABLE: + # xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (video, all_latents, all_log_probs, all_kl) + + return WanPipelineOutput(frames=video), all_latents, all_log_probs, all_kl diff --git a/adv_grpo/diffusers_patch/wan_prompt_embedding.py b/adv_grpo/diffusers_patch/wan_prompt_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..cc73b64fea6889054ec5cf72557bfa3e26847a3d --- /dev/null +++ b/adv_grpo/diffusers_patch/wan_prompt_embedding.py @@ -0,0 +1,97 @@ +import torch +from typing import Any, Callable, Dict, List, Optional, Union + +def _get_t5_prompt_embeds( + text_encoder, + tokenizer, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 226, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + +def encode_prompt( + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + max_sequence_length: int = 226, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = text_encoder[0].device + dtype = text_encoder[0].dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds = _get_t5_prompt_embeds( + text_encoder=text_encoder[0], + tokenizer=tokenizer[0], + prompt=prompt, + max_sequence_length=max_sequence_length, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=dtype, + ) + + return prompt_embeds \ No newline at end of file diff --git a/adv_grpo/ema.py b/adv_grpo/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..dcdfc2caf0f1b4e11c85965f7747bdc4f52833cb --- /dev/null +++ b/adv_grpo/ema.py @@ -0,0 +1,88 @@ +# Copied from another repo, but I can't remember exactly which one. + +from collections.abc import Iterable + +import torch + + +class EMAModuleWrapper: + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + update_step_interval: int = 1, + device: torch.device | None = None, + ): + parameters = list(parameters) + self.ema_parameters = [p.clone().detach().to(device) for p in parameters] + + self.temp_stored_parameters = None + + self.decay = decay + self.update_step_interval = update_step_interval + self.device = device + + # TODO: add an automatic decay calculation based on this formula: + # The impact of the last n steps can be calculated as: + # impact = 1-(decay^n) + # The number of steps needed to reach a specific impact is: + # n = log_decay(1-impact) + # The decay needed to reach a specific impact after n steps is: + # decay = (1-impact)^(1/n) + + def get_current_decay(self, optimization_step) -> float: + return min( + (1 + optimization_step) / (10 + optimization_step), + self.decay + ) + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step): + parameters = list(parameters) + + one_minus_decay = 1 - self.get_current_decay(optimization_step) + + if (optimization_step + 1) % self.update_step_interval == 0: + for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True): + if parameter.requires_grad: + if ema_parameter.device == parameter.device: + ema_parameter.add_(one_minus_decay * (parameter - ema_parameter)) + else: + # in place calculations to save memory + parameter_copy = parameter.detach().to(ema_parameter.device) + parameter_copy.sub_(ema_parameter) + parameter_copy.mul_(one_minus_decay) + ema_parameter.add_(parameter_copy) + del parameter_copy + + def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None: + self.device = device + self.ema_parameters = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.ema_parameters + ] + + def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True) -> None: + if store_temp: + self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters] + + parameters = list(parameters) + for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True): + parameter.data.copy_(ema_parameter.to(parameter.device).data) + + def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True): + parameter.data.copy_(temp_parameter.data) + + self.temp_stored_parameters = None + + def load_state_dict(self, state_dict: dict) -> None: + self.decay = self.decay if self.decay else state_dict.get("decay", self.decay) + self.ema_parameters = state_dict.get("ema_parameters") + self.to(self.device) + + def state_dict(self) -> dict: + return { + "decay": self.decay, + "ema_parameters": self.ema_parameters, + } diff --git a/adv_grpo/imagereward_scorer.py b/adv_grpo/imagereward_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab192eb17ee54474c592e98a91945132365fab46 --- /dev/null +++ b/adv_grpo/imagereward_scorer.py @@ -0,0 +1,40 @@ +from transformers import AutoProcessor, AutoModel +from PIL import Image +import torch +import ImageReward as RM + +class ImageRewardScorer(torch.nn.Module): + def __init__(self, device="cuda", dtype=torch.float32): + super().__init__() + self.model_path = "ImageReward-v1.0" + self.device = device + self.dtype = dtype + self.model = RM.load(self.model_path, device=device).eval().to(dtype=dtype) + self.model.requires_grad_(False) + + @torch.no_grad() + def __call__(self, prompts, images): + rewards = [] + for prompt,image in zip(prompts, images): + _, reward = self.model.inference_rank(prompt, [image]) + rewards.append(reward) + return rewards + +# Usage example +def main(): + scorer = ImageRewardScorer( + device="cuda", + dtype=torch.float32 + ) + + images=[ + "astronaut.jpg", + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + print(scorer(prompts, pil_images)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/inflated_layers.py b/adv_grpo/inflated_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eb14e9ae7637821db6bf58bd8149aca8b0b16b --- /dev/null +++ b/adv_grpo/inflated_layers.py @@ -0,0 +1,305 @@ +from functools import partial +from typing import Literal +from einops import rearrange +from torch import Tensor +from torch.nn import ConvTranspose2d, ConvTranspose3d + +from flow_grpo.inflated_lib import ( + MemoryState, + extend_head, + inflate_bias, + inflate_distribution_bias, + inflate_distribution_weight, + inflate_weight, + modify_state_dict, +) +from flow_grpo.conv_gradfix import GradFixConv2d, GradFixConv3d + +VERBOSE = False + +_inflation_mode_t = (Literal["none", "flatten", "partial_flatten", "pad", "tile"],) +_direction_t = Literal["", "out", "in"] + + +class InflatedCausalConv3d(GradFixConv3d): + """ + Note: + To align the behavior of pretrained 2D models, + if you compose a video clip from a single image by: + - duplicating: set shape_norm = True + - padding zeros: set shape_norm = False + to avoid gaps in the beginning of training process. + """ + + def __init__( + self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs + ): + self.shape_norm = shape_norm + self.inflation_mode = inflation_mode + self.padding_bank = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. + + def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: + bank_size = self.stride[0] - self.kernel_size[0] + padding_bank = ( + input[:, :, bank_size:].detach() + if (bank_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.padding_bank) + else: + input = extend_head(input, times=self.temporal_padding * 2) + if memory_state != MemoryState.DISABLED and not self.training: + self.padding_bank = padding_bank + return super().forward(input) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.inflation_mode == "none": + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + else: + # NOTE: need to switch off strict + super()._load_from_state_dict( + modify_state_dict( + self, + state_dict, + prefix, + verbose=VERBOSE, + inflate_weight_fn=partial(inflate_weight, position="tail"), + inflate_bias_fn=partial(inflate_bias, position="tail"), + ), + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class InflatedDistributionCausalConv3d(GradFixConv3d): + """ + Note: + Direction: + - out: this layer generates mean/std of some distribution; + - in: this layer takes tensors sampled from output of `out` layer as input. + """ + + def __init__( + self, + *args, + direction: _direction_t, + inflation_mode: _inflation_mode_t, + shape_norm: bool = True, + **kwargs, + ): + self.shape_norm = shape_norm + self.inflation_mode = inflation_mode + self.direction = direction + self.padding_bank = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. + + def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: + bank_size = self.stride[0] - self.kernel_size[0] + padding_bank = ( + input[:, :, bank_size:].detach() + if (bank_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.padding_bank) + else: + input = extend_head(input, times=self.temporal_padding * 2) + if memory_state != MemoryState.DISABLED and not self.training: + self.padding_bank = padding_bank + return super().forward(input) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.inflation_mode == "none": + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + else: + super()._load_from_state_dict( + modify_state_dict( + self, + state_dict, + prefix, + verbose=VERBOSE, + inflate_weight_fn=partial( + inflate_distribution_weight, direction=self.direction, position="tail" + ), + inflate_bias_fn=partial( + inflate_distribution_bias, direction=self.direction, position="tail" + ), + ), + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class InflatedConvTranspose3d(ConvTranspose3d): + # Note: It's not a causal one. + def __init__( + self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs + ): + self.shape_norm = shape_norm + self.inflation_mode = inflation_mode + super().__init__(*args, **kwargs) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.inflation_mode == "none": + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + else: + # NOTE: need to switch off strict + super()._load_from_state_dict( + modify_state_dict( + self, + state_dict, + prefix, + verbose=VERBOSE, + inflate_weight_fn=partial(inflate_weight, position="center"), + inflate_bias_fn=partial(inflate_bias, position="center"), + ), + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class FlattenedConvTranspose3d(ConvTranspose2d): + def forward(self, input: Tensor, **kwargs) -> Tensor: + output = rearrange(input, "b c f h w -> (b f) c h w") + output = super().forward(output) + output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2)) + return output + + +class FlattenedConv3d(GradFixConv2d): + def forward(self, input: Tensor, **kwargs) -> Tensor: + output = rearrange(input, "b c f h w -> (b f) c h w") + output = super().forward(output) + output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2)) + return output + + +def init_causal_conv3d( + *args, + inflation_mode: _inflation_mode_t, + direction: _direction_t = "", + partial_switch: bool = False, + **kwargs, +): + """ + Initialize a Causal-3D convolution layer. + Parameters: + inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. + - none: No inflation will be conducted. + The loading logic of state dict will fall back to default. + - flatten: It will produce a `fake` 3D layer, + which simply squeeze the axis of batch size and depth together, + and then conduct 2D convolution. + - partial_flatten: + - layers with `partial_switch` on: using `none` mode. + - layers with `partial_switch` off: using `flatten` mode. + - pad / tile: Refer to the definition of `InflatedCausalConv3d`. + direction: + - empty string: Ordinary causal convolution layer. + - out / in: Refer to the definition of `InflatedDistributionCausalConv3d`. + partial_switch: Only works when `inflation_mode` is `partial_flatten`. + """ + stride = kwargs.get("stride", args[3] if len(args) > 3 else None) + padding = kwargs.get("padding", args[4] if len(args) > 4 else None) + if "flatten" in inflation_mode: + if ( + ( + (not stride) + or isinstance(stride, int) + or (isinstance(stride, list or tuple) and len(stride) < 3) + ) # if the config of stride can be used for 2D conv + and ( + (not padding) + or isinstance(padding, int) + or (isinstance(padding, list or tuple) and len(padding) < 3) + ) # if the config of padding can be used for 2D conv + and (("partial" not in inflation_mode) or (not partial_switch)) + # if it's fully-flatten mode, or with `partial_switch` off + ): + return FlattenedConv3d(*args, **kwargs) + else: + return InflatedCausalConv3d(*args, inflation_mode="none", **kwargs) + # Force-override + else: + if direction: + return InflatedDistributionCausalConv3d( + *args, direction=direction, inflation_mode=inflation_mode, **kwargs + ) + else: + return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) + + +def init_transposed_conv3d( + *args, inflation_mode: _inflation_mode_t, partial_switch: bool = False, **kwargs +): + stride = kwargs.get("stride", args[3] if len(args) > 3 else None) + padding = kwargs.get("padding", args[4] if len(args) > 4 else None) + if "flatten" in inflation_mode: + if ( + ( + (not stride) + or isinstance(stride, int) + or (isinstance(stride, list or tuple) and len(stride) < 3) + ) + and ( + (not padding) + or isinstance(padding, int) + or (isinstance(padding, list or tuple) and len(padding) < 3) + ) + or (("partial" in inflation_mode) and not partial_switch) + ): + return FlattenedConvTranspose3d(*args, **kwargs) + else: + return InflatedConvTranspose3d( + *args, inflation_mode="none", **kwargs + ) # Force-override + else: + return InflatedConvTranspose3d(*args, inflation_mode=inflation_mode, **kwargs) \ No newline at end of file diff --git a/adv_grpo/inflated_lib.py b/adv_grpo/inflated_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..589ffb773351ade6ed806137e8b3df03b11bde27 --- /dev/null +++ b/adv_grpo/inflated_lib.py @@ -0,0 +1,346 @@ +import math +from enum import Enum +from typing import Optional +import numpy as np +import torch +from diffusers.models.attention_processor import SpatialNorm +from diffusers.models.normalization import RMSNorm +from einops import rearrange +from torch import Tensor, nn + +# from common.logger import get_logger + +# logger = get_logger(__name__) + + +class MemoryState(Enum): + """ + State[Disabled]: No memory bank will be enabled. + State[Initializing]: The model is handling the first clip, + need to reset / initialize the memory bank. + State[Active]: There has been some data in the memory bank. + """ + + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + + +def norm_wrapper( + norm_layer: nn.Module, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + keep_causal: bool = False, +) -> torch.Tensor: + if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x + if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4 or (not keep_causal and not isinstance(norm_layer, nn.BatchNorm2d)): + return norm_layer(x) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + if isinstance(norm_layer, SpatialNorm): + t = -1 + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + if y.ndim == 5: + y = rearrange(y, "b c t h w -> (b t) c h w") + if x.ndim != 4 or y.ndim != 4: + raise NotImplementedError + x = norm_layer(x, y) + if t != -1: + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + raise NotImplementedError + + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + """ + Remove duplicated first frame features in the up-sampling process. + """ + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + + +def extend_head( + tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None +) -> Tensor: + """ + When memory is None: + - Duplicate first frame features in the down-sampling process. + When memory is not None: + - Concatenate memory features with the input features to keep temporal consistency. + """ + if times == 0: + return tensor + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + else: + tile_repeat = np.ones(tensor.ndim).astype(int) + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) + + +def fill_weight_in_depth(weight: torch.Tensor, source: torch.Tensor, position: str): + """ + Inflate a 2D convolution weight matrix to a 3D one by padding zeros in the channel of depth. + Parameters: + weight: The weight parameters of 3D conv kernel to be initialized. + source: The weight parameters of 2D conv kernel to be inflated. + position: Where to insert the 2D weights, can be chosen from + - tail: Pad zeros in the front of the 2D kernel. Used for casual inflation. + - center: Pad zeros around the 2D kernel. Used for normal inflation. + """ + assert position in ["tail", "center"], "Unsupported fill-in position for weight inflation." + depth = weight.size(2) + weight.fill_(0.0) + if position == "center": + if depth % 2 == 1: + weight[:, :, depth // 2].copy_(source.squeeze(2)) + else: + weight[:, :, depth // 2].copy_(source.squeeze(2) / 2.0) + weight[:, :, depth // 2 - 1].copy_(source.squeeze(2) / 2.0) + else: + if depth % 2 == 1: + weight[:, :, -1].copy_(source.squeeze(2)) + else: + weight[:, :, -1].copy_(source.squeeze(2) / 2.0) + weight[:, :, -2].copy_(source.squeeze(2) / 2.0) + return weight + + +def inflate_weight( + weight_2d: torch.Tensor, + weight_3d: torch.Tensor, + shape_norm: bool, + name: str, + inflation_mode: str, + position: str, + verbose: bool = True, +): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Parameters: + weight_2d: The weight matrix of 2D conv to be inflated. + weight_3d: The weight matrix of 3D conv to be initialized. + inflation_mode: the mode of inflation + - pad: pad zeros around 2D kernel. + - tile: tile 2D kernel along the depth axis. + + shape_norm: Whether to scale the parameters of 2D kernel so that the untrained + inflated model behaves exactly the same as the original 2D model + in the reconstruction of image and video. recommend to switch it on. + + name: The name of inflated module. Only be used in logging. + position: Refer to the doc of `fill_weight_in_depth`. + Only works when `inflation_mode` is `pad`. + verbose: Whether to log information about inflation. + """ + assert inflation_mode in ["pad", "tile"] + depth = weight_3d.size(2) + tgt_out, tgt_in = weight_3d.size()[:2] + src_out, src_in = weight_2d.size()[:2] + assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0) + out_fan, in_fan = tgt_out // src_out, tgt_in // src_in + depth_factor = 1 if inflation_mode == "pad" else depth + factor = (depth_factor * math.sqrt(out_fan) * math.sqrt(in_fan)) if shape_norm else 1 + with torch.no_grad(): + channel_inflation = weight_2d.unsqueeze(2).repeat(out_fan, in_fan, 1, 1, 1) / factor + if inflation_mode == "tile": + weight_3d.copy_(channel_inflation.repeat(1, 1, depth, 1, 1)) + else: + weight_3d = fill_weight_in_depth(weight_3d, channel_inflation, position) + if verbose: + print( + f"*** {name}weight {weight_2d.size()} is inflated to {weight_3d.size()} ***" + ) + return weight_3d + + +def inflate_bias( + bias_2d: torch.Tensor, + bias_3d: torch.Tensor, + shape_norm: bool, + name: str, + inflation_mode: str, + position: str, + verbose: bool = True, +): + """ + Inflate a 2D convolution bias tensor to a 3D one + Parameters: + bias_2d: The bias tensor of 2D conv to be inflated. + bias_3d: The bias tensor of 3D conv to be initialized. + shape_norm: Refer to `inflate_weight` function. + name: The name of inflated module. Only be used in logging. + inflation_mode: Placeholder to align `inflate_weight`. + position: Placeholder to align `inflate_weight`. + verbose: Whether to log information about inflation. + """ + tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0) + assert tgt_ch % src_ch == 0 + fan = tgt_ch // src_ch + factor = math.sqrt(fan) if shape_norm else 1 + with torch.no_grad(): + bias_3d.copy_(bias_2d.repeat(fan) / factor) + if (tgt_ch != src_ch) and verbose: + print(f"*** {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***") + return bias_3d + + +def inflate_distribution_weight( + weight_2d: torch.Tensor, + weight_3d: torch.Tensor, + shape_norm: bool, + name: str, + direction: str, + inflation_mode: str, + position: str, + verbose: bool = True, +): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Note: Different from `inflate_weight`, + it's designed for `quant_conv` or `post_quant_conv` layers. + i.e., a convolution layer used to produce `mean` and `std` of some distribution, + or its subsequent layer. + Parameters: Refer to `inflate_weight`. + direction: + - out: this layer generates `mean` and `std` of some distribution. + - in: this layer takes tensors sampled from output of `out` layer as input. + """ + assert inflation_mode in ["pad", "tile"] + depth = weight_3d.size(2) + tgt_out, tgt_in = weight_3d.size()[:2] + src_out, src_in = weight_2d.size()[:2] + assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0) + out_fan, in_fan = tgt_out // src_out, tgt_in // src_in + depth_factor = 1 if inflation_mode == "pad" else depth + if direction == "out": + factor = (depth_factor * math.sqrt(in_fan)) if shape_norm else 1 + with torch.no_grad(): + in_inflation = weight_2d.unsqueeze(2).repeat(1, in_fan, 1, 1, 1) / factor + # [src_out, src_in, k_h, k_w] -> [src_out, tgt_in, 1, k_h, k_w] + out_mean_weight, out_std_weight = torch.chunk(in_inflation, 2, dim=0) + mean_slice = slice(src_out // 2) + std_slice = slice(tgt_out // 2, tgt_out // 2 + src_out // 2) + if inflation_mode == "tile": + weight_3d[mean_slice] = out_mean_weight + weight_3d[std_slice] = out_std_weight + # Other part will be randomly initialized. + else: + weight_3d[mean_slice] = fill_weight_in_depth( + weight_3d[mean_slice], out_mean_weight, position + ) + weight_3d[std_slice] = fill_weight_in_depth( + weight_3d[std_slice], out_std_weight, position + ) + # Other part will be randomly initialized. + elif direction == "in": + factor = (depth_factor * math.sqrt(out_fan)) if shape_norm else 1 + with torch.no_grad(): + out_inflation = weight_2d.unsqueeze(2).repeat(out_fan, 1, 1, 1, 1) / factor + # [src_out, src_in, k_h, k_w] -> [tgt_out, src_in, 1, k_h, k_w] + if inflation_mode == "tile": + weight_3d[:, :src_in] = out_inflation + else: + weight_3d[:, :src_in] = fill_weight_in_depth( + weight_3d[:, :src_in], out_inflation, position + ) + weight_3d[:, src_in:].fill_(0.0) + else: + raise NotImplementedError + if verbose: + print( + f"*** [Distribution] {name}weight {weight_2d.size()} " + f"is inflated to {weight_3d.size()} ***" + ) + return weight_3d + + +def inflate_distribution_bias( + bias_2d: torch.Tensor, + bias_3d: torch.Tensor, + shape_norm: bool, + name: str, + direction: str, + inflation_mode: str, + position: str, + verbose: bool = True, +): + """ + The combination of `inflate_distribution_weight` and `inflate_bias`. + """ + tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0) + assert tgt_ch % src_ch == 0 + if direction == "out": + with torch.no_grad(): + out_mean_bias, out_std_bias = torch.chunk(bias_2d, 2, dim=0) + bias_3d[: src_ch // 2] = out_mean_bias + bias_3d[tgt_ch // 2 : tgt_ch // 2 + src_ch // 2] = out_std_bias + elif direction == "in": + with torch.no_grad(): + bias_3d[:src_ch] = bias_2d + bias_3d[src_ch:].fill_(0.0) + else: + raise NotImplementedError + if verbose: + print( + f"*** [Distribution] {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***" + ) + return bias_3d + + +def modify_state_dict( + layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn, verbose=False +): + """ + the main function to inflated 2D parameters to 3D. + """ + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if ( + weight_2d.dim() == 4 + ): # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + shape_norm=layer.shape_norm, + name=prefix, + verbose=verbose, + inflation_mode=layer.inflation_mode, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + # It's a 3d state dict, should not do inflation on both bias and weight. + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: # Assuming the 2D biases are 1D tensors (out_channels,) + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + shape_norm=layer.shape_norm, + name=prefix, + verbose=verbose, + inflation_mode=layer.inflation_mode, + ) + state_dict[bias_name] = bias_3d + return state_dict \ No newline at end of file diff --git a/adv_grpo/ocr.py b/adv_grpo/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..a973ef2f7daa12f4e7efadf6c5a998993b30078f --- /dev/null +++ b/adv_grpo/ocr.py @@ -0,0 +1,138 @@ +from paddleocr import PaddleOCR +import torch +import numpy as np +from Levenshtein import distance +from typing import List, Union, Tuple +from PIL import Image + +class OcrScorer: + def __init__(self, use_gpu: bool = False): + """ + OCR reward calculator + :param use_gpu: Whether to use GPU acceleration for PaddleOCR + """ + self.ocr = PaddleOCR( + use_angle_cls=False, + lang="en", + use_gpu=use_gpu, + show_log=False # Disable unnecessary log output + ) + + @torch.no_grad() + def __call__(self, + images: Union[List[Image.Image], List[np.ndarray]], + prompts: List[str]) -> torch.Tensor: + """ + Calculate OCR reward + :param images: List of input images (PIL or numpy format) + :param prompts: Corresponding target text list + :return: Reward tensor (CPU) + """ + # import pdb; pdb.set_trace() + prompts = [prompt.split('"')[1] for prompt in prompts] + rewards = [] + # Ensure input lengths are consistent + assert len(images) == len(prompts), "Images and prompts must have the same length" + for img, prompt in zip(images, prompts): + # Convert image format + if isinstance(img, Image.Image): + img = np.array(img) + + try: + # OCR recognition + result = self.ocr.ocr(img, cls=False) + # Extract recognized text (handle possible multi-line results) + recognized_text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else '' + + recognized_text = recognized_text.replace(' ', '').lower() + prompt = prompt.replace(' ', '').lower() + if prompt in recognized_text: + dist = 0 + else: + dist = distance(recognized_text, prompt) + # import pdb; pdb.set_trace() + # Recognized many unrelated characters, only add one character penalty + if dist > len(prompt): + dist = len(prompt) + + except Exception as e: + # Error handling (e.g., OCR parsing failure) + print(f"OCR processing failed: {str(e)}") + dist = len(prompt) # Maximum penalty + reward = 1-dist/(len(prompt)) + rewards.append(reward) + + return rewards + +class OcrScorer_video_or_image: + def __init__(self, use_gpu: bool = False): + """ + OCR reward calculator + :param use_gpu: Whether to use GPU acceleration for PaddleOCR + """ + self.ocr = PaddleOCR( + use_angle_cls=False, + lang="en", + use_gpu=use_gpu, + show_log=False # Disable unnecessary log output + ) + self.frame_interval = 4 + + @torch.no_grad() + def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> Tuple[List[float], torch.Tensor]: + """ + :param images: List of images or videos (each video as np.ndarray of shape [F, H, W, C]) + :param prompts: List of prompts containing target text + :return: (List of OCR rewards, Tensor of attention regions) + """ + prompts = [prompt.split('"')[1] for prompt in prompts] + assert len(images) == len(prompts), "Mismatch between images and prompts." + + rewards = [] + for img, prompt in zip(images, prompts): + prompt = prompt.replace(' ', '').lower() + frame_rewards = [] + + # Handle video: shape (F, H, W, C) + if isinstance(img, np.ndarray) and img.ndim == 4: + sampled_frames = img[::self.frame_interval] + else: + sampled_frames = [img] + + for frame in sampled_frames: + region = None + if isinstance(frame, Image.Image): + frame = np.array(frame) + try: + result = self.ocr.ocr(frame, cls=False) + text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else '' + text = text.replace(' ', '').lower() + + dist = distance(text, prompt) + dist = min(dist, len(prompt)) + + except Exception as e: + print(f"OCR failed on frame: {e}") + dist = len(prompt) + + reward = 1 - dist / len(prompt) + if reward > 0: + frame_rewards.append(reward) + + if frame_rewards: + rewards.append(sum(frame_rewards) / len(frame_rewards)) + else: + rewards.append(0.0) + + return rewards + +if __name__ == "__main__": + example_image_path = "media_images_eval_images_499_ef42de47b8ec98892954.jpg" + example_image = Image.open(example_image_path) + example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky' + # Instantiate scorer + scorer = OcrScorer(use_gpu=False) + + # Call scorer and print result + reward = scorer([example_image], [example_prompt]) + print(f"OCR Reward: {reward}") \ No newline at end of file diff --git a/adv_grpo/pick_score_training.py b/adv_grpo/pick_score_training.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6782c978cda6e5331c0a9e177b311154573d01 --- /dev/null +++ b/adv_grpo/pick_score_training.py @@ -0,0 +1,385 @@ +import torch +from transformers import CLIPProcessor, CLIPModel +from PIL import Image +from torch.utils.data import DataLoader + +# ====== 使用你找到的 CLIPCriterion ====== +from dataclasses import dataclass +from torch.nn.modules.loss import _Loss +from torch.utils.data import Dataset, DataLoader +import os +import json +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + + + +def evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device, max_eval=100): + """ + 简单评估:取前 max_eval 对 Qwen vs SD3 pair,算平均分 + """ + model.eval() + if hasattr(model, "module"): # DDP 情况 + model = model.module + + with open(json_file, "r") as f: + prompt2img = json.load(f) + + prompts = list(prompt2img.keys())[:max_eval] + + qwen_scores, sd3_scores = [], [] + + for prompt in prompts: + filename = prompt2img[prompt] + qwen_img_path = os.path.join(qwen_dir, filename) + sd3_img_path = os.path.join(sd3_dir, filename) + + if not (os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path)): + continue + + qwen_img = Image.open(qwen_img_path).convert("RGB") + sd3_img = Image.open(sd3_img_path).convert("RGB") + + # 文本 & 图像输入 + text_inputs = processor.tokenizer( + prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77 + ).to(device) + qwen_inputs = processor(images=qwen_img, return_tensors="pt").to(device) + sd3_inputs = processor(images=sd3_img, return_tensors="pt").to(device) + + with torch.no_grad(): + text_features = model.get_text_features(**text_inputs) + qwen_features = model.get_image_features(**qwen_inputs) + sd3_features = model.get_image_features(**sd3_inputs) + + # 归一化 + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + qwen_features = qwen_features / qwen_features.norm(dim=-1, keepdim=True) + sd3_features = sd3_features / sd3_features.norm(dim=-1, keepdim=True) + + # 相似度分数 + logit_scale = model.logit_scale.exp() + qwen_score = (logit_scale * (text_features @ qwen_features.T)).item() + sd3_score = (logit_scale * (text_features @ sd3_features.T)).item() + + qwen_scores.append(qwen_score) + sd3_scores.append(sd3_score) + + model.train() + if len(qwen_scores) > 0: + print(f"[Eval] Qwen avg={sum(qwen_scores)/len(qwen_scores):.4f} " + f"| SD3 avg={sum(sd3_scores)/len(sd3_scores):.4f}") + + +@dataclass +class CLIPCriterionConfig: + _target_: str = "trainer.criterions.clip_criterion.CLIPCriterion" + is_distributed: bool = False # 本地先关掉 + label_0_column_name: str = "label_0" + label_1_column_name: str = "label_1" + input_ids_column_name: str = "input_ids" + pixels_0_column_name: str = "pixels_0" + pixels_1_column_name: str = "pixels_1" + num_examples_per_prompt_column_name: str = "num_examples_per_prompt" + in_batch_negatives: bool = False + + +class CLIPCriterion(_Loss): + def __init__(self, cfg: CLIPCriterionConfig): + super().__init__() + self.cfg = cfg + + @staticmethod + def get_features(model, input_ids, pixels_0_values, pixels_1_values): + # import pdb; pdb.set_trace() + # if hasattr(model, "module"): + # model = model.module + all_pixel_values = torch.cat([pixels_0_values, pixels_1_values], dim=0) + # text_features, all_image_features = model(text_inputs=input_ids, image_inputs=all_pixel_values) + text_features = model.get_text_features(input_ids=input_ids) + all_image_features = model.get_image_features(pixel_values=all_pixel_values) + all_image_features = all_image_features / all_image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + image_0_features, image_1_features = all_image_features.chunk(2, dim=0) + return image_0_features, image_1_features, text_features + + @staticmethod + def gather_features(features): + all_features = torch.cat(torch.distributed.nn.all_gather(features), dim=0) + return all_features + + # def safe_sync(self, msg): + # torch.cuda.synchronize() + # print(f"[Rank {dist.get_rank()}] OK at {msg}") + + def calc_loss( + self, + text_features, + image_0_features, + image_1_features, + logit_scale, + label_0, + label_1, + num_examples_per_prompt, + *args, + **kwargs + ): + # self.safe_sync("start") + + device = image_0_features.device + + # gather features + if self.cfg.is_distributed: + image_0_features = self.gather_features(image_0_features) + image_1_features = self.gather_features(image_1_features) + text_features = self.gather_features(text_features) + label_0 = self.gather_features(label_0) + label_1 = self.gather_features(label_1) + num_examples_per_prompt = self.gather_features(num_examples_per_prompt) + + # calc logits # TODO use local loss as open-clip does + all_image_features = torch.cat([image_0_features, image_1_features], dim=0) # (2 * batch_size, dim) + logits_per_image = logit_scale * all_image_features @ text_features.T + image_0_logits, image_1_logits = logits_per_image.chunk(2, dim=0) + text_logits = logit_scale * text_features @ all_image_features.T + + if self.cfg.in_batch_negatives: + # get labels + num_images = all_image_features.shape[0] + image_labels = torch.arange(num_images, device=device, dtype=torch.long) + image_0_labels, image_1_labels = image_labels.chunk(2, dim=0) + num_texts = text_features.shape[0] + text_labels = torch.arange(num_texts, device=device, dtype=torch.long) + + # image loss - we want to increase the logits of the preferred image to the text + image_0_loss = torch.nn.functional.cross_entropy(image_0_logits, text_labels, reduction="none") + image_1_loss = torch.nn.functional.cross_entropy(image_1_logits, text_labels, reduction="none") + # if we have a tie, we will increase both images equally, and average so the image loss of each example is + # proportional + image_loss = label_0 * image_0_loss + label_1 * image_1_loss + + # text loss - we want to increase the logits of the text to the preferred image + text_0_loss = torch.nn.functional.cross_entropy(text_logits, image_0_labels, reduction="none") + text_1_loss = torch.nn.functional.cross_entropy(text_logits, image_1_labels, reduction="none") + + else: + text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1) + index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long) + + text_0_logits = text_0_logits[index, index] + text_1_logits = text_1_logits[index, index] + text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1) + text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long) + text_1_labels = text_0_labels + 1 + text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none") + text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none") + + # if we have a tie we want the logits of for each image to be equal + text_loss = label_0 * text_0_loss + label_1 * text_1_loss + # we want the ideal loss to be 0, currently, if there is a tie, it is 0.5 * log(0.5) + 0.5 * log(0.5) + # so we add log(0.5) to the loss + is_tie = (label_0 == label_1).float() + is_tie *= torch.log(torch.tensor(0.5, device=device)) + text_loss += is_tie + + # we average the image and text loss + if self.cfg.in_batch_negatives: + loss = (image_loss + text_loss) / 2 + else: + loss = text_loss + # import pdb; pdb.set_trace() + + # some prompts have lots of interactions, we want weight them accordingly + # absolute_example_weight = 1 / num_examples_per_prompt + # denominator = absolute_example_weight.sum() + # weight_per_example = absolute_example_weight / denominator + # loss *= weight_per_example + loss = loss.mean() + # import pdb; pdb.set_trace() + + # loss = loss.sum() + return loss + + def forward(self, model, batch): + # import pdb; pdb.set_trace() + image_0_features, image_1_features, text_features = self.get_features( + model, + batch[self.cfg.input_ids_column_name], + batch[self.cfg.pixels_0_column_name], + batch[self.cfg.pixels_1_column_name] + ) + # print("text_features:", text_features.shape) + + loss = self.calc_loss( + text_features, + image_0_features, + image_1_features, + model.logit_scale.exp(), + batch[self.cfg.label_0_column_name], + batch[self.cfg.label_1_column_name], + batch[self.cfg.num_examples_per_prompt_column_name], + ) + return loss + + +# ====== 数据准备 ====== +class QwenSD3JsonDataset(Dataset): + def __init__(self, processor, json_file, qwen_dir, sd3_dir): + """ + json_file: prompt2img.json {prompt: filename} + qwen_dir: 存放Qwen图像的文件夹 + sd3_dir: 存放SD3图像的文件夹 + """ + self.processor = processor + + with open(json_file, "r") as f: + self.prompt2img = json.load(f) + + self.prompts = list(self.prompt2img.keys()) + self.qwen_dir = qwen_dir + self.sd3_dir = sd3_dir + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, idx): + prompt = self.prompts[idx] + filename = self.prompt2img[prompt] + + qwen_img_path = os.path.join(self.qwen_dir, filename) + sd3_img_path = os.path.join(self.sd3_dir, filename) + + if os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path): + qwen_img = Image.open(qwen_img_path).convert("RGB") + sd3_img = Image.open(sd3_img_path).convert("RGB") + else: + qwen_img = Image.open(sd3_img_path).convert("RGB") + sd3_img = Image.open(sd3_img_path).convert("RGB") + + # 文本token + text_inputs = self.processor.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt" + ) + input_ids = text_inputs["input_ids"].squeeze(0) + + # 图像预处理 + pixels_0 = self.processor(images=qwen_img, return_tensors="pt")["pixel_values"].squeeze(0) + pixels_1 = self.processor(images=sd3_img, return_tensors="pt")["pixel_values"].squeeze(0) + + return { + "input_ids": input_ids, + "pixels_0": pixels_0, # 正样本 (Qwen) + "pixels_1": pixels_1, # 负样本 (SD3) + "label_0": torch.tensor(1.0), + "label_1": torch.tensor(0.0), + "num_examples_per_prompt": torch.tensor(1.0) + } + + +# ====== 训练 loop ====== +# def finetune_pickscore(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6, device="cuda"): +# processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") +# model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device) + +# dataset = QwenSD3JsonDataset(processor,json_file, qwen_dir, sd3_dir) +# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + +# criterion = CLIPCriterion(CLIPCriterionConfig()) +# optimizer = torch.optim.AdamW(model.parameters(), lr=lr) +# # import pdb; pdb.set_trace() + +# model.train() +# for epoch in range(epochs): +# total_loss = 0.0 +# for batch in dataloader: +# batch = {k: v.to(device) for k, v in batch.items()} +# loss = criterion(model, batch) + +# optimizer.zero_grad() +# loss.backward() +# optimizer.step() + +# total_loss += loss.item() +# print(f"Epoch {epoch} | Loss {total_loss/len(dataloader):.4f}") + +# model.save_pretrained("pickscore_qwen_finetuned") +# return model + +def finetune_pickscore_distributed(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6): + # 1. 初始化分布式 + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + # 2. 准备数据 + processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + dataset = QwenSD3JsonDataset(processor, json_file, qwen_dir, sd3_dir) + sampler = DistributedSampler(dataset) + dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + # 3. 模型 + DDP + model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device) + model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) + + criterion = CLIPCriterion(CLIPCriterionConfig()) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + # 4. 训练 + model.train() + if dist.get_rank() == 0: + evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device) + for epoch in range(epochs): + sampler.set_epoch(epoch) # 保证每个 epoch shuffle 一样 + total_loss = 0.0 + + for step, batch in enumerate(dataloader): + batch = {k: v.to(device) for k, v in batch.items()} + loss = criterion(model.module, batch) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 累积loss(先local) + total_loss += loss.item() + + # 每隔一定步打印一次(rank=0) + if step % 50 == 0: # 你可以改成10、100 + # all_reduce 把所有 GPU 的 loss 平均 + avg_loss = torch.tensor(loss.item(), device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) + if dist.get_rank() == 0: + print(f"[Epoch {epoch} | Step {step}/{len(dataloader)}] " + f"local_loss={loss.item():.4f} | avg_loss={avg_loss.item():.4f}") + + # 每个 epoch 打印 epoch 平均 loss + epoch_loss = torch.tensor(total_loss / len(dataloader), device=device) + dist.all_reduce(epoch_loss, op=dist.ReduceOp.AVG) + if dist.get_rank() == 0: + print(f"===> Epoch {epoch} done | avg_epoch_loss={epoch_loss.item():.4f}") + evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device) + + # 5. 保存模型(只在 rank=0) + if dist.get_rank() == 0: + model.module.save_pretrained("pickscore_qwen_finetuned") + + dist.destroy_process_group() + + +# ====== 用法示例 ====== +if __name__ == "__main__": + finetune_pickscore_distributed( + json_file="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images/prompt2img.json", + qwen_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images", + sd3_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images", + epochs=2, + batch_size=4, + lr=1e-6, + ) diff --git a/adv_grpo/pickscore_scorer.py b/adv_grpo/pickscore_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec09229bdef3d4351f3f810da9dcbe466df82e0 --- /dev/null +++ b/adv_grpo/pickscore_scorer.py @@ -0,0 +1,70 @@ +from transformers import CLIPProcessor, CLIPModel +from PIL import Image +import torch + +class PickScoreScorer(torch.nn.Module): + def __init__(self, device="cuda", dtype=torch.float32): + super().__init__() + processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + model_path = "yuvalkirstain/PickScore_v1" + self.device = device + self.dtype = dtype + self.processor = CLIPProcessor.from_pretrained(processor_path) + self.model = CLIPModel.from_pretrained(model_path).eval().to(device) + self.model = self.model.to(dtype=dtype) + + @torch.no_grad() + def __call__(self, prompt, images): + # Preprocess images + if hasattr(self.model, "module"): + self.model = self.model.module + image_inputs = self.processor( + images=images, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()} + # Preprocess text + text_inputs = self.processor( + text=prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()} + + # Get embeddings + image_embs = self.model.get_image_features(**image_inputs) + image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True) + + text_embs = self.model.get_text_features(**text_inputs) + text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True) + + # Calculate scores + logit_scale = self.model.logit_scale.exp() + scores = logit_scale * (text_embs @ image_embs.T) + scores = scores.diag() + # norm to 0-1 + scores = scores/26 + return scores + +# Usage example +def main(): + scorer = PickScoreScorer( + device="cuda", + dtype=torch.float32 + ) + images=[ + "nasa.jpg", + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + print(scorer(prompts, pil_images)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/pickscore_scorer_constractive.py b/adv_grpo/pickscore_scorer_constractive.py new file mode 100644 index 0000000000000000000000000000000000000000..32589bfcfb055122a9d7cc8a5ac2bdfa9bc48424 --- /dev/null +++ b/adv_grpo/pickscore_scorer_constractive.py @@ -0,0 +1,89 @@ +from transformers import CLIPProcessor, CLIPModel +from PIL import Image +import torch + +class PickScoreScorerConstractive(torch.nn.Module): + def __init__(self, device="cuda", dtype=torch.float32): + super().__init__() + processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + model_path = "yuvalkirstain/PickScore_v1" + self.device = device + self.dtype = dtype + self.processor = CLIPProcessor.from_pretrained(processor_path) + self.model = CLIPModel.from_pretrained(model_path).eval().to(device) + self.model = self.model.to(dtype=dtype) + + @torch.no_grad() + def __call__(self, prompt, images, ref_images): + # Preprocess images + image_inputs = self.processor( + images=images, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()} + + ref_image_inputs = self.processor( + images=ref_images, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + ref_image_inputs = {k: v.to(device=self.device) for k, v in ref_image_inputs.items()} + + + + # Preprocess text + text_inputs = self.processor( + text=prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()} + + # Get embeddings + image_embs = self.model.get_image_features(**image_inputs) + image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True) + + ref_image_embs = self.model.get_image_features(**ref_image_inputs) + ref_image_embs = ref_image_embs / ref_image_embs.norm(p=2, dim=-1, keepdim=True) + + text_embs = self.model.get_text_features(**text_inputs) + text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True) + + # Calculate scores + logit_scale = self.model.logit_scale.exp() + scores = logit_scale * (text_embs @ image_embs.T) + scores = scores.diag() + # norm to 0-1 + scores = scores/26 + + ref_scores = logit_scale * (text_embs @ ref_image_embs.T) + ref_scores = ref_scores.diag() + ref_scores = ref_scores/26 + + + return scores, ref_scores, image_embs, ref_image_embs + +# Usage example +def main(): + scorer = PickScoreScorer( + device="cuda", + dtype=torch.float32 + ) + images=[ + "nasa.jpg", + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + print(scorer(prompts, pil_images)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/pickscore_scorer_patch.py b/adv_grpo/pickscore_scorer_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..62f4180b7f4703e8c03e5c6093c6b49f90bd6ef1 --- /dev/null +++ b/adv_grpo/pickscore_scorer_patch.py @@ -0,0 +1,78 @@ +from transformers import CLIPProcessor, CLIPModel +from PIL import Image +import torch + +class PickScoreScorer(torch.nn.Module): + def __init__(self, device="cuda", dtype=torch.float32): + super().__init__() + processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + model_path = "yuvalkirstain/PickScore_v1" + self.device = device + self.dtype = dtype + self.processor = CLIPProcessor.from_pretrained(processor_path) + self.model = CLIPModel.from_pretrained(model_path).eval().to(device) + self.model = self.model.to(dtype=dtype) + + @torch.no_grad() + def __call__(self, prompt, images): + # Preprocess images + if hasattr(self.model, "module"): + self.model = self.model.module + image_inputs = self.processor( + images=images, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()} + # Preprocess text + text_inputs = self.processor( + text=prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ) + text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()} + + # Get embeddings + # image_embs = self.model.get_image_features(**image_inputs) + import pdb; pdb.set_trace() + image_embs = self.model.vision_model(image_inputs["pixel_values"],output_hidden_states=True) + image_embs = image_embs.last_hidden_state + + image_embs = self.model.visual_projection(image_embs) # [B, N, 1024] + image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True) + + text_embs = self.model.get_text_features(**text_inputs) + text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True) + + # Calculate scores + logit_scale = self.model.logit_scale.exp() + # scores = logit_scale * (text_embs @ image_embs.T) + patch_scores = torch.einsum("bd,bnd->bn", text_embs, image_embs) # [B, N] + scores = logit_scale * patch_scores.mean(dim=1) # 取所有 patch 的平均 + # scores = scores.diag() + # norm to 0-1 + scores = scores/26 + # import pdb; pdb.set_trace() + return scores + +# Usage example +def main(): + scorer = PickScoreScorer( + device="cuda", + dtype=torch.float32 + ) + images=[ + "nasa.jpg", + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + print(scorer(prompts, pil_images)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/prompts.py b/adv_grpo/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..63b0d55e465ebb8f4e5f16efa43ae6e33629dabb --- /dev/null +++ b/adv_grpo/prompts.py @@ -0,0 +1,80 @@ +from importlib import resources +import os +import functools +import random +# import inflect + +# IE = inflect.engine() +IE=None +ASSETS_PATH = resources.files("adv_grpo.assets") + + +@functools.cache +def _load_lines(path): + """ + Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the + `adv_grpo/assets` directory for a file named `path`. + """ + if not os.path.exists(path): + newpath = ASSETS_PATH.joinpath(path) + if not os.path.exists(newpath): + raise FileNotFoundError(f"Could not find {path} or adv_grpo.assets/{path}") + path = newpath + with open(path, "r") as f: + return [line.strip() for line in f.readlines()] + + +def from_file(path, low=None, high=None): + prompts = _load_lines(path)[low:high] + return random.choice(prompts), {} + + +def imagenet_all(): + return from_file("imagenet_classes.txt") + + +def imagenet_animals(): + return from_file("imagenet_classes.txt", 0, 398) + + +def imagenet_dogs(): + return from_file("imagenet_classes.txt", 151, 269) + + +def simple_animals(): + return from_file("simple_animals.txt") + +def general_ocr(): + return from_file("general_ocr_train.txt") + +def simple_ocr_animals(): + animals = _load_lines("simple_ocr_animals.txt") + # random_number = random.randint(100, 999) + # random_number = ''.join([str(random.randint(0, 9)) for _ in range(10)]) + num=random.randint(1, 9) + random_number = ''.join([str(6) for _ in range(num)]) + return f'A {random.choice(animals)} holding a sign that says "{random_number}"', {} + +def nouns_activities(nouns_file, activities_file): + nouns = _load_lines(nouns_file) + activities = _load_lines(activities_file) + return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} + + +def counting(nouns_file, low, high): + nouns = _load_lines(nouns_file) + number = IE.number_to_words(random.randint(low, high)) + noun = random.choice(nouns) + plural_noun = IE.plural(noun) + prompt = f"{number} {plural_noun}" + metadata = { + "questions": [ + f"How many {plural_noun} are there in this image?", + f"What animal is in this image?", + ], + "answers": [ + number, + noun, + ], + } + return prompt, metadata diff --git a/adv_grpo/qwenvl.py b/adv_grpo/qwenvl.py new file mode 100644 index 0000000000000000000000000000000000000000..d8534785cb9b57e67843d0988553334d688fd37f --- /dev/null +++ b/adv_grpo/qwenvl.py @@ -0,0 +1,118 @@ +from PIL import Image +import torch +import re +import base64 +from io import BytesIO +from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info + +def pil_image_to_base64(image): + buffered = BytesIO() + image.save(buffered, format="PNG") + encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") + base64_qwen = f"data:image;base64,{encoded_image_text}" + return base64_qwen + +def extract_scores(output_text): + scores = [] + for text in output_text: + match = re.search(r'(\d+)', text) + if match: + scores.append(float(match.group(1))/5) + else: + scores.append(0) + return scores + +class QwenVLScorer(torch.nn.Module): + def __init__(self, device="cuda", dtype=torch.bfloat16): + super().__init__() + self.device = device + self.dtype = dtype + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", + torch_dtype=self.dtype, + attn_implementation="flash_attention_2", + device_map=None, + ).to(self.device) + self.model.requires_grad_(False) + self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True) + self.task = ''' +Your role is to evaluate the aesthetic quality score of given images. +1. Bad: Extremely blurry, underexposed with significant noise, indiscernible +subjects, and chaotic composition. +2. Poor: Noticeable blur, poor lighting, washed-out colors, and awkward +composition with cut-off subjects. +3. Fair: In focus with adequate lighting, dull colors, decent composition but +lacks creativity. +4. Good: Sharp, good exposure, vibrant colors, thoughtful composition with +a clear focal point. +5. Excellent: Exceptional clarity, perfect exposure, rich colors, masterful +composition with emotional impact. + +Please first provide a detailed analysis of the evaluation process, including the criteria for judging aesthetic quality, within the tag. Then, give a final score from 1 to 5 within the tag. + +[Analyze the evaluation process in detail here] + +X +''' + + @torch.no_grad() + def __call__(self, prompt, images): + images_base64 = [pil_image_to_base64(image) for image in images] + messages=[] + for base64_qwen in images_base64: + messages.append([ + { + "role": "user", + "content": [ + {"type": "image", "image": base64_qwen}, + {"type": "text", "text": self.task}, + ], + }, + ]) + + # Preparation for batch inference + texts = [ + self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + for msg in messages + ] + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(self.device) + + # Batch Inference + generated_ids = self.model.generate(**inputs, max_new_tokens=2048) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_texts = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + rewards = extract_scores(output_texts) + return rewards + +# Usage example +def main(): + scorer = QwenVLScorer( + device="cuda", + dtype=torch.bfloat16 + ) + images=[ + "nasa.jpg", + ] + pil_images = [Image.open(img) for img in images] + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + + print(scorer(None, pil_images)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/rewards.py b/adv_grpo/rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..e6bb4d39424e1b4a0d33f6375a1f5aa8141a96b1 --- /dev/null +++ b/adv_grpo/rewards.py @@ -0,0 +1,1126 @@ +from PIL import Image +import io +import numpy as np +import torch +from collections import defaultdict + +import torch +import timm +import torch.nn as nn +import torch.nn.functional as F + + +def jpeg_incompressibility(): + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + buffers = [io.BytesIO() for _ in images] + for image, buffer in zip(images, buffers): + image.save(buffer, format="JPEG", quality=95) + sizes = [buffer.tell() / 1000 for buffer in buffers] + return np.array(sizes), {} + + return _fn + + +def jpeg_compressibility(): + jpeg_fn = jpeg_incompressibility() + + def _fn(images, prompts, metadata): + rew, meta = jpeg_fn(images, prompts, metadata) + return -rew/500, meta + + return _fn + +def aesthetic_score(): + from adv_grpo.aesthetic_scorer import AestheticScorer + + scorer = AestheticScorer(dtype=torch.float32).cuda() + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8) + else: + images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW + images = torch.tensor(images, dtype=torch.uint8) + scores = scorer(images) + return scores, {} + + return _fn + +def clip_score(): + from adv_grpo.clip_scorer import ClipScorer + + # scorer = ClipScorer(dtype=torch.float32).cuda() + scorer = ClipScorer().cuda() + + def _fn(images, prompts, metadata): + if not isinstance(images, torch.Tensor): + images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW + images = torch.tensor(images, dtype=torch.uint8)/255.0 + scores = scorer(images, prompts) + return scores, {} + + return _fn + + +def siglip_image_similarity_score(device): + import torch + import numpy as np + from transformers import SiglipModel + import torch.nn.functional as F + + # 1. 加载 SigLIP 模型(推荐 so400m-p14-384) + scorer = SiglipModel.from_pretrained( + "google/siglip-so400m-patch14-384" + ).to(device).to(torch.bfloat16) + scorer.eval() + + # SigLIP preprocess mean/std + siglip_mean = [0.5, 0.5, 0.5] + siglip_std = [0.5, 0.5, 0.5] + + # 模型输入分辨率(自动匹配所选 SigLIP 模型) + image_size = scorer.config.vision_config.image_size # e.g., 224/256/384 + + def _preprocess(images): + # 转成 tensor + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + + # 0-255 → 0-1 + if images.max() > 1.0: + images = images / 255.0 + + # NHWC → NCHW + if images.shape[-1] == 3: + images = images.permute(0, 3, 1, 2) + + # resize to SigLIP input + images = F.interpolate( + images, + size=(image_size, image_size), + mode="bicubic", + align_corners=False + ) + + # normalize using SigLIP's mean/std + mean = torch.tensor(siglip_mean, device=device)[None, :, None, None] + std = torch.tensor(siglip_std, device=device)[None, :, None, None] + images = (images - mean) / std + + return images.to(device).to(torch.bfloat16) + + def _fn(images, ref_images): + # 2. preprocess + images = _preprocess(images) + ref_images = _preprocess(ref_images) + + with torch.no_grad(): + # SigLIP extract feature + out_img = scorer.vision_model( + pixel_values=images.to(torch.float32) + ) + out_ref = scorer.vision_model( + pixel_values=ref_images.to(torch.float32) + ) + + emb_images = out_img.pooler_output.to(torch.bfloat16) + emb_ref = out_ref.pooler_output.to(torch.bfloat16) + + # 3. normalize embeddings (L2) + emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True) + emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True) + + # 4. cosine similarity + scores = torch.matmul(emb_images, emb_ref.T) # [N,M] + per_img = scores.max(dim=1).values # [N] + + return per_img.detach(), {"pairwise": scores.detach()} + + return _fn + + + +def image_similarity_score(device): + import torch + import numpy as np + import timm + import torch.nn.functional as F + + # 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14) + model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True) + # model = timm.create_model('vit_large_patch16_dinov3_qkvb.lvd1689m', pretrained=True) + model.eval().to(device) + + + def _preprocess(images): + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + if images.max() > 1.0: + images = images / 255.0 + if images.shape[-1] == 3: # NHWC -> NCHW + images = images.permute(0, 3, 1, 2) + # 调整到 518×518 + # images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + # DINOv2 normalization + # images = (images - 0.5) / 0.5 + images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None] + images = (images - mean) / std + return images.to(device) + + def _fn(images, ref_images): + # 2. 预处理 + # import pdb; pdb.set_trace() + images = _preprocess(images) + ref_images = _preprocess(ref_images) + + with torch.no_grad(): + # import pdb; pdb.set_trace() + emb_images = model(images) # [N,D] + emb_ref = model(ref_images) # [M,D] + + # 3. 归一化 + emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True) + emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True) + + # 4. 计算相似度 (余弦相似度) + scores = torch.matmul(emb_images, emb_ref.T) # [N,M] + per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1) + # import pdb; pdb.set_trace() + + # 返回一维分数,pairwise 放到 info 里 + # return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref + return per_img.detach(), {"pairwise": scores.detach()} + + + # return scores, {} + + return _fn + + + + +def image_similarity_score_eval(device): + import torch + import numpy as np + import timm + import torch.nn.functional as F + + # 1. 加载 DINOv2 模型(这里用 ViT-Base/14,可换成 ViT-L/14 或 ViT-G/14) + model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True) + model.eval().to(device) + + def _preprocess(images): + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + if images.max() > 1.0: + images = images / 255.0 + if images.shape[-1] == 3: # NHWC -> NCHW + images = images.permute(0, 3, 1, 2) + # 调整到 518×518 + # images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + # DINOv2 normalization + # images = (images - 0.5) / 0.5 + images = torch.nn.functional.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None] + images = (images - mean) / std + return images.to(device) + + def _fn(images, ref_images): + # 2. 预处理 + # import pdb; pdb.set_trace() + images = _preprocess(images) + ref_images = _preprocess(ref_images) + + with torch.no_grad(): + # import pdb; pdb.set_trace() + emb_images = model(images) # [N,D] + emb_ref = model(ref_images) # [M,D] + + # 3. 归一化 + # emb_images = emb_images / emb_images.norm(dim=-1, keepdim=True) + # emb_ref = emb_ref / emb_ref.norm(dim=-1, keepdim=True) + + # 4. 计算相似度 (余弦相似度) + scores = torch.matmul(emb_images, emb_ref.T) # [N,M] + per_img = scores.max(dim=1).values # [N];若想平均用:scores.mean(dim=1) + # import pdb; pdb.set_trace() + + # 返回一维分数,pairwise 放到 info 里 + return per_img.detach(), {"pairwise": scores.detach()}, emb_images, emb_ref + # return per_img.detach(), {"pairwise": scores.detach()} + + + # return scores, {} + + return _fn + + + +def dino_cotrain_score(device): + def _preprocess(images): + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + if images.max() > 1.0: + images = images / 255.0 + if images.shape[-1] == 3: # NHWC -> NCHW + images = images.permute(0, 3, 1, 2) + images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + # images = (images - 0.5) / 0.5 + mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None] + images = (images - mean) / std + return images.to(device).to(torch.bfloat16) + + def _fn(scorer, head, images, prompts, metadata): + images = _preprocess(images) + + with torch.no_grad(): + emb = scorer(images) # [N,D] + emb = emb / emb.norm(dim=-1, keepdim=True) + + # Head 输出 reward + scores = head(emb).squeeze(-1) # [N] + # import pdb; pdb.set_trace() + + return scores.detach(), {"embeddings": emb.detach()} + + return _fn + + + + +def siglip_cotrain_score(device): + """ + 使用方式与原来的 dino_cotrain_score 一致: + reward_fn = siglip_cotrain_score(device) + reward_fn(scorer, head, images, prompts, metadata) + + scorer: SigLIPModel (从 HF transformers 加载) + head: 你的 reward head + """ + from torchvision import transforms + tiny_jitter = transforms.ColorJitter( + brightness=0.02, + contrast=0.02, + ) + + def _preprocess(images, image_size=224): + # 转 tensor + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + + # 将 uint8 0-255 → 0-1 + if images.max() > 1.0: + images = images / 255.0 + + # NHWC → NCHW + if images.shape[-1] == 3: + images = images.permute(0, 3, 1, 2) + + imgs_aug = [] + for img in images: + imgs_aug.append(tiny_jitter(img)) # 做轻微亮度扰动 + images = torch.stack(imgs_aug) + + # Resize到 SigLIP 默认输入大小(通常 224 / 256 / 384) + images = F.interpolate( + images, + size=(image_size, image_size), + mode="bicubic", + align_corners=False + ) + + # ★★★ SigLIP 官方 mean/std(注意不同于 CLIP)★★★ + mean = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None] + std = torch.tensor([0.5, 0.5, 0.5], device=device)[None, :, None, None] + + images = (images - mean) / std + + return images.to(device).to(torch.bfloat16) + + def _fn(scorer, head, images, prompts, metadata): + """ + scorer: SigLIPModel + head: reward head + """ + + # 图像预处理(scorer.config.vision_config.image_size 通常 224) + # import pdb; pdb.set_trace() + images = _preprocess(images, 512) + + scorer.eval() + with torch.no_grad(): + # ★★★ SigLIP 特征获取 ★★★ + # 返回 CLS token 的 embedding,类似 CLIP 的 global feature + vision_out = scorer.vision_model( + pixel_values=images.to(torch.float32) + ) + emb = vision_out.pooler_output.to(torch.bfloat16) # [B, D] + + # head 输出 reward + scores = head(emb).squeeze(-1) + + return scores.detach(), {"embeddings": emb.detach()} + + return _fn + + +def dino_patch_cotrain_score(device, n_patches=64): + import torch + import torch.nn.functional as F + + def _preprocess(images): + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + if images.max() > 1.0: + images = images / 255.0 + if images.shape[-1] == 3: # NHWC -> NCHW + images = images.permute(0, 3, 1, 2) + images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + # images = (images - 0.5) / 0.5 + mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None] + images = (images - mean) / std + return images.to(device).to(torch.bfloat16) + + def _fn(scorer, head, images, prompts, metadata, cls_weight=0.7): + images = _preprocess(images) + with torch.no_grad(): + # 提取所有特征: [B, N+1, D] + feats = scorer.forward_features(images) + + # --- 分离 CLS 与 patch --- + cls_emb = feats[:, 0, :] # [B, D] + patch_emb = feats[:, 1:, :] # [B, N, D] + B, N, D = patch_emb.shape + + # --- 随机采样 patch --- + n_select = min(n_patches, N) + idx = torch.randint(0, N, (B, n_select), device=device) + sampled_patches = torch.gather( + patch_emb, 1, idx.unsqueeze(-1).expand(-1, -1, D) + ) # [B, n_select, D] + + # --- 归一化 --- + cls_emb = cls_emb / (cls_emb.norm(dim=-1, keepdim=True) + 1e-6) + sampled_patches = sampled_patches / (sampled_patches.norm(dim=-1, keepdim=True) + 1e-6) + + # --- 计算分数 --- + cls_score = head(cls_emb).squeeze(-1) # [B] + patch_scores = head(sampled_patches).squeeze(-1) # [B, n_select] + patch_score_mean = patch_scores.mean(dim=1) # [B] + + # --- 混合 reward --- + hybrid_score = cls_weight * cls_score + (1 - cls_weight) * patch_score_mean + # hybrid_score = hybrid_score.unsqueeze(-1) # [B, 1] + + # import pdb; pdb.set_trace() + + # --- 返回结果 --- + return hybrid_score.detach(), { + "cls_score": cls_score.detach(), + "patch_scores": patch_scores.detach(), + "patch_indices": idx.detach(), + "cls_weight": cls_weight, + } + + return _fn + + +def _get_layer_tokens_timm(model, imgs, layer_ids=(2, 5, 8, 11)): + handles, feats = [], {i: None for i in layer_ids} + + def make_hook(i): + def hook(_module, _inp, out): + # 对 timm 的 ViT,block 的输出通常是 [B, N+1, D] + feats[i] = out + return hook + + for i in layer_ids: + assert 0 <= i < len(model.blocks), f"layer id {i} out of range" + handles.append(model.blocks[i].register_forward_hook(make_hook(i))) + + # 触发前向。timm 的 ViT 有 forward_features;没有就直接 __call__ + with torch.no_grad(): + if hasattr(model, "forward_features"): + _ = model.forward_features(imgs) + else: + _ = model(imgs) + + for h in handles: + h.remove() + + return [feats[i] for i in layer_ids] # list of [B, N+1, D] + +# -------- reward 工厂:分层 head + top-k 池化 + 融合 -------- +def dino_multi_cotrain_score( + device, + topk_tau=0.2, # 每层取前 tau 比例的 patch logits 做均值 + apply_sigmoid=True, # 是否把 logit 过 sigmoid 得到 [0,1] reward + lambda_cls=0.5, + zscore=False, # 是否对 batch 维做 z-score(组内可再自行处理) +): + + def _preprocess(images): + if not isinstance(images, torch.Tensor): + images = torch.tensor(images, dtype=torch.float32) + if images.max() > 1.0: + images = images / 255.0 + if images.shape[-1] == 3: # NHWC -> NCHW + images = images.permute(0, 3, 1, 2) + images = F.interpolate(images, size=(518, 518), mode="bicubic", align_corners=False) + images = (images - 0.5) / 0.5 + # mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)[None, :, None, None] + # std = torch.tensor([0.229, 0.224, 0.225], device=images.device)[None, :, None, None] + # images = (images - mean) / std + return images.to(device) + + @torch.no_grad() + def _fn(scorer, heads, fusion, images,prompts=None, metadata=None, layer_ids=(8,),temperature=0.2): + """ + scorer : timm 的 ViT-DINOv2 backbone(已 .eval() 且 requires_grad=False) + heads : nn.ModuleList,长度 = len(layer_ids),每层一个 head + fusion : 融合器,将 (B,T) -> (B,) + images : [N,H,W,3] or [N,3,H,W],值域 [0,1]/[0,255] + 返回: + rewards: [N](float32) + aux: dict,包含 per_layer_scores 等,便于调试/可视化 + """ + from torch.nn.parallel import DistributedDataParallel as DDP + hmod = heads.module if isinstance(heads, DDP) else heads + fmod = fusion.module if isinstance(fusion, DDP) else fusion + x = _preprocess(images).to(dtype=next(scorer.parameters()).dtype, device=device) + + # 取多层 tokens + tokens_list = _get_layer_tokens_timm(scorer, x, layer_ids=layer_ids) # list of [B, N+1, D] + B = x.size(0) + T = len(tokens_list) + + per_layer_scores = [] + per_layer_logits = [] # 保存每层的 top-k 前的 patch logits(可选) + per_layer_cls_scores = [] + + for t in range(T): + tokens = tokens_list[t] # [B, N+1, D] + patch = tokens[:, 1:] # [B, N, D] 忽略 CLS + class_patch = tokens[:, 0] + Bn, N, D = patch.shape + + # head 支持 [B,N,D]:输出 [B,N] + logits_patch = hmod[t](patch).squeeze(-1) # [B, N] + per_layer_logits.append(logits_patch) + + # 层内 top-k 池化 + k = max(1, int(N * topk_tau)) + pooled = logits_patch.topk(k, dim=1).values.mean(dim=1) # [B] + per_layer_scores.append(pooled) + + cls_logit = hmod[t](class_patch).squeeze(-1) # [B] + per_layer_cls_scores.append(cls_logit) + + per_layer_scores = torch.stack(per_layer_scores, dim=1) # [B, T] + per_layer_cls_scores = torch.stack(per_layer_cls_scores, dim=1) + + # 融合为最终 logit + logit_patch = fmod(per_layer_scores) # [B] + logit_cls = fmod(per_layer_cls_scores) + + # logits = (1.0 - float(lambda_cls)) * logit_patch + float(lambda_cls) * logit_cls # [B] + logits = logit_patch + # 标定成 reward + rewards = logits + # import pdb; pdb.set_trace() + if apply_sigmoid: + rewards = torch.sigmoid(rewards / float(temperature)) + if zscore: + mu = rewards.mean(dim=0, keepdim=True) + sigma = rewards.std(dim=0, keepdim=True).clamp_min(1e-6) + rewards = (rewards - mu) / sigma + + # 输出 float32,避免后续与 bf16 混用出问题 + rewards = rewards.float() + # import pdb; pdb.set_trace() + + aux = { + "per_layer_scores": per_layer_scores.float(), # [B,T] + "logits": logits.float(), # 未标定的融合 logit + # 下行可能很大(B×T×N),需要时再用;默认不返回节省带宽 + # "per_layer_logits": [lp.float() for lp in per_layer_logits], + } + return rewards, aux + + return _fn + +def pickscore_score(device): + from adv_grpo.pickscore_scorer import PickScoreScorer + + scorer = PickScoreScorer(dtype=torch.float32, device=device) + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + scores = scorer(prompts, images) + return scores, {} + + return _fn + + +def pickscore_cotrain_score(device): + + + def _fn(scorer, images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + scores = scorer(prompts, images) + # import pdb; pdb.set_trace() + return scores, {} + + return _fn + + +def pickscore_score_patch(device): + from adv_grpo.pickscore_scorer_patch import PickScoreScorer + + + scorer = PickScoreScorer(dtype=torch.float32, device=device) + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + scores = scorer(prompts, images) + return scores, {} + + return _fn + + + + +def discriminator_score(device): + def _fn(scorer, images, prompts=None, metadata=None): + # 归一化到 [-1,1] + if isinstance(images, torch.Tensor): + if images.max() > 1.5: # 可能是 0~255 + images = images / 255.0 + images = (images - 0.5) * 2.0 + else: + raise ValueError("images must be a torch.Tensor in [B,3,H,W]") + + with torch.no_grad(): + logits = scorer(images.to(device)) # StyleGAN: [B] / [B,1] PatchGAN: [B,1,H',W'] + + if logits.ndim == 1: + # StyleGAN D,已经是 [B] + scores = torch.sigmoid(logits) + elif logits.ndim == 2 and logits.shape[1] == 1: + # StyleGAN D,输出 [B,1] + scores = torch.sigmoid(logits.squeeze(1)) + elif logits.ndim == 4 and logits.shape[1] == 1: + # PatchGAN D,输出 [B,1,H',W'] + scores = torch.sigmoid(logits).mean(dim=[1,2,3]) # -> [B] + else: + raise ValueError(f"Unexpected logits shape: {logits.shape}") + + return scores.cpu(), {} + + return _fn + + + +def imagereward_score(device): + from adv_grpo.imagereward_scorer import ImageRewardScorer + + scorer = ImageRewardScorer(dtype=torch.float32, device=device) + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + prompts = [prompt for prompt in prompts] + scores = scorer(prompts, images) + return scores, {} + + return _fn + +def qwenvl_score(device): + from adv_grpo.qwenvl import QwenVLScorer + + scorer = QwenVLScorer(dtype=torch.bfloat16, device=device) + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + prompts = [prompt for prompt in prompts] + scores = scorer(prompts, images) + return scores, {} + + return _fn + + +def ocr_score(device): + from adv_grpo.ocr import OcrScorer + + scorer = OcrScorer() + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + # import pdb; pdb.set_trace() + scores = scorer(images, prompts) + # change tensor to list + return scores, {} + + return _fn + +def video_ocr_score(device): + from adv_grpo.ocr import OcrScorer_video_or_image + + scorer = OcrScorer_video_or_image() + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + if images.dim() == 4 and images.shape[1] == 3: + images = images.permute(0, 2, 3, 1) + elif images.dim() == 5 and images.shape[2] == 3: + images = images.permute(0, 1, 3, 4, 2) + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + scores = scorer(images, prompts) + # change tensor to list + return scores, {} + + return _fn + +def constractive_external(device, beta=0.5, top_n=2): + import torch + from PIL import Image + from adv_grpo.pickscore_scorer_constractive import PickScoreScorerConstractive + + scorer = PickScoreScorerConstractive(dtype=torch.float32, device=device) + + def _fn(images,ref_images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + + ref_images = (ref_images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + ref_images = ref_images.transpose(0, 2, 3, 1) # NCHW -> NHWC + ref_images = [Image.fromarray(image) for image in ref_images] + # import pdb; pdb.set_trace() + + # scorer + scores, ref_scores, img_embeds, ref_img_embeds = scorer(prompts, images, ref_images) + + # external anchor + ref_embed = ref_img_embeds.mean(dim=0, keepdim=True) # (1,D) + ext_score = ref_scores.mean() + + top_idx = torch.topk(scores, k=min(top_n, len(scores))).indices + hack_scores = scores[top_idx] + hack_embeds = img_embeds[top_idx] # (N,D) + + if ext_score >= hack_scores.max(): + return scores, {"raw_scores": scores, "ref_scores": ref_scores} + + # 计算对比修正 + sim_to_ext = torch.nn.functional.cosine_similarity(img_embeds, ref_embed) + sim_to_hack = torch.nn.functional.cosine_similarity( + img_embeds.unsqueeze(1), hack_embeds.unsqueeze(0), dim=-1 + ) # (num_images, N) + sim_to_hack = sim_to_hack.mean(dim=1) + + adjusted_scores = scores + beta * (sim_to_ext - sim_to_hack) + + return adjusted_scores, { + "raw_scores": scores, + "ref_scores": ref_scores, + "sim_to_ext": sim_to_ext, + "sim_to_hack": sim_to_hack, + "hack_scores": hack_scores + } + + return _fn + + +def deqa_score_remote(device): + """Submits images to DeQA and computes a reward. + """ + import requests + from requests.adapters import HTTPAdapter, Retry + from io import BytesIO + import pickle + + batch_size = 64 + url = "http://127.0.0.1:18086" + sess = requests.Session() + retries = Retry( + total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False + ) + sess.mount("http://", HTTPAdapter(max_retries=retries)) + + def _fn(images, prompts, metadata): + del prompts + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) + all_scores = [] + for image_batch in images_batched: + jpeg_images = [] + + # Compress the images using JPEG + for image in image_batch: + img = Image.fromarray(image) + buffer = BytesIO() + img.save(buffer, format="JPEG") + jpeg_images.append(buffer.getvalue()) + + # format for LLaVA server + data = { + "images": jpeg_images, + } + data_bytes = pickle.dumps(data) + + # send a request to the llava server + response = sess.post(url, data=data_bytes, timeout=120) + response_data = pickle.loads(response.content) + + all_scores += response_data["outputs"] + + return all_scores, {} + + return _fn + + + +def geneval_score(device): + """Submits images to GenEval and computes a reward. + """ + import requests + from requests.adapters import HTTPAdapter, Retry + from io import BytesIO + import pickle + + batch_size = 64 + url = "http://127.0.0.1:18085" + sess = requests.Session() + retries = Retry( + total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False + ) + sess.mount("http://", HTTPAdapter(max_retries=retries)) + + def _fn(images, prompts, metadatas, only_strict): + del prompts + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) + metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size)) + all_scores = [] + all_rewards = [] + all_strict_rewards = [] + all_group_strict_rewards = [] + all_group_rewards = [] + for image_batch, metadata_batched in zip(images_batched, metadatas_batched): + jpeg_images = [] + + # Compress the images using JPEG + for image in image_batch: + img = Image.fromarray(image) + buffer = BytesIO() + img.save(buffer, format="JPEG") + jpeg_images.append(buffer.getvalue()) + + # format for LLaVA server + data = { + "images": jpeg_images, + "meta_datas": list(metadata_batched), + "only_strict": only_strict, + } + data_bytes = pickle.dumps(data) + + # send a request to the llava server + response = sess.post(url, data=data_bytes, timeout=120) + response_data = pickle.loads(response.content) + + all_scores += response_data["scores"] + all_rewards += response_data["rewards"] + all_strict_rewards += response_data["strict_rewards"] + all_group_strict_rewards.append(response_data["group_strict_rewards"]) + all_group_rewards.append(response_data["group_rewards"]) + all_group_strict_rewards_dict = defaultdict(list) + all_group_rewards_dict = defaultdict(list) + for current_dict in all_group_strict_rewards: + for key, value in current_dict.items(): + all_group_strict_rewards_dict[key].extend(value) + all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict) + + for current_dict in all_group_rewards: + for key, value in current_dict.items(): + all_group_rewards_dict[key].extend(value) + all_group_rewards_dict = dict(all_group_rewards_dict) + + return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict + + return _fn + +def unifiedreward_score_remote(device): + """Submits images to DeQA and computes a reward. + """ + import requests + from requests.adapters import HTTPAdapter, Retry + from io import BytesIO + import pickle + + batch_size = 64 + url = "http://10.82.120.15:18085" + sess = requests.Session() + retries = Retry( + total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False + ) + sess.mount("http://", HTTPAdapter(max_retries=retries)) + + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) + prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size)) + + all_scores = [] + for image_batch, prompt_batch in zip(images_batched, prompts_batched): + jpeg_images = [] + + # Compress the images using JPEG + for image in image_batch: + img = Image.fromarray(image) + buffer = BytesIO() + img.save(buffer, format="JPEG") + jpeg_images.append(buffer.getvalue()) + + # format for LLaVA server + data = { + "images": jpeg_images, + "prompts": prompt_batch + } + data_bytes = pickle.dumps(data) + + # send a request to the llava server + response = sess.post(url, data=data_bytes, timeout=120) + print("response: ", response) + print("response: ", response.content) + response_data = pickle.loads(response.content) + + all_scores += response_data["outputs"] + + return all_scores, {} + + return _fn + +def unifiedreward_score_sglang(device): + import asyncio + from openai import AsyncOpenAI + import base64 + from io import BytesIO + import re + + def pil_image_to_base64(image): + buffered = BytesIO() + image.save(buffered, format="PNG") + encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") + base64_qwen = f"data:image;base64,{encoded_image_text}" + return base64_qwen + + def _extract_scores(text_outputs): + scores = [] + pattern = r"Final Score:\s*([1-5](?:\.\d+)?)" + for text in text_outputs: + match = re.search(pattern, text) + if match: + try: + scores.append(float(match.group(1))) + except ValueError: + scores.append(0.0) + else: + scores.append(0.0) + return scores + + client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo") + + async def evaluate_image(prompt, image): + question = f"\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after \'Final Score:\'.\nYour task is provided as follows:\nText Caption: [{prompt}]" + images_base64 = pil_image_to_base64(image) + response = await client.chat.completions.create( + model="UnifiedReward-7b-v1.5", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": images_base64}, + }, + { + "type": "text", + "text": question, + }, + ], + }, + ], + temperature=0, + ) + return response.choices[0].message.content + + async def evaluate_batch_image(images, prompts): + tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)] + results = await asyncio.gather(*tasks) + return results + + def _fn(images, prompts, metadata): + # 处理Tensor类型转换 + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + + # 转换为PIL Image并调整尺寸 + images = [Image.fromarray(image).resize((512, 512)) for image in images] + + # 执行异步批量评估 + text_outputs = asyncio.run(evaluate_batch_image(images, prompts)) + score = _extract_scores(text_outputs) + score = [sc/5.0 for sc in score] + return score, {} + + return _fn + +def multi_score(device, score_dict): + score_functions = { + "deqa": deqa_score_remote, + "ocr": ocr_score, + "video_ocr": video_ocr_score, + "imagereward": imagereward_score, + "pickscore": pickscore_score, + "qwenvl": qwenvl_score, + "aesthetic": aesthetic_score, + "jpeg_compressibility": jpeg_compressibility, + "unifiedreward": unifiedreward_score_sglang, + "geneval": geneval_score, + "clipscore": clip_score, + "image_similarity": image_similarity_score, + "image_similarity_eval": image_similarity_score_eval, + "constractive_external": constractive_external, + "discriminator": discriminator_score, + "pickscore_cotrain": pickscore_cotrain_score, + "pickscore_patch":pickscore_score_patch, + "dino_cotrain":dino_cotrain_score, + "dino_multi_cotrain": dino_multi_cotrain_score, + "dino_patch_cotrain": dino_patch_cotrain_score, + "siglip_cotrain": siglip_cotrain_score, + "siglip_image_similarity": siglip_image_similarity_score + } + score_fns={} + for score_name, weight in score_dict.items(): + # import pdb; pdb.set_trace() + score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]() + + # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time. + def _fn(images, prompts, metadata, scorer = None, ref_images=None, only_strict=True, head=None, fusion=None, layer_ids = None, temperature=0.2): + total_scores = [] + score_details = {} + + for score_name, weight in score_dict.items(): + if score_name == "geneval": + scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict) + score_details['accuracy'] = rewards + score_details['strict_accuracy'] = strict_rewards + for key, value in group_strict_rewards.items(): + score_details[f'{key}_strict_accuracy'] = value + for key, value in group_rewards.items(): + score_details[f'{key}_accuracy'] = value + elif score_name == "image_similarity": + scores, rewards = score_fns[score_name](images, ref_images) + elif score_name == "siglip_image_similarity": + scores, rewards = score_fns[score_name](images, ref_images) + elif score_name == "image_similarity_eval": + scores, rewards, feat, ref_feat = score_fns[score_name](images, ref_images) + score_details['feat'] = feat + score_details['ref_feat'] = ref_feat + elif score_name == "constractive_external": + scores, rewards = score_fns[score_name](images, prompts, ref_images) + elif score_name == "discriminator": + scores, rewards = score_fns[score_name](scorer, images, prompts, ref_images) + elif score_name == "pickscore_cotrain": + scores, rewards = score_fns[score_name](scorer, images, prompts, metadata) + elif score_name == "dino_cotrain": + scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata) + elif score_name == "siglip_cotrain": + scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata) + elif score_name == "dino_multi_cotrain": + scores, rewards = score_fns[score_name](scorer, head, fusion, images, prompts, metadata, layer_ids, temperature) + elif score_name == "dino_patch_cotrain": + scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata) + elif score_name == "dinov3_patch_cotrain": + scores, rewards = score_fns[score_name](scorer, head, images, prompts, metadata) + else: + scores, rewards = score_fns[score_name](images, prompts, metadata) + + score_details[score_name] = scores + weighted_scores = [weight * score for score in scores] + + if not total_scores: + total_scores = weighted_scores + else: + total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)] + # import pdb; pdb.set_trace() + + score_details['avg'] = total_scores + return score_details, {} + + return _fn + +def main(): + import torchvision.transforms as transforms + + image_paths = [ + "nasa.jpg", + ] + + transform = transforms.Compose([ + transforms.ToTensor(), # Convert to tensor + ]) + + images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths]) + prompts=[ + 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', + ] + metadata = {} # Example metadata + score_dict = { + "unifiedreward": 1.0 + } + # Initialize the multi_score function with a device and score_dict + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + scoring_fn = multi_score(device, score_dict) + # Get the scores + scores, _ = scoring_fn(images, prompts, metadata) + # Print the scores + print("Scores:", scores) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/adv_grpo/stat_tracking.py b/adv_grpo/stat_tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..98738b2953c3de5e9abb89554ad1ac551b2f6617 --- /dev/null +++ b/adv_grpo/stat_tracking.py @@ -0,0 +1,94 @@ +import numpy as np +from collections import deque +import torch +# import warnings +# def warn_with_traceback(message, category, filename, lineno, file=None, line=None): +# print(f"\n⚠️ RuntimeWarning caught: {message}") +# print("first 10 values:", rewards) +# return warnings.default_action + +# warnings.showwarning = warn_with_traceback + +class PerPromptStatTracker: + def __init__(self, global_std=False): + self.global_std = global_std + self.stats = {} + self.history_prompts = set() + + def update(self, prompts, rewards, type='grpo'): + prompts = np.array(prompts) + rewards = np.array(rewards, dtype=np.float64) + unique = np.unique(prompts) + advantages = np.empty_like(rewards)*0.0 + + # try: + # advantages = np.empty_like(rewards) * 0.0 + # except RuntimeWarning as e: + # print("⚠️ RuntimeWarning:", e) + # print("rewards shape:", rewards.shape) + # print("rewards first 10:", rewards) + # raise + + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = [] + self.stats[prompt].extend(prompt_rewards) + self.history_prompts.add(hash(prompt)) # Add hash of prompt to history_prompts + for prompt in unique: + self.stats[prompt] = np.stack(self.stats[prompt]) + prompt_rewards = rewards[prompts == prompt] # Fix: Recalculate prompt_rewards for each prompt + mean = np.mean(self.stats[prompt], axis=0, keepdims=True) + if self.global_std: + std = np.std(rewards, axis=0, keepdims=True) + 1e-4 # Use global std of all rewards + else: + std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4 + if type=='grpo': + advantages[prompts == prompt] = (prompt_rewards - mean) / std + elif type=='rwr': + # advantages[prompts == prompt] = (prompt_rewards - mean) / std + advantages[prompts == prompt] = prompt_rewards + # advantages[prompts == prompt] = torch.softmax(torch.tensor(prompt_rewards), dim=0).numpy() + elif type=='sft': + advantages[prompts == prompt] = (torch.tensor(prompt_rewards) == torch.max(torch.tensor(prompt_rewards))).float().numpy() + elif type=='dpo': + # Get the advantages of the current prompt + prompt_advantages = torch.tensor(prompt_rewards) + # Find the indices of the maximum and minimum values + max_idx = torch.argmax(prompt_advantages) + min_idx = torch.argmin(prompt_advantages) + # If all rewards in a group are the same + if max_idx == min_idx: + min_idx = 0 + max_idx = 1 + result = torch.zeros_like(prompt_advantages).float() + # Set the maximum index to 1, minimum index to -1 + result[max_idx] = 1.0 + result[min_idx] = -1.0 + advantages[prompts == prompt] = result.numpy() + # print("reward difference one group", prompt_advantages[max_idx]-prompt_advantages[min_idx]) + + return advantages + + def get_stats(self): + avg_group_size = sum(len(v) for v in self.stats.values()) / len(self.stats) if self.stats else 0 + history_prompts = len(self.history_prompts) + return avg_group_size, history_prompts + + def clear(self): + self.stats = {} + +def main(): + tracker = PerPromptStatTracker() + prompts = ['a', 'b', 'a', 'c', 'b', 'a'] + rewards = [1, 2, 3, 4, 5, 6] + advantages = tracker.update(prompts, rewards) + print("Advantages:", advantages) + avg_group_size, history_prompts = tracker.get_stats() + print("Average Group Size:", avg_group_size) + print("History Prompts:", history_prompts) + tracker.clear() + print("Stats after clear:", tracker.stats) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..85a06606fe4b19e0513a861571441633d5f7ba4c --- /dev/null +++ b/app.py @@ -0,0 +1,223 @@ +import gradio as gr +import spaces +import torch +from diffusers import StableDiffusion3Pipeline +from adv_grpo.diffusers_patch.sd3_pipeline_with_logprob_fast import pipeline_with_logprob_random as pipeline_with_logprob +from adv_grpo.diffusers_patch.train_dreambooth_lora_sd3 import encode_prompt +from adv_grpo.ema import EMAModuleWrapper +from peft import PeftModel +from PIL import Image +import numpy as np +import os +from ml_collections import config_flags +from huggingface_hub import hf_hub_download +from huggingface_hub import login + +login(os.environ["HF_TOKEN"]) + + + +# --------------------------------------------------------- +# GLOBAL VARIABLES +# --------------------------------------------------------- + +pipeline = None +config = None +text_encoders = None +tokenizers = None +ema = None +transformer_trainable_parameters = None + +def load_lora_from_subfolder(): + repo_id = "benzweijia/Adv-GRPO" + subfolder = "PickScore" + + local_dir = "/tmp/PickScore" + os.makedirs(local_dir, exist_ok=True) + + for filename in ["adapter_config.json", "adapter_model.safetensors"]: + hf_hub_download( + repo_id=repo_id, + repo_type="model", + subfolder=subfolder, + filename=filename, + local_dir=local_dir, + force_download=False + ) + # import pdb; pdb.set_trace() + return local_dir + +# -------------- Load Config ------------------------------ +def load_config(): + """ + """ + import importlib.util + + config_path = "config/base.py" + spec = importlib.util.spec_from_file_location("config", config_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.get_config() + + +# -------------- Embedding Function ----------------------- +def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders, tokenizers, prompt, max_sequence_length + ) + prompt_embeds = prompt_embeds.to(device) + pooled_prompt_embeds = pooled_prompt_embeds.to(device) + return prompt_embeds, pooled_prompt_embeds + + +# --------------------------------------------------------- +# GPU MODEL INITIALIZATION +# --------------------------------------------------------- +@spaces.GPU +def init_model(): + global pipeline, config, text_encoders, tokenizers, ema, transformer_trainable_parameters + + print("🔥 Loading config...") + config = load_config() + + print("🔥 Loading SD3 base model on GPU...") + # import pdb; pdb.set_trace() + pipeline = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-medium" + ) + + # freeze non-trainable params + pipeline.vae.requires_grad_(False) + pipeline.text_encoder.requires_grad_(False) + pipeline.text_encoder_2.requires_grad_(False) + pipeline.text_encoder_3.requires_grad_(False) + + pipeline.transformer.requires_grad_(not config.use_lora) + + text_encoders = [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.text_encoder_3] + tokenizers = [pipeline.tokenizer, pipeline.tokenizer_2, pipeline.tokenizer_3] + + pipeline.safety_checker = None + pipeline.set_progress_bar_config(disable=True) + + # move to GPU + pipeline.vae.to("cuda") + pipeline.text_encoder.to("cuda") + pipeline.text_encoder_2.to("cuda") + pipeline.text_encoder_3.to("cuda") + pipeline.transformer.to("cuda") + config.train.lora_path = "benzweijia/Adv-GRPO/PickScore" + config.use_lora = True + lora_dir = load_lora_from_subfolder() + + if config.use_lora and config.train.lora_path: + print("🔥 Loading LoRA from:", config.train.lora_path) + pipeline.transformer = PeftModel.from_pretrained( + pipeline.transformer, + os.path.join(lora_dir,"PickScore") + ) + pipeline.transformer.set_adapter("default") + + transformer_trainable_parameters = list( + filter(lambda p: p.requires_grad, pipeline.transformer.parameters()) + ) + + # Setup EMA + ema = EMAModuleWrapper( + transformer_trainable_parameters, + decay=0.9, + update_step_interval=8, + device="cuda" + ) + + print("✅ Model initialized and ready.") + + +# --------------------------------------------------------- +# INFERENCE FUNCTION +# --------------------------------------------------------- +@spaces.GPU +def infer(prompt): + print("start infer") + global pipeline, config + print(pipeline) + + if pipeline is None: + init_model() + print(pipeline) + print("start infer 1111") + + + prompts = [prompt] + + # get prompt embedding + prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers, + max_sequence_length=128, + device="cuda" + ) + print("start infer 2") + + neg_embed, neg_pooled_embed = compute_text_embeddings( + [""], text_encoders, tokenizers, + max_sequence_length=128, + device="cuda" + ) + + neg_prompt_embeds = neg_embed.repeat(1, 1, 1) + neg_pooled_prompt_embeds = neg_pooled_embed.repeat(1, 1) + print("start infer 3") + + # generation seed + generator = torch.Generator().manual_seed(0) + + with torch.no_grad(): + images, _, _, _ = pipeline_with_logprob( + pipeline, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=neg_prompt_embeds, + negative_pooled_prompt_embeds=neg_pooled_prompt_embeds, + num_inference_steps=config.sample.eval_num_steps, + guidance_scale=config.sample.guidance_scale, + output_type="pt", + height=config.resolution, + width=config.resolution, + noise_level=0, + mini_num_image_per_prompt=1, + process_index=0, + sample_num_steps=config.sample.num_steps, + random_timestep=0, + generator=generator, + ) + + print("images type:", type(images)) + print("images len:", len(images)) + print("first image shape:", images[0].shape) + + # Convert to PIL + pil = Image.fromarray( + (images[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + ) + + # Fixed 512x512 for output + pil = pil.resize((512, 512)) + + return pil + + +# --------------------------------------------------------- +# GRADIO UI +# --------------------------------------------------------- +# init_model() + +demo = gr.Interface( + fn=infer, + inputs=gr.Textbox(lines=2, label="Prompt"), + outputs=gr.Image(type="pil"), + title="Adv-GRPO(PickScore)", + description="Enter a prompt and generate image using Adv-GRPO", +) + +demo.launch() diff --git a/config/__pycache__/base.cpython-310.pyc b/config/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa426d488add841064574435e65cd29d880463e Binary files /dev/null and b/config/__pycache__/base.cpython-310.pyc differ diff --git a/config/__pycache__/grpo.cpython-310.pyc b/config/__pycache__/grpo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ed4ee8c72a05c1616160d2d549678ad06b89990 Binary files /dev/null and b/config/__pycache__/grpo.cpython-310.pyc differ diff --git a/config/base.py b/config/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d969e684edff4fa4403af44ab1c5f2f6f98a88fd --- /dev/null +++ b/config/base.py @@ -0,0 +1,113 @@ +import ml_collections + + +def get_config(): + config = ml_collections.ConfigDict() + + ###### General ###### + # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime. + config.run_name = "" + # random seed for reproducibility. + config.seed = 42 + # top-level logging directory for checkpoint saving. + config.logdir = "logs" + # number of epochs between saving model checkpoints. + config.save_freq = 20 + # number of epochs between evaluating the model. + config.eval_freq = 20 + # number of checkpoints to keep before overwriting old ones. + config.num_checkpoint_limit = 5 + # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly. + config.mixed_precision = "fp16" + # allow tf32 on Ampere GPUs, which can speed up training. + config.allow_tf32 = True + # whether or not to use LoRA. + config.use_lora = True + config.dataset = "" + config.resolution = 768 + + ###### Pretrained Model ###### + config.pretrained = pretrained = ml_collections.ConfigDict() + # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub. + pretrained.model = "runwayml/stable-diffusion-v1-5" + # revision of the model to load. + pretrained.revision = "main" + + ###### Sampling ###### + config.sample = sample = ml_collections.ConfigDict() + # number of sampler inference steps for collecting dataset. + sample.num_steps = 40 + # number of sampler inference steps for evaluation. + sample.eval_num_steps = 40 + # classifier-free guidance weight. 1.0 is no guidance. + sample.guidance_scale = 4.5 + # batch size (per GPU!) to use for sampling. + sample.train_batch_size = 1 + sample.num_image_per_prompt = 1 + sample.test_batch_size = 1 + # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch * + # batch_size * num_gpus`. + sample.num_batches_per_epoch = 2 + # Whether use all samples in a batch to compute std + sample.global_std = True + # noise level + sample.noise_level = 0.7 + # Whether to use the same noise for the same prompt + sample.same_latent = False + + ###### Training ###### + config.train = train = ml_collections.ConfigDict() + # batch size (per GPU!) to use for training. + train.batch_size = 1 + # whether to use the 8bit Adam optimizer from bitsandbytes. + train.use_8bit_adam = False + # learning rate. + train.learning_rate = 3e-4 + # Adam beta1. + train.adam_beta1 = 0.9 + # Adam beta2. + train.adam_beta2 = 0.999 + # Adam weight decay. + train.adam_weight_decay = 1e-4 + # Adam epsilon. + train.adam_epsilon = 1e-8 + # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus * + # gradient_accumulation_steps`. + train.gradient_accumulation_steps = 1 + # maximum gradient norm for gradient clipping. + train.max_grad_norm = 1.0 + # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one + # outer epoch's round of sampling. + train.num_inner_epochs = 1 + # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during + # sampling will be used during training. + train.cfg = True + # clip advantages to the range [-adv_clip_max, adv_clip_max]. + train.adv_clip_max = 5 + # the PPO clip range. + train.clip_range = 1e-4 + # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the + # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates. + train.timestep_fraction = 1.0 + # kl ratio + train.beta = 0.0 + # pretrained lora path + train.lora_path = None + # save ema model + train.ema = False + + ###### Prompt Function ###### + # prompt function to use. see `prompts.py` for available prompt functions. + config.prompt_fn = "imagenet_animals" + # kwargs to pass to the prompt function. + config.prompt_fn_kwargs = {} + + ###### Reward Function ###### + # reward function to use. see `rewards.py` for available reward functions. + config.reward_fn = ml_collections.ConfigDict() + config.save_dir = '' + + ###### Per-Prompt Stat Tracking ###### + config.per_prompt_stat_tracking = True + + return config diff --git a/config/dpo.py b/config/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..96b2116a4d0ecb20808b7f540e6fdf37d5e96146 --- /dev/null +++ b/config/dpo.py @@ -0,0 +1,109 @@ +import ml_collections +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + +def compressibility(): + config = base.get_config() + + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.use_lora = True + + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 4 + + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 2 + + # prompting + config.prompt_fn = "general_ocr" + + # rewards + config.reward_fn = {"jpeg_compressibility": 1} + config.per_prompt_stat_tracking = True + return config + + +def geneval_sd3(): + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/geneval") + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 40 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + config.sample.train_batch_size = 24 + config.sample.num_image_per_prompt = 24 + config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 14 # This bs is a special design, the test set has a total of 2212, to make gpu_num*bs*n as close as possible to 2212, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + + config.train.algorithm = 'dpo' + # Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO. + config.train.ref_update_step=10000000 + config.train.batch_size = config.sample.train_batch_size + config.train.gradient_accumulation_steps = 1 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.beta = 100 + config.sample.global_std=True + config.train.ema=True + config.save_freq = 40 # epoch + config.eval_freq = 40 + config.save_dir = 'logs/geneval/sd3.5-M-dpo' + config.reward_fn = { + "geneval": 1.0, + } + + config.prompt_fn = "geneval" + + config.per_prompt_stat_tracking = True + return config + +def pickscore_sd3(): + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 40 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale=4.5 + + config.resolution = 512 + config.sample.train_batch_size = 24 + config.sample.num_image_per_prompt = 24 + config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + + config.train.algorithm = 'dpo' + # Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO. + config.train.ref_update_step=10000000 + + config.train.batch_size = config.sample.train_batch_size + config.train.gradient_accumulation_steps = 1 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.beta = 100 + config.sample.global_std=True + config.train.ema=True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.save_dir = 'logs/pickscore/sd3.5-M-dpo' + config.reward_fn = { + "pickscore": 1.0, + } + + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + +def get_config(name): + return globals()[name]() diff --git a/config/grpo.py b/config/grpo.py new file mode 100644 index 0000000000000000000000000000000000000000..27ccc08d2332895b7faded12c4645c4a9167bb6c --- /dev/null +++ b/config/grpo.py @@ -0,0 +1,434 @@ +import ml_collections +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + +def compressibility(): + config = base.get_config() + + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.use_lora = True + + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 4 + + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 2 + + # prompting + config.prompt_fn = "general_ocr" + + # rewards + config.reward_fn = {"jpeg_compressibility": 1} + config.per_prompt_stat_tracking = True + return config + + + +def dino_cotrain_sd3_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.mixed_precision = "bf16" + config.wandb_init = True + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 16 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + config.sample.random_timestep = 0 + + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.discriminator = "pickscore" + config.d_times=10 + config.d_lr=1e-4 + config.train.lora_path = None + config.tune_layer=-2 + + # config.use_lora = False + # config.train.learning_rate = 1e-5 + # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora" + + config.train_d = True + config.weight_path = None + config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json" + config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode" + config.test_external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test" + + config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_new_loss_24_9_preprocess" + config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_new_loss_16_8_preprocess' + # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8' + config.reward_fn = { + "dino_cotrain":1, + } + config.eval_reward_fn = { + "pickscore":1, + "image_similarity": 1 + } + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + + +def dino_cotrain_sd3_patch_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.mixed_precision = "bf16" + config.wandb_init = True + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 16 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + config.sample.random_timestep = 0 + + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.discriminator = "pickscore" + config.d_times=10 + config.d_lr=1e-4 + config.train.lora_path = None + config.tune_layer=-2 + + # config.use_lora = False + # config.train.learning_rate = 1e-5 + # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora" + + config.train_d = True + config.weight_path = None + config.limit = None + config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json" + config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode" + config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test" + + # config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_geneval.json" + # config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_geneval_multinode2" + # config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_ocr_test" + + config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_patch_image_loss_73_again" + config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_patch_image_loss_73_again' + # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8' + config.reward_fn = { + "dino_patch_cotrain":1, + } + config.eval_reward_fn = { + "pickscore":1, + "image_similarity": 1 + } + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + +def dino_cotrain_sd3_multi_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.mixed_precision = "bf16" + config.wandb_init = False + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 8 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + config.sample.random_timestep = 0 + + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0.0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.discriminator = "pickscore" + config.d_times=10 + config.d_lr=1e-4 + config.train.lora_path = None + config.tune_layer=(11,) + config.temperature = 2 + + # config.use_lora = False + # config.train.learning_rate = 1e-5 + # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora" + + config.train_d = True + config.weight_path = None + config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json" + config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode" + config.test_external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test" + + config.case_name = "fast_dino_cotrain_16_8_lr_times_10_1e4_multi_image_loss_11_only_patch3_tem_2" + config.save_dir = 'logs/dino/sd3.5-M-fast_dino_cotrain_16_8_lr_times_10_1e4_multi_image_loss_11_only_patch3_tem_2' + # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8' + config.reward_fn = { + "dino_multi_cotrain":1, + } + config.eval_reward_fn = { + "pickscore":1, + "image_similarity": 1 + } + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + +def eval_sd3_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.mixed_precision = "bf16" + config.wandb_init = False + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 8 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + # config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + + config.sample.test_batch_size = 16 + config.sample.repeat = 1 + + config.sample.random_timestep = 0 + + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0.0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.discriminator = "pickscore" + config.d_times=10 + config.d_lr=1e-4 + config.tune_layer=-2 + + config.train.lora_path = "" + config.save_folder = "/mnt/bn/vgfm2/test_dit/weijia/outputs_flowgrpo_test2/sd3_dino_pickscore_test_1" + config.train_d = True + config.weight_path = None + config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json" + config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode" + config.test_reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_test" + + config.reward_fn = { + "dino_cotrain":1, + } + config.eval_reward_fn = { + "pickscore":1, + } + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + + +def pickscore_cotrain_sd3_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.mixed_precision = "bf16" + config.wandb_init = True + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 16 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + config.sample.random_timestep = 0 + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0.0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.discriminator = "pickscore" + config.d_times=20 + config.d_lr=5e-6 + config.train.lora_path = None + config.tune_layer=-1 + # config.train.lora_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/logs/pickscore_again/sd3.5-M-fast_1node_8_8/checkpoints/checkpoint-1800/lora" + + config.train_d = True + config.weight_path = None + config.json_path = "/mnt/bn/vgfm2/test_dit/weijia/flow_grpo/prompt2img_merged_pickscore.json" + config.reference_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images_pickscore_8_multinode" + + config.case_name = "fast_pickscore_cotrain_lr_5e6_last1_16_8" + config.save_dir = 'logs/pickscore/sd3.5-M-fast_pickscore_cotrain_lr_5e6_last1_16_8' + # config.save_dir = 'logs/discriminator_again/sd3.5-M-fast_pickscore_16_8' + config.reward_fn = { + "pickscore_cotrain":1, + } + config.eval_reward_fn = { + "pickscore":1 + } + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + +def pickscore_sd3_fast(): + gpu_number=8 + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/ocr") + + config.mixed_precision = "bf16" + config.case_name = "fast_1node_16_8_multireward_11" + config.wandb_init = True + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 10 + config.sample.train_num_steps = 2 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + # 这里固定为1 + config.sample.train_batch_size = 1 + config.sample.num_image_per_prompt = 16 + config.sample.mini_num_image_per_prompt = 8 + # config.sample.mini_num_image_per_prompt = 4 + config.sample.num_batches_per_epoch = int(48/(gpu_number*config.sample.mini_num_image_per_prompt/config.sample.num_image_per_prompt)) + # config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + config.sample.random_timestep = None + + config.train.batch_size = config.sample.mini_num_image_per_prompt + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.clip_range = 1e-5 + config.train.beta = 0.0 + config.sample.global_std = True + config.sample.noise_level = 0.8 + config.train.ema = True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.save_dir = 'logs/pickscore_again/sd3.5-M-fast_1node_16_8_multireward_11_ocr_pickscore' + config.external_image_path = "/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images" + config.reward_fn = { + "pickscore": 0.5, + "ocr": 0.5, + } + + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + + + +def get_config(name): + return globals()[name]() + diff --git a/config/sft.py b/config/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..3fefd840b84ad156f9e611b598acc5beb543c681 --- /dev/null +++ b/config/sft.py @@ -0,0 +1,109 @@ +import ml_collections +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + +def compressibility(): + config = base.get_config() + + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + config.use_lora = True + + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 4 + + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 2 + + # prompting + config.prompt_fn = "general_ocr" + + # rewards + config.reward_fn = {"jpeg_compressibility": 1} + config.per_prompt_stat_tracking = True + return config + + +def geneval_sd3(): + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/geneval") + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 40 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale = 4.5 + + config.resolution = 512 + config.sample.train_batch_size = 24 + config.sample.num_image_per_prompt = 24 + config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 14 # This bs is a special design, the test set has a total of 2212, to make gpu_num*bs*n as close as possible to 2212, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + + config.train.algorithm = 'sft' + # Change ref_update_step to a small number, e.g., 40, to switch to OnlineSFT. + config.train.ref_update_step=10000000 + config.train.batch_size = config.sample.train_batch_size + config.train.gradient_accumulation_steps = 1 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.beta = 100 + config.sample.global_std=True + config.train.ema=True + config.save_freq = 40 # epoch + config.eval_freq = 40 + config.save_dir = 'logs/geneval/sd3.5-M-sft' + config.reward_fn = { + "geneval": 1.0, + } + + config.prompt_fn = "geneval" + + config.per_prompt_stat_tracking = True + return config + +def pickscore_sd3(): + config = compressibility() + config.dataset = os.path.join(os.getcwd(), "dataset/pickscore") + + # sd3.5 medium + config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" + config.sample.num_steps = 40 + config.sample.eval_num_steps = 40 + config.sample.guidance_scale=4.5 + + config.resolution = 512 + config.sample.train_batch_size = 24 + config.sample.num_image_per_prompt = 24 + config.sample.num_batches_per_epoch = 1 + config.sample.test_batch_size = 16 # # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. + + config.train.algorithm = 'sft' + # Change ref_update_step to a small number, e.g., 40, to switch to OnlineSFT. + config.train.ref_update_step=10000000 + + config.train.batch_size = config.sample.train_batch_size + config.train.gradient_accumulation_steps = config.sample.num_batches_per_epoch//2 + config.train.num_inner_epochs = 1 + config.train.timestep_fraction = 0.99 + config.train.beta = 100 + config.sample.global_std=True + config.train.ema=True + config.save_freq = 60 # epoch + config.eval_freq = 60 + config.save_dir = 'logs/pickscore/sd3.5-M-sft' + config.reward_fn = { + "pickscore": 1.0, + } + + config.prompt_fn = "general_ocr" + + config.per_prompt_stat_tracking = True + return config + + +def get_config(name): + return globals()[name]() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1cab484b32586f8bb0b14e572bf5ccd09d08a694 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +torch==2.6.0 +torchvision==0.21.0 +transformers==4.54.0 +diffusers==0.33.1 +accelerate>=0.25.0 +peft>=0.6.2 +safetensors +numpy==1.26.4 +Pillow +gradio>=4.12.0 +spaces>=0.24.0 +tqdm +ml_collections +huggingface-hub +sentencepiece