|
|
""" |
|
|
CBM models and utilities consolidated from the Video_cbm.ipynb notebook. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from typing import List, Optional, Dict, Tuple |
|
|
|
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
import math |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
import re |
|
|
import pandas as pd |
|
|
import glob |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib as mpl |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def explain_instance( |
|
|
model: nn.Module, |
|
|
window_embeddings: torch.Tensor, |
|
|
key_padding_mask: Optional[torch.Tensor] = None, |
|
|
channel_ids: Optional[Union[List[int], torch.Tensor]] = None, |
|
|
window_ids: Optional[Union[List[int], torch.Tensor]] = None, |
|
|
target_class: Optional[int] = None, |
|
|
window_spans: Optional[List[Tuple[int, int]]] = None, |
|
|
fps: Optional[float] = None, |
|
|
): |
|
|
|
|
|
device = next(model.parameters(), torch.empty(0)).device |
|
|
x = window_embeddings.to(device) |
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(0) |
|
|
if key_padding_mask is not None and key_padding_mask.dim() == 1: |
|
|
key_padding_mask = key_padding_mask.unsqueeze(0) |
|
|
|
|
|
|
|
|
logits, concepts, concepts_t, sharpness = model( |
|
|
x, |
|
|
key_padding_mask=key_padding_mask, |
|
|
channel_ids=channel_ids, |
|
|
window_ids=window_ids, |
|
|
) |
|
|
|
|
|
|
|
|
logits_t = model.classifier(concepts_t) |
|
|
|
|
|
|
|
|
if target_class is None: |
|
|
target_class = int(logits[0].argmax().item()) |
|
|
|
|
|
|
|
|
concepts_t_1 = concepts_t[0] |
|
|
logits_t_1 = logits_t[0] |
|
|
|
|
|
|
|
|
w = model.classifier.weight[target_class] |
|
|
b = ( |
|
|
0.0 |
|
|
if model.classifier.bias is None |
|
|
else float(model.classifier.bias[target_class].item()) |
|
|
) |
|
|
|
|
|
|
|
|
contrib_t = concepts_t_1 * w.unsqueeze(0) |
|
|
score_per_time = contrib_t.sum(dim=1) + b |
|
|
|
|
|
|
|
|
tau = float(model.lse_tau) |
|
|
time_scores = logits_t_1[:, target_class] |
|
|
if key_padding_mask is not None: |
|
|
time_scores = time_scores.masked_fill(key_padding_mask[0], float("-inf")) |
|
|
time_importance = torch.softmax(time_scores / tau, dim=0) |
|
|
|
|
|
|
|
|
contrib_global = (time_importance.unsqueeze(1) * contrib_t).sum(dim=0) |
|
|
|
|
|
|
|
|
res = { |
|
|
"target_class": torch.tensor(target_class), |
|
|
"logits": logits[0].detach().cpu(), |
|
|
"logits_per_time": logits_t_1.detach().cpu(), |
|
|
"concepts": concepts[0].detach().cpu(), |
|
|
"concepts_per_time": concepts_t_1.detach().cpu(), |
|
|
"time_importance": time_importance.detach().cpu(), |
|
|
"score_per_time": score_per_time.detach().cpu(), |
|
|
"concept_contributions_per_time": contrib_t.detach().cpu(), |
|
|
"concept_contributions_global": contrib_global.detach().cpu(), |
|
|
"sharpness": { |
|
|
k: {m: v.detach().cpu() for m, v in d.items()} for k, d in sharpness.items() |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if window_spans is not None and len(window_spans) == concepts_t_1.shape[0]: |
|
|
res["frame_spans"] = torch.tensor(window_spans, dtype=torch.long) |
|
|
if fps is not None and fps > 0: |
|
|
res["second_spans"] = torch.tensor( |
|
|
[(s / fps, e / fps) for (s, e) in window_spans], dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
attn = [getattr(layer, "attn_weights", None) for layer in model.layers] |
|
|
if any(a is not None for a in attn): |
|
|
res["attn_per_layer"] = [ |
|
|
a[0].detach().cpu() if a is not None else None for a in attn |
|
|
] |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def _bar(x, width=20): |
|
|
|
|
|
n = int(round(x * width)) |
|
|
return "█" * n + "·" * (width - n) |
|
|
|
|
|
|
|
|
def print_explanation( |
|
|
res: dict, |
|
|
fps_frame: dict, |
|
|
concepts_list: Optional[List[str]] = None, |
|
|
top_k_times: int = 3, |
|
|
top_k_concepts: int = 8, |
|
|
by_abs: bool = True, |
|
|
positive_only: bool = True, |
|
|
): |
|
|
|
|
|
def td(x): |
|
|
return x.detach().cpu() if isinstance(x, torch.Tensor) else x |
|
|
|
|
|
ti = td(res["time_importance"]).flatten() |
|
|
spt = td(res["score_per_time"]).flatten() |
|
|
cpt = td(res["concept_contributions_per_time"]) |
|
|
cglob = td(res["concept_contributions_global"]).flatten() |
|
|
tgt = res["target_class"] |
|
|
target_class = int(tgt.item()) if hasattr(tgt, "item") else int(tgt) |
|
|
|
|
|
T, C = ti.shape[0], cglob.shape[0] |
|
|
if concepts_list is None: |
|
|
concepts_list = [f"c{j}" for j in range(C)] |
|
|
|
|
|
|
|
|
frame_spans = res.get("frame_spans", None) |
|
|
second_spans = res.get("second_spans", None) |
|
|
|
|
|
|
|
|
ti_norm = (ti - ti.min()) / (ti.max() - ti.min() + 1e-8) |
|
|
spt_norm = (spt - spt.min()) / (spt.max() - spt.min() + 1e-8) |
|
|
|
|
|
|
|
|
rank_vals = cglob.abs() if by_abs else cglob |
|
|
if positive_only: |
|
|
|
|
|
rank_vals = torch.where(cglob > 0, rank_vals, torch.zeros_like(rank_vals)) |
|
|
top_k_concepts = min(top_k_concepts, int((rank_vals > 0).sum().item())) |
|
|
topc_vals, topc_idx = torch.topk(rank_vals, k=min(top_k_concepts, C)) |
|
|
print(f"Target class: {target_class}\n") |
|
|
print("Top concepts (global):") |
|
|
for _, j in zip(topc_vals, topc_idx): |
|
|
j = int(j) |
|
|
name = concepts_list[j] if j < len(concepts_list) else f"c{j}" |
|
|
val = float(cglob[j]) |
|
|
|
|
|
mag = abs(val) |
|
|
mag_norm = mag / (float(cglob.abs().max()) + 1e-8) |
|
|
print(f" {name:30s} {val:+.3f} {_bar(mag_norm)}") |
|
|
|
|
|
|
|
|
_, topt_idx = torch.topk(ti, k=min(top_k_times, T)) |
|
|
topt_idx = sorted(topt_idx.tolist(), key=lambda t: float(ti[t]), reverse=True) |
|
|
|
|
|
print("\nImportant time steps:") |
|
|
for t in topt_idx: |
|
|
t_imp = float(ti[t]) |
|
|
extras = [] |
|
|
if frame_spans is not None: |
|
|
fs = frame_spans[t] |
|
|
extras.append(f"frames=[{int(fs[0])},{int(fs[1])}]") |
|
|
if second_spans is not None: |
|
|
ss = second_spans[t] |
|
|
extras.append(f"sec=[{float(ss[0]):.2f},{float(ss[1]):.2f}]") |
|
|
extra_str = (" " + " ".join(extras)) if extras else "" |
|
|
start, end = fps_frame[t] |
|
|
print( |
|
|
f" t=[{int(start//60):02d}:{start%60:05.2f} - {int(end//60):02d}:{end%60:05.2f}] time_importance={t_imp:.3f} TI[{_bar(float(ti_norm[t]))}] Score[{_bar(float(spt_norm[t]))}]" |
|
|
+ extra_str |
|
|
) |
|
|
ct = cpt[t] |
|
|
|
|
|
rank_vals_t = ct.abs() if by_abs else ct |
|
|
if positive_only: |
|
|
|
|
|
rank_vals_t = torch.where(ct > 0, rank_vals_t, torch.zeros_like(rank_vals_t)) |
|
|
k = min(top_k_concepts, int((rank_vals_t > 0).sum().item()), C) |
|
|
else: |
|
|
k = min(top_k_concepts, C) |
|
|
vals, idxs = torch.topk(rank_vals_t, k=k) |
|
|
|
|
|
denom = float(ct.abs().max()) + 1e-8 |
|
|
for j_rank in idxs: |
|
|
j = int(j_rank) |
|
|
name = concepts_list[j] if j < len(concepts_list) else f"c{j}" |
|
|
val = float(ct[j]) |
|
|
print(f" - {name:30s} {val:+.3f} {_bar(abs(val)/denom)}") |
|
|
|
|
|
|
|
|
def print_explanation_with_labels( |
|
|
res: dict, |
|
|
fps_frame: dict, |
|
|
label_decoder: LabelEncoder, |
|
|
true_label_idx: int, |
|
|
positive_only: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
pred = res["target_class"] |
|
|
pred_idx = int(pred.item()) if hasattr(pred, "item") else int(pred) |
|
|
true_idx = int(true_label_idx) |
|
|
pred_name = label_decoder.inverse_transform([pred_idx])[0] |
|
|
true_name = label_decoder.inverse_transform([true_idx])[0] |
|
|
print(f"Predicted: {pred_idx} ({pred_name}) | True: {true_idx} ({true_name})") |
|
|
print_explanation(res, fps_frame, positive_only=positive_only, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_batch_sequences( |
|
|
seqs: List[torch.Tensor], device: torch.device |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Pad a list of [T_i, C] tensors into a batch [B, T_max, C] and return |
|
|
a key_padding_mask [B, T_max] with True for padded positions. |
|
|
""" |
|
|
if len(seqs) == 0: |
|
|
raise ValueError("pad_batch_sequences received empty sequence list") |
|
|
lengths = [int(s.shape[0]) for s in seqs] |
|
|
C = int(seqs[0].shape[1]) |
|
|
T_max = int(max(lengths)) |
|
|
B = len(seqs) |
|
|
batch = torch.zeros((B, T_max, C), dtype=torch.float32, device=device) |
|
|
mask = torch.ones((B, T_max), dtype=torch.bool, device=device) |
|
|
for i, s in enumerate(seqs): |
|
|
t = lengths[i] |
|
|
batch[i, :t, :] = s.to(device) |
|
|
mask[i, :t] = False |
|
|
return batch, mask |
|
|
|
|
|
|
|
|
def _cv_bar_img(frac: float, width: int = 160, height: int = 8) -> np.ndarray: |
|
|
frac = float(max(0.0, min(1.0, frac))) |
|
|
w = max(1, int(round(frac * width))) |
|
|
bar = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
bar[:, :w, :] = 255 |
|
|
return bar |
|
|
|
|
|
|
|
|
def _put_text_multiline( |
|
|
img, |
|
|
lines, |
|
|
org, |
|
|
line_h, |
|
|
font=cv2.FONT_HERSHEY_SIMPLEX, |
|
|
font_scale=0.40, |
|
|
thickness=1, |
|
|
color=(255, 255, 255), |
|
|
): |
|
|
x, y = org |
|
|
for i, line in enumerate(lines): |
|
|
cv2.putText( |
|
|
img, |
|
|
line, |
|
|
(x, y + i * line_h), |
|
|
font, |
|
|
font_scale, |
|
|
color, |
|
|
thickness, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
|
|
|
def _safe_paste_bar(frame: np.ndarray, x: int, y: int, bar: np.ndarray) -> None: |
|
|
H, W = frame.shape[:2] |
|
|
bh, bw = bar.shape[:2] |
|
|
x1 = max(0, x) |
|
|
y1 = max(0, y) |
|
|
x2 = min(W, x + bw) |
|
|
y2 = min(H, y + bh) |
|
|
if x1 >= x2 or y1 >= y2: |
|
|
return |
|
|
bx1 = x1 - x |
|
|
by1 = y1 - y |
|
|
bx2 = bx1 + (x2 - x1) |
|
|
by2 = by1 + (y2 - y1) |
|
|
roi = frame[y1:y2, x1:x2] |
|
|
bar_crop = bar[by1:by2, bx1:bx2] |
|
|
np.maximum(roi, bar_crop, out=roi) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def render_explained_video_small_tl( |
|
|
vid_path: str, |
|
|
out_path: str, |
|
|
res: dict, |
|
|
fps_frame_seconds: List[Tuple[float, float]], |
|
|
label_decoder, |
|
|
true_label_idx: int, |
|
|
concepts_list: Optional[List[str]] = None, |
|
|
top_k_times: int = 3, |
|
|
top_k_concepts: int = 4, |
|
|
by_abs: bool = True, |
|
|
up_scale: float = 2.0, |
|
|
margin: int = 10, |
|
|
panel_w_px: int = 300, |
|
|
panel_alpha: float = 0.70, |
|
|
font_scale: float = 0.40, |
|
|
thickness: int = 1, |
|
|
codec: str = "mp4v", |
|
|
) -> str: |
|
|
cap = cv2.VideoCapture(vid_path) |
|
|
if not cap.isOpened(): |
|
|
raise RuntimeError(f"Could not open video: {vid_path}") |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
|
|
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
F = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
outW = int(round(W * up_scale)) |
|
|
outH = int(round(H * up_scale)) |
|
|
writer = cv2.VideoWriter( |
|
|
out_path, cv2.VideoWriter_fourcc(*codec), fps, (outW, outH) |
|
|
) |
|
|
if not writer.isOpened(): |
|
|
cap.release() |
|
|
raise RuntimeError(f"Could not open writer for: {out_path}") |
|
|
|
|
|
|
|
|
ti = res["time_importance"].detach().cpu().float() |
|
|
cpt = res["concept_contributions_per_time"].detach().cpu() |
|
|
tgt = res["target_class"] |
|
|
pred_idx = int(tgt.item()) if hasattr(tgt, "item") else int(tgt) |
|
|
|
|
|
T = ti.shape[0] |
|
|
C = cpt.shape[1] |
|
|
if concepts_list is None: |
|
|
concepts_list = [f"c{j}" for j in range(C)] |
|
|
|
|
|
try: |
|
|
pred_name = label_decoder.inverse_transform([pred_idx])[0] |
|
|
true_name = label_decoder.inverse_transform([int(true_label_idx)])[0] |
|
|
except Exception: |
|
|
pred_name = str(pred_idx) |
|
|
true_name = str(true_label_idx) |
|
|
|
|
|
|
|
|
if top_k_times == 0: |
|
|
top_k_times = T |
|
|
kT = min(top_k_times, T) |
|
|
_, topt_idx = torch.topk(ti, k=kT, largest=True, sorted=True) |
|
|
important_t = set(int(i) for i in topt_idx.tolist()) |
|
|
|
|
|
|
|
|
per_t_top = [] |
|
|
for t in range(T): |
|
|
ct = cpt[t] |
|
|
rank_vals = ct.abs() if by_abs else ct |
|
|
kk = min(top_k_concepts, C) |
|
|
_, idxs = torch.topk(rank_vals, k=kk, largest=True, sorted=True) |
|
|
denom = float(ct.abs().max().item()) + 1e-8 |
|
|
entries = [] |
|
|
for j in idxs.tolist(): |
|
|
name = concepts_list[j] if j < len(concepts_list) else f"c{j}" |
|
|
sval = float(ct[j].item()) |
|
|
frac = min(1.0, abs(sval) / denom) if denom > 0 else 0.0 |
|
|
entries.append((name, sval, frac)) |
|
|
per_t_top.append(entries) |
|
|
|
|
|
|
|
|
frame_to_t = [None] * F |
|
|
for t, (ss, es) in enumerate(fps_frame_seconds): |
|
|
fs = max(0, int(round(ss * fps))) |
|
|
fe = min(F - 1, int(round(es * fps))) |
|
|
for f in range(fs, fe + 1): |
|
|
frame_to_t[f] = t |
|
|
|
|
|
|
|
|
|
|
|
line_h = 16 |
|
|
rows = 2 + top_k_concepts |
|
|
panel_h_px = 18 + rows * line_h + 12 |
|
|
x0, y0 = margin, margin |
|
|
panel_rect = (x0, y0, panel_w_px, panel_h_px) |
|
|
|
|
|
fidx = 0 |
|
|
try: |
|
|
while True: |
|
|
ok, frame = cap.read() |
|
|
if not ok: |
|
|
break |
|
|
|
|
|
|
|
|
frame = cv2.resize(frame, (outW, outH), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
t = frame_to_t[fidx] if fidx < len(frame_to_t) else None |
|
|
if (t is not None) and (t in important_t): |
|
|
|
|
|
overlay = frame.copy() |
|
|
x, y, pw, ph = panel_rect |
|
|
cv2.rectangle(overlay, (x, y), (x + pw, y + ph), (0, 0, 0), -1) |
|
|
cv2.addWeighted(overlay, panel_alpha, frame, 1 - panel_alpha, 0, frame) |
|
|
|
|
|
|
|
|
sec = fidx / fps |
|
|
ss, es = fps_frame_seconds[t] |
|
|
header = [ |
|
|
f"Pred:{pred_name} | True:{true_name}", |
|
|
f"t={t} TI={float(ti[t]):.3f} [{ss:.2f}-{es:.2f}]s", |
|
|
] |
|
|
_put_text_multiline( |
|
|
frame, |
|
|
header, |
|
|
(x + 8, y + 18), |
|
|
line_h, |
|
|
font_scale=font_scale, |
|
|
thickness=thickness, |
|
|
) |
|
|
|
|
|
|
|
|
y_cursor = y + 18 + line_h * len(header) + 2 |
|
|
for name, sval, frac in per_t_top[t]: |
|
|
bar = _cv_bar_img(frac, width=120, height=8) |
|
|
bx, by = x + 8, int(y_cursor - 8) |
|
|
_safe_paste_bar(frame, bx, by, bar) |
|
|
cv2.putText( |
|
|
frame, |
|
|
f"{name[:16]:16s} {sval:+.2f}", |
|
|
(bx + bar.shape[1] + 8, int(y_cursor + 4)), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
font_scale, |
|
|
(255, 255, 255), |
|
|
thickness, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
y_cursor += line_h |
|
|
|
|
|
writer.write(frame) |
|
|
fidx += 1 |
|
|
finally: |
|
|
cap.release() |
|
|
writer.release() |
|
|
|
|
|
return out_path |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def print_temporal_dependencies( |
|
|
res: dict, |
|
|
top_k_times: int = 5, |
|
|
top_k_links: int = 5, |
|
|
concept_idx: Optional[ |
|
|
int |
|
|
] = None, |
|
|
layer_agg: str = "mean", |
|
|
head_or_concept_agg: str = "mean", |
|
|
focus_times: Optional[ |
|
|
List[int] |
|
|
] = None, |
|
|
by_abs: bool = False, |
|
|
): |
|
|
""" |
|
|
Print temporal dependencies (attention) between timesteps. |
|
|
|
|
|
Handles both per-channel attention [C,T,T] and full-attention [H,T,T]. |
|
|
|
|
|
Strategy: |
|
|
1) Load attention maps per layer. |
|
|
2) If shape is [C,T,T] (per-channel), either select 'concept_idx' or aggregate across concepts. |
|
|
If shape is [H,T,T] (full), aggregate across heads. |
|
|
3) Aggregate across layers via mean/max. |
|
|
4) Choose which timesteps to display: |
|
|
- 'focus_times' if given, |
|
|
- else top 'top_k_times' by res["time_importance"] (if available), |
|
|
- else first 'top_k_times'. |
|
|
5) For each chosen timestep t, print top 'top_k_links' target timesteps u with largest attention weight. |
|
|
""" |
|
|
attn_layers = res.get("attn_per_layer", None) |
|
|
if not attn_layers or all(a is None for a in attn_layers): |
|
|
print( |
|
|
"[temporal] No attention maps available in 'res'. Ensure your model layers store 'attn_weights'." |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
mats = [] |
|
|
for a in attn_layers: |
|
|
if a is None: |
|
|
continue |
|
|
|
|
|
if not torch.is_tensor(a): |
|
|
a = torch.as_tensor(a) |
|
|
mats.append(a.float()) |
|
|
|
|
|
if len(mats) == 0: |
|
|
print("[temporal] No attention maps available after filtering.") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
G, T, T2 = mats[0].shape |
|
|
assert T == T2, f"Expected square attention [G,T,T], got {mats[0].shape}" |
|
|
|
|
|
|
|
|
def agg_g(x: torch.Tensor) -> torch.Tensor: |
|
|
if concept_idx is not None and x.shape[0] > concept_idx: |
|
|
return x[concept_idx] |
|
|
if head_or_concept_agg == "max": |
|
|
return x.max(dim=0).values |
|
|
return x.mean(dim=0) |
|
|
|
|
|
mats_agg_g = [agg_g(a) for a in mats] |
|
|
|
|
|
|
|
|
stack = torch.stack(mats_agg_g, dim=0) |
|
|
if layer_agg == "max": |
|
|
A = stack.max(dim=0).values |
|
|
else: |
|
|
A = stack.mean(dim=0).values if hasattr(stack, "values") else stack.mean(dim=0) |
|
|
if isinstance(A, torch.return_types.max): |
|
|
A = A.values |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if focus_times is not None and len(focus_times) > 0: |
|
|
query_times = [t for t in focus_times if 0 <= t < T] |
|
|
else: |
|
|
ti = res.get("time_importance", None) |
|
|
if isinstance(ti, torch.Tensor) and ti.numel() == T: |
|
|
vals, idx = torch.topk(ti, k=min(top_k_times, T)) |
|
|
query_times = idx.tolist() |
|
|
|
|
|
query_times = sorted(query_times, key=lambda t: float(ti[t]), reverse=True) |
|
|
else: |
|
|
query_times = list(range(min(top_k_times, T))) |
|
|
|
|
|
|
|
|
second_spans = res.get("second_spans", None) |
|
|
|
|
|
def _fmt_time(ti_): |
|
|
if ( |
|
|
second_spans is not None |
|
|
and hasattr(second_spans, "__len__") |
|
|
and len(second_spans) == T |
|
|
): |
|
|
ss, es = second_spans[ti_] |
|
|
return f"t={ti_} [{float(ss):.2f}-{float(es):.2f}s]" |
|
|
return f"t={ti_}" |
|
|
|
|
|
|
|
|
tgt = res.get("target_class", None) |
|
|
if tgt is not None: |
|
|
tc = int(tgt.item()) if hasattr(tgt, "item") else int(tgt) |
|
|
print(f"[temporal] Target class: {tc}") |
|
|
if concept_idx is not None: |
|
|
print(f"[temporal] Using per-channel attention for concept c={concept_idx}") |
|
|
else: |
|
|
print( |
|
|
f"[temporal] Aggregation over {'concepts' if G==A.shape[0] else 'heads'}: {head_or_concept_agg}, layers: {layer_agg}" |
|
|
) |
|
|
|
|
|
|
|
|
for t in query_times: |
|
|
row = A[t] |
|
|
|
|
|
|
|
|
rank_vals = row.abs() if by_abs else row |
|
|
k = min(top_k_links, T) |
|
|
vals, idxs = torch.topk(rank_vals, k=k, largest=True, sorted=True) |
|
|
|
|
|
|
|
|
print(f"\n{_fmt_time(t)} (row-softmaxed attention to other timesteps)") |
|
|
|
|
|
denom = float(rank_vals[idxs[0]] + 1e-12) |
|
|
for j, v in zip(idxs.tolist(), vals.tolist()): |
|
|
w = float(row[j]) |
|
|
rel = max(0.0, min(1.0, float(abs(v) / denom))) |
|
|
bar = ( |
|
|
_bar(rel) |
|
|
if " _bar" in globals() or "_bar" in locals() |
|
|
else f"{rel:.2f}" |
|
|
) |
|
|
if second_spans is not None and len(second_spans) == T: |
|
|
ss, es = second_spans[j] |
|
|
target_str = f"u={j} [{float(ss):.2f}-{float(es):.2f}s]" |
|
|
else: |
|
|
target_str = f"u={j}" |
|
|
print(f" -> {target_str:18s} w={w:+.4f} {bar}") |
|
|
|
|
|
|
|
|
def _fmt_sec(sec: float) -> str: |
|
|
|
|
|
m = int(sec // 60) |
|
|
s = sec - 60 * m |
|
|
return f"{m}:{s:05.2f}s" if m else f"{s:.2f}s" |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def plot_attention_heatmaps( |
|
|
res: dict, |
|
|
concept_idx: Optional[ |
|
|
int |
|
|
] = None, |
|
|
concept_names: Optional[List[str]] = None, |
|
|
layer_idxs: Optional[List[int]] = None, |
|
|
layer_agg: Optional[str] = None, |
|
|
head_or_concept_agg: str = "mean", |
|
|
normalize_rows: bool = True, |
|
|
show_seconds: bool = True, |
|
|
cmap: str = "magma", |
|
|
figsize: Tuple[int, int] = (5, 4), |
|
|
savepath: Optional[str] = None, |
|
|
title_prefix: str = "Attention", |
|
|
): |
|
|
rc = { |
|
|
"font.family": "serif", |
|
|
"font.serif": ["Times New Roman", "Times", "DejaVu Serif", "Liberation Serif"], |
|
|
"mathtext.fontset": "stix", |
|
|
} |
|
|
|
|
|
attn_layers = res.get("attn_per_layer", None) |
|
|
if not attn_layers or all(a is None for a in attn_layers): |
|
|
print("[heatmap] No attention maps in 'res'.") |
|
|
return |
|
|
|
|
|
mats = [] |
|
|
for a in attn_layers: |
|
|
if a is None: |
|
|
continue |
|
|
a = torch.as_tensor(a, dtype=torch.float32) |
|
|
assert ( |
|
|
a.ndim == 3 and a.shape[-1] == a.shape[-2] |
|
|
), f"Expected [G,T,T], got {tuple(a.shape)}" |
|
|
mats.append(a) |
|
|
if not mats: |
|
|
print("[heatmap] No usable attention maps.") |
|
|
return |
|
|
|
|
|
if layer_idxs is not None: |
|
|
mats = [mats[i] for i in layer_idxs if 0 <= i < len(mats)] |
|
|
if not mats: |
|
|
print("[heatmap] Selected layer_idxs produced empty set.") |
|
|
return |
|
|
|
|
|
G, T, _ = mats[0].shape |
|
|
second_spans = res.get("second_spans", None) |
|
|
|
|
|
def agg_g(x: torch.Tensor) -> torch.Tensor: |
|
|
if concept_idx is not None: |
|
|
if not (0 <= concept_idx < x.shape[0]): |
|
|
raise IndexError( |
|
|
f"concept_idx={concept_idx} out of range [0,{x.shape[0]-1}]." |
|
|
) |
|
|
return x[concept_idx] |
|
|
return x.max(dim=0).values if head_or_concept_agg == "max" else x.mean(dim=0) |
|
|
|
|
|
per_layer = [agg_g(L) for L in mats] |
|
|
|
|
|
plots = [] |
|
|
if layer_agg in (None, ""): |
|
|
for Li, A in enumerate(per_layer): |
|
|
plots.append((Li, A)) |
|
|
elif layer_agg == "mean": |
|
|
plots.append(("mean", torch.stack(per_layer, dim=0).mean(dim=0))) |
|
|
elif layer_agg == "max": |
|
|
plots.append(("max", torch.stack(per_layer, dim=0).max(dim=0).values)) |
|
|
else: |
|
|
raise ValueError("layer_agg must be None, 'mean', or 'max'.") |
|
|
|
|
|
def row_norm(A: torch.Tensor) -> torch.Tensor: |
|
|
if not normalize_rows: |
|
|
return A |
|
|
denom = A.sum(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return A / denom |
|
|
|
|
|
def make_ticks(T: int): |
|
|
step = max(1, T // 8) |
|
|
idxs = list(range(0, T, step)) |
|
|
if idxs[-1] != T - 1: |
|
|
idxs.append(T - 1) |
|
|
if ( |
|
|
show_seconds |
|
|
and isinstance(second_spans, torch.Tensor) |
|
|
and second_spans.shape[0] == T |
|
|
): |
|
|
lbls = [] |
|
|
for i in idxs: |
|
|
ss, es = second_spans[i].tolist() |
|
|
mid = 0.5 * (float(ss) + float(es)) |
|
|
lbls.append(f"u={i} · {_fmt_sec(mid)}") |
|
|
else: |
|
|
lbls = [f"u={i}" for i in idxs] |
|
|
return idxs, lbls |
|
|
|
|
|
figs = [] |
|
|
|
|
|
with mpl.rc_context(rc): |
|
|
for tag, A in plots: |
|
|
A = row_norm(A.detach().cpu()) |
|
|
fig, ax = plt.subplots(figsize=figsize) |
|
|
im = ax.imshow( |
|
|
A, |
|
|
origin="lower", |
|
|
interpolation="nearest", |
|
|
cmap=cmap, |
|
|
vmin=0.0, |
|
|
vmax=float(A.max().item()) or None, |
|
|
) |
|
|
|
|
|
ax.set_xlabel("Key time u (source/context)", fontsize=16) |
|
|
ax.set_ylabel("Query time t (target/current)", fontsize=16) |
|
|
|
|
|
xt, xl = make_ticks(T) |
|
|
yt, yl = make_ticks(T) |
|
|
yl = [lbl.replace("u=", "t=") for lbl in yl] |
|
|
|
|
|
ax.set_xticks(xt) |
|
|
ax.set_xticklabels(xl, rotation=45, ha="right", fontsize=13) |
|
|
ax.set_yticks(yt) |
|
|
ax.set_yticklabels(yl, fontsize=13) |
|
|
|
|
|
cname = None |
|
|
if ( |
|
|
concept_idx is not None |
|
|
and concept_names |
|
|
and 0 <= concept_idx < len(concept_names) |
|
|
): |
|
|
cname = concept_names[concept_idx] |
|
|
tag_str = f"layer={tag}" if isinstance(tag, (int, str)) else str(tag) |
|
|
if concept_idx is not None: |
|
|
title = f" ({cname})" if cname else "" |
|
|
else: |
|
|
title = f" (agg over {'concepts' if G==A.shape[0] else 'heads'})" |
|
|
ax.set_title(title, fontsize=18) |
|
|
|
|
|
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
|
|
cbar.set_label("Attention weight", fontsize=16) |
|
|
|
|
|
fig.tight_layout() |
|
|
if savepath: |
|
|
p = savepath |
|
|
if len(plots) > 1: |
|
|
stem, ext = (savepath.rsplit(".", 1) + ["png"])[:2] |
|
|
p = f"{stem}_{tag_str}.{ext}" |
|
|
fig.savefig(p, dpi=150, bbox_inches="tight") |
|
|
figs.append(fig) |
|
|
|
|
|
plt.show() |
|
|
|
|
|
return figs |
|
|
|