| import json |
| from typing import List, Dict, Any, Optional, Tuple |
| from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity |
|
|
| class SimpleSVGVisualizer: |
| def __init__(self, config): |
| self.config = config |
| |
| def create_visualization_html( |
| self, |
| input_tokens: List[str], |
| output_tokens: List[str], |
| attention_matrices: List[Dict], |
| threshold: float = 0.05, |
| initial_step: int = 0, |
| selected_token: Optional[int] = None, |
| selected_type: Optional[str] = None |
| ) -> str: |
| """Create a simple SVG visualization without D3.""" |
| |
| input_labels = [clean_label(token) for token in input_tokens] |
| output_labels = [clean_label(token) for token in output_tokens] |
| |
| |
| width = self.config.PLOT_WIDTH |
| height = self.config.PLOT_HEIGHT |
| margin = 100 |
| |
| input_x = margin |
| output_x = width - margin |
| |
| |
| svg_elements = [] |
| |
| |
| svg_elements.append(f'<rect width="{width}" height="{height}" fill="white" stroke="#ddd"/>') |
| |
| |
| svg_elements.append(f'<text x="{width/2}" y="30" text-anchor="middle" font-size="16" font-weight="bold">Token Attention Flow</text>') |
| |
| |
| input_y_positions = [] |
| output_y_positions = [] |
| |
| if len(input_labels) > 0: |
| input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1) |
| input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))] |
| |
| if len(output_labels) > 0: |
| output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1) |
| output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))] |
| |
| |
| for j in range(min(initial_step + 1, len(output_labels))): |
| if j < len(attention_matrices): |
| for i in range(len(input_labels)): |
| weight = attention_matrices[j]['input_attention'][i].item() |
| |
| |
| if selected_token is not None: |
| if selected_type == 'input' and i != selected_token: |
| continue |
| elif selected_type == 'output' and j != selected_token: |
| continue |
| |
| if weight > threshold: |
| opacity = scale_weight_to_opacity(weight, threshold) |
| width_val = scale_weight_to_width(weight) |
| |
| svg_elements.append( |
| f'<line x1="{input_x}" y1="{input_y_positions[i]}" ' |
| f'x2="{output_x}" y2="{output_y_positions[j]}" ' |
| f'stroke="blue" stroke-width="{width_val}" opacity="{opacity}"/>' |
| ) |
| |
| |
| for i, label in enumerate(input_labels): |
| y = input_y_positions[i] |
| color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR |
| |
| svg_elements.append( |
| f'<circle cx="{input_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" ' |
| f'fill="{color}" stroke="darkblue" stroke-width="2" ' |
| f'style="cursor: pointer" ' |
| f'onclick="handleTokenClick({i}, \'input\')"/>' |
| ) |
| svg_elements.append( |
| f'<text x="{input_x - self.config.NODE_SIZE/2 - 10}" y="{y + 5}" ' |
| f'text-anchor="end" font-size="{self.config.FONT_SIZE}">{label}</text>' |
| ) |
| |
| |
| for j, label in enumerate(output_labels): |
| y = output_y_positions[j] |
| color = "yellow" if selected_token == j and selected_type == 'output' else ( |
| self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6" |
| ) |
| |
| svg_elements.append( |
| f'<circle cx="{output_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" ' |
| f'fill="{color}" stroke="darkred" stroke-width="2" ' |
| f'style="cursor: pointer" ' |
| f'onclick="handleTokenClick({j}, \'output\')"/>' |
| ) |
| svg_elements.append( |
| f'<text x="{output_x + self.config.NODE_SIZE/2 + 10}" y="{y + 5}" ' |
| f'text-anchor="start" font-size="{self.config.FONT_SIZE}">{label}</text>' |
| ) |
| |
| |
| svg_elements.append( |
| f'<text x="{width/2}" y="{height - 20}" text-anchor="middle" font-size="12" fill="darkred">' |
| f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"' |
| f'</text>' |
| ) |
| |
| |
| html = f""" |
| <div style="width: 100%; overflow-x: auto;"> |
| <svg width="{width}" height="{height}" style="border: 1px solid #ddd;"> |
| {''.join(svg_elements)} |
| </svg> |
| </div> |
| |
| <script> |
| function handleTokenClick(index, type) {{ |
| console.log('Token clicked:', index, type); |
| const hiddenInput = document.querySelector('#clicked-token-d3 textarea'); |
| if (hiddenInput) {{ |
| const clickData = JSON.stringify({{index: index, type: type}}); |
| hiddenInput.value = clickData; |
| hiddenInput.dispatchEvent(new Event('input', {{ bubbles: true }})); |
| }} |
| }} |
| </script> |
| """ |
| |
| return html |