MoTIF / utils /motif.py
P4ddyki's picture
Upload folder using huggingface_hub
3cf4fff verified
from __future__ import annotations
import os
import random
import numpy as np
import torch
import copy
from typing import List, Optional, Dict, Tuple
import cv2
from PIL import Image
import tqdm
import torch.nn as nn
import gc
import torch.nn.functional as F
from torchvision.transforms import (
Compose,
Resize,
CenterCrop,
ToTensor,
Normalize,
InterpolationMode,
)
import math
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import wandb
import re
import pandas as pd
import glob
def init_repro(seed: int = 42, deterministic: bool = True):
"""Call this at the very top of your notebook/script BEFORE creating any model/processor/device context."""
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = (
":16:8" # deterministic cuBLAS on Ampere+, nice default
)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Determinism knobs (do this before any CUDA ops)
if deterministic:
try:
torch.use_deterministic_algorithms(True)
except Exception:
# older torch may not support signature
torch.set_deterministic(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Reduce threading nondeterminism
torch.set_num_threads(1)
return seed
def get_torch_device(prefer: Optional[str] = None) -> torch.device:
if prefer is not None:
pref = prefer.lower()
if pref == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
if (
pref == "mps"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
return torch.device("mps")
if pref == "cpu":
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def pad_batch_sequences(
seqs: List[torch.Tensor], device: torch.device
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Pad a list of [T_i, C] tensors into a batch [B, T_max, C] and return
a key_padding_mask [B, T_max] with True for padded positions.
"""
if len(seqs) == 0:
raise ValueError("pad_batch_sequences received empty sequence list")
lengths = [int(s.shape[0]) for s in seqs]
C = int(seqs[0].shape[1])
T_max = int(max(lengths))
B = len(seqs)
batch = torch.zeros((B, T_max, C), dtype=torch.float32, device=device)
mask = torch.ones((B, T_max), dtype=torch.bool, device=device) # True=padded
for i, s in enumerate(seqs):
t = lengths[i]
batch[i, :t, :] = s.to(device)
mask[i, :t] = False
return batch, mask
def compute_concept_standardization(seqs: List[torch.Tensor | np.ndarray]):
cat = torch.cat(
[
(
s
if isinstance(s, torch.Tensor)
else torch.tensor(np.array(s), dtype=torch.float32)
)
for s in seqs
],
dim=0,
)
mean = cat.mean(dim=0)
std = cat.std(dim=0).clamp_min(1e-6)
return mean, std
def apply_standardization(
seqs: List[torch.Tensor | np.ndarray], mean: torch.Tensor, std: torch.Tensor
):
out = []
for s in seqs:
s_t = (
s
if isinstance(s, torch.Tensor)
else torch.tensor(np.array(s), dtype=torch.float32)
)
out.append((s_t - mean) / std)
return out
def concepts_over_time_cosine(
concepts: torch.Tensor,
all_data_list,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
chunk_size: int | None = None,
):
"""
Cosine-sim per frame vs concepts.
- Normalizes in fp32 for stability, computes in fp32, then returns on CPU.
- Optional chunked matmul to cap peak memory.
"""
with torch.no_grad():
# normalize concepts in fp32 on target device
c = F.normalize(
concepts.detach().to(device=device, dtype=torch.float32), dim=1
) # [K,C]
K = c.shape[0]
activations, embeddings = [], []
for vid in all_data_list:
x = vid if isinstance(vid, torch.Tensor) else torch.as_tensor(vid)
if x.ndim == 1:
x = x.unsqueeze(0)
elif x.ndim > 2:
x = x.view(-1, x.size(-1))
x = x.detach().to(device=device, dtype=torch.float32) # [T,C]
if x.numel() == 0:
sim = torch.empty((0, K), dtype=torch.float32, device=device)
else:
x = F.normalize(x, dim=1)
if chunk_size is None or x.shape[0] <= chunk_size:
sim = x @ c.T # [T,K]
else:
# chunk over T to limit peak memory
outs = []
for s in range(0, x.shape[0], chunk_size):
outs.append(x[s : s + chunk_size] @ c.T)
sim = torch.cat(outs, dim=0)
sim = torch.clamp(sim, min=0.0)
# return CPU fp32
activations.append(sim.to("cpu", dtype=dtype))
embeddings.append(vid) # keep original reference if needed
return activations, embeddings
class PositionalEncoding(nn.Module):
"""
Supports both [T, C] and [B, T, C] input tensors, automatically unsqueezing and squeezing as needed for 2D input.
"""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(
1
) # [max_len,1]
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32)
* (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, d_model, dtype=torch.float32) # [max_len, C]
# Handle even and odd indices separately to avoid dimension mismatch
pe[:, 0::2] = torch.sin(position * div_term)
if d_model % 2 == 0:
# Even d_model: use same div_term for cosine
pe[:, 1::2] = torch.cos(position * div_term)
else:
# Odd d_model: need one more element for cosine
div_term_cos = torch.exp(
torch.arange(0, d_model - 1, 2, dtype=torch.float32)
* (-math.log(10000.0) / d_model)
)
pe[:, 1::2] = torch.cos(position * div_term_cos)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor):
"""
Handles both 2D and 3D input, automatically unsqueezing and squeezing for [T, C] input. Positional encoding is broadcast over the batch dimension.
"""
squeeze_back = False
if x.dim() == 2: # [T, C] -> [1, T, C]
x = x.unsqueeze(0)
squeeze_back = True
seq_len = x.size(1)
x = x + self.pe[:seq_len, :] # broadcast over batch
x = self.dropout(x)
if squeeze_back:
x = x.squeeze(0)
return x
# -------------------------
# Diagonal (per-channel) Q/K/V + per-channel FFN
# -------------------------
class DiagQKVd(nn.Module):
"""Per-channel Q/K/V with width d (no cross-concept mixing)."""
def __init__(self, C: int, d: int = 8, bias: bool = True):
super().__init__()
self.C, self.d = C, d
# groups=C keeps channels isolated; each channel gets d features
self.q = nn.Conv1d(C, C * d, 1, groups=C, bias=bias)
self.k = nn.Conv1d(C, C * d, 1, groups=C, bias=bias)
self.v = nn.Conv1d(C, C * d, 1, groups=C, bias=bias)
def forward(self, x): # x: [B,T,C]
B, T, C = x.shape
xc = x.transpose(1, 2) # [B,C,T]
Q = self.q(xc).transpose(1, 2).view(B, T, C, self.d) # [B,T,C,d]
K = self.k(xc).transpose(1, 2).view(B, T, C, self.d)
V = self.v(xc).transpose(1, 2).view(B, T, C, self.d)
return Q, K, V
class ChannelTimeNorm(nn.Module):
def __init__(self, C, eps=1e-5, affine=True):
super().__init__()
self.ln = nn.LayerNorm(C, eps=eps, elementwise_affine=affine)
def forward(self, x): # x: [B,T,C]
return self.ln(x)
class PerChannelFFN(nn.Module):
"""Per-channel FFN (no cross-concept mixing)."""
def __init__(self, C: int, dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Conv1d(
C, C, kernel_size=1, groups=C, bias=True
) # group equals C to have no channel mixing!
self.fc2 = nn.Conv1d(C, C, kernel_size=1, groups=C, bias=True)
self.act = nn.GELU()
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
# x: [B, T, C]
xc = x.transpose(1, 2) # [B, C, T]
y = self.fc2(self.drop(self.act(self.fc1(xc))))
return y.transpose(1, 2) # [B, T, C]
class PerChannelTemporalBlock(nn.Module):
"""
Attention over time for each concept channel independently.
Stores attn_weights: [B, C, T, T].
"""
def __init__(self, C: int, d: int = 1, dropout: float = 0.1, T_max: int = 1024):
super().__init__()
self.C, self.d = C, d
self.qkv = DiagQKVd(C, d)
self.scale = d**-0.5
self.logit_scale = nn.Parameter(torch.zeros(C)) # per-concept multiplier
self.norm1 = ChannelTimeNorm(C)
self.norm2 = ChannelTimeNorm(C)
self.drop = nn.Dropout(dropout)
self.ffn = PerChannelFFN(C, dropout=dropout)
self.act = nn.GELU()
self.attn_weights = None
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, C = x.shape
# Pre-attention norm
y = self.norm1(x) # [B, T, C]
# Per-channel QKV: Q/K/V are [B, T, C, d]
Q, K, V = self.qkv(y)
# Attention logits per channel: [B, C, T, T]
scores = torch.einsum("btcd,bucd->bctu", Q, K) * self.scale
# Optional masks
if attn_mask is not None:
# treat bool as additive -inf mask; float as-is
if attn_mask.dtype == torch.bool:
am = torch.zeros_like(attn_mask, dtype=scores.dtype)
am = am.masked_fill(attn_mask, float("-inf"))
else:
am = attn_mask.to(dtype=scores.dtype)
scores = scores + am.view(1, 1, T, T)
if key_padding_mask is not None:
kpm = key_padding_mask.view(B, 1, 1, T) # True = masked
scores = scores.masked_fill(kpm, float("-inf"))
# Softmax over source time axis
w = torch.softmax(scores, dim=-1) # [B, C, T, T]
self.attn_weights = w.detach()
# Weighted sum of values, then reduce d
out = torch.einsum("bctu,bucd->btcd", w, V).mean(dim=-1) # [B, T, C]
# Residual + dropout
x = x + self.drop(out)
# Post-attention norm + per-channel FFN (already expects [B,T,C])
z = self.norm2(x)
z = self.ffn(z)
# Residual + dropout
x = x + self.drop(z)
return x
def _pick_num_heads(C: int, proposed: Optional[int]) -> int:
if proposed is not None and proposed >= 1 and C % proposed == 0:
return proposed
for h in [8, 6, 4, 3, 2]:
if h <= C and C % h == 0:
return h
return 1
class FullAttentionTemporalBlock(nn.Module):
"""
Full multi-head self-attention over time with channel mixing (manual implementation).
"""
def __init__(
self,
C: int,
num_heads: Optional[int] = None,
dropout: float = 0.1,
ffn_mult: int = 4,
):
super().__init__()
self.C = C
self.H = _pick_num_heads(C, num_heads)
self.d = C // self.H
assert self.H * self.d == C, "C must be divisible by num_heads"
# Projections (mix channels)
self.q_proj = nn.Linear(C, C, bias=True)
self.k_proj = nn.Linear(C, C, bias=True)
self.v_proj = nn.Linear(C, C, bias=True)
self.o_proj = nn.Linear(C, C, bias=True)
self.attn_drop = nn.Dropout(dropout)
self.proj_drop = nn.Dropout(dropout)
self.ffn = nn.Sequential(
nn.Linear(C, ffn_mult * C),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ffn_mult * C, C),
)
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(C)
self.norm2 = nn.LayerNorm(C)
self.attn_weights = None # [B, H, T, T]
def _shape_heads(self, x: torch.Tensor) -> torch.Tensor:
# [B, T, C] -> [B, H, T, d]
B, T, _ = x.shape
return x.view(B, T, self.H, self.d).permute(0, 2, 1, 3)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, # [T, T]
key_padding_mask: Optional[torch.Tensor] = None, # [B, T]
) -> torch.Tensor:
assert x.dim() == 3, "x must be [B, T, C]"
B, T, C = x.shape
assert C == self.C
# Projections
Q = self._shape_heads(self.q_proj(x)) # [B,H,T,d]
K = self._shape_heads(self.k_proj(x)) # [B,H,T,d]
V = self._shape_heads(self.v_proj(x)) # [B,H,T,d]
# Scaled dot-product attention
scale = self.d**-0.5
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale # [B,H,T,T]
# Masks
if attn_mask is not None:
# bool -> additive mask; float left as-is
if attn_mask.dtype == torch.bool:
am = torch.zeros_like(attn_mask, dtype=Q.dtype) # 0 keep
am = am.masked_fill(attn_mask, float("-inf"))
else:
am = attn_mask.to(dtype=Q.dtype)
scores = scores + am.view(1, 1, T, T)
if key_padding_mask is not None:
kpm = key_padding_mask.to(torch.bool).view(
B, 1, 1, T
) # broadcast on heads & queries
scores = scores.masked_fill(kpm, float("-inf"))
weights = F.softmax(scores, dim=-1) # [B,H,T,T]
weights = self.attn_drop(weights)
self.attn_weights = weights.detach()
out = torch.matmul(weights, V) # [B,H,T,d]
out = out.permute(0, 2, 1, 3).contiguous() # [B,T,H,d]
out = out.view(B, T, C) # [B,T,C]
out = self.o_proj(out)
out = self.proj_drop(out)
# Residual + norm
x = self.norm1(x + out)
# FFN + residual + norm
ff = self.ffn(x)
x = self.norm2(x + self.dropout(ff))
return x
class MoTIF:
"""
MoTIF model for video classification using concept bottleneck models.
Assumes:
- concepts_over_time_cosine returns signed cosine sims (no clamp).
- self.model(window_embeddings, key_padding_mask) returns (logits, concepts, concepts_t, sharpness)
"""
@staticmethod
def _collate_pad(batch):
"""
batch: list of tuples (seq:[T,C] CPU float32, y:int)
Returns CPU pinned tensors to enable non_blocking .to(device)
"""
B = len(batch)
T = max(seq.shape[0] for seq, _ in batch)
C = batch[0][0].shape[1]
x = torch.zeros((B, T, C), dtype=torch.float32)
mask = torch.ones((B, T), dtype=torch.bool) # True = padded
y = torch.empty((B,), dtype=torch.long)
for i, (seq, yi) in enumerate(batch):
t = seq.shape[0]
x[i, :t].copy_(seq) # CPU->CPU copy into pinned
mask[i, :t] = False
y[i] = yi
return x, mask, y
def __init__(self, embedder, concepts):
self.device = get_torch_device(prefer="cuda")
self.concepts = concepts
self.all_data = embedder.video_embeddings # dict: path -> [T,C]
self.all_labels = (
embedder.labels
) # list aligned with keys order (non-SSv2 case)
self.video_paths = list(self.all_data.keys())
self.video_spans = embedder.video_window_spans
self.concept_bank = concepts.text_embeddings
self.raw_activations, self.video_embeddings = concepts_over_time_cosine(
self.concept_bank, list(self.all_data.values())
) # list of [T,C]
keep_idx = [
i
for i, act in enumerate(self.raw_activations)
if isinstance(act, torch.Tensor) and act.shape[0] > 0
]
if len(keep_idx) != len(self.raw_activations):
removed = len(self.raw_activations) - len(keep_idx)
self.raw_activations = [self.raw_activations[i] for i in keep_idx]
self.video_paths = [self.video_paths[i] for i in keep_idx]
self.all_labels = [self.all_labels[i] for i in keep_idx] # non-SSv2 path
self.video_embeddings = [self.video_embeddings[i] for i in keep_idx]
print(f"[MoTIF] Removed {removed} entries with empty activations.")
# Stable, aligned numeric IDs (for SSv2)
self.video_ids = [self.path_to_id(p) for p in self.video_paths]
self.kept_ids = {vid for vid in self.video_ids if vid is not None}
# Defer LabelEncoder to preprocess()
self.encoder = LabelEncoder()
self.class_weights = None
self.mean_c, self.std_c = None, None
self.X_train = self.X_val = self.X_test = None
self.y_train = self.y_val = self.y_test = None
self.paths_train = self.paths_val = self.paths_test = None
self.test_zero_shot = None
# Model attached later
self.model = None
@staticmethod
def path_to_id(p: str):
base = os.path.splitext(os.path.basename(p))[0]
m = re.search(r"(\d+)", base)
return int(m.group(1)) if m else None
# -------------------------
# Zero-shot (vectorized over frames)
# -------------------------
@torch.inference_mode()
def zero_shot(self, concept_embedder, wandb_run=None):
assert (
self.test_zero_shot is not None and self.y_test is not None
), "Call preprocess(...) first."
# build text prompts and text embeddings
class_prompts = ["a video of " + c for c in self.encoder.classes_.tolist()]
text_embedder = copy.copy(concept_embedder)
text_embedder.tokenizer = concept_embedder.tokenizer
text_embedder.model = concept_embedder.model
text_embedder.embedd_text(class_prompts) # keep original method name
# ensure device + dtype
text_embeddings = text_embedder.text_embeddings.to(self.device, dtype=torch.float32) # [K, C]
text_embeddings = F.normalize(text_embeddings, dim=-1)
# check model type for probability transform
model_name = getattr(text_embedder, "model_name", "").lower()
use_siglip = "siglip" in model_name
if use_siglip:
# SigLIP style scaling/bias (ensure fp32)
scale = text_embedder.model.logit_scale.exp().to(self.device).float()
bias = text_embedder.model.logit_bias.to(self.device).float() # shape [K] or [1,K]
# counters
correct_pooled = 0
correct_soft_avg = 0
correct_hard_majority = 0
for idx, frames in enumerate(self.test_zero_shot):
# frames -> frame embeddings [T, C] on device
frame_emb = torch.as_tensor(np.array(frames), device=self.device, dtype=torch.float32)
frame_emb = F.normalize(frame_emb, dim=-1) # [T, C]
# pooled embedding (mean over time) [1, C]
pooled_emb = F.normalize(frame_emb.mean(dim=0, keepdim=True), dim=-1) # [1, C]
# raw logits
if use_siglip:
logits_pooled = pooled_emb @ text_embeddings.T
logits_pooled = logits_pooled * scale + bias # [1, K]
logits_per_frame = (frame_emb @ text_embeddings.T) * scale + bias # [T, K]
probs_per_frame = logits_per_frame.sigmoid() # for soft average
else:
logits_pooled = pooled_emb @ text_embeddings.T # [1, K]
logits_per_frame = frame_emb @ text_embeddings.T # [T, K]
probs_per_frame = logits_per_frame.softmax(dim=-1) # for soft average
# predictions
pred_pooled = logits_pooled.argmax(dim=-1).item() # mean-pooled embedding
pred_soft_avg = probs_per_frame.mean(dim=0).argmax().item() # soft voting (avg probs)
per_frame_preds = logits_per_frame.argmax(dim=-1) # [T]
counts = torch.bincount(per_frame_preds, minlength=logits_per_frame.size(1))
pred_hard_majority = counts.argmax().item() # hard majority (mode)
# ground truth
y = int(self.y_test[idx])
# update counters
correct_pooled += int(pred_pooled == y)
correct_soft_avg += int(pred_soft_avg == y)
correct_hard_majority += int(pred_hard_majority == y)
n = max(1, len(self.test_zero_shot))
acc_pooled = correct_pooled / n
acc_soft_avg = correct_soft_avg / n
acc_hard_majority = correct_hard_majority / n
# logging
if wandb_run is not None:
wandb_run.log(
{
"zero_shot_acc_pooled": acc_pooled,
"zero_shot_acc_soft_avg": acc_soft_avg,
"zero_shot_acc_hard_majority": acc_hard_majority,
}
)
print(
f"[ZS] pooled={acc_pooled:.4f} | soft-avg={acc_soft_avg:.4f} | hard-majority={acc_hard_majority:.4f}"
)
return {
"acc_pooled": acc_pooled,
"acc_soft_avg": acc_soft_avg,
"acc_hard_majority": acc_hard_majority,
}
# -------------------------
# Preprocess (unchanged split logic; at end we build datasets)
# -------------------------
def preprocess(self,
dataset: str,
info: Optional[str] = None,
test_size: float = 0.2,
random_state: int = 42,):
binary_array = []
def get_index(info):
if info == "s1":
index = 1
elif info == "s2":
index = 2
elif info == "s3":
index = 3
else:
index = 1
return index
if info:
if dataset == "breakfast":
RANGES = {
"s1": range(3, 16),
"s2": range(16, 29),
"s3": range(29, 42),
"s4": range(42, 54),
}
def split_paths_by_group(paths, group_name, ranges=RANGES):
if group_name not in ranges:
raise ValueError(
f"Unknown group '{group_name}'. Expected one of {list(ranges)}"
)
target = ranges[group_name]
for p in paths:
if any(re.search(rf"P{num:02}", p) for num in target):
binary_array.append(False)
else:
binary_array.append(True)
return binary_array
binary_array = split_paths_by_group(self.video_paths, info)
elif dataset == "ucf101":
index = get_index(info)
ucf_test_list = (
f"../Datasets/UCF101/ucfTrainTestlist/testlist0{index}.txt"
)
path_list = pd.read_csv(ucf_test_list, sep=" ", header=None)
for path in self.video_paths:
path_rel = path.split("Video_data/")[1].replace(".mp4", ".avi")
binary_array.append(
False if path_rel in path_list[0].values else True
)
elif dataset == "hmdb51":
index = get_index(info)
labels_path = "../Datasets/HMDB/testTrainMulti_7030_splits/"
path_text_dirs = glob.glob(os.path.join(labels_path, "*.txt"))
path_text_dirs_idx = [p for p in path_text_dirs if f"split{index}" in p]
path_text_dirs_idx.sort()
path_list_test, path_list_train, path_list_ignore = set(), set(), set()
for txt_path in path_text_dirs_idx:
with open(txt_path, "r") as fh:
for line in fh:
name, flag = line.strip().split()
if flag == "2":
path_list_test.add(name)
elif flag == "0":
path_list_ignore.add(name)
else:
path_list_train.add(name)
mask = []
for vp in self.video_paths:
basename = os.path.basename(vp).replace(".mp4", ".avi")
if basename in path_list_test:
mask.append(False)
elif basename in path_list_train:
mask.append(True)
elif basename in path_list_ignore:
mask.append(None)
else:
mask.append(None)
kept = [
(x, y, p, b, m)
for x, y, p, b, m in zip(
self.raw_activations,
self.all_labels,
self.video_paths,
self.video_embeddings,
mask,
)
if m is not None
]
if not kept:
raise ValueError(
"HMDB split produced no usable items. Check paths and split lists."
)
(
self.raw_activations,
self.all_labels,
self.video_paths,
self.video_embeddings,
mask_kept,
) = map(list, zip(*kept))
self.video_ids = [
(
int(os.path.splitext(os.path.basename(p))[0])
if os.path.splitext(os.path.basename(p))[0].isdigit()
else None
)
for p in self.video_paths
]
self.kept_ids = {vid for vid in self.video_ids if vid is not None}
binary_array = [True if m else False for m in mask_kept]
elif dataset == "something2":
# ===== SSv2 handling =====
def replace_something(text: str) -> str:
return re.sub(r"\[(.*?)\]", r"\1", text)
val_json = "../Datasets/Something2/labels/validation.json"
train_json = "../Datasets/Something2/labels/train.json"
test_json = "../Datasets/Something2/labels/test.json"
test_csv = "../Datasets/Something2/labels/test-answers.csv"
df_train = pd.read_json(train_json)
df_val = pd.read_json(val_json)
df_test = pd.read_json(test_json)
train_ids = [int(row[0]) for row in df_train.values.tolist()]
val_ids = [int(row[0]) for row in df_val.values.tolist()]
test_ids = [int(row[0]) for row in df_test.values.tolist()]
train_labels = [replace_something(t) for t in df_train["template"]]
val_labels = [replace_something(t) for t in df_val["template"]]
test_tbl = pd.read_csv(
test_csv, sep=";", header=None, dtype={0: int, 1: str}
)
test_labels_map = dict(zip(test_tbl[0].tolist(), test_tbl[1].tolist()))
test_labels = [test_labels_map[i] for i in test_ids]
id2split = {}
id2split.update(
{i: ("train", l) for i, l in zip(train_ids, train_labels)}
)
id2split.update({i: ("val", l) for i, l in zip(val_ids, val_labels)})
id2split.update({i: ("test", l) for i, l in zip(test_ids, test_labels)})
train_x, val_x, test_x = [], [], []
train_y, val_y, test_y = [], [], []
self.test_zero_shot = []
self.paths_train, self.paths_val, self.paths_test = [], [], []
self.video_ids = [self.path_to_id(p) for p in self.video_paths]
missed = 0
for idx, vid in enumerate(self.video_ids):
if vid is None:
missed += 1
continue
entry = id2split.get(vid)
if entry is None:
missed += 1
continue
split, lab = entry
if split == "train":
train_x.append(self.raw_activations[idx])
train_y.append(lab)
self.paths_train.append(self.video_paths[idx])
elif split == "val":
val_x.append(self.raw_activations[idx])
val_y.append(lab)
self.paths_val.append(self.video_paths[idx])
elif split == "test":
test_x.append(self.raw_activations[idx])
test_y.append(lab)
self.paths_test.append(self.video_paths[idx])
self.test_zero_shot.append(self.video_embeddings[idx])
if missed:
print(
f"[SSv2] Skipped {missed} items (no parseable ID or not in official splits)."
)
if len(train_x) == 0:
raise RuntimeError(
"[SSv2] No training samples matched. Check filename-to-ID parsing and dataset paths."
)
self.encoder = self.encoder.fit(train_y)
self.X_train, self.y_train = train_x, self.encoder.transform(
np.array(train_y, dtype=object)
)
self.X_val, self.y_val = val_x, (
self.encoder.transform(np.array(val_y, dtype=object))
if len(val_x)
else (None, None)
)
self.X_test, self.y_test = test_x, (
self.encoder.transform(np.array(test_y, dtype=object))
if len(test_x)
else (None, None)
)
# ===== end SSv2 =====
if dataset != "something2":
self.X_train = [
self.raw_activations[i]
for i in range(len(self.raw_activations))
if binary_array[i]
]
self.X_test = [
self.raw_activations[i]
for i in range(len(self.raw_activations))
if not binary_array[i]
]
self.y_train = [
self.all_labels[i]
for i in range(len(self.all_labels))
if binary_array[i]
]
self.y_test = [
self.all_labels[i]
for i in range(len(self.all_labels))
if not binary_array[i]
]
self.paths_train = [
self.video_paths[i]
for i in range(len(self.video_paths))
if binary_array[i]
]
self.paths_test = [
self.video_paths[i]
for i in range(len(self.video_paths))
if not binary_array[i]
]
self.encoder = self.encoder.fit(self.y_train)
self.y_train = self.encoder.transform(self.y_train)
self.y_test = self.encoder.transform(self.y_test)
self.test_zero_shot = [
self.video_embeddings[i]
for i in range(len(self.video_embeddings))
if not binary_array[i]
]
else:
# Stratified random split (non-SSv2)
(
self.X_train,
self.X_test,
self.y_train,
self.y_test,
self.paths_train,
self.paths_test,
) = train_test_split(
self.raw_activations,
self.all_labels,
self.video_paths,
test_size=test_size,
random_state=random_state,
stratify=self.all_labels,
)
self.encoder = self.encoder.fit(self.y_train)
self.y_train = self.encoder.transform(self.y_train)
self.y_test = self.encoder.transform(self.y_test)
# ----- Standardization -----
self.mean_c, self.std_c = compute_concept_standardization(self.X_train)
self.X_train = apply_standardization(self.X_train, self.mean_c, self.std_c)
self.X_test = apply_standardization(self.X_test, self.mean_c, self.std_c)
if self.X_val is not None:
self.X_val = apply_standardization(self.X_val, self.mean_c, self.std_c)
# ----- Class weights -----
classes, counts = np.unique(self.y_train, return_counts=True)
self.class_weights = torch.tensor(counts.max() / counts, dtype=torch.float32)
self.num_concepts = self.X_train[0].shape[-1]
self.num_classes = len(classes)
def train_model(
self,
num_epochs: int,
l1_lambda: float,
lambda_sparse: float,
batch_size: int = 8,
lr: float = 1e-4,
weight_decay: float = 1e-2,
enforce_nonneg: bool = True,
class_weights: bool = True,
wandb_run: Optional[wandb.WandbRun] = None,
random_seed: int = 42,
ckpt_path: Optional[str] = None,
early_stopping_patience: int = 50,
):
if wandb_run is not None:
wandb_run.config.update(
{
"num_epochs": num_epochs,
"l1_lambda": l1_lambda,
"lambda_sparse": lambda_sparse,
"lr": lr,
"weight_decay": weight_decay,
"batch_size": batch_size,
"enforce_nonneg": enforce_nonneg,
"class_weights": class_weights,
"transformer_layers": self.model.transformer_layers,
"lse_tau": self.model.lse_tau,
"diagonal_attention": self.model.diagonal_attention,
"early_stopping_patience": early_stopping_patience,
}
)
# move model to device
self.model.to(self.device)
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
if class_weights:
criterion = nn.CrossEntropyLoss(
weight=self.class_weights.to(self.device), label_smoothing=0.1
)
else:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
num_train = len(self.X_train)
best_metric = -float("inf")
best_state = None
best_epoch = -1
epochs_since_improvement = 0
use_early_stopping = (early_stopping_patience is not None) and (
len(self.X_test) > 0
)
for epoch in range(num_epochs):
self.model.train()
correct, total = 0, 0
last_loss, last_L_sparse = None, None
epoch_L_sparse_sum, epoch_batches = 0.0, 0
base_seed = int(getattr(self, "seed", random_seed))
g = torch.Generator(device="cpu").manual_seed(base_seed + epoch)
perm_tensor = torch.randperm(num_train, generator=g)
perm = perm_tensor.tolist()
for start in range(0, num_train, batch_size):
end = min(start + batch_size, num_train)
idx = perm[start:end]
batch_seqs = [self.X_train[i] for i in idx]
batch_labels = torch.tensor(
[int(self.y_train[i]) for i in idx],
dtype=torch.long,
device=self.device,
)
inputs, pad_mask = pad_batch_sequences(batch_seqs, device=self.device)
optimizer.zero_grad()
# updated forward: now returns sharpness
logits, concepts_, concepts_t, sharpness = self.model(
inputs, key_padding_mask=pad_mask
)
valid = (~pad_mask).unsqueeze(-1).float()
last_L_sparse = (concepts_t.abs() * valid).sum() / (
valid.sum() * concepts_t.shape[-1]
).clamp(min=1.0)
ce = criterion(logits, batch_labels)
l1 = l1_lambda * self.model.classifier.weight.abs().sum()
loss = ce + l1 + lambda_sparse * last_L_sparse
loss.backward()
optimizer.step()
last_loss = loss
# accumulate for epoch-average L_sparse
epoch_L_sparse_sum += float(last_L_sparse.detach().item())
epoch_batches += 1
if enforce_nonneg:
with torch.no_grad():
self.model.classifier.weight.clamp_(min=0.0)
preds = logits.argmax(dim=1)
correct += int((preds == batch_labels).sum().item())
total += batch_labels.shape[0]
acc = correct / max(1, total)
epoch_L_sparse = epoch_L_sparse_sum / max(1, epoch_batches)
# ===== evaluation =====
def evaluate(dataset_X, dataset_y):
self.model.eval()
correct, total = 0, 0
sharpness_vals = []
with torch.no_grad():
for start in range(0, len(dataset_X), batch_size):
end = min(start + batch_size, len(dataset_X))
batch_seqs = [dataset_X[i] for i in range(start, end)]
batch_labels = torch.tensor(
[int(dataset_y[i]) for i in range(start, end)],
dtype=torch.long,
device=self.device,
)
inputs, pad_mask = pad_batch_sequences(
batch_seqs, device=self.device
)
logits, _, _, sharpness = self.model(
inputs, key_padding_mask=pad_mask
)
preds = logits.argmax(dim=1)
correct += int((preds == batch_labels).sum().item())
total += batch_labels.shape[0]
for b in range(logits.shape[0]):
sharpness_vals.append(
{
"concepts_max": float(
sharpness["concepts"]["max"][b]
.mean()
.detach()
.cpu()
.item()
),
"concepts_entropy": float(
sharpness["concepts"]["entropy"][b]
.mean()
.detach()
.cpu()
.item()
),
"logits_max": float(
sharpness["logits"]["max"][b]
.mean()
.detach()
.cpu()
.item()
),
"logits_entropy": float(
sharpness["logits"]["entropy"][b]
.mean()
.detach()
.cpu()
.item()
),
}
)
acc = correct / max(1, total)
if sharpness_vals:
mean_sharp = {
k: float(np.mean([s[k] for s in sharpness_vals]))
for k in sharpness_vals[0]
}
else:
mean_sharp = {}
return acc, mean_sharp
test_acc, test_sharp = (
(0.0, {})
if len(self.X_test) == 0
else evaluate(self.X_test, self.y_test)
)
val_acc, val_sharp = (
(0.0, {}) if self.X_val is None else evaluate(self.X_val, self.y_val)
)
metric = test_acc if len(self.X_test) > 0 else acc
# ===== checkpointing =====
if metric > best_metric + 1e-8:
best_metric = metric
best_epoch = epoch
epochs_since_improvement = 0
best_state = {
k: v.detach().cpu().clone()
for k, v in self.model.state_dict().items()
}
if ckpt_path:
tmp = ckpt_path + ".tmp"
torch.save(best_state, tmp)
os.replace(tmp, ckpt_path)
else:
epochs_since_improvement += 1
# ===== wandb logging =====
if wandb_run is not None:
current_lr = (
optimizer.param_groups[0]["lr"] if optimizer.param_groups else None
)
log_data = {
"epoch": epoch + 1,
"train_loss": (
float(last_loss.item()) if last_loss is not None else None
),
"train_acc": acc,
"test_acc": test_acc,
"val_acc": val_acc if self.X_val is not None else None,
"L_sparse": (
float(last_L_sparse.item())
if last_L_sparse is not None
else None
),
"learning_rate": current_lr,
"best_val_acc": best_metric,
"epochs_since_improvement": epochs_since_improvement,
}
# add sharpness metrics
for prefix, sharp in [("test_", test_sharp), ("val_", val_sharp)]:
for k, v in sharp.items():
log_data[prefix + "sharp_" + k] = v
wandb_run.log(log_data)
if epoch % 10 == 0 or epoch == num_epochs - 1:
msg_loss = (
float(last_loss.item()) if last_loss is not None else float("nan")
)
msg_sparse = (
float(last_L_sparse.item())
if last_L_sparse is not None
else float("nan")
)
print(
f"Epoch {epoch+1}/{num_epochs} | loss {msg_loss:.4f} | test_acc {test_acc:.4f} "
f"| train_acc {acc:.4f} | L_sparse {msg_sparse:.4f} "
f"| best_val {best_metric:.4f} | epochs_no_improve {epochs_since_improvement}"
)
# early stopping
if (
use_early_stopping
and epochs_since_improvement >= early_stopping_patience
):
print(
f"[MoTIF] Early stopping triggered (no improvement for {epochs_since_improvement} epochs). Stopping at epoch {epoch+1}."
)
if wandb_run is not None:
wandb_run.log(
{
"early_stopped_epoch": epoch + 1,
"early_stopping_patience": early_stopping_patience,
}
)
break
# ===== restore best =====
if best_state is not None:
self.model.load_state_dict(best_state, strict=True)
self.model.eval()
print(
f"[MoTIF] Restored best weights from epoch {best_epoch+1} (metric={best_metric:.4f})."
)
else:
print("[MoTIF] No best_state captured (empty training?).")
# -------------------------
# PerConceptAffine + CBMTransformer using the per-channel temporal block
# -------------------------
class PerConceptAffine(nn.Module):
def __init__(self, num_concepts: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(num_concepts))
self.bias = nn.Parameter(torch.zeros(num_concepts))
def forward(self, x: torch.Tensor):
## Comment out to test no scaling and bias ablation for paper
y = F.softplus(x * self.scale + self.bias) - math.log(2.0)
return y.clamp(min=0.0)
class CBMTransformer(nn.Module):
def __init__(
self,
num_concepts: int,
num_classes: int,
transformer_layers: int = 1,
dropout: float = 0.1,
lse_tau: float = 1.0,
nonneg_classifier: bool = False,
diagonal_attention: bool = True,
dimension=1,
):
super().__init__()
self.lse_tau = lse_tau
self.diagonal_attention = diagonal_attention
self.transformer_layers = transformer_layers
self.posenc = PositionalEncoding(
d_model=num_concepts, dropout=dropout, max_len=2000
)
if diagonal_attention:
self.layers = nn.ModuleList(
[
PerChannelTemporalBlock(
C=num_concepts, dropout=dropout, d=dimension
)
for _ in range(transformer_layers)
]
)
else:
self.layers = nn.ModuleList(
[
FullAttentionTemporalBlock(
C=num_concepts, num_heads=None, dropout=dropout
)
for _ in range(transformer_layers)
]
)
self.norm = nn.LayerNorm(num_concepts)
self.concept_predictor = PerConceptAffine(num_concepts)
if nonneg_classifier:
self.classifier = NonNegativeLinear(num_concepts, num_classes)
else:
self.classifier = nn.Linear(num_concepts, num_classes)
# for introspection
self.last_time_importance = None # [B,T] detached
def forward(
self,
window_embeddings: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
channel_ids: Optional[Union[List[int], torch.Tensor]] = None,
window_ids: Optional[Union[List[int], torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
window_embeddings: [B,T,C] or [T,C]
key_padding_mask: [B,T] with True for padded tokens to be ignored
Returns:
logits: [B,K] pooled class logits
concepts: [B,C] pooled concept activations
concepts_t: [B,T,C] per-time-step concepts
sharpness: dict with 'concepts' and 'logits' sharpness per batch
"""
x = window_embeddings
if x.dim() == 2:
x = x.unsqueeze(0) # [1,T,C]
if key_padding_mask is not None and key_padding_mask.dim() == 1:
key_padding_mask = key_padding_mask.unsqueeze(0)
# --- transformer backbone ---
x = self.posenc(x) # [B,T,C]
for layer in self.layers:
x = layer(x, key_padding_mask=key_padding_mask)
x = self.norm(x) # [B,T,C]
# --- concept predictions per time step ---
concepts_t = self.concept_predictor(x) # [B,T,C]
# --- concept interventions ---
if channel_ids is not None and window_ids is not None:
concepts_t[:, window_ids, channel_ids] = 0
elif channel_ids is not None:
concepts_t[:, :, channel_ids] = 0
elif window_ids is not None:
concepts_t[:, window_ids, :] = 0
logits_t = self.classifier(concepts_t) # [B,T,K]
tau = self.lse_tau
# --- LSE pooling over time ---
if key_padding_mask is not None:
concepts_t_masked = concepts_t.masked_fill(
key_padding_mask.unsqueeze(-1), float("-inf")
)
logits_t_masked = logits_t.masked_fill(
key_padding_mask.unsqueeze(-1), float("-inf")
)
concepts = (concepts_t_masked * tau).logsumexp(dim=1) / tau # [B,C]
logits = (logits_t_masked * tau).logsumexp(dim=1) / tau # [B,K]
else:
concepts = (concepts_t * tau).logsumexp(dim=1) / tau
logits = (logits_t * tau).logsumexp(dim=1) / tau
# --- temporal importance for explanation ---
with torch.no_grad():
pred = logits.argmax(dim=1) # [B]
sel = torch.gather(logits_t, dim=2, index=pred[:, None, None]).squeeze(
-1
) # [B,T]
if key_padding_mask is not None:
sel = sel.masked_fill(key_padding_mask, float("-inf"))
self.last_time_importance = torch.softmax(
sel / tau, dim=1
).detach() # softmax importance
# --- compute sharpness of LSE pooled distributions ---
def compute_sharpness(x_t, mask=None):
"""Compute max / entropy as sharpness metric for batch"""
if mask is not None:
x_t = x_t.masked_fill(mask.unsqueeze(-1), float("-inf"))
probs = torch.softmax(tau * x_t, dim=1)
probs = probs.clamp(min=1e-8) # avoids log(0)
max_prob = probs.max(dim=1).values # [B]
entropy = -(probs * probs.log()).sum(dim=1)
return {"max": max_prob, "entropy": entropy}
sharpness = {
"concepts": compute_sharpness(concepts_t, key_padding_mask),
"logits": compute_sharpness(logits_t, key_padding_mask),
}
return logits, concepts, concepts_t, sharpness
def get_attention_maps(self):
# list of [B, C, T, T] (detached)
return [
layer.attn_weights.cpu() if layer.attn_weights is not None else None
for layer in self.layers
]
def mean_cbm(model, wandb_run=None):
X_train, X_test = model.X_train.copy(), model.X_test.copy()
y_train, y_test = model.y_train.copy(), model.y_test.copy()
num_classes = model.num_classes
num_concepts = model.num_concepts
batch_size = 1
device = getattr(model, "device", get_torch_device())
random = False # was for testing
if random:
def get_random_image(x):
idx = np.random.randint(0, len(x))
return x[idx]
# Replace each video with a random frame (as np array)
X_train_random = [get_random_image(x) for x in X_train]
X_test_random = [get_random_image(x) for x in X_test]
X_train_mean = X_train_random
X_test_mean = X_test_random
else:
# take mean
X_train_mean = [torch.mean(x, axis=0) for x in X_train] # [T,C] -> [C]
X_test_mean = [torch.mean(x, axis=0) for x in X_test] # [T,C] -> [C]
# Stack into arrays before converting to torch tensors
X_train_arr = np.stack(
[
t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t)
for t in X_train_mean
]
)
X_test_arr = np.stack(
[
t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t)
for t in X_test_mean
]
)
tensor_train = torch.tensor(X_train_arr, dtype=torch.float32, device=device)
tensor_test = torch.tensor(X_test_arr, dtype=torch.float32, device=device)
# train a linear model on the random/mean frames
linear_model = nn.Linear(num_concepts, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_model.parameters(), lr=0.001)
num_epochs = 200
for epoch in range(num_epochs):
linear_model.train()
optimizer.zero_grad()
outputs = linear_model(tensor_train)
loss = criterion(
outputs, torch.tensor(y_train, dtype=torch.long, device=device)
)
loss.backward()
optimizer.step()
if wandb_run is not None:
with torch.no_grad():
preds = outputs.argmax(dim=1)
acc = (preds.detach().cpu().numpy() == y_train).mean()
current_lr = (
optimizer.param_groups[0]["lr"] if optimizer.param_groups else None
)
wandb_run.log(
{
"mean_train_loss": loss.item(),
"mean_train_acc": acc,
"mean_learning_rate": current_lr,
}
)
linear_model.eval()
with torch.no_grad():
outputs = linear_model(tensor_test)
_, predicted = torch.max(outputs, 1)
accuracy = (predicted.detach().cpu().numpy() == y_test).mean()
print(f"CBM accuracy test: {accuracy:.4f}")
if wandb_run is not None:
wandb_run.log({"mean_test_acc": accuracy})
class NonNegativeLinear:
def __init__(self, in_features, out_features, bias=True):
self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
self.linear.weight.data.clamp_(min=0.0)
return self.linear(x)