| import gradio as gr |
| import torch |
| import sys |
| import os |
| import json |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from core.model_handler import ModelHandler |
| from core.attention import AttentionProcessor |
| from core.cache import AttentionCache |
| from config import Config |
| from visualization.d3_viz import create_d3_visualization |
|
|
| class TokenVisualizerApp: |
| def __init__(self): |
| self.config = Config() |
| self.model_handler = ModelHandler(config=self.config) |
| self.cache = AttentionCache(max_size=self.config.CACHE_SIZE) |
| self.current_data = None |
| self.model_loaded = False |
| |
|
|
| def load_model(self, model_name: str = None) -> str: |
| """Load the model and return status message.""" |
| if not model_name: |
| model_name = self.config.DEFAULT_MODEL |
| |
| success, message = self.model_handler.load_model(model_name) |
| self.model_loaded = success |
| |
| if success: |
| model_info = self.model_handler.get_model_info() |
| return f"โ
Model loaded: {model_name}\n๐ Parameters: {model_info['num_parameters']:,}\n๐ฅ๏ธ Device: {model_info['device']}" |
| else: |
| return f"โ Failed to load model: {message}" |
| |
| def generate_and_visualize( |
| self, |
| prompt: str, |
| max_tokens: int, |
| threshold: float, |
| temperature: float, |
| normalization: str, |
| progress=gr.Progress() |
| ): |
| """Main generation function (no visualization).""" |
| if not self.model_loaded: |
| return None, "Please load a model first!", None |
| |
| if not prompt.strip(): |
| return None, "Please enter a prompt!", None |
| |
| progress(0.2, desc="Checking cache...") |
| |
| |
| cache_key = self.cache.get_key( |
| prompt, max_tokens, |
| self.model_handler.model_name, |
| temperature |
| ) |
| cached = self.cache.get(cache_key) |
| |
| if cached: |
| progress(0.5, desc="Using cached data...") |
| self.current_data = cached |
| else: |
| progress(0.3, desc="Generating text...") |
| |
| |
| attention_data, output_tokens, input_tokens, generated_text = \ |
| self.model_handler.generate_with_attention( |
| prompt, max_tokens, temperature |
| ) |
| |
| if attention_data is None: |
| return None, f"Generation failed: {generated_text}", None |
| |
| progress(0.6, desc="Processing attention...") |
| |
| |
| if normalization == "separate": |
| attention_matrices = AttentionProcessor.process_attention_separate( |
| attention_data, input_tokens, output_tokens |
| ) |
| else: |
| attention_matrices = AttentionProcessor.process_attention_joint( |
| attention_data, input_tokens, output_tokens |
| ) |
| |
| self.current_data = { |
| 'input_tokens': input_tokens, |
| 'output_tokens': output_tokens, |
| 'attention_matrices': attention_matrices, |
| 'generated_text': generated_text, |
| 'attention_data': attention_data |
| } |
| |
| |
| self.cache.set(cache_key, self.current_data) |
| |
| progress(1.0, desc="Complete!") |
| |
| |
| info_text = f"๐ Generated: {self.current_data['generated_text']}\n" |
| info_text += f"๐ค Input tokens: {len(self.current_data['input_tokens'])}\n" |
| info_text += f"๐ค Output tokens: {len(self.current_data['output_tokens'])}" |
| |
| return ( |
| info_text, |
| ) |
| |
| def update_step(self, step_idx: int, threshold: float): |
| """No-op placeholder after removing visualization.""" |
| return None |
| |
| def update_threshold(self, threshold: float, normalization: str): |
| """No-op placeholder after removing visualization.""" |
| return None |
| |
| def filter_token_connections(self, token_idx: int, token_type: str, threshold: float): |
| """Removed visualization; keep placeholder.""" |
| return None |
| |
| def reset_view(self, threshold: float): |
| """Removed visualization; keep placeholder.""" |
| return None |
|
|
| def on_d3_token_click(self, click_data: str, threshold: float): |
| """Removed visualization; keep placeholder for compatibility.""" |
| return None, gr.update() |
| |
| def on_input_token_select(self, token_label: str, threshold: float): |
| """Removed visualization; keep placeholder for compatibility.""" |
| return None |
|
|
| def prepare_d3_data(self, step_idx: int, threshold: float = 0.01, filter_token: str = None): |
| """ |
| Convert attention data to D3.js-friendly JSON format. |
| |
| Args: |
| step_idx: Generation step to visualize (0-based) |
| threshold: Minimum attention weight to include |
| filter_token: Token to filter by (format: "[IN] token" or "[OUT] token" or "All tokens") |
| |
| Returns: |
| dict: JSON structure with nodes and links for D3.js |
| """ |
| if not self.current_data: |
| return {"nodes": [], "links": []} |
| |
| input_tokens = self.current_data['input_tokens'] |
| output_tokens = self.current_data['output_tokens'] |
| attention_matrices = self.current_data['attention_matrices'] |
| |
| |
| if step_idx >= len(attention_matrices): |
| step_idx = len(attention_matrices) - 1 |
| |
| attention_matrix = attention_matrices[step_idx] |
| |
| |
| nodes = [] |
| |
| |
| for i, token in enumerate(input_tokens): |
| nodes.append({ |
| "id": f"input_{i}", |
| "token": token, |
| "type": "input", |
| "index": i |
| }) |
| |
| |
| for i in range(step_idx + 1): |
| if i < len(output_tokens): |
| nodes.append({ |
| "id": f"output_{i}", |
| "token": output_tokens[i], |
| "type": "output", |
| "index": i |
| }) |
| |
| |
| filter_type = None |
| filter_idx = None |
| if filter_token and filter_token != "All tokens": |
| if filter_token.startswith("[IN] "): |
| filter_type = "input" |
| filter_token_text = filter_token[5:] |
| filter_idx = next((i for i, token in enumerate(input_tokens) if token == filter_token_text), None) |
| elif filter_token.startswith("[OUT] "): |
| filter_type = "output" |
| filter_token_text = filter_token[6:] |
| filter_idx = next((i for i, token in enumerate(output_tokens) if token == filter_token_text), None) |
| |
| |
| links = [] |
| |
| |
| for current_step in range(step_idx + 1): |
| if current_step < len(attention_matrices): |
| step_attention = attention_matrices[current_step] |
| |
| |
| input_attention = step_attention['input_attention'] |
| if input_attention is not None: |
| for input_idx in range(len(input_tokens)): |
| if input_idx < len(input_attention): |
| weight = float(input_attention[input_idx]) |
| if weight >= threshold: |
| |
| show_link = True |
| if filter_type == "input" and filter_idx is not None: |
| |
| show_link = (input_idx == filter_idx) |
| elif filter_type == "output" and filter_idx is not None: |
| |
| show_link = (current_step == filter_idx) |
| |
| if show_link: |
| links.append({ |
| "source": f"input_{input_idx}", |
| "target": f"output_{current_step}", |
| "weight": weight, |
| "type": "input_to_output" |
| }) |
| |
| |
| output_attention = step_attention['output_attention'] |
| if output_attention is not None and current_step > 0: |
| for prev_output_idx in range(current_step): |
| if prev_output_idx < len(output_attention): |
| weight = float(output_attention[prev_output_idx]) |
| if weight >= threshold: |
| |
| show_link = True |
| if filter_type == "input" and filter_idx is not None: |
| |
| show_link = False |
| elif filter_type == "output" and filter_idx is not None: |
| |
| show_link = (prev_output_idx == filter_idx or current_step == filter_idx) |
| |
| if show_link: |
| links.append({ |
| "source": f"output_{prev_output_idx}", |
| "target": f"output_{current_step}", |
| "weight": weight, |
| "type": "output_to_output" |
| }) |
| |
| return { |
| "nodes": nodes, |
| "links": links, |
| "step": step_idx, |
| "total_steps": len(attention_matrices), |
| "input_count": len(input_tokens), |
| "output_count": step_idx + 1 |
| } |
|
|
| def create_d3_visualization_html(self, step_idx: int = 0, threshold: float = 0.01, filter_token: str = None): |
| """ |
| Create D3.js visualization HTML for the current data. |
| |
| Args: |
| step_idx: Generation step to visualize (0-based) |
| threshold: Minimum attention weight to include |
| filter_token: Token to filter by (format: "[IN] token" or "[OUT] token") |
| |
| Returns: |
| str: HTML string for D3.js visualization |
| """ |
| if not self.current_data: |
| return "<div>No data available. Generate text first!</div>" |
| |
| d3_data = self.prepare_d3_data(step_idx, threshold, filter_token) |
| |
| viz_html = create_d3_visualization(d3_data) |
| return viz_html |
|
|
| def get_token_choices(self): |
| """ |
| Get list of token choices for dropdown. |
| |
| Returns: |
| list: List of token strings for dropdown options |
| """ |
| if not self.current_data: |
| return [] |
| |
| input_tokens = self.current_data['input_tokens'] |
| output_tokens = self.current_data['output_tokens'] |
| |
| |
| choices = ["All tokens"] |
| choices.extend([f"[IN] {token}" for token in input_tokens]) |
| choices.extend([f"[OUT] {token}" for token in output_tokens]) |
| |
| return choices |
|
|
|
|
| def create_gradio_interface(): |
| """Create the Gradio interface.""" |
| app = TokenVisualizerApp() |
| |
| with gr.Blocks( |
| title="Token Attention Visualizer", |
| css=""" |
| /* Default/Light mode styles */ |
| .main-header { |
| text-align: center; |
| padding: 2rem 0 3rem 0; |
| background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); |
| border-radius: 1rem; |
| margin-bottom: 2rem; |
| border: 1px solid #e2e8f0; |
| } |
| |
| .main-title { |
| font-size: 2.5rem; |
| font-weight: 700; |
| color: #1e293b; |
| margin-bottom: 0.5rem; |
| background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| } |
| |
| .main-subtitle { |
| font-size: 1.125rem; |
| color: #64748b; |
| font-weight: 400; |
| } |
| |
| .section-title { |
| font-size: 1.25rem; |
| font-weight: 600; |
| color: #1e293b; |
| margin-bottom: 1.5rem; |
| padding-bottom: 0.5rem; |
| border-bottom: 2px solid #e2e8f0; |
| } |
| |
| /* Explicit light mode overrides */ |
| .light .main-header, |
| [data-theme="light"] .main-header { |
| background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); |
| border: 1px solid #e2e8f0; |
| } |
| |
| .light .main-title, |
| [data-theme="light"] .main-title { |
| color: #1e293b; |
| background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| } |
| |
| .light .main-subtitle, |
| [data-theme="light"] .main-subtitle { |
| color: #64748b; |
| } |
| |
| .light .section-title, |
| [data-theme="light"] .section-title { |
| color: #1e293b; |
| border-bottom: 2px solid #e2e8f0; |
| } |
| |
| /* Dark mode styles with higher specificity */ |
| .dark .main-header, |
| [data-theme="dark"] .main-header { |
| background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important; |
| border: 1px solid #475569 !important; |
| } |
| |
| .dark .main-title, |
| [data-theme="dark"] .main-title { |
| color: #f1f5f9 !important; |
| background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%) !important; |
| -webkit-background-clip: text !important; |
| -webkit-text-fill-color: transparent !important; |
| background-clip: text !important; |
| } |
| |
| .dark .main-subtitle, |
| [data-theme="dark"] .main-subtitle { |
| color: #cbd5e1 !important; |
| } |
| |
| .dark .section-title, |
| [data-theme="dark"] .section-title { |
| color: #f1f5f9 !important; |
| border-bottom: 2px solid #475569 !important; |
| } |
| |
| /* System dark mode - only apply when no explicit theme is set */ |
| @media (prefers-color-scheme: dark) { |
| :root:not([data-theme="light"]) .main-header { |
| background: linear-gradient(135deg, #1e293b 0%, #334155 100%); |
| border: 1px solid #475569; |
| } |
| |
| :root:not([data-theme="light"]) .main-title { |
| color: #f1f5f9; |
| background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| } |
| |
| :root:not([data-theme="light"]) .main-subtitle { |
| color: #cbd5e1; |
| } |
| |
| :root:not([data-theme="light"]) .section-title { |
| color: #f1f5f9; |
| border-bottom: 2px solid #475569; |
| } |
| } |
| |
| .load-model-btn { |
| background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important; |
| color: white !important; |
| border: none !important; |
| font-weight: 600 !important; |
| padding: 0.75rem 2rem !important; |
| border-radius: 0.5rem !important; |
| box-shadow: 0 4px 6px -1px rgba(249, 115, 22, 0.25) !important; |
| transition: all 0.2s ease !important; |
| } |
| |
| .load-model-btn:hover { |
| background: linear-gradient(135deg, #ea580c 0%, #dc2626 100%) !important; |
| transform: translateY(-1px) !important; |
| box-shadow: 0 6px 8px -1px rgba(249, 115, 22, 0.35) !important; |
| } |
| """ |
| ) as demo: |
| gr.HTML(""" |
| <div class="main-header"> |
| <h1 class="main-title">Token Attention Visualizer</h1> |
| <p class="main-subtitle">Interactive visualization of attention patterns in Large Language Models</p> |
| </div> |
| """) |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.HTML('<h2 class="section-title">Model & Generation</h2>') |
| |
| |
| model_input = gr.Textbox( |
| label="Model Name", |
| value=app.config.DEFAULT_MODEL, |
| placeholder="Enter Hugging Face model name..." |
| ) |
| load_model_btn = gr.Button("Load Model", variant="primary", elem_classes=["load-model-btn"]) |
| |
| model_status = gr.Textbox( |
| label="Model Status", |
| value="No model loaded", |
| interactive=False, |
| lines=2 |
| ) |
| |
| |
| prompt_input = gr.Textbox( |
| label="Prompt", |
| value=app.config.DEFAULT_PROMPT, |
| lines=3, |
| placeholder="Enter your prompt here..." |
| ) |
| |
| max_tokens_input = gr.Slider( |
| minimum=1, |
| maximum=50, |
| value=app.config.DEFAULT_MAX_TOKENS, |
| step=1, |
| label="Max Tokens" |
| ) |
| |
| temperature_input = gr.Slider( |
| minimum=0.0, |
| maximum=2.0, |
| value=app.config.DEFAULT_TEMPERATURE, |
| step=0.1, |
| label="Temperature" |
| ) |
| |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") |
| |
| generated_info = gr.Textbox( |
| label="Generation Info", |
| interactive=False, |
| lines=4 |
| ) |
| |
| gr.HTML('<h2 class="section-title">Visualization Controls</h2>') |
| |
| step_slider = gr.Slider( |
| minimum=0, |
| maximum=10, |
| value=0, |
| step=1, |
| label="Generation Step", |
| info="Navigate through generation steps" |
| ) |
| |
| threshold_slider = gr.Slider( |
| minimum=0.001, |
| maximum=0.5, |
| value=0.01, |
| step=0.001, |
| label="Attention Threshold", |
| info="Filter weak connections" |
| ) |
| |
| token_dropdown = gr.Dropdown( |
| choices=["All tokens"], |
| value="All tokens", |
| label="Filter by Token", |
| info="Select a token to highlight" |
| ) |
| |
| |
| with gr.Column(scale=2): |
| gr.HTML('<h2 class="section-title">Attention Visualization</h2>') |
| |
| d3_visualization = gr.HTML( |
| value="""<div style='height: 700px; display: flex; align-items: center; justify-content: center; font-size: 16px;'> |
| <div style='text-align: center;'> |
| <div style='font-size: 3rem; margin-bottom: 16px; opacity: 0.5;'>โช</div> |
| <div style='font-weight: 500; margin-bottom: 8px;'>Ready to visualize</div> |
| <div>Generate text to see attention patterns</div> |
| </div> |
| </div>""" |
| ) |
| |
| |
| |
| |
| with gr.Accordion("๐ How to Use", open=False): |
| gr.Markdown( |
| """ |
| ### Instructions: |
| 1. **Load a model** from Hugging Face (default: Llama-3.2-1B) |
| 2. **Enter a prompt** and configure generation settings |
| 3. **Click Generate** to create text and visualize attention |
| 4. **Interact with the visualization:** |
| - Use the **step slider** to navigate through generation steps |
| - Adjust the **threshold** to filter weak connections |
| - Click on **tokens** in the plot to filter their connections |
| - Click **Reset View** to show all connections |
| |
| ### Understanding the Visualization: |
| - **Blue lines**: Attention from input to output tokens |
| - **Orange curves**: Attention between output tokens |
| - **Line thickness**: Represents attention weight strength |
| - **Node colors**: Blue = input tokens, Coral = generated tokens |
| """ |
| ) |
| |
| |
| load_model_btn.click( |
| fn=app.load_model, |
| inputs=[model_input], |
| outputs=[model_status] |
| ) |
| |
| def _generate(prompt, max_tokens, threshold, temperature): |
| info, = app.generate_and_visualize( |
| prompt, max_tokens, threshold, temperature, "separate" |
| ) |
| |
| |
| max_steps = len(app.current_data['attention_matrices']) - 1 if app.current_data else 0 |
| viz_html = app.create_d3_visualization_html(step_idx=max_steps, threshold=0.01) |
| token_choices = app.get_token_choices() |
| |
| return info, viz_html, gr.update(choices=token_choices, value="All tokens"), gr.update(maximum=max_steps, value=max_steps) |
|
|
| generate_btn.click( |
| fn=_generate, |
| inputs=[ |
| prompt_input, |
| max_tokens_input, |
| gr.State(app.config.DEFAULT_THRESHOLD), |
| temperature_input |
| ], |
| outputs=[generated_info, d3_visualization, token_dropdown, step_slider] |
| ) |
| |
| |
| def _update_visualization(step_idx, threshold, filter_token="All tokens"): |
| """Update visualization when step or threshold changes.""" |
| viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=filter_token) |
| return viz_html |
|
|
| def _filter_by_token(selected_token, step_idx, threshold): |
| """Update visualization when token filter changes.""" |
| viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=selected_token) |
| return viz_html |
|
|
| |
| step_slider.change( |
| fn=_update_visualization, |
| inputs=[step_slider, threshold_slider, token_dropdown], |
| outputs=[d3_visualization] |
| ) |
| |
| threshold_slider.change( |
| fn=_update_visualization, |
| inputs=[step_slider, threshold_slider, token_dropdown], |
| outputs=[d3_visualization] |
| ) |
| |
| token_dropdown.change( |
| fn=_filter_by_token, |
| inputs=[token_dropdown, step_slider, threshold_slider], |
| outputs=[d3_visualization] |
| ) |
|
|
|
|
| |
| |
| demo.load( |
| fn=app.load_model, |
| inputs=[gr.State(app.config.DEFAULT_MODEL)], |
| outputs=[model_status] |
| ) |
| |
| return demo |
|
|
| if __name__ == "__main__": |
| |
| if torch.cuda.is_available(): |
| print(f"โ
CUDA available: {torch.cuda.get_device_name(0)}") |
| else: |
| print("โ ๏ธ CUDA not available, using CPU") |
| |
| |
| demo = create_gradio_interface() |
| """ demo.launch( |
| share=False, # Set to True for public URL |
| server_name="0.0.0.0", # Allow external connections |
| server_port=7860, # Default Gradio port |
| inbrowser=False # Don't auto-open browser |
| ) """ |
|
|
| demo.launch() |