| | """ |
| | attribution_tracer.py |
| | |
| | Core implementation of the Attribution Tracing module for the glyphs framework. |
| | This module maps token-to-token attribution flows, tracks query-key alignment, |
| | and visualizes attention patterns to reveal latent semantic structures. |
| | """ |
| |
|
| | import logging |
| | import time |
| | import numpy as np |
| | from typing import Dict, List, Optional, Tuple, Union, Any, Set |
| | from dataclasses import dataclass, field |
| | import json |
| | import hashlib |
| | from pathlib import Path |
| | import matplotlib.pyplot as plt |
| | import networkx as nx |
| | from enum import Enum |
| |
|
| | from ..models.adapter import ModelAdapter |
| | from ..utils.visualization_utils import VisualizationEngine |
| |
|
| | |
| | logger = logging.getLogger("glyphs.attribution_tracer") |
| | logger.setLevel(logging.INFO) |
| |
|
| | class AttributionType(Enum): |
| | """Types of attribution that can be traced.""" |
| | DIRECT = "direct" |
| | INDIRECT = "indirect" |
| | RESIDUAL = "residual" |
| | MULTIHEAD = "multihead" |
| | NULL = "null" |
| | COMPOSITE = "composite" |
| | RECURSIVE = "recursive" |
| | EMERGENT = "emergent" |
| |
|
| |
|
| | @dataclass |
| | class AttributionLink: |
| | """A link in an attribution chain between source and target tokens.""" |
| | source_idx: int |
| | target_idx: int |
| | attribution_type: AttributionType |
| | strength: float |
| | attention_heads: List[int] = field(default_factory=list) |
| | layers: List[int] = field(default_factory=list) |
| | intermediate_tokens: List[int] = field(default_factory=list) |
| | residue: Optional[Dict[str, Any]] = None |
| |
|
| |
|
| | @dataclass |
| | class AttributionMap: |
| | """Complete map of attribution across a sequence.""" |
| | prompt_tokens: List[str] |
| | output_tokens: List[str] |
| | links: List[AttributionLink] |
| | token_salience: Dict[int, float] = field(default_factory=dict) |
| | attribution_gaps: List[Tuple[int, int]] = field(default_factory=list) |
| | collapsed_regions: List[Tuple[int, int]] = field(default_factory=list) |
| | uncertainty: Dict[int, float] = field(default_factory=dict) |
| | metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| |
|
| | @dataclass |
| | class ForkPath: |
| | """A fork in the attribution path, representing alternative attributions.""" |
| | id: str |
| | description: str |
| | links: List[AttributionLink] |
| | confidence: float |
| | conflict_points: List[int] = field(default_factory=list) |
| | residue: Optional[Dict[str, Any]] = None |
| |
|
| |
|
| | @dataclass |
| | class AttentionHead: |
| | """Representation of an attention head's behavior.""" |
| | layer: int |
| | head: int |
| | pattern_type: str |
| | focus_tokens: List[int] |
| | strength: float |
| | function: Optional[str] = None |
| | attribution_role: Optional[str] = None |
| |
|
| |
|
| | class AttributionTracer: |
| | """ |
| | Core attribution tracing system for the glyphs framework. |
| | |
| | This class implements attribution tracing between tokens, mapping |
| | how information flows through transformer architectures from inputs |
| | to outputs. It provides insights into the causal relationships |
| | between tokens and the formation of semantic structures. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | model: ModelAdapter, |
| | config: Optional[Dict[str, Any]] = None, |
| | visualizer: Optional[VisualizationEngine] = None |
| | ): |
| | """ |
| | Initialize the attribution tracer. |
| | |
| | Parameters: |
| | ----------- |
| | model : ModelAdapter |
| | Model adapter for the target model |
| | config : Optional[Dict[str, Any]] |
| | Configuration parameters for the tracer |
| | visualizer : Optional[VisualizationEngine] |
| | Visualization engine for attribution visualization |
| | """ |
| | self.model = model |
| | self.config = config or {} |
| | self.visualizer = visualizer |
| | |
| | |
| | self.trace_depth = self.config.get("trace_depth", 5) |
| | self.min_attribution_strength = self.config.get("min_attribution_strength", 0.1) |
| | self.include_indirect = self.config.get("include_indirect", True) |
| | self.trace_residual = self.config.get("trace_residual", True) |
| | self.collapse_threshold = self.config.get("collapse_threshold", 0.05) |
| | |
| | |
| | self.attribution_history = [] |
| | |
| | |
| | self._init_attribution_glyphs() |
| | |
| | logger.info(f"Attribution tracer initialized for model: {model.model_id}") |
| | |
| | def _init_attribution_glyphs(self): |
| | """Initialize glyph mappings for attribution visualization.""" |
| | |
| | self.strength_glyphs = { |
| | "very_strong": "🔍", |
| | "strong": "🔗", |
| | "moderate": "🧩", |
| | "weak": "🌫️", |
| | "very_weak": "💤", |
| | } |
| | |
| | |
| | self.type_glyphs = { |
| | AttributionType.DIRECT: "⮕", |
| | AttributionType.INDIRECT: "⤑", |
| | AttributionType.RESIDUAL: "↝", |
| | AttributionType.MULTIHEAD: "⥇", |
| | AttributionType.NULL: "⊘", |
| | AttributionType.COMPOSITE: "⬥", |
| | AttributionType.RECURSIVE: "↻", |
| | AttributionType.EMERGENT: "⇞", |
| | } |
| | |
| | |
| | self.pattern_glyphs = { |
| | "attribution_chain": "🔗", |
| | "attribution_fork": "🔀", |
| | "attribution_loop": "🔄", |
| | "attribution_gap": "⊟", |
| | "attribution_cluster": "☷", |
| | "attribution_decay": "🌊", |
| | "attribution_conflict": "⚡", |
| | } |
| | |
| | |
| | self.meta_glyphs = { |
| | "attribution_focus": "🎯", |
| | "uncertainty": "❓", |
| | "recursive_reference": "🜏", |
| | "collapse_point": "🝚", |
| | } |
| | |
| | def trace( |
| | self, |
| | prompt: str, |
| | output: Optional[str] = None, |
| | depth: Optional[int] = None, |
| | include_confidence: bool = True, |
| | visualize: bool = False |
| | ) -> AttributionMap: |
| | """ |
| | Trace attribution between prompt and output. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | output : Optional[str] |
| | Output to trace attribution for. If None, will generate output. |
| | depth : Optional[int] |
| | Depth of attribution tracing. If None, uses default. |
| | include_confidence : bool |
| | Whether to include confidence scores |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | AttributionMap |
| | Map of attribution between tokens |
| | """ |
| | trace_start = time.time() |
| | depth = depth or self.trace_depth |
| | |
| | logger.info(f"Tracing attribution with depth {depth}") |
| | |
| | |
| | if output is None: |
| | output = self.model.generate(prompt=prompt, max_tokens=800) |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | output_tokens = self._tokenize(output) |
| | |
| | |
| | attribution_map = AttributionMap( |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | links=[], |
| | metadata={ |
| | "prompt": prompt, |
| | "output": output, |
| | "model_id": self.model.model_id, |
| | "trace_depth": depth, |
| | "timestamp": time.time() |
| | } |
| | ) |
| | |
| | |
| | if hasattr(self.model, "get_attribution"): |
| | try: |
| | logger.info("Getting attribution directly from model API") |
| | api_attribution = self.model.get_attribution( |
| | prompt=prompt, |
| | output=output, |
| | include_confidence=include_confidence |
| | ) |
| | attribution_map = self._process_api_attribution( |
| | api_attribution, |
| | prompt_tokens, |
| | output_tokens |
| | ) |
| | logger.info("Successfully processed API attribution") |
| | except Exception as e: |
| | logger.warning(f"Failed to get attribution from API: {e}") |
| | logger.info("Falling back to inference-based attribution") |
| | attribution_map = self._infer_attribution( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | depth=depth |
| | ) |
| | else: |
| | |
| | logger.info("Using inference-based attribution") |
| | attribution_map = self._infer_attribution( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | depth=depth |
| | ) |
| | |
| | |
| | attribution_map.token_salience = self._analyze_token_salience( |
| | attribution_map.links |
| | ) |
| | |
| | |
| | attribution_map.attribution_gaps = self._find_attribution_gaps( |
| | attribution_map.links, |
| | len(prompt_tokens), |
| | len(output_tokens) |
| | ) |
| | |
| | |
| | attribution_map.collapsed_regions = self._detect_collapsed_regions( |
| | attribution_map.links, |
| | len(prompt_tokens), |
| | len(output_tokens) |
| | ) |
| | |
| | |
| | if include_confidence: |
| | attribution_map.uncertainty = self._calculate_attribution_uncertainty( |
| | attribution_map.links |
| | ) |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_attribution(attribution_map) |
| | attribution_map.metadata["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | attribution_map.metadata["trace_time"] = trace_time |
| | |
| | |
| | self.attribution_history.append(attribution_map) |
| | |
| | logger.info(f"Attribution tracing completed in {trace_time:.2f}s") |
| | return attribution_map |
| | |
| | def trace_with_forks( |
| | self, |
| | prompt: str, |
| | output: Optional[str] = None, |
| | fork_factor: int = 3, |
| | include_confidence: bool = True, |
| | visualize: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Trace attribution with multiple fork paths. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | output : Optional[str] |
| | Output to trace attribution for. If None, will generate output. |
| | fork_factor : int |
| | Number of alternative attribution paths to generate |
| | include_confidence : bool |
| | Whether to include confidence scores |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Dictionary containing attribution map and fork paths |
| | """ |
| | trace_start = time.time() |
| | |
| | logger.info(f"Tracing attribution with {fork_factor} fork paths") |
| | |
| | |
| | if output is None: |
| | output = self.model.generate(prompt=prompt, max_tokens=800) |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | output_tokens = self._tokenize(output) |
| | |
| | |
| | base_attribution = self.trace( |
| | prompt=prompt, |
| | output=output, |
| | include_confidence=include_confidence, |
| | visualize=False |
| | ) |
| | |
| | |
| | fork_paths = [] |
| | |
| | |
| | conflict_points = self._identify_conflict_points(base_attribution) |
| | |
| | |
| | for i in range(fork_factor): |
| | |
| | fork_path = self._generate_fork_path( |
| | base_attribution=base_attribution, |
| | conflict_points=conflict_points, |
| | fork_id=f"fork_{i+1}", |
| | fork_index=i |
| | ) |
| | fork_paths.append(fork_path) |
| | |
| | |
| | fork_result = { |
| | "base_attribution": base_attribution, |
| | "fork_paths": fork_paths, |
| | "conflict_points": conflict_points, |
| | "metadata": { |
| | "prompt": prompt, |
| | "output": output, |
| | "model_id": self.model.model_id, |
| | "fork_factor": fork_factor, |
| | "timestamp": time.time() |
| | } |
| | } |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_attribution_forks(fork_result) |
| | fork_result["metadata"]["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | fork_result["metadata"]["trace_time"] = trace_time |
| | |
| | logger.info(f"Attribution fork tracing completed in {trace_time:.2f}s") |
| | return fork_result |
| | |
| | def trace_attention_heads( |
| | self, |
| | prompt: str, |
| | output: Optional[str] = None, |
| | layer_range: Optional[Tuple[int, int]] = None, |
| | head_threshold: float = 0.1, |
| | visualize: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Trace attribution through specific attention heads. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | output : Optional[str] |
| | Output to trace attention for. If None, will generate output. |
| | layer_range : Optional[Tuple[int, int]] |
| | Range of layers to analyze. If None, analyzes all layers. |
| | head_threshold : float |
| | Minimum attention strength to include head |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Dictionary containing attention head analysis |
| | """ |
| | trace_start = time.time() |
| | |
| | |
| | if output is None: |
| | output = self.model.generate(prompt=prompt, max_tokens=800) |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | output_tokens = self._tokenize(output) |
| | |
| | |
| | model_info = self.model.get_model_info() |
| | num_layers = model_info.get("num_layers", 12) |
| | num_heads = model_info.get("num_heads", 12) |
| | |
| | |
| | if layer_range is None: |
| | layer_range = (0, num_layers - 1) |
| | else: |
| | layer_range = ( |
| | max(0, layer_range[0]), |
| | min(num_layers - 1, layer_range[1]) |
| | ) |
| | |
| | |
| | attention_heads = [] |
| | |
| | |
| | if hasattr(self.model, "get_attention_patterns"): |
| | try: |
| | logger.info("Getting attention patterns directly from model API") |
| | attention_patterns = self.model.get_attention_patterns( |
| | prompt=prompt, |
| | output=output, |
| | layer_range=layer_range |
| | ) |
| | attention_heads = self._process_api_attention( |
| | attention_patterns, |
| | prompt_tokens, |
| | output_tokens, |
| | layer_range, |
| | head_threshold |
| | ) |
| | logger.info("Successfully processed API attention patterns") |
| | except Exception as e: |
| | logger.warning(f"Failed to get attention patterns from API: {e}") |
| | logger.info("Falling back to inference-based attention analysis") |
| | attention_heads = self._infer_attention_behavior( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_range=layer_range, |
| | num_heads=num_heads, |
| | head_threshold=head_threshold |
| | ) |
| | else: |
| | |
| | logger.info("Using inference-based attention analysis") |
| | attention_heads = self._infer_attention_behavior( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_range=layer_range, |
| | num_heads=num_heads, |
| | head_threshold=head_threshold |
| | ) |
| | |
| | |
| | head_patterns = self._analyze_attention_patterns(attention_heads) |
| | |
| | |
| | attention_result = { |
| | "prompt_tokens": prompt_tokens, |
| | "output_tokens": output_tokens, |
| | "attention_heads": attention_heads, |
| | "head_patterns": head_patterns, |
| | "metadata": { |
| | "prompt": prompt, |
| | "output": output, |
| | "model_id": self.model.model_id, |
| | "layer_range": layer_range, |
| | "head_threshold": head_threshold, |
| | "timestamp": time.time() |
| | } |
| | } |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_attention_heads(attention_result) |
| | attention_result["metadata"]["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | attention_result["metadata"]["trace_time"] = trace_time |
| | |
| | logger.info(f"Attention head tracing completed in {trace_time:.2f}s") |
| | return attention_result |
| | |
| | def trace_qk_alignment( |
| | self, |
| | prompt: str, |
| | output: Optional[str] = None, |
| | layer_indices: Optional[List[int]] = None, |
| | visualize: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Trace query-key alignment in attention mechanisms. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | output : Optional[str] |
| | Output to trace alignment for. If None, will generate output. |
| | layer_indices : Optional[List[int]] |
| | Specific layers to analyze. If None, analyzes representative layers. |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Dictionary containing QK alignment analysis |
| | """ |
| | trace_start = time.time() |
| | |
| | |
| | if output is None: |
| | output = self.model.generate(prompt=prompt, max_tokens=800) |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | output_tokens = self._tokenize(output) |
| | |
| | |
| | model_info = self.model.get_model_info() |
| | num_layers = model_info.get("num_layers", 12) |
| | |
| | |
| | if layer_indices is None: |
| | |
| | layer_indices = [ |
| | 0, |
| | num_layers // 2, |
| | num_layers - 1 |
| | ] |
| | |
| | |
| | qk_alignments = [] |
| | |
| | |
| | if hasattr(self.model, "get_qk_values"): |
| | try: |
| | logger.info("Getting QK values directly from model API") |
| | qk_values = self.model.get_qk_values( |
| | prompt=prompt, |
| | output=output, |
| | layer_indices=layer_indices |
| | ) |
| | qk_alignments = self._process_api_qk_values( |
| | qk_values, |
| | prompt_tokens, |
| | output_tokens, |
| | layer_indices |
| | ) |
| | logger.info("Successfully processed API QK values") |
| | except Exception as e: |
| | logger.warning(f"Failed to get QK values from API: {e}") |
| | logger.info("Falling back to inference-based QK alignment") |
| | qk_alignments = self._infer_qk_alignment( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_indices=layer_indices |
| | ) |
| | else: |
| | |
| | logger.info("Using inference-based QK alignment") |
| | qk_alignments = self._infer_qk_alignment( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_indices=layer_indices |
| | ) |
| | |
| | |
| | qk_patterns = self._analyze_qk_patterns(qk_alignments) |
| | |
| | |
| | qk_result = { |
| | "prompt_tokens": prompt_tokens, |
| | "output_tokens": output_tokens, |
| | "qk_alignments": qk_alignments, |
| | "qk_patterns": qk_patterns, |
| | "metadata": { |
| | "prompt": prompt, |
| | "output": output, |
| | "model_id": self.model.model_id, |
| | "layer_indices": layer_indices, |
| | "timestamp": time.time() |
| | } |
| | } |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_qk_alignment(qk_result) |
| | qk_result["metadata"]["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | qk_result["metadata"]["trace_time"] = trace_time |
| | |
| | logger.info(f"QK alignment tracing completed in {trace_time:.2f}s") |
| | return qk_result |
| | |
| | def trace_ov_projection( |
| | self, |
| | prompt: str, |
| | output: Optional[str] = None, |
| | layer_indices: Optional[List[int]] = None, |
| | visualize: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Trace output-value projection in attention mechanisms. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | output : Optional[str] |
| | Output to trace projection for. If None, will generate output. |
| | layer_indices : Optional[List[int]] |
| | Specific layers to analyze. If None, analyzes representative layers. |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Dictionary containing OV projection analysis |
| | """ |
| | trace_start = time.time() |
| | |
| | |
| | if output is None: |
| | output = self.model.generate(prompt=prompt, max_tokens=800) |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | output_tokens = self._tokenize(output) |
| | |
| | |
| | model_info = self.model.get_model_info() |
| | num_layers = model_info.get("num_layers", 12) |
| | |
| | |
| | if layer_indices is None: |
| | |
| | layer_indices = [ |
| | 0, |
| | num_layers // 2, |
| | num_layers - 1 |
| | ] |
| | |
| | |
| | ov_projections = [] |
| | |
| | |
| | if hasattr(self.model, "get_ov_values"): |
| | try: |
| | logger.info("Getting OV values directly from model API") |
| | ov_values = self.model.get_ov_values( |
| | prompt=prompt, |
| | output=output, |
| | layer_indices=layer_indices |
| | ) |
| | ov_projections = self._process_api_ov_values( |
| | ov_values, |
| | prompt_tokens, |
| | output_tokens, |
| | layer_indices |
| | ) |
| | logger.info("Successfully processed API OV values") |
| | except Exception as e: |
| | logger.warning(f"Failed to get OV values from API: {e}") |
| | logger.info("Falling back to inference-based OV projection") |
| | ov_projections = self._infer_ov_projection( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_indices=layer_indices |
| | ) |
| | else: |
| | |
| | logger.info("Using inference-based OV projection") |
| | ov_projections = self._infer_ov_projection( |
| | prompt=prompt, |
| | output=output, |
| | prompt_tokens=prompt_tokens, |
| | output_tokens=output_tokens, |
| | layer_indices=layer_indices |
| | ) |
| | |
| | |
| | ov_patterns = self._analyze_ov_patterns(ov_projections) |
| | |
| | |
| | ov_result = { |
| | "prompt_tokens": prompt_tokens, |
| | "output_tokens": output_tokens, |
| | "ov_projections": ov_projections, |
| | "ov_patterns": ov_patterns, |
| | "metadata": { |
| | "prompt": prompt, |
| | "output": output, |
| | "model_id": self.model.model_id, |
| | "layer_indices": layer_indices, |
| | "timestamp": time.time() |
| | } |
| | } |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_ov_projection(ov_result) |
| | ov_result["metadata"]["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | ov_result["metadata"]["trace_time"] = trace_time |
| | |
| | logger.info(f"OV projection tracing completed in {trace_time:.2f}s") |
| | return ov_result |
| | |
| | def compare_attribution( |
| | self, |
| | prompt: str, |
| | outputs: List[str], |
| | include_confidence: bool = True, |
| | visualize: bool = False |
| | ) -> Dict[str, Any]: |
| | """ |
| | Compare attribution across multiple outputs for the same prompt. |
| | |
| | Parameters: |
| | ----------- |
| | prompt : str |
| | Input prompt |
| | outputs : List[str] |
| | List of outputs to compare attribution for |
| | include_confidence : bool |
| | Whether to include confidence scores |
| | visualize : bool |
| | Whether to generate visualization |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Dictionary containing attribution comparison |
| | """ |
| | trace_start = time.time() |
| | |
| | logger.info(f"Comparing attribution across {len(outputs)} outputs") |
| | |
| | |
| | prompt_tokens = self._tokenize(prompt) |
| | |
| | |
| | attribution_maps = [] |
| | for i, output in enumerate(outputs): |
| | logger.info(f"Tracing attribution for output {i+1}/{len(outputs)}") |
| | attribution_map = self.trace( |
| | prompt=prompt, |
| | output=output, |
| | include_confidence=include_confidence, |
| | visualize=False |
| | ) |
| | attribution_maps.append(attribution_map) |
| | |
| | |
| | comparison = self._compare_attribution_maps(attribution_maps) |
| | |
| | |
| | comparison_result = { |
| | "prompt": prompt, |
| | "prompt_tokens": prompt_tokens, |
| | "outputs": outputs, |
| | "attribution_maps": attribution_maps, |
| | "comparison": comparison, |
| | "metadata": { |
| | "model_id": self.model.model_id, |
| | "num_outputs": len(outputs), |
| | "timestamp": time.time() |
| | } |
| | } |
| | |
| | |
| | if visualize and self.visualizer: |
| | visualization = self.visualizer.visualize_attribution_comparison(comparison_result) |
| | comparison_result["metadata"]["visualization"] = visualization |
| | |
| | |
| | trace_time = time.time() - trace_start |
| | comparison_result["metadata"]["trace_time"] = trace_time |
| | |
| | logger.info(f"Attribution comparison completed in {trace_time:.2f}s") |
| | return comparison_result |
| | |
| | def visualize_attribution( |
| | self, |
| | attribution_map: AttributionMap, |
| | visualization_type: str = "network", |
| | highlight_tokens: Optional[List[str]] = None, |
| | output_path: Optional[str] = None |
| | ) -> Dict[str, Any]: |
| | """ |
| | Visualize attribution map. |
| | |
| | Parameters: |
| | ----------- |
| | attribution_map : AttributionMap |
| | Attribution map to visualize |
| | visualization_type : str |
| | Type of visualization to generate |
| | highlight_tokens : Optional[List[str]] |
| | Tokens to highlight in visualization |
| | output_path : Optional[str] |
| | Path to save visualization to |
| | |
| | Returns: |
| | -------- |
| | Dict[str, Any] |
| | Visualization result |
| | """ |
| | if self.visualizer: |
| | return self.visualizer.visualize_attribution( |
| | attribution_map=attribution_map, |
| | visualization_type=visualization_type, |
| | highlight_tokens=highlight_tokens, |
| | output_path=output_path |
| | ) |
| | else: |
| | |
| | return self._simple_visualization( |
| | attribution_map=attribution_map, |
| | visualization_type=visualization_type, |
| | highlight_tokens=highlight_tokens, |
| | output_path=output_path |
| | ) |
| | |
| | def _tokenize(self, text: str) -> List[str]: |
| | """Tokenize text using model tokenizer.""" |
| | if hasattr(self.model, "tokenize"): |
| | return self.model.tokenize(text) |
| | else: |
| | |
| | return text.split() |
| | |
| | def _infer_attribution( |
| | self, |
| | prompt: str, |
| | output: str, |
| | prompt_tokens: |
| | |