from utils.motif import init_repro seed = 42 init_repro(seed, deterministic=True) import sys, os, time, pickle, copy, math sys.path.append("./utils") import torch from transformers import ( CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPTokenizer, CLIPTextModelWithProjection, AutoProcessor, AutoModel ) import clip import wandb from utils.video_embedder import VideoEmbedder, Create_Concepts from utils.motif import MoTIF, CBMTransformer, mean_cbm import core.vision_encoder.pe as pe import core.vision_encoder.transforms as pe_transformer def run_experiment(hparams): """Run one CBM training experiment with given hyperparameters.""" dataset = hparams["dataset"] clip_model = hparams["clip_model"] window_size = hparams["window_size"] random = hparams["random"] # dataset name mapping dataset_map = { "breakfast": "Breakfast", "ucf101": "UCF101", "hmdb51": "HMDB", "something2": "Something2" } dataset_name = dataset_map[dataset] # load CLIP or related models if clip_model == "b32": model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval() processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b32.pkl" clip_name = "clip" elif clip_model == "b16": model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").eval() processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16", use_fast=False) embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b16.pkl" clip_name = "clip" elif clip_model == "l14": model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval() processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=False) embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_l14.pkl" clip_name = "clip" elif clip_model == "res50": model, preprocess = clip.load("RN50", device="cpu") processor = preprocess embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_res50.pkl" clip_name = "res50" elif clip_model == "clip4clip": model = CLIPVisionModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k").eval() model_text = CLIPTextModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k") processor = CLIPTokenizer.from_pretrained("Searchium-ai/clip4clip-webvid150k") embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_clip4clip.pkl" clip_name = "clip4clip" elif clip_model == "siglip": model = AutoModel.from_pretrained("google/siglip-base-patch16-224") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") clip_name = "siglip" embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglip.pkl" elif clip_model == "siglipl14": model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384") processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") clip_name = "siglipl14" embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglipl14.pkl" elif clip_model == "pe-l14": model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True) processor = pe_transformer.get_image_transform(model.image_size) tokenizer = pe_transformer.get_text_tokenizer(model.context_length) clip_name = "pe-l14" embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_pe-l14.pkl" else: model = None processor = None model_text = None raise ValueError(f"Unknown clip_model {clip_model}") # embedder embedder = VideoEmbedder(clip_name, model, processor) embedder.dataset_name = dataset if os.path.exists(embedd_path): with open(embedd_path, "rb") as f: embedder = pickle.load(f) print("Loaded existing embedder from", embedd_path) else: folder_path = [f"../Datasets/{dataset_name}/Video_data"] embedder.process_data(folder_path, window_size=window_size, output_path="../Embeddings/Datasets") with open(embedd_path, "wb") as f: pickle.dump(embedder, f) # concepts if clip_model == "clip4clip": concepts = Create_Concepts(clip_name, model_text, processor) elif clip_model == "pe-l14": concepts = Create_Concepts(clip_name, model, tokenizer) else: concepts = Create_Concepts(clip_name, model, processor) if dataset == "breakfast": text_concepts = ["grind, fill, boil, pour, steep, brew, tamp, insert, steam, froth, stir, sip, add, slice, toast, butter, spread, cut, assemble, grate, chop, peel, core, squeeze, pit, mash, crack, whisk, beat, fry, scramble, flip, mix, cook, drizzle, serve, drain, grill, preheat, bake, warm, wash, rinse, blend, measure, set, open, close, take, put, remove, pack, dry, wipe, sit, stand, carry, pick, blow, taste, adjust, reach, place, seal, unwrap, unscrew, scoop, zest, juice, start, stop, turn, heat, cool, toss, shake, tap, knock, press, release, slide, rotate, fold, unfold, wring, sprinkle, arrange, sort, stack, unstack, hide, reveal, cover, uncover, balance, tilt, catch, throw, drop, roll, toss, spin, twist, poke, pinch, pull, push, drag, scrub, brush, comb, shave, zip, button, tie, untie, snap, clap, wave, point, nod, gesture, smile, frown, laugh, coffee, kettle, water, tea, milk, sugar, cereal, yogurt, granola, fruit, bread, bagel, cheese, tomato, cucumber, onion, herb, banana, apple, orange, avocado, egg, bacon, sausage, ham, pan, stove, oven, pastry, croissant, strawberry, blender, ice, batter, syrup, cinnamon, honey, jar, plate, cup, spoon, fork, knife, tongs, lid, package, container, carton, bottle, pantry, fridge, cupboard, counter, sink, dish, towel, timer, mug, bowl, spatula, ladle, grater, peeler, colander, sieve, cuttingboard, tray, ovenmitt, scale, thermometer, stool, chair, table, napkin, freezer, hood, burner, flame, plug, socket, switch, knob, handle, cover, stirrer, measuringcup, measuringspoon, recipe, cookbook, ingredient, serving, leftover, waste, soap, sponge, detergent, faucet, garbage, recycle, bin"] # shorter for test elif dataset == "ucf101": text_concepts = ["jump, swing, skip, throw, catch, dribble, bounce, kick, pass, hit, serve, smash, block, spike, dive, swim, climb, grab, pull, hang, push, sit, ride, pedal, balance, stop, start, steer, mount, dismount, gallop, control, lift, curl, press, squat, deadlift, jab, hook, uppercut, dodge, wrestle, grapple, flip, perform, walk, handstand, run, sprint, shoot, turn, grind, row, paddle, surf, stand, tuck, enter, splash, wave, clap, raise, squat, spin, dance, breakdance, strike, parry, fight, reload, aim, release, bowl, swing, pitch, hit, catch, skateboard, snowboard, ski, trampoline, yoga, sword, gun, archery, hockey, basketball, volleyball, soccer, rugby, baseball, cricket, rope, ball, bat, racket, puck, stick, net, goal, pool, lane, wall, ladder, bar, dumbbell, barbell, mat, beam, hurdle, bicycle, helmet, horse, reins, rail, snowboard, skis, kayak, canoe, paddle, surfboard, gloves, boxing, stage, microphone, instrument, music, sheet, player, opponent, teammate, referee, coach, dancer, athlete, gymnast, swimmer, skater, snowboarder, skateboarder, rower, surfer, archer, shooter, bow, club, frisbee, arrow, target, goalpost, jersey, uniform, cap, helmet, pad, netting, court, field, track, floor, platform, water, sand, snow, ice, gym, stadium, arena, ring, mat, beam, hoop, basket, scoreboard, timer"] elif dataset == "hmdb51": text_concepts = ["bow, fight, sword, walk, run, sprint, jog, stand, up, sit, down, jump, hop, leap, fall, roll, crouch, bend, stretch, turn, around, look, up, look, down, look, left, look, right, nod, head, shake, head, smile, laugh, frown, yawn, talk, mouth, words, sing, chew, eat, with, hands, eat, with, utensils, drink, from, cup, drink, from, bottle, sip, blow, kiss, hug, wave, hand, point, reach, grab, object, release, object, throw, object, catch, object, toss, ball, kick, ball, hit, with, hand, punch, block, push, pull, lift, object, carry, object, drag, object, drop, object, catch, fall, climb, up, climb, down, crawl, swim, dive, surface, float, balance, ride, bicycle, pedal, bicycle, brake, bicycle, steer, bicycle, mount, horse, dismount, horse, gallop, ride, skateboard, skate, sled, ski, snowboard, slide, skate, backward, turn, skateboard, shoot, basketball, dribble, ball, bounce, ball, serve, tennis, swing, racket, hit, tennis, ball, swing, bat, hit, baseball, throw, frisbee, catch, frisbee, juggle, spin, object, roll, ball, kick, leg, high, kick, leg, low, flip, somersault, cartwheel, handstand, headstand, touch, head, touch, face, wash, face, comb, hair, brush, hair, brush, teeth, shave, apply, makeup, put, on, hat, take, off, hat, put, on, jacket, take, off, jacket, button, shirt, zip, jacket, tie, shoelace, untie, shoelace, open, door, close, door, knock, door, enter, room, exit, room, sit, on, chair, stand, from, chair, lie, down, wake, up, sleep, sprint, start, cross, finish, line"] elif dataset == "something2": text_concepts = ["push, pull, lift, drop, hold, carry, throw, catch, slide, drag, roll, spin, rotate, flip, fold, unfold, wrap, unwrap, tie, untie, fasten, unfasten, tighten, loosen, break, cut, slice, chop, tear, peel, crumple, flatten, bend, stretch, shake, stir, pour, scoop, sprinkle, stack, unstack, assemble, disassemble, open, close, lock, unlock, press, tap, swipe, scroll, zoom in, zoom out, point, touch, wave, clap, knock, snap, swing, juggle, bounce, balance, topple, insert, remove, fill, empty, mix, separate, spill, scatter, gather, cover, uncover, hide, reveal, lean, tilt, climb, crawl, jump, hop, walk, run, sprint, stumble, fall, get up, sit, stand, kneel, crouch, bow, dance, spin dance, nod, shake head, smile, frown, laugh, cry, shout, whisper, speak, yawn, sneeze, cough, sleep, wake, eat, chew, bite, sip, drink, spit, blow, smell, taste, write, draw, erase, paint, type, click, drag mouse, plug, unplug, connect, disconnect, turn on, turn off, start, stop, accelerate, decelerate, pretend to push, pretend to pull, pretend to pour, pretend to eat, pretend to drink, pretend to throw, pretend to catch, pretend to type, pretend to swipe, pretend to scroll, pretend to climb, pretend to fall, pretend to hug, pretend to kiss, pretend to wave, pretend to play guitar, pretend to drive, pretend to steer, pretend to read, pretend to sleep, pretend to wake, pretend to write, pretend to draw, pretend to paint, pretend to clean, pretend to cook, pretend to stir, pretend to measure, pretend to weigh, pretend to look around, pretend to search, pretend to point, pretend to balance, pretend to open, pretend to close, pretend to lock, pretend to unlock, pretend to kick, pretend to punch, pretend to block, pretend to dodge, pretend to jump rope, pretend to row, pretend to paddle, pretend to shoot arrow, pretend to load gun, pretend to fire gun, pretend to throw ball, pretend to dribble, pretend to shoot basket, pretend to swing bat, pretend to serve, pretend to catch fish, pretend to steer wheel, pretend to honk, pretend to use controller, pretend to play piano, pretend to play drums, pretend to dance, pretend to sing, pretend to clap, pretend to salute, pretend to bow, pretend to shake hands, pretend to hug, pretend to kiss, object, container, box, cup, bowl, plate, spoon, knife, fork, chopstick, pen, pencil, paper, book, phone, remote, laptop, keyboard, mouse, bag, backpack, toy, ball, fruit, apple, orange, banana, grape, vegetable, carrot, cucumber, tomato, bottle, can, lid, cap, key, lock, door, window, wall, floor, table, chair, shelf, hand, finger, arm, face, person, other, background, surface, inside, outside, top, bottom, left, right, upward, downward, hot, cold, wet, dry, clean, dirty, empty, full, broken, fixed, smooth, rough, heavy, light, fragile, durable, rollable, stackable, squeezable, pourable, spillable, openable, closeable, edible, drinkable"] else: print("Unknown dataset", dataset) text_concepts = [] concepts.embedd_text(text_concepts) # model cbm_model = MoTIF(embedder, concepts) cbm_model.preprocess(dataset, info=hparams["test_split"], random_state=seed) cbm_model.model = CBMTransformer( cbm_model.num_concepts, num_classes=cbm_model.num_classes, transformer_layers=hparams["transformer_layers"], lse_tau=hparams["lse_tau"], dimension=hparams["d"], diagonal_attention=hparams["diagonal_attention"], ) # wandb time_now = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) run_name = f'{dataset}_{clip_model}_{time_now}' wandb_run = wandb.init(project="motif", name=run_name, config=hparams) wandb_run.log({ 'test_split': hparams["test_split"], }) cbm_model.train_model( num_epochs=hparams["num_epochs"], l1_lambda=hparams["l1_lambda"], lambda_sparse=hparams["lambda_sparse"], lr=hparams["lr"], batch_size=hparams["batch_size"], enforce_nonneg=hparams["enforce_nonneg"], class_weights=hparams["class_weights"], wandb_run=wandb_run, random_seed=seed, ) cbm_model.zero_shot(concepts, wandb_run=wandb_run) mean_cbm(cbm_model, wandb_run=wandb_run) wandb_run.finish() model_name = f"./Models/checkpoint_{clip_model}_{dataset_name}.pkl" os.makedirs(os.path.dirname(model_name), exist_ok=True) with open(model_name, "wb") as f: pickle.dump(cbm_model, f) print("cbm_model and class saved to", model_name) if __name__ == "__main__": # define hyperparameter grid with descriptions # Hyperparameter descriptions: # num_epochs: Number of training epochs. # batch_size: Number of samples per training batch. # lse_tau: Temperature parameter for log-sum-exp pooling. # l1_lambda: L1 regularization strength. # lambda_sparse: Sparsity regularization strength. # lr: Learning rate for optimizer. # transformer_layers: Number of transformer layers in the CBM model. # diagonal_attention: If True, restricts attention to diagonal (self-attention only). # enforce_nonneg: If True, enforces non-negative concept activations. # class_weights: If True, uses class weights to balance loss. # weight_decay: Weight decay (L2 regularization) for optimizer. # d: Model dimension. Always 1, can be set higher to express more representations after Conv1d. # test_split: Which test split to use (e.g., "s1"). # window_size: Temporal window size for video embedding. # dataset: Dataset to use (e.g., "hmdb51", "breakfast", "something2"). # random: If True, uses random seed for image selection in window. # clip_model: CLIP model variant to use (e.g., "pe-l14", "b16", "res50", "clip4clip"). search_space = { "num_epochs": [100], "batch_size": [32], "lse_tau": [10.0], "l1_lambda": [1e-3], "lambda_sparse": [1e-3], "lr": [1e-3], "transformer_layers": [1], "diagonal_attention": [True], "enforce_nonneg": [True], "class_weights": [True], "weight_decay": [1e-2], "d": [1], "test_split": ["s1"], "window_size": [8], "dataset": ["hmdb51"], "random": [True], "clip_model": ["pe-l14"], } # grid search import itertools keys, values = zip(*search_space.items()) for v in itertools.product(*values): hparams = dict(zip(keys, v)) run_experiment(hparams)