SeaWolf-AI commited on
Commit
c98aa0c
Β·
verified Β·
1 Parent(s): b8875ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -129
app.py CHANGED
@@ -85,34 +85,23 @@ THINKING_END = "<channel|>"
85
  _loaded_model_name = None
86
  _model = None
87
  _processor = None
88
- _strip_tokens = []
89
-
90
 
91
  def _load_model(model_name: str):
92
- """Load or switch model. Unloads previous model first."""
93
  global _loaded_model_name, _model, _processor, _strip_tokens
94
 
95
  if _loaded_model_name == model_name and _model is not None:
96
- return # Already loaded
97
 
98
  model_cfg = MODELS[model_name]
99
  model_id = model_cfg["id"]
100
  print(f"[MODEL] Loading {model_name} ({model_id})...", flush=True)
101
 
102
- # Unload previous model
103
- if _model is not None:
104
- del _model
105
- _model = None
106
- torch.cuda.empty_cache()
107
- import gc; gc.collect()
108
- print(f"[MODEL] Unloaded previous model", flush=True)
109
-
110
  _processor = AutoProcessor.from_pretrained(model_id)
111
  _model = AutoModelForMultimodalLM.from_pretrained(
112
  model_id, device_map="auto", dtype=torch.bfloat16,
113
  )
114
 
115
- # Build strip tokens list (keep thinking delimiters)
116
  _keep = {THINKING_START, THINKING_END}
117
  _strip_tokens = sorted(
118
  (t for t in _processor.tokenizer.all_special_tokens if t not in _keep),
@@ -123,7 +112,7 @@ def _load_model(model_name: str):
123
  print(f"[MODEL] βœ“ {model_name} loaded ({model_cfg['arch']}, {model_cfg['active']} active)", flush=True)
124
 
125
 
126
- # Load default model at startup
127
  _load_model(DEFAULT_MODEL)
128
 
129
 
@@ -137,17 +126,14 @@ def _strip_special_tokens(text: str) -> str:
137
  # 3. THINKING MODE HELPERS
138
  # ══════════════════════════════════════════════════════════════════════════════
139
  def parse_think_blocks(text: str) -> tuple[str, str]:
140
- """Parse <|channel>...<channel|> thinking blocks"""
141
  m = re.search(r"<\|channel\>(.*?)<channel\|>\s*", text, re.DOTALL)
142
  if m:
143
  return (m.group(1).strip(), text[m.end():].strip())
144
- # Fallback: <think>...</think>
145
  m = re.search(r"<think>(.*?)</think>\s*", text, re.DOTALL)
146
  return (m.group(1).strip(), text[m.end():].strip()) if m else ("", text)
147
 
148
 
149
  def format_response(raw: str) -> str:
150
- """Format response with thinking blocks collapsed"""
151
  chain, answer = parse_think_blocks(raw)
152
  if chain:
153
  return (
@@ -157,7 +143,6 @@ def format_response(raw: str) -> str:
157
  "</details>\n\n"
158
  f"{answer}"
159
  )
160
- # Thinking in progress
161
  if THINKING_START in raw and THINKING_END not in raw:
162
  think_len = len(raw) - raw.index(THINKING_START) - len(THINKING_START)
163
  return f"🧠 Reasoning... ({think_len} chars)"
@@ -240,13 +225,8 @@ def generate_reply(
240
  max_new_tokens: int,
241
  temperature: float,
242
  top_p: float,
243
- model_choice: str,
244
  ) -> Generator[str, None, None]:
245
- """Main generation function β€” builds messages, calls GPU inference."""
246
-
247
- # ── Model switching ──
248
- target_model = model_choice if model_choice in MODELS else DEFAULT_MODEL
249
- _load_model(target_model)
250
 
251
  use_think = "Thinking" in thinking_mode
252
  max_new_tokens = min(int(max_new_tokens), 8192)
@@ -256,17 +236,12 @@ def generate_reply(
256
  if system_prompt.strip():
257
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]})
258
 
259
- # Process history
260
  for turn in history:
261
  if isinstance(turn, dict):
262
  role = turn.get("role", "")
263
  raw = turn.get("content") or ""
264
  if isinstance(raw, list):
265
- text_parts = []
266
- for p in raw:
267
- if isinstance(p, dict) and p.get("type") == "text":
268
- text_parts.append(p.get("text", ""))
269
- text = " ".join(text_parts)
270
  else:
271
  text = str(raw)
272
  if role == "user":
@@ -275,34 +250,13 @@ def generate_reply(
275
  _, clean = parse_think_blocks(text)
276
  messages.append({"role": "assistant", "content": [{"type": "text", "text": clean}]})
277
 
278
- # ── User message with optional vision ──
279
  user_content: list[dict] = []
280
 
281
- if image_input:
282
- try:
283
- import io
284
- from PIL import Image as PILImage
285
-
286
- if isinstance(image_input, str) and image_input.startswith("data:"):
287
- _, b64_data = image_input.split(",", 1)
288
- img_bytes = base64.b64decode(b64_data)
289
- elif isinstance(image_input, str) and os.path.isfile(image_input):
290
- with open(image_input, "rb") as f:
291
- img_bytes = f.read()
292
- else:
293
- buf = io.BytesIO()
294
- if not isinstance(image_input, PILImage.Image):
295
- image_input = PILImage.fromarray(image_input)
296
- image_input.save(buf, format="JPEG")
297
- img_bytes = buf.getvalue()
298
-
299
- b64 = base64.b64encode(img_bytes).decode()
300
- user_content.append({
301
- "type": "image",
302
- "url": f"data:image/jpeg;base64,{b64}",
303
- })
304
- except Exception as e:
305
- print(f"[VISION] Image processing error: {e}", flush=True)
306
 
307
  user_content.append({"type": "text", "text": message})
308
  messages.append({"role": "user", "content": user_content})
@@ -340,8 +294,6 @@ def generate_reply(
340
  yield f"**❌ Generation error:** `{e}`"
341
 
342
 
343
- # ══════════════════════════════════════════════════════════════════════════════
344
-
345
  # ══════════════════════════════════════════════════════════════════════════════
346
  # 6. GRADIO UI
347
  # ══════════════════════════════════════════════════════════════════════════════
@@ -351,109 +303,76 @@ footer { display: none !important; }
351
  .gradio-container { background: #faf8f5 !important; }
352
  #send-btn { background: linear-gradient(135deg, #6d28d9, #7c3aed) !important; border: none !important; border-radius: 12px !important; color: white !important; font-size: 18px !important; min-width: 48px !important; }
353
  #chatbot { border: 1.5px solid #e4dfd8 !important; border-radius: 14px !important; background: rgba(255,255,255,.65) !important; }
354
- .model-info-box { padding: 10px 14px; border-radius: 10px; border: 1.5px solid rgba(109,40,217,.2); background: linear-gradient(135deg, rgba(109,40,217,.04), rgba(16,185,129,.03)); font-size: 12px; line-height: 1.6; }
355
- .model-info-box b { color: #6d28d9; }
356
- .model-info-box .stats { font-size: 10px; color: #78716c; margin-top: 4px; }
357
  """
358
 
359
- # Model info display (updates when dropdown changes)
360
- def _model_info_html(name):
361
- m = MODELS.get(name, MODELS[DEFAULT_MODEL])
362
- return (
363
- f'<div class="model-info-box">'
364
- f'<b>{"⚑" if m["arch"]=="MoE" else "πŸ†"} {name}</b> '
365
- f'<span style="font-size:9px;padding:2px 6px;border-radius:6px;background:rgba(109,40,217,.08);color:#6d28d9;font-weight:700">{m["arch"]}</span><br>'
366
- f'<div class="stats">{m["active"]} active / {m["total"]} total Β· πŸ‘οΈ Vision Β· {m["ctx"]} context<br>{m["desc"]}</div>'
367
- f'</div>'
368
- )
 
 
369
 
370
  with gr.Blocks(title="Gemma 4 Playground") as demo:
371
 
372
- gr.Markdown("## πŸ’Ž Gemma 4 Playground\nGoogle DeepMind Β· Dense 31B or MoE 26B-A4B Β· Vision Β· Thinking Β· Apache 2.0")
373
 
374
  with gr.Row():
375
- # ══ Sidebar ══
376
  with gr.Column(scale=0, min_width=300):
377
-
378
- gr.Markdown("#### Select Model")
379
- model_dd = gr.Dropdown(
380
- choices=list(MODELS.keys()), value=DEFAULT_MODEL,
381
- label="Model", elem_id="model-dd",
382
- info="MoE=Fast inference | Dense=Best quality",
383
- )
384
- model_info = gr.HTML(value=_model_info_html(DEFAULT_MODEL))
385
 
386
  gr.Markdown("---")
387
- gr.Markdown("#### πŸ‘οΈ Vision")
388
- image_input = gr.Image(label="Upload image", type="filepath", height=150)
389
 
390
  gr.Markdown("---")
391
  gr.Markdown("#### Settings")
392
- thinking_radio = gr.Radio(
393
- ["⚑ Fast", "🧠 Thinking"], value="⚑ Fast", label="Mode",
394
- )
395
- sys_prompt = gr.Textbox(
396
- value=PRESETS["general"], label="System Prompt", lines=2,
397
- )
398
- preset_dd = gr.Dropdown(
399
- choices=list(PRESETS.keys()), value="general", label="Preset",
400
- )
401
  max_tok = gr.Slider(64, 8192, value=4096, step=64, label="Max Tokens")
402
  temp = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
403
  topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
404
  clear_btn = gr.Button("πŸ—‘οΈ Clear conversation", size="sm")
405
 
406
- # ══ Chat ══
407
  with gr.Column(scale=3):
408
  chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=600)
409
  with gr.Row():
410
  chat_input = gr.Textbox(
411
- placeholder="Message Gemma 4…", show_label=False,
412
- scale=7, autofocus=True, lines=1, max_lines=4,
413
  )
414
  send_btn = gr.Button("↑", variant="primary", scale=0, min_width=48, elem_id="send-btn")
415
 
416
- # ── Events: model info update ──
417
- model_dd.change(
418
- fn=_model_info_html,
419
- inputs=[model_dd],
420
- outputs=[model_info],
421
- )
422
-
423
- # ── Events: preset β†’ system prompt ──
424
- preset_dd.change(
425
- fn=lambda k: PRESETS.get(k, PRESETS["general"]),
426
- inputs=[preset_dd],
427
- outputs=[sys_prompt],
428
- )
429
 
430
- # ── Chat logic ──
431
  def user_msg(msg, hist):
432
- if not msg.strip():
433
- return "", hist
434
  return "", hist + [{"role": "user", "content": msg}]
435
 
436
- def bot_reply(hist, think, img, sysp, maxt, tmp, tp, model):
437
- if not hist or hist[-1]["role"] != "user":
438
- return hist
439
  txt, past = hist[-1]["content"], hist[:-1]
440
  hist = hist + [{"role": "assistant", "content": ""}]
441
- for chunk in generate_reply(txt, past, think, img, sysp, maxt, tmp, tp, model):
442
  hist[-1]["content"] = chunk
443
  yield hist
444
 
445
- ins = [chatbot, thinking_radio, image_input, sys_prompt, max_tok, temp, topp, model_dd]
446
-
447
- send_btn.click(
448
- user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False
449
- ).then(
450
- bot_reply, ins, chatbot
451
- )
452
- chat_input.submit(
453
- user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False
454
- ).then(
455
- bot_reply, ins, chatbot
456
- )
457
  clear_btn.click(lambda: [], None, chatbot, queue=False)
458
 
459
 
@@ -461,5 +380,5 @@ with gr.Blocks(title="Gemma 4 Playground") as demo:
461
  # 7. LAUNCH
462
  # ══════════════════════════════════════════════════════════════════════════════
463
  if __name__ == "__main__":
464
- print(f"[BOOT] Gemma 4 Playground Β· Default: {DEFAULT_MODEL}", flush=True)
465
- demo.launch(server_name="0.0.0.0", server_port=7860, css=CSS)
 
85
  _loaded_model_name = None
86
  _model = None
87
  _processor = None
 
 
88
 
89
  def _load_model(model_name: str):
90
+ """Load model at startup only. ZeroGPU packs tensors once β€” no runtime switching."""
91
  global _loaded_model_name, _model, _processor, _strip_tokens
92
 
93
  if _loaded_model_name == model_name and _model is not None:
94
+ return
95
 
96
  model_cfg = MODELS[model_name]
97
  model_id = model_cfg["id"]
98
  print(f"[MODEL] Loading {model_name} ({model_id})...", flush=True)
99
 
 
 
 
 
 
 
 
 
100
  _processor = AutoProcessor.from_pretrained(model_id)
101
  _model = AutoModelForMultimodalLM.from_pretrained(
102
  model_id, device_map="auto", dtype=torch.bfloat16,
103
  )
104
 
 
105
  _keep = {THINKING_START, THINKING_END}
106
  _strip_tokens = sorted(
107
  (t for t in _processor.tokenizer.all_special_tokens if t not in _keep),
 
112
  print(f"[MODEL] βœ“ {model_name} loaded ({model_cfg['arch']}, {model_cfg['active']} active)", flush=True)
113
 
114
 
115
+ # Load default model at startup (ZeroGPU will pack tensors β€” cannot switch later)
116
  _load_model(DEFAULT_MODEL)
117
 
118
 
 
126
  # 3. THINKING MODE HELPERS
127
  # ══════════════════════════════════════════════════════════════════════════════
128
  def parse_think_blocks(text: str) -> tuple[str, str]:
 
129
  m = re.search(r"<\|channel\>(.*?)<channel\|>\s*", text, re.DOTALL)
130
  if m:
131
  return (m.group(1).strip(), text[m.end():].strip())
 
132
  m = re.search(r"<think>(.*?)</think>\s*", text, re.DOTALL)
133
  return (m.group(1).strip(), text[m.end():].strip()) if m else ("", text)
134
 
135
 
136
  def format_response(raw: str) -> str:
 
137
  chain, answer = parse_think_blocks(raw)
138
  if chain:
139
  return (
 
143
  "</details>\n\n"
144
  f"{answer}"
145
  )
 
146
  if THINKING_START in raw and THINKING_END not in raw:
147
  think_len = len(raw) - raw.index(THINKING_START) - len(THINKING_START)
148
  return f"🧠 Reasoning... ({think_len} chars)"
 
225
  max_new_tokens: int,
226
  temperature: float,
227
  top_p: float,
 
228
  ) -> Generator[str, None, None]:
229
+ """Main generation function."""
 
 
 
 
230
 
231
  use_think = "Thinking" in thinking_mode
232
  max_new_tokens = min(int(max_new_tokens), 8192)
 
236
  if system_prompt.strip():
237
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]})
238
 
 
239
  for turn in history:
240
  if isinstance(turn, dict):
241
  role = turn.get("role", "")
242
  raw = turn.get("content") or ""
243
  if isinstance(raw, list):
244
+ text = " ".join(p.get("text", "") for p in raw if isinstance(p, dict) and p.get("type") == "text")
 
 
 
 
245
  else:
246
  text = str(raw)
247
  if role == "user":
 
250
  _, clean = parse_think_blocks(text)
251
  messages.append({"role": "assistant", "content": [{"type": "text", "text": clean}]})
252
 
253
+ # ── User message with optional image ──
254
  user_content: list[dict] = []
255
 
256
+ # IMAGE: pass filepath directly as URL (Gemma 4 processor handles it)
257
+ if image_input and isinstance(image_input, str) and os.path.isfile(image_input):
258
+ user_content.append({"type": "image", "url": image_input})
259
+ print(f"[VISION] Image attached: {image_input}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  user_content.append({"type": "text", "text": message})
262
  messages.append({"role": "user", "content": user_content})
 
294
  yield f"**❌ Generation error:** `{e}`"
295
 
296
 
 
 
297
  # ══════════════════════════════════════════════════════════════════════════════
298
  # 6. GRADIO UI
299
  # ══════════════════════════════════════════════════════════════════════════════
 
303
  .gradio-container { background: #faf8f5 !important; }
304
  #send-btn { background: linear-gradient(135deg, #6d28d9, #7c3aed) !important; border: none !important; border-radius: 12px !important; color: white !important; font-size: 18px !important; min-width: 48px !important; }
305
  #chatbot { border: 1.5px solid #e4dfd8 !important; border-radius: 14px !important; background: rgba(255,255,255,.65) !important; }
306
+ .model-box { padding: 10px 14px; border-radius: 10px; border: 1.5px solid rgba(109,40,217,.2); background: linear-gradient(135deg, rgba(109,40,217,.04), rgba(16,185,129,.03)); font-size: 12px; line-height: 1.6; }
307
+ .model-box b { color: #6d28d9; }
308
+ .model-box .st { font-size: 10px; color: #78716c; margin-top: 4px; }
309
  """
310
 
311
+ _mcfg = MODELS[DEFAULT_MODEL]
312
+ MODEL_INFO_HTML = (
313
+ f'<div class="model-box">'
314
+ f'<b>{"⚑" if _mcfg["arch"]=="MoE" else "πŸ†"} {DEFAULT_MODEL}</b> '
315
+ f'<span style="font-size:9px;padding:2px 6px;border-radius:6px;background:rgba(109,40,217,.08);color:#6d28d9;font-weight:700">{_mcfg["arch"]}</span><br>'
316
+ f'<div class="st">{_mcfg["active"]} active / {_mcfg["total"]} total Β· πŸ‘οΈ Vision Β· {_mcfg["ctx"]} context</div>'
317
+ f'<div class="st">{_mcfg["desc"]}</div>'
318
+ f'<div class="st" style="margin-top:6px">'
319
+ f'<a href="https://huggingface.co/{MODELS[DEFAULT_MODEL]["id"]}" target="_blank" style="color:#6d28d9;font-weight:700;text-decoration:none">πŸ€— Model Card β†—</a> Β· '
320
+ f'<a href="https://deepmind.google/models/gemma/gemma-4/" target="_blank" style="color:#059669;font-weight:700;text-decoration:none">πŸ”¬ DeepMind β†—</a>'
321
+ f'</div></div>'
322
+ )
323
 
324
  with gr.Blocks(title="Gemma 4 Playground") as demo:
325
 
326
+ gr.Markdown("## πŸ’Ž Gemma 4 Playground\nGoogle DeepMind Β· Apache 2.0 Β· Vision Β· Thinking")
327
 
328
  with gr.Row():
329
+ # ── Sidebar ──
330
  with gr.Column(scale=0, min_width=300):
331
+ gr.Markdown("#### Current Model")
332
+ gr.HTML(MODEL_INFO_HTML)
 
 
 
 
 
 
333
 
334
  gr.Markdown("---")
335
+ gr.Markdown("#### πŸ‘οΈ Upload Image")
336
+ image_input = gr.Image(label=None, type="filepath", height=160)
337
 
338
  gr.Markdown("---")
339
  gr.Markdown("#### Settings")
340
+ thinking_radio = gr.Radio(["⚑ Fast", "🧠 Thinking"], value="⚑ Fast", label="Mode")
341
+ sys_prompt = gr.Textbox(value=PRESETS["general"], label="System Prompt", lines=2)
342
+ preset_dd = gr.Dropdown(choices=list(PRESETS.keys()), value="general", label="Preset")
 
 
 
 
 
 
343
  max_tok = gr.Slider(64, 8192, value=4096, step=64, label="Max Tokens")
344
  temp = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
345
  topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
346
  clear_btn = gr.Button("πŸ—‘οΈ Clear conversation", size="sm")
347
 
348
+ # ── Chat ──
349
  with gr.Column(scale=3):
350
  chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=600)
351
  with gr.Row():
352
  chat_input = gr.Textbox(
353
+ placeholder="Message Gemma 4… (upload image in sidebar for vision)",
354
+ show_label=False, scale=7, autofocus=True, lines=1, max_lines=4,
355
  )
356
  send_btn = gr.Button("↑", variant="primary", scale=0, min_width=48, elem_id="send-btn")
357
 
358
+ # ── Events ──
359
+ preset_dd.change(fn=lambda k: PRESETS.get(k, PRESETS["general"]), inputs=[preset_dd], outputs=[sys_prompt])
 
 
 
 
 
 
 
 
 
 
 
360
 
 
361
  def user_msg(msg, hist):
362
+ if not msg.strip(): return "", hist
 
363
  return "", hist + [{"role": "user", "content": msg}]
364
 
365
+ def bot_reply(hist, think, img, sysp, maxt, tmp, tp):
366
+ if not hist or hist[-1]["role"] != "user": return hist
 
367
  txt, past = hist[-1]["content"], hist[:-1]
368
  hist = hist + [{"role": "assistant", "content": ""}]
369
+ for chunk in generate_reply(txt, past, think, img, sysp, maxt, tmp, tp):
370
  hist[-1]["content"] = chunk
371
  yield hist
372
 
373
+ ins = [chatbot, thinking_radio, image_input, sys_prompt, max_tok, temp, topp]
374
+ send_btn.click(user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False).then(bot_reply, ins, chatbot)
375
+ chat_input.submit(user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False).then(bot_reply, ins, chatbot)
 
 
 
 
 
 
 
 
 
376
  clear_btn.click(lambda: [], None, chatbot, queue=False)
377
 
378
 
 
380
  # 7. LAUNCH
381
  # ══════════════════════════════════════════════════════════════════════════════
382
  if __name__ == "__main__":
383
+ print(f"[BOOT] Gemma 4 Playground Β· Model: {DEFAULT_MODEL}", flush=True)
384
+ demo.launch(server_name="0.0.0.0", server_port=7860, css=CSS, ssr_mode=False)