| | from __future__ import annotations |
| |
|
| | import os |
| | import random |
| | import numpy as np |
| | import torch |
| | import copy |
| |
|
| |
|
| | from typing import List, Optional, Dict, Tuple |
| |
|
| | import cv2 |
| | from PIL import Image |
| | import tqdm |
| |
|
| | import torch.nn as nn |
| | import gc |
| |
|
| | import torch.nn.functional as F |
| | from torchvision.transforms import ( |
| | Compose, |
| | Resize, |
| | CenterCrop, |
| | ToTensor, |
| | Normalize, |
| | InterpolationMode, |
| | ) |
| | import math |
| | from sklearn.preprocessing import LabelEncoder |
| | from sklearn.model_selection import train_test_split |
| | import wandb |
| | import re |
| | import pandas as pd |
| | import glob |
| |
|
| | def init_repro(seed: int = 42, deterministic: bool = True): |
| | """Call this at the very top of your notebook/script BEFORE creating any model/processor/device context.""" |
| | os.environ["PYTHONHASHSEED"] = str(seed) |
| | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ( |
| | ":16:8" |
| | ) |
| | os.environ["OMP_NUM_THREADS"] = "1" |
| | os.environ["MKL_NUM_THREADS"] = "1" |
| |
|
| | random.seed(seed) |
| | np.random.seed(seed) |
| |
|
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | |
| | if deterministic: |
| | try: |
| | torch.use_deterministic_algorithms(True) |
| | except Exception: |
| | |
| | torch.set_deterministic(True) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cuda.matmul.allow_tf32 = False |
| | torch.backends.cudnn.allow_tf32 = False |
| |
|
| | |
| | torch.set_num_threads(1) |
| |
|
| | return seed |
| |
|
| | def get_torch_device(prefer: Optional[str] = None) -> torch.device: |
| | if prefer is not None: |
| | pref = prefer.lower() |
| | if pref == "cuda" and torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | if ( |
| | pref == "mps" |
| | and hasattr(torch.backends, "mps") |
| | and torch.backends.mps.is_available() |
| | ): |
| | return torch.device("mps") |
| | if pref == "cpu": |
| | return torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | return torch.device("mps") |
| | return torch.device("cpu") |
| |
|
| |
|
| | def pad_batch_sequences( |
| | seqs: List[torch.Tensor], device: torch.device |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """ |
| | Pad a list of [T_i, C] tensors into a batch [B, T_max, C] and return |
| | a key_padding_mask [B, T_max] with True for padded positions. |
| | """ |
| | if len(seqs) == 0: |
| | raise ValueError("pad_batch_sequences received empty sequence list") |
| | lengths = [int(s.shape[0]) for s in seqs] |
| | C = int(seqs[0].shape[1]) |
| | T_max = int(max(lengths)) |
| | B = len(seqs) |
| | batch = torch.zeros((B, T_max, C), dtype=torch.float32, device=device) |
| | mask = torch.ones((B, T_max), dtype=torch.bool, device=device) |
| | for i, s in enumerate(seqs): |
| | t = lengths[i] |
| | batch[i, :t, :] = s.to(device) |
| | mask[i, :t] = False |
| | return batch, mask |
| |
|
| |
|
| | def compute_concept_standardization(seqs: List[torch.Tensor | np.ndarray]): |
| | cat = torch.cat( |
| | [ |
| | ( |
| | s |
| | if isinstance(s, torch.Tensor) |
| | else torch.tensor(np.array(s), dtype=torch.float32) |
| | ) |
| | for s in seqs |
| | ], |
| | dim=0, |
| | ) |
| | mean = cat.mean(dim=0) |
| | std = cat.std(dim=0).clamp_min(1e-6) |
| | return mean, std |
| |
|
| |
|
| | def apply_standardization( |
| | seqs: List[torch.Tensor | np.ndarray], mean: torch.Tensor, std: torch.Tensor |
| | ): |
| | out = [] |
| | for s in seqs: |
| | s_t = ( |
| | s |
| | if isinstance(s, torch.Tensor) |
| | else torch.tensor(np.array(s), dtype=torch.float32) |
| | ) |
| | out.append((s_t - mean) / std) |
| | return out |
| |
|
| |
|
| | def concepts_over_time_cosine( |
| | concepts: torch.Tensor, |
| | all_data_list, |
| | device: torch.device = torch.device("cpu"), |
| | dtype: torch.dtype = torch.float32, |
| | chunk_size: int | None = None, |
| | ): |
| | """ |
| | Cosine-sim per frame vs concepts. |
| | - Normalizes in fp32 for stability, computes in fp32, then returns on CPU. |
| | - Optional chunked matmul to cap peak memory. |
| | """ |
| | with torch.no_grad(): |
| | |
| | c = F.normalize( |
| | concepts.detach().to(device=device, dtype=torch.float32), dim=1 |
| | ) |
| | K = c.shape[0] |
| |
|
| | activations, embeddings = [], [] |
| |
|
| | for vid in all_data_list: |
| | x = vid if isinstance(vid, torch.Tensor) else torch.as_tensor(vid) |
| | if x.ndim == 1: |
| | x = x.unsqueeze(0) |
| | elif x.ndim > 2: |
| | x = x.view(-1, x.size(-1)) |
| | x = x.detach().to(device=device, dtype=torch.float32) |
| |
|
| | if x.numel() == 0: |
| | sim = torch.empty((0, K), dtype=torch.float32, device=device) |
| | else: |
| | x = F.normalize(x, dim=1) |
| | if chunk_size is None or x.shape[0] <= chunk_size: |
| | sim = x @ c.T |
| | else: |
| | |
| | outs = [] |
| | for s in range(0, x.shape[0], chunk_size): |
| | outs.append(x[s : s + chunk_size] @ c.T) |
| | sim = torch.cat(outs, dim=0) |
| | sim = torch.clamp(sim, min=0.0) |
| |
|
| | |
| | activations.append(sim.to("cpu", dtype=dtype)) |
| | embeddings.append(vid) |
| |
|
| | return activations, embeddings |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | """ |
| | Supports both [T, C] and [B, T, C] input tensors, automatically unsqueezing and squeezing as needed for 2D input. |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000): |
| | super().__init__() |
| | self.dropout = nn.Dropout(p=dropout) |
| | position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze( |
| | 1 |
| | ) |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2, dtype=torch.float32) |
| | * (-math.log(10000.0) / d_model) |
| | ) |
| | pe = torch.zeros(max_len, d_model, dtype=torch.float32) |
| |
|
| | |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | if d_model % 2 == 0: |
| | |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | else: |
| | |
| | div_term_cos = torch.exp( |
| | torch.arange(0, d_model - 1, 2, dtype=torch.float32) |
| | * (-math.log(10000.0) / d_model) |
| | ) |
| | pe[:, 1::2] = torch.cos(position * div_term_cos) |
| |
|
| | self.register_buffer("pe", pe) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | Handles both 2D and 3D input, automatically unsqueezing and squeezing for [T, C] input. Positional encoding is broadcast over the batch dimension. |
| | """ |
| | squeeze_back = False |
| | if x.dim() == 2: |
| | x = x.unsqueeze(0) |
| | squeeze_back = True |
| | seq_len = x.size(1) |
| | x = x + self.pe[:seq_len, :] |
| | x = self.dropout(x) |
| | if squeeze_back: |
| | x = x.squeeze(0) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | class DiagQKVd(nn.Module): |
| | """Per-channel Q/K/V with width d (no cross-concept mixing).""" |
| |
|
| | def __init__(self, C: int, d: int = 8, bias: bool = True): |
| | super().__init__() |
| | self.C, self.d = C, d |
| | |
| | self.q = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) |
| | self.k = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) |
| | self.v = nn.Conv1d(C, C * d, 1, groups=C, bias=bias) |
| |
|
| | def forward(self, x): |
| | B, T, C = x.shape |
| | xc = x.transpose(1, 2) |
| | Q = self.q(xc).transpose(1, 2).view(B, T, C, self.d) |
| | K = self.k(xc).transpose(1, 2).view(B, T, C, self.d) |
| | V = self.v(xc).transpose(1, 2).view(B, T, C, self.d) |
| | return Q, K, V |
| |
|
| | class ChannelTimeNorm(nn.Module): |
| | def __init__(self, C, eps=1e-5, affine=True): |
| | super().__init__() |
| | self.ln = nn.LayerNorm(C, eps=eps, elementwise_affine=affine) |
| |
|
| | def forward(self, x): |
| | return self.ln(x) |
| |
|
| |
|
| | class PerChannelFFN(nn.Module): |
| | """Per-channel FFN (no cross-concept mixing).""" |
| |
|
| | def __init__(self, C: int, dropout: float = 0.1): |
| | super().__init__() |
| | self.fc1 = nn.Conv1d( |
| | C, C, kernel_size=1, groups=C, bias=True |
| | ) |
| | self.fc2 = nn.Conv1d(C, C, kernel_size=1, groups=C, bias=True) |
| | self.act = nn.GELU() |
| | self.drop = nn.Dropout(dropout) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | xc = x.transpose(1, 2) |
| | y = self.fc2(self.drop(self.act(self.fc1(xc)))) |
| | return y.transpose(1, 2) |
| |
|
| |
|
| | class PerChannelTemporalBlock(nn.Module): |
| | """ |
| | Attention over time for each concept channel independently. |
| | Stores attn_weights: [B, C, T, T]. |
| | """ |
| |
|
| | def __init__(self, C: int, d: int = 1, dropout: float = 0.1, T_max: int = 1024): |
| | super().__init__() |
| | self.C, self.d = C, d |
| | self.qkv = DiagQKVd(C, d) |
| | self.scale = d**-0.5 |
| | self.logit_scale = nn.Parameter(torch.zeros(C)) |
| |
|
| | self.norm1 = ChannelTimeNorm(C) |
| | self.norm2 = ChannelTimeNorm(C) |
| | self.drop = nn.Dropout(dropout) |
| |
|
| | self.ffn = PerChannelFFN(C, dropout=dropout) |
| |
|
| | self.act = nn.GELU() |
| |
|
| | self.attn_weights = None |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | key_padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | B, T, C = x.shape |
| |
|
| | |
| | y = self.norm1(x) |
| |
|
| | |
| | Q, K, V = self.qkv(y) |
| |
|
| | |
| | scores = torch.einsum("btcd,bucd->bctu", Q, K) * self.scale |
| |
|
| | |
| | if attn_mask is not None: |
| | |
| | if attn_mask.dtype == torch.bool: |
| | am = torch.zeros_like(attn_mask, dtype=scores.dtype) |
| | am = am.masked_fill(attn_mask, float("-inf")) |
| | else: |
| | am = attn_mask.to(dtype=scores.dtype) |
| | scores = scores + am.view(1, 1, T, T) |
| |
|
| | if key_padding_mask is not None: |
| | kpm = key_padding_mask.view(B, 1, 1, T) |
| | scores = scores.masked_fill(kpm, float("-inf")) |
| |
|
| | |
| | w = torch.softmax(scores, dim=-1) |
| | self.attn_weights = w.detach() |
| |
|
| | |
| | out = torch.einsum("bctu,bucd->btcd", w, V).mean(dim=-1) |
| |
|
| | |
| | x = x + self.drop(out) |
| |
|
| | |
| | z = self.norm2(x) |
| | z = self.ffn(z) |
| |
|
| | |
| | x = x + self.drop(z) |
| | return x |
| |
|
| |
|
| | def _pick_num_heads(C: int, proposed: Optional[int]) -> int: |
| | if proposed is not None and proposed >= 1 and C % proposed == 0: |
| | return proposed |
| | for h in [8, 6, 4, 3, 2]: |
| | if h <= C and C % h == 0: |
| | return h |
| | return 1 |
| |
|
| |
|
| | class FullAttentionTemporalBlock(nn.Module): |
| | """ |
| | Full multi-head self-attention over time with channel mixing (manual implementation). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | C: int, |
| | num_heads: Optional[int] = None, |
| | dropout: float = 0.1, |
| | ffn_mult: int = 4, |
| | ): |
| | super().__init__() |
| | self.C = C |
| | self.H = _pick_num_heads(C, num_heads) |
| | self.d = C // self.H |
| | assert self.H * self.d == C, "C must be divisible by num_heads" |
| |
|
| | |
| | self.q_proj = nn.Linear(C, C, bias=True) |
| | self.k_proj = nn.Linear(C, C, bias=True) |
| | self.v_proj = nn.Linear(C, C, bias=True) |
| | self.o_proj = nn.Linear(C, C, bias=True) |
| |
|
| | self.attn_drop = nn.Dropout(dropout) |
| | self.proj_drop = nn.Dropout(dropout) |
| |
|
| | self.ffn = nn.Sequential( |
| | nn.Linear(C, ffn_mult * C), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(ffn_mult * C, C), |
| | ) |
| | self.dropout = nn.Dropout(dropout) |
| | self.norm1 = nn.LayerNorm(C) |
| | self.norm2 = nn.LayerNorm(C) |
| |
|
| | self.attn_weights = None |
| |
|
| | def _shape_heads(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | B, T, _ = x.shape |
| | return x.view(B, T, self.H, self.d).permute(0, 2, 1, 3) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | key_padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | assert x.dim() == 3, "x must be [B, T, C]" |
| | B, T, C = x.shape |
| | assert C == self.C |
| |
|
| | |
| | Q = self._shape_heads(self.q_proj(x)) |
| | K = self._shape_heads(self.k_proj(x)) |
| | V = self._shape_heads(self.v_proj(x)) |
| |
|
| | |
| | scale = self.d**-0.5 |
| | scores = torch.matmul(Q, K.transpose(-2, -1)) * scale |
| |
|
| | |
| | if attn_mask is not None: |
| | |
| | if attn_mask.dtype == torch.bool: |
| | am = torch.zeros_like(attn_mask, dtype=Q.dtype) |
| | am = am.masked_fill(attn_mask, float("-inf")) |
| | else: |
| | am = attn_mask.to(dtype=Q.dtype) |
| | scores = scores + am.view(1, 1, T, T) |
| |
|
| | if key_padding_mask is not None: |
| | kpm = key_padding_mask.to(torch.bool).view( |
| | B, 1, 1, T |
| | ) |
| | scores = scores.masked_fill(kpm, float("-inf")) |
| |
|
| | weights = F.softmax(scores, dim=-1) |
| | weights = self.attn_drop(weights) |
| | self.attn_weights = weights.detach() |
| |
|
| | out = torch.matmul(weights, V) |
| | out = out.permute(0, 2, 1, 3).contiguous() |
| | out = out.view(B, T, C) |
| | out = self.o_proj(out) |
| | out = self.proj_drop(out) |
| |
|
| | |
| | x = self.norm1(x + out) |
| |
|
| | |
| | ff = self.ffn(x) |
| | x = self.norm2(x + self.dropout(ff)) |
| | return x |
| |
|
| |
|
| | class MoTIF: |
| | """ |
| | MoTIF model for video classification using concept bottleneck models. |
| | Assumes: |
| | - concepts_over_time_cosine returns signed cosine sims (no clamp). |
| | - self.model(window_embeddings, key_padding_mask) returns (logits, concepts, concepts_t, sharpness) |
| | """ |
| |
|
| | @staticmethod |
| | def _collate_pad(batch): |
| | """ |
| | batch: list of tuples (seq:[T,C] CPU float32, y:int) |
| | Returns CPU pinned tensors to enable non_blocking .to(device) |
| | """ |
| | B = len(batch) |
| | T = max(seq.shape[0] for seq, _ in batch) |
| | C = batch[0][0].shape[1] |
| | x = torch.zeros((B, T, C), dtype=torch.float32) |
| | mask = torch.ones((B, T), dtype=torch.bool) |
| | y = torch.empty((B,), dtype=torch.long) |
| | for i, (seq, yi) in enumerate(batch): |
| | t = seq.shape[0] |
| | x[i, :t].copy_(seq) |
| | mask[i, :t] = False |
| | y[i] = yi |
| | return x, mask, y |
| |
|
| | def __init__(self, embedder, concepts): |
| | self.device = get_torch_device(prefer="cuda") |
| |
|
| | self.concepts = concepts |
| | self.all_data = embedder.video_embeddings |
| | self.all_labels = ( |
| | embedder.labels |
| | ) |
| | self.video_paths = list(self.all_data.keys()) |
| | self.video_spans = embedder.video_window_spans |
| |
|
| | self.concept_bank = concepts.text_embeddings |
| | self.raw_activations, self.video_embeddings = concepts_over_time_cosine( |
| | self.concept_bank, list(self.all_data.values()) |
| | ) |
| |
|
| | keep_idx = [ |
| | i |
| | for i, act in enumerate(self.raw_activations) |
| | if isinstance(act, torch.Tensor) and act.shape[0] > 0 |
| | ] |
| | if len(keep_idx) != len(self.raw_activations): |
| | removed = len(self.raw_activations) - len(keep_idx) |
| | self.raw_activations = [self.raw_activations[i] for i in keep_idx] |
| | self.video_paths = [self.video_paths[i] for i in keep_idx] |
| | self.all_labels = [self.all_labels[i] for i in keep_idx] |
| | self.video_embeddings = [self.video_embeddings[i] for i in keep_idx] |
| | print(f"[MoTIF] Removed {removed} entries with empty activations.") |
| |
|
| | |
| | self.video_ids = [self.path_to_id(p) for p in self.video_paths] |
| | self.kept_ids = {vid for vid in self.video_ids if vid is not None} |
| |
|
| | |
| | self.encoder = LabelEncoder() |
| | self.class_weights = None |
| |
|
| | self.mean_c, self.std_c = None, None |
| | self.X_train = self.X_val = self.X_test = None |
| | self.y_train = self.y_val = self.y_test = None |
| | self.paths_train = self.paths_val = self.paths_test = None |
| | self.test_zero_shot = None |
| |
|
| | |
| | self.model = None |
| |
|
| | @staticmethod |
| | def path_to_id(p: str): |
| | base = os.path.splitext(os.path.basename(p))[0] |
| | m = re.search(r"(\d+)", base) |
| | return int(m.group(1)) if m else None |
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def zero_shot(self, concept_embedder, wandb_run=None): |
| | assert ( |
| | self.test_zero_shot is not None and self.y_test is not None |
| | ), "Call preprocess(...) first." |
| |
|
| | |
| | class_prompts = ["a video of " + c for c in self.encoder.classes_.tolist()] |
| | text_embedder = copy.copy(concept_embedder) |
| | text_embedder.tokenizer = concept_embedder.tokenizer |
| | text_embedder.model = concept_embedder.model |
| | text_embedder.embedd_text(class_prompts) |
| |
|
| | |
| | text_embeddings = text_embedder.text_embeddings.to(self.device, dtype=torch.float32) |
| | text_embeddings = F.normalize(text_embeddings, dim=-1) |
| |
|
| | |
| | model_name = getattr(text_embedder, "model_name", "").lower() |
| | use_siglip = "siglip" in model_name |
| |
|
| | if use_siglip: |
| | |
| | scale = text_embedder.model.logit_scale.exp().to(self.device).float() |
| | bias = text_embedder.model.logit_bias.to(self.device).float() |
| |
|
| | |
| | correct_pooled = 0 |
| | correct_soft_avg = 0 |
| | correct_hard_majority = 0 |
| |
|
| | for idx, frames in enumerate(self.test_zero_shot): |
| | |
| | frame_emb = torch.as_tensor(np.array(frames), device=self.device, dtype=torch.float32) |
| | frame_emb = F.normalize(frame_emb, dim=-1) |
| |
|
| | |
| | pooled_emb = F.normalize(frame_emb.mean(dim=0, keepdim=True), dim=-1) |
| |
|
| | |
| | if use_siglip: |
| | logits_pooled = pooled_emb @ text_embeddings.T |
| | logits_pooled = logits_pooled * scale + bias |
| | logits_per_frame = (frame_emb @ text_embeddings.T) * scale + bias |
| | probs_per_frame = logits_per_frame.sigmoid() |
| | else: |
| | logits_pooled = pooled_emb @ text_embeddings.T |
| | logits_per_frame = frame_emb @ text_embeddings.T |
| | probs_per_frame = logits_per_frame.softmax(dim=-1) |
| |
|
| | |
| | pred_pooled = logits_pooled.argmax(dim=-1).item() |
| | pred_soft_avg = probs_per_frame.mean(dim=0).argmax().item() |
| |
|
| | per_frame_preds = logits_per_frame.argmax(dim=-1) |
| | counts = torch.bincount(per_frame_preds, minlength=logits_per_frame.size(1)) |
| | pred_hard_majority = counts.argmax().item() |
| |
|
| | |
| | y = int(self.y_test[idx]) |
| |
|
| | |
| | correct_pooled += int(pred_pooled == y) |
| | correct_soft_avg += int(pred_soft_avg == y) |
| | correct_hard_majority += int(pred_hard_majority == y) |
| |
|
| | n = max(1, len(self.test_zero_shot)) |
| | acc_pooled = correct_pooled / n |
| | acc_soft_avg = correct_soft_avg / n |
| | acc_hard_majority = correct_hard_majority / n |
| |
|
| | |
| | if wandb_run is not None: |
| | wandb_run.log( |
| | { |
| | "zero_shot_acc_pooled": acc_pooled, |
| | "zero_shot_acc_soft_avg": acc_soft_avg, |
| | "zero_shot_acc_hard_majority": acc_hard_majority, |
| | } |
| | ) |
| |
|
| | print( |
| | f"[ZS] pooled={acc_pooled:.4f} | soft-avg={acc_soft_avg:.4f} | hard-majority={acc_hard_majority:.4f}" |
| | ) |
| |
|
| | return { |
| | "acc_pooled": acc_pooled, |
| | "acc_soft_avg": acc_soft_avg, |
| | "acc_hard_majority": acc_hard_majority, |
| | } |
| |
|
| | |
| | |
| | |
| | def preprocess(self, |
| | dataset: str, |
| | info: Optional[str] = None, |
| | test_size: float = 0.2, |
| | random_state: int = 42,): |
| | binary_array = [] |
| |
|
| | def get_index(info): |
| | if info == "s1": |
| | index = 1 |
| | elif info == "s2": |
| | index = 2 |
| | elif info == "s3": |
| | index = 3 |
| | else: |
| | index = 1 |
| | return index |
| |
|
| | if info: |
| | if dataset == "breakfast": |
| | RANGES = { |
| | "s1": range(3, 16), |
| | "s2": range(16, 29), |
| | "s3": range(29, 42), |
| | "s4": range(42, 54), |
| | } |
| |
|
| | def split_paths_by_group(paths, group_name, ranges=RANGES): |
| | if group_name not in ranges: |
| | raise ValueError( |
| | f"Unknown group '{group_name}'. Expected one of {list(ranges)}" |
| | ) |
| | target = ranges[group_name] |
| | for p in paths: |
| | if any(re.search(rf"P{num:02}", p) for num in target): |
| | binary_array.append(False) |
| | else: |
| | binary_array.append(True) |
| | return binary_array |
| |
|
| | binary_array = split_paths_by_group(self.video_paths, info) |
| |
|
| | elif dataset == "ucf101": |
| | index = get_index(info) |
| | ucf_test_list = ( |
| | f"../Datasets/UCF101/ucfTrainTestlist/testlist0{index}.txt" |
| | ) |
| | path_list = pd.read_csv(ucf_test_list, sep=" ", header=None) |
| | for path in self.video_paths: |
| | path_rel = path.split("Video_data/")[1].replace(".mp4", ".avi") |
| | binary_array.append( |
| | False if path_rel in path_list[0].values else True |
| | ) |
| |
|
| | elif dataset == "hmdb51": |
| | index = get_index(info) |
| | labels_path = "../Datasets/HMDB/testTrainMulti_7030_splits/" |
| | path_text_dirs = glob.glob(os.path.join(labels_path, "*.txt")) |
| | path_text_dirs_idx = [p for p in path_text_dirs if f"split{index}" in p] |
| | path_text_dirs_idx.sort() |
| | path_list_test, path_list_train, path_list_ignore = set(), set(), set() |
| | for txt_path in path_text_dirs_idx: |
| | with open(txt_path, "r") as fh: |
| | for line in fh: |
| | name, flag = line.strip().split() |
| | if flag == "2": |
| | path_list_test.add(name) |
| | elif flag == "0": |
| | path_list_ignore.add(name) |
| | else: |
| | path_list_train.add(name) |
| | mask = [] |
| | for vp in self.video_paths: |
| | basename = os.path.basename(vp).replace(".mp4", ".avi") |
| | if basename in path_list_test: |
| | mask.append(False) |
| | elif basename in path_list_train: |
| | mask.append(True) |
| | elif basename in path_list_ignore: |
| | mask.append(None) |
| | else: |
| | mask.append(None) |
| | kept = [ |
| | (x, y, p, b, m) |
| | for x, y, p, b, m in zip( |
| | self.raw_activations, |
| | self.all_labels, |
| | self.video_paths, |
| | self.video_embeddings, |
| | mask, |
| | ) |
| | if m is not None |
| | ] |
| | if not kept: |
| | raise ValueError( |
| | "HMDB split produced no usable items. Check paths and split lists." |
| | ) |
| | ( |
| | self.raw_activations, |
| | self.all_labels, |
| | self.video_paths, |
| | self.video_embeddings, |
| | mask_kept, |
| | ) = map(list, zip(*kept)) |
| | self.video_ids = [ |
| | ( |
| | int(os.path.splitext(os.path.basename(p))[0]) |
| | if os.path.splitext(os.path.basename(p))[0].isdigit() |
| | else None |
| | ) |
| | for p in self.video_paths |
| | ] |
| | self.kept_ids = {vid for vid in self.video_ids if vid is not None} |
| | binary_array = [True if m else False for m in mask_kept] |
| |
|
| | elif dataset == "something2": |
| | |
| | def replace_something(text: str) -> str: |
| | return re.sub(r"\[(.*?)\]", r"\1", text) |
| |
|
| | val_json = "../Datasets/Something2/labels/validation.json" |
| | train_json = "../Datasets/Something2/labels/train.json" |
| | test_json = "../Datasets/Something2/labels/test.json" |
| | test_csv = "../Datasets/Something2/labels/test-answers.csv" |
| |
|
| | df_train = pd.read_json(train_json) |
| | df_val = pd.read_json(val_json) |
| | df_test = pd.read_json(test_json) |
| | train_ids = [int(row[0]) for row in df_train.values.tolist()] |
| | val_ids = [int(row[0]) for row in df_val.values.tolist()] |
| | test_ids = [int(row[0]) for row in df_test.values.tolist()] |
| | train_labels = [replace_something(t) for t in df_train["template"]] |
| | val_labels = [replace_something(t) for t in df_val["template"]] |
| | test_tbl = pd.read_csv( |
| | test_csv, sep=";", header=None, dtype={0: int, 1: str} |
| | ) |
| | test_labels_map = dict(zip(test_tbl[0].tolist(), test_tbl[1].tolist())) |
| | test_labels = [test_labels_map[i] for i in test_ids] |
| | id2split = {} |
| | id2split.update( |
| | {i: ("train", l) for i, l in zip(train_ids, train_labels)} |
| | ) |
| | id2split.update({i: ("val", l) for i, l in zip(val_ids, val_labels)}) |
| | id2split.update({i: ("test", l) for i, l in zip(test_ids, test_labels)}) |
| |
|
| | train_x, val_x, test_x = [], [], [] |
| | train_y, val_y, test_y = [], [], [] |
| | self.test_zero_shot = [] |
| | self.paths_train, self.paths_val, self.paths_test = [], [], [] |
| | self.video_ids = [self.path_to_id(p) for p in self.video_paths] |
| | missed = 0 |
| | for idx, vid in enumerate(self.video_ids): |
| | if vid is None: |
| | missed += 1 |
| | continue |
| | entry = id2split.get(vid) |
| | if entry is None: |
| | missed += 1 |
| | continue |
| | split, lab = entry |
| | if split == "train": |
| | train_x.append(self.raw_activations[idx]) |
| | train_y.append(lab) |
| | self.paths_train.append(self.video_paths[idx]) |
| | elif split == "val": |
| | val_x.append(self.raw_activations[idx]) |
| | val_y.append(lab) |
| | self.paths_val.append(self.video_paths[idx]) |
| | elif split == "test": |
| | test_x.append(self.raw_activations[idx]) |
| | test_y.append(lab) |
| | self.paths_test.append(self.video_paths[idx]) |
| | self.test_zero_shot.append(self.video_embeddings[idx]) |
| | if missed: |
| | print( |
| | f"[SSv2] Skipped {missed} items (no parseable ID or not in official splits)." |
| | ) |
| |
|
| | if len(train_x) == 0: |
| | raise RuntimeError( |
| | "[SSv2] No training samples matched. Check filename-to-ID parsing and dataset paths." |
| | ) |
| |
|
| | self.encoder = self.encoder.fit(train_y) |
| | self.X_train, self.y_train = train_x, self.encoder.transform( |
| | np.array(train_y, dtype=object) |
| | ) |
| | self.X_val, self.y_val = val_x, ( |
| | self.encoder.transform(np.array(val_y, dtype=object)) |
| | if len(val_x) |
| | else (None, None) |
| | ) |
| | self.X_test, self.y_test = test_x, ( |
| | self.encoder.transform(np.array(test_y, dtype=object)) |
| | if len(test_x) |
| | else (None, None) |
| | ) |
| |
|
| | |
| | if dataset != "something2": |
| | self.X_train = [ |
| | self.raw_activations[i] |
| | for i in range(len(self.raw_activations)) |
| | if binary_array[i] |
| | ] |
| | self.X_test = [ |
| | self.raw_activations[i] |
| | for i in range(len(self.raw_activations)) |
| | if not binary_array[i] |
| | ] |
| | self.y_train = [ |
| | self.all_labels[i] |
| | for i in range(len(self.all_labels)) |
| | if binary_array[i] |
| | ] |
| | self.y_test = [ |
| | self.all_labels[i] |
| | for i in range(len(self.all_labels)) |
| | if not binary_array[i] |
| | ] |
| | self.paths_train = [ |
| | self.video_paths[i] |
| | for i in range(len(self.video_paths)) |
| | if binary_array[i] |
| | ] |
| | self.paths_test = [ |
| | self.video_paths[i] |
| | for i in range(len(self.video_paths)) |
| | if not binary_array[i] |
| | ] |
| | self.encoder = self.encoder.fit(self.y_train) |
| | self.y_train = self.encoder.transform(self.y_train) |
| | self.y_test = self.encoder.transform(self.y_test) |
| | self.test_zero_shot = [ |
| | self.video_embeddings[i] |
| | for i in range(len(self.video_embeddings)) |
| | if not binary_array[i] |
| | ] |
| |
|
| | else: |
| | |
| | ( |
| | self.X_train, |
| | self.X_test, |
| | self.y_train, |
| | self.y_test, |
| | self.paths_train, |
| | self.paths_test, |
| | ) = train_test_split( |
| | self.raw_activations, |
| | self.all_labels, |
| | self.video_paths, |
| | test_size=test_size, |
| | random_state=random_state, |
| | stratify=self.all_labels, |
| | ) |
| | self.encoder = self.encoder.fit(self.y_train) |
| | self.y_train = self.encoder.transform(self.y_train) |
| | self.y_test = self.encoder.transform(self.y_test) |
| |
|
| | |
| | self.mean_c, self.std_c = compute_concept_standardization(self.X_train) |
| | self.X_train = apply_standardization(self.X_train, self.mean_c, self.std_c) |
| | self.X_test = apply_standardization(self.X_test, self.mean_c, self.std_c) |
| | if self.X_val is not None: |
| | self.X_val = apply_standardization(self.X_val, self.mean_c, self.std_c) |
| |
|
| | |
| | classes, counts = np.unique(self.y_train, return_counts=True) |
| | self.class_weights = torch.tensor(counts.max() / counts, dtype=torch.float32) |
| | self.num_concepts = self.X_train[0].shape[-1] |
| | self.num_classes = len(classes) |
| |
|
| | def train_model( |
| | self, |
| | num_epochs: int, |
| | l1_lambda: float, |
| | lambda_sparse: float, |
| | batch_size: int = 8, |
| | lr: float = 1e-4, |
| | weight_decay: float = 1e-2, |
| | enforce_nonneg: bool = True, |
| | class_weights: bool = True, |
| | wandb_run: Optional[wandb.WandbRun] = None, |
| | random_seed: int = 42, |
| | ckpt_path: Optional[str] = None, |
| | early_stopping_patience: int = 50, |
| | ): |
| |
|
| | if wandb_run is not None: |
| | wandb_run.config.update( |
| | { |
| | "num_epochs": num_epochs, |
| | "l1_lambda": l1_lambda, |
| | "lambda_sparse": lambda_sparse, |
| | "lr": lr, |
| | "weight_decay": weight_decay, |
| | "batch_size": batch_size, |
| | "enforce_nonneg": enforce_nonneg, |
| | "class_weights": class_weights, |
| | "transformer_layers": self.model.transformer_layers, |
| | "lse_tau": self.model.lse_tau, |
| | "diagonal_attention": self.model.diagonal_attention, |
| | "early_stopping_patience": early_stopping_patience, |
| | } |
| | ) |
| |
|
| | |
| | self.model.to(self.device) |
| | optimizer = torch.optim.AdamW( |
| | self.model.parameters(), lr=lr, weight_decay=weight_decay |
| | ) |
| | if class_weights: |
| | criterion = nn.CrossEntropyLoss( |
| | weight=self.class_weights.to(self.device), label_smoothing=0.1 |
| | ) |
| | else: |
| | criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
| |
|
| | num_train = len(self.X_train) |
| |
|
| | best_metric = -float("inf") |
| | best_state = None |
| | best_epoch = -1 |
| | epochs_since_improvement = 0 |
| | use_early_stopping = (early_stopping_patience is not None) and ( |
| | len(self.X_test) > 0 |
| | ) |
| |
|
| | for epoch in range(num_epochs): |
| | self.model.train() |
| | correct, total = 0, 0 |
| | last_loss, last_L_sparse = None, None |
| | epoch_L_sparse_sum, epoch_batches = 0.0, 0 |
| |
|
| | base_seed = int(getattr(self, "seed", random_seed)) |
| | g = torch.Generator(device="cpu").manual_seed(base_seed + epoch) |
| | perm_tensor = torch.randperm(num_train, generator=g) |
| | perm = perm_tensor.tolist() |
| |
|
| | for start in range(0, num_train, batch_size): |
| | end = min(start + batch_size, num_train) |
| | idx = perm[start:end] |
| | batch_seqs = [self.X_train[i] for i in idx] |
| | batch_labels = torch.tensor( |
| | [int(self.y_train[i]) for i in idx], |
| | dtype=torch.long, |
| | device=self.device, |
| | ) |
| |
|
| | inputs, pad_mask = pad_batch_sequences(batch_seqs, device=self.device) |
| | optimizer.zero_grad() |
| |
|
| | |
| | logits, concepts_, concepts_t, sharpness = self.model( |
| | inputs, key_padding_mask=pad_mask |
| | ) |
| |
|
| | valid = (~pad_mask).unsqueeze(-1).float() |
| | last_L_sparse = (concepts_t.abs() * valid).sum() / ( |
| | valid.sum() * concepts_t.shape[-1] |
| | ).clamp(min=1.0) |
| |
|
| | ce = criterion(logits, batch_labels) |
| | l1 = l1_lambda * self.model.classifier.weight.abs().sum() |
| | loss = ce + l1 + lambda_sparse * last_L_sparse |
| | loss.backward() |
| | optimizer.step() |
| | last_loss = loss |
| |
|
| | |
| | epoch_L_sparse_sum += float(last_L_sparse.detach().item()) |
| | epoch_batches += 1 |
| |
|
| | if enforce_nonneg: |
| | with torch.no_grad(): |
| | self.model.classifier.weight.clamp_(min=0.0) |
| |
|
| | preds = logits.argmax(dim=1) |
| | correct += int((preds == batch_labels).sum().item()) |
| | total += batch_labels.shape[0] |
| |
|
| | acc = correct / max(1, total) |
| | epoch_L_sparse = epoch_L_sparse_sum / max(1, epoch_batches) |
| |
|
| | |
| | def evaluate(dataset_X, dataset_y): |
| | self.model.eval() |
| | correct, total = 0, 0 |
| | sharpness_vals = [] |
| | with torch.no_grad(): |
| | for start in range(0, len(dataset_X), batch_size): |
| | end = min(start + batch_size, len(dataset_X)) |
| | batch_seqs = [dataset_X[i] for i in range(start, end)] |
| | batch_labels = torch.tensor( |
| | [int(dataset_y[i]) for i in range(start, end)], |
| | dtype=torch.long, |
| | device=self.device, |
| | ) |
| | inputs, pad_mask = pad_batch_sequences( |
| | batch_seqs, device=self.device |
| | ) |
| |
|
| | logits, _, _, sharpness = self.model( |
| | inputs, key_padding_mask=pad_mask |
| | ) |
| | preds = logits.argmax(dim=1) |
| | correct += int((preds == batch_labels).sum().item()) |
| | total += batch_labels.shape[0] |
| |
|
| | for b in range(logits.shape[0]): |
| | sharpness_vals.append( |
| | { |
| | "concepts_max": float( |
| | sharpness["concepts"]["max"][b] |
| | .mean() |
| | .detach() |
| | .cpu() |
| | .item() |
| | ), |
| | "concepts_entropy": float( |
| | sharpness["concepts"]["entropy"][b] |
| | .mean() |
| | .detach() |
| | .cpu() |
| | .item() |
| | ), |
| | "logits_max": float( |
| | sharpness["logits"]["max"][b] |
| | .mean() |
| | .detach() |
| | .cpu() |
| | .item() |
| | ), |
| | "logits_entropy": float( |
| | sharpness["logits"]["entropy"][b] |
| | .mean() |
| | .detach() |
| | .cpu() |
| | .item() |
| | ), |
| | } |
| | ) |
| |
|
| | acc = correct / max(1, total) |
| | if sharpness_vals: |
| | mean_sharp = { |
| | k: float(np.mean([s[k] for s in sharpness_vals])) |
| | for k in sharpness_vals[0] |
| | } |
| | else: |
| | mean_sharp = {} |
| | return acc, mean_sharp |
| |
|
| | test_acc, test_sharp = ( |
| | (0.0, {}) |
| | if len(self.X_test) == 0 |
| | else evaluate(self.X_test, self.y_test) |
| | ) |
| | val_acc, val_sharp = ( |
| | (0.0, {}) if self.X_val is None else evaluate(self.X_val, self.y_val) |
| | ) |
| |
|
| | metric = test_acc if len(self.X_test) > 0 else acc |
| |
|
| | |
| | if metric > best_metric + 1e-8: |
| | best_metric = metric |
| | best_epoch = epoch |
| | epochs_since_improvement = 0 |
| | best_state = { |
| | k: v.detach().cpu().clone() |
| | for k, v in self.model.state_dict().items() |
| | } |
| | if ckpt_path: |
| | tmp = ckpt_path + ".tmp" |
| | torch.save(best_state, tmp) |
| | os.replace(tmp, ckpt_path) |
| | else: |
| | epochs_since_improvement += 1 |
| |
|
| | |
| | if wandb_run is not None: |
| | current_lr = ( |
| | optimizer.param_groups[0]["lr"] if optimizer.param_groups else None |
| | ) |
| | log_data = { |
| | "epoch": epoch + 1, |
| | "train_loss": ( |
| | float(last_loss.item()) if last_loss is not None else None |
| | ), |
| | "train_acc": acc, |
| | "test_acc": test_acc, |
| | "val_acc": val_acc if self.X_val is not None else None, |
| | "L_sparse": ( |
| | float(last_L_sparse.item()) |
| | if last_L_sparse is not None |
| | else None |
| | ), |
| | "learning_rate": current_lr, |
| | "best_val_acc": best_metric, |
| | "epochs_since_improvement": epochs_since_improvement, |
| | } |
| | |
| | for prefix, sharp in [("test_", test_sharp), ("val_", val_sharp)]: |
| | for k, v in sharp.items(): |
| | log_data[prefix + "sharp_" + k] = v |
| | wandb_run.log(log_data) |
| |
|
| | if epoch % 10 == 0 or epoch == num_epochs - 1: |
| | msg_loss = ( |
| | float(last_loss.item()) if last_loss is not None else float("nan") |
| | ) |
| | msg_sparse = ( |
| | float(last_L_sparse.item()) |
| | if last_L_sparse is not None |
| | else float("nan") |
| | ) |
| | print( |
| | f"Epoch {epoch+1}/{num_epochs} | loss {msg_loss:.4f} | test_acc {test_acc:.4f} " |
| | f"| train_acc {acc:.4f} | L_sparse {msg_sparse:.4f} " |
| | f"| best_val {best_metric:.4f} | epochs_no_improve {epochs_since_improvement}" |
| | ) |
| |
|
| | |
| | if ( |
| | use_early_stopping |
| | and epochs_since_improvement >= early_stopping_patience |
| | ): |
| | print( |
| | f"[MoTIF] Early stopping triggered (no improvement for {epochs_since_improvement} epochs). Stopping at epoch {epoch+1}." |
| | ) |
| | if wandb_run is not None: |
| | wandb_run.log( |
| | { |
| | "early_stopped_epoch": epoch + 1, |
| | "early_stopping_patience": early_stopping_patience, |
| | } |
| | ) |
| | break |
| |
|
| | |
| | if best_state is not None: |
| | self.model.load_state_dict(best_state, strict=True) |
| | self.model.eval() |
| | print( |
| | f"[MoTIF] Restored best weights from epoch {best_epoch+1} (metric={best_metric:.4f})." |
| | ) |
| | else: |
| | print("[MoTIF] No best_state captured (empty training?).") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class PerConceptAffine(nn.Module): |
| | def __init__(self, num_concepts: int): |
| | super().__init__() |
| | self.scale = nn.Parameter(torch.ones(num_concepts)) |
| | self.bias = nn.Parameter(torch.zeros(num_concepts)) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | y = F.softplus(x * self.scale + self.bias) - math.log(2.0) |
| | return y.clamp(min=0.0) |
| |
|
| |
|
| | class CBMTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | num_concepts: int, |
| | num_classes: int, |
| | transformer_layers: int = 1, |
| | dropout: float = 0.1, |
| | lse_tau: float = 1.0, |
| | nonneg_classifier: bool = False, |
| | diagonal_attention: bool = True, |
| | dimension=1, |
| | ): |
| | super().__init__() |
| | self.lse_tau = lse_tau |
| | self.diagonal_attention = diagonal_attention |
| | self.transformer_layers = transformer_layers |
| |
|
| | self.posenc = PositionalEncoding( |
| | d_model=num_concepts, dropout=dropout, max_len=2000 |
| | ) |
| | if diagonal_attention: |
| | self.layers = nn.ModuleList( |
| | [ |
| | PerChannelTemporalBlock( |
| | C=num_concepts, dropout=dropout, d=dimension |
| | ) |
| | for _ in range(transformer_layers) |
| | ] |
| | ) |
| | else: |
| | self.layers = nn.ModuleList( |
| | [ |
| | FullAttentionTemporalBlock( |
| | C=num_concepts, num_heads=None, dropout=dropout |
| | ) |
| | for _ in range(transformer_layers) |
| | ] |
| | ) |
| | self.norm = nn.LayerNorm(num_concepts) |
| | self.concept_predictor = PerConceptAffine(num_concepts) |
| |
|
| | if nonneg_classifier: |
| | self.classifier = NonNegativeLinear(num_concepts, num_classes) |
| | else: |
| | self.classifier = nn.Linear(num_concepts, num_classes) |
| |
|
| | |
| | self.last_time_importance = None |
| | |
| |
|
| | def forward( |
| | self, |
| | window_embeddings: torch.Tensor, |
| | key_padding_mask: Optional[torch.Tensor] = None, |
| | channel_ids: Optional[Union[List[int], torch.Tensor]] = None, |
| | window_ids: Optional[Union[List[int], torch.Tensor]] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | window_embeddings: [B,T,C] or [T,C] |
| | key_padding_mask: [B,T] with True for padded tokens to be ignored |
| | |
| | Returns: |
| | logits: [B,K] pooled class logits |
| | concepts: [B,C] pooled concept activations |
| | concepts_t: [B,T,C] per-time-step concepts |
| | sharpness: dict with 'concepts' and 'logits' sharpness per batch |
| | """ |
| | x = window_embeddings |
| | if x.dim() == 2: |
| | x = x.unsqueeze(0) |
| | if key_padding_mask is not None and key_padding_mask.dim() == 1: |
| | key_padding_mask = key_padding_mask.unsqueeze(0) |
| |
|
| | |
| | x = self.posenc(x) |
| | for layer in self.layers: |
| | x = layer(x, key_padding_mask=key_padding_mask) |
| | x = self.norm(x) |
| |
|
| | |
| | concepts_t = self.concept_predictor(x) |
| |
|
| | |
| | if channel_ids is not None and window_ids is not None: |
| | concepts_t[:, window_ids, channel_ids] = 0 |
| | elif channel_ids is not None: |
| | concepts_t[:, :, channel_ids] = 0 |
| | elif window_ids is not None: |
| | concepts_t[:, window_ids, :] = 0 |
| |
|
| | logits_t = self.classifier(concepts_t) |
| |
|
| | tau = self.lse_tau |
| |
|
| | |
| | if key_padding_mask is not None: |
| | concepts_t_masked = concepts_t.masked_fill( |
| | key_padding_mask.unsqueeze(-1), float("-inf") |
| | ) |
| | logits_t_masked = logits_t.masked_fill( |
| | key_padding_mask.unsqueeze(-1), float("-inf") |
| | ) |
| |
|
| | concepts = (concepts_t_masked * tau).logsumexp(dim=1) / tau |
| | logits = (logits_t_masked * tau).logsumexp(dim=1) / tau |
| | else: |
| | concepts = (concepts_t * tau).logsumexp(dim=1) / tau |
| | logits = (logits_t * tau).logsumexp(dim=1) / tau |
| |
|
| | |
| | with torch.no_grad(): |
| | pred = logits.argmax(dim=1) |
| | sel = torch.gather(logits_t, dim=2, index=pred[:, None, None]).squeeze( |
| | -1 |
| | ) |
| | if key_padding_mask is not None: |
| | sel = sel.masked_fill(key_padding_mask, float("-inf")) |
| | self.last_time_importance = torch.softmax( |
| | sel / tau, dim=1 |
| | ).detach() |
| |
|
| | |
| | def compute_sharpness(x_t, mask=None): |
| | """Compute max / entropy as sharpness metric for batch""" |
| | if mask is not None: |
| | x_t = x_t.masked_fill(mask.unsqueeze(-1), float("-inf")) |
| | probs = torch.softmax(tau * x_t, dim=1) |
| | probs = probs.clamp(min=1e-8) |
| | max_prob = probs.max(dim=1).values |
| | entropy = -(probs * probs.log()).sum(dim=1) |
| | return {"max": max_prob, "entropy": entropy} |
| |
|
| | sharpness = { |
| | "concepts": compute_sharpness(concepts_t, key_padding_mask), |
| | "logits": compute_sharpness(logits_t, key_padding_mask), |
| | } |
| |
|
| | return logits, concepts, concepts_t, sharpness |
| |
|
| | def get_attention_maps(self): |
| | |
| | return [ |
| | layer.attn_weights.cpu() if layer.attn_weights is not None else None |
| | for layer in self.layers |
| | ] |
| |
|
| |
|
| | def mean_cbm(model, wandb_run=None): |
| | X_train, X_test = model.X_train.copy(), model.X_test.copy() |
| | y_train, y_test = model.y_train.copy(), model.y_test.copy() |
| | num_classes = model.num_classes |
| | num_concepts = model.num_concepts |
| | batch_size = 1 |
| |
|
| | device = getattr(model, "device", get_torch_device()) |
| |
|
| | random = False |
| | if random: |
| |
|
| | def get_random_image(x): |
| | idx = np.random.randint(0, len(x)) |
| | return x[idx] |
| |
|
| | |
| | X_train_random = [get_random_image(x) for x in X_train] |
| | X_test_random = [get_random_image(x) for x in X_test] |
| |
|
| | X_train_mean = X_train_random |
| | X_test_mean = X_test_random |
| |
|
| | else: |
| | |
| | X_train_mean = [torch.mean(x, axis=0) for x in X_train] |
| | X_test_mean = [torch.mean(x, axis=0) for x in X_test] |
| |
|
| | |
| | X_train_arr = np.stack( |
| | [ |
| | t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) |
| | for t in X_train_mean |
| | ] |
| | ) |
| | X_test_arr = np.stack( |
| | [ |
| | t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) |
| | for t in X_test_mean |
| | ] |
| | ) |
| |
|
| | tensor_train = torch.tensor(X_train_arr, dtype=torch.float32, device=device) |
| | tensor_test = torch.tensor(X_test_arr, dtype=torch.float32, device=device) |
| |
|
| | |
| |
|
| | linear_model = nn.Linear(num_concepts, num_classes).to(device) |
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = torch.optim.Adam(linear_model.parameters(), lr=0.001) |
| | num_epochs = 200 |
| | for epoch in range(num_epochs): |
| | linear_model.train() |
| | optimizer.zero_grad() |
| | outputs = linear_model(tensor_train) |
| | loss = criterion( |
| | outputs, torch.tensor(y_train, dtype=torch.long, device=device) |
| | ) |
| | loss.backward() |
| | optimizer.step() |
| | if wandb_run is not None: |
| | with torch.no_grad(): |
| | preds = outputs.argmax(dim=1) |
| | acc = (preds.detach().cpu().numpy() == y_train).mean() |
| | current_lr = ( |
| | optimizer.param_groups[0]["lr"] if optimizer.param_groups else None |
| | ) |
| | wandb_run.log( |
| | { |
| | "mean_train_loss": loss.item(), |
| | "mean_train_acc": acc, |
| | "mean_learning_rate": current_lr, |
| | } |
| | ) |
| | linear_model.eval() |
| | with torch.no_grad(): |
| | outputs = linear_model(tensor_test) |
| | _, predicted = torch.max(outputs, 1) |
| | accuracy = (predicted.detach().cpu().numpy() == y_test).mean() |
| | print(f"CBM accuracy test: {accuracy:.4f}") |
| | if wandb_run is not None: |
| | wandb_run.log({"mean_test_acc": accuracy}) |
| |
|
| |
|
| | class NonNegativeLinear: |
| | def __init__(self, in_features, out_features, bias=True): |
| | self.linear = nn.Linear(in_features, out_features, bias=bias) |
| |
|
| | def forward(self, x): |
| | self.linear.weight.data.clamp_(min=0.0) |
| | return self.linear(x) |
| |
|