import gradio as gr import torch import spaces import numpy as np import plotly.graph_objects as go from threading import Lock from huggingface_hub import snapshot_download from transformers import AutoModelForCausalLM, AutoTokenizer import random # --- 1. CONFIG & SETUP --- MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" print(f"⬇️ Downloading {MODEL_ID}...") try: snapshot_download(repo_id=MODEL_ID) print("✅ Download Ready.") except Exception as e: print(f"⚠️ Warning: {e}") model_lock = Lock() model = None tokenizer = None # We use 28 layers for Qwen 1.5B NUM_LAYERS = 28 # Visual settings NODES_PER_LAYER = 10 # Represent each layer as 10 visual nodes (abstract representation) LINES_PER_LAYER = 15 # Lines between layers to create the "Dense" look # Pre-calculate Network Geometry (X, Y, Z coords for nodes) # Structure: Layers spread along X axis. Nodes spread on Y/Z plane. node_coords_x = [] node_coords_y = [] node_coords_z = [] # Generate positions for layer_i in range(NUM_LAYERS): x_pos = layer_i * 2 # Spacing between layers # create a ring or grid of nodes for this layer for n in range(NODES_PER_LAYER): # Circle arrangement theta = (2 * np.pi * n) / NODES_PER_LAYER radius = 4 y_pos = radius * np.cos(theta) z_pos = radius * np.sin(theta) node_coords_x.append(x_pos) node_coords_y.append(y_pos) node_coords_z.append(z_pos) # Pre-calculate Connections (Edges) # List of (x1, y1, z1, x2, y2, z2) for lines edge_x, edge_y, edge_z = [], [], [] for layer_i in range(NUM_LAYERS - 1): curr_start_idx = layer_i * NODES_PER_LAYER next_start_idx = (layer_i + 1) * NODES_PER_LAYER # Create random dense connections for _ in range(LINES_PER_LAYER): # Pick random start node in current layer n1 = random.randint(0, NODES_PER_LAYER - 1) # Pick random end node in next layer n2 = random.randint(0, NODES_PER_LAYER - 1) idx1 = curr_start_idx + n1 idx2 = next_start_idx + n2 edge_x.extend([node_coords_x[idx1], node_coords_x[idx2], None]) edge_y.extend([node_coords_y[idx1], node_coords_y[idx2], None]) edge_z.extend([node_coords_z[idx1], node_coords_z[idx2], None]) # --- 2. BACKEND LOGIC --- def load_model(): global model, tokenizer if model is not None: return with model_lock: print("Loading weights...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Session State will store: {'tokens': [], 'activations': [[layer0_val, ...], [layer0_val...]]} def run_inference(prompt): load_model() # 1. Setup Hooks # We will capture the MEAN activation of each layer for the current token current_step_activations = {} def hook_fn(layer_idx): def _hook(mod, inp, out): if isinstance(out, tuple): h = out[0] else: h = out # Capture Norm of the last token processed with torch.no_grad(): val = torch.norm(h[:, -1, :]).item() current_step_activations[layer_idx] = val return _hook hooks = [] for i, layer in enumerate(model.model.layers): hooks.append(layer.register_forward_hook(hook_fn(i))) # 2. Tokenize msgs = [{"role": "user", "content": prompt}] inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device) input_ids = inputs # Storage for history history_tokens = [] history_acts = [] # List of Lists # 3. Generate Loop past_key_values = None max_new_tokens = 100 yield "Thinking...", gr.update(visible=False), gr.update(visible=False) # Status update accumulated_text = "" try: for _ in range(max_new_tokens): current_step_activations.clear() with torch.no_grad(): if past_key_values is None: out = model(input_ids) else: out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values) logits = out.logits[:, -1, :] past_key_values = out.past_key_values next_id = torch.argmax(logits, dim=-1).unsqueeze(-1) token_str = tokenizer.decode(next_id[0], skip_special_tokens=True) # Store Data accumulated_text += token_str history_tokens.append(token_str) # Sort activations by layer index and store step_acts = [current_step_activations.get(i, 0.0) for i in range(NUM_LAYERS)] history_acts.append(step_acts) input_ids = torch.cat([input_ids, next_id], dim=-1) yield accumulated_text, gr.update(visible=False), gr.update(visible=False) if next_id.item() == tokenizer.eos_token_id: break # FINISHED # Enable Slider and Return Data # Max slider value = number of generated tokens - 1 print(f"Generated {len(history_tokens)} tokens.") # Package history for the state session_data = { "tokens": history_tokens, "activations": history_acts } # Return: Text, Slider Update, Session JSON yield accumulated_text, gr.update(minimum=0, maximum=len(history_tokens)-1, value=0, visible=True, label=f"Time Travel (0-{len(history_tokens)-1})"), session_data finally: for h in hooks: h.remove() # --- 3. VISUALIZER FUNCTION --- def render_network_at_step(step_idx, session_data): if not session_data or step_idx is None: return None tokens = session_data["tokens"] acts_history = session_data["activations"] # Safety checks if step_idx >= len(tokens): step_idx = len(tokens) - 1 if step_idx < 0: step_idx = 0 current_token = tokens[step_idx] current_acts = acts_history[step_idx] # Size: 28 (layers) # --- Prepare Visual Attributes --- # We map 28 layer values to (28 * NODES_PER_LAYER) visual nodes # If Layer 1 is active, all 10 nodes in Layer 1 light up node_colors = [] node_sizes = [] # Normalize current step max_act = max(current_acts) if current_acts else 1.0 for layer_i in range(NUM_LAYERS): intensity = current_acts[layer_i] / max_act if max_act > 0 else 0 # Color mapping (Dark Blue -> Bright Cyan/White) for _ in range(NODES_PER_LAYER): node_sizes.append(4 + (intensity * 8)) # Size varies 4 to 12 node_colors.append(intensity) # --- Construct Plotly Figure --- fig = go.Figure() # 1. Edges (Static wires) fig.add_trace(go.Scatter3d( x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='rgba(100, 150, 255, 0.15)', width=1), # Faint blue lines hoverinfo='none' )) # 2. Nodes (Dynamic Lights) fig.add_trace(go.Scatter3d( x=node_coords_x, y=node_coords_y, z=node_coords_z, mode='markers', marker=dict( size=node_sizes, color=node_colors, colorscale='Electric', # Distinct AI look cmin=0, cmax=1, opacity=0.9 ), text=[f"Layer {i//NODES_PER_LAYER}" for i in range(len(node_coords_x))], hoverinfo='text' )) # Layout styling to match the reference image (Dark Void) camera = dict( up=dict(x=0, y=1, z=0), eye=dict(x=0.5, y=2.5, z=0.5) # Side view ) fig.update_layout( title=dict( text=f"Token: '{current_token}'", font=dict(color="white", size=24) ), template="plotly_dark", paper_bgcolor='black', plot_bgcolor='black', scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), bgcolor='black', camera=camera ), margin=dict(l=0, r=0, b=0, t=50), ) return fig # Wrapper to handle slider change @spaces.GPU def on_slider_change(step, session_state): return render_network_at_step(step, session_state) # --- 4. UI BUILD --- with gr.Blocks(theme=gr.themes.Base()) as demo: # Store history data here session_state = gr.State() gr.Markdown("# 🕸️ Neural Time-Traveler") gr.Markdown("1. **Generate** text. 2. **Use the Slider** to travel through time and see the network state for each token.") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Input", value="Explain how neural networks learn.", lines=2) gen_btn = gr.Button("RUN GENERATION", variant="primary") # This is the Time Slider - initially hidden time_slider = gr.Slider(label="Timeline (Tokens)", minimum=0, maximum=10, step=1, visible=False) output_text = gr.Textbox(label="Full Output", lines=8, interactive=False) with gr.Column(scale=3): # Large visualization area network_plot = gr.Plot(label="Internal State Visualization", container=True) # Logic: # 1. Click Button -> Run Model -> Update Text + Unhide Slider + Save State # 2. Slider Change -> Read State -> Update Plot gen_btn.click( fn=run_inference, inputs=prompt, outputs=[output_text, time_slider, session_state] ) # When generation finishes (or slider moves), show the last/current frame time_slider.change( fn=on_slider_change, inputs=[time_slider, session_state], outputs=network_plot ) # Initial trigger to ensure clean state # (Optional) if __name__ == "__main__": demo.launch()