StrokeMitra-API / src /explainability /attention_rollout.py
DhruvB1906's picture
Upload folder using huggingface_hub
4e9a3bc verified
"""Attention rollout for transformer models."""
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
class AttentionRollout:
"""
Attention rollout for visualizing transformer attention patterns.
Computes the flow of information through transformer layers.
"""
def __init__(self, model, num_layers=24, num_heads=16):
"""
Initialize attention rollout.
Args:
model: Transformer model (e.g., HuBERT)
num_layers: Number of transformer layers
num_heads: Number of attention heads
"""
self.model = model
self.num_layers = num_layers
self.num_heads = num_heads
self.attention_maps = []
self._register_hooks()
def _register_hooks(self):
"""Register hooks to capture attention weights."""
def attention_hook(module, input, output):
# Capture attention weights
# output format: (batch, heads, seq_len, seq_len)
if hasattr(output, "attentions") and output.attentions is not None:
self.attention_maps.append(output.attentions.detach())
# Register hooks on transformer attention layers
# This is model-specific; adjust for your architecture
for name, module in self.model.named_modules():
if "attention" in name.lower():
module.register_forward_hook(attention_hook)
logger.info("Attention hooks registered")
def compute_rollout(
self,
input_tensor: torch.Tensor,
head_fusion: str = "mean",
) -> np.ndarray:
"""
Compute attention rollout.
Args:
input_tensor: Model input
head_fusion: How to fuse attention heads ('mean', 'max', 'min')
Returns:
Attention rollout array (seq_len,) showing frame importance
"""
self.attention_maps = []
self.model.eval()
with torch.no_grad():
_ = self.model(input_tensor)
if len(self.attention_maps) == 0:
logger.warning("No attention maps captured. Using uniform attention.")
seq_len = input_tensor.shape[1] if input_tensor.ndim > 1 else 100
return np.ones(seq_len) / seq_len
# Fuse attention heads
fused_attentions = []
for attn in self.attention_maps:
# attn: (batch, heads, seq_len, seq_len)
if head_fusion == "mean":
fused = attn.mean(dim=1) # (batch, seq_len, seq_len)
elif head_fusion == "max":
fused = attn.max(dim=1)[0]
elif head_fusion == "min":
fused = attn.min(dim=1)[0]
else:
fused = attn.mean(dim=1)
fused_attentions.append(fused.squeeze(0).cpu().numpy()) # (seq_len, seq_len)
# Multiply attention across layers (rollout)
rollout = np.eye(fused_attentions[0].shape[0]) # Identity matrix
for attn in fused_attentions:
rollout = rollout @ attn
# Average attention to each position
importance = rollout.mean(axis=0) # (seq_len,)
# Normalize
importance = importance / (importance.sum() + 1e-8)
logger.info(f"Attention rollout computed: {len(fused_attentions)} layers")
return importance
def get_top_k_frames(
self,
importance: np.ndarray,
k: int = 5,
hop_length: int = 512,
sr: int = 16000,
) -> list[tuple[float, float, float]]:
"""
Get top-k most important time frames.
Args:
importance: Frame importance scores (seq_len,)
k: Number of top frames to return
hop_length: Hop length used in feature extraction
sr: Sample rate
Returns:
List of (start_sec, end_sec, importance_score)
"""
# Get top k indices
top_indices = np.argsort(importance)[-k:][::-1]
segments = []
for idx in top_indices:
start_sec = (idx * hop_length) / sr
end_sec = ((idx + 1) * hop_length) / sr
score = float(importance[idx])
segments.append((start_sec, end_sec, score))
return segments