from __future__ import annotations import os import random import numpy as np import torch import copy from typing import List, Optional, Dict, Tuple import cv2 from PIL import Image import tqdm import torch.nn as nn import gc import torch.nn.functional as F from torchvision.transforms import ( Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, ) import math from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split import wandb import re import pandas as pd import glob def init_repro(seed: int = 42, deterministic: bool = True): """Call this at the very top of your notebook/script BEFORE creating any model/processor/device context.""" os.environ["PYTHONHASHSEED"] = str(seed) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ( ":16:8" # deterministic cuBLAS on Ampere+, nice default ) os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Determinism knobs (do this before any CUDA ops) if deterministic: try: torch.use_deterministic_algorithms(True) except Exception: # older torch may not support signature torch.set_deterministic(True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False # Reduce threading nondeterminism torch.set_num_threads(1) return seed def get_torch_device(prefer: Optional[str] = None) -> torch.device: if prefer is not None: pref = prefer.lower() if pref == "cuda" and torch.cuda.is_available(): return torch.device("cuda") if ( pref == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ): return torch.device("mps") if pref == "cpu": return torch.device("cpu") if torch.cuda.is_available(): return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def pad_batch_sequences( seqs: List[torch.Tensor], device: torch.device ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Pad a list of [T_i, C] tensors into a batch [B, T_max, C] and return a key_padding_mask [B, T_max] with True for padded positions. """ if len(seqs) == 0: raise ValueError("pad_batch_sequences received empty sequence list") lengths = [int(s.shape[0]) for s in seqs] C = int(seqs[0].shape[1]) T_max = int(max(lengths)) B = len(seqs) batch = torch.zeros((B, T_max, C), dtype=torch.float32, device=device) mask = torch.ones((B, T_max), dtype=torch.bool, device=device) # True=padded for i, s in enumerate(seqs): t = lengths[i] batch[i, :t, :] = s.to(device) mask[i, :t] = False return batch, mask def compute_concept_standardization(seqs: List[torch.Tensor | np.ndarray]): cat = torch.cat( [ ( s if isinstance(s, torch.Tensor) else torch.tensor(np.array(s), dtype=torch.float32) ) for s in seqs ], dim=0, ) mean = cat.mean(dim=0) std = cat.std(dim=0).clamp_min(1e-6) return mean, std def apply_standardization( seqs: List[torch.Tensor | np.ndarray], mean: torch.Tensor, std: torch.Tensor ): out = [] for s in seqs: s_t = ( s if isinstance(s, torch.Tensor) else torch.tensor(np.array(s), dtype=torch.float32) ) out.append((s_t - mean) / std) return out def concepts_over_time_cosine( concepts: torch.Tensor, all_data_list, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, chunk_size: int | None = None, ): """ Cosine-sim per frame vs concepts. - Normalizes in fp32 for stability, computes in fp32, then returns on CPU. - Optional chunked matmul to cap peak memory. """ with torch.no_grad(): # normalize concepts in fp32 on target device c = F.normalize( concepts.detach().to(device=device, dtype=torch.float32), dim=1 ) # [K,C] K = c.shape[0] activations, embeddings = [], [] for vid in all_data_list: x = vid if isinstance(vid, torch.Tensor) else torch.as_tensor(vid) if x.ndim == 1: x = x.unsqueeze(0) elif x.ndim > 2: x = x.view(-1, x.size(-1)) x = x.detach().to(device=device, dtype=torch.float32) # [T,C] if x.numel() == 0: sim = torch.empty((0, K), dtype=torch.float32, device=device) else: x = F.normalize(x, dim=1) if chunk_size is None or x.shape[0] <= chunk_size: sim = x @ c.T # [T,K] else: # chunk over T to limit peak memory outs = [] for s in range(0, x.shape[0], chunk_size): outs.append(x[s : s + chunk_size] @ c.T) sim = torch.cat(outs, dim=0) sim = torch.clamp(sim, min=0.0) # return CPU fp32 activations.append(sim.to("cpu", dtype=dtype)) embeddings.append(vid) # keep original reference if needed return activations, embeddings class PositionalEncoding(nn.Module): """ Supports both [T, C] and [B, T, C] input tensors, automatically unsqueezing and squeezing as needed for 2D input. """ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze( 1 ) # [max_len,1] div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, d_model, dtype=torch.float32) # [max_len, C] # Handle even and odd indices separately to avoid dimension mismatch pe[:, 0::2] = torch.sin(position * div_term) if d_model % 2 == 0: # Even d_model: use same div_term for cosine pe[:, 1::2] = torch.cos(position * div_term) else: # Odd d_model: need one more element for cosine div_term_cos = torch.exp( torch.arange(0, d_model - 1, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model) ) pe[:, 1::2] = torch.cos(position * div_term_cos) self.register_buffer("pe", pe) def forward(self, x: torch.Tensor): """ Handles both 2D and 3D input, automatically unsqueezing and squeezing for [T, C] input. Positional encoding is broadcast over the batch dimension. """ squeeze_back = False if x.dim() == 2: # [T, C] -> [1, T, C] x = x.unsqueeze(0) squeeze_back = True seq_len = x.size(1) x = x + self.pe[:seq_len, :] # broadcast over batch x = self.dropout(x) if squeeze_back: x = x.squeeze(0) return x # ------------------------- # Diagonal (per-channel) Q/K/V + per-channel FFN # ------------------------- class DiagQKVd(nn.Module): """Per-channel Q/K/V with width d (no cross-concept mixing).""" def __init__(self, C: int, d: int = 8, bias: bool = True): super().__init__() self.C, self.d = C, d # groups=C keeps channels isolated; each channel gets d features self.q = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) self.k = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) self.v = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) def forward(self, x): # x: [B,T,C] B, T, C = x.shape xc = x.transpose(1, 2) # [B,C,T] Q = self.q(xc).transpose(1, 2).view(B, T, C, self.d) # [B,T,C,d] K = self.k(xc).transpose(1, 2).view(B, T, C, self.d) V = self.v(xc).transpose(1, 2).view(B, T, C, self.d) return Q, K, V class ChannelTimeNorm(nn.Module): def __init__(self, C, eps=1e-5, affine=True): super().__init__() self.ln = nn.LayerNorm(C, eps=eps, elementwise_affine=affine) def forward(self, x): # x: [B,T,C] return self.ln(x) class PerChannelFFN(nn.Module): """Per-channel FFN (no cross-concept mixing).""" def __init__(self, C: int, dropout: float = 0.1): super().__init__() self.fc1 = nn.Conv1d( C, C, kernel_size=1, groups=C, bias=True ) # group equals C to have no channel mixing! self.fc2 = nn.Conv1d(C, C, kernel_size=1, groups=C, bias=True) self.act = nn.GELU() self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor): # x: [B, T, C] xc = x.transpose(1, 2) # [B, C, T] y = self.fc2(self.drop(self.act(self.fc1(xc)))) return y.transpose(1, 2) # [B, T, C] class PerChannelTemporalBlock(nn.Module): """ Attention over time for each concept channel independently. Stores attn_weights: [B, C, T, T]. """ def __init__(self, C: int, d: int = 1, dropout: float = 0.1, T_max: int = 1024): super().__init__() self.C, self.d = C, d self.qkv = DiagQKVd(C, d) self.scale = d**-0.5 self.logit_scale = nn.Parameter(torch.zeros(C)) # per-concept multiplier self.norm1 = ChannelTimeNorm(C) self.norm2 = ChannelTimeNorm(C) self.drop = nn.Dropout(dropout) self.ffn = PerChannelFFN(C, dropout=dropout) self.act = nn.GELU() self.attn_weights = None def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, T, C = x.shape # Pre-attention norm y = self.norm1(x) # [B, T, C] # Per-channel QKV: Q/K/V are [B, T, C, d] Q, K, V = self.qkv(y) # Attention logits per channel: [B, C, T, T] scores = torch.einsum("btcd,bucd->bctu", Q, K) * self.scale # Optional masks if attn_mask is not None: # treat bool as additive -inf mask; float as-is if attn_mask.dtype == torch.bool: am = torch.zeros_like(attn_mask, dtype=scores.dtype) am = am.masked_fill(attn_mask, float("-inf")) else: am = attn_mask.to(dtype=scores.dtype) scores = scores + am.view(1, 1, T, T) if key_padding_mask is not None: kpm = key_padding_mask.view(B, 1, 1, T) # True = masked scores = scores.masked_fill(kpm, float("-inf")) # Softmax over source time axis w = torch.softmax(scores, dim=-1) # [B, C, T, T] self.attn_weights = w.detach() # Weighted sum of values, then reduce d out = torch.einsum("bctu,bucd->btcd", w, V).mean(dim=-1) # [B, T, C] # Residual + dropout x = x + self.drop(out) # Post-attention norm + per-channel FFN (already expects [B,T,C]) z = self.norm2(x) z = self.ffn(z) # Residual + dropout x = x + self.drop(z) return x def _pick_num_heads(C: int, proposed: Optional[int]) -> int: if proposed is not None and proposed >= 1 and C % proposed == 0: return proposed for h in [8, 6, 4, 3, 2]: if h <= C and C % h == 0: return h return 1 class FullAttentionTemporalBlock(nn.Module): """ Full multi-head self-attention over time with channel mixing (manual implementation). """ def __init__( self, C: int, num_heads: Optional[int] = None, dropout: float = 0.1, ffn_mult: int = 4, ): super().__init__() self.C = C self.H = _pick_num_heads(C, num_heads) self.d = C // self.H assert self.H * self.d == C, "C must be divisible by num_heads" # Projections (mix channels) self.q_proj = nn.Linear(C, C, bias=True) self.k_proj = nn.Linear(C, C, bias=True) self.v_proj = nn.Linear(C, C, bias=True) self.o_proj = nn.Linear(C, C, bias=True) self.attn_drop = nn.Dropout(dropout) self.proj_drop = nn.Dropout(dropout) self.ffn = nn.Sequential( nn.Linear(C, ffn_mult * C), nn.GELU(), nn.Dropout(dropout), nn.Linear(ffn_mult * C, C), ) self.dropout = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(C) self.norm2 = nn.LayerNorm(C) self.attn_weights = None # [B, H, T, T] def _shape_heads(self, x: torch.Tensor) -> torch.Tensor: # [B, T, C] -> [B, H, T, d] B, T, _ = x.shape return x.view(B, T, self.H, self.d).permute(0, 2, 1, 3) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, # [T, T] key_padding_mask: Optional[torch.Tensor] = None, # [B, T] ) -> torch.Tensor: assert x.dim() == 3, "x must be [B, T, C]" B, T, C = x.shape assert C == self.C # Projections Q = self._shape_heads(self.q_proj(x)) # [B,H,T,d] K = self._shape_heads(self.k_proj(x)) # [B,H,T,d] V = self._shape_heads(self.v_proj(x)) # [B,H,T,d] # Scaled dot-product attention scale = self.d**-0.5 scores = torch.matmul(Q, K.transpose(-2, -1)) * scale # [B,H,T,T] # Masks if attn_mask is not None: # bool -> additive mask; float left as-is if attn_mask.dtype == torch.bool: am = torch.zeros_like(attn_mask, dtype=Q.dtype) # 0 keep am = am.masked_fill(attn_mask, float("-inf")) else: am = attn_mask.to(dtype=Q.dtype) scores = scores + am.view(1, 1, T, T) if key_padding_mask is not None: kpm = key_padding_mask.to(torch.bool).view( B, 1, 1, T ) # broadcast on heads & queries scores = scores.masked_fill(kpm, float("-inf")) weights = F.softmax(scores, dim=-1) # [B,H,T,T] weights = self.attn_drop(weights) self.attn_weights = weights.detach() out = torch.matmul(weights, V) # [B,H,T,d] out = out.permute(0, 2, 1, 3).contiguous() # [B,T,H,d] out = out.view(B, T, C) # [B,T,C] out = self.o_proj(out) out = self.proj_drop(out) # Residual + norm x = self.norm1(x + out) # FFN + residual + norm ff = self.ffn(x) x = self.norm2(x + self.dropout(ff)) return x class MoTIF: """ MoTIF model for video classification using concept bottleneck models. Assumes: - concepts_over_time_cosine returns signed cosine sims (no clamp). - self.model(window_embeddings, key_padding_mask) returns (logits, concepts, concepts_t, sharpness) """ @staticmethod def _collate_pad(batch): """ batch: list of tuples (seq:[T,C] CPU float32, y:int) Returns CPU pinned tensors to enable non_blocking .to(device) """ B = len(batch) T = max(seq.shape[0] for seq, _ in batch) C = batch[0][0].shape[1] x = torch.zeros((B, T, C), dtype=torch.float32) mask = torch.ones((B, T), dtype=torch.bool) # True = padded y = torch.empty((B,), dtype=torch.long) for i, (seq, yi) in enumerate(batch): t = seq.shape[0] x[i, :t].copy_(seq) # CPU->CPU copy into pinned mask[i, :t] = False y[i] = yi return x, mask, y def __init__(self, embedder, concepts): self.device = get_torch_device(prefer="cuda") self.concepts = concepts self.all_data = embedder.video_embeddings # dict: path -> [T,C] self.all_labels = ( embedder.labels ) # list aligned with keys order (non-SSv2 case) self.video_paths = list(self.all_data.keys()) self.video_spans = embedder.video_window_spans self.concept_bank = concepts.text_embeddings self.raw_activations, self.video_embeddings = concepts_over_time_cosine( self.concept_bank, list(self.all_data.values()) ) # list of [T,C] keep_idx = [ i for i, act in enumerate(self.raw_activations) if isinstance(act, torch.Tensor) and act.shape[0] > 0 ] if len(keep_idx) != len(self.raw_activations): removed = len(self.raw_activations) - len(keep_idx) self.raw_activations = [self.raw_activations[i] for i in keep_idx] self.video_paths = [self.video_paths[i] for i in keep_idx] self.all_labels = [self.all_labels[i] for i in keep_idx] # non-SSv2 path self.video_embeddings = [self.video_embeddings[i] for i in keep_idx] print(f"[MoTIF] Removed {removed} entries with empty activations.") # Stable, aligned numeric IDs (for SSv2) self.video_ids = [self.path_to_id(p) for p in self.video_paths] self.kept_ids = {vid for vid in self.video_ids if vid is not None} # Defer LabelEncoder to preprocess() self.encoder = LabelEncoder() self.class_weights = None self.mean_c, self.std_c = None, None self.X_train = self.X_val = self.X_test = None self.y_train = self.y_val = self.y_test = None self.paths_train = self.paths_val = self.paths_test = None self.test_zero_shot = None # Model attached later self.model = None @staticmethod def path_to_id(p: str): base = os.path.splitext(os.path.basename(p))[0] m = re.search(r"(\d+)", base) return int(m.group(1)) if m else None # ------------------------- # Zero-shot (vectorized over frames) # ------------------------- @torch.inference_mode() def zero_shot(self, concept_embedder, wandb_run=None): assert ( self.test_zero_shot is not None and self.y_test is not None ), "Call preprocess(...) first." # build text prompts and text embeddings class_prompts = ["a video of " + c for c in self.encoder.classes_.tolist()] text_embedder = copy.copy(concept_embedder) text_embedder.tokenizer = concept_embedder.tokenizer text_embedder.model = concept_embedder.model text_embedder.embedd_text(class_prompts) # keep original method name # ensure device + dtype text_embeddings = text_embedder.text_embeddings.to(self.device, dtype=torch.float32) # [K, C] text_embeddings = F.normalize(text_embeddings, dim=-1) # check model type for probability transform model_name = getattr(text_embedder, "model_name", "").lower() use_siglip = "siglip" in model_name if use_siglip: # SigLIP style scaling/bias (ensure fp32) scale = text_embedder.model.logit_scale.exp().to(self.device).float() bias = text_embedder.model.logit_bias.to(self.device).float() # shape [K] or [1,K] # counters correct_pooled = 0 correct_soft_avg = 0 correct_hard_majority = 0 for idx, frames in enumerate(self.test_zero_shot): # frames -> frame embeddings [T, C] on device frame_emb = torch.as_tensor(np.array(frames), device=self.device, dtype=torch.float32) frame_emb = F.normalize(frame_emb, dim=-1) # [T, C] # pooled embedding (mean over time) [1, C] pooled_emb = F.normalize(frame_emb.mean(dim=0, keepdim=True), dim=-1) # [1, C] # raw logits if use_siglip: logits_pooled = pooled_emb @ text_embeddings.T logits_pooled = logits_pooled * scale + bias # [1, K] logits_per_frame = (frame_emb @ text_embeddings.T) * scale + bias # [T, K] probs_per_frame = logits_per_frame.sigmoid() # for soft average else: logits_pooled = pooled_emb @ text_embeddings.T # [1, K] logits_per_frame = frame_emb @ text_embeddings.T # [T, K] probs_per_frame = logits_per_frame.softmax(dim=-1) # for soft average # predictions pred_pooled = logits_pooled.argmax(dim=-1).item() # mean-pooled embedding pred_soft_avg = probs_per_frame.mean(dim=0).argmax().item() # soft voting (avg probs) per_frame_preds = logits_per_frame.argmax(dim=-1) # [T] counts = torch.bincount(per_frame_preds, minlength=logits_per_frame.size(1)) pred_hard_majority = counts.argmax().item() # hard majority (mode) # ground truth y = int(self.y_test[idx]) # update counters correct_pooled += int(pred_pooled == y) correct_soft_avg += int(pred_soft_avg == y) correct_hard_majority += int(pred_hard_majority == y) n = max(1, len(self.test_zero_shot)) acc_pooled = correct_pooled / n acc_soft_avg = correct_soft_avg / n acc_hard_majority = correct_hard_majority / n # logging if wandb_run is not None: wandb_run.log( { "zero_shot_acc_pooled": acc_pooled, "zero_shot_acc_soft_avg": acc_soft_avg, "zero_shot_acc_hard_majority": acc_hard_majority, } ) print( f"[ZS] pooled={acc_pooled:.4f} | soft-avg={acc_soft_avg:.4f} | hard-majority={acc_hard_majority:.4f}" ) return { "acc_pooled": acc_pooled, "acc_soft_avg": acc_soft_avg, "acc_hard_majority": acc_hard_majority, } # ------------------------- # Preprocess (unchanged split logic; at end we build datasets) # ------------------------- def preprocess(self, dataset: str, info: Optional[str] = None, test_size: float = 0.2, random_state: int = 42,): binary_array = [] def get_index(info): if info == "s1": index = 1 elif info == "s2": index = 2 elif info == "s3": index = 3 else: index = 1 return index if info: if dataset == "breakfast": RANGES = { "s1": range(3, 16), "s2": range(16, 29), "s3": range(29, 42), "s4": range(42, 54), } def split_paths_by_group(paths, group_name, ranges=RANGES): if group_name not in ranges: raise ValueError( f"Unknown group '{group_name}'. Expected one of {list(ranges)}" ) target = ranges[group_name] for p in paths: if any(re.search(rf"P{num:02}", p) for num in target): binary_array.append(False) else: binary_array.append(True) return binary_array binary_array = split_paths_by_group(self.video_paths, info) elif dataset == "ucf101": index = get_index(info) ucf_test_list = ( f"../Datasets/UCF101/ucfTrainTestlist/testlist0{index}.txt" ) path_list = pd.read_csv(ucf_test_list, sep=" ", header=None) for path in self.video_paths: path_rel = path.split("Video_data/")[1].replace(".mp4", ".avi") binary_array.append( False if path_rel in path_list[0].values else True ) elif dataset == "hmdb51": index = get_index(info) labels_path = "../Datasets/HMDB/testTrainMulti_7030_splits/" path_text_dirs = glob.glob(os.path.join(labels_path, "*.txt")) path_text_dirs_idx = [p for p in path_text_dirs if f"split{index}" in p] path_text_dirs_idx.sort() path_list_test, path_list_train, path_list_ignore = set(), set(), set() for txt_path in path_text_dirs_idx: with open(txt_path, "r") as fh: for line in fh: name, flag = line.strip().split() if flag == "2": path_list_test.add(name) elif flag == "0": path_list_ignore.add(name) else: path_list_train.add(name) mask = [] for vp in self.video_paths: basename = os.path.basename(vp).replace(".mp4", ".avi") if basename in path_list_test: mask.append(False) elif basename in path_list_train: mask.append(True) elif basename in path_list_ignore: mask.append(None) else: mask.append(None) kept = [ (x, y, p, b, m) for x, y, p, b, m in zip( self.raw_activations, self.all_labels, self.video_paths, self.video_embeddings, mask, ) if m is not None ] if not kept: raise ValueError( "HMDB split produced no usable items. Check paths and split lists." ) ( self.raw_activations, self.all_labels, self.video_paths, self.video_embeddings, mask_kept, ) = map(list, zip(*kept)) self.video_ids = [ ( int(os.path.splitext(os.path.basename(p))[0]) if os.path.splitext(os.path.basename(p))[0].isdigit() else None ) for p in self.video_paths ] self.kept_ids = {vid for vid in self.video_ids if vid is not None} binary_array = [True if m else False for m in mask_kept] elif dataset == "something2": # ===== SSv2 handling ===== def replace_something(text: str) -> str: return re.sub(r"\[(.*?)\]", r"\1", text) val_json = "../Datasets/Something2/labels/validation.json" train_json = "../Datasets/Something2/labels/train.json" test_json = "../Datasets/Something2/labels/test.json" test_csv = "../Datasets/Something2/labels/test-answers.csv" df_train = pd.read_json(train_json) df_val = pd.read_json(val_json) df_test = pd.read_json(test_json) train_ids = [int(row[0]) for row in df_train.values.tolist()] val_ids = [int(row[0]) for row in df_val.values.tolist()] test_ids = [int(row[0]) for row in df_test.values.tolist()] train_labels = [replace_something(t) for t in df_train["template"]] val_labels = [replace_something(t) for t in df_val["template"]] test_tbl = pd.read_csv( test_csv, sep=";", header=None, dtype={0: int, 1: str} ) test_labels_map = dict(zip(test_tbl[0].tolist(), test_tbl[1].tolist())) test_labels = [test_labels_map[i] for i in test_ids] id2split = {} id2split.update( {i: ("train", l) for i, l in zip(train_ids, train_labels)} ) id2split.update({i: ("val", l) for i, l in zip(val_ids, val_labels)}) id2split.update({i: ("test", l) for i, l in zip(test_ids, test_labels)}) train_x, val_x, test_x = [], [], [] train_y, val_y, test_y = [], [], [] self.test_zero_shot = [] self.paths_train, self.paths_val, self.paths_test = [], [], [] self.video_ids = [self.path_to_id(p) for p in self.video_paths] missed = 0 for idx, vid in enumerate(self.video_ids): if vid is None: missed += 1 continue entry = id2split.get(vid) if entry is None: missed += 1 continue split, lab = entry if split == "train": train_x.append(self.raw_activations[idx]) train_y.append(lab) self.paths_train.append(self.video_paths[idx]) elif split == "val": val_x.append(self.raw_activations[idx]) val_y.append(lab) self.paths_val.append(self.video_paths[idx]) elif split == "test": test_x.append(self.raw_activations[idx]) test_y.append(lab) self.paths_test.append(self.video_paths[idx]) self.test_zero_shot.append(self.video_embeddings[idx]) if missed: print( f"[SSv2] Skipped {missed} items (no parseable ID or not in official splits)." ) if len(train_x) == 0: raise RuntimeError( "[SSv2] No training samples matched. Check filename-to-ID parsing and dataset paths." ) self.encoder = self.encoder.fit(train_y) self.X_train, self.y_train = train_x, self.encoder.transform( np.array(train_y, dtype=object) ) self.X_val, self.y_val = val_x, ( self.encoder.transform(np.array(val_y, dtype=object)) if len(val_x) else (None, None) ) self.X_test, self.y_test = test_x, ( self.encoder.transform(np.array(test_y, dtype=object)) if len(test_x) else (None, None) ) # ===== end SSv2 ===== if dataset != "something2": self.X_train = [ self.raw_activations[i] for i in range(len(self.raw_activations)) if binary_array[i] ] self.X_test = [ self.raw_activations[i] for i in range(len(self.raw_activations)) if not binary_array[i] ] self.y_train = [ self.all_labels[i] for i in range(len(self.all_labels)) if binary_array[i] ] self.y_test = [ self.all_labels[i] for i in range(len(self.all_labels)) if not binary_array[i] ] self.paths_train = [ self.video_paths[i] for i in range(len(self.video_paths)) if binary_array[i] ] self.paths_test = [ self.video_paths[i] for i in range(len(self.video_paths)) if not binary_array[i] ] self.encoder = self.encoder.fit(self.y_train) self.y_train = self.encoder.transform(self.y_train) self.y_test = self.encoder.transform(self.y_test) self.test_zero_shot = [ self.video_embeddings[i] for i in range(len(self.video_embeddings)) if not binary_array[i] ] else: # Stratified random split (non-SSv2) ( self.X_train, self.X_test, self.y_train, self.y_test, self.paths_train, self.paths_test, ) = train_test_split( self.raw_activations, self.all_labels, self.video_paths, test_size=test_size, random_state=random_state, stratify=self.all_labels, ) self.encoder = self.encoder.fit(self.y_train) self.y_train = self.encoder.transform(self.y_train) self.y_test = self.encoder.transform(self.y_test) # ----- Standardization ----- self.mean_c, self.std_c = compute_concept_standardization(self.X_train) self.X_train = apply_standardization(self.X_train, self.mean_c, self.std_c) self.X_test = apply_standardization(self.X_test, self.mean_c, self.std_c) if self.X_val is not None: self.X_val = apply_standardization(self.X_val, self.mean_c, self.std_c) # ----- Class weights ----- classes, counts = np.unique(self.y_train, return_counts=True) self.class_weights = torch.tensor(counts.max() / counts, dtype=torch.float32) self.num_concepts = self.X_train[0].shape[-1] self.num_classes = len(classes) def train_model( self, num_epochs: int, l1_lambda: float, lambda_sparse: float, batch_size: int = 8, lr: float = 1e-4, weight_decay: float = 1e-2, enforce_nonneg: bool = True, class_weights: bool = True, wandb_run: Optional[wandb.WandbRun] = None, random_seed: int = 42, ckpt_path: Optional[str] = None, early_stopping_patience: int = 50, ): if wandb_run is not None: wandb_run.config.update( { "num_epochs": num_epochs, "l1_lambda": l1_lambda, "lambda_sparse": lambda_sparse, "lr": lr, "weight_decay": weight_decay, "batch_size": batch_size, "enforce_nonneg": enforce_nonneg, "class_weights": class_weights, "transformer_layers": self.model.transformer_layers, "lse_tau": self.model.lse_tau, "diagonal_attention": self.model.diagonal_attention, "early_stopping_patience": early_stopping_patience, } ) # move model to device self.model.to(self.device) optimizer = torch.optim.AdamW( self.model.parameters(), lr=lr, weight_decay=weight_decay ) if class_weights: criterion = nn.CrossEntropyLoss( weight=self.class_weights.to(self.device), label_smoothing=0.1 ) else: criterion = nn.CrossEntropyLoss(label_smoothing=0.1) num_train = len(self.X_train) best_metric = -float("inf") best_state = None best_epoch = -1 epochs_since_improvement = 0 use_early_stopping = (early_stopping_patience is not None) and ( len(self.X_test) > 0 ) for epoch in range(num_epochs): self.model.train() correct, total = 0, 0 last_loss, last_L_sparse = None, None epoch_L_sparse_sum, epoch_batches = 0.0, 0 base_seed = int(getattr(self, "seed", random_seed)) g = torch.Generator(device="cpu").manual_seed(base_seed + epoch) perm_tensor = torch.randperm(num_train, generator=g) perm = perm_tensor.tolist() for start in range(0, num_train, batch_size): end = min(start + batch_size, num_train) idx = perm[start:end] batch_seqs = [self.X_train[i] for i in idx] batch_labels = torch.tensor( [int(self.y_train[i]) for i in idx], dtype=torch.long, device=self.device, ) inputs, pad_mask = pad_batch_sequences(batch_seqs, device=self.device) optimizer.zero_grad() # updated forward: now returns sharpness logits, concepts_, concepts_t, sharpness = self.model( inputs, key_padding_mask=pad_mask ) valid = (~pad_mask).unsqueeze(-1).float() last_L_sparse = (concepts_t.abs() * valid).sum() / ( valid.sum() * concepts_t.shape[-1] ).clamp(min=1.0) ce = criterion(logits, batch_labels) l1 = l1_lambda * self.model.classifier.weight.abs().sum() loss = ce + l1 + lambda_sparse * last_L_sparse loss.backward() optimizer.step() last_loss = loss # accumulate for epoch-average L_sparse epoch_L_sparse_sum += float(last_L_sparse.detach().item()) epoch_batches += 1 if enforce_nonneg: with torch.no_grad(): self.model.classifier.weight.clamp_(min=0.0) preds = logits.argmax(dim=1) correct += int((preds == batch_labels).sum().item()) total += batch_labels.shape[0] acc = correct / max(1, total) epoch_L_sparse = epoch_L_sparse_sum / max(1, epoch_batches) # ===== evaluation ===== def evaluate(dataset_X, dataset_y): self.model.eval() correct, total = 0, 0 sharpness_vals = [] with torch.no_grad(): for start in range(0, len(dataset_X), batch_size): end = min(start + batch_size, len(dataset_X)) batch_seqs = [dataset_X[i] for i in range(start, end)] batch_labels = torch.tensor( [int(dataset_y[i]) for i in range(start, end)], dtype=torch.long, device=self.device, ) inputs, pad_mask = pad_batch_sequences( batch_seqs, device=self.device ) logits, _, _, sharpness = self.model( inputs, key_padding_mask=pad_mask ) preds = logits.argmax(dim=1) correct += int((preds == batch_labels).sum().item()) total += batch_labels.shape[0] for b in range(logits.shape[0]): sharpness_vals.append( { "concepts_max": float( sharpness["concepts"]["max"][b] .mean() .detach() .cpu() .item() ), "concepts_entropy": float( sharpness["concepts"]["entropy"][b] .mean() .detach() .cpu() .item() ), "logits_max": float( sharpness["logits"]["max"][b] .mean() .detach() .cpu() .item() ), "logits_entropy": float( sharpness["logits"]["entropy"][b] .mean() .detach() .cpu() .item() ), } ) acc = correct / max(1, total) if sharpness_vals: mean_sharp = { k: float(np.mean([s[k] for s in sharpness_vals])) for k in sharpness_vals[0] } else: mean_sharp = {} return acc, mean_sharp test_acc, test_sharp = ( (0.0, {}) if len(self.X_test) == 0 else evaluate(self.X_test, self.y_test) ) val_acc, val_sharp = ( (0.0, {}) if self.X_val is None else evaluate(self.X_val, self.y_val) ) metric = test_acc if len(self.X_test) > 0 else acc # ===== checkpointing ===== if metric > best_metric + 1e-8: best_metric = metric best_epoch = epoch epochs_since_improvement = 0 best_state = { k: v.detach().cpu().clone() for k, v in self.model.state_dict().items() } if ckpt_path: tmp = ckpt_path + ".tmp" torch.save(best_state, tmp) os.replace(tmp, ckpt_path) else: epochs_since_improvement += 1 # ===== wandb logging ===== if wandb_run is not None: current_lr = ( optimizer.param_groups[0]["lr"] if optimizer.param_groups else None ) log_data = { "epoch": epoch + 1, "train_loss": ( float(last_loss.item()) if last_loss is not None else None ), "train_acc": acc, "test_acc": test_acc, "val_acc": val_acc if self.X_val is not None else None, "L_sparse": ( float(last_L_sparse.item()) if last_L_sparse is not None else None ), "learning_rate": current_lr, "best_val_acc": best_metric, "epochs_since_improvement": epochs_since_improvement, } # add sharpness metrics for prefix, sharp in [("test_", test_sharp), ("val_", val_sharp)]: for k, v in sharp.items(): log_data[prefix + "sharp_" + k] = v wandb_run.log(log_data) if epoch % 10 == 0 or epoch == num_epochs - 1: msg_loss = ( float(last_loss.item()) if last_loss is not None else float("nan") ) msg_sparse = ( float(last_L_sparse.item()) if last_L_sparse is not None else float("nan") ) print( f"Epoch {epoch+1}/{num_epochs} | loss {msg_loss:.4f} | test_acc {test_acc:.4f} " f"| train_acc {acc:.4f} | L_sparse {msg_sparse:.4f} " f"| best_val {best_metric:.4f} | epochs_no_improve {epochs_since_improvement}" ) # early stopping if ( use_early_stopping and epochs_since_improvement >= early_stopping_patience ): print( f"[MoTIF] Early stopping triggered (no improvement for {epochs_since_improvement} epochs). Stopping at epoch {epoch+1}." ) if wandb_run is not None: wandb_run.log( { "early_stopped_epoch": epoch + 1, "early_stopping_patience": early_stopping_patience, } ) break # ===== restore best ===== if best_state is not None: self.model.load_state_dict(best_state, strict=True) self.model.eval() print( f"[MoTIF] Restored best weights from epoch {best_epoch+1} (metric={best_metric:.4f})." ) else: print("[MoTIF] No best_state captured (empty training?).") # ------------------------- # PerConceptAffine + CBMTransformer using the per-channel temporal block # ------------------------- class PerConceptAffine(nn.Module): def __init__(self, num_concepts: int): super().__init__() self.scale = nn.Parameter(torch.ones(num_concepts)) self.bias = nn.Parameter(torch.zeros(num_concepts)) def forward(self, x: torch.Tensor): ## Comment out to test no scaling and bias ablation for paper y = F.softplus(x * self.scale + self.bias) - math.log(2.0) return y.clamp(min=0.0) class CBMTransformer(nn.Module): def __init__( self, num_concepts: int, num_classes: int, transformer_layers: int = 1, dropout: float = 0.1, lse_tau: float = 1.0, nonneg_classifier: bool = False, diagonal_attention: bool = True, dimension=1, ): super().__init__() self.lse_tau = lse_tau self.diagonal_attention = diagonal_attention self.transformer_layers = transformer_layers self.posenc = PositionalEncoding( d_model=num_concepts, dropout=dropout, max_len=2000 ) if diagonal_attention: self.layers = nn.ModuleList( [ PerChannelTemporalBlock( C=num_concepts, dropout=dropout, d=dimension ) for _ in range(transformer_layers) ] ) else: self.layers = nn.ModuleList( [ FullAttentionTemporalBlock( C=num_concepts, num_heads=None, dropout=dropout ) for _ in range(transformer_layers) ] ) self.norm = nn.LayerNorm(num_concepts) self.concept_predictor = PerConceptAffine(num_concepts) if nonneg_classifier: self.classifier = NonNegativeLinear(num_concepts, num_classes) else: self.classifier = nn.Linear(num_concepts, num_classes) # for introspection self.last_time_importance = None # [B,T] detached def forward( self, window_embeddings: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, channel_ids: Optional[Union[List[int], torch.Tensor]] = None, window_ids: Optional[Union[List[int], torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ window_embeddings: [B,T,C] or [T,C] key_padding_mask: [B,T] with True for padded tokens to be ignored Returns: logits: [B,K] pooled class logits concepts: [B,C] pooled concept activations concepts_t: [B,T,C] per-time-step concepts sharpness: dict with 'concepts' and 'logits' sharpness per batch """ x = window_embeddings if x.dim() == 2: x = x.unsqueeze(0) # [1,T,C] if key_padding_mask is not None and key_padding_mask.dim() == 1: key_padding_mask = key_padding_mask.unsqueeze(0) # --- transformer backbone --- x = self.posenc(x) # [B,T,C] for layer in self.layers: x = layer(x, key_padding_mask=key_padding_mask) x = self.norm(x) # [B,T,C] # --- concept predictions per time step --- concepts_t = self.concept_predictor(x) # [B,T,C] # --- concept interventions --- if channel_ids is not None and window_ids is not None: concepts_t[:, window_ids, channel_ids] = 0 elif channel_ids is not None: concepts_t[:, :, channel_ids] = 0 elif window_ids is not None: concepts_t[:, window_ids, :] = 0 logits_t = self.classifier(concepts_t) # [B,T,K] tau = self.lse_tau # --- LSE pooling over time --- if key_padding_mask is not None: concepts_t_masked = concepts_t.masked_fill( key_padding_mask.unsqueeze(-1), float("-inf") ) logits_t_masked = logits_t.masked_fill( key_padding_mask.unsqueeze(-1), float("-inf") ) concepts = (concepts_t_masked * tau).logsumexp(dim=1) / tau # [B,C] logits = (logits_t_masked * tau).logsumexp(dim=1) / tau # [B,K] else: concepts = (concepts_t * tau).logsumexp(dim=1) / tau logits = (logits_t * tau).logsumexp(dim=1) / tau # --- temporal importance for explanation --- with torch.no_grad(): pred = logits.argmax(dim=1) # [B] sel = torch.gather(logits_t, dim=2, index=pred[:, None, None]).squeeze( -1 ) # [B,T] if key_padding_mask is not None: sel = sel.masked_fill(key_padding_mask, float("-inf")) self.last_time_importance = torch.softmax( sel / tau, dim=1 ).detach() # softmax importance # --- compute sharpness of LSE pooled distributions --- def compute_sharpness(x_t, mask=None): """Compute max / entropy as sharpness metric for batch""" if mask is not None: x_t = x_t.masked_fill(mask.unsqueeze(-1), float("-inf")) probs = torch.softmax(tau * x_t, dim=1) probs = probs.clamp(min=1e-8) # avoids log(0) max_prob = probs.max(dim=1).values # [B] entropy = -(probs * probs.log()).sum(dim=1) return {"max": max_prob, "entropy": entropy} sharpness = { "concepts": compute_sharpness(concepts_t, key_padding_mask), "logits": compute_sharpness(logits_t, key_padding_mask), } return logits, concepts, concepts_t, sharpness def get_attention_maps(self): # list of [B, C, T, T] (detached) return [ layer.attn_weights.cpu() if layer.attn_weights is not None else None for layer in self.layers ] def mean_cbm(model, wandb_run=None): X_train, X_test = model.X_train.copy(), model.X_test.copy() y_train, y_test = model.y_train.copy(), model.y_test.copy() num_classes = model.num_classes num_concepts = model.num_concepts batch_size = 1 device = getattr(model, "device", get_torch_device()) random = False # was for testing if random: def get_random_image(x): idx = np.random.randint(0, len(x)) return x[idx] # Replace each video with a random frame (as np array) X_train_random = [get_random_image(x) for x in X_train] X_test_random = [get_random_image(x) for x in X_test] X_train_mean = X_train_random X_test_mean = X_test_random else: # take mean X_train_mean = [torch.mean(x, axis=0) for x in X_train] # [T,C] -> [C] X_test_mean = [torch.mean(x, axis=0) for x in X_test] # [T,C] -> [C] # Stack into arrays before converting to torch tensors X_train_arr = np.stack( [ t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) for t in X_train_mean ] ) X_test_arr = np.stack( [ t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) for t in X_test_mean ] ) tensor_train = torch.tensor(X_train_arr, dtype=torch.float32, device=device) tensor_test = torch.tensor(X_test_arr, dtype=torch.float32, device=device) # train a linear model on the random/mean frames linear_model = nn.Linear(num_concepts, num_classes).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(linear_model.parameters(), lr=0.001) num_epochs = 200 for epoch in range(num_epochs): linear_model.train() optimizer.zero_grad() outputs = linear_model(tensor_train) loss = criterion( outputs, torch.tensor(y_train, dtype=torch.long, device=device) ) loss.backward() optimizer.step() if wandb_run is not None: with torch.no_grad(): preds = outputs.argmax(dim=1) acc = (preds.detach().cpu().numpy() == y_train).mean() current_lr = ( optimizer.param_groups[0]["lr"] if optimizer.param_groups else None ) wandb_run.log( { "mean_train_loss": loss.item(), "mean_train_acc": acc, "mean_learning_rate": current_lr, } ) linear_model.eval() with torch.no_grad(): outputs = linear_model(tensor_test) _, predicted = torch.max(outputs, 1) accuracy = (predicted.detach().cpu().numpy() == y_test).mean() print(f"CBM accuracy test: {accuracy:.4f}") if wandb_run is not None: wandb_run.log({"mean_test_acc": accuracy}) class NonNegativeLinear: def __init__(self, in_features, out_features, bias=True): self.linear = nn.Linear(in_features, out_features, bias=bias) def forward(self, x): self.linear.weight.data.clamp_(min=0.0) return self.linear(x)