diamond-in commited on
Commit
a7aae6c
·
verified ·
1 Parent(s): bac3a3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -116
app.py CHANGED
@@ -3,180 +3,197 @@ import torch
3
  import spaces
4
  import json
5
  import numpy as np
6
- import matplotlib
7
- import matplotlib.pyplot as plt
8
- from mpl_toolkits.mplot3d import Axes3D
9
  from threading import Lock
10
  from huggingface_hub import snapshot_download
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
- # Set Matplotlib backend to Agg (non-interactive) for server-side rendering
14
- matplotlib.use('Agg')
15
-
16
- # --- 1. DOWNLOAD MODEL FIRST ---
17
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
18
  print(f"⬇️ Downloading {MODEL_ID}...")
19
  try:
20
  snapshot_download(repo_id=MODEL_ID)
21
- print("✅ Model downloaded successfully.")
22
  except Exception as e:
23
- print(f"⚠️ Warning during download: {e}")
24
 
25
- # --- 2. GLOBAL STATE ---
26
  model_lock = Lock()
27
  model = None
28
  tokenizer = None
29
- current_activations = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # --- 3. BACKEND: LOAD MODEL ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def load_model():
33
  global model, tokenizer
34
- if model is not None:
35
- return
36
 
37
  with model_lock:
38
- print("LOADING Model into Memory...")
39
  model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_ID,
41
  torch_dtype=torch.float16,
42
- device_map="auto",
43
- trust_remote_code=True
44
  )
45
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
46
- print("Model Loaded!")
47
 
48
- # --- 4. BACKEND: HOOKS ---
49
- def get_activation_hook(layer_idx):
50
- def hook(module, input, output):
51
- if isinstance(output, tuple):
52
- hidden_states = output[0]
53
- else:
54
- hidden_states = output
55
-
56
  with torch.no_grad():
57
- val = torch.norm(hidden_states[:, -1, :]).item()
 
58
  current_activations[layer_idx] = val
59
  return hook
60
 
61
- # --- 5. VISUALIZATION FUNCTION (MATPLOTLIB) ---
62
- def create_3d_plot(token_text):
63
- plt.close('all') # Close previous figures to prevent memory leaks
64
- plt.style.use('dark_background')
65
-
66
- fig = plt.figure(figsize=(8, 6))
67
- ax = fig.add_subplot(111, projection='3d')
68
-
69
- layers = list(range(28))
70
- values = [current_activations.get(i, 0.1) for i in layers]
71
-
72
- # Normalize
73
- max_val = max(values) if values and max(values) > 0 else 1
74
- norm_values = [v / max_val for v in values]
75
-
76
- # Bar Data
77
- x_pos = np.arange(28)
78
- y_pos = np.zeros(28)
79
- z_pos = np.zeros(28)
80
- dx = np.ones(28) * 0.8
81
- dy = np.ones(28) * 0.5
82
- dz = values
83
-
84
- # Colors
85
- colormap = plt.cm.plasma
86
- colors = colormap(norm_values)
87
-
88
- # Draw Bars
89
- ax.bar3d(x_pos, y_pos, z_pos, dx, dy, dz, color=colors, shade=True)
90
-
91
- # Styling
92
- ax.set_title(f"Live Activations: '{token_text}'", color='cyan', fontsize=12)
93
- ax.set_xlabel('Layer')
94
- ax.set_zlabel('Intensity')
95
- ax.set_yticks([])
96
-
97
- # --- ERROR FIX HERE: Use xaxis directly, not w_xaxis ---
98
- dark_gray = (0.1, 0.1, 0.1, 1.0)
99
- ax.xaxis.set_pane_color(dark_gray)
100
- ax.yaxis.set_pane_color(dark_gray)
101
- ax.zaxis.set_pane_color(dark_gray)
102
-
103
- ax.grid(color='gray', linestyle=':', linewidth=0.3)
104
-
105
- plt.tight_layout()
106
- return fig
107
-
108
- # --- 6. INFERENCE GENERATOR ---
109
  @spaces.GPU(duration=120)
110
- def generate_response(user_prompt):
111
  load_model()
112
 
 
113
  hooks = []
114
  current_activations.clear()
115
  for i, layer in enumerate(model.model.layers):
116
- h = layer.register_forward_hook(get_activation_hook(i))
117
  hooks.append(h)
118
-
119
- messages = [{"role": "user", "content": user_prompt}]
120
- text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
121
- inputs = tokenizer([text_input], return_tensors="pt").to(model.device)
122
 
123
- input_ids = inputs.input_ids
124
- past_key_values = None
125
- accumulated_text = ""
126
 
127
- yield "", create_3d_plot("Init")
 
 
128
 
129
- step_count = 0
130
- max_tokens = 200
131
 
132
- for _ in range(max_tokens):
 
133
  with torch.no_grad():
134
  if past_key_values is None:
135
- outputs = model(input_ids)
136
  else:
137
- outputs = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
138
 
139
- logits = outputs.logits[:, -1, :]
140
- past_key_values = outputs.past_key_values
141
 
142
  next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
143
  token_str = tokenizer.decode(next_token[0], skip_special_tokens=True)
144
- accumulated_text += token_str
145
- input_ids = torch.cat([input_ids, next_token], dim=-1)
146
 
147
- step_count += 1
148
 
149
- # Update plot every 3 tokens
150
- if step_count % 3 == 0 or next_token.item() == tokenizer.eos_token_id:
151
- fig = create_3d_plot(token_str)
152
- yield accumulated_text, fig
 
 
153
  else:
154
- # Use gr.Skip() properly to avoid re-rendering
155
- yield accumulated_text, gr.update()
 
156
 
157
  if next_token.item() == tokenizer.eos_token_id:
158
  break
159
-
 
160
  for h in hooks: h.remove()
161
- plt.close('all')
162
 
163
- # --- 7. UI LAYOUT ---
164
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="cyan")) as demo:
165
- gr.Markdown("# 🧠 Qwen 2.5 (1.5B) Live Visualization")
166
-
 
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
- prompt_input = gr.Textbox(label="Prompt", lines=3, value="Explain quantum computing briefly.")
170
- generate_btn = gr.Button("Generate", variant="primary")
171
- output_text = gr.Textbox(label="Response", lines=8)
172
 
173
- with gr.Column(scale=1):
174
- viz_plot = gr.Plot(label="Real-Time Activation Topology")
175
-
176
- generate_btn.click(
177
- fn=generate_response,
178
- inputs=prompt_input,
179
- outputs=[output_text, viz_plot]
 
180
  )
181
 
182
  if __name__ == "__main__":
 
3
  import spaces
4
  import json
5
  import numpy as np
6
+ import plotly.graph_objects as go
 
 
7
  from threading import Lock
8
  from huggingface_hub import snapshot_download
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
+ # --- 1. MODEL DOWNLOAD (Immediate) ---
 
 
 
12
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
13
  print(f"⬇️ Downloading {MODEL_ID}...")
14
  try:
15
  snapshot_download(repo_id=MODEL_ID)
16
+ print("✅ Model downloaded.")
17
  except Exception as e:
18
+ print(f"⚠️ Download check ignored: {e}")
19
 
20
+ # --- 2. GLOBAL SETUP & COORDINATES ---
21
  model_lock = Lock()
22
  model = None
23
  tokenizer = None
24
+ current_activations = {}
25
+
26
+ # Pre-calculate 3D Coordinates for the Neural Spiral (28 Layers)
27
+ # We calculate this once so we don't waste CPU during generation
28
+ num_layers = 28
29
+ t_vals = np.linspace(0, 4 * np.pi, num_layers) # 2 loops
30
+ radius = 5
31
+ node_x = radius * np.cos(t_vals)
32
+ node_y = radius * np.sin(t_vals)
33
+ node_z = np.linspace(0, 15, num_layers) # Height
34
+
35
+ # --- 3. PLOTLY VISUALIZATION FUNCTION ---
36
+ def get_neural_plot(token_text, layer_data):
37
+ """
38
+ Creates an interactive 3D Plotly figure.
39
+ """
40
+ # 1. Prepare Data
41
+ # Get activations for all 28 layers (default 0.1)
42
+ acts = [layer_data.get(i, 0.0) for i in range(num_layers)]
43
+
44
+ # Normalize for visuals
45
+ max_val = max(acts) if acts and max(acts) > 0 else 1.0
46
+ norm_acts = [val / max_val for val in acts]
47
+
48
+ # 2. Determine Sizes and Colors
49
+ # Base size 10, grow up to 25 based on activity
50
+ sizes = [10 + (n * 20) for n in norm_acts]
51
+
52
+ # 3. Create Scatter3D Trace
53
+ trace = go.Scatter3d(
54
+ x=node_x,
55
+ y=node_y,
56
+ z=node_z,
57
+ mode='markers+lines', # Nodes connected by lines
58
+ marker=dict(
59
+ size=sizes,
60
+ color=norm_acts, # Color by intensity
61
+ colorscale='Viridis', # Cool -> Hot colors
62
+ cmin=0, cmax=1,
63
+ opacity=0.9,
64
+ line=dict(width=1, color='white')
65
+ ),
66
+ line=dict(
67
+ color='#444444',
68
+ width=2
69
+ ),
70
+ hovertext=[f"Layer {i}: {a:.2f}" for i, a in enumerate(acts)],
71
+ hoverinfo="text"
72
+ )
73
 
74
+ # 4. Layout
75
+ layout = go.Layout(
76
+ title=dict(
77
+ text=f"Token Processing: '{token_text}'",
78
+ font=dict(color="#00ffcc", size=20)
79
+ ),
80
+ paper_bgcolor='#0b0f19', # Dark Background
81
+ plot_bgcolor='#0b0f19',
82
+ scene=dict(
83
+ xaxis=dict(visible=False),
84
+ yaxis=dict(visible=False),
85
+ zaxis=dict(title="Layer Depth", color="white"),
86
+ bgcolor='#0b0f19',
87
+ camera=dict(
88
+ eye=dict(x=1.5, y=1.5, z=0.5) # Initial Camera angle
89
+ )
90
+ ),
91
+ margin=dict(l=0, r=0, b=0, t=40),
92
+ template="plotly_dark"
93
+ )
94
+
95
+ return go.Figure(data=[trace], layout=layout)
96
+
97
+ # --- 4. BACKEND LOGIC ---
98
  def load_model():
99
  global model, tokenizer
100
+ if model is not None: return
 
101
 
102
  with model_lock:
103
+ print("Loading Model...")
104
  model = AutoModelForCausalLM.from_pretrained(
105
  MODEL_ID,
106
  torch_dtype=torch.float16,
107
+ device_map="auto"
 
108
  )
109
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
110
+ print("Loaded.")
111
 
112
+ def hook_fn(layer_idx):
113
+ def hook(module, inp, out):
114
+ if isinstance(out, tuple): h = out[0]
115
+ else: h = out
 
 
 
 
116
  with torch.no_grad():
117
+ # L2 Norm of last token
118
+ val = torch.norm(h[:, -1, :]).item()
119
  current_activations[layer_idx] = val
120
  return hook
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @spaces.GPU(duration=120)
123
+ def generate(prompt):
124
  load_model()
125
 
126
+ # Hook Setup
127
  hooks = []
128
  current_activations.clear()
129
  for i, layer in enumerate(model.model.layers):
130
+ h = layer.register_forward_hook(hook_fn(i))
131
  hooks.append(h)
 
 
 
 
132
 
133
+ # Tokenize
134
+ msgs = [{"role": "user", "content": prompt}]
135
+ inputs = tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True).to(model.device)
136
 
137
+ input_ids = inputs
138
+ past_key_values = None
139
+ accum_text = ""
140
 
141
+ # Initial Plot (Empty)
142
+ yield "", get_neural_plot("Waiting...", {})
143
 
144
+ # Generator
145
+ for step in range(256):
146
  with torch.no_grad():
147
  if past_key_values is None:
148
+ out = model(input_ids)
149
  else:
150
+ out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
151
 
152
+ logits = out.logits[:, -1, :]
153
+ past_key_values = out.past_key_values
154
 
155
  next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
156
  token_str = tokenizer.decode(next_token[0], skip_special_tokens=True)
157
+ accum_text += token_str
 
158
 
159
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
160
 
161
+ # --- YIELD LOGIC ---
162
+ # Plotly is slightly heavy to generate every single token (might lag).
163
+ # We yield the updated Plot every 4 tokens to keep the UI buttery smooth.
164
+ if step % 4 == 0 or next_token.item() == tokenizer.eos_token_id:
165
+ fig = get_neural_plot(token_str, current_activations)
166
+ yield accum_text, fig
167
  else:
168
+ # Use gr.update() effectively skips sending the heavy plot
169
+ # Just update text
170
+ yield accum_text, gr.skip()
171
 
172
  if next_token.item() == tokenizer.eos_token_id:
173
  break
174
+
175
+ # Cleanup
176
  for h in hooks: h.remove()
 
177
 
178
+ # --- 5. UI LAYOUT ---
179
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="cyan")) as demo:
180
+ gr.Markdown("# 🧠 Qwen 1.5B - Interactive Neural Spiral")
181
+ gr.Markdown("*Zoom, Pan, and Rotate with your mouse. Nodes pulse based on AI thought process.*")
182
+
183
  with gr.Row():
184
  with gr.Column(scale=1):
185
+ prompt = gr.Textbox(label="User Prompt", value="Write a poem about neural networks.", lines=3)
186
+ btn = gr.Button("Generate", variant="primary")
187
+ output = gr.Textbox(label="AI Response", lines=10)
188
 
189
+ with gr.Column(scale=2):
190
+ # GRADIO PLOT Component (Supports Plotly Interactivity)
191
+ plot_component = gr.Plot(label="Live Neural Activations")
192
+
193
+ btn.click(
194
+ fn=generate,
195
+ inputs=prompt,
196
+ outputs=[output, plot_component]
197
  )
198
 
199
  if __name__ == "__main__":