3dgraphllm / app.py
diamond-in's picture
Update app.py
4c74fbe verified
import gradio as gr
import torch
import spaces
import numpy as np
import plotly.graph_objects as go
from threading import Lock
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
# --- 1. CONFIG & SETUP ---
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
print(f"⬇️ Downloading {MODEL_ID}...")
try:
snapshot_download(repo_id=MODEL_ID)
print("✅ Download Ready.")
except Exception as e:
print(f"⚠️ Warning: {e}")
model_lock = Lock()
model = None
tokenizer = None
# We use 28 layers for Qwen 1.5B
NUM_LAYERS = 28
# Visual settings
NODES_PER_LAYER = 10 # Represent each layer as 10 visual nodes (abstract representation)
LINES_PER_LAYER = 15 # Lines between layers to create the "Dense" look
# Pre-calculate Network Geometry (X, Y, Z coords for nodes)
# Structure: Layers spread along X axis. Nodes spread on Y/Z plane.
node_coords_x = []
node_coords_y = []
node_coords_z = []
# Generate positions
for layer_i in range(NUM_LAYERS):
x_pos = layer_i * 2 # Spacing between layers
# create a ring or grid of nodes for this layer
for n in range(NODES_PER_LAYER):
# Circle arrangement
theta = (2 * np.pi * n) / NODES_PER_LAYER
radius = 4
y_pos = radius * np.cos(theta)
z_pos = radius * np.sin(theta)
node_coords_x.append(x_pos)
node_coords_y.append(y_pos)
node_coords_z.append(z_pos)
# Pre-calculate Connections (Edges)
# List of (x1, y1, z1, x2, y2, z2) for lines
edge_x, edge_y, edge_z = [], [], []
for layer_i in range(NUM_LAYERS - 1):
curr_start_idx = layer_i * NODES_PER_LAYER
next_start_idx = (layer_i + 1) * NODES_PER_LAYER
# Create random dense connections
for _ in range(LINES_PER_LAYER):
# Pick random start node in current layer
n1 = random.randint(0, NODES_PER_LAYER - 1)
# Pick random end node in next layer
n2 = random.randint(0, NODES_PER_LAYER - 1)
idx1 = curr_start_idx + n1
idx2 = next_start_idx + n2
edge_x.extend([node_coords_x[idx1], node_coords_x[idx2], None])
edge_y.extend([node_coords_y[idx1], node_coords_y[idx2], None])
edge_z.extend([node_coords_z[idx1], node_coords_z[idx2], None])
# --- 2. BACKEND LOGIC ---
def load_model():
global model, tokenizer
if model is not None: return
with model_lock:
print("Loading weights...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Session State will store: {'tokens': [], 'activations': [[layer0_val, ...], [layer0_val...]]}
def run_inference(prompt):
load_model()
# 1. Setup Hooks
# We will capture the MEAN activation of each layer for the current token
current_step_activations = {}
def hook_fn(layer_idx):
def _hook(mod, inp, out):
if isinstance(out, tuple): h = out[0]
else: h = out
# Capture Norm of the last token processed
with torch.no_grad():
val = torch.norm(h[:, -1, :]).item()
current_step_activations[layer_idx] = val
return _hook
hooks = []
for i, layer in enumerate(model.model.layers):
hooks.append(layer.register_forward_hook(hook_fn(i)))
# 2. Tokenize
msgs = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device)
input_ids = inputs
# Storage for history
history_tokens = []
history_acts = [] # List of Lists
# 3. Generate Loop
past_key_values = None
max_new_tokens = 100
yield "Thinking...", gr.update(visible=False), gr.update(visible=False) # Status update
accumulated_text = ""
try:
for _ in range(max_new_tokens):
current_step_activations.clear()
with torch.no_grad():
if past_key_values is None:
out = model(input_ids)
else:
out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
logits = out.logits[:, -1, :]
past_key_values = out.past_key_values
next_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
token_str = tokenizer.decode(next_id[0], skip_special_tokens=True)
# Store Data
accumulated_text += token_str
history_tokens.append(token_str)
# Sort activations by layer index and store
step_acts = [current_step_activations.get(i, 0.0) for i in range(NUM_LAYERS)]
history_acts.append(step_acts)
input_ids = torch.cat([input_ids, next_id], dim=-1)
yield accumulated_text, gr.update(visible=False), gr.update(visible=False)
if next_id.item() == tokenizer.eos_token_id:
break
# FINISHED
# Enable Slider and Return Data
# Max slider value = number of generated tokens - 1
print(f"Generated {len(history_tokens)} tokens.")
# Package history for the state
session_data = {
"tokens": history_tokens,
"activations": history_acts
}
# Return: Text, Slider Update, Session JSON
yield accumulated_text, gr.update(minimum=0, maximum=len(history_tokens)-1, value=0, visible=True, label=f"Time Travel (0-{len(history_tokens)-1})"), session_data
finally:
for h in hooks: h.remove()
# --- 3. VISUALIZER FUNCTION ---
def render_network_at_step(step_idx, session_data):
if not session_data or step_idx is None:
return None
tokens = session_data["tokens"]
acts_history = session_data["activations"]
# Safety checks
if step_idx >= len(tokens): step_idx = len(tokens) - 1
if step_idx < 0: step_idx = 0
current_token = tokens[step_idx]
current_acts = acts_history[step_idx] # Size: 28 (layers)
# --- Prepare Visual Attributes ---
# We map 28 layer values to (28 * NODES_PER_LAYER) visual nodes
# If Layer 1 is active, all 10 nodes in Layer 1 light up
node_colors = []
node_sizes = []
# Normalize current step
max_act = max(current_acts) if current_acts else 1.0
for layer_i in range(NUM_LAYERS):
intensity = current_acts[layer_i] / max_act if max_act > 0 else 0
# Color mapping (Dark Blue -> Bright Cyan/White)
for _ in range(NODES_PER_LAYER):
node_sizes.append(4 + (intensity * 8)) # Size varies 4 to 12
node_colors.append(intensity)
# --- Construct Plotly Figure ---
fig = go.Figure()
# 1. Edges (Static wires)
fig.add_trace(go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
mode='lines',
line=dict(color='rgba(100, 150, 255, 0.15)', width=1), # Faint blue lines
hoverinfo='none'
))
# 2. Nodes (Dynamic Lights)
fig.add_trace(go.Scatter3d(
x=node_coords_x,
y=node_coords_y,
z=node_coords_z,
mode='markers',
marker=dict(
size=node_sizes,
color=node_colors,
colorscale='Electric', # Distinct AI look
cmin=0, cmax=1,
opacity=0.9
),
text=[f"Layer {i//NODES_PER_LAYER}" for i in range(len(node_coords_x))],
hoverinfo='text'
))
# Layout styling to match the reference image (Dark Void)
camera = dict(
up=dict(x=0, y=1, z=0),
eye=dict(x=0.5, y=2.5, z=0.5) # Side view
)
fig.update_layout(
title=dict(
text=f"Token: '{current_token}'",
font=dict(color="white", size=24)
),
template="plotly_dark",
paper_bgcolor='black',
plot_bgcolor='black',
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
bgcolor='black',
camera=camera
),
margin=dict(l=0, r=0, b=0, t=50),
)
return fig
# Wrapper to handle slider change
@spaces.GPU
def on_slider_change(step, session_state):
return render_network_at_step(step, session_state)
# --- 4. UI BUILD ---
with gr.Blocks(theme=gr.themes.Base()) as demo:
# Store history data here
session_state = gr.State()
gr.Markdown("# 🕸️ Neural Time-Traveler")
gr.Markdown("1. **Generate** text. 2. **Use the Slider** to travel through time and see the network state for each token.")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Input", value="Explain how neural networks learn.", lines=2)
gen_btn = gr.Button("RUN GENERATION", variant="primary")
# This is the Time Slider - initially hidden
time_slider = gr.Slider(label="Timeline (Tokens)", minimum=0, maximum=10, step=1, visible=False)
output_text = gr.Textbox(label="Full Output", lines=8, interactive=False)
with gr.Column(scale=3):
# Large visualization area
network_plot = gr.Plot(label="Internal State Visualization", container=True)
# Logic:
# 1. Click Button -> Run Model -> Update Text + Unhide Slider + Save State
# 2. Slider Change -> Read State -> Update Plot
gen_btn.click(
fn=run_inference,
inputs=prompt,
outputs=[output_text, time_slider, session_state]
)
# When generation finishes (or slider moves), show the last/current frame
time_slider.change(
fn=on_slider_change,
inputs=[time_slider, session_state],
outputs=network_plot
)
# Initial trigger to ensure clean state
# (Optional)
if __name__ == "__main__":
demo.launch()