Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import spaces | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from threading import Lock | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import random | |
| # --- 1. CONFIG & SETUP --- | |
| MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" | |
| print(f"⬇️ Downloading {MODEL_ID}...") | |
| try: | |
| snapshot_download(repo_id=MODEL_ID) | |
| print("✅ Download Ready.") | |
| except Exception as e: | |
| print(f"⚠️ Warning: {e}") | |
| model_lock = Lock() | |
| model = None | |
| tokenizer = None | |
| # We use 28 layers for Qwen 1.5B | |
| NUM_LAYERS = 28 | |
| # Visual settings | |
| NODES_PER_LAYER = 10 # Represent each layer as 10 visual nodes (abstract representation) | |
| LINES_PER_LAYER = 15 # Lines between layers to create the "Dense" look | |
| # Pre-calculate Network Geometry (X, Y, Z coords for nodes) | |
| # Structure: Layers spread along X axis. Nodes spread on Y/Z plane. | |
| node_coords_x = [] | |
| node_coords_y = [] | |
| node_coords_z = [] | |
| # Generate positions | |
| for layer_i in range(NUM_LAYERS): | |
| x_pos = layer_i * 2 # Spacing between layers | |
| # create a ring or grid of nodes for this layer | |
| for n in range(NODES_PER_LAYER): | |
| # Circle arrangement | |
| theta = (2 * np.pi * n) / NODES_PER_LAYER | |
| radius = 4 | |
| y_pos = radius * np.cos(theta) | |
| z_pos = radius * np.sin(theta) | |
| node_coords_x.append(x_pos) | |
| node_coords_y.append(y_pos) | |
| node_coords_z.append(z_pos) | |
| # Pre-calculate Connections (Edges) | |
| # List of (x1, y1, z1, x2, y2, z2) for lines | |
| edge_x, edge_y, edge_z = [], [], [] | |
| for layer_i in range(NUM_LAYERS - 1): | |
| curr_start_idx = layer_i * NODES_PER_LAYER | |
| next_start_idx = (layer_i + 1) * NODES_PER_LAYER | |
| # Create random dense connections | |
| for _ in range(LINES_PER_LAYER): | |
| # Pick random start node in current layer | |
| n1 = random.randint(0, NODES_PER_LAYER - 1) | |
| # Pick random end node in next layer | |
| n2 = random.randint(0, NODES_PER_LAYER - 1) | |
| idx1 = curr_start_idx + n1 | |
| idx2 = next_start_idx + n2 | |
| edge_x.extend([node_coords_x[idx1], node_coords_x[idx2], None]) | |
| edge_y.extend([node_coords_y[idx1], node_coords_y[idx2], None]) | |
| edge_z.extend([node_coords_z[idx1], node_coords_z[idx2], None]) | |
| # --- 2. BACKEND LOGIC --- | |
| def load_model(): | |
| global model, tokenizer | |
| if model is not None: return | |
| with model_lock: | |
| print("Loading weights...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # Session State will store: {'tokens': [], 'activations': [[layer0_val, ...], [layer0_val...]]} | |
| def run_inference(prompt): | |
| load_model() | |
| # 1. Setup Hooks | |
| # We will capture the MEAN activation of each layer for the current token | |
| current_step_activations = {} | |
| def hook_fn(layer_idx): | |
| def _hook(mod, inp, out): | |
| if isinstance(out, tuple): h = out[0] | |
| else: h = out | |
| # Capture Norm of the last token processed | |
| with torch.no_grad(): | |
| val = torch.norm(h[:, -1, :]).item() | |
| current_step_activations[layer_idx] = val | |
| return _hook | |
| hooks = [] | |
| for i, layer in enumerate(model.model.layers): | |
| hooks.append(layer.register_forward_hook(hook_fn(i))) | |
| # 2. Tokenize | |
| msgs = [{"role": "user", "content": prompt}] | |
| inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device) | |
| input_ids = inputs | |
| # Storage for history | |
| history_tokens = [] | |
| history_acts = [] # List of Lists | |
| # 3. Generate Loop | |
| past_key_values = None | |
| max_new_tokens = 100 | |
| yield "Thinking...", gr.update(visible=False), gr.update(visible=False) # Status update | |
| accumulated_text = "" | |
| try: | |
| for _ in range(max_new_tokens): | |
| current_step_activations.clear() | |
| with torch.no_grad(): | |
| if past_key_values is None: | |
| out = model(input_ids) | |
| else: | |
| out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values) | |
| logits = out.logits[:, -1, :] | |
| past_key_values = out.past_key_values | |
| next_id = torch.argmax(logits, dim=-1).unsqueeze(-1) | |
| token_str = tokenizer.decode(next_id[0], skip_special_tokens=True) | |
| # Store Data | |
| accumulated_text += token_str | |
| history_tokens.append(token_str) | |
| # Sort activations by layer index and store | |
| step_acts = [current_step_activations.get(i, 0.0) for i in range(NUM_LAYERS)] | |
| history_acts.append(step_acts) | |
| input_ids = torch.cat([input_ids, next_id], dim=-1) | |
| yield accumulated_text, gr.update(visible=False), gr.update(visible=False) | |
| if next_id.item() == tokenizer.eos_token_id: | |
| break | |
| # FINISHED | |
| # Enable Slider and Return Data | |
| # Max slider value = number of generated tokens - 1 | |
| print(f"Generated {len(history_tokens)} tokens.") | |
| # Package history for the state | |
| session_data = { | |
| "tokens": history_tokens, | |
| "activations": history_acts | |
| } | |
| # Return: Text, Slider Update, Session JSON | |
| yield accumulated_text, gr.update(minimum=0, maximum=len(history_tokens)-1, value=0, visible=True, label=f"Time Travel (0-{len(history_tokens)-1})"), session_data | |
| finally: | |
| for h in hooks: h.remove() | |
| # --- 3. VISUALIZER FUNCTION --- | |
| def render_network_at_step(step_idx, session_data): | |
| if not session_data or step_idx is None: | |
| return None | |
| tokens = session_data["tokens"] | |
| acts_history = session_data["activations"] | |
| # Safety checks | |
| if step_idx >= len(tokens): step_idx = len(tokens) - 1 | |
| if step_idx < 0: step_idx = 0 | |
| current_token = tokens[step_idx] | |
| current_acts = acts_history[step_idx] # Size: 28 (layers) | |
| # --- Prepare Visual Attributes --- | |
| # We map 28 layer values to (28 * NODES_PER_LAYER) visual nodes | |
| # If Layer 1 is active, all 10 nodes in Layer 1 light up | |
| node_colors = [] | |
| node_sizes = [] | |
| # Normalize current step | |
| max_act = max(current_acts) if current_acts else 1.0 | |
| for layer_i in range(NUM_LAYERS): | |
| intensity = current_acts[layer_i] / max_act if max_act > 0 else 0 | |
| # Color mapping (Dark Blue -> Bright Cyan/White) | |
| for _ in range(NODES_PER_LAYER): | |
| node_sizes.append(4 + (intensity * 8)) # Size varies 4 to 12 | |
| node_colors.append(intensity) | |
| # --- Construct Plotly Figure --- | |
| fig = go.Figure() | |
| # 1. Edges (Static wires) | |
| fig.add_trace(go.Scatter3d( | |
| x=edge_x, y=edge_y, z=edge_z, | |
| mode='lines', | |
| line=dict(color='rgba(100, 150, 255, 0.15)', width=1), # Faint blue lines | |
| hoverinfo='none' | |
| )) | |
| # 2. Nodes (Dynamic Lights) | |
| fig.add_trace(go.Scatter3d( | |
| x=node_coords_x, | |
| y=node_coords_y, | |
| z=node_coords_z, | |
| mode='markers', | |
| marker=dict( | |
| size=node_sizes, | |
| color=node_colors, | |
| colorscale='Electric', # Distinct AI look | |
| cmin=0, cmax=1, | |
| opacity=0.9 | |
| ), | |
| text=[f"Layer {i//NODES_PER_LAYER}" for i in range(len(node_coords_x))], | |
| hoverinfo='text' | |
| )) | |
| # Layout styling to match the reference image (Dark Void) | |
| camera = dict( | |
| up=dict(x=0, y=1, z=0), | |
| eye=dict(x=0.5, y=2.5, z=0.5) # Side view | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Token: '{current_token}'", | |
| font=dict(color="white", size=24) | |
| ), | |
| template="plotly_dark", | |
| paper_bgcolor='black', | |
| plot_bgcolor='black', | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| bgcolor='black', | |
| camera=camera | |
| ), | |
| margin=dict(l=0, r=0, b=0, t=50), | |
| ) | |
| return fig | |
| # Wrapper to handle slider change | |
| def on_slider_change(step, session_state): | |
| return render_network_at_step(step, session_state) | |
| # --- 4. UI BUILD --- | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| # Store history data here | |
| session_state = gr.State() | |
| gr.Markdown("# 🕸️ Neural Time-Traveler") | |
| gr.Markdown("1. **Generate** text. 2. **Use the Slider** to travel through time and see the network state for each token.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Input", value="Explain how neural networks learn.", lines=2) | |
| gen_btn = gr.Button("RUN GENERATION", variant="primary") | |
| # This is the Time Slider - initially hidden | |
| time_slider = gr.Slider(label="Timeline (Tokens)", minimum=0, maximum=10, step=1, visible=False) | |
| output_text = gr.Textbox(label="Full Output", lines=8, interactive=False) | |
| with gr.Column(scale=3): | |
| # Large visualization area | |
| network_plot = gr.Plot(label="Internal State Visualization", container=True) | |
| # Logic: | |
| # 1. Click Button -> Run Model -> Update Text + Unhide Slider + Save State | |
| # 2. Slider Change -> Read State -> Update Plot | |
| gen_btn.click( | |
| fn=run_inference, | |
| inputs=prompt, | |
| outputs=[output_text, time_slider, session_state] | |
| ) | |
| # When generation finishes (or slider moves), show the last/current frame | |
| time_slider.change( | |
| fn=on_slider_change, | |
| inputs=[time_slider, session_state], | |
| outputs=network_plot | |
| ) | |
| # Initial trigger to ensure clean state | |
| # (Optional) | |
| if __name__ == "__main__": | |
| demo.launch() |