File size: 5,868 Bytes
3cf4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import sys
import os
# Ensure current directory is in sys.path for local imports
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils"))
import math
import time
import pickle
import wandb
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
from transformers import CLIPProcessor, CLIPModel
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from transformers import (
CLIPVisionModelWithProjection,
CLIPTokenizer,
CLIPTextModelWithProjection,
)
from transformers import AutoProcessor, AutoModel # siglip
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as pe_transformer
import clip
from video_embedder import VideoEmbedder
# Set random seed
def set_seed(seed=42):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
# Map dataset keys to readable names
DATASET_MAP = {
"breakfast": "Breakfast",
"ucf101": "UCF101",
"hmdb": "HMDB",
"ssv2": "Something2",
"jester": "Jester",
}
def process_dataset(dataset_key, clip_model, window_size=16, random=True,
batch_size: int = 256,
pe_video_batch_size: Optional[int] = None,
pe_target_T: Optional[int] = None,
enable_tf32: bool = True):
dataset_name = DATASET_MAP.get(dataset_key.lower())
if dataset_name is None:
raise ValueError(f"Unknown dataset: {dataset_key}")
folder_path = [f"../Datasets/{dataset_name}/Video_data"]
output_dir = "../Embeddings/Datasets"
embedd_path = f"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_{clip_model}.pkl"
# Optional: enable TF32 for faster matmul on Ampere+
if enable_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Load CLIP model & processor
if clip_model == "b32":
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval()
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32", use_fast=True
)
elif clip_model == "b16":
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").eval()
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch16", use_fast=True
)
elif clip_model == "l14":
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval()
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14", use_fast=True
)
elif clip_model == "res50":
# use the 'clip' module (imported at top) to load RN50
model, processor = clip.load("RN50", device="cuda")
elif clip_model == "clip4clip":
# avoid naming any variable `clip` that would shadow the imported module
model = CLIPVisionModelWithProjection.from_pretrained(
"Searchium-ai/clip4clip-webvid150k"
)
model = model.eval()
clip_full = CLIPModel.from_pretrained(
"Searchium-ai/clip4clip-webvid150k"
) # renamed to avoid shadowing
model_text = CLIPTextModelWithProjection.from_pretrained(
"Searchium-ai/clip4clip-webvid150k"
) # for text
processor = CLIPTokenizer.from_pretrained(
"Searchium-ai/clip4clip-webvid150k"
) # for text
elif clip_model == "siglip":
model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
elif clip_model == "siglip2":
model = AutoModel.from_pretrained("google/siglip2-base-patch32-256")
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch32-256")
elif clip_model == "pe-l14":
model = pe.CLIP.from_config("PE-Core-L14-336")
processor = pe_transformer.get_image_transform(model.image_size)
tokenizer = pe_transformer.get_text_tokenizer(model.context_length)
else:
raise ValueError(f"Unknown CLIP model: {clip_model}")
# Create embedder
if (
clip_model == "clip4clip"
or clip_model == "siglip"
or clip_model == "siglip2"
or clip_model == "res50"
or clip_model == "pe-l14"
):
embedder = VideoEmbedder(
clip_model, model, processor,
pe_video_batch_size=pe_video_batch_size,
pe_target_T=pe_target_T,
)
else:
embedder = VideoEmbedder("clip", model, processor)
embedder.dataset_name = dataset_key
if os.path.exists(embedd_path):
try:
with open(embedd_path, "rb") as f:
embedder = pickle.load(f)
print(f"Loaded existing embedder from {embedd_path}")
except FileNotFoundError:
print("Embedder file not found, creating a new one.")
else:
embedder.process_data(
folder_path,
window_size=window_size,
output_path=output_dir,
random=random,
save_intermediate=True,
batch_size=batch_size,
)
os.makedirs(os.path.dirname(embedd_path), exist_ok=True)
with open(embedd_path, "wb") as f:
pickle.dump(embedder, f)
# Example usage
window_size = 32 # example: match your window size
clip_model = "pe-l14"
# Faster defaults for video models: smaller T and larger video batch
process_dataset(
"breakfast",
clip_model,
window_size=window_size,
random=True,
batch_size=256, # ignored for PE path except as upper bound
pe_video_batch_size=24, # try 8–16 depending on VRAM
pe_target_T=8, # uniformly sample each window to 8 frames
enable_tf32=True,
)
|