|
|
""" |
|
|
Graph Visualization per Circuit Tracer |
|
|
|
|
|
Versione UNIFIED: merge di graph_visualization.py e graph_visualization_fixed.py |
|
|
|
|
|
Layout intelligente: |
|
|
- Se Feature hanno layer/pos: usa layout Layer × Position (improved) |
|
|
- Altrimenti: fallback a layout semplice grid-based |
|
|
|
|
|
Usage: |
|
|
from scripts.visualization.graph_visualization import ( |
|
|
create_graph_visualization, |
|
|
Supernode, |
|
|
InterventionGraph, |
|
|
Feature |
|
|
) |
|
|
"""git |
|
|
|
|
|
from collections import namedtuple, defaultdict |
|
|
from typing import List, Optional, Tuple, Dict |
|
|
import math |
|
|
import html |
|
|
|
|
|
import torch |
|
|
from IPython.display import SVG |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Feature = namedtuple('Feature', ['layer', 'pos', 'feature_idx']) |
|
|
|
|
|
|
|
|
class InterventionGraph: |
|
|
"""Grafo di intervento con prompt e nodi organizzati""" |
|
|
prompt: str |
|
|
ordered_nodes: List['Supernode'] |
|
|
nodes: Dict[str, 'Supernode'] |
|
|
|
|
|
def __init__(self, ordered_nodes: List['Supernode'], prompt: str): |
|
|
self.ordered_nodes = ordered_nodes |
|
|
self.prompt = prompt |
|
|
self.nodes = {} |
|
|
|
|
|
def initialize_node(self, node, activations): |
|
|
"""Inizializza un nodo con le sue attivazioni di default""" |
|
|
self.nodes[node.name] = node |
|
|
if node.features: |
|
|
node.default_activations = torch.tensor([activations[feature] for feature in node.features]) |
|
|
else: |
|
|
node.default_activations = None |
|
|
|
|
|
def set_node_activation_fractions(self, current_activations): |
|
|
"""Imposta le frazioni di attivazione correnti per tutti i nodi""" |
|
|
for node in self.nodes.values(): |
|
|
if node.features: |
|
|
current_node_activation = torch.tensor([current_activations[feature] for feature in node.features]) |
|
|
node.activation = (current_node_activation / node.default_activations).mean().item() |
|
|
else: |
|
|
node.activation = None |
|
|
node.intervention = None |
|
|
node.replacement_node = None |
|
|
|
|
|
|
|
|
class Supernode: |
|
|
"""Nodo del grafo rappresentante un gruppo di feature""" |
|
|
name: str |
|
|
activation: float|None |
|
|
default_activations: torch.Tensor|None |
|
|
children: List['Supernode'] |
|
|
intervention: None |
|
|
replacement_node: Optional['Supernode'] |
|
|
|
|
|
def __init__(self, name: str, features: List[Feature], children: List['Supernode'] = [], |
|
|
intervention: Optional[str] = None, replacement_node: Optional['Supernode'] = None): |
|
|
self.name = name |
|
|
self.features = features |
|
|
self.activation = None |
|
|
self.default_activations = None |
|
|
self.children = children |
|
|
self.intervention = intervention |
|
|
self.replacement_node = replacement_node |
|
|
|
|
|
def __repr__(self): |
|
|
return f"Node(name={self.name}, activation={self.activation}, children={self.children}, intervention={self.intervention}, replacement_node={self.replacement_node})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_node_positions_improved(nodes: List[List['Supernode']]): |
|
|
""" |
|
|
Layout MIGLIORATO: Usa layer e position REALI dalle Feature! |
|
|
|
|
|
Layout: |
|
|
- Asse X: Token position (da Feature.pos) |
|
|
- Asse Y: Layer (da Feature.layer) - bottom-up |
|
|
|
|
|
Returns: |
|
|
node_data: Dict[node_name, {x, y, node, layer, pos}] |
|
|
layer_range: (min_layer, max_layer) |
|
|
pos_range: (min_pos, max_pos) |
|
|
""" |
|
|
container_width = 1200 |
|
|
container_height = 1600 |
|
|
node_width = 80 |
|
|
node_height = 30 |
|
|
|
|
|
|
|
|
x_spacing = 100 |
|
|
y_spacing = 30 |
|
|
|
|
|
|
|
|
all_nodes = [] |
|
|
for layer_list in nodes: |
|
|
all_nodes.extend(layer_list) |
|
|
|
|
|
|
|
|
min_layer, max_layer = float('inf'), 0 |
|
|
min_pos, max_pos = float('inf'), 0 |
|
|
|
|
|
for node in all_nodes: |
|
|
if node.features: |
|
|
for feature in node.features: |
|
|
min_layer = min(min_layer, feature.layer) |
|
|
max_layer = max(max_layer, feature.layer) |
|
|
min_pos = min(min_pos, feature.pos) |
|
|
max_pos = max(max_pos, feature.pos) |
|
|
|
|
|
|
|
|
if min_layer == float('inf'): |
|
|
return None |
|
|
|
|
|
|
|
|
nodes_by_layer_pos = defaultdict(list) |
|
|
|
|
|
for node in all_nodes: |
|
|
if node.features: |
|
|
|
|
|
avg_layer = sum(f.layer for f in node.features) / len(node.features) |
|
|
avg_pos = sum(f.pos for f in node.features) / len(node.features) |
|
|
|
|
|
|
|
|
layer_key = int(round(avg_layer)) |
|
|
pos_key = int(round(avg_pos)) |
|
|
|
|
|
nodes_by_layer_pos[(layer_key, pos_key)].append(node) |
|
|
else: |
|
|
|
|
|
pos_key = len(nodes_by_layer_pos) % 3 |
|
|
nodes_by_layer_pos[(0, pos_key)].append(node) |
|
|
|
|
|
|
|
|
node_data = {} |
|
|
base_x = 100 |
|
|
base_y = container_height - 150 |
|
|
|
|
|
for (layer, pos), node_list in nodes_by_layer_pos.items(): |
|
|
|
|
|
x_base = base_x + pos * x_spacing |
|
|
y_base = base_y - layer * y_spacing |
|
|
|
|
|
|
|
|
n_nodes = len(node_list) |
|
|
offset_step = 25 if n_nodes > 1 else 0 |
|
|
|
|
|
for idx, node in enumerate(node_list): |
|
|
|
|
|
x_offset = (idx - (n_nodes - 1) / 2) * offset_step |
|
|
|
|
|
node_x = x_base + x_offset - node_width / 2 |
|
|
node_y = y_base - node_height / 2 |
|
|
|
|
|
node_data[node.name] = { |
|
|
'x': node_x, |
|
|
'y': node_y, |
|
|
'node': node, |
|
|
'layer': layer, |
|
|
'pos': pos |
|
|
} |
|
|
|
|
|
|
|
|
all_nodes_set = set(all_nodes) |
|
|
for node in all_nodes_set: |
|
|
if node.replacement_node and node.replacement_node.name not in node_data: |
|
|
original_pos = node_data.get(node.name) |
|
|
if original_pos: |
|
|
node_data[node.replacement_node.name] = { |
|
|
'x': original_pos['x'] + 30, |
|
|
'y': original_pos['y'] - 40, |
|
|
'node': node.replacement_node, |
|
|
'layer': original_pos.get('layer', 0), |
|
|
'pos': original_pos.get('pos', 0) |
|
|
} |
|
|
|
|
|
return node_data, (int(min_layer), int(max_layer)), (int(min_pos), int(max_pos)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_node_positions_simple(nodes: List[List['Supernode']]): |
|
|
""" |
|
|
Layout SEMPLICE: Grid-based (fallback quando non ci sono layer/pos) |
|
|
""" |
|
|
container_width = 600 |
|
|
container_height = 250 |
|
|
node_width = 100 |
|
|
node_height = 35 |
|
|
|
|
|
node_data = {} |
|
|
|
|
|
for row_index in range(len(nodes)): |
|
|
row = nodes[row_index] |
|
|
row_y = container_height - (row_index * (container_height / (len(nodes) + 0.5))) |
|
|
|
|
|
for col_index in range(len(row)): |
|
|
node = row[col_index] |
|
|
row_width = len(row) * node_width + (len(row) - 1) * 50 |
|
|
start_x = (container_width - row_width) / 2 |
|
|
node_x = start_x + col_index * (node_width + 50) |
|
|
|
|
|
node_data[node.name] = { |
|
|
'x': node_x, |
|
|
'y': row_y, |
|
|
'node': node, |
|
|
'layer': row_index, |
|
|
'pos': col_index |
|
|
} |
|
|
|
|
|
|
|
|
all_nodes = set() |
|
|
for layer in nodes: |
|
|
for node in layer: |
|
|
all_nodes.add(node) |
|
|
if node.replacement_node: |
|
|
all_nodes.add(node.replacement_node) |
|
|
|
|
|
for node in all_nodes: |
|
|
if node.replacement_node and node.replacement_node.name not in node_data: |
|
|
original_pos = node_data.get(node.name) |
|
|
if original_pos: |
|
|
node_data[node.replacement_node.name] = { |
|
|
'x': original_pos['x'] + 30, |
|
|
'y': original_pos['y'] - 35, |
|
|
'node': node.replacement_node, |
|
|
'layer': original_pos.get('layer', 0), |
|
|
'pos': original_pos.get('pos', 0) |
|
|
} |
|
|
|
|
|
return node_data, (0, len(nodes)-1), (0, max(len(row) for row in nodes)-1 if nodes else 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_center(node_data, node_name, simple_mode=False): |
|
|
"""Get center coordinates of a node""" |
|
|
node = node_data.get(node_name) |
|
|
if not node: |
|
|
return {'x': 0, 'y': 0} |
|
|
|
|
|
if simple_mode: |
|
|
return { |
|
|
'x': node['x'] + 50, |
|
|
'y': node['y'] + 17.5 |
|
|
} |
|
|
else: |
|
|
return { |
|
|
'x': node['x'] + 40, |
|
|
'y': node['y'] + 15 |
|
|
} |
|
|
|
|
|
|
|
|
def create_connection_svg(node_data, connections, simple_mode=False): |
|
|
"""Generate SVG elements for all connections""" |
|
|
svg_parts = [] |
|
|
|
|
|
for conn in connections: |
|
|
from_center = get_node_center(node_data, conn['from'], simple_mode) |
|
|
to_center = get_node_center(node_data, conn['to'], simple_mode) |
|
|
|
|
|
if from_center['x'] == 0 or to_center['x'] == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if conn.get('replacement'): |
|
|
stroke_color = "#D2691E" |
|
|
stroke_width = "4" |
|
|
elif not simple_mode: |
|
|
|
|
|
from_node_data = node_data.get(conn['from']) |
|
|
to_node_data = node_data.get(conn['to']) |
|
|
|
|
|
is_forward = False |
|
|
if from_node_data and to_node_data: |
|
|
from_layer = from_node_data.get('layer', 0) |
|
|
to_layer = to_node_data.get('layer', 0) |
|
|
is_forward = to_layer > from_layer |
|
|
|
|
|
if is_forward: |
|
|
stroke_color = "#4169E1" |
|
|
stroke_width = "2.5" |
|
|
else: |
|
|
stroke_color = "#DC143C" |
|
|
stroke_width = "2" |
|
|
else: |
|
|
|
|
|
stroke_color = "#8B4513" |
|
|
stroke_width = "3" |
|
|
|
|
|
opacity = "0.6" if not simple_mode else "1.0" |
|
|
|
|
|
|
|
|
svg_parts.append(f'<line x1="{from_center["x"]}" y1="{from_center["y"]}" ' |
|
|
f'x2="{to_center["x"]}" y2="{to_center["y"]}" ' |
|
|
f'stroke="{stroke_color}" stroke-width="{stroke_width}" ' |
|
|
f'opacity="{opacity}"/>') |
|
|
|
|
|
|
|
|
dx = to_center['x'] - from_center['x'] |
|
|
dy = to_center['y'] - from_center['y'] |
|
|
length = math.sqrt(dx * dx + dy * dy) |
|
|
|
|
|
if length > 0: |
|
|
dx_norm = dx / length |
|
|
dy_norm = dy / length |
|
|
|
|
|
arrow_size = 8 |
|
|
arrow_tip_x = to_center['x'] |
|
|
arrow_tip_y = to_center['y'] |
|
|
|
|
|
base_x = arrow_tip_x - arrow_size * dx_norm |
|
|
base_y = arrow_tip_y - arrow_size * dy_norm |
|
|
|
|
|
perp_x = -dy_norm * (arrow_size / 2) |
|
|
perp_y = dx_norm * (arrow_size / 2) |
|
|
|
|
|
left_x = base_x + perp_x |
|
|
left_y = base_y + perp_y |
|
|
right_x = base_x - perp_x |
|
|
right_y = base_y - perp_y |
|
|
|
|
|
svg_parts.append(f'<polygon points="{arrow_tip_x},{arrow_tip_y} {left_x},{left_y} {right_x},{right_y}" ' |
|
|
f'fill="{stroke_color}" opacity="{opacity}"/>') |
|
|
|
|
|
return '\n'.join(svg_parts) |
|
|
|
|
|
|
|
|
def create_nodes_svg(node_data, simple_mode=False): |
|
|
"""Generate SVG elements for all nodes""" |
|
|
svg_parts = [] |
|
|
|
|
|
|
|
|
replacement_nodes = set() |
|
|
for data in node_data.values(): |
|
|
node = data['node'] |
|
|
if node.replacement_node: |
|
|
replacement_nodes.add(node.replacement_node.name) |
|
|
|
|
|
for name, data in node_data.items(): |
|
|
node = data['node'] |
|
|
x = data['x'] |
|
|
y = data['y'] |
|
|
layer = data.get('layer', 0) |
|
|
|
|
|
|
|
|
is_low_activation = node.activation is not None and node.activation <= 0.25 |
|
|
has_negative_intervention = node.intervention and '-' in node.intervention |
|
|
is_replacement = name in replacement_nodes |
|
|
|
|
|
if is_low_activation or has_negative_intervention: |
|
|
fill_color = "#f0f0f0" |
|
|
text_color = "#bbb" |
|
|
stroke_color = "#ddd" |
|
|
elif is_replacement: |
|
|
fill_color = "#FFF8DC" |
|
|
text_color = "#333" |
|
|
stroke_color = "#D2691E" |
|
|
elif not simple_mode: |
|
|
|
|
|
layer_hue = (layer * 30) % 360 |
|
|
fill_color = f"hsl({layer_hue}, 70%, 85%)" |
|
|
text_color = "#333" |
|
|
stroke_color = "#999" |
|
|
else: |
|
|
|
|
|
fill_color = "#e8e8e8" |
|
|
text_color = "#333" |
|
|
stroke_color = "#999" |
|
|
|
|
|
|
|
|
if simple_mode: |
|
|
width, height = 100, 35 |
|
|
font_size = 12 |
|
|
else: |
|
|
width, height = 80, 30 |
|
|
font_size = 10 |
|
|
|
|
|
svg_parts.append(f'<rect x="{x}" y="{y}" width="{width}" height="{height}" ' |
|
|
f'fill="{fill_color}" stroke="{stroke_color}" stroke-width="2" rx="{"8" if simple_mode else "6"}"/>') |
|
|
|
|
|
|
|
|
text_x = x + width / 2 |
|
|
text_y = y + height / 2 + font_size / 3 |
|
|
|
|
|
|
|
|
max_len = 15 if simple_mode else 12 |
|
|
truncated_name = name if len(name) <= max_len else name[:max_len-2] + '...' |
|
|
display_name = html.escape(truncated_name) |
|
|
|
|
|
svg_parts.append(f'<text x="{text_x}" y="{text_y}" text-anchor="middle" ' |
|
|
f'fill="{text_color}" font-family="Arial, sans-serif" font-size="{font_size}" font-weight="bold">{display_name}</text>') |
|
|
|
|
|
|
|
|
if not simple_mode and layer > 0: |
|
|
badge_x = x - 8 |
|
|
badge_y = y - 8 |
|
|
svg_parts.append(f'<circle cx="{badge_x}" cy="{badge_y}" r="10" ' |
|
|
f'fill="white" stroke="#666" stroke-width="1"/>') |
|
|
svg_parts.append(f'<text x="{badge_x}" y="{badge_y + 4}" text-anchor="middle" ' |
|
|
f'fill="#666" font-family="Arial, sans-serif" font-size="8">{layer}</text>') |
|
|
|
|
|
|
|
|
if node.activation is not None: |
|
|
activation_pct = round(node.activation * 100) |
|
|
|
|
|
if simple_mode: |
|
|
|
|
|
label_x = x - 15 |
|
|
label_y = y - 5 |
|
|
svg_parts.append(f'<rect x="{label_x}" y="{label_y}" width="30" height="16" ' |
|
|
f'fill="white" stroke="#ccc" stroke-width="1" rx="4"/>') |
|
|
svg_parts.append(f'<text x="{label_x + 15}" y="{label_y + 12}" text-anchor="middle" ' |
|
|
f'fill="#8B4513" font-family="Arial, sans-serif" font-size="10" font-weight="bold">{activation_pct}%</text>') |
|
|
else: |
|
|
|
|
|
label_x = x + 75 |
|
|
label_y = y - 5 |
|
|
svg_parts.append(f'<text x="{label_x}" y="{label_y}" text-anchor="end" ' |
|
|
f'fill="#8B4513" font-family="Arial, sans-serif" font-size="9">{activation_pct}%</text>') |
|
|
|
|
|
|
|
|
if simple_mode and node.intervention: |
|
|
intervention_x = x - 20 |
|
|
intervention_y = y - 5 |
|
|
|
|
|
text_width = len(node.intervention) * 8 + 10 |
|
|
escaped_intervention = html.escape(node.intervention) |
|
|
|
|
|
svg_parts.append(f'<rect x="{intervention_x}" y="{intervention_y}" width="{text_width}" height="16" ' |
|
|
f'fill="#D2691E" stroke="none" rx="12"/>') |
|
|
svg_parts.append(f'<text x="{intervention_x + text_width/2}" y="{intervention_y + 12}" text-anchor="middle" ' |
|
|
f'fill="white" font-family="Arial, sans-serif" font-size="10" font-weight="bold">{escaped_intervention}</text>') |
|
|
|
|
|
return '\n'.join(svg_parts) |
|
|
|
|
|
|
|
|
def build_connections_data(nodes: List[List['Supernode']]): |
|
|
"""Build connection data from node relationships""" |
|
|
connections = [] |
|
|
|
|
|
|
|
|
all_nodes = set() |
|
|
|
|
|
def add_node_and_related(node): |
|
|
all_nodes.add(node) |
|
|
if node.replacement_node: |
|
|
add_node_and_related(node.replacement_node) |
|
|
for child in node.children: |
|
|
add_node_and_related(child) |
|
|
|
|
|
for layer in nodes: |
|
|
for node in layer: |
|
|
add_node_and_related(node) |
|
|
|
|
|
|
|
|
replacement_nodes = set() |
|
|
for node in all_nodes: |
|
|
if node.replacement_node: |
|
|
replacement_nodes.add(node.replacement_node.name) |
|
|
|
|
|
|
|
|
for node in all_nodes: |
|
|
for child in node.children: |
|
|
if node.replacement_node: |
|
|
continue |
|
|
|
|
|
is_replacement = node.name in replacement_nodes |
|
|
|
|
|
connection = { |
|
|
'from': node.name, |
|
|
'to': child.name |
|
|
} |
|
|
if is_replacement: |
|
|
connection['replacement'] = True |
|
|
|
|
|
connections.append(connection) |
|
|
|
|
|
return connections |
|
|
|
|
|
|
|
|
def wrap_text_for_svg(text, max_width=80): |
|
|
"""Simple text wrapping for SVG""" |
|
|
if len(text) <= max_width: |
|
|
return [text] |
|
|
|
|
|
words = text.split() |
|
|
lines = [] |
|
|
current_line = "" |
|
|
|
|
|
for word in words: |
|
|
if len(current_line + " " + word) <= max_width: |
|
|
current_line = current_line + " " + word if current_line else word |
|
|
else: |
|
|
if current_line: |
|
|
lines.append(current_line) |
|
|
current_line = word |
|
|
|
|
|
if current_line: |
|
|
lines.append(current_line) |
|
|
|
|
|
return lines |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_graph_visualization(intervention_graph: InterventionGraph, |
|
|
top_outputs: List[Tuple[str, float]], |
|
|
force_simple: bool = False): |
|
|
""" |
|
|
Crea visualizzazione SVG del grafo. |
|
|
|
|
|
INTELLIGENTE: usa automaticamente il layout migliore disponibile: |
|
|
- Se le Feature hanno layer/pos: usa layout Layer × Position (improved) |
|
|
- Altrimenti: fallback a layout semplice grid-based |
|
|
|
|
|
Args: |
|
|
intervention_graph: Il grafo di intervento |
|
|
top_outputs: Lista di tuple (token, probability) |
|
|
force_simple: Se True, forza l'uso del layout semplice |
|
|
|
|
|
Returns: |
|
|
SVG object per visualizzazione |
|
|
""" |
|
|
nodes = intervention_graph.ordered_nodes |
|
|
prompt = intervention_graph.prompt |
|
|
|
|
|
|
|
|
use_simple = force_simple |
|
|
node_data = None |
|
|
layer_range = (0, 0) |
|
|
pos_range = (0, 0) |
|
|
|
|
|
if not force_simple: |
|
|
result = calculate_node_positions_improved(nodes) |
|
|
if result is None: |
|
|
|
|
|
use_simple = True |
|
|
else: |
|
|
node_data, layer_range, pos_range = result |
|
|
|
|
|
|
|
|
if use_simple: |
|
|
node_data, layer_range, pos_range = calculate_node_positions_simple(nodes) |
|
|
|
|
|
|
|
|
connections = build_connections_data(nodes) |
|
|
|
|
|
|
|
|
connections_svg = create_connection_svg(node_data, connections, use_simple) |
|
|
nodes_svg = create_nodes_svg(node_data, use_simple) |
|
|
|
|
|
|
|
|
if use_simple: |
|
|
svg_content = _create_simple_svg(prompt, top_outputs, connections_svg, nodes_svg) |
|
|
else: |
|
|
svg_content = _create_improved_svg(prompt, top_outputs, connections_svg, nodes_svg, |
|
|
layer_range, pos_range) |
|
|
|
|
|
return SVG(svg_content) |
|
|
|
|
|
|
|
|
def _create_simple_svg(prompt, top_outputs, connections_svg, nodes_svg): |
|
|
"""Generate simple layout SVG""" |
|
|
|
|
|
|
|
|
output_y_start = 350 |
|
|
output_items_svg = [] |
|
|
current_x = 40 |
|
|
|
|
|
for i, (text, percentage) in enumerate(top_outputs): |
|
|
if i >= 6: |
|
|
break |
|
|
|
|
|
display_text = text if text else "(empty)" |
|
|
escaped_display_text = html.escape(display_text) |
|
|
percentage_text = f"{round(percentage * 100)}%" |
|
|
|
|
|
item_width = len(display_text) * 8 + len(percentage_text) * 6 + 20 |
|
|
output_items_svg.append(f'<rect x="{current_x}" y="{output_y_start}" width="{item_width}" height="20" ' |
|
|
f'fill="#e8e8e8" stroke="none" rx="6"/>') |
|
|
|
|
|
output_items_svg.append(f'<text x="{current_x + 5}" y="{output_y_start + 14}" ' |
|
|
f'fill="#333" font-family="Arial, sans-serif" font-size="11" font-weight="bold">' |
|
|
f'{escaped_display_text} <tspan fill="#555" font-size="10">{percentage_text}</tspan></text>') |
|
|
|
|
|
current_x += item_width + 10 |
|
|
|
|
|
output_items_svg_str = '\n'.join(output_items_svg) |
|
|
|
|
|
|
|
|
escaped_prompt = html.escape(prompt) |
|
|
prompt_lines = wrap_text_for_svg(escaped_prompt, max_width=80) |
|
|
|
|
|
prompt_text_svg = [] |
|
|
for i, line in enumerate(prompt_lines): |
|
|
y_offset = 325 + (i * 15) |
|
|
prompt_text_svg.append(f'<text x="40" y="{y_offset}" fill="#333" font-family="Arial, sans-serif" font-size="12">{line}</text>') |
|
|
|
|
|
prompt_text_svg_str = '\n'.join(prompt_text_svg) |
|
|
|
|
|
return f'''<svg width="700" height="400" xmlns="http://www.w3.org/2000/svg"> |
|
|
<!-- Background --> |
|
|
<rect width="700" height="400" fill="#f5f5f5"/> |
|
|
<rect x="20" y="20" width="660" height="360" fill="white" stroke="none" rx="12"/> |
|
|
|
|
|
<!-- Title --> |
|
|
<text x="40" y="45" fill="#666" font-family="Arial, sans-serif" font-size="14" font-weight="bold" |
|
|
text-transform="uppercase" letter-spacing="1px">Graph & Interventions</text> |
|
|
|
|
|
<!-- Graph area --> |
|
|
<g transform="translate(50, 0)"> |
|
|
{connections_svg} |
|
|
{nodes_svg} |
|
|
</g> |
|
|
|
|
|
<!-- Prompt section --> |
|
|
<line x1="40" y1="290" x2="660" y2="290" stroke="#ddd" stroke-width="1"/> |
|
|
<text x="40" y="310" fill="#666" font-family="Arial, sans-serif" font-size="12" font-weight="bold" |
|
|
text-transform="uppercase" letter-spacing="0.5px">Prompt</text> |
|
|
|
|
|
{prompt_text_svg_str} |
|
|
|
|
|
<!-- Top outputs section --> |
|
|
<text x="40" y="350" fill="#666" font-family="Arial, sans-serif" font-size="10" font-weight="bold" |
|
|
text-transform="uppercase" letter-spacing="0.5px">Top Outputs</text> |
|
|
|
|
|
<g transform="translate(0, 5)"> |
|
|
{output_items_svg_str} |
|
|
</g> |
|
|
</svg>''' |
|
|
|
|
|
|
|
|
def _create_improved_svg(prompt, top_outputs, connections_svg, nodes_svg, layer_range, pos_range): |
|
|
"""Generate improved layout SVG with Layer × Position""" |
|
|
|
|
|
min_layer, max_layer = layer_range |
|
|
min_pos, max_pos = pos_range |
|
|
|
|
|
|
|
|
grid_svg_parts = [] |
|
|
base_y = 1450 |
|
|
y_spacing = 30 |
|
|
|
|
|
for layer in range(min_layer, max_layer + 1): |
|
|
y = base_y - layer * y_spacing |
|
|
grid_svg_parts.append(f'<line x1="50" y1="{y}" x2="1150" y2="{y}" ' |
|
|
f'stroke="#ddd" stroke-width="0.5" stroke-dasharray="5,5"/>') |
|
|
grid_svg_parts.append(f'<text x="30" y="{y + 5}" fill="#999" ' |
|
|
f'font-family="monospace" font-size="10">L{layer}</text>') |
|
|
|
|
|
grid_svg = '\n'.join(grid_svg_parts) |
|
|
|
|
|
|
|
|
base_x = 100 |
|
|
x_spacing = 100 |
|
|
token_markers = [] |
|
|
|
|
|
for pos in range(min_pos, min(max_pos + 1, 15)): |
|
|
x = base_x + pos * x_spacing |
|
|
token_markers.append(f'<text x="{x}" y="1530" fill="#666" text-anchor="middle" ' |
|
|
f'font-family="monospace" font-size="10">T{pos}</text>') |
|
|
|
|
|
tokens_svg = '\n'.join(token_markers) |
|
|
|
|
|
|
|
|
legend_svg = ''' |
|
|
<g transform="translate(950, 50)"> |
|
|
<text x="0" y="0" fill="#666" font-size="12" font-weight="bold">Legenda:</text> |
|
|
<line x1="0" y1="15" x2="40" y2="15" stroke="#4169E1" stroke-width="2.5"/> |
|
|
<text x="45" y="20" fill="#666" font-size="10">Forward (layer up)</text> |
|
|
<line x1="0" y1="35" x2="40" y2="35" stroke="#DC143C" stroke-width="2"/> |
|
|
<text x="45" y="40" fill="#666" font-size="10">Backward/Lateral</text> |
|
|
</g> |
|
|
''' |
|
|
|
|
|
|
|
|
output_items_svg = [] |
|
|
current_x = 950 |
|
|
current_y = 100 |
|
|
|
|
|
for i, (text, percentage) in enumerate(top_outputs[:4]): |
|
|
display_text = text if text else "(empty)" |
|
|
escaped_display_text = html.escape(display_text) |
|
|
percentage_text = f"{round(percentage * 100)}%" |
|
|
|
|
|
output_items_svg.append(f'<text x="{current_x}" y="{current_y + i*20}" fill="#333" ' |
|
|
f'font-family="Arial, sans-serif" font-size="10">' |
|
|
f'{escaped_display_text}: <tspan fill="#666">{percentage_text}</tspan></text>') |
|
|
|
|
|
outputs_svg = '\n'.join(output_items_svg) |
|
|
|
|
|
return f'''<svg width="1300" height="1650" xmlns="http://www.w3.org/2000/svg"> |
|
|
<rect width="1300" height="1650" fill="#fafafa"/> |
|
|
|
|
|
<text x="650" y="30" text-anchor="middle" fill="#333" |
|
|
font-family="Arial, sans-serif" font-size="18" font-weight="bold"> |
|
|
Attribution Graph - Layer × Token Position Layout |
|
|
</text> |
|
|
|
|
|
<text x="650" y="50" text-anchor="middle" fill="#666" |
|
|
font-family="Arial, sans-serif" font-size="12"> |
|
|
{html.escape(prompt)} |
|
|
</text> |
|
|
|
|
|
<!-- Grid --> |
|
|
{grid_svg} |
|
|
|
|
|
<!-- Token markers --> |
|
|
{tokens_svg} |
|
|
|
|
|
<!-- Connections --> |
|
|
{connections_svg} |
|
|
|
|
|
<!-- Nodes --> |
|
|
{nodes_svg} |
|
|
|
|
|
<!-- Legend --> |
|
|
{legend_svg} |
|
|
|
|
|
<!-- Top outputs --> |
|
|
<text x="950" y="85" fill="#666" font-size="12" font-weight="bold">Top Outputs:</text> |
|
|
{outputs_svg} |
|
|
|
|
|
<!-- Axis labels --> |
|
|
<text x="650" y="1570" text-anchor="middle" fill="#666" font-weight="bold" font-size="12"> |
|
|
Token Position → |
|
|
</text> |
|
|
<text x="10" y="825" text-anchor="middle" fill="#666" font-weight="bold" font-size="12" |
|
|
transform="rotate(-90 10 825)"> |
|
|
Layer → |
|
|
</text> |
|
|
</svg>''' |
|
|
|