api / backend /icl_attention_extractor.py
gary-boon
Add Code Llama 7B support with hardware-aware filtering and ICL timeout fixes
ed40a9a
"""
Real Attention Extraction for In-Context Learning Analysis
This module hooks into transformer models to extract actual attention weights
during generation, providing real data for ICL analysis.
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class AttentionData:
"""Stores attention data from model generation"""
layer_attentions: List[torch.Tensor] # Attention from each layer
token_positions: List[int] # Position of each generated token
example_boundaries: List[Tuple[int, int]] # Start/end positions of examples
class AttentionExtractor:
"""Extracts real attention patterns from transformer models during generation"""
def __init__(self, model, tokenizer, adapter=None):
self.model = model
self.tokenizer = tokenizer
self.adapter = adapter # Model adapter for multi-architecture support
self.device = next(model.parameters()).device
# Storage for attention during generation
self.attention_weights = []
self.handles = []
def register_hooks(self):
"""Register forward hooks to capture attention weights"""
self.clear_hooks()
# Use adapter if available for multi-architecture support
if self.adapter:
num_layers = self.adapter.get_num_layers()
for i in range(num_layers):
attn_module = self.adapter.get_attention_module(i)
if attn_module:
handle = attn_module.register_forward_hook(
lambda module, input, output, layer_idx=i:
self._attention_hook(module, input, output, layer_idx)
)
self.handles.append(handle)
# Fallback for CodeGen models without adapter
elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
# Hook into each transformer layer
for i, layer in enumerate(self.model.transformer.h):
if hasattr(layer, 'attn'):
handle = layer.attn.register_forward_hook(
lambda module, input, output, layer_idx=i:
self._attention_hook(module, input, output, layer_idx)
)
self.handles.append(handle)
logger.info(f"Registered {len(self.handles)} attention hooks")
def _attention_hook(self, module, input, output, layer_idx):
"""Hook function to capture attention weights"""
# For CodeGen, output is (hidden_states, attention_weights)
if isinstance(output, tuple) and len(output) >= 2:
attention = output[1]
if attention is not None:
# Store attention weights
self.attention_weights.append({
'layer': layer_idx,
'attention': attention.detach().cpu()
})
def clear_hooks(self):
"""Remove all hooks"""
for handle in self.handles:
handle.remove()
self.handles = []
self.attention_weights = []
def extract_attention_with_generation(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 0.7
) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]:
"""Generate text while extracting attention patterns"""
# Register hooks before generation
self.register_hooks()
self.attention_weights = []
try:
# Generate token by token to capture attention at each step
generated_ids = []
all_scores = [] # Store scores for confidence calculation
current_input_ids = input_ids.clone()
current_attention_mask = attention_mask.clone()
for _ in range(max_new_tokens):
with torch.no_grad():
# Forward pass through model
outputs = self.model(
input_ids=current_input_ids,
attention_mask=current_attention_mask,
use_cache=False, # Don't use cache to get full attention
output_attentions=True,
return_dict=True
)
# Capture attention from outputs if hooks didn't get it
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
for layer_idx, attn in enumerate(outputs.attentions):
self.attention_weights.append({
'layer': layer_idx,
'attention': attn.detach().cpu()
})
# Get next token logits
next_token_logits = outputs.logits[:, -1, :]
# Store the scores
all_scores.append(next_token_logits)
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Stop if EOS token
if next_token.item() == self.tokenizer.eos_token_id:
break
# Append token
generated_ids.append(next_token.item())
current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
current_attention_mask = torch.cat([
current_attention_mask,
torch.ones((1, 1), device=self.device)
], dim=1)
# Convert to tensor
if generated_ids:
generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0)
else:
generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long)
return generated_tensor, self.attention_weights, all_scores
finally:
# Always clear hooks after generation
self.clear_hooks()
def aggregate_attention_to_examples(
self,
attention_data: List[Dict],
example_boundaries: List[Tuple[int, int]],
prompt_length: int
) -> Dict[str, List[float]]:
"""
Aggregate attention from generated tokens back to example regions
Returns:
Dict mapping example_id -> list of attention weights per generated token
"""
if not attention_data or not example_boundaries:
return {}
attention_to_examples = {}
# Process attention for each generated token position
# We have attention data for each layer for each generated token
# Count unique positions based on attention data
num_layers = 20 # CodeGen has 20 layers
num_generated = len(attention_data) // num_layers if attention_data else 0
logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens")
for example_idx, (start, end) in enumerate(example_boundaries):
example_id = str(example_idx + 1)
example_attention = []
# For each generated token
for gen_idx in range(num_generated):
# Aggregate attention across all layers for this generated position
total_attention = 0.0
# Get attention records for this generated position
layer_count = 0
for i, attn_record in enumerate(attention_data):
# Each generated token should have attention from all layers
# So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx
if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers:
if 'attention' in attn_record:
attn_tensor = attn_record['attention']
# Get attention from generated position to example region
if attn_tensor.dim() >= 3:
# Shape: [batch, heads, seq_len, seq_len]
# The last position in the attention matrix corresponds to the newly generated token
seq_len = attn_tensor.shape[-1]
# Average across heads, get attention from last position to example region
if end <= seq_len:
attn_to_example = attn_tensor[0, :, -1, start:end].mean().item()
total_attention += attn_to_example
layer_count += 1
# Average across layers
if layer_count > 0:
example_attention.append(total_attention / layer_count)
else:
example_attention.append(0.0)
attention_to_examples[example_id] = example_attention
# Normalize attention for each generated token
for gen_idx in range(num_generated):
total = sum(
attention_to_examples[ex_id][gen_idx]
for ex_id in attention_to_examples
if gen_idx < len(attention_to_examples[ex_id])
)
if total > 0:
for ex_id in attention_to_examples:
if gen_idx < len(attention_to_examples[ex_id]):
attention_to_examples[ex_id][gen_idx] /= total
return attention_to_examples
def calculate_example_influences(
self,
attention_to_examples: Dict[str, List[float]]
) -> Dict[str, float]:
"""
Calculate overall influence of each example based on attention patterns
Returns:
Dict mapping example_id -> influence score (0-1)
"""
influences = {}
for example_id, attention_weights in attention_to_examples.items():
# Overall influence is the mean attention across all generated tokens
if attention_weights:
influences[example_id] = float(np.mean(attention_weights))
else:
influences[example_id] = 0.0
# Normalize to sum to 1
total = sum(influences.values())
if total > 0:
influences = {k: v/total for k, v in influences.items()}
return influences