diff --git a/ProMoE-B-256/__pycache__/pipeline.cpython-312.pyc b/ProMoE-B-256/__pycache__/pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5074d201d49a47c35ce0b2a15a1083065195e8e7 Binary files /dev/null and b/ProMoE-B-256/__pycache__/pipeline.cpython-312.pyc differ diff --git a/ProMoE-B-256/model_index.json b/ProMoE-B-256/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..5700d1111b22b2a886e2181c653d7c99291527ca --- /dev/null +++ b/ProMoE-B-256/model_index.json @@ -0,0 +1,1021 @@ +{ + "_class_name": [ + "pipeline", + "ProMoEPipeline" + ], + "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_flow_match_promoe", + "ProMoEFlowMatchScheduler" + ], + "transformer": [ + "transformer_promoe", + "ProMoETransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ], + "id2label": { + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "3": "tiger shark, Galeocerdo cuvieri", + "4": "hammerhead, hammerhead shark", + "5": "electric ray, crampfish, numbfish, torpedo", + "6": "stingray", + "7": "cock", + "8": "hen", + "9": "ostrich, Struthio camelus", + "10": "brambling, Fringilla montifringilla", + "11": "goldfinch, Carduelis carduelis", + "12": "house finch, linnet, Carpodacus mexicanus", + "13": "junco, snowbird", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "15": "robin, American robin, Turdus migratorius", + "16": "bulbul", + "17": "jay", + "18": "magpie", + "19": "chickadee", + "20": "water ouzel, dipper", + "21": "kite", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "23": "vulture", + "24": "great grey owl, great gray owl, Strix nebulosa", + "25": "European fire salamander, Salamandra salamandra", + "26": "common newt, Triturus vulgaris", + "27": "eft", + "28": "spotted salamander, Ambystoma maculatum", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "30": "bullfrog, Rana catesbeiana", + "31": "tree frog, tree-frog", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "35": "mud turtle", + "36": "terrapin", + "37": "box turtle, box tortoise", + "38": "banded gecko", + "39": "common iguana, iguana, Iguana iguana", + "40": "American chameleon, anole, Anolis carolinensis", + "41": "whiptail, whiptail lizard", + "42": "agama", + "43": "frilled lizard, Chlamydosaurus kingi", + "44": "alligator lizard", + "45": "Gila monster, Heloderma suspectum", + "46": "green lizard, Lacerta viridis", + "47": "African chameleon, Chamaeleo chamaeleon", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "50": "American alligator, Alligator mississipiensis", + "51": "triceratops", + "52": "thunder snake, worm snake, Carphophis amoenus", + "53": "ringneck snake, ring-necked snake, ring snake", + "54": "hognose snake, puff adder, sand viper", + "55": "green snake, grass snake", + "56": "king snake, kingsnake", + "57": "garter snake, grass snake", + "58": "water snake", + "59": "vine snake", + "60": "night snake, Hypsiglena torquata", + "61": "boa constrictor, Constrictor constrictor", + "62": "rock python, rock snake, Python sebae", + "63": "Indian cobra, Naja naja", + "64": "green mamba", + "65": "sea snake", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "69": "trilobite", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "71": "scorpion", + "72": "black and gold garden spider, Argiope aurantia", + "73": "barn spider, Araneus cavaticus", + "74": "garden spider, Aranea diademata", + "75": "black widow, Latrodectus mactans", + "76": "tarantula", + "77": "wolf spider, hunting spider", + "78": "tick", + "79": "centipede", + "80": "black grouse", + "81": "ptarmigan", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "83": "prairie chicken, prairie grouse, prairie fowl", + "84": "peacock", + "85": "quail", + "86": "partridge", + "87": "African grey, African gray, Psittacus erithacus", + "88": "macaw", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "90": "lorikeet", + "91": "coucal", + "92": "bee eater", + "93": "hornbill", + "94": "hummingbird", + "95": "jacamar", + "96": "toucan", + "97": "drake", + "98": "red-breasted merganser, Mergus serrator", + "99": "goose", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "400": "academic gown, academic robe, judge robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenters kit, tool kit", + "478": "carton", + "479": "car wheel", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jewelers loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pjs, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenters plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "730": "plow, plough", + "731": "plunger, plumbers helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potters wheel", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spiders web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" + } +} diff --git a/ProMoE-B-256/pipeline.py b/ProMoE-B-256/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a22aa2d52139703430ab9d7e7ebcc78db2d3d777 --- /dev/null +++ b/ProMoE-B-256/pipeline.py @@ -0,0 +1,259 @@ +"""Hub custom pipeline: ProMoEPipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +try: + from diffusers.pipelines.pipeline_utils import DiffusionPipeline +except Exception: # pragma: no cover + class DiffusionPipeline: + def __init__(self): + self._execution_device = torch.device("cpu") + + def register_modules(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to(self, device): + self._execution_device = torch.device(device) + for module in (getattr(self, "transformer", None), getattr(self, "vae", None)): + if module is not None and hasattr(module, "to"): + module.to(device) + return self + + def progress_bar(self, iterable): + return iterable + + def maybe_free_model_hooks(self): + return None + +@dataclass +class ProMoEPipelineOutput: + images: Union[List[Image.Image], np.ndarray, torch.Tensor] + +class ProMoEPipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with ProMoE. + + Parameters: + transformer ([`ProMoETransformer2DModel`]): + Class-conditional ProMoE transformer for flow-matching in latent space. + scheduler ([`ProMoEFlowMatchScheduler`]): + Flow-matching scheduler used during denoising. + vae ([`AutoencoderKL`], *optional*): + Variational autoencoder used to decode latents to pixels. + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer->vae" + _optional_components = ["vae"] + + def __init__( + self, + transformer, + scheduler, + vae=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + variant_dir = Path(variant_path).resolve() + model_index_path = variant_dir / "model_index.json" + if not model_index_path.exists(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings. Each string must match a synonym in `id2label`. + """ + self._ensure_labels_loaded() + label2id = self.labels + if not label2id: + raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.") + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [label2id[item] for item in label] + + def _get_vae_spatial_downsample(self) -> int: + if self.vae is None: + return 8 + block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0]) + return 2 ** (len(block_out_channels) - 1) + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + device: torch.device, + ) -> torch.LongTensor: + if torch.is_tensor(class_labels): + return class_labels.to(device=device, dtype=torch.long).reshape(-1) + + if isinstance(class_labels, int): + class_label_ids = [class_labels] + elif isinstance(class_labels, str): + class_label_ids = self.get_label_ids(class_labels) + elif class_labels and isinstance(class_labels[0], str): + class_label_ids = self.get_label_ids(class_labels) + else: + class_label_ids = list(class_labels) + + return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1) + + def _prepare_latents( + self, + batch_size: int, + latent_height: int, + latent_width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + shape = (batch_size, self.transformer.in_channels, latent_height, latent_width) + if isinstance(generator, list): + latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator] + return torch.cat(latents, dim=0) + return torch.randn(shape, generator=generator, device=device, dtype=dtype) + + def _decode_latents(self, latents: torch.Tensor, output_type: str): + if output_type == "latent": + return latents + if self.vae is not None: + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + decode_dtype = next(self.vae.parameters()).dtype + latents = (latents / scaling_factor).to(dtype=decode_dtype) + image = self.vae.decode(latents, return_dict=False)[0] + else: + image = latents + + image = (image / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return image + image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() + if output_type == "np": + return image + pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image] + return pil_images + + @torch.no_grad() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ProMoEPipelineOutput, Tuple]: + r""" + Generate class-conditional images with ProMoE. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`): + ImageNet class indices or human-readable English label strings. + """ + device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu") + model_dtype = next(self.transformer.parameters()).dtype + class_labels = self._normalize_class_labels(class_labels, device) + batch_size = class_labels.shape[0] + + vae_scale = self._get_vae_spatial_downsample() + latent_height = height // vae_scale + latent_width = width // vae_scale + latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000)) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1.0: + latent_input = torch.cat([latents, latents], dim=0) + labels = torch.cat([class_labels, null_labels], dim=0) + else: + latent_input = latents + labels = class_labels + timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype) + model_output = self.transformer( + hidden_states=latent_input, + timestep=timestep, + class_labels=labels, + return_dict=True, + ).sample + if model_output.shape[1] != latents.shape[1]: + model_output = model_output.chunk(2, dim=1)[0] + if guidance_scale > 1.0: + model_output_cond, model_output_uncond = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) + latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample + + images = self._decode_latents(latents, output_type) + self.maybe_free_model_hooks() + if not return_dict: + return (images,) + return ProMoEPipelineOutput(images=images) \ No newline at end of file diff --git a/ProMoE-B-256/scheduler/config.json b/ProMoE-B-256/scheduler/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b02311af404a07878b2d91f97ee9b4967e435d8d --- /dev/null +++ b/ProMoE-B-256/scheduler/config.json @@ -0,0 +1,5 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "num_train_timesteps": 1000, + "shift": 1.0 +} diff --git a/ProMoE-B-256/scheduler/scheduler_config.json b/ProMoE-B-256/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d57a6cefb17ef05cb172b2d55177ab379a67a715 --- /dev/null +++ b/ProMoE-B-256/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "shift": 1.0, + "stochastic_sampling": false +} diff --git a/ProMoE-B-256/scheduler/scheduling_flow_match_promoe.py b/ProMoE-B-256/scheduler/scheduling_flow_match_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..d71fe31541e09779d1afe32a7bcb9418a453e69f --- /dev/null +++ b/ProMoE-B-256/scheduler/scheduling_flow_match_promoe.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +import torch + +try: + from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +except Exception: # pragma: no cover + FlowMatchEulerDiscreteScheduler = None + + +@dataclass +class ProMoEFlowMatchSchedulerOutput: + prev_sample: torch.FloatTensor + + +if FlowMatchEulerDiscreteScheduler is not None: + + class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler): + pass + +else: + + class ProMoEFlowMatchScheduler: + def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): + self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False) + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32) + + def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None): + self.timesteps = torch.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=torch.float32, + device=device, + ) + + def step(self, model_output, timestep, sample, generator=None): + del generator + dt = 1.0 / max(len(self.timesteps), 1) + prev_sample = sample - dt * model_output + return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample) diff --git a/ProMoE-B-256/transformer/backbone_diffmoe.py b/ProMoE-B-256/transformer/backbone_diffmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..91f8dfcec6a943fdb985195fa5c706fdb94a4293 --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_diffmoe.py @@ -0,0 +1,302 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + num_experts, + n_shared_experts=0, + capacity=2, + mlp_ratio=4.0, + use_diff_expert=False, + ): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + self.use_diff_expert = use_diff_expert + if use_diff_expert: + self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio)) + + self.capacity_predictor = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, self.num_experts, bias=True), + ) + + if self.n_shared_experts > 0: + mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts)) + self.register_buffer("ema_decay", torch.tensor([0.95])) + + def forward(self, x): + if self.training: + return self.forward_train(x) + return self.forward_eval(x) + + def update_threshold(self, capacity_pred): + if not self.training: + return + capacity_pred = torch.sigmoid(capacity_pred) + seq_len = capacity_pred.size(0) + topk = int((seq_len / self.num_experts) * self.capacity) + threshold = self.expert_threshold + ema_decay = self.ema_decay + for i in range(self.num_experts): + scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True) + quantile = scores[-1].detach() + threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(threshold, op=dist.ReduceOp.SUM) + threshold /= dist.get_world_size() + self.expert_threshold = threshold + + def forward_train(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = self.capacity_predictor(x.detach()) + k = int((total_tokens / self.num_experts) * self.capacity) + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(1, 0) + gating, index = torch.topk(scores, k=k, dim=-1, sorted=False) + mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device) + mask.scatter_(1, index, 1.0) + expert_inputs = x[index] + expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)]) + gated_outputs = gating.unsqueeze(-1) * expert_outputs + + y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device) + offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens + flat_index = (index + offset.long()).view(-1) + y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim)) + y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False) + + self.update_threshold(capacity_pred) + x_out = y.view(bsz, seq_len, hidden_dim) + ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts) + capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + if self.use_diff_expert: + x_out = x_out - self.diff_expert(identity) + return x_out, ones, capacity_pred + + def forward_eval(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach())) + threshold = self.expert_threshold + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(-1, -2) + y = torch.zeros_like(x, dtype=x.dtype) + for i, expert in enumerate(self.experts): + k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum() + gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False) + y[index, :] += gating.unsqueeze(-1) * expert(x[index, :]) + x_out = y.view(bsz, seq_len, hidden_dim) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out, None, None + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + qk_norm=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + mlp_ratio=4.0, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + gate_mlp.unsqueeze(1) * x_mlp + return x, ones, pred_c + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x, None, None + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + CapacityPred_loss_weight=0.01, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.CapacityPred_loss_weight = CapacityPred_loss_weight + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + ones_list, pred_c_list, layer_idx_list = [], [], [] + for layer_idx, block in enumerate(self.blocks): + x, ones, pred_c = block(x, c) + if ones is not None: + ones_list.append(ones) + pred_c_list.append(pred_c) + layer_idx_list.append(layer_idx) + x = self.final_layer(x, c) + x = self.unpatchify(x) + return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight diff --git a/ProMoE-B-256/transformer/backbone_dit.py b/ProMoE-B-256/transformer/backbone_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fde70ff5dc640a9467dfd563e36419f722c7c1 --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_dit.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + head_dim=None, + use_swiglu=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-B-256/transformer/backbone_ecdit.py b/ProMoE-B-256/transformer/backbone_ecdit.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae2c725ae1bac6c23a23cf467d444bed11b9f3d --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_ecdit.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2) + + def forward(self, x): + identity = x + batch_size, seq_len, _ = x.shape + logits = F.linear(x, self.gate_weight, None) + affinity = logits.softmax(dim=-1) + affinity = torch.einsum("b s e -> b e s", affinity) + k = int((seq_len / self.num_experts) * self.capacity) + gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False) + dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype) + x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x) + x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)] + x_e = torch.stack(x_e, dim=1) + x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.gate_proj.weight, std=std) + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + if hasattr(expert, "gate_proj"): + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-B-256/transformer/backbone_promoe_ec.py b/ProMoE-B-256/transformer/backbone_promoe_ec.py new file mode 100644 index 0000000000000000000000000000000000000000..05da901ed601ca8e683ab5d55da0af3922534015 --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_promoe_ec.py @@ -0,0 +1,286 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=1, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, cond_hidden_states): + b_cond, seq_len, _ = cond_hidden_states.shape + num_cond_experts = self.num_routed_experts + input_norm = F.normalize(cond_hidden_states, p=2, dim=-1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1) + cos_sim = input_norm @ cluster_norm.T + cos_sim_expert_view = cos_sim.transpose(1, 2) + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim_expert_view, dim=-1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim_expert_view) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim_expert_view + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len)) + router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False) + dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype) + expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states) + return dispatch_mask, router_weights, expert_inputs + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + identity = hidden_states + batch_size, _, hidden_dim = hidden_states.shape + final_output = torch.zeros_like(hidden_states) + loss = None + cond_batch_mask = ( + labels.view(-1) != 1000 + ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device) + uncond_batch_mask = ~cond_batch_mask + cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts + + if cond_batch_mask.any(): + cond_hidden_states = hidden_states[cond_batch_mask] + dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states) + num_cond_experts = len(cond_experts) + expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1) + cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype) + final_output[cond_batch_mask] = cond_output + if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1: + expert_token_means = expert_inputs.mean(dim=2) + routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means) + loss = routing_contrastive_loss * self.routing_contrastive_lam + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + for expert in cond_experts: + final_output = final_output + expert(dummy_input).sum() * 0 + + if self.use_uncond_expert: + if uncond_batch_mask.any(): + uncond_hidden_states = hidden_states[uncond_batch_mask] + final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype) + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output = final_output + self.experts[-1](dummy_input).sum() * 0 + + if self.use_shared_expert: + final_output += self.shared_expert(identity).to(hidden_states.dtype) + return final_output, loss + + def compute_routing_contrastive_loss(self, expert_token_means): + batch_size, num_cond_experts, _ = expert_token_means.shape + if num_cond_experts < 2: + return torch.tensor(0.0, device=expert_token_means.device) + centers_norm = F.normalize(self.cluster_centers, p=2, dim=1) + means_norm = F.normalize(expert_token_means, p=2, dim=2) + sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm) + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1) + return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1)) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-B-256/transformer/backbone_promoe_tc.py b/ProMoE-B-256/transformer/backbone_promoe_tc.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5f0036823de886748b5a375a98b7f0efc6377f --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_promoe_tc.py @@ -0,0 +1,355 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=2, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del norm_topk_prob + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.seq_aux = seq_aux + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.alpha = load_balance_loss_coef + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, hidden_states, labels): + batch_size, seq_len, _ = hidden_states.shape + device = hidden_states.device + flat_input = hidden_states.view(-1, self.hidden_size) + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + if self.use_uncond_expert and flat_labels is not None: + uncond_mask = flat_labels == 1000 + cond_mask = ~uncond_mask + else: + uncond_mask = None + cond_mask = torch.ones_like(flat_labels, dtype=torch.bool) + + router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype) + expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long) + + if uncond_mask is not None and uncond_mask.any(): + uncond_positions = torch.where(uncond_mask)[0] + router_weights[uncond_positions, 0] = 1.0 + expert_indices[uncond_positions] = self.num_experts - 1 + + cond_weights = None + topk_idx = None + if cond_mask.any(): + cond_positions = torch.where(cond_mask)[0] + cond_input = flat_input[cond_positions] + input_norm = F.normalize(cond_input, p=2, dim=1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1) + cos_sim = input_norm @ cluster_norm.T + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim, dim=1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1) + router_weights[cond_positions] = topk_scores.to(router_weights.dtype) + expert_indices[cond_positions] = topk_idx + + router_weights = router_weights.view(batch_size, seq_len, self.top_k) + expert_indices = expert_indices.view(batch_size, seq_len, self.top_k) + + load_balance_loss = None + if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None: + cond_batch_size = (labels != 1000).sum() + scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights + topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1) + ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.num_routed_experts) + load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.num_routed_experts + load_balance_loss = (pi * fi).sum() * self.alpha + return router_weights, expert_indices, load_balance_loss + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels) + batch_size, seq_len, hidden_dim = hidden_states.shape + flat_input = hidden_states.view(-1, hidden_dim) + flat_weights = router_weights.view(-1, self.top_k) + flat_indices = expert_indices.view(-1, self.top_k) + total_tokens = batch_size * seq_len + final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + + for expert_id in range(self.num_experts): + expert_mask = (flat_indices == expert_id).any(dim=1) + token_ids = torch.where(expert_mask)[0] + if token_ids.numel() > 0: + expert_input = flat_input[token_ids] + expert_weight_mask = flat_indices[token_ids] == expert_id + expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype) + combined_weights = expert_weights.sum(dim=1) + expert_output = self.experts[expert_id](expert_input) + weighted_output = expert_output * combined_weights.unsqueeze(1) + final_output.index_add_(0, token_ids, weighted_output) + else: + dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output[0] += self.experts[expert_id](dummy_input)[0] * 0 + + final_output = final_output.view(batch_size, seq_len, hidden_dim) + if self.use_shared_expert: + final_output += self.shared_expert(hidden_states) + + loss = load_balance_loss + if self.training and self.routing_contrastive_lam > 0: + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + cond_mask = ~( + flat_labels == 1000 + ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device) + cond_token_embeddings = flat_input[cond_mask] + if self.use_top_k_for_routing_contrastive: + cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask] + else: + top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0] + cond_cluster_assignments = top1_expert_indices[cond_mask] + routing_contrastive_loss = self.compute_routing_contrastive_loss( + cond_token_embeddings, + cond_cluster_assignments, + use_top_k=self.use_top_k_for_routing_contrastive, + ) + routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam + loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss + + return final_output, loss + + def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False): + cluster_centers = self.cluster_centers + num_clusters = cluster_centers.size(0) + device = cluster_centers.device + cluster_means = [] + valid_clusters = [] + for cluster_id in range(num_clusters): + mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id + if mask.sum() > 0: + cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True)) + valid_clusters.append(cluster_id) + if len(valid_clusters) < 2: + return torch.tensor(0.0, device=device) + cluster_means = torch.cat(cluster_means, dim=0) + valid_centers = cluster_centers[valid_clusters] + centers_norm = F.normalize(valid_centers, p=2, dim=1) + means_norm = F.normalize(cluster_means, p=2, dim=1) + sim_matrix = centers_norm @ means_norm.T + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(sim_matrix.size(0), device=device) + return F.cross_entropy(logits, labels) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-B-256/transformer/backbone_tcdit.py b/ProMoE-B-256/transformer/backbone_tcdit.py new file mode 100644 index 0000000000000000000000000000000000000000..18bc64b114caf8c359ff6842ffd54bdf18af2123 --- /dev/null +++ b/ProMoE-B-256/transformer/backbone_tcdit.py @@ -0,0 +1,304 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + self.scoring_func = "softmax" + self.alpha = aux_loss_alpha + self.seq_aux = False + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func != "softmax": + raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}") + scores = logits.softmax(dim=-1) + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + if self.top_k > 1 and self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + mlp_ratio=4, + num_experts=16, + num_experts_per_tok=2, + pretraining_tp=2, + n_shared_experts=2, + ): + super().__init__() + self.top_k = num_experts_per_tok + self.experts = nn.ModuleList(experts) + self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP( + hidden_size=hidden_dim, + intermediate_size=intermediate_size, + pretraining_tp=pretraining_tp, + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states, dtype=hidden_states.dtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float() + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + if self.n_shared_experts > 0: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.top_k + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_( + 0, + exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), + expert_out, + reduce="sum", + ) + return expert_cache + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4, + pretraining_tp=2, + use_swiglu=False, + MoE_config=None, + use_moe=True, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [ + MoeMLP( + hidden_size=hidden_size, + intermediate_size=mlp_hidden_dim, + pretraining_tp=pretraining_tp, + ) + for _ in range(MoE_config.num_experts) + ] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + num_experts_per_tok=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + pretraining_tp=1, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + pretraining_tp=pretraining_tp, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-B-256/transformer/config.json b/ProMoE-B-256/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..eddc2006a3c803628bee5ce4cc41d7dd919ca2fd --- /dev/null +++ b/ProMoE-B-256/transformer/config.json @@ -0,0 +1,22 @@ +{ + "_class_name": "ProMoETransformer2DModel", + "architecture": "promoe_tc", + "model_config": { + "MoE_config": { + "init_MoeMLP": false, + "interleave": true, + "moe_intermediate_size": 1536, + "num_routed_experts": 12, + "shared_expert_intermediate_size": 1536, + "top_k": 1, + "use_shared_expert": true, + "use_uncond_expert": true + }, + "depth": 12, + "hidden_size": 768, + "input_size": 32, + "num_classes": 1000, + "num_heads": 12, + "patch_size": 2 + } +} diff --git a/ProMoE-B-256/transformer/diffusion_pytorch_model.safetensors b/ProMoE-B-256/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..575743515c41f33b6eb4a213a0b06e4255c5293e --- /dev/null +++ b/ProMoE-B-256/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08f42b814b6cb9b8948665fc996bbb559e6d55b0961253573a8f0d6e4c64fdcd +size 1202482576 diff --git a/ProMoE-B-256/transformer/modeling_promoe_common.py b/ProMoE-B-256/transformer/modeling_promoe_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0a82f2ece8db2dff46a45faafb9731af18f09a34 --- /dev/null +++ b/ProMoE-B-256/transformer/modeling_promoe_common.py @@ -0,0 +1,291 @@ +import collections.abc +import math +from dataclasses import dataclass +from itertools import repeat +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as error: + raise AttributeError(item) from error + + def __setattr__(self, key, value): + self[key] = value + + @staticmethod + def from_data(data: Any) -> Any: + if isinstance(data, dict): + return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()}) + if isinstance(data, list): + return [AttrDict.from_data(v) for v in data] + return data + + +class PatchEmbed(nn.Module): + def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True): + super().__init__() + self.img_size = to_2tuple(input_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = ( + self.img_size[0] // self.patch_size[0], + self.img_size[1] // self.patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + return hidden_states.flatten(2).transpose(1, 2) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MoeMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class MoeMLP_DiffMoE(nn.Module): + def __init__(self, hidden_size, intermediate_size, pretraining_tp=2): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = nn.SiLU() + self.pretraining_tp = pretraining_tp + + def forward(self, x): + if self.pretraining_tp > 1: + split_size = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0) + up_proj_slices = self.up_proj.weight.split(split_size, dim=0) + down_proj_slices = self.down_proj.weight.split(split_size, dim=1) + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + return sum(down_proj) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + head_dim=None, + norm_layer: nn.Module = nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + if head_dim is None: + if dim % num_heads != 0: + raise ValueError("dim must be divisible by num_heads") + self.head_dim = dim // num_heads + else: + self.head_dim = head_dim + self.scale = self.head_dim**-0.5 + self.fused_attn = True + self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)).softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(batch_size, seq_len, -1) + x = self.proj(x) + return self.proj_drop(x) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + return self.mlp(t_freq.to(dtype=weight_dtype)) + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + self.return_labels = return_labels + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + return torch.where(drop_ids, self.num_classes, labels) + + def forward(self, labels, train, force_drop_ids=None): + if (train and self.dropout_prob > 0) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + if self.return_labels: + return embeddings, labels + return embeddings + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + return self.linear(x) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + return np.concatenate([emb_sin, emb_cos], axis=1) diff --git a/ProMoE-B-256/transformer/transformer_promoe.py b/ProMoE-B-256/transformer/transformer_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6369fcdcb55c394ef5c7fa8d4d50b7b32ba145 --- /dev/null +++ b/ProMoE-B-256/transformer/transformer_promoe.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except Exception: # pragma: no cover + class BaseOutput(dict): + def __post_init__(self): + self.update(self.__dict__) + + class _Config(dict): + def __getattr__(self, key): + try: + return self[key] + except KeyError as error: + raise AttributeError(key) from error + + class ConfigMixin: + config_name = "config.json" + + class ModelMixin(nn.Module): + pass + + def register_to_config(init): + def wrapper(self, *args, **kwargs): + import inspect + + signature = inspect.signature(init) + bound = signature.bind(self, *args, **kwargs) + bound.apply_defaults() + self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"}) + init(self, *args, **kwargs) + + return wrapper + +from .backbone_diffmoe import DiT as DiffMoEBackbone +from .backbone_dit import DiT as DiTBackbone +from .backbone_ecdit import DiT as ECDiTBackbone +from .backbone_promoe_ec import DiT as ProMoEECBackbone +from .backbone_promoe_tc import DiT as ProMoETCBackbone +from .backbone_tcdit import DiT as TCDiTBackbone +from .modeling_promoe_common import AttrDict + + +@dataclass +class ProMoETransformer2DModelOutput(BaseOutput): + sample: torch.FloatTensor + loss_strategy: Optional[str] = None + layer_idx_list: Optional[Tuple[int, ...]] = None + ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None + pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None + capacity_pred_loss_weight: Optional[float] = None + + +_BACKBONES = { + "dit": DiTBackbone, + "tcdit": TCDiTBackbone, + "ecdit": ECDiTBackbone, + "diffmoe": DiffMoEBackbone, + "promoe_tc": ProMoETCBackbone, + "promoe_ec": ProMoEECBackbone, +} + + +class ProMoETransformer2DModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None): + super().__init__() + if architecture not in _BACKBONES: + raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}") + model_config = model_config or {} + self.architecture = architecture + self.model_config = model_config + self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config)) + self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4)) + self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4)) + + def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]: + prepared = {} + for key, value in model_config.items(): + prepared[key] = AttrDict.from_data(value) + return prepared + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.LongTensor] = None, + context: Optional[torch.LongTensor] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]: + labels = class_labels if class_labels is not None else context + if labels is None: + raise ValueError("Either `class_labels` or `context` must be provided.") + + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten() + if timestep.numel() == 1: + timestep = timestep.repeat(labels.shape[0]) + + sample = self.backbone(hidden_states, timestep, labels, **kwargs) + if isinstance(sample, tuple): + if len(sample) == 6 and sample[1] == "Capacity_Pred": + output = ProMoETransformer2DModelOutput( + sample=sample[0], + loss_strategy=sample[1], + layer_idx_list=tuple(sample[2]), + ones_list=tuple(sample[3]), + pred_c_list=tuple(sample[4]), + capacity_pred_loss_weight=float(sample[5]), + ) + else: + output = ProMoETransformer2DModelOutput(sample=sample[0]) + else: + output = ProMoETransformer2DModelOutput(sample=sample) + + if not return_dict: + if output.loss_strategy is None: + return (output.sample,) + return ( + output.sample, + output.loss_strategy, + output.layer_idx_list, + output.ones_list, + output.pred_c_list, + output.capacity_pred_loss_weight, + ) + return output diff --git a/ProMoE-B-256/vae/config.json b/ProMoE-B-256/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..0db26717579be63eb0ddbf15b43faa43700dfe5a --- /dev/null +++ b/ProMoE-B-256/vae/config.json @@ -0,0 +1,29 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.4.2", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] +} diff --git a/ProMoE-B-256/vae/diffusion_pytorch_model.safetensors b/ProMoE-B-256/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d6fc2b1f7ae2b1f4f83c25812f819a17473f0c1a --- /dev/null +++ b/ProMoE-B-256/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec +size 334643268 diff --git a/ProMoE-L-256/model_index.json b/ProMoE-L-256/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..c38159a94f58f8ec4828be64008655e80b74d2c2 --- /dev/null +++ b/ProMoE-L-256/model_index.json @@ -0,0 +1,1021 @@ +{ + "_class_name": [ + "pipeline", + "ProMoEPipeline" + ], + "_diffusers_version": "0.36.0", + "id2label": { + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "10": "brambling, Fringilla montifringilla", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "11": "goldfinch, Carduelis carduelis", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "12": "house finch, linnet, Carpodacus mexicanus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "13": "junco, snowbird", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "15": "robin, American robin, Turdus migratorius", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "16": "bulbul", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "17": "jay", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "18": "magpie", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "19": "chickadee", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "20": "water ouzel, dipper", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "21": "kite", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "23": "vulture", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "24": "great grey owl, great gray owl, Strix nebulosa", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "25": "European fire salamander, Salamandra salamandra", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "26": "common newt, Triturus vulgaris", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "27": "eft", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "28": "spotted salamander, Ambystoma maculatum", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "3": "tiger shark, Galeocerdo cuvieri", + "30": "bullfrog, Rana catesbeiana", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "31": "tree frog, tree-frog", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "35": "mud turtle", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "36": "terrapin", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "37": "box turtle, box tortoise", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "38": "banded gecko", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "39": "common iguana, iguana, Iguana iguana", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "4": "hammerhead, hammerhead shark", + "40": "American chameleon, anole, Anolis carolinensis", + "400": "academic gown, academic robe, judge robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "41": "whiptail, whiptail lizard", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "42": "agama", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "43": "frilled lizard, Chlamydosaurus kingi", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "44": "alligator lizard", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "45": "Gila monster, Heloderma suspectum", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "46": "green lizard, Lacerta viridis", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "47": "African chameleon, Chamaeleo chamaeleon", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenters kit, tool kit", + "478": "carton", + "479": "car wheel", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "5": "electric ray, crampfish, numbfish, torpedo", + "50": "American alligator, Alligator mississipiensis", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "51": "triceratops", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "52": "thunder snake, worm snake, Carphophis amoenus", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "53": "ringneck snake, ring-necked snake, ring snake", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "54": "hognose snake, puff adder, sand viper", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "55": "green snake, grass snake", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "56": "king snake, kingsnake", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "57": "garter snake, grass snake", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "58": "water snake", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "59": "vine snake", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "6": "stingray", + "60": "night snake, Hypsiglena torquata", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "61": "boa constrictor, Constrictor constrictor", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "62": "rock python, rock snake, Python sebae", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "63": "Indian cobra, Naja naja", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jewelers loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "64": "green mamba", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "65": "sea snake", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "69": "trilobite", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pjs, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "7": "cock", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "71": "scorpion", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "72": "black and gold garden spider, Argiope aurantia", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenters plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "73": "barn spider, Araneus cavaticus", + "730": "plow, plough", + "731": "plunger, plumbers helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potters wheel", + "74": "garden spider, Aranea diademata", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "75": "black widow, Latrodectus mactans", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "76": "tarantula", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "77": "wolf spider, hunting spider", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "78": "tick", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "79": "centipede", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "8": "hen", + "80": "black grouse", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "81": "ptarmigan", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spiders web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "83": "prairie chicken, prairie grouse, prairie fowl", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "84": "peacock", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "85": "quail", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "86": "partridge", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "87": "African grey, African gray, Psittacus erithacus", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "88": "macaw", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "9": "ostrich, Struthio camelus", + "90": "lorikeet", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "91": "coucal", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "92": "bee eater", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "93": "hornbill", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "94": "hummingbird", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "95": "jacamar", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "96": "toucan", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "97": "drake", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "98": "red-breasted merganser, Mergus serrator", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "99": "goose", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" + }, + "scheduler": [ + "scheduling_flow_match_promoe", + "ProMoEFlowMatchScheduler" + ], + "transformer": [ + "transformer_promoe", + "ProMoETransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/ProMoE-L-256/pipeline.py b/ProMoE-L-256/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a22aa2d52139703430ab9d7e7ebcc78db2d3d777 --- /dev/null +++ b/ProMoE-L-256/pipeline.py @@ -0,0 +1,259 @@ +"""Hub custom pipeline: ProMoEPipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +try: + from diffusers.pipelines.pipeline_utils import DiffusionPipeline +except Exception: # pragma: no cover + class DiffusionPipeline: + def __init__(self): + self._execution_device = torch.device("cpu") + + def register_modules(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to(self, device): + self._execution_device = torch.device(device) + for module in (getattr(self, "transformer", None), getattr(self, "vae", None)): + if module is not None and hasattr(module, "to"): + module.to(device) + return self + + def progress_bar(self, iterable): + return iterable + + def maybe_free_model_hooks(self): + return None + +@dataclass +class ProMoEPipelineOutput: + images: Union[List[Image.Image], np.ndarray, torch.Tensor] + +class ProMoEPipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with ProMoE. + + Parameters: + transformer ([`ProMoETransformer2DModel`]): + Class-conditional ProMoE transformer for flow-matching in latent space. + scheduler ([`ProMoEFlowMatchScheduler`]): + Flow-matching scheduler used during denoising. + vae ([`AutoencoderKL`], *optional*): + Variational autoencoder used to decode latents to pixels. + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer->vae" + _optional_components = ["vae"] + + def __init__( + self, + transformer, + scheduler, + vae=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + variant_dir = Path(variant_path).resolve() + model_index_path = variant_dir / "model_index.json" + if not model_index_path.exists(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings. Each string must match a synonym in `id2label`. + """ + self._ensure_labels_loaded() + label2id = self.labels + if not label2id: + raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.") + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [label2id[item] for item in label] + + def _get_vae_spatial_downsample(self) -> int: + if self.vae is None: + return 8 + block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0]) + return 2 ** (len(block_out_channels) - 1) + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + device: torch.device, + ) -> torch.LongTensor: + if torch.is_tensor(class_labels): + return class_labels.to(device=device, dtype=torch.long).reshape(-1) + + if isinstance(class_labels, int): + class_label_ids = [class_labels] + elif isinstance(class_labels, str): + class_label_ids = self.get_label_ids(class_labels) + elif class_labels and isinstance(class_labels[0], str): + class_label_ids = self.get_label_ids(class_labels) + else: + class_label_ids = list(class_labels) + + return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1) + + def _prepare_latents( + self, + batch_size: int, + latent_height: int, + latent_width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + shape = (batch_size, self.transformer.in_channels, latent_height, latent_width) + if isinstance(generator, list): + latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator] + return torch.cat(latents, dim=0) + return torch.randn(shape, generator=generator, device=device, dtype=dtype) + + def _decode_latents(self, latents: torch.Tensor, output_type: str): + if output_type == "latent": + return latents + if self.vae is not None: + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + decode_dtype = next(self.vae.parameters()).dtype + latents = (latents / scaling_factor).to(dtype=decode_dtype) + image = self.vae.decode(latents, return_dict=False)[0] + else: + image = latents + + image = (image / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return image + image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() + if output_type == "np": + return image + pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image] + return pil_images + + @torch.no_grad() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ProMoEPipelineOutput, Tuple]: + r""" + Generate class-conditional images with ProMoE. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`): + ImageNet class indices or human-readable English label strings. + """ + device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu") + model_dtype = next(self.transformer.parameters()).dtype + class_labels = self._normalize_class_labels(class_labels, device) + batch_size = class_labels.shape[0] + + vae_scale = self._get_vae_spatial_downsample() + latent_height = height // vae_scale + latent_width = width // vae_scale + latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000)) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1.0: + latent_input = torch.cat([latents, latents], dim=0) + labels = torch.cat([class_labels, null_labels], dim=0) + else: + latent_input = latents + labels = class_labels + timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype) + model_output = self.transformer( + hidden_states=latent_input, + timestep=timestep, + class_labels=labels, + return_dict=True, + ).sample + if model_output.shape[1] != latents.shape[1]: + model_output = model_output.chunk(2, dim=1)[0] + if guidance_scale > 1.0: + model_output_cond, model_output_uncond = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) + latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample + + images = self._decode_latents(latents, output_type) + self.maybe_free_model_hooks() + if not return_dict: + return (images,) + return ProMoEPipelineOutput(images=images) \ No newline at end of file diff --git a/ProMoE-L-256/scheduler/config.json b/ProMoE-L-256/scheduler/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b02311af404a07878b2d91f97ee9b4967e435d8d --- /dev/null +++ b/ProMoE-L-256/scheduler/config.json @@ -0,0 +1,5 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "num_train_timesteps": 1000, + "shift": 1.0 +} diff --git a/ProMoE-L-256/scheduler/scheduler_config.json b/ProMoE-L-256/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d57a6cefb17ef05cb172b2d55177ab379a67a715 --- /dev/null +++ b/ProMoE-L-256/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "shift": 1.0, + "stochastic_sampling": false +} diff --git a/ProMoE-L-256/scheduler/scheduling_flow_match_promoe.py b/ProMoE-L-256/scheduler/scheduling_flow_match_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..d71fe31541e09779d1afe32a7bcb9418a453e69f --- /dev/null +++ b/ProMoE-L-256/scheduler/scheduling_flow_match_promoe.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +import torch + +try: + from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +except Exception: # pragma: no cover + FlowMatchEulerDiscreteScheduler = None + + +@dataclass +class ProMoEFlowMatchSchedulerOutput: + prev_sample: torch.FloatTensor + + +if FlowMatchEulerDiscreteScheduler is not None: + + class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler): + pass + +else: + + class ProMoEFlowMatchScheduler: + def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): + self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False) + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32) + + def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None): + self.timesteps = torch.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=torch.float32, + device=device, + ) + + def step(self, model_output, timestep, sample, generator=None): + del generator + dt = 1.0 / max(len(self.timesteps), 1) + prev_sample = sample - dt * model_output + return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample) diff --git a/ProMoE-L-256/transformer/backbone_diffmoe.py b/ProMoE-L-256/transformer/backbone_diffmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..91f8dfcec6a943fdb985195fa5c706fdb94a4293 --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_diffmoe.py @@ -0,0 +1,302 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + num_experts, + n_shared_experts=0, + capacity=2, + mlp_ratio=4.0, + use_diff_expert=False, + ): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + self.use_diff_expert = use_diff_expert + if use_diff_expert: + self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio)) + + self.capacity_predictor = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, self.num_experts, bias=True), + ) + + if self.n_shared_experts > 0: + mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts)) + self.register_buffer("ema_decay", torch.tensor([0.95])) + + def forward(self, x): + if self.training: + return self.forward_train(x) + return self.forward_eval(x) + + def update_threshold(self, capacity_pred): + if not self.training: + return + capacity_pred = torch.sigmoid(capacity_pred) + seq_len = capacity_pred.size(0) + topk = int((seq_len / self.num_experts) * self.capacity) + threshold = self.expert_threshold + ema_decay = self.ema_decay + for i in range(self.num_experts): + scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True) + quantile = scores[-1].detach() + threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(threshold, op=dist.ReduceOp.SUM) + threshold /= dist.get_world_size() + self.expert_threshold = threshold + + def forward_train(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = self.capacity_predictor(x.detach()) + k = int((total_tokens / self.num_experts) * self.capacity) + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(1, 0) + gating, index = torch.topk(scores, k=k, dim=-1, sorted=False) + mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device) + mask.scatter_(1, index, 1.0) + expert_inputs = x[index] + expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)]) + gated_outputs = gating.unsqueeze(-1) * expert_outputs + + y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device) + offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens + flat_index = (index + offset.long()).view(-1) + y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim)) + y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False) + + self.update_threshold(capacity_pred) + x_out = y.view(bsz, seq_len, hidden_dim) + ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts) + capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + if self.use_diff_expert: + x_out = x_out - self.diff_expert(identity) + return x_out, ones, capacity_pred + + def forward_eval(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach())) + threshold = self.expert_threshold + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(-1, -2) + y = torch.zeros_like(x, dtype=x.dtype) + for i, expert in enumerate(self.experts): + k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum() + gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False) + y[index, :] += gating.unsqueeze(-1) * expert(x[index, :]) + x_out = y.view(bsz, seq_len, hidden_dim) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out, None, None + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + qk_norm=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + mlp_ratio=4.0, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + gate_mlp.unsqueeze(1) * x_mlp + return x, ones, pred_c + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x, None, None + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + CapacityPred_loss_weight=0.01, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.CapacityPred_loss_weight = CapacityPred_loss_weight + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + ones_list, pred_c_list, layer_idx_list = [], [], [] + for layer_idx, block in enumerate(self.blocks): + x, ones, pred_c = block(x, c) + if ones is not None: + ones_list.append(ones) + pred_c_list.append(pred_c) + layer_idx_list.append(layer_idx) + x = self.final_layer(x, c) + x = self.unpatchify(x) + return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight diff --git a/ProMoE-L-256/transformer/backbone_dit.py b/ProMoE-L-256/transformer/backbone_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fde70ff5dc640a9467dfd563e36419f722c7c1 --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_dit.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + head_dim=None, + use_swiglu=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-L-256/transformer/backbone_ecdit.py b/ProMoE-L-256/transformer/backbone_ecdit.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae2c725ae1bac6c23a23cf467d444bed11b9f3d --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_ecdit.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2) + + def forward(self, x): + identity = x + batch_size, seq_len, _ = x.shape + logits = F.linear(x, self.gate_weight, None) + affinity = logits.softmax(dim=-1) + affinity = torch.einsum("b s e -> b e s", affinity) + k = int((seq_len / self.num_experts) * self.capacity) + gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False) + dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype) + x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x) + x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)] + x_e = torch.stack(x_e, dim=1) + x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.gate_proj.weight, std=std) + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + if hasattr(expert, "gate_proj"): + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-L-256/transformer/backbone_promoe_ec.py b/ProMoE-L-256/transformer/backbone_promoe_ec.py new file mode 100644 index 0000000000000000000000000000000000000000..05da901ed601ca8e683ab5d55da0af3922534015 --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_promoe_ec.py @@ -0,0 +1,286 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=1, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, cond_hidden_states): + b_cond, seq_len, _ = cond_hidden_states.shape + num_cond_experts = self.num_routed_experts + input_norm = F.normalize(cond_hidden_states, p=2, dim=-1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1) + cos_sim = input_norm @ cluster_norm.T + cos_sim_expert_view = cos_sim.transpose(1, 2) + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim_expert_view, dim=-1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim_expert_view) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim_expert_view + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len)) + router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False) + dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype) + expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states) + return dispatch_mask, router_weights, expert_inputs + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + identity = hidden_states + batch_size, _, hidden_dim = hidden_states.shape + final_output = torch.zeros_like(hidden_states) + loss = None + cond_batch_mask = ( + labels.view(-1) != 1000 + ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device) + uncond_batch_mask = ~cond_batch_mask + cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts + + if cond_batch_mask.any(): + cond_hidden_states = hidden_states[cond_batch_mask] + dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states) + num_cond_experts = len(cond_experts) + expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1) + cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype) + final_output[cond_batch_mask] = cond_output + if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1: + expert_token_means = expert_inputs.mean(dim=2) + routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means) + loss = routing_contrastive_loss * self.routing_contrastive_lam + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + for expert in cond_experts: + final_output = final_output + expert(dummy_input).sum() * 0 + + if self.use_uncond_expert: + if uncond_batch_mask.any(): + uncond_hidden_states = hidden_states[uncond_batch_mask] + final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype) + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output = final_output + self.experts[-1](dummy_input).sum() * 0 + + if self.use_shared_expert: + final_output += self.shared_expert(identity).to(hidden_states.dtype) + return final_output, loss + + def compute_routing_contrastive_loss(self, expert_token_means): + batch_size, num_cond_experts, _ = expert_token_means.shape + if num_cond_experts < 2: + return torch.tensor(0.0, device=expert_token_means.device) + centers_norm = F.normalize(self.cluster_centers, p=2, dim=1) + means_norm = F.normalize(expert_token_means, p=2, dim=2) + sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm) + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1) + return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1)) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-L-256/transformer/backbone_promoe_tc.py b/ProMoE-L-256/transformer/backbone_promoe_tc.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5f0036823de886748b5a375a98b7f0efc6377f --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_promoe_tc.py @@ -0,0 +1,355 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=2, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del norm_topk_prob + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.seq_aux = seq_aux + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.alpha = load_balance_loss_coef + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, hidden_states, labels): + batch_size, seq_len, _ = hidden_states.shape + device = hidden_states.device + flat_input = hidden_states.view(-1, self.hidden_size) + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + if self.use_uncond_expert and flat_labels is not None: + uncond_mask = flat_labels == 1000 + cond_mask = ~uncond_mask + else: + uncond_mask = None + cond_mask = torch.ones_like(flat_labels, dtype=torch.bool) + + router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype) + expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long) + + if uncond_mask is not None and uncond_mask.any(): + uncond_positions = torch.where(uncond_mask)[0] + router_weights[uncond_positions, 0] = 1.0 + expert_indices[uncond_positions] = self.num_experts - 1 + + cond_weights = None + topk_idx = None + if cond_mask.any(): + cond_positions = torch.where(cond_mask)[0] + cond_input = flat_input[cond_positions] + input_norm = F.normalize(cond_input, p=2, dim=1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1) + cos_sim = input_norm @ cluster_norm.T + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim, dim=1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1) + router_weights[cond_positions] = topk_scores.to(router_weights.dtype) + expert_indices[cond_positions] = topk_idx + + router_weights = router_weights.view(batch_size, seq_len, self.top_k) + expert_indices = expert_indices.view(batch_size, seq_len, self.top_k) + + load_balance_loss = None + if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None: + cond_batch_size = (labels != 1000).sum() + scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights + topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1) + ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.num_routed_experts) + load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.num_routed_experts + load_balance_loss = (pi * fi).sum() * self.alpha + return router_weights, expert_indices, load_balance_loss + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels) + batch_size, seq_len, hidden_dim = hidden_states.shape + flat_input = hidden_states.view(-1, hidden_dim) + flat_weights = router_weights.view(-1, self.top_k) + flat_indices = expert_indices.view(-1, self.top_k) + total_tokens = batch_size * seq_len + final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + + for expert_id in range(self.num_experts): + expert_mask = (flat_indices == expert_id).any(dim=1) + token_ids = torch.where(expert_mask)[0] + if token_ids.numel() > 0: + expert_input = flat_input[token_ids] + expert_weight_mask = flat_indices[token_ids] == expert_id + expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype) + combined_weights = expert_weights.sum(dim=1) + expert_output = self.experts[expert_id](expert_input) + weighted_output = expert_output * combined_weights.unsqueeze(1) + final_output.index_add_(0, token_ids, weighted_output) + else: + dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output[0] += self.experts[expert_id](dummy_input)[0] * 0 + + final_output = final_output.view(batch_size, seq_len, hidden_dim) + if self.use_shared_expert: + final_output += self.shared_expert(hidden_states) + + loss = load_balance_loss + if self.training and self.routing_contrastive_lam > 0: + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + cond_mask = ~( + flat_labels == 1000 + ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device) + cond_token_embeddings = flat_input[cond_mask] + if self.use_top_k_for_routing_contrastive: + cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask] + else: + top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0] + cond_cluster_assignments = top1_expert_indices[cond_mask] + routing_contrastive_loss = self.compute_routing_contrastive_loss( + cond_token_embeddings, + cond_cluster_assignments, + use_top_k=self.use_top_k_for_routing_contrastive, + ) + routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam + loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss + + return final_output, loss + + def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False): + cluster_centers = self.cluster_centers + num_clusters = cluster_centers.size(0) + device = cluster_centers.device + cluster_means = [] + valid_clusters = [] + for cluster_id in range(num_clusters): + mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id + if mask.sum() > 0: + cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True)) + valid_clusters.append(cluster_id) + if len(valid_clusters) < 2: + return torch.tensor(0.0, device=device) + cluster_means = torch.cat(cluster_means, dim=0) + valid_centers = cluster_centers[valid_clusters] + centers_norm = F.normalize(valid_centers, p=2, dim=1) + means_norm = F.normalize(cluster_means, p=2, dim=1) + sim_matrix = centers_norm @ means_norm.T + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(sim_matrix.size(0), device=device) + return F.cross_entropy(logits, labels) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-L-256/transformer/backbone_tcdit.py b/ProMoE-L-256/transformer/backbone_tcdit.py new file mode 100644 index 0000000000000000000000000000000000000000..18bc64b114caf8c359ff6842ffd54bdf18af2123 --- /dev/null +++ b/ProMoE-L-256/transformer/backbone_tcdit.py @@ -0,0 +1,304 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + self.scoring_func = "softmax" + self.alpha = aux_loss_alpha + self.seq_aux = False + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func != "softmax": + raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}") + scores = logits.softmax(dim=-1) + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + if self.top_k > 1 and self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + mlp_ratio=4, + num_experts=16, + num_experts_per_tok=2, + pretraining_tp=2, + n_shared_experts=2, + ): + super().__init__() + self.top_k = num_experts_per_tok + self.experts = nn.ModuleList(experts) + self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP( + hidden_size=hidden_dim, + intermediate_size=intermediate_size, + pretraining_tp=pretraining_tp, + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states, dtype=hidden_states.dtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float() + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + if self.n_shared_experts > 0: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.top_k + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_( + 0, + exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), + expert_out, + reduce="sum", + ) + return expert_cache + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4, + pretraining_tp=2, + use_swiglu=False, + MoE_config=None, + use_moe=True, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [ + MoeMLP( + hidden_size=hidden_size, + intermediate_size=mlp_hidden_dim, + pretraining_tp=pretraining_tp, + ) + for _ in range(MoE_config.num_experts) + ] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + num_experts_per_tok=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + pretraining_tp=1, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + pretraining_tp=pretraining_tp, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-L-256/transformer/config.json b/ProMoE-L-256/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3d2226c29a691d6c10d1d08abd50bef1d4008a1c --- /dev/null +++ b/ProMoE-L-256/transformer/config.json @@ -0,0 +1,22 @@ +{ + "_class_name": "ProMoETransformer2DModel", + "architecture": "promoe_tc", + "model_config": { + "MoE_config": { + "init_MoeMLP": false, + "interleave": true, + "moe_intermediate_size": 2048, + "num_routed_experts": 12, + "shared_expert_intermediate_size": 2048, + "top_k": 1, + "use_shared_expert": true, + "use_uncond_expert": true + }, + "depth": 24, + "hidden_size": 1024, + "input_size": 32, + "num_classes": 1000, + "num_heads": 16, + "patch_size": 2 + } +} diff --git a/ProMoE-L-256/transformer/diffusion_pytorch_model.safetensors b/ProMoE-L-256/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..19d6d397b4b2905cdacf95dba8ed962eb5ccd1bf --- /dev/null +++ b/ProMoE-L-256/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d6f52a00ecfdb55d68bd525851a28d802f57e218b4d4dd0de8e5136e3c16c75 +size 4250844688 diff --git a/ProMoE-L-256/transformer/modeling_promoe_common.py b/ProMoE-L-256/transformer/modeling_promoe_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0a82f2ece8db2dff46a45faafb9731af18f09a34 --- /dev/null +++ b/ProMoE-L-256/transformer/modeling_promoe_common.py @@ -0,0 +1,291 @@ +import collections.abc +import math +from dataclasses import dataclass +from itertools import repeat +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as error: + raise AttributeError(item) from error + + def __setattr__(self, key, value): + self[key] = value + + @staticmethod + def from_data(data: Any) -> Any: + if isinstance(data, dict): + return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()}) + if isinstance(data, list): + return [AttrDict.from_data(v) for v in data] + return data + + +class PatchEmbed(nn.Module): + def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True): + super().__init__() + self.img_size = to_2tuple(input_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = ( + self.img_size[0] // self.patch_size[0], + self.img_size[1] // self.patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + return hidden_states.flatten(2).transpose(1, 2) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MoeMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class MoeMLP_DiffMoE(nn.Module): + def __init__(self, hidden_size, intermediate_size, pretraining_tp=2): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = nn.SiLU() + self.pretraining_tp = pretraining_tp + + def forward(self, x): + if self.pretraining_tp > 1: + split_size = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0) + up_proj_slices = self.up_proj.weight.split(split_size, dim=0) + down_proj_slices = self.down_proj.weight.split(split_size, dim=1) + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + return sum(down_proj) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + head_dim=None, + norm_layer: nn.Module = nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + if head_dim is None: + if dim % num_heads != 0: + raise ValueError("dim must be divisible by num_heads") + self.head_dim = dim // num_heads + else: + self.head_dim = head_dim + self.scale = self.head_dim**-0.5 + self.fused_attn = True + self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)).softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(batch_size, seq_len, -1) + x = self.proj(x) + return self.proj_drop(x) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + return self.mlp(t_freq.to(dtype=weight_dtype)) + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + self.return_labels = return_labels + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + return torch.where(drop_ids, self.num_classes, labels) + + def forward(self, labels, train, force_drop_ids=None): + if (train and self.dropout_prob > 0) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + if self.return_labels: + return embeddings, labels + return embeddings + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + return self.linear(x) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + return np.concatenate([emb_sin, emb_cos], axis=1) diff --git a/ProMoE-L-256/transformer/transformer_promoe.py b/ProMoE-L-256/transformer/transformer_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6369fcdcb55c394ef5c7fa8d4d50b7b32ba145 --- /dev/null +++ b/ProMoE-L-256/transformer/transformer_promoe.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except Exception: # pragma: no cover + class BaseOutput(dict): + def __post_init__(self): + self.update(self.__dict__) + + class _Config(dict): + def __getattr__(self, key): + try: + return self[key] + except KeyError as error: + raise AttributeError(key) from error + + class ConfigMixin: + config_name = "config.json" + + class ModelMixin(nn.Module): + pass + + def register_to_config(init): + def wrapper(self, *args, **kwargs): + import inspect + + signature = inspect.signature(init) + bound = signature.bind(self, *args, **kwargs) + bound.apply_defaults() + self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"}) + init(self, *args, **kwargs) + + return wrapper + +from .backbone_diffmoe import DiT as DiffMoEBackbone +from .backbone_dit import DiT as DiTBackbone +from .backbone_ecdit import DiT as ECDiTBackbone +from .backbone_promoe_ec import DiT as ProMoEECBackbone +from .backbone_promoe_tc import DiT as ProMoETCBackbone +from .backbone_tcdit import DiT as TCDiTBackbone +from .modeling_promoe_common import AttrDict + + +@dataclass +class ProMoETransformer2DModelOutput(BaseOutput): + sample: torch.FloatTensor + loss_strategy: Optional[str] = None + layer_idx_list: Optional[Tuple[int, ...]] = None + ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None + pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None + capacity_pred_loss_weight: Optional[float] = None + + +_BACKBONES = { + "dit": DiTBackbone, + "tcdit": TCDiTBackbone, + "ecdit": ECDiTBackbone, + "diffmoe": DiffMoEBackbone, + "promoe_tc": ProMoETCBackbone, + "promoe_ec": ProMoEECBackbone, +} + + +class ProMoETransformer2DModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None): + super().__init__() + if architecture not in _BACKBONES: + raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}") + model_config = model_config or {} + self.architecture = architecture + self.model_config = model_config + self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config)) + self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4)) + self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4)) + + def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]: + prepared = {} + for key, value in model_config.items(): + prepared[key] = AttrDict.from_data(value) + return prepared + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.LongTensor] = None, + context: Optional[torch.LongTensor] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]: + labels = class_labels if class_labels is not None else context + if labels is None: + raise ValueError("Either `class_labels` or `context` must be provided.") + + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten() + if timestep.numel() == 1: + timestep = timestep.repeat(labels.shape[0]) + + sample = self.backbone(hidden_states, timestep, labels, **kwargs) + if isinstance(sample, tuple): + if len(sample) == 6 and sample[1] == "Capacity_Pred": + output = ProMoETransformer2DModelOutput( + sample=sample[0], + loss_strategy=sample[1], + layer_idx_list=tuple(sample[2]), + ones_list=tuple(sample[3]), + pred_c_list=tuple(sample[4]), + capacity_pred_loss_weight=float(sample[5]), + ) + else: + output = ProMoETransformer2DModelOutput(sample=sample[0]) + else: + output = ProMoETransformer2DModelOutput(sample=sample) + + if not return_dict: + if output.loss_strategy is None: + return (output.sample,) + return ( + output.sample, + output.loss_strategy, + output.layer_idx_list, + output.ones_list, + output.pred_c_list, + output.capacity_pred_loss_weight, + ) + return output diff --git a/ProMoE-L-256/vae/config.json b/ProMoE-L-256/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..0db26717579be63eb0ddbf15b43faa43700dfe5a --- /dev/null +++ b/ProMoE-L-256/vae/config.json @@ -0,0 +1,29 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.4.2", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] +} diff --git a/ProMoE-L-256/vae/diffusion_pytorch_model.safetensors b/ProMoE-L-256/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d6fc2b1f7ae2b1f4f83c25812f819a17473f0c1a --- /dev/null +++ b/ProMoE-L-256/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec +size 334643268 diff --git a/ProMoE-XL-256/model_index.json b/ProMoE-XL-256/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..c38159a94f58f8ec4828be64008655e80b74d2c2 --- /dev/null +++ b/ProMoE-XL-256/model_index.json @@ -0,0 +1,1021 @@ +{ + "_class_name": [ + "pipeline", + "ProMoEPipeline" + ], + "_diffusers_version": "0.36.0", + "id2label": { + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "10": "brambling, Fringilla montifringilla", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "11": "goldfinch, Carduelis carduelis", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "12": "house finch, linnet, Carpodacus mexicanus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "13": "junco, snowbird", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "15": "robin, American robin, Turdus migratorius", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "16": "bulbul", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "17": "jay", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "18": "magpie", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "19": "chickadee", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "20": "water ouzel, dipper", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "21": "kite", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "23": "vulture", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "24": "great grey owl, great gray owl, Strix nebulosa", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "25": "European fire salamander, Salamandra salamandra", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "26": "common newt, Triturus vulgaris", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "27": "eft", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "28": "spotted salamander, Ambystoma maculatum", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "3": "tiger shark, Galeocerdo cuvieri", + "30": "bullfrog, Rana catesbeiana", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "31": "tree frog, tree-frog", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "35": "mud turtle", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "36": "terrapin", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "37": "box turtle, box tortoise", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "38": "banded gecko", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "39": "common iguana, iguana, Iguana iguana", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "4": "hammerhead, hammerhead shark", + "40": "American chameleon, anole, Anolis carolinensis", + "400": "academic gown, academic robe, judge robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "41": "whiptail, whiptail lizard", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "42": "agama", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "43": "frilled lizard, Chlamydosaurus kingi", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "44": "alligator lizard", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "45": "Gila monster, Heloderma suspectum", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "46": "green lizard, Lacerta viridis", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "47": "African chameleon, Chamaeleo chamaeleon", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenters kit, tool kit", + "478": "carton", + "479": "car wheel", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "5": "electric ray, crampfish, numbfish, torpedo", + "50": "American alligator, Alligator mississipiensis", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "51": "triceratops", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "52": "thunder snake, worm snake, Carphophis amoenus", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "53": "ringneck snake, ring-necked snake, ring snake", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "54": "hognose snake, puff adder, sand viper", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "55": "green snake, grass snake", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "56": "king snake, kingsnake", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "57": "garter snake, grass snake", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "58": "water snake", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "59": "vine snake", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "6": "stingray", + "60": "night snake, Hypsiglena torquata", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "61": "boa constrictor, Constrictor constrictor", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "62": "rock python, rock snake, Python sebae", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "63": "Indian cobra, Naja naja", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jewelers loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "64": "green mamba", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "65": "sea snake", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "69": "trilobite", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pjs, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "7": "cock", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "71": "scorpion", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "72": "black and gold garden spider, Argiope aurantia", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenters plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "73": "barn spider, Araneus cavaticus", + "730": "plow, plough", + "731": "plunger, plumbers helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potters wheel", + "74": "garden spider, Aranea diademata", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "75": "black widow, Latrodectus mactans", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "76": "tarantula", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "77": "wolf spider, hunting spider", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "78": "tick", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "79": "centipede", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "8": "hen", + "80": "black grouse", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "81": "ptarmigan", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spiders web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "83": "prairie chicken, prairie grouse, prairie fowl", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "84": "peacock", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "85": "quail", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "86": "partridge", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "87": "African grey, African gray, Psittacus erithacus", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "88": "macaw", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "9": "ostrich, Struthio camelus", + "90": "lorikeet", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "91": "coucal", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "92": "bee eater", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "93": "hornbill", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "94": "hummingbird", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "95": "jacamar", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "96": "toucan", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "97": "drake", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "98": "red-breasted merganser, Mergus serrator", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "99": "goose", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" + }, + "scheduler": [ + "scheduling_flow_match_promoe", + "ProMoEFlowMatchScheduler" + ], + "transformer": [ + "transformer_promoe", + "ProMoETransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/ProMoE-XL-256/pipeline.py b/ProMoE-XL-256/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a22aa2d52139703430ab9d7e7ebcc78db2d3d777 --- /dev/null +++ b/ProMoE-XL-256/pipeline.py @@ -0,0 +1,259 @@ +"""Hub custom pipeline: ProMoEPipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +try: + from diffusers.pipelines.pipeline_utils import DiffusionPipeline +except Exception: # pragma: no cover + class DiffusionPipeline: + def __init__(self): + self._execution_device = torch.device("cpu") + + def register_modules(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to(self, device): + self._execution_device = torch.device(device) + for module in (getattr(self, "transformer", None), getattr(self, "vae", None)): + if module is not None and hasattr(module, "to"): + module.to(device) + return self + + def progress_bar(self, iterable): + return iterable + + def maybe_free_model_hooks(self): + return None + +@dataclass +class ProMoEPipelineOutput: + images: Union[List[Image.Image], np.ndarray, torch.Tensor] + +class ProMoEPipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with ProMoE. + + Parameters: + transformer ([`ProMoETransformer2DModel`]): + Class-conditional ProMoE transformer for flow-matching in latent space. + scheduler ([`ProMoEFlowMatchScheduler`]): + Flow-matching scheduler used during denoising. + vae ([`AutoencoderKL`], *optional*): + Variational autoencoder used to decode latents to pixels. + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer->vae" + _optional_components = ["vae"] + + def __init__( + self, + transformer, + scheduler, + vae=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + variant_dir = Path(variant_path).resolve() + model_index_path = variant_dir / "model_index.json" + if not model_index_path.exists(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings. Each string must match a synonym in `id2label`. + """ + self._ensure_labels_loaded() + label2id = self.labels + if not label2id: + raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.") + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [label2id[item] for item in label] + + def _get_vae_spatial_downsample(self) -> int: + if self.vae is None: + return 8 + block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0]) + return 2 ** (len(block_out_channels) - 1) + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + device: torch.device, + ) -> torch.LongTensor: + if torch.is_tensor(class_labels): + return class_labels.to(device=device, dtype=torch.long).reshape(-1) + + if isinstance(class_labels, int): + class_label_ids = [class_labels] + elif isinstance(class_labels, str): + class_label_ids = self.get_label_ids(class_labels) + elif class_labels and isinstance(class_labels[0], str): + class_label_ids = self.get_label_ids(class_labels) + else: + class_label_ids = list(class_labels) + + return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1) + + def _prepare_latents( + self, + batch_size: int, + latent_height: int, + latent_width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + shape = (batch_size, self.transformer.in_channels, latent_height, latent_width) + if isinstance(generator, list): + latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator] + return torch.cat(latents, dim=0) + return torch.randn(shape, generator=generator, device=device, dtype=dtype) + + def _decode_latents(self, latents: torch.Tensor, output_type: str): + if output_type == "latent": + return latents + if self.vae is not None: + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + decode_dtype = next(self.vae.parameters()).dtype + latents = (latents / scaling_factor).to(dtype=decode_dtype) + image = self.vae.decode(latents, return_dict=False)[0] + else: + image = latents + + image = (image / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return image + image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() + if output_type == "np": + return image + pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image] + return pil_images + + @torch.no_grad() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ProMoEPipelineOutput, Tuple]: + r""" + Generate class-conditional images with ProMoE. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`): + ImageNet class indices or human-readable English label strings. + """ + device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu") + model_dtype = next(self.transformer.parameters()).dtype + class_labels = self._normalize_class_labels(class_labels, device) + batch_size = class_labels.shape[0] + + vae_scale = self._get_vae_spatial_downsample() + latent_height = height // vae_scale + latent_width = width // vae_scale + latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000)) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1.0: + latent_input = torch.cat([latents, latents], dim=0) + labels = torch.cat([class_labels, null_labels], dim=0) + else: + latent_input = latents + labels = class_labels + timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype) + model_output = self.transformer( + hidden_states=latent_input, + timestep=timestep, + class_labels=labels, + return_dict=True, + ).sample + if model_output.shape[1] != latents.shape[1]: + model_output = model_output.chunk(2, dim=1)[0] + if guidance_scale > 1.0: + model_output_cond, model_output_uncond = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) + latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample + + images = self._decode_latents(latents, output_type) + self.maybe_free_model_hooks() + if not return_dict: + return (images,) + return ProMoEPipelineOutput(images=images) \ No newline at end of file diff --git a/ProMoE-XL-256/scheduler/config.json b/ProMoE-XL-256/scheduler/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b02311af404a07878b2d91f97ee9b4967e435d8d --- /dev/null +++ b/ProMoE-XL-256/scheduler/config.json @@ -0,0 +1,5 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "num_train_timesteps": 1000, + "shift": 1.0 +} diff --git a/ProMoE-XL-256/scheduler/scheduler_config.json b/ProMoE-XL-256/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d57a6cefb17ef05cb172b2d55177ab379a67a715 --- /dev/null +++ b/ProMoE-XL-256/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "ProMoEFlowMatchScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "shift": 1.0, + "stochastic_sampling": false +} diff --git a/ProMoE-XL-256/scheduler/scheduling_flow_match_promoe.py b/ProMoE-XL-256/scheduler/scheduling_flow_match_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..d71fe31541e09779d1afe32a7bcb9418a453e69f --- /dev/null +++ b/ProMoE-XL-256/scheduler/scheduling_flow_match_promoe.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +import torch + +try: + from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +except Exception: # pragma: no cover + FlowMatchEulerDiscreteScheduler = None + + +@dataclass +class ProMoEFlowMatchSchedulerOutput: + prev_sample: torch.FloatTensor + + +if FlowMatchEulerDiscreteScheduler is not None: + + class ProMoEFlowMatchScheduler(FlowMatchEulerDiscreteScheduler): + pass + +else: + + class ProMoEFlowMatchScheduler: + def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): + self.config = SimpleNamespace(num_train_timesteps=num_train_timesteps, shift=shift, stochastic_sampling=False) + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.float32) + + def set_timesteps(self, num_inference_steps: int, device: Optional[torch.device] = None): + self.timesteps = torch.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=torch.float32, + device=device, + ) + + def step(self, model_output, timestep, sample, generator=None): + del generator + dt = 1.0 / max(len(self.timesteps), 1) + prev_sample = sample - dt * model_output + return ProMoEFlowMatchSchedulerOutput(prev_sample=prev_sample) diff --git a/ProMoE-XL-256/transformer/backbone_diffmoe.py b/ProMoE-XL-256/transformer/backbone_diffmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..91f8dfcec6a943fdb985195fa5c706fdb94a4293 --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_diffmoe.py @@ -0,0 +1,302 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + num_experts, + n_shared_experts=0, + capacity=2, + mlp_ratio=4.0, + use_diff_expert=False, + ): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + self.use_diff_expert = use_diff_expert + if use_diff_expert: + self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio)) + + self.capacity_predictor = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, self.num_experts, bias=True), + ) + + if self.n_shared_experts > 0: + mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts)) + self.register_buffer("ema_decay", torch.tensor([0.95])) + + def forward(self, x): + if self.training: + return self.forward_train(x) + return self.forward_eval(x) + + def update_threshold(self, capacity_pred): + if not self.training: + return + capacity_pred = torch.sigmoid(capacity_pred) + seq_len = capacity_pred.size(0) + topk = int((seq_len / self.num_experts) * self.capacity) + threshold = self.expert_threshold + ema_decay = self.ema_decay + for i in range(self.num_experts): + scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True) + quantile = scores[-1].detach() + threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(threshold, op=dist.ReduceOp.SUM) + threshold /= dist.get_world_size() + self.expert_threshold = threshold + + def forward_train(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = self.capacity_predictor(x.detach()) + k = int((total_tokens / self.num_experts) * self.capacity) + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(1, 0) + gating, index = torch.topk(scores, k=k, dim=-1, sorted=False) + mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device) + mask.scatter_(1, index, 1.0) + expert_inputs = x[index] + expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)]) + gated_outputs = gating.unsqueeze(-1) * expert_outputs + + y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device) + offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens + flat_index = (index + offset.long()).view(-1) + y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim)) + y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False) + + self.update_threshold(capacity_pred) + x_out = y.view(bsz, seq_len, hidden_dim) + ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts) + capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + if self.use_diff_expert: + x_out = x_out - self.diff_expert(identity) + return x_out, ones, capacity_pred + + def forward_eval(self, x): + bsz, seq_len, hidden_dim = x.shape + identity = x + x = x.view(-1, hidden_dim) + total_tokens = x.shape[0] + capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach())) + threshold = self.expert_threshold + logits = F.linear(x, self.gate_weight, None) + scores = logits.softmax(dim=-1).permute(-1, -2) + y = torch.zeros_like(x, dtype=x.dtype) + for i, expert in enumerate(self.experts): + k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum() + gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False) + y[index, :] += gating.unsqueeze(-1) * expert(x[index, :]) + x_out = y.view(bsz, seq_len, hidden_dim) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out, None, None + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + qk_norm=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + mlp_ratio=4.0, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + gate_mlp.unsqueeze(1) * x_mlp + return x, ones, pred_c + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x, None, None + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + CapacityPred_loss_weight=0.01, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.CapacityPred_loss_weight = CapacityPred_loss_weight + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + ones_list, pred_c_list, layer_idx_list = [], [], [] + for layer_idx, block in enumerate(self.blocks): + x, ones, pred_c = block(x, c) + if ones is not None: + ones_list.append(ones) + pred_c_list.append(pred_c) + layer_idx_list.append(layer_idx) + x = self.final_layer(x, c) + x = self.unpatchify(x) + return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight diff --git a/ProMoE-XL-256/transformer/backbone_dit.py b/ProMoE-XL-256/transformer/backbone_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fde70ff5dc640a9467dfd563e36419f722c7c1 --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_dit.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + head_dim=None, + use_swiglu=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-XL-256/transformer/backbone_ecdit.py b/ProMoE-XL-256/transformer/backbone_ecdit.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae2c725ae1bac6c23a23cf467d444bed11b9f3d --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_ecdit.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class SparseMoEBlock(nn.Module): + def __init__(self, experts, hidden_dim, num_experts, n_shared_experts=0, capacity=2): + super().__init__() + self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) + nn.init.normal_(self.gate_weight, std=0.006) + self.experts = nn.ModuleList(experts) + self.capacity = capacity + self.num_experts = num_experts + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP(hidden_size=hidden_dim, intermediate_size=intermediate_size, pretraining_tp=2) + + def forward(self, x): + identity = x + batch_size, seq_len, _ = x.shape + logits = F.linear(x, self.gate_weight, None) + affinity = logits.softmax(dim=-1) + affinity = torch.einsum("b s e -> b e s", affinity) + k = int((seq_len / self.num_experts) * self.capacity) + gating, index = torch.topk(affinity, k=k, dim=-1, sorted=False) + dispatch = F.one_hot(index, num_classes=seq_len).to(device=x.device, dtype=x.dtype) + x_in = torch.einsum("b e c s, b s d -> b e c d", dispatch, x) + x_e = [self.experts[e](x_in[:, e]) for e in range(self.num_experts)] + x_e = torch.stack(x_e, dim=1) + x_out = torch.einsum("b e c s, b e c, b e c d -> b s d", dispatch, gating, x_e) + if self.n_shared_experts > 0: + x_out = x_out + self.shared_experts(identity) + return x_out + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + capacity=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + self.capacity_schedule = MoE_config.get("capacity_schedule", None) + if self.capacity_schedule: + self.training_iters = -1 + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.gate_proj.weight, std=std) + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + if hasattr(expert, "gate_proj"): + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + if self.training and self.capacity_schedule: + num_experts = self.MoE_config.num_experts + capacity = self.MoE_config.capacity + stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters + stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters + if self.training_iters <= stage_i: + capacity = num_experts + elif self.training_iters <= stage_ii: + capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) + for block in self.blocks: + if hasattr(block.mlp, "capacity"): + block.mlp.capacity = capacity + + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-XL-256/transformer/backbone_promoe_ec.py b/ProMoE-XL-256/transformer/backbone_promoe_ec.py new file mode 100644 index 0000000000000000000000000000000000000000..05da901ed601ca8e683ab5d55da0af3922534015 --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_promoe_ec.py @@ -0,0 +1,286 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=1, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del load_balance_loss_coef, norm_topk_prob, seq_aux, use_top_k_for_routing_contrastive + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, cond_hidden_states): + b_cond, seq_len, _ = cond_hidden_states.shape + num_cond_experts = self.num_routed_experts + input_norm = F.normalize(cond_hidden_states, p=2, dim=-1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=-1) + cos_sim = input_norm @ cluster_norm.T + cos_sim_expert_view = cos_sim.transpose(1, 2) + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim_expert_view, dim=-1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim_expert_view) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim_expert_view + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + k = max(1, min(int((seq_len / num_cond_experts) * self.top_k), seq_len)) + router_weights, indices = torch.topk(cond_weights, k=k, dim=-1, sorted=False) + dispatch_mask = F.one_hot(indices, num_classes=seq_len).to(dtype=cond_hidden_states.dtype) + expert_inputs = torch.einsum("becs,bsd->becd", dispatch_mask, cond_hidden_states) + return dispatch_mask, router_weights, expert_inputs + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + identity = hidden_states + batch_size, _, hidden_dim = hidden_states.shape + final_output = torch.zeros_like(hidden_states) + loss = None + cond_batch_mask = ( + labels.view(-1) != 1000 + ) if self.use_uncond_expert else torch.ones(batch_size, dtype=torch.bool, device=hidden_states.device) + uncond_batch_mask = ~cond_batch_mask + cond_experts = self.experts[:-1] if self.use_uncond_expert else self.experts + + if cond_batch_mask.any(): + cond_hidden_states = hidden_states[cond_batch_mask] + dispatch_mask, gating_scores, expert_inputs = self.compute_router(cond_hidden_states) + num_cond_experts = len(cond_experts) + expert_outputs = torch.stack([cond_experts[e](expert_inputs[:, e]) for e in range(num_cond_experts)], dim=1) + cond_output = torch.einsum("becs,bec,becd->bsd", dispatch_mask, gating_scores, expert_outputs).to(hidden_states.dtype) + final_output[cond_batch_mask] = cond_output + if self.training and self.routing_contrastive_lam > 0 and num_cond_experts > 1: + expert_token_means = expert_inputs.mean(dim=2) + routing_contrastive_loss = self.compute_routing_contrastive_loss(expert_token_means) + loss = routing_contrastive_loss * self.routing_contrastive_lam + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + for expert in cond_experts: + final_output = final_output + expert(dummy_input).sum() * 0 + + if self.use_uncond_expert: + if uncond_batch_mask.any(): + uncond_hidden_states = hidden_states[uncond_batch_mask] + final_output[uncond_batch_mask] = self.experts[-1](uncond_hidden_states).to(final_output.dtype) + else: + dummy_input = torch.zeros(1, 1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output = final_output + self.experts[-1](dummy_input).sum() * 0 + + if self.use_shared_expert: + final_output += self.shared_expert(identity).to(hidden_states.dtype) + return final_output, loss + + def compute_routing_contrastive_loss(self, expert_token_means): + batch_size, num_cond_experts, _ = expert_token_means.shape + if num_cond_experts < 2: + return torch.tensor(0.0, device=expert_token_means.device) + centers_norm = F.normalize(self.cluster_centers, p=2, dim=1) + means_norm = F.normalize(expert_token_means, p=2, dim=2) + sim_matrix = torch.einsum("id,bjd->bij", centers_norm, means_norm) + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(num_cond_experts, device=logits.device).unsqueeze(0).expand(batch_size, -1) + return F.cross_entropy(logits.reshape(batch_size * num_cond_experts, -1), labels.reshape(-1)) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-XL-256/transformer/backbone_promoe_tc.py b/ProMoE-XL-256/transformer/backbone_promoe_tc.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5f0036823de886748b5a375a98b7f0efc6377f --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_promoe_tc.py @@ -0,0 +1,355 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoeBlock(nn.Module): + def __init__( + self, + num_routed_experts, + hidden_size, + moe_intermediate_size, + shared_expert_intermediate_size, + top_k=2, + load_balance_loss_coef=0, + norm_topk_prob=False, + seq_aux=False, + use_shared_expert=True, + use_uncond_expert=True, + router_weight_mode="softmax", + routing_contrastive_lam=0, + use_top_k_for_routing_contrastive=False, + routing_contrastive_temperature=0.1, + **kwargs, + ): + super().__init__() + del norm_topk_prob + self.num_experts = num_routed_experts + 1 if use_uncond_expert else num_routed_experts + self.num_routed_experts = num_routed_experts + self.seq_aux = seq_aux + self.hidden_size = hidden_size + self.top_k = top_k + self.cluster_centers = nn.Parameter(torch.randn(num_routed_experts, hidden_size)) + self.alpha = load_balance_loss_coef + self.use_shared_expert = use_shared_expert + self.use_uncond_expert = use_uncond_expert + self.router_weight_mode = router_weight_mode + self.routing_contrastive_lam = routing_contrastive_lam + self.use_top_k_for_routing_contrastive = use_top_k_for_routing_contrastive + self.routing_contrastive_temperature = routing_contrastive_temperature + self.experts = nn.ModuleList( + [MoeMLP(hidden_size=hidden_size, intermediate_size=moe_intermediate_size) for _ in range(self.num_experts)] + ) + if use_shared_expert: + self.shared_expert = MoeMLP(hidden_size=hidden_size, intermediate_size=shared_expert_intermediate_size) + self._init_weights() + + def compute_router(self, hidden_states, labels): + batch_size, seq_len, _ = hidden_states.shape + device = hidden_states.device + flat_input = hidden_states.view(-1, self.hidden_size) + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + if self.use_uncond_expert and flat_labels is not None: + uncond_mask = flat_labels == 1000 + cond_mask = ~uncond_mask + else: + uncond_mask = None + cond_mask = torch.ones_like(flat_labels, dtype=torch.bool) + + router_weights = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=hidden_states.dtype) + expert_indices = torch.zeros(batch_size * seq_len, self.top_k, device=device, dtype=torch.long) + + if uncond_mask is not None and uncond_mask.any(): + uncond_positions = torch.where(uncond_mask)[0] + router_weights[uncond_positions, 0] = 1.0 + expert_indices[uncond_positions] = self.num_experts - 1 + + cond_weights = None + topk_idx = None + if cond_mask.any(): + cond_positions = torch.where(cond_mask)[0] + cond_input = flat_input[cond_positions] + input_norm = F.normalize(cond_input, p=2, dim=1) + cluster_norm = F.normalize(self.cluster_centers, p=2, dim=1) + cos_sim = input_norm @ cluster_norm.T + if self.router_weight_mode == "softmax": + cond_weights = F.softmax(cos_sim, dim=1) + elif self.router_weight_mode == "sigmoid": + cond_weights = torch.sigmoid(cos_sim) + elif self.router_weight_mode == "identity": + cond_weights = cos_sim + else: + raise ValueError(f"Unsupported router_weight_mode: {self.router_weight_mode}") + topk_scores, topk_idx = torch.topk(cond_weights, k=self.top_k, dim=1) + router_weights[cond_positions] = topk_scores.to(router_weights.dtype) + expert_indices[cond_positions] = topk_idx + + router_weights = router_weights.view(batch_size, seq_len, self.top_k) + expert_indices = expert_indices.view(batch_size, seq_len, self.top_k) + + load_balance_loss = None + if self.training and self.alpha > 0.0 and cond_weights is not None and topk_idx is not None: + cond_batch_size = (labels != 1000).sum() + scores_for_aux = F.softmax(cond_weights, dim=1) if self.router_weight_mode != "softmax" else cond_weights + topk_idx_for_aux_loss = topk_idx.view(cond_batch_size, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(cond_batch_size, seq_len, -1) + ce = torch.zeros(cond_batch_size, self.num_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(cond_batch_size, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.num_routed_experts) + load_balance_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.num_routed_experts + load_balance_loss = (pi * fi).sum() * self.alpha + return router_weights, expert_indices, load_balance_loss + + def forward(self, hidden_states: torch.Tensor, labels: torch.Tensor): + router_weights, expert_indices, load_balance_loss = self.compute_router(hidden_states, labels) + batch_size, seq_len, hidden_dim = hidden_states.shape + flat_input = hidden_states.view(-1, hidden_dim) + flat_weights = router_weights.view(-1, self.top_k) + flat_indices = expert_indices.view(-1, self.top_k) + total_tokens = batch_size * seq_len + final_output = torch.zeros(total_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + + for expert_id in range(self.num_experts): + expert_mask = (flat_indices == expert_id).any(dim=1) + token_ids = torch.where(expert_mask)[0] + if token_ids.numel() > 0: + expert_input = flat_input[token_ids] + expert_weight_mask = flat_indices[token_ids] == expert_id + expert_weights = flat_weights[token_ids] * expert_weight_mask.to(dtype=flat_weights.dtype) + combined_weights = expert_weights.sum(dim=1) + expert_output = self.experts[expert_id](expert_input) + weighted_output = expert_output * combined_weights.unsqueeze(1) + final_output.index_add_(0, token_ids, weighted_output) + else: + dummy_input = torch.zeros(1, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype) + final_output[0] += self.experts[expert_id](dummy_input)[0] * 0 + + final_output = final_output.view(batch_size, seq_len, hidden_dim) + if self.use_shared_expert: + final_output += self.shared_expert(hidden_states) + + loss = load_balance_loss + if self.training and self.routing_contrastive_lam > 0: + flat_labels = labels.view(batch_size, 1).expand(-1, seq_len).reshape(-1) + cond_mask = ~( + flat_labels == 1000 + ) if self.use_uncond_expert else torch.ones(batch_size * seq_len, dtype=torch.bool, device=hidden_states.device) + cond_token_embeddings = flat_input[cond_mask] + if self.use_top_k_for_routing_contrastive: + cond_cluster_assignments = expert_indices.view(batch_size * seq_len, self.top_k)[cond_mask] + else: + top1_expert_indices = expert_indices.view(batch_size * seq_len, self.top_k)[:, 0] + cond_cluster_assignments = top1_expert_indices[cond_mask] + routing_contrastive_loss = self.compute_routing_contrastive_loss( + cond_token_embeddings, + cond_cluster_assignments, + use_top_k=self.use_top_k_for_routing_contrastive, + ) + routing_contrastive_loss = routing_contrastive_loss * self.routing_contrastive_lam + loss = routing_contrastive_loss if loss is None else loss + routing_contrastive_loss + + return final_output, loss + + def compute_routing_contrastive_loss(self, token_embeddings, cluster_assignments, use_top_k=False): + cluster_centers = self.cluster_centers + num_clusters = cluster_centers.size(0) + device = cluster_centers.device + cluster_means = [] + valid_clusters = [] + for cluster_id in range(num_clusters): + mask = (cluster_assignments == cluster_id).any(dim=1) if use_top_k else cluster_assignments == cluster_id + if mask.sum() > 0: + cluster_means.append(token_embeddings[mask].mean(dim=0, keepdim=True)) + valid_clusters.append(cluster_id) + if len(valid_clusters) < 2: + return torch.tensor(0.0, device=device) + cluster_means = torch.cat(cluster_means, dim=0) + valid_centers = cluster_centers[valid_clusters] + centers_norm = F.normalize(valid_centers, p=2, dim=1) + means_norm = F.normalize(cluster_means, p=2, dim=1) + sim_matrix = centers_norm @ means_norm.T + logits = sim_matrix / self.routing_contrastive_temperature + labels = torch.arange(sim_matrix.size(0), device=device) + return F.cross_entropy(logits, labels) + + def _init_weights(self): + nn.init.normal_(self.cluster_centers, mean=0.0, std=0.02) + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + head_dim=None, + mlp_ratio=4.0, + use_swiglu=False, + MoE_config=None, + use_moe=False, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + self.mlp = SparseMoeBlock(hidden_size=hidden_size, **MoE_config) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, label): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + if self.use_moe: + x_mlp, aux_loss = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), label) + if aux_loss is not None: + x_mlp = AddAuxiliaryLoss.apply(x_mlp, aux_loss) + return x + gate_mlp.unsqueeze(1) * x_mlp + return x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + head_dim=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, return_labels=True) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.init_MoeMLP = MoE_config.init_MoeMLP + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def init_moe_mlp(module, std=0.006): + nn.init.normal_(module.up_proj.weight, std=std) + nn.init.normal_(module.down_proj.weight, std=std) + + if self.init_MoeMLP: + for block in self.blocks: + if hasattr(block.mlp, "experts"): + for expert in block.mlp.experts: + init_moe_mlp(expert) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, timestep, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(timestep) + y, labels = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c, labels) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-XL-256/transformer/backbone_tcdit.py b/ProMoE-XL-256/transformer/backbone_tcdit.py new file mode 100644 index 0000000000000000000000000000000000000000..18bc64b114caf8c359ff6842ffd54bdf18af2123 --- /dev/null +++ b/ProMoE-XL-256/transformer/backbone_tcdit.py @@ -0,0 +1,304 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modeling_promoe_common import ( + Attention, + FinalLayer, + LabelEmbedder, + Mlp, + MoeMLP_DiffMoE as MoeMLP, + PatchEmbed, + TimestepEmbedder, + get_2d_sincos_pos_embed, + modulate, +) + + +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + self.scoring_func = "softmax" + self.alpha = aux_loss_alpha + self.seq_aux = False + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func != "softmax": + raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}") + scores = logits.softmax(dim=-1) + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + if self.top_k > 1 and self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * self.top_k, device=hidden_states.device), + ).div_(seq_len * self.top_k / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) if ctx.required_aux_loss else None + return grad_output, grad_loss + + +class SparseMoEBlock(nn.Module): + def __init__( + self, + experts, + hidden_dim, + mlp_ratio=4, + num_experts=16, + num_experts_per_tok=2, + pretraining_tp=2, + n_shared_experts=2, + ): + super().__init__() + self.top_k = num_experts_per_tok + self.experts = nn.ModuleList(experts) + self.gate = MoEGate(embed_dim=hidden_dim, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + self.n_shared_experts = n_shared_experts + if self.n_shared_experts > 0: + intermediate_size = hidden_dim * self.n_shared_experts + self.shared_experts = MoeMLP( + hidden_size=hidden_dim, + intermediate_size=intermediate_size, + pretraining_tp=pretraining_tp, + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states, dtype=hidden_states.dtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).float() + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + if self.n_shared_experts > 0: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.top_k + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_( + 0, + exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), + expert_out, + reduce="sum", + ) + return expert_cache + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4, + pretraining_tp=2, + use_swiglu=False, + MoE_config=None, + use_moe=True, + **block_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.use_moe = use_moe + if use_moe: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + experts = [ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + for _ in range(MoE_config.num_experts) + ] + else: + experts = [ + MoeMLP( + hidden_size=hidden_size, + intermediate_size=mlp_hidden_dim, + pretraining_tp=pretraining_tp, + ) + for _ in range(MoE_config.num_experts) + ] + self.mlp = SparseMoEBlock( + experts=experts, + hidden_dim=hidden_size, + num_experts=MoE_config.num_experts, + num_experts_per_tok=MoE_config.capacity, + n_shared_experts=MoE_config.n_shared_experts, + ) + else: + if not use_swiglu: + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + else: + self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qk_norm=False, + class_dropout_prob=0.1, + num_classes=1000, + pretraining_tp=1, + learn_sigma=True, + use_swiglu=False, + MoE_config=None, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.MoE_config = MoE_config + use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + use_swiglu=use_swiglu, + pretraining_tp=pretraining_tp, + MoE_config=MoE_config, + use_moe=use_moe_flag[i], + ) + for i in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def forward(self, x, t, context, **kwargs): + y = context + if len(x.shape) != 4: + x = x.squeeze(2) + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + return self.unpatchify(x) diff --git a/ProMoE-XL-256/transformer/config.json b/ProMoE-XL-256/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..0726a979ddce31c3d4913f3d34907530e88d5ee2 --- /dev/null +++ b/ProMoE-XL-256/transformer/config.json @@ -0,0 +1,22 @@ +{ + "_class_name": "ProMoETransformer2DModel", + "architecture": "promoe_tc", + "model_config": { + "MoE_config": { + "init_MoeMLP": false, + "interleave": true, + "moe_intermediate_size": 2304, + "num_routed_experts": 12, + "shared_expert_intermediate_size": 2304, + "top_k": 1, + "use_shared_expert": true, + "use_uncond_expert": true + }, + "depth": 28, + "hidden_size": 1152, + "input_size": 32, + "num_classes": 1000, + "num_heads": 16, + "patch_size": 2 + } +} diff --git a/ProMoE-XL-256/transformer/diffusion_pytorch_model.safetensors b/ProMoE-XL-256/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8cf35e1fa35a787e826b3e8ed435826210b5175a --- /dev/null +++ b/ProMoE-XL-256/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c02abf475a881cdd445984c14642463ec2679aa8f28967ac20761da09f105580 +size 6271058696 diff --git a/ProMoE-XL-256/transformer/modeling_promoe_common.py b/ProMoE-XL-256/transformer/modeling_promoe_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0a82f2ece8db2dff46a45faafb9731af18f09a34 --- /dev/null +++ b/ProMoE-XL-256/transformer/modeling_promoe_common.py @@ -0,0 +1,291 @@ +import collections.abc +import math +from dataclasses import dataclass +from itertools import repeat +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as error: + raise AttributeError(item) from error + + def __setattr__(self, key, value): + self[key] = value + + @staticmethod + def from_data(data: Any) -> Any: + if isinstance(data, dict): + return AttrDict({k: AttrDict.from_data(v) for k, v in data.items()}) + if isinstance(data, list): + return [AttrDict.from_data(v) for v in data] + return data + + +class PatchEmbed(nn.Module): + def __init__(self, input_size: int, patch_size: int, in_channels: int, embed_dim: int, bias: bool = True): + super().__init__() + self.img_size = to_2tuple(input_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = ( + self.img_size[0] // self.patch_size[0], + self.img_size[1] // self.patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + return hidden_states.flatten(2).transpose(1, 2) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MoeMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class MoeMLP_DiffMoE(nn.Module): + def __init__(self, hidden_size, intermediate_size, pretraining_tp=2): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = nn.SiLU() + self.pretraining_tp = pretraining_tp + + def forward(self, x): + if self.pretraining_tp > 1: + split_size = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(split_size, dim=0) + up_proj_slices = self.up_proj.weight.split(split_size, dim=0) + down_proj_slices = self.down_proj.weight.split(split_size, dim=1) + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(split_size, dim=-1) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + return sum(down_proj) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + head_dim=None, + norm_layer: nn.Module = nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + if head_dim is None: + if dim % num_heads != 0: + raise ValueError("dim must be divisible by num_heads") + self.head_dim = dim // num_heads + else: + self.head_dim = head_dim + self.scale = self.head_dim**-0.5 + self.fused_attn = True + self.qkv = nn.Linear(dim, self.head_dim * self.num_heads * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)).softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(batch_size, seq_len, -1) + x = self.proj(x) + return self.proj_drop(x) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t.float(), self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + return self.mlp(t_freq.to(dtype=weight_dtype)) + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob, return_labels=False): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + self.return_labels = return_labels + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + return torch.where(drop_ids, self.num_classes, labels) + + def forward(self, labels, train, force_drop_ids=None): + if (train and self.dropout_prob > 0) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + if self.return_labels: + return embeddings, labels + return embeddings + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + return self.linear(x) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + return np.concatenate([emb_sin, emb_cos], axis=1) diff --git a/ProMoE-XL-256/transformer/transformer_promoe.py b/ProMoE-XL-256/transformer/transformer_promoe.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6369fcdcb55c394ef5c7fa8d4d50b7b32ba145 --- /dev/null +++ b/ProMoE-XL-256/transformer/transformer_promoe.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except Exception: # pragma: no cover + class BaseOutput(dict): + def __post_init__(self): + self.update(self.__dict__) + + class _Config(dict): + def __getattr__(self, key): + try: + return self[key] + except KeyError as error: + raise AttributeError(key) from error + + class ConfigMixin: + config_name = "config.json" + + class ModelMixin(nn.Module): + pass + + def register_to_config(init): + def wrapper(self, *args, **kwargs): + import inspect + + signature = inspect.signature(init) + bound = signature.bind(self, *args, **kwargs) + bound.apply_defaults() + self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"}) + init(self, *args, **kwargs) + + return wrapper + +from .backbone_diffmoe import DiT as DiffMoEBackbone +from .backbone_dit import DiT as DiTBackbone +from .backbone_ecdit import DiT as ECDiTBackbone +from .backbone_promoe_ec import DiT as ProMoEECBackbone +from .backbone_promoe_tc import DiT as ProMoETCBackbone +from .backbone_tcdit import DiT as TCDiTBackbone +from .modeling_promoe_common import AttrDict + + +@dataclass +class ProMoETransformer2DModelOutput(BaseOutput): + sample: torch.FloatTensor + loss_strategy: Optional[str] = None + layer_idx_list: Optional[Tuple[int, ...]] = None + ones_list: Optional[Tuple[torch.FloatTensor, ...]] = None + pred_c_list: Optional[Tuple[torch.FloatTensor, ...]] = None + capacity_pred_loss_weight: Optional[float] = None + + +_BACKBONES = { + "dit": DiTBackbone, + "tcdit": TCDiTBackbone, + "ecdit": ECDiTBackbone, + "diffmoe": DiffMoEBackbone, + "promoe_tc": ProMoETCBackbone, + "promoe_ec": ProMoEECBackbone, +} + + +class ProMoETransformer2DModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__(self, architecture: str = "promoe_tc", model_config: Optional[Dict[str, Any]] = None): + super().__init__() + if architecture not in _BACKBONES: + raise ValueError(f"Unsupported architecture: {architecture}. Valid: {sorted(_BACKBONES)}") + model_config = model_config or {} + self.architecture = architecture + self.model_config = model_config + self.backbone = _BACKBONES[architecture](**self._prepare_config(model_config)) + self.in_channels = getattr(self.backbone, "in_channels", model_config.get("in_channels", 4)) + self.out_channels = getattr(self.backbone, "out_channels", model_config.get("in_channels", 4)) + + def _prepare_config(self, model_config: Dict[str, Any]) -> Dict[str, Any]: + prepared = {} + for key, value in model_config.items(): + prepared[key] = AttrDict.from_data(value) + return prepared + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.LongTensor] = None, + context: Optional[torch.LongTensor] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[ProMoETransformer2DModelOutput, Tuple[torch.Tensor, ...]]: + labels = class_labels if class_labels is not None else context + if labels is None: + raise ValueError("Either `class_labels` or `context` must be provided.") + + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten() + if timestep.numel() == 1: + timestep = timestep.repeat(labels.shape[0]) + + sample = self.backbone(hidden_states, timestep, labels, **kwargs) + if isinstance(sample, tuple): + if len(sample) == 6 and sample[1] == "Capacity_Pred": + output = ProMoETransformer2DModelOutput( + sample=sample[0], + loss_strategy=sample[1], + layer_idx_list=tuple(sample[2]), + ones_list=tuple(sample[3]), + pred_c_list=tuple(sample[4]), + capacity_pred_loss_weight=float(sample[5]), + ) + else: + output = ProMoETransformer2DModelOutput(sample=sample[0]) + else: + output = ProMoETransformer2DModelOutput(sample=sample) + + if not return_dict: + if output.loss_strategy is None: + return (output.sample,) + return ( + output.sample, + output.loss_strategy, + output.layer_idx_list, + output.ones_list, + output.pred_c_list, + output.capacity_pred_loss_weight, + ) + return output diff --git a/ProMoE-XL-256/vae/config.json b/ProMoE-XL-256/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..0db26717579be63eb0ddbf15b43faa43700dfe5a --- /dev/null +++ b/ProMoE-XL-256/vae/config.json @@ -0,0 +1,29 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.4.2", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] +} diff --git a/ProMoE-XL-256/vae/diffusion_pytorch_model.safetensors b/ProMoE-XL-256/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d6fc2b1f7ae2b1f4f83c25812f819a17473f0c1a --- /dev/null +++ b/ProMoE-XL-256/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec +size 334643268 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9c845b63de0f6a9ba915c53f3eb7806b7cbb3e8 --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +# ProMoE — Hub custom pipeline + +Load checkpoints with **native Hugging Face diffusers** and this folder on the Hub (or via `custom_pipeline`): + +```python +import torch +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained( + "BiliSakura/ProMoE-diffusers", + trust_remote_code=True, + torch_dtype=torch.float16, +) +pipe.to("cuda") +``` + +## Hub layout + +| Path | Purpose | +| --- | --- | +| `pipeline.py` | `ProMoEPipeline` | +| `transformer/` | backbone_diffmoe.py, backbone_dit.py, backbone_ecdit.py, backbone_promoe_ec.py, backbone_promoe_tc.py, backbone_tcdit.py, … | +| `scheduler/` | scheduling_flow_match_promoe.py | + + +## ImageNet class labels + +Each variant keeps an English `id2label` map in `model_index.json` (DiT-style). + +- `pipe.id2label` — id → English label (comma-separated synonyms) +- `pipe(class_labels=207, ...)` — class-conditional sampling with integer ids + +Copy the full 1000-class `id2label` block from `BiliSakura/DiT-diffusers` when publishing a model repo. + +## `model_index.json` + +Copy entries from `model_index.json.example` into your model repo after `save_pretrained`. +Use `["_class_name"] = ["pipeline", "ProMoEPipeline"]` and custom module stems for each component. + +- FlowMatch scheduler: `"scheduler": ["scheduling_flow_match_promoe", "ProMoEFlowMatchScheduler"]` +- VAE: `"vae": ["diffusers", "AutoencoderKL"]` with `stabilityai/sd-vae-ft-mse` weights or bundled safetensors +- ProMoE-TC presets: `ProMoE_TC_S`, `ProMoE_TC_B`, `ProMoE_TC_L`, `ProMoE_TC_XL` (see convert script) + +Regenerate: `python scripts/build_community_pipelines.py`