gary-boon commited on
Commit
992dc8c
·
1 Parent(s): c2f6135

Capture complete attention patterns after generation

Browse files

- Modified attention capture to occur after all tokens are generated
- Added tokens field to TraceData model for token display
- Increased sampling threshold from 20 to 100 tokens
- Capture all layers instead of sampling every Nth layer
- Include full token list (prompt + generated) in attention traces
- This ensures complete attention matrices that match token count

Files changed (1) hide show
  1. backend/model_service.py +45 -19
backend/model_service.py CHANGED
@@ -75,6 +75,7 @@ class TraceData(BaseModel):
75
  type: str
76
  layer: Optional[str] = None
77
  weights: Optional[List[List[float]]] = None
 
78
  max_weight: Optional[float] = None
79
  entropy: Optional[float] = None
80
  mean: Optional[float] = None
@@ -128,7 +129,7 @@ class ModelManager:
128
  logger.error(f"Failed to load model: {e}")
129
  raise
130
 
131
- def extract_attention_trace(self, layer_idx: int, attention_weights) -> TraceData:
132
  """Extract attention pattern trace from a layer"""
133
  # attention_weights is a tuple of tensors, one for each layer
134
  # Each tensor has shape (batch_size, num_heads, seq_len, seq_len)
@@ -138,10 +139,13 @@ class ModelManager:
138
  # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len)
139
  avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy()
140
 
141
- # Sample the weights for efficiency
142
- if avg_attention.shape[0] > 20:
143
- indices = np.random.choice(avg_attention.shape[0], 20, replace=False)
 
144
  avg_attention = avg_attention[indices][:, indices]
 
 
145
 
146
  # Ensure values are finite
147
  avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0)
@@ -163,6 +167,7 @@ class ModelManager:
163
  type="attention",
164
  layer=f"layer.{layer_idx}",
165
  weights=avg_attention.tolist(),
 
166
  max_weight=max_weight,
167
  entropy=entropy,
168
  timestamp=datetime.now().timestamp()
@@ -521,20 +526,9 @@ class ModelManager:
521
  output_hidden_states=True
522
  )
523
 
524
- # Sample traces based on sampling rate
525
- if np.random.random() < sampling_rate:
526
- # Extract attention traces from multiple layers
527
- if outputs.attentions and len(outputs.attentions) > 0:
528
- # Sample every Nth layer to get good coverage
529
- num_layers = len(outputs.attentions)
530
- step = max(1, num_layers // 10) # Get ~10 layers sampled
531
- for layer_idx in range(0, num_layers, step):
532
- try:
533
- trace = self.extract_attention_trace(layer_idx, outputs.attentions)
534
- traces.append(trace)
535
- await self.broadcast_trace(trace)
536
- except Exception as e:
537
- logger.warning(f"Failed to extract attention trace from layer {layer_idx}: {e}")
538
 
539
  # Extract activation traces periodically (not every token to avoid overflow)
540
  if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3:
@@ -633,8 +627,40 @@ class ModelManager:
633
  if next_token.item() == self.tokenizer.eos_token_id:
634
  break
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  # Calculate final confidence
637
- confidence_trace = self.calculate_confidence(logits)
638
  traces.append(confidence_trace)
639
  await self.broadcast_trace(confidence_trace)
640
 
 
75
  type: str
76
  layer: Optional[str] = None
77
  weights: Optional[List[List[float]]] = None
78
+ tokens: Optional[List[str]] = None # Add tokens field
79
  max_weight: Optional[float] = None
80
  entropy: Optional[float] = None
81
  mean: Optional[float] = None
 
129
  logger.error(f"Failed to load model: {e}")
130
  raise
131
 
132
+ def extract_attention_trace(self, layer_idx: int, attention_weights, tokens: Optional[List[str]] = None) -> TraceData:
133
  """Extract attention pattern trace from a layer"""
134
  # attention_weights is a tuple of tensors, one for each layer
135
  # Each tensor has shape (batch_size, num_heads, seq_len, seq_len)
 
139
  # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len)
140
  avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy()
141
 
142
+ # Don't sample if we have complete attention - we want the full matrix
143
+ # Only sample if the matrix is very large (>100x100)
144
+ if avg_attention.shape[0] > 100:
145
+ indices = np.random.choice(avg_attention.shape[0], 100, replace=False)
146
  avg_attention = avg_attention[indices][:, indices]
147
+ if tokens:
148
+ tokens = [tokens[i] for i in indices]
149
 
150
  # Ensure values are finite
151
  avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0)
 
167
  type="attention",
168
  layer=f"layer.{layer_idx}",
169
  weights=avg_attention.tolist(),
170
+ tokens=tokens, # Include tokens in the trace
171
  max_weight=max_weight,
172
  entropy=entropy,
173
  timestamp=datetime.now().timestamp()
 
526
  output_hidden_states=True
527
  )
528
 
529
+ # Skip mid-generation attention capture - we'll capture complete attention at the end
530
+ # This ensures we get the full attention matrix for all generated tokens
531
+ pass # Removed mid-generation attention capture
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  # Extract activation traces periodically (not every token to avoid overflow)
534
  if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3:
 
627
  if next_token.item() == self.tokenizer.eos_token_id:
628
  break
629
 
630
+ # After generation is complete, capture final attention patterns for all tokens
631
+ # Do a final forward pass with the complete sequence to get full attention
632
+ with torch.no_grad():
633
+ final_outputs = self.model(
634
+ **inputs,
635
+ output_attentions=True,
636
+ output_hidden_states=True
637
+ )
638
+
639
+ # Extract complete attention patterns from all layers
640
+ if final_outputs.attentions and len(final_outputs.attentions) > 0:
641
+ num_layers = len(final_outputs.attentions)
642
+
643
+ # Clear previous partial traces and add complete ones
644
+ traces = [] # Reset traces to only include complete attention patterns
645
+
646
+ # Capture ALL layers for complete visualization
647
+ for layer_idx in range(num_layers):
648
+ try:
649
+ # Get all token IDs (prompt + generated)
650
+ all_token_ids = inputs["input_ids"][0].tolist()
651
+
652
+ # Decode each token individually to preserve token boundaries
653
+ all_tokens = [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in all_token_ids]
654
+
655
+ # Pass tokens to the extraction method
656
+ trace = self.extract_attention_trace(layer_idx, final_outputs.attentions, all_tokens)
657
+ traces.append(trace)
658
+ await self.broadcast_trace(trace)
659
+ except Exception as e:
660
+ logger.warning(f"Failed to extract final attention trace from layer {layer_idx}: {e}")
661
+
662
  # Calculate final confidence
663
+ confidence_trace = self.calculate_confidence(final_outputs.logits)
664
  traces.append(confidence_trace)
665
  await self.broadcast_trace(confidence_trace)
666