diamond-in commited on
Commit
4c74fbe
·
verified ·
1 Parent(s): a7aae6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -151
app.py CHANGED
@@ -1,200 +1,309 @@
1
  import gradio as gr
2
  import torch
3
  import spaces
4
- import json
5
  import numpy as np
6
  import plotly.graph_objects as go
7
  from threading import Lock
8
  from huggingface_hub import snapshot_download
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
 
11
- # --- 1. MODEL DOWNLOAD (Immediate) ---
12
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
 
13
  print(f"⬇️ Downloading {MODEL_ID}...")
14
  try:
15
  snapshot_download(repo_id=MODEL_ID)
16
- print("✅ Model downloaded.")
17
  except Exception as e:
18
- print(f"⚠️ Download check ignored: {e}")
19
 
20
- # --- 2. GLOBAL SETUP & COORDINATES ---
21
  model_lock = Lock()
22
  model = None
23
  tokenizer = None
24
- current_activations = {}
25
-
26
- # Pre-calculate 3D Coordinates for the Neural Spiral (28 Layers)
27
- # We calculate this once so we don't waste CPU during generation
28
- num_layers = 28
29
- t_vals = np.linspace(0, 4 * np.pi, num_layers) # 2 loops
30
- radius = 5
31
- node_x = radius * np.cos(t_vals)
32
- node_y = radius * np.sin(t_vals)
33
- node_z = np.linspace(0, 15, num_layers) # Height
34
-
35
- # --- 3. PLOTLY VISUALIZATION FUNCTION ---
36
- def get_neural_plot(token_text, layer_data):
37
- """
38
- Creates an interactive 3D Plotly figure.
39
- """
40
- # 1. Prepare Data
41
- # Get activations for all 28 layers (default 0.1)
42
- acts = [layer_data.get(i, 0.0) for i in range(num_layers)]
43
-
44
- # Normalize for visuals
45
- max_val = max(acts) if acts and max(acts) > 0 else 1.0
46
- norm_acts = [val / max_val for val in acts]
47
-
48
- # 2. Determine Sizes and Colors
49
- # Base size 10, grow up to 25 based on activity
50
- sizes = [10 + (n * 20) for n in norm_acts]
51
-
52
- # 3. Create Scatter3D Trace
53
- trace = go.Scatter3d(
54
- x=node_x,
55
- y=node_y,
56
- z=node_z,
57
- mode='markers+lines', # Nodes connected by lines
58
- marker=dict(
59
- size=sizes,
60
- color=norm_acts, # Color by intensity
61
- colorscale='Viridis', # Cool -> Hot colors
62
- cmin=0, cmax=1,
63
- opacity=0.9,
64
- line=dict(width=1, color='white')
65
- ),
66
- line=dict(
67
- color='#444444',
68
- width=2
69
- ),
70
- hovertext=[f"Layer {i}: {a:.2f}" for i, a in enumerate(acts)],
71
- hoverinfo="text"
72
- )
73
 
74
- # 4. Layout
75
- layout = go.Layout(
76
- title=dict(
77
- text=f"Token Processing: '{token_text}'",
78
- font=dict(color="#00ffcc", size=20)
79
- ),
80
- paper_bgcolor='#0b0f19', # Dark Background
81
- plot_bgcolor='#0b0f19',
82
- scene=dict(
83
- xaxis=dict(visible=False),
84
- yaxis=dict(visible=False),
85
- zaxis=dict(title="Layer Depth", color="white"),
86
- bgcolor='#0b0f19',
87
- camera=dict(
88
- eye=dict(x=1.5, y=1.5, z=0.5) # Initial Camera angle
89
- )
90
- ),
91
- margin=dict(l=0, r=0, b=0, t=40),
92
- template="plotly_dark"
93
- )
94
 
95
- return go.Figure(data=[trace], layout=layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # --- 4. BACKEND LOGIC ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def load_model():
99
  global model, tokenizer
100
  if model is not None: return
101
-
102
  with model_lock:
103
- print("Loading Model...")
104
  model = AutoModelForCausalLM.from_pretrained(
105
- MODEL_ID,
106
- torch_dtype=torch.float16,
107
- device_map="auto"
108
  )
109
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
110
- print("Loaded.")
111
-
112
- def hook_fn(layer_idx):
113
- def hook(module, inp, out):
114
- if isinstance(out, tuple): h = out[0]
115
- else: h = out
116
- with torch.no_grad():
117
- # L2 Norm of last token
118
- val = torch.norm(h[:, -1, :]).item()
119
- current_activations[layer_idx] = val
120
- return hook
121
-
122
- @spaces.GPU(duration=120)
123
- def generate(prompt):
124
  load_model()
125
 
126
- # Hook Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  hooks = []
128
- current_activations.clear()
129
  for i, layer in enumerate(model.model.layers):
130
- h = layer.register_forward_hook(hook_fn(i))
131
- hooks.append(h)
132
-
133
- # Tokenize
134
  msgs = [{"role": "user", "content": prompt}]
135
  inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device)
136
-
137
  input_ids = inputs
 
 
 
 
 
 
138
  past_key_values = None
139
- accum_text = ""
140
-
141
- # Initial Plot (Empty)
142
- yield "", get_neural_plot("Waiting...", {})
143
-
144
- # Generator
145
- for step in range(256):
146
- with torch.no_grad():
147
- if past_key_values is None:
148
- out = model(input_ids)
149
- else:
150
- out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
151
-
152
- logits = out.logits[:, -1, :]
153
- past_key_values = out.past_key_values
154
-
155
- next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
156
- token_str = tokenizer.decode(next_token[0], skip_special_tokens=True)
157
- accum_text += token_str
158
-
159
- input_ids = torch.cat([input_ids, next_token], dim=-1)
160
-
161
- # --- YIELD LOGIC ---
162
- # Plotly is slightly heavy to generate every single token (might lag).
163
- # We yield the updated Plot every 4 tokens to keep the UI buttery smooth.
164
- if step % 4 == 0 or next_token.item() == tokenizer.eos_token_id:
165
- fig = get_neural_plot(token_str, current_activations)
166
- yield accum_text, fig
167
- else:
168
- # Use gr.update() effectively skips sending the heavy plot
169
- # Just update text
170
- yield accum_text, gr.skip()
171
 
172
- if next_token.item() == tokenizer.eos_token_id:
173
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- # Cleanup
176
- for h in hooks: h.remove()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # --- 5. UI LAYOUT ---
179
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="cyan")) as demo:
180
- gr.Markdown("# 🧠 Qwen 1.5B - Interactive Neural Spiral")
181
- gr.Markdown("*Zoom, Pan, and Rotate with your mouse. Nodes pulse based on AI thought process.*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with gr.Row():
184
  with gr.Column(scale=1):
185
- prompt = gr.Textbox(label="User Prompt", value="Write a poem about neural networks.", lines=3)
186
- btn = gr.Button("Generate", variant="primary")
187
- output = gr.Textbox(label="AI Response", lines=10)
 
 
 
 
188
 
189
- with gr.Column(scale=2):
190
- # GRADIO PLOT Component (Supports Plotly Interactivity)
191
- plot_component = gr.Plot(label="Live Neural Activations")
192
 
193
- btn.click(
194
- fn=generate,
 
 
 
 
195
  inputs=prompt,
196
- outputs=[output, plot_component]
 
 
 
 
 
 
 
197
  )
 
 
 
198
 
199
  if __name__ == "__main__":
200
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import spaces
 
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  from threading import Lock
7
  from huggingface_hub import snapshot_download
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ import random
10
 
11
+ # --- 1. CONFIG & SETUP ---
12
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
13
+
14
  print(f"⬇️ Downloading {MODEL_ID}...")
15
  try:
16
  snapshot_download(repo_id=MODEL_ID)
17
+ print("✅ Download Ready.")
18
  except Exception as e:
19
+ print(f"⚠️ Warning: {e}")
20
 
 
21
  model_lock = Lock()
22
  model = None
23
  tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # We use 28 layers for Qwen 1.5B
26
+ NUM_LAYERS = 28
27
+ # Visual settings
28
+ NODES_PER_LAYER = 10 # Represent each layer as 10 visual nodes (abstract representation)
29
+ LINES_PER_LAYER = 15 # Lines between layers to create the "Dense" look
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Pre-calculate Network Geometry (X, Y, Z coords for nodes)
32
+ # Structure: Layers spread along X axis. Nodes spread on Y/Z plane.
33
+ node_coords_x = []
34
+ node_coords_y = []
35
+ node_coords_z = []
36
+
37
+ # Generate positions
38
+ for layer_i in range(NUM_LAYERS):
39
+ x_pos = layer_i * 2 # Spacing between layers
40
+
41
+ # create a ring or grid of nodes for this layer
42
+ for n in range(NODES_PER_LAYER):
43
+ # Circle arrangement
44
+ theta = (2 * np.pi * n) / NODES_PER_LAYER
45
+ radius = 4
46
+ y_pos = radius * np.cos(theta)
47
+ z_pos = radius * np.sin(theta)
48
+
49
+ node_coords_x.append(x_pos)
50
+ node_coords_y.append(y_pos)
51
+ node_coords_z.append(z_pos)
52
 
53
+ # Pre-calculate Connections (Edges)
54
+ # List of (x1, y1, z1, x2, y2, z2) for lines
55
+ edge_x, edge_y, edge_z = [], [], []
56
+
57
+ for layer_i in range(NUM_LAYERS - 1):
58
+ curr_start_idx = layer_i * NODES_PER_LAYER
59
+ next_start_idx = (layer_i + 1) * NODES_PER_LAYER
60
+
61
+ # Create random dense connections
62
+ for _ in range(LINES_PER_LAYER):
63
+ # Pick random start node in current layer
64
+ n1 = random.randint(0, NODES_PER_LAYER - 1)
65
+ # Pick random end node in next layer
66
+ n2 = random.randint(0, NODES_PER_LAYER - 1)
67
+
68
+ idx1 = curr_start_idx + n1
69
+ idx2 = next_start_idx + n2
70
+
71
+ edge_x.extend([node_coords_x[idx1], node_coords_x[idx2], None])
72
+ edge_y.extend([node_coords_y[idx1], node_coords_y[idx2], None])
73
+ edge_z.extend([node_coords_z[idx1], node_coords_z[idx2], None])
74
+
75
+ # --- 2. BACKEND LOGIC ---
76
  def load_model():
77
  global model, tokenizer
78
  if model is not None: return
 
79
  with model_lock:
80
+ print("Loading weights...")
81
  model = AutoModelForCausalLM.from_pretrained(
82
+ MODEL_ID, torch_dtype=torch.float16, device_map="auto"
 
 
83
  )
84
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
85
+
86
+ # Session State will store: {'tokens': [], 'activations': [[layer0_val, ...], [layer0_val...]]}
87
+ def run_inference(prompt):
 
 
 
 
 
 
 
 
 
 
 
88
  load_model()
89
 
90
+ # 1. Setup Hooks
91
+ # We will capture the MEAN activation of each layer for the current token
92
+ current_step_activations = {}
93
+
94
+ def hook_fn(layer_idx):
95
+ def _hook(mod, inp, out):
96
+ if isinstance(out, tuple): h = out[0]
97
+ else: h = out
98
+ # Capture Norm of the last token processed
99
+ with torch.no_grad():
100
+ val = torch.norm(h[:, -1, :]).item()
101
+ current_step_activations[layer_idx] = val
102
+ return _hook
103
+
104
  hooks = []
 
105
  for i, layer in enumerate(model.model.layers):
106
+ hooks.append(layer.register_forward_hook(hook_fn(i)))
107
+
108
+ # 2. Tokenize
 
109
  msgs = [{"role": "user", "content": prompt}]
110
  inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device)
 
111
  input_ids = inputs
112
+
113
+ # Storage for history
114
+ history_tokens = []
115
+ history_acts = [] # List of Lists
116
+
117
+ # 3. Generate Loop
118
  past_key_values = None
119
+ max_new_tokens = 100
120
+
121
+ yield "Thinking...", gr.update(visible=False), gr.update(visible=False) # Status update
122
+
123
+ accumulated_text = ""
124
+
125
+ try:
126
+ for _ in range(max_new_tokens):
127
+ current_step_activations.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ with torch.no_grad():
130
+ if past_key_values is None:
131
+ out = model(input_ids)
132
+ else:
133
+ out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
134
+
135
+ logits = out.logits[:, -1, :]
136
+ past_key_values = out.past_key_values
137
+
138
+ next_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
139
+ token_str = tokenizer.decode(next_id[0], skip_special_tokens=True)
140
+
141
+ # Store Data
142
+ accumulated_text += token_str
143
+ history_tokens.append(token_str)
144
 
145
+ # Sort activations by layer index and store
146
+ step_acts = [current_step_activations.get(i, 0.0) for i in range(NUM_LAYERS)]
147
+ history_acts.append(step_acts)
148
+
149
+ input_ids = torch.cat([input_ids, next_id], dim=-1)
150
+
151
+ yield accumulated_text, gr.update(visible=False), gr.update(visible=False)
152
+
153
+ if next_id.item() == tokenizer.eos_token_id:
154
+ break
155
+
156
+ # FINISHED
157
+ # Enable Slider and Return Data
158
+ # Max slider value = number of generated tokens - 1
159
+ print(f"Generated {len(history_tokens)} tokens.")
160
+
161
+ # Package history for the state
162
+ session_data = {
163
+ "tokens": history_tokens,
164
+ "activations": history_acts
165
+ }
166
+
167
+ # Return: Text, Slider Update, Session JSON
168
+ yield accumulated_text, gr.update(minimum=0, maximum=len(history_tokens)-1, value=0, visible=True, label=f"Time Travel (0-{len(history_tokens)-1})"), session_data
169
 
170
+ finally:
171
+ for h in hooks: h.remove()
172
+
173
+ # --- 3. VISUALIZER FUNCTION ---
174
+ def render_network_at_step(step_idx, session_data):
175
+ if not session_data or step_idx is None:
176
+ return None
177
+
178
+ tokens = session_data["tokens"]
179
+ acts_history = session_data["activations"]
180
+
181
+ # Safety checks
182
+ if step_idx >= len(tokens): step_idx = len(tokens) - 1
183
+ if step_idx < 0: step_idx = 0
184
+
185
+ current_token = tokens[step_idx]
186
+ current_acts = acts_history[step_idx] # Size: 28 (layers)
187
+
188
+ # --- Prepare Visual Attributes ---
189
+ # We map 28 layer values to (28 * NODES_PER_LAYER) visual nodes
190
+ # If Layer 1 is active, all 10 nodes in Layer 1 light up
191
+
192
+ node_colors = []
193
+ node_sizes = []
194
+
195
+ # Normalize current step
196
+ max_act = max(current_acts) if current_acts else 1.0
197
+
198
+ for layer_i in range(NUM_LAYERS):
199
+ intensity = current_acts[layer_i] / max_act if max_act > 0 else 0
200
+
201
+ # Color mapping (Dark Blue -> Bright Cyan/White)
202
+ for _ in range(NODES_PER_LAYER):
203
+ node_sizes.append(4 + (intensity * 8)) # Size varies 4 to 12
204
+ node_colors.append(intensity)
205
 
206
+ # --- Construct Plotly Figure ---
207
+ fig = go.Figure()
208
+
209
+ # 1. Edges (Static wires)
210
+ fig.add_trace(go.Scatter3d(
211
+ x=edge_x, y=edge_y, z=edge_z,
212
+ mode='lines',
213
+ line=dict(color='rgba(100, 150, 255, 0.15)', width=1), # Faint blue lines
214
+ hoverinfo='none'
215
+ ))
216
+
217
+ # 2. Nodes (Dynamic Lights)
218
+ fig.add_trace(go.Scatter3d(
219
+ x=node_coords_x,
220
+ y=node_coords_y,
221
+ z=node_coords_z,
222
+ mode='markers',
223
+ marker=dict(
224
+ size=node_sizes,
225
+ color=node_colors,
226
+ colorscale='Electric', # Distinct AI look
227
+ cmin=0, cmax=1,
228
+ opacity=0.9
229
+ ),
230
+ text=[f"Layer {i//NODES_PER_LAYER}" for i in range(len(node_coords_x))],
231
+ hoverinfo='text'
232
+ ))
233
+
234
+ # Layout styling to match the reference image (Dark Void)
235
+ camera = dict(
236
+ up=dict(x=0, y=1, z=0),
237
+ eye=dict(x=0.5, y=2.5, z=0.5) # Side view
238
+ )
239
+
240
+ fig.update_layout(
241
+ title=dict(
242
+ text=f"Token: '{current_token}'",
243
+ font=dict(color="white", size=24)
244
+ ),
245
+ template="plotly_dark",
246
+ paper_bgcolor='black',
247
+ plot_bgcolor='black',
248
+ scene=dict(
249
+ xaxis=dict(visible=False),
250
+ yaxis=dict(visible=False),
251
+ zaxis=dict(visible=False),
252
+ bgcolor='black',
253
+ camera=camera
254
+ ),
255
+ margin=dict(l=0, r=0, b=0, t=50),
256
+ )
257
+
258
+ return fig
259
+
260
+ # Wrapper to handle slider change
261
+ @spaces.GPU
262
+ def on_slider_change(step, session_state):
263
+ return render_network_at_step(step, session_state)
264
+
265
+ # --- 4. UI BUILD ---
266
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
267
+
268
+ # Store history data here
269
+ session_state = gr.State()
270
+
271
+ gr.Markdown("# 🕸️ Neural Time-Traveler")
272
+ gr.Markdown("1. **Generate** text. 2. **Use the Slider** to travel through time and see the network state for each token.")
273
+
274
  with gr.Row():
275
  with gr.Column(scale=1):
276
+ prompt = gr.Textbox(label="Input", value="Explain how neural networks learn.", lines=2)
277
+ gen_btn = gr.Button("RUN GENERATION", variant="primary")
278
+
279
+ # This is the Time Slider - initially hidden
280
+ time_slider = gr.Slider(label="Timeline (Tokens)", minimum=0, maximum=10, step=1, visible=False)
281
+
282
+ output_text = gr.Textbox(label="Full Output", lines=8, interactive=False)
283
 
284
+ with gr.Column(scale=3):
285
+ # Large visualization area
286
+ network_plot = gr.Plot(label="Internal State Visualization", container=True)
287
 
288
+ # Logic:
289
+ # 1. Click Button -> Run Model -> Update Text + Unhide Slider + Save State
290
+ # 2. Slider Change -> Read State -> Update Plot
291
+
292
+ gen_btn.click(
293
+ fn=run_inference,
294
  inputs=prompt,
295
+ outputs=[output_text, time_slider, session_state]
296
+ )
297
+
298
+ # When generation finishes (or slider moves), show the last/current frame
299
+ time_slider.change(
300
+ fn=on_slider_change,
301
+ inputs=[time_slider, session_state],
302
+ outputs=network_plot
303
  )
304
+
305
+ # Initial trigger to ensure clean state
306
+ # (Optional)
307
 
308
  if __name__ == "__main__":
309
  demo.launch()