Spaces:
Sleeping
Sleeping
gary-boon
commited on
Commit
·
992dc8c
1
Parent(s):
c2f6135
Capture complete attention patterns after generation
Browse files- Modified attention capture to occur after all tokens are generated
- Added tokens field to TraceData model for token display
- Increased sampling threshold from 20 to 100 tokens
- Capture all layers instead of sampling every Nth layer
- Include full token list (prompt + generated) in attention traces
- This ensures complete attention matrices that match token count
- backend/model_service.py +45 -19
backend/model_service.py
CHANGED
|
@@ -75,6 +75,7 @@ class TraceData(BaseModel):
|
|
| 75 |
type: str
|
| 76 |
layer: Optional[str] = None
|
| 77 |
weights: Optional[List[List[float]]] = None
|
|
|
|
| 78 |
max_weight: Optional[float] = None
|
| 79 |
entropy: Optional[float] = None
|
| 80 |
mean: Optional[float] = None
|
|
@@ -128,7 +129,7 @@ class ModelManager:
|
|
| 128 |
logger.error(f"Failed to load model: {e}")
|
| 129 |
raise
|
| 130 |
|
| 131 |
-
def extract_attention_trace(self, layer_idx: int, attention_weights) -> TraceData:
|
| 132 |
"""Extract attention pattern trace from a layer"""
|
| 133 |
# attention_weights is a tuple of tensors, one for each layer
|
| 134 |
# Each tensor has shape (batch_size, num_heads, seq_len, seq_len)
|
|
@@ -138,10 +139,13 @@ class ModelManager:
|
|
| 138 |
# Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len)
|
| 139 |
avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy()
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
if
|
| 143 |
-
|
|
|
|
| 144 |
avg_attention = avg_attention[indices][:, indices]
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# Ensure values are finite
|
| 147 |
avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0)
|
|
@@ -163,6 +167,7 @@ class ModelManager:
|
|
| 163 |
type="attention",
|
| 164 |
layer=f"layer.{layer_idx}",
|
| 165 |
weights=avg_attention.tolist(),
|
|
|
|
| 166 |
max_weight=max_weight,
|
| 167 |
entropy=entropy,
|
| 168 |
timestamp=datetime.now().timestamp()
|
|
@@ -521,20 +526,9 @@ class ModelManager:
|
|
| 521 |
output_hidden_states=True
|
| 522 |
)
|
| 523 |
|
| 524 |
-
#
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
if outputs.attentions and len(outputs.attentions) > 0:
|
| 528 |
-
# Sample every Nth layer to get good coverage
|
| 529 |
-
num_layers = len(outputs.attentions)
|
| 530 |
-
step = max(1, num_layers // 10) # Get ~10 layers sampled
|
| 531 |
-
for layer_idx in range(0, num_layers, step):
|
| 532 |
-
try:
|
| 533 |
-
trace = self.extract_attention_trace(layer_idx, outputs.attentions)
|
| 534 |
-
traces.append(trace)
|
| 535 |
-
await self.broadcast_trace(trace)
|
| 536 |
-
except Exception as e:
|
| 537 |
-
logger.warning(f"Failed to extract attention trace from layer {layer_idx}: {e}")
|
| 538 |
|
| 539 |
# Extract activation traces periodically (not every token to avoid overflow)
|
| 540 |
if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3:
|
|
@@ -633,8 +627,40 @@ class ModelManager:
|
|
| 633 |
if next_token.item() == self.tokenizer.eos_token_id:
|
| 634 |
break
|
| 635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
# Calculate final confidence
|
| 637 |
-
confidence_trace = self.calculate_confidence(logits)
|
| 638 |
traces.append(confidence_trace)
|
| 639 |
await self.broadcast_trace(confidence_trace)
|
| 640 |
|
|
|
|
| 75 |
type: str
|
| 76 |
layer: Optional[str] = None
|
| 77 |
weights: Optional[List[List[float]]] = None
|
| 78 |
+
tokens: Optional[List[str]] = None # Add tokens field
|
| 79 |
max_weight: Optional[float] = None
|
| 80 |
entropy: Optional[float] = None
|
| 81 |
mean: Optional[float] = None
|
|
|
|
| 129 |
logger.error(f"Failed to load model: {e}")
|
| 130 |
raise
|
| 131 |
|
| 132 |
+
def extract_attention_trace(self, layer_idx: int, attention_weights, tokens: Optional[List[str]] = None) -> TraceData:
|
| 133 |
"""Extract attention pattern trace from a layer"""
|
| 134 |
# attention_weights is a tuple of tensors, one for each layer
|
| 135 |
# Each tensor has shape (batch_size, num_heads, seq_len, seq_len)
|
|
|
|
| 139 |
# Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len)
|
| 140 |
avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy()
|
| 141 |
|
| 142 |
+
# Don't sample if we have complete attention - we want the full matrix
|
| 143 |
+
# Only sample if the matrix is very large (>100x100)
|
| 144 |
+
if avg_attention.shape[0] > 100:
|
| 145 |
+
indices = np.random.choice(avg_attention.shape[0], 100, replace=False)
|
| 146 |
avg_attention = avg_attention[indices][:, indices]
|
| 147 |
+
if tokens:
|
| 148 |
+
tokens = [tokens[i] for i in indices]
|
| 149 |
|
| 150 |
# Ensure values are finite
|
| 151 |
avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0)
|
|
|
|
| 167 |
type="attention",
|
| 168 |
layer=f"layer.{layer_idx}",
|
| 169 |
weights=avg_attention.tolist(),
|
| 170 |
+
tokens=tokens, # Include tokens in the trace
|
| 171 |
max_weight=max_weight,
|
| 172 |
entropy=entropy,
|
| 173 |
timestamp=datetime.now().timestamp()
|
|
|
|
| 526 |
output_hidden_states=True
|
| 527 |
)
|
| 528 |
|
| 529 |
+
# Skip mid-generation attention capture - we'll capture complete attention at the end
|
| 530 |
+
# This ensures we get the full attention matrix for all generated tokens
|
| 531 |
+
pass # Removed mid-generation attention capture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
# Extract activation traces periodically (not every token to avoid overflow)
|
| 534 |
if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3:
|
|
|
|
| 627 |
if next_token.item() == self.tokenizer.eos_token_id:
|
| 628 |
break
|
| 629 |
|
| 630 |
+
# After generation is complete, capture final attention patterns for all tokens
|
| 631 |
+
# Do a final forward pass with the complete sequence to get full attention
|
| 632 |
+
with torch.no_grad():
|
| 633 |
+
final_outputs = self.model(
|
| 634 |
+
**inputs,
|
| 635 |
+
output_attentions=True,
|
| 636 |
+
output_hidden_states=True
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
# Extract complete attention patterns from all layers
|
| 640 |
+
if final_outputs.attentions and len(final_outputs.attentions) > 0:
|
| 641 |
+
num_layers = len(final_outputs.attentions)
|
| 642 |
+
|
| 643 |
+
# Clear previous partial traces and add complete ones
|
| 644 |
+
traces = [] # Reset traces to only include complete attention patterns
|
| 645 |
+
|
| 646 |
+
# Capture ALL layers for complete visualization
|
| 647 |
+
for layer_idx in range(num_layers):
|
| 648 |
+
try:
|
| 649 |
+
# Get all token IDs (prompt + generated)
|
| 650 |
+
all_token_ids = inputs["input_ids"][0].tolist()
|
| 651 |
+
|
| 652 |
+
# Decode each token individually to preserve token boundaries
|
| 653 |
+
all_tokens = [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in all_token_ids]
|
| 654 |
+
|
| 655 |
+
# Pass tokens to the extraction method
|
| 656 |
+
trace = self.extract_attention_trace(layer_idx, final_outputs.attentions, all_tokens)
|
| 657 |
+
traces.append(trace)
|
| 658 |
+
await self.broadcast_trace(trace)
|
| 659 |
+
except Exception as e:
|
| 660 |
+
logger.warning(f"Failed to extract final attention trace from layer {layer_idx}: {e}")
|
| 661 |
+
|
| 662 |
# Calculate final confidence
|
| 663 |
+
confidence_trace = self.calculate_confidence(final_outputs.logits)
|
| 664 |
traces.append(confidence_trace)
|
| 665 |
await self.broadcast_trace(confidence_trace)
|
| 666 |
|