|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils")) |
|
|
|
|
|
import math |
|
|
import time |
|
|
import pickle |
|
|
import wandb |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
from sklearn.model_selection import train_test_split |
|
|
from transformers import ( |
|
|
CLIPVisionModelWithProjection, |
|
|
CLIPTokenizer, |
|
|
CLIPTextModelWithProjection, |
|
|
) |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
import core.vision_encoder.pe as pe |
|
|
import core.vision_encoder.transforms as pe_transformer |
|
|
|
|
|
import clip |
|
|
|
|
|
|
|
|
from video_embedder import VideoEmbedder |
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed=42): |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
|
|
|
DATASET_MAP = { |
|
|
"breakfast": "Breakfast", |
|
|
"ucf101": "UCF101", |
|
|
"hmdb": "HMDB", |
|
|
"ssv2": "Something2", |
|
|
"jester": "Jester", |
|
|
} |
|
|
|
|
|
|
|
|
def process_dataset(dataset_key, clip_model, window_size=16, random=True, |
|
|
batch_size: int = 256, |
|
|
pe_video_batch_size: Optional[int] = None, |
|
|
pe_target_T: Optional[int] = None, |
|
|
enable_tf32: bool = True): |
|
|
dataset_name = DATASET_MAP.get(dataset_key.lower()) |
|
|
if dataset_name is None: |
|
|
raise ValueError(f"Unknown dataset: {dataset_key}") |
|
|
|
|
|
folder_path = [f"../Datasets/{dataset_name}/Video_data"] |
|
|
output_dir = "../Embeddings/Datasets" |
|
|
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_{clip_model}.pkl" |
|
|
|
|
|
|
|
|
if enable_tf32 and torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
if clip_model == "b32": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval() |
|
|
processor = CLIPProcessor.from_pretrained( |
|
|
"openai/clip-vit-base-patch32", use_fast=True |
|
|
) |
|
|
elif clip_model == "b16": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").eval() |
|
|
processor = CLIPProcessor.from_pretrained( |
|
|
"openai/clip-vit-base-patch16", use_fast=True |
|
|
) |
|
|
elif clip_model == "l14": |
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval() |
|
|
processor = CLIPProcessor.from_pretrained( |
|
|
"openai/clip-vit-large-patch14", use_fast=True |
|
|
) |
|
|
elif clip_model == "res50": |
|
|
|
|
|
model, processor = clip.load("RN50", device="cuda") |
|
|
elif clip_model == "clip4clip": |
|
|
|
|
|
model = CLIPVisionModelWithProjection.from_pretrained( |
|
|
"Searchium-ai/clip4clip-webvid150k" |
|
|
) |
|
|
model = model.eval() |
|
|
clip_full = CLIPModel.from_pretrained( |
|
|
"Searchium-ai/clip4clip-webvid150k" |
|
|
) |
|
|
|
|
|
model_text = CLIPTextModelWithProjection.from_pretrained( |
|
|
"Searchium-ai/clip4clip-webvid150k" |
|
|
) |
|
|
processor = CLIPTokenizer.from_pretrained( |
|
|
"Searchium-ai/clip4clip-webvid150k" |
|
|
) |
|
|
elif clip_model == "siglip": |
|
|
model = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
|
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
elif clip_model == "siglip2": |
|
|
model = AutoModel.from_pretrained("google/siglip2-base-patch32-256") |
|
|
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch32-256") |
|
|
elif clip_model == "pe-l14": |
|
|
model = pe.CLIP.from_config("PE-Core-L14-336") |
|
|
|
|
|
processor = pe_transformer.get_image_transform(model.image_size) |
|
|
tokenizer = pe_transformer.get_text_tokenizer(model.context_length) |
|
|
else: |
|
|
raise ValueError(f"Unknown CLIP model: {clip_model}") |
|
|
|
|
|
|
|
|
if ( |
|
|
clip_model == "clip4clip" |
|
|
or clip_model == "siglip" |
|
|
or clip_model == "siglip2" |
|
|
or clip_model == "res50" |
|
|
or clip_model == "pe-l14" |
|
|
): |
|
|
embedder = VideoEmbedder( |
|
|
clip_model, model, processor, |
|
|
pe_video_batch_size=pe_video_batch_size, |
|
|
pe_target_T=pe_target_T, |
|
|
) |
|
|
else: |
|
|
embedder = VideoEmbedder("clip", model, processor) |
|
|
embedder.dataset_name = dataset_key |
|
|
|
|
|
if os.path.exists(embedd_path): |
|
|
try: |
|
|
with open(embedd_path, "rb") as f: |
|
|
embedder = pickle.load(f) |
|
|
print(f"Loaded existing embedder from {embedd_path}") |
|
|
except FileNotFoundError: |
|
|
print("Embedder file not found, creating a new one.") |
|
|
else: |
|
|
embedder.process_data( |
|
|
folder_path, |
|
|
window_size=window_size, |
|
|
output_path=output_dir, |
|
|
random=random, |
|
|
save_intermediate=True, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
os.makedirs(os.path.dirname(embedd_path), exist_ok=True) |
|
|
with open(embedd_path, "wb") as f: |
|
|
pickle.dump(embedder, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
window_size = 32 |
|
|
clip_model = "pe-l14" |
|
|
|
|
|
process_dataset( |
|
|
"breakfast", |
|
|
clip_model, |
|
|
window_size=window_size, |
|
|
random=True, |
|
|
batch_size=256, |
|
|
pe_video_batch_size=24, |
|
|
pe_target_T=8, |
|
|
enable_tf32=True, |
|
|
) |
|
|
|