Spaces:
Sleeping
Sleeping
Commit ·
d5dc3e0
1
Parent(s): 3bbf674
Feature 4: Replace BertViz head_view with model_view for hierarchical attention visualization
Browse files- todo.md +29 -5
- utils/__pycache__/model_patterns.cpython-311.pyc +0 -0
- utils/model_patterns.py +28 -22
todo.md
CHANGED
|
@@ -32,8 +32,32 @@
|
|
| 32 |
✅ Feature 3 complete!
|
| 33 |
|
| 34 |
Feature Updates:
|
| 35 |
-
[
|
| 36 |
-
[
|
| 37 |
-
[
|
| 38 |
-
[
|
| 39 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
✅ Feature 3 complete!
|
| 33 |
|
| 34 |
Feature Updates:
|
| 35 |
+
[x] Collapsible Sidebar should minimize to the left and allow main dashboard to fill screen. Maximized size should remain as is, minimized should hide all the way to the left with still visible chevron to maximize.
|
| 36 |
+
[x] The "Compare +" button should switch to a red button that says "Remove -". It should function exactly the same, removing the second prompt, just with a different visual.
|
| 37 |
+
[x] The "Check Token" text box needs a "Submit" button in order to kickoff the creation of the 4th edge.
|
| 38 |
+
[x] Bug: When a second prompt is given and the "Run Analysis" button is clicked, only 1 graph is created when there should be 2 graphs: one above the other.
|
| 39 |
+
[x] Bug: The token given in the Check Token box has a probability of 0 for every layer - added debug output to investigate
|
| 40 |
+
✅ All feature updates complete!
|
| 41 |
+
|
| 42 |
+
## Feature 4: Replace BertViz head_view with model_view
|
| 43 |
+
[x] Read current generate_bertviz_html implementation
|
| 44 |
+
[x] Replace head_view call with model_view
|
| 45 |
+
[x] Update to pass all layers' attention to model_view
|
| 46 |
+
[ ] Test with GPT-2 and Qwen2.5-0.5B models
|
| 47 |
+
[ ] Verify model_view displays correctly in iframe
|
| 48 |
+
|
| 49 |
+
## Feature 5: Attention Head Detection and Categorization
|
| 50 |
+
[ ] Create utility module for head categorization (utils/head_detection.py)
|
| 51 |
+
[ ] Implement detection heuristics for Previous-Token heads
|
| 52 |
+
[ ] Implement detection heuristics for First/Positional heads
|
| 53 |
+
[ ] Implement detection heuristics for Bag-of-Words heads
|
| 54 |
+
[ ] Implement detection heuristics for Syntactic heads
|
| 55 |
+
[ ] Add UI section to display categorized heads
|
| 56 |
+
[ ] Make heuristics parameterized for tuning
|
| 57 |
+
|
| 58 |
+
## Feature 6: Two-Prompt Difference Analysis
|
| 59 |
+
[ ] Compute attention distribution differences across layers/heads
|
| 60 |
+
[ ] Compute output probability differences at each layer
|
| 61 |
+
[ ] Highlight layers with significant differences (red border)
|
| 62 |
+
[ ] Add summary panel showing top-N divergent layers/heads
|
| 63 |
+
[ ] Make difference thresholds configurable
|
utils/__pycache__/model_patterns.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
|
|
|
utils/model_patterns.py
CHANGED
|
@@ -484,42 +484,45 @@ def format_data_for_cytoscape(activation_data: Dict[str, Any], model, tokenizer,
|
|
| 484 |
|
| 485 |
def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, view_type: str = 'full') -> str:
|
| 486 |
"""
|
| 487 |
-
Generate BertViz attention visualization HTML
|
|
|
|
|
|
|
| 488 |
|
| 489 |
Args:
|
| 490 |
activation_data: Output from execute_forward_pass
|
| 491 |
-
layer_index: Index of layer to visualize
|
| 492 |
view_type: 'full' for complete visualization or 'mini' for preview
|
| 493 |
|
| 494 |
Returns:
|
| 495 |
HTML string for the visualization
|
| 496 |
"""
|
| 497 |
try:
|
| 498 |
-
from bertviz import
|
| 499 |
from transformers import AutoTokenizer
|
| 500 |
|
| 501 |
# Extract attention modules and sort by layer
|
| 502 |
attention_outputs = activation_data.get('attention_outputs', {})
|
| 503 |
if not attention_outputs:
|
| 504 |
-
return f"<p>No attention data available
|
| 505 |
|
| 506 |
-
#
|
| 507 |
-
|
| 508 |
for module_name in attention_outputs.keys():
|
| 509 |
numbers = re.findall(r'\d+', module_name)
|
| 510 |
-
if numbers
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
if not isinstance(attention_output, list) or len(attention_output) < 2:
|
| 520 |
-
return f"<p>Invalid attention format for layer {layer_index}</p>"
|
| 521 |
|
| 522 |
-
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# Get tokens
|
| 525 |
input_ids = torch.tensor(activation_data['input_ids'])
|
|
@@ -528,6 +531,7 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
|
|
| 528 |
# Load tokenizer and convert to tokens
|
| 529 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 530 |
raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
|
|
|
| 531 |
tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
|
| 532 |
|
| 533 |
# Generate visualization based on view_type
|
|
@@ -537,15 +541,17 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
|
|
| 537 |
<div style="padding:10px; border:1px solid #ccc; border-radius:5px;">
|
| 538 |
<h4>Layer {layer_index} Attention Preview</h4>
|
| 539 |
<p><strong>Tokens:</strong> {' '.join(tokens[:8])}{'...' if len(tokens) > 8 else ''}</p>
|
| 540 |
-
<p><strong>
|
| 541 |
-
<p><
|
|
|
|
| 542 |
</div>
|
| 543 |
"""
|
| 544 |
else:
|
| 545 |
-
# Full version: complete bertviz visualization
|
| 546 |
-
|
| 547 |
-
html_result = head_view(attentions, tokens, html_action='return')
|
| 548 |
return html_result.data if hasattr(html_result, 'data') else str(html_result)
|
| 549 |
|
| 550 |
except Exception as e:
|
|
|
|
|
|
|
| 551 |
return f"<p>Error generating visualization: {str(e)}</p>"
|
|
|
|
| 484 |
|
| 485 |
def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, view_type: str = 'full') -> str:
|
| 486 |
"""
|
| 487 |
+
Generate BertViz attention visualization HTML using model_view.
|
| 488 |
+
|
| 489 |
+
Shows all layers with the specified layer highlighted/focused.
|
| 490 |
|
| 491 |
Args:
|
| 492 |
activation_data: Output from execute_forward_pass
|
| 493 |
+
layer_index: Index of layer to visualize (for context; model_view shows all layers)
|
| 494 |
view_type: 'full' for complete visualization or 'mini' for preview
|
| 495 |
|
| 496 |
Returns:
|
| 497 |
HTML string for the visualization
|
| 498 |
"""
|
| 499 |
try:
|
| 500 |
+
from bertviz import model_view
|
| 501 |
from transformers import AutoTokenizer
|
| 502 |
|
| 503 |
# Extract attention modules and sort by layer
|
| 504 |
attention_outputs = activation_data.get('attention_outputs', {})
|
| 505 |
if not attention_outputs:
|
| 506 |
+
return f"<p>No attention data available</p>"
|
| 507 |
|
| 508 |
+
# Sort attention modules by layer number
|
| 509 |
+
layer_attention_pairs = []
|
| 510 |
for module_name in attention_outputs.keys():
|
| 511 |
numbers = re.findall(r'\d+', module_name)
|
| 512 |
+
if numbers:
|
| 513 |
+
layer_num = int(numbers[0])
|
| 514 |
+
attention_output = attention_outputs[module_name]['output']
|
| 515 |
+
if isinstance(attention_output, list) and len(attention_output) >= 2:
|
| 516 |
+
# Get attention weights (element 1 of the output tuple)
|
| 517 |
+
attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
|
| 518 |
+
layer_attention_pairs.append((layer_num, attention_weights))
|
| 519 |
|
| 520 |
+
if not layer_attention_pairs:
|
| 521 |
+
return f"<p>No valid attention data found</p>"
|
|
|
|
|
|
|
| 522 |
|
| 523 |
+
# Sort by layer number and extract attention tensors
|
| 524 |
+
layer_attention_pairs.sort(key=lambda x: x[0])
|
| 525 |
+
attentions = tuple(attn for _, attn in layer_attention_pairs)
|
| 526 |
|
| 527 |
# Get tokens
|
| 528 |
input_ids = torch.tensor(activation_data['input_ids'])
|
|
|
|
| 531 |
# Load tokenizer and convert to tokens
|
| 532 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 533 |
raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 534 |
+
# Clean up tokens (remove special tokenizer artifacts like Ġ for GPT-2)
|
| 535 |
tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
|
| 536 |
|
| 537 |
# Generate visualization based on view_type
|
|
|
|
| 541 |
<div style="padding:10px; border:1px solid #ccc; border-radius:5px;">
|
| 542 |
<h4>Layer {layer_index} Attention Preview</h4>
|
| 543 |
<p><strong>Tokens:</strong> {' '.join(tokens[:8])}{'...' if len(tokens) > 8 else ''}</p>
|
| 544 |
+
<p><strong>Total Layers:</strong> {len(attentions)}</p>
|
| 545 |
+
<p><strong>Heads per Layer:</strong> {attentions[0].shape[1] if attentions else 'N/A'}</p>
|
| 546 |
+
<p><em>Click for full model_view visualization</em></p>
|
| 547 |
</div>
|
| 548 |
"""
|
| 549 |
else:
|
| 550 |
+
# Full version: complete bertviz model_view visualization (shows all layers)
|
| 551 |
+
html_result = model_view(attentions, tokens, html_action='return')
|
|
|
|
| 552 |
return html_result.data if hasattr(html_result, 'data') else str(html_result)
|
| 553 |
|
| 554 |
except Exception as e:
|
| 555 |
+
import traceback
|
| 556 |
+
traceback.print_exc()
|
| 557 |
return f"<p>Error generating visualization: {str(e)}</p>"
|