|
|
from utils.motif import init_repro |
|
|
|
|
|
seed = 42 |
|
|
init_repro(seed, deterministic=True) |
|
|
|
|
|
import sys, os, time, pickle, copy, math |
|
|
sys.path.append("./utils") |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
CLIPProcessor, CLIPModel, |
|
|
CLIPVisionModelWithProjection, |
|
|
CLIPTokenizer, CLIPTextModelWithProjection, |
|
|
AutoProcessor, AutoModel |
|
|
) |
|
|
import clip |
|
|
import wandb |
|
|
|
|
|
from utils.video_embedder import VideoEmbedder, Create_Concepts |
|
|
from utils.motif import MoTIF, CBMTransformer, mean_cbm |
|
|
|
|
|
import core.vision_encoder.pe as pe |
|
|
import core.vision_encoder.transforms as pe_transformer |
|
|
|
|
|
def run_experiment(hparams): |
|
|
"""Run one CBM training experiment with given hyperparameters.""" |
|
|
dataset = hparams["dataset"] |
|
|
clip_model = hparams["clip_model"] |
|
|
window_size = hparams["window_size"] |
|
|
random = hparams["random"] |
|
|
|
|
|
|
|
|
dataset_map = { |
|
|
"breakfast": "Breakfast", |
|
|
"ucf101": "UCF101", |
|
|
"hmdb51": "HMDB", |
|
|
"something2": "Something2" |
|
|
} |
|
|
dataset_name = dataset_map[dataset] |
|
|
|
|
|
|
|
|
if clip_model == "b32": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval() |
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b32.pkl" |
|
|
clip_name = "clip" |
|
|
elif clip_model == "b16": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").eval() |
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16", use_fast=False) |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b16.pkl" |
|
|
clip_name = "clip" |
|
|
elif clip_model == "l14": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval() |
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=False) |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_l14.pkl" |
|
|
clip_name = "clip" |
|
|
elif clip_model == "res50": |
|
|
model, preprocess = clip.load("RN50", device="cpu") |
|
|
processor = preprocess |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_res50.pkl" |
|
|
clip_name = "res50" |
|
|
elif clip_model == "clip4clip": |
|
|
model = CLIPVisionModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k").eval() |
|
|
model_text = CLIPTextModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k") |
|
|
processor = CLIPTokenizer.from_pretrained("Searchium-ai/clip4clip-webvid150k") |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_clip4clip.pkl" |
|
|
clip_name = "clip4clip" |
|
|
elif clip_model == "siglip": |
|
|
model = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
|
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
clip_name = "siglip" |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglip.pkl" |
|
|
|
|
|
elif clip_model == "siglipl14": |
|
|
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384") |
|
|
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") |
|
|
clip_name = "siglipl14" |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglipl14.pkl" |
|
|
|
|
|
elif clip_model == "pe-l14": |
|
|
model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True) |
|
|
|
|
|
processor = pe_transformer.get_image_transform(model.image_size) |
|
|
tokenizer = pe_transformer.get_text_tokenizer(model.context_length) |
|
|
clip_name = "pe-l14" |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_pe-l14.pkl" |
|
|
else: |
|
|
model = None |
|
|
processor = None |
|
|
model_text = None |
|
|
|
|
|
raise ValueError(f"Unknown clip_model {clip_model}") |
|
|
|
|
|
|
|
|
embedder = VideoEmbedder(clip_name, model, processor) |
|
|
embedder.dataset_name = dataset |
|
|
|
|
|
if os.path.exists(embedd_path): |
|
|
with open(embedd_path, "rb") as f: |
|
|
embedder = pickle.load(f) |
|
|
print("Loaded existing embedder from", embedd_path) |
|
|
else: |
|
|
folder_path = [f"../Datasets/{dataset_name}/Video_data"] |
|
|
embedder.process_data(folder_path, window_size=window_size, output_path="../Embeddings/Datasets") |
|
|
with open(embedd_path, "wb") as f: |
|
|
pickle.dump(embedder, f) |
|
|
|
|
|
|
|
|
if clip_model == "clip4clip": |
|
|
concepts = Create_Concepts(clip_name, model_text, processor) |
|
|
elif clip_model == "pe-l14": |
|
|
concepts = Create_Concepts(clip_name, model, tokenizer) |
|
|
else: |
|
|
concepts = Create_Concepts(clip_name, model, processor) |
|
|
|
|
|
if dataset == "breakfast": |
|
|
text_concepts = ["grind, fill, boil, pour, steep, brew, tamp, insert, steam, froth, stir, sip, add, slice, toast, butter, spread, cut, assemble, grate, chop, peel, core, squeeze, pit, mash, crack, whisk, beat, fry, scramble, flip, mix, cook, drizzle, serve, drain, grill, preheat, bake, warm, wash, rinse, blend, measure, set, open, close, take, put, remove, pack, dry, wipe, sit, stand, carry, pick, blow, taste, adjust, reach, place, seal, unwrap, unscrew, scoop, zest, juice, start, stop, turn, heat, cool, toss, shake, tap, knock, press, release, slide, rotate, fold, unfold, wring, sprinkle, arrange, sort, stack, unstack, hide, reveal, cover, uncover, balance, tilt, catch, throw, drop, roll, toss, spin, twist, poke, pinch, pull, push, drag, scrub, brush, comb, shave, zip, button, tie, untie, snap, clap, wave, point, nod, gesture, smile, frown, laugh, coffee, kettle, water, tea, milk, sugar, cereal, yogurt, granola, fruit, bread, bagel, cheese, tomato, cucumber, onion, herb, banana, apple, orange, avocado, egg, bacon, sausage, ham, pan, stove, oven, pastry, croissant, strawberry, blender, ice, batter, syrup, cinnamon, honey, jar, plate, cup, spoon, fork, knife, tongs, lid, package, container, carton, bottle, pantry, fridge, cupboard, counter, sink, dish, towel, timer, mug, bowl, spatula, ladle, grater, peeler, colander, sieve, cuttingboard, tray, ovenmitt, scale, thermometer, stool, chair, table, napkin, freezer, hood, burner, flame, plug, socket, switch, knob, handle, cover, stirrer, measuringcup, measuringspoon, recipe, cookbook, ingredient, serving, leftover, waste, soap, sponge, detergent, faucet, garbage, recycle, bin"] |
|
|
elif dataset == "ucf101": |
|
|
text_concepts = ["jump, swing, skip, throw, catch, dribble, bounce, kick, pass, hit, serve, smash, block, spike, dive, swim, climb, grab, pull, hang, push, sit, ride, pedal, balance, stop, start, steer, mount, dismount, gallop, control, lift, curl, press, squat, deadlift, jab, hook, uppercut, dodge, wrestle, grapple, flip, perform, walk, handstand, run, sprint, shoot, turn, grind, row, paddle, surf, stand, tuck, enter, splash, wave, clap, raise, squat, spin, dance, breakdance, strike, parry, fight, reload, aim, release, bowl, swing, pitch, hit, catch, skateboard, snowboard, ski, trampoline, yoga, sword, gun, archery, hockey, basketball, volleyball, soccer, rugby, baseball, cricket, rope, ball, bat, racket, puck, stick, net, goal, pool, lane, wall, ladder, bar, dumbbell, barbell, mat, beam, hurdle, bicycle, helmet, horse, reins, rail, snowboard, skis, kayak, canoe, paddle, surfboard, gloves, boxing, stage, microphone, instrument, music, sheet, player, opponent, teammate, referee, coach, dancer, athlete, gymnast, swimmer, skater, snowboarder, skateboarder, rower, surfer, archer, shooter, bow, club, frisbee, arrow, target, goalpost, jersey, uniform, cap, helmet, pad, netting, court, field, track, floor, platform, water, sand, snow, ice, gym, stadium, arena, ring, mat, beam, hoop, basket, scoreboard, timer"] |
|
|
elif dataset == "hmdb51": |
|
|
text_concepts = ["bow, fight, sword, walk, run, sprint, jog, stand, up, sit, down, jump, hop, leap, fall, roll, crouch, bend, stretch, turn, around, look, up, look, down, look, left, look, right, nod, head, shake, head, smile, laugh, frown, yawn, talk, mouth, words, sing, chew, eat, with, hands, eat, with, utensils, drink, from, cup, drink, from, bottle, sip, blow, kiss, hug, wave, hand, point, reach, grab, object, release, object, throw, object, catch, object, toss, ball, kick, ball, hit, with, hand, punch, block, push, pull, lift, object, carry, object, drag, object, drop, object, catch, fall, climb, up, climb, down, crawl, swim, dive, surface, float, balance, ride, bicycle, pedal, bicycle, brake, bicycle, steer, bicycle, mount, horse, dismount, horse, gallop, ride, skateboard, skate, sled, ski, snowboard, slide, skate, backward, turn, skateboard, shoot, basketball, dribble, ball, bounce, ball, serve, tennis, swing, racket, hit, tennis, ball, swing, bat, hit, baseball, throw, frisbee, catch, frisbee, juggle, spin, object, roll, ball, kick, leg, high, kick, leg, low, flip, somersault, cartwheel, handstand, headstand, touch, head, touch, face, wash, face, comb, hair, brush, hair, brush, teeth, shave, apply, makeup, put, on, hat, take, off, hat, put, on, jacket, take, off, jacket, button, shirt, zip, jacket, tie, shoelace, untie, shoelace, open, door, close, door, knock, door, enter, room, exit, room, sit, on, chair, stand, from, chair, lie, down, wake, up, sleep, sprint, start, cross, finish, line"] |
|
|
elif dataset == "something2": |
|
|
text_concepts = ["push, pull, lift, drop, hold, carry, throw, catch, slide, drag, roll, spin, rotate, flip, fold, unfold, wrap, unwrap, tie, untie, fasten, unfasten, tighten, loosen, break, cut, slice, chop, tear, peel, crumple, flatten, bend, stretch, shake, stir, pour, scoop, sprinkle, stack, unstack, assemble, disassemble, open, close, lock, unlock, press, tap, swipe, scroll, zoom in, zoom out, point, touch, wave, clap, knock, snap, swing, juggle, bounce, balance, topple, insert, remove, fill, empty, mix, separate, spill, scatter, gather, cover, uncover, hide, reveal, lean, tilt, climb, crawl, jump, hop, walk, run, sprint, stumble, fall, get up, sit, stand, kneel, crouch, bow, dance, spin dance, nod, shake head, smile, frown, laugh, cry, shout, whisper, speak, yawn, sneeze, cough, sleep, wake, eat, chew, bite, sip, drink, spit, blow, smell, taste, write, draw, erase, paint, type, click, drag mouse, plug, unplug, connect, disconnect, turn on, turn off, start, stop, accelerate, decelerate, pretend to push, pretend to pull, pretend to pour, pretend to eat, pretend to drink, pretend to throw, pretend to catch, pretend to type, pretend to swipe, pretend to scroll, pretend to climb, pretend to fall, pretend to hug, pretend to kiss, pretend to wave, pretend to play guitar, pretend to drive, pretend to steer, pretend to read, pretend to sleep, pretend to wake, pretend to write, pretend to draw, pretend to paint, pretend to clean, pretend to cook, pretend to stir, pretend to measure, pretend to weigh, pretend to look around, pretend to search, pretend to point, pretend to balance, pretend to open, pretend to close, pretend to lock, pretend to unlock, pretend to kick, pretend to punch, pretend to block, pretend to dodge, pretend to jump rope, pretend to row, pretend to paddle, pretend to shoot arrow, pretend to load gun, pretend to fire gun, pretend to throw ball, pretend to dribble, pretend to shoot basket, pretend to swing bat, pretend to serve, pretend to catch fish, pretend to steer wheel, pretend to honk, pretend to use controller, pretend to play piano, pretend to play drums, pretend to dance, pretend to sing, pretend to clap, pretend to salute, pretend to bow, pretend to shake hands, pretend to hug, pretend to kiss, object, container, box, cup, bowl, plate, spoon, knife, fork, chopstick, pen, pencil, paper, book, phone, remote, laptop, keyboard, mouse, bag, backpack, toy, ball, fruit, apple, orange, banana, grape, vegetable, carrot, cucumber, tomato, bottle, can, lid, cap, key, lock, door, window, wall, floor, table, chair, shelf, hand, finger, arm, face, person, other, background, surface, inside, outside, top, bottom, left, right, upward, downward, hot, cold, wet, dry, clean, dirty, empty, full, broken, fixed, smooth, rough, heavy, light, fragile, durable, rollable, stackable, squeezable, pourable, spillable, openable, closeable, edible, drinkable"] |
|
|
else: |
|
|
print("Unknown dataset", dataset) |
|
|
text_concepts = [] |
|
|
|
|
|
concepts.embedd_text(text_concepts) |
|
|
|
|
|
|
|
|
cbm_model = MoTIF(embedder, concepts) |
|
|
cbm_model.preprocess(dataset, info=hparams["test_split"], random_state=seed) |
|
|
|
|
|
cbm_model.model = CBMTransformer( |
|
|
cbm_model.num_concepts, |
|
|
num_classes=cbm_model.num_classes, |
|
|
transformer_layers=hparams["transformer_layers"], |
|
|
lse_tau=hparams["lse_tau"], |
|
|
dimension=hparams["d"], |
|
|
diagonal_attention=hparams["diagonal_attention"], |
|
|
) |
|
|
|
|
|
|
|
|
time_now = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) |
|
|
run_name = f'{dataset}_{clip_model}_{time_now}' |
|
|
wandb_run = wandb.init(project="motif", name=run_name, config=hparams) |
|
|
|
|
|
wandb_run.log({ |
|
|
'test_split': hparams["test_split"], |
|
|
}) |
|
|
|
|
|
cbm_model.train_model( |
|
|
num_epochs=hparams["num_epochs"], |
|
|
l1_lambda=hparams["l1_lambda"], |
|
|
lambda_sparse=hparams["lambda_sparse"], |
|
|
lr=hparams["lr"], |
|
|
batch_size=hparams["batch_size"], |
|
|
enforce_nonneg=hparams["enforce_nonneg"], |
|
|
class_weights=hparams["class_weights"], |
|
|
wandb_run=wandb_run, |
|
|
random_seed=seed, |
|
|
|
|
|
) |
|
|
cbm_model.zero_shot(concepts, wandb_run=wandb_run) |
|
|
mean_cbm(cbm_model, wandb_run=wandb_run) |
|
|
wandb_run.finish() |
|
|
|
|
|
|
|
|
model_name = f"./Models/checkpoint_{clip_model}_{dataset_name}.pkl" |
|
|
os.makedirs(os.path.dirname(model_name), exist_ok=True) |
|
|
with open(model_name, "wb") as f: |
|
|
pickle.dump(cbm_model, f) |
|
|
print("cbm_model and class saved to", model_name) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search_space = { |
|
|
"num_epochs": [100], |
|
|
"batch_size": [32], |
|
|
"lse_tau": [10.0], |
|
|
"l1_lambda": [1e-3], |
|
|
"lambda_sparse": [1e-3], |
|
|
"lr": [1e-3], |
|
|
"transformer_layers": [1], |
|
|
"diagonal_attention": [True], |
|
|
"enforce_nonneg": [True], |
|
|
"class_weights": [True], |
|
|
"weight_decay": [1e-2], |
|
|
"d": [1], |
|
|
"test_split": ["s1"], |
|
|
"window_size": [8], |
|
|
"dataset": ["hmdb51"], |
|
|
"random": [True], |
|
|
"clip_model": ["pe-l14"], |
|
|
} |
|
|
|
|
|
|
|
|
import itertools |
|
|
keys, values = zip(*search_space.items()) |
|
|
for v in itertools.product(*values): |
|
|
hparams = dict(zip(keys, v)) |
|
|
run_experiment(hparams) |
|
|
|