Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 11,368 Bytes
920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
"""
Real Attention Extraction for In-Context Learning Analysis
This module hooks into transformer models to extract actual attention weights
during generation, providing real data for ICL analysis.
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class AttentionData:
"""Stores attention data from model generation"""
layer_attentions: List[torch.Tensor] # Attention from each layer
token_positions: List[int] # Position of each generated token
example_boundaries: List[Tuple[int, int]] # Start/end positions of examples
class AttentionExtractor:
"""Extracts real attention patterns from transformer models during generation"""
def __init__(self, model, tokenizer, adapter=None):
self.model = model
self.tokenizer = tokenizer
self.adapter = adapter # Model adapter for multi-architecture support
self.device = next(model.parameters()).device
# Storage for attention during generation
self.attention_weights = []
self.handles = []
def register_hooks(self):
"""Register forward hooks to capture attention weights"""
self.clear_hooks()
# Use adapter if available for multi-architecture support
if self.adapter:
num_layers = self.adapter.get_num_layers()
for i in range(num_layers):
attn_module = self.adapter.get_attention_module(i)
if attn_module:
handle = attn_module.register_forward_hook(
lambda module, input, output, layer_idx=i:
self._attention_hook(module, input, output, layer_idx)
)
self.handles.append(handle)
# Fallback for CodeGen models without adapter
elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
# Hook into each transformer layer
for i, layer in enumerate(self.model.transformer.h):
if hasattr(layer, 'attn'):
handle = layer.attn.register_forward_hook(
lambda module, input, output, layer_idx=i:
self._attention_hook(module, input, output, layer_idx)
)
self.handles.append(handle)
logger.info(f"Registered {len(self.handles)} attention hooks")
def _attention_hook(self, module, input, output, layer_idx):
"""Hook function to capture attention weights"""
# For CodeGen, output is (hidden_states, attention_weights)
if isinstance(output, tuple) and len(output) >= 2:
attention = output[1]
if attention is not None:
# Store attention weights
self.attention_weights.append({
'layer': layer_idx,
'attention': attention.detach().cpu()
})
def clear_hooks(self):
"""Remove all hooks"""
for handle in self.handles:
handle.remove()
self.handles = []
self.attention_weights = []
def extract_attention_with_generation(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 0.7
) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]:
"""Generate text while extracting attention patterns"""
# Register hooks before generation
self.register_hooks()
self.attention_weights = []
try:
# Generate token by token to capture attention at each step
generated_ids = []
all_scores = [] # Store scores for confidence calculation
current_input_ids = input_ids.clone()
current_attention_mask = attention_mask.clone()
for _ in range(max_new_tokens):
with torch.no_grad():
# Forward pass through model
outputs = self.model(
input_ids=current_input_ids,
attention_mask=current_attention_mask,
use_cache=False, # Don't use cache to get full attention
output_attentions=True,
return_dict=True
)
# Capture attention from outputs if hooks didn't get it
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
for layer_idx, attn in enumerate(outputs.attentions):
self.attention_weights.append({
'layer': layer_idx,
'attention': attn.detach().cpu()
})
# Get next token logits
next_token_logits = outputs.logits[:, -1, :]
# Store the scores
all_scores.append(next_token_logits)
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Stop if EOS token
if next_token.item() == self.tokenizer.eos_token_id:
break
# Append token
generated_ids.append(next_token.item())
current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
current_attention_mask = torch.cat([
current_attention_mask,
torch.ones((1, 1), device=self.device)
], dim=1)
# Convert to tensor
if generated_ids:
generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0)
else:
generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long)
return generated_tensor, self.attention_weights, all_scores
finally:
# Always clear hooks after generation
self.clear_hooks()
def aggregate_attention_to_examples(
self,
attention_data: List[Dict],
example_boundaries: List[Tuple[int, int]],
prompt_length: int
) -> Dict[str, List[float]]:
"""
Aggregate attention from generated tokens back to example regions
Returns:
Dict mapping example_id -> list of attention weights per generated token
"""
if not attention_data or not example_boundaries:
return {}
attention_to_examples = {}
# Process attention for each generated token position
# We have attention data for each layer for each generated token
# Count unique positions based on attention data
num_layers = 20 # CodeGen has 20 layers
num_generated = len(attention_data) // num_layers if attention_data else 0
logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens")
for example_idx, (start, end) in enumerate(example_boundaries):
example_id = str(example_idx + 1)
example_attention = []
# For each generated token
for gen_idx in range(num_generated):
# Aggregate attention across all layers for this generated position
total_attention = 0.0
# Get attention records for this generated position
layer_count = 0
for i, attn_record in enumerate(attention_data):
# Each generated token should have attention from all layers
# So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx
if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers:
if 'attention' in attn_record:
attn_tensor = attn_record['attention']
# Get attention from generated position to example region
if attn_tensor.dim() >= 3:
# Shape: [batch, heads, seq_len, seq_len]
# The last position in the attention matrix corresponds to the newly generated token
seq_len = attn_tensor.shape[-1]
# Average across heads, get attention from last position to example region
if end <= seq_len:
attn_to_example = attn_tensor[0, :, -1, start:end].mean().item()
total_attention += attn_to_example
layer_count += 1
# Average across layers
if layer_count > 0:
example_attention.append(total_attention / layer_count)
else:
example_attention.append(0.0)
attention_to_examples[example_id] = example_attention
# Normalize attention for each generated token
for gen_idx in range(num_generated):
total = sum(
attention_to_examples[ex_id][gen_idx]
for ex_id in attention_to_examples
if gen_idx < len(attention_to_examples[ex_id])
)
if total > 0:
for ex_id in attention_to_examples:
if gen_idx < len(attention_to_examples[ex_id]):
attention_to_examples[ex_id][gen_idx] /= total
return attention_to_examples
def calculate_example_influences(
self,
attention_to_examples: Dict[str, List[float]]
) -> Dict[str, float]:
"""
Calculate overall influence of each example based on attention patterns
Returns:
Dict mapping example_id -> influence score (0-1)
"""
influences = {}
for example_id, attention_weights in attention_to_examples.items():
# Overall influence is the mean attention across all generated tokens
if attention_weights:
influences[example_id] = float(np.mean(attention_weights))
else:
influences[example_id] = 0.0
# Normalize to sum to 1
total = sum(influences.values())
if total > 0:
influences = {k: v/total for k, v in influences.items()}
return influences |