diamond-in commited on
Commit
30320d1
·
verified ·
1 Parent(s): cb826d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -230
app.py CHANGED
@@ -2,49 +2,53 @@ import gradio as gr
2
  import torch
3
  import spaces
4
  import json
 
5
  import numpy as np
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
  from threading import Lock
 
 
8
 
9
- # --- Global Config ---
10
- # Qwen 2.5 32B is the target. We must use a global lock for thread safety on ZeroGPU.
11
- MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
 
 
 
 
 
 
12
  model = None
13
  tokenizer = None
14
- model_lock = Lock()
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Storage for capturing live layer data
18
- current_layer_stats = {}
19
 
20
- # --- Frontend Logic (Embedded HTML + Three.js) ---
21
- # This Javascript will handle the 3D visualization in the browser.
22
- frontend_html = """
23
  <!DOCTYPE html>
24
  <html lang="en">
25
  <head>
26
  <meta charset="UTF-8">
27
  <style>
28
- body { margin: 0; background-color: #020617; overflow: hidden; font-family: 'Segoe UI', sans-serif; }
29
- #viz-container { width: 100%; height: 600px; position: relative; border: 1px solid #1e293b; border-radius: 8px; }
30
- #hud { position: absolute; top: 10px; left: 10px; pointer-events: none; z-index: 10; color: #94a3b8; }
31
- .hud-panel { background: rgba(15, 23, 42, 0.85); padding: 12px; border-radius: 6px; border: 1px solid #334155; display: inline-block; backdrop-filter: blur(4px); }
32
- h1 { margin: 0; font-size: 16px; color: #38bdf8; }
33
- p { margin: 4px 0 0 0; font-size: 12px; }
34
- #stream-hidden { display: none; }
35
  </style>
36
- <!-- Load Three.js -->
37
  <script type="importmap">
38
  { "imports": { "three": "https://unpkg.com/three@0.160.0/build/three.module.js", "three/addons/": "https://unpkg.com/three@0.160.0/examples/jsm/" } }
39
  </script>
40
  </head>
41
  <body>
42
- <div id="viz-container">
43
- <div id="hud">
44
- <div class="hud-panel">
45
- <h1>Qwen 32B Activity Graph</h1>
46
- <p>Live Layer Norm Visualization</p>
47
- <p style="color: #fbbf24; margin-top:8px" id="token-display">Waiting...</p>
48
  </div>
49
  </div>
50
  </div>
@@ -53,276 +57,253 @@ frontend_html = """
53
  import * as THREE from 'three';
54
  import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
55
 
56
- // 1. Scene & Camera
57
- const container = document.getElementById('viz-container');
58
  const scene = new THREE.Scene();
59
- scene.fog = new THREE.FogExp2(0x020617, 0.035);
60
 
61
  const camera = new THREE.PerspectiveCamera(50, container.clientWidth / container.clientHeight, 0.1, 100);
62
- camera.position.set(0, 0, 30);
63
 
64
  const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
65
  renderer.setSize(container.clientWidth, container.clientHeight);
66
- renderer.setPixelRatio(window.devicePixelRatio);
67
  container.appendChild(renderer.domElement);
68
 
69
  const controls = new OrbitControls(camera, renderer.domElement);
70
- controls.enableDamping = true;
71
  controls.autoRotate = true;
72
- controls.autoRotateSpeed = 1.0;
 
73
 
74
- // 2. Geometry: Spiral Helix representing layers
75
- const numLayers = 64; // Qwen 32B layer count
76
  const nodes = [];
77
  const group = new THREE.Group();
78
 
79
- const nodeGeo = new THREE.SphereGeometry(0.5, 16, 16);
80
- const nodeMat = new THREE.MeshStandardMaterial({ color: 0x334155, roughness: 0.1, metalness: 0.5, emissive: 0x000000 });
 
 
 
 
 
 
81
 
82
- for(let i=0; i<numLayers; i++) {
83
- const mesh = new THREE.Mesh(nodeGeo, nodeMat.clone());
84
 
85
- // Calculate Helix positions
86
- const theta = i * 0.4;
87
  const y = (i - numLayers/2) * 0.6;
88
- const r = 6;
89
 
90
- mesh.position.set(Math.cos(theta)*r, y, Math.sin(theta)*r);
91
- nodes.push(mesh);
92
- group.add(mesh);
 
 
 
93
  }
94
  scene.add(group);
95
-
96
- // Add lights
97
- const ambient = new THREE.AmbientLight(0x404040);
98
- scene.add(ambient);
99
- const point = new THREE.PointLight(0xffffff, 2, 50);
100
- point.position.set(5, 10, 5);
101
- scene.add(point);
102
-
103
- // 3. Render Loop
 
 
 
 
 
104
  function animate() {
105
  requestAnimationFrame(animate);
106
  controls.update();
107
  renderer.render(scene, camera);
108
  }
109
  animate();
110
-
111
- window.addEventListener('resize', () => {
112
- camera.aspect = container.clientWidth / container.clientHeight;
113
- camera.updateProjectionMatrix();
114
- renderer.setSize(container.clientWidth, container.clientHeight);
115
- });
116
 
117
- // 4. Data Bridge - Listener Logic
118
- // We look for the DOM element with ID 'data-stream' (Gradio Textbox)
119
- // and listen for text changes using MutationObserver.
120
-
121
- let processedLength = 0;
122
-
123
- function updateGraph(json) {
124
- document.getElementById('token-display').innerText = `Gen: "${json.token}"`;
125
-
126
- // Map activation values to the 3D nodes
127
- const vals = json.activations;
128
- const maxVal = Math.max(...vals, 0.1);
129
-
130
  nodes.forEach((node, idx) => {
131
- const val = vals[idx] || 0;
132
- // Normalize 0.0 - 1.0
133
- const norm = val / maxVal;
134
 
135
- // Scale Animation
136
- node.scale.setScalar(1 + norm * 2.5);
 
 
137
 
138
- // Color Animation: Blue -> White -> Red
139
- const color = new THREE.Color().setHSL(0.6 - (norm * 0.6), 1.0, 0.5 + (norm * 0.4));
140
- node.material.color.copy(color);
141
- node.material.emissive.setHSL(0.6 - (norm * 0.6), 1.0, norm);
142
  });
143
  }
144
 
145
- // Logic to setup observer once the Gradio app loads fully
146
- const setupObserver = setInterval(() => {
147
- const target = document.querySelector('#data-stream textarea') || document.getElementById('data-stream');
148
-
149
- if (target) {
150
- console.log("3D Visualizer: Connected to Stream");
151
- clearInterval(setupObserver);
 
 
 
 
 
152
 
153
- const observer = new MutationObserver((mutations) => {
154
- const text = target.value || target.innerText;
155
- // Only process new data chunks
156
- if(text.length > processedLength) {
157
- const newContent = text.substring(processedLength);
158
- processedLength = text.length;
159
-
160
- // Parse JSON lines. Sometimes chunks have multiple lines.
161
- const lines = newContent.trim().split('\\n');
162
- lines.forEach(line => {
163
- try {
164
- if(line.startsWith('{')) updateGraph(JSON.parse(line));
165
- } catch(e) {}
166
- });
167
- }
168
  });
169
- observer.observe(target, { attributes: true, childList: true, subtree: true });
170
  }
171
- }, 1000);
172
 
 
 
 
 
 
 
173
  </script>
174
  </body>
175
  </html>
176
  """
177
 
178
- # --- Backend: Model Logic ---
179
 
180
- def load_qwen_32b():
181
- """
182
- Loads Qwen-32B with 4-bit Quantization (fits ~19GB VRAM).
183
- """
184
  global model, tokenizer
185
- if model is not None: return
186
-
187
- print("LOADING: Qwen 2.5 32B (4-bit)...")
188
 
189
- bnb_config = BitsAndBytesConfig(
190
- load_in_4bit=True,
191
- bnb_4bit_compute_dtype=torch.float16,
192
- bnb_4bit_quant_type="nf4"
193
- )
 
 
 
 
 
 
 
 
194
 
195
- model = AutoModelForCausalLM.from_pretrained(
196
- MODEL_ID,
197
- quantization_config=bnb_config,
198
- device_map="auto",
199
- trust_remote_code=True
200
- )
201
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
202
- print("MODEL LOADED.")
203
-
204
- def make_hook(layer_idx):
205
- """
206
- Creates a PyTorch forward hook that measures the 'activity'
207
- (L2 norm) of a specific layer during inference.
208
- """
209
- def hook(module, input, output):
210
- # Qwen returns (hidden_states, past_key_values)
211
- if isinstance(output, tuple):
212
- hidden = output[0]
213
- else:
214
- hidden = output
215
-
216
- # We calculate the norm of the last token generated
217
- # hidden shape: [Batch, Seq, Dim]
218
- # We access the last token: [:, -1, :]
219
  with torch.no_grad():
220
- activation_val = torch.norm(hidden[:, -1, :], p=2).item()
221
- current_layer_stats[layer_idx] = activation_val
222
-
223
- return hook
224
 
225
- @spaces.GPU(duration=120)
226
- def generate_and_visualize(prompt):
227
- global model, tokenizer
228
 
229
- # Ensure loaded (Lazy loading for faster startup)
230
- if model is None:
231
- load_qwen_32b()
232
-
233
- # 1. Register Hooks (Visualization Data Miners)
234
- # We clear old hooks to be safe
235
  hooks = []
236
- current_layer_stats.clear()
237
-
238
- # Qwen uses 'model.model.layers'
239
  for i, layer in enumerate(model.model.layers):
240
- h = layer.register_forward_hook(make_hook(i))
241
  hooks.append(h)
242
-
243
  # 2. Tokenize
244
- messages = [
245
- {"role": "system", "content": "You are a helpful coding assistant."},
246
- {"role": "user", "content": prompt}
247
- ]
248
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
249
- inputs = tokenizer([text], return_tensors="pt").to(device)
250
-
251
- # 3. Manual Generation Loop (Streaming)
252
  input_ids = inputs.input_ids
 
253
 
254
- # Yield initial clear state
255
- yield json.dumps({"token": "", "activations": []}) + "\n"
256
 
257
- # We generate up to 256 tokens for this demo
258
- max_new_tokens = 256
259
- generated_text = ""
260
-
261
- # NOTE: Using a custom loop instead of .generate to get granular access
262
- past_key_values = None
263
 
264
- for _ in range(max_new_tokens):
265
- with torch.no_grad():
266
- if past_key_values is None:
267
- outputs = model(input_ids)
268
- else:
269
- outputs = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
270
-
271
- logits = outputs.logits[:, -1, :]
272
- past_key_values = outputs.past_key_values
273
-
274
- # Simple Greedy Decoding
275
- next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
276
-
277
- # Update sequence
278
- input_ids = torch.cat([input_ids, next_token], dim=-1)
279
-
280
- # Decode Token
281
- token_str = tokenizer.decode(next_token[0], skip_special_tokens=True)
282
- generated_text += token_str
283
-
284
- # PREPARE DATA PACKET for Visualization
285
- # Collect data from hooks (sorted by layer index)
286
- # 64 layers in 32B model
287
- act_values = [current_layer_stats.get(i, 0.0) for i in range(len(model.model.layers))]
288
-
289
- json_payload = json.dumps({
290
- "token": token_str,
291
- "activations": act_values
292
- })
293
-
294
- # Yield packet (Frontend sees this)
295
- yield json_payload + "\n"
296
-
297
- # Break on EOS
298
- if next_token.item() == tokenizer.eos_token_id:
299
- break
300
 
301
- # Cleanup hooks
302
- for h in hooks: h.remove()
 
303
 
304
- # --- Gradio UI Layout ---
305
 
306
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="gray")) as demo:
307
- gr.Markdown("## Qwen 2.5 32B Real-Time Neural Internals")
308
 
309
  with gr.Row():
310
- with gr.Column(scale=4):
311
- # Input Area
312
- user_input = gr.Textbox(label="Prompt (Coding/Reasoning)", value="Write a Python script for Dijkstra's algorithm.", lines=3)
313
- btn = gr.Button("Generate", variant="primary")
314
-
315
- # THE BRIDGE: This textbox receives the stream from Python
316
- # It is given a specific ID so JS can find it.
317
- # We set visible=True but users won't look at it (css hides it partially).
318
- stream_box = gr.Textbox(label="Raw Data Stream", elem_id="data-stream", visible=False)
319
 
320
- with gr.Column(scale=5):
321
- # Visualization
322
- gr.HTML(frontend_html)
 
 
323
 
324
- # Event Wiring
325
- btn.click(generate_and_visualize, inputs=user_input, outputs=stream_box)
 
 
 
 
 
326
 
327
  if __name__ == "__main__":
328
  demo.launch()
 
2
  import torch
3
  import spaces
4
  import json
5
+ import os
6
  import numpy as np
 
7
  from threading import Lock
8
+ from huggingface_hub import snapshot_download
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
+ # --- 1. PRE-DOWNLOAD STEP ---
12
+ # This runs immediately when the container starts to ensure the model is ready.
13
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
14
+ print(f"⬇️ Downloading {MODEL_ID}...")
15
+ snapshot_download(repo_id=MODEL_ID)
16
+ print("✅ Download complete.")
17
+
18
+ # --- 2. Global State ---
19
+ model_lock = Lock()
20
  model = None
21
  tokenizer = None
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ # Store layer activations for the visualizer
25
+ current_activations = {}
26
 
27
+ # --- 3. Frontend: HTML & Three.js 3D Visualizer ---
28
+ # We configure this for 28 layers (the size of Qwen 1.5B)
29
+ visualization_html = """
30
  <!DOCTYPE html>
31
  <html lang="en">
32
  <head>
33
  <meta charset="UTF-8">
34
  <style>
35
+ body { margin: 0; background: transparent; overflow: hidden; font-family: monospace; }
36
+ #canvas-wrapper { width: 100%; height: 500px; border-radius: 8px; border: 1px solid #333; background: #0b0f19; position: relative; }
37
+ #overlay { position: absolute; top: 10px; left: 10px; color: #00ffcc; z-index: 10; pointer-events: none; }
38
+ .data-panel { background: rgba(0,0,0,0.5); padding: 5px 10px; border-radius: 4px; }
39
+ #stream_hidden { display: none; }
 
 
40
  </style>
41
+ <!-- Import Three.js -->
42
  <script type="importmap">
43
  { "imports": { "three": "https://unpkg.com/three@0.160.0/build/three.module.js", "three/addons/": "https://unpkg.com/three@0.160.0/examples/jsm/" } }
44
  </script>
45
  </head>
46
  <body>
47
+ <div id="canvas-wrapper">
48
+ <div id="overlay">
49
+ <div class="data-panel">
50
+ <div id="status">INITIATING...</div>
51
+ <div id="token-show" style="color: white; font-weight: bold;"></div>
 
52
  </div>
53
  </div>
54
  </div>
 
57
  import * as THREE from 'three';
58
  import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
59
 
60
+ // 1. Setup Scene
61
+ const container = document.getElementById('canvas-wrapper');
62
  const scene = new THREE.Scene();
63
+ scene.fog = new THREE.FogExp2(0x0b0f19, 0.05);
64
 
65
  const camera = new THREE.PerspectiveCamera(50, container.clientWidth / container.clientHeight, 0.1, 100);
66
+ camera.position.set(0, 0, 20);
67
 
68
  const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
69
  renderer.setSize(container.clientWidth, container.clientHeight);
 
70
  container.appendChild(renderer.domElement);
71
 
72
  const controls = new OrbitControls(camera, renderer.domElement);
 
73
  controls.autoRotate = true;
74
+ controls.autoRotateSpeed = 2.0;
75
+ controls.enableDamping = true;
76
 
77
+ // 2. Build 3D Neural Tower
78
+ const numLayers = 28; // Qwen 1.5B has 28 layers
79
  const nodes = [];
80
  const group = new THREE.Group();
81
 
82
+ // Geometry: Flattened cylinders representing layers
83
+ const geometry = new THREE.CylinderGeometry(2, 2, 0.2, 32);
84
+ const material = new THREE.MeshStandardMaterial({
85
+ color: 0x223344,
86
+ emissive: 0x000000,
87
+ metalness: 0.8,
88
+ roughness: 0.2
89
+ });
90
 
91
+ for (let i = 0; i < numLayers; i++) {
92
+ const node = new THREE.Mesh(geometry, material.clone());
93
 
94
+ // Vertical Stack
 
95
  const y = (i - numLayers/2) * 0.6;
96
+ node.position.set(0, y, 0);
97
 
98
+ // Subtle rotation spiral
99
+ node.rotation.y = i * 0.1;
100
+ node.rotation.x = 0.1;
101
+
102
+ nodes.push(node);
103
+ group.add(node);
104
  }
105
  scene.add(group);
106
+
107
+ // Add connecting central 'axon'
108
+ const coreGeo = new THREE.CylinderGeometry(0.2, 0.2, numLayers * 0.6, 8);
109
+ const coreMat = new THREE.MeshBasicMaterial({ color: 0x0044aa, transparent: true, opacity: 0.5 });
110
+ const core = new THREE.Mesh(coreGeo, coreMat);
111
+ scene.add(core);
112
+
113
+ // Lights
114
+ const light = new THREE.PointLight(0x00ffff, 2, 50);
115
+ light.position.set(5, 5, 10);
116
+ scene.add(light);
117
+ scene.add(new THREE.AmbientLight(0x222222));
118
+
119
+ // Animation Loop
120
  function animate() {
121
  requestAnimationFrame(animate);
122
  controls.update();
123
  renderer.render(scene, camera);
124
  }
125
  animate();
 
 
 
 
 
 
126
 
127
+ // 3. Data Streaming Logic
128
+ function updateVisuals(data) {
129
+ document.getElementById('status').innerText = "BRAIN ACTIVITY: ACTIVE";
130
+ document.getElementById('token-show').innerText = `"${data.token}"`;
131
+
132
+ const acts = data.activations;
133
+ const maxVal = Math.max(...acts, 1.0);
134
+
 
 
 
 
 
135
  nodes.forEach((node, idx) => {
136
+ const val = acts[idx] || 0;
137
+ const normalized = val / maxVal;
 
138
 
139
+ // Color Logic: Blue -> White -> Orange
140
+ const targetColor = new THREE.Color().setHSL(0.6 - (normalized*0.5), 1.0, 0.2 + (normalized*0.5));
141
+ node.material.color.copy(targetColor);
142
+ node.material.emissive.copy(targetColor).multiplyScalar(normalized * 2);
143
 
144
+ // Expansion Logic
145
+ node.scale.set(1 + normalized, 1, 1 + normalized);
 
 
146
  });
147
  }
148
 
149
+ // Bridge to Gradio Textbox
150
+ let lastLen = 0;
151
+ setInterval(() => {
152
+ // Find the invisible stream textbox provided by Python
153
+ const el = document.querySelector('#stream-bridge textarea') || document.getElementById('stream-bridge');
154
+ if(!el) return;
155
+
156
+ const content = el.value || "";
157
+ if(content.length > lastLen) {
158
+ // Parse only the new lines
159
+ const newLines = content.substring(lastLen).trim().split('\\n');
160
+ lastLen = content.length;
161
 
162
+ newLines.forEach(line => {
163
+ try {
164
+ if(line.startsWith('{')) {
165
+ updateVisuals(JSON.parse(line));
166
+ }
167
+ } catch(e) {}
 
 
 
 
 
 
 
 
 
168
  });
 
169
  }
170
+ }, 50); // check every 50ms
171
 
172
+ // Resize handler
173
+ window.addEventListener('resize', () => {
174
+ camera.aspect = container.clientWidth / container.clientHeight;
175
+ camera.updateProjectionMatrix();
176
+ renderer.setSize(container.clientWidth, container.clientHeight);
177
+ });
178
  </script>
179
  </body>
180
  </html>
181
  """
182
 
183
+ # --- 4. Backend Logic ---
184
 
185
+ def get_model():
186
+ """Load model with Torch standard precision (small enough for standard load)"""
 
 
187
  global model, tokenizer
188
+ if model is not None:
189
+ return model, tokenizer
 
190
 
191
+ with model_lock:
192
+ if model is not None: return model, tokenizer
193
+
194
+ print("LOADING Qwen 1.5B (FP16)...")
195
+ # Load in Float16 to fit nicely in 3GB VRAM
196
+ model = AutoModelForCausalLM.from_pretrained(
197
+ MODEL_ID,
198
+ torch_dtype=torch.float16,
199
+ device_map="auto"
200
+ )
201
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
202
+ print("Model Loaded.")
203
+ return model, tokenizer
204
 
205
+ def hook_fn(layer_idx):
206
+ def _hook(module, inp, out):
207
+ # Qwen tuple output: (hidden_states, ...)
208
+ if isinstance(out, tuple): hidden = out[0]
209
+ else: hidden = out
210
+
211
+ # Capture L2 Norm of the *last token*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  with torch.no_grad():
213
+ # [batch, seq, dim] -> take last sequence element
214
+ norm = hidden[:, -1, :].norm(p=2).item()
215
+ current_activations[layer_idx] = norm
216
+ return _hook
217
 
218
+ @spaces.GPU
219
+ def chat_stream(prompt):
220
+ model, tokenizer = get_model()
221
 
222
+ # 1. Register hooks on all 28 layers
223
+ current_activations.clear()
 
 
 
 
224
  hooks = []
225
+ # model.model.layers is standard for Qwen
 
 
226
  for i, layer in enumerate(model.model.layers):
227
+ h = layer.register_forward_hook(hook_fn(i))
228
  hooks.append(h)
229
+
230
  # 2. Tokenize
231
+ messages = [{"role": "user", "content": prompt}]
232
+ text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
233
+ inputs = tokenizer([text_input], return_tensors="pt").to(model.device)
234
+
235
+ # 3. Generate Loop
 
 
 
236
  input_ids = inputs.input_ids
237
+ past_key_values = None
238
 
239
+ generated_full_text = ""
240
+ yield "", "" # Reset UI
241
 
242
+ max_tokens = 300
 
 
 
 
 
243
 
244
+ try:
245
+ for _ in range(max_tokens):
246
+ with torch.no_grad():
247
+ if past_key_values is None:
248
+ out = model(input_ids)
249
+ else:
250
+ out = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values)
251
+
252
+ logits = out.logits[:, -1, :]
253
+ past_key_values = out.past_key_values
254
+
255
+ next_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
256
+
257
+ # Check stop
258
+ if next_id.item() == tokenizer.eos_token_id:
259
+ break
260
+
261
+ token_txt = tokenizer.decode(next_id[0], skip_special_tokens=True)
262
+ generated_full_text += token_txt
263
+
264
+ input_ids = torch.cat([input_ids, next_id], dim=-1)
265
+
266
+ # 4. Prepare Stream Data
267
+ # Get stats for all 28 layers
268
+ layer_stats = [current_activations.get(i, 0.0) for i in range(28)]
269
+
270
+ # Viz JSON (goes to hidden box)
271
+ viz_json = json.dumps({
272
+ "token": token_txt,
273
+ "activations": layer_stats
274
+ }) + "\n"
275
+
276
+ # Yield: (Viz Data, Answer Text)
277
+ yield viz_json, generated_full_text
 
 
278
 
279
+ finally:
280
+ # Cleanup
281
+ for h in hooks: h.remove()
282
 
283
+ # --- 5. UI Layout ---
284
 
285
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
286
+ gr.Markdown("## Qwen2.5-1.5B 3D Network Explorer (Fast & Light)")
287
 
288
  with gr.Row():
289
+ with gr.Column(scale=1):
290
+ prompt = gr.Textbox(label="User Question", lines=2, placeholder="Type your query...")
291
+ run_btn = gr.Button("Thinking Process", variant="primary")
292
+ answer_box = gr.Textbox(label="AI Answer", lines=10, interactive=False)
 
 
 
 
 
293
 
294
+ # HIDDEN bridge for 3D data
295
+ stream_bridge = gr.Textbox(elem_id="stream-bridge", visible=False)
296
+
297
+ with gr.Column(scale=1):
298
+ gr.HTML(visualization_html)
299
 
300
+ # Wire it up
301
+ # Output Order must match: yield viz_json, generated_full_text
302
+ run_btn.click(
303
+ fn=chat_stream,
304
+ inputs=prompt,
305
+ outputs=[stream_bridge, answer_box]
306
+ )
307
 
308
  if __name__ == "__main__":
309
  demo.launch()