File size: 18,554 Bytes
09cae0a
750ba3d
 
09cae0a
 
 
 
750ba3d
5cfe54f
750ba3d
 
 
 
5c1be4f
 
 
 
 
 
 
 
750ba3d
5c1be4f
 
 
 
 
 
 
 
09cae0a
 
 
 
 
 
 
750ba3d
 
 
 
 
 
09cae0a
750ba3d
09cae0a
750ba3d
 
 
 
 
 
 
 
 
 
 
 
 
09cae0a
 
750ba3d
 
09cae0a
750ba3d
09cae0a
 
 
750ba3d
09cae0a
 
 
750ba3d
 
 
 
 
 
 
 
 
09cae0a
750ba3d
09cae0a
750ba3d
 
 
 
 
c98aa0c
750ba3d
 
 
c98aa0c
750ba3d
 
 
 
 
59d32ff
750ba3d
5cfe54f
750ba3d
 
 
 
 
 
 
 
 
 
 
 
c98aa0c
750ba3d
 
 
 
 
 
 
09cae0a
750ba3d
 
 
 
09cae0a
750ba3d
 
 
09cae0a
 
 
 
 
 
 
 
 
 
 
 
 
 
750ba3d
 
09cae0a
 
 
750ba3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09cae0a
750ba3d
09cae0a
750ba3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09cae0a
 
 
 
 
 
 
 
 
bef73b8
09cae0a
c98aa0c
09cae0a
bef73b8
 
 
 
 
 
09cae0a
750ba3d
09cae0a
750ba3d
09cae0a
 
750ba3d
09cae0a
 
 
 
750ba3d
 
c98aa0c
750ba3d
 
09cae0a
750ba3d
09cae0a
 
750ba3d
09cae0a
c98aa0c
750ba3d
09cae0a
c98aa0c
 
 
 
750ba3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5603963
09cae0a
750ba3d
 
 
 
 
 
 
 
 
 
 
09cae0a
6d6166b
d0ead96
6d6166b
 
 
e3bb4f6
 
 
b8875ff
c98aa0c
 
 
e3bb4f6
6d6166b
bef73b8
 
 
 
 
 
 
 
 
 
 
 
 
 
6d6166b
e3bb4f6
6d6166b
f550425
 
12c11cc
f550425
d0ead96
e3bb4f6
bef73b8
7719836
bef73b8
 
 
 
 
7719836
c98aa0c
7719836
 
 
 
 
 
 
e3bb4f6
c98aa0c
e3bb4f6
b8875ff
e3bb4f6
b8875ff
7719836
c98aa0c
b8875ff
7887818
6d6166b
c98aa0c
bef73b8
c98aa0c
b8875ff
e3bb4f6
c98aa0c
e3bb4f6
 
bef73b8
c98aa0c
e3bb4f6
 
bef73b8
e3bb4f6
 
 
bef73b8
c98aa0c
 
e3bb4f6
750ba3d
b8875ff
 
 
 
09cae0a
c98aa0c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
"""
🧬 Gemma 4 Playground β€” Demo Space
Dual model (31B / 26B-A4B) Β· ZeroGPU Β· Vision Β· Thinking Mode
"""
import sys
print(f"[BOOT] Python {sys.version}", flush=True)

import base64, os, re, json, subprocess
from typing import Generator
from collections.abc import Iterator
from pathlib import Path
from threading import Thread

# Install pre-built transformers wheel BEFORE importing transformers
_app_dir = Path(__file__).parent
_whls = sorted(_app_dir.glob("transformers*.whl"))
_installed = False
if _whls:
    _whl = _whls[0]
    print(f"[BOOT] Installing wheel: {_whl.name}", flush=True)
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", str(_whl)])
        _installed = True
        print("[BOOT] βœ“ Wheel installed", flush=True)
    except subprocess.CalledProcessError as e:
        print(f"[BOOT] ⚠ Wheel install failed ({e}), falling back to PyPI", flush=True)

if not _installed:
    print("[BOOT] Installing transformers from PyPI...", flush=True)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers>=4.49"])

try:
    import gradio as gr
    print(f"[BOOT] gradio {gr.__version__}", flush=True)
except ImportError as e:
    print(f"[BOOT] FATAL: {e}", flush=True); sys.exit(1)

import torch
import spaces
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
from transformers.generation.streamers import TextIteratorStreamer


# ══════════════════════════════════════════════════════════════════════════════
# 1.  MODEL CONFIG β€” Gemma 4 Dual Model
# ══════════════════════════════════════════════════════════════════════════════
MODELS = {
    "Gemma-4-31B-it": {
        "id": "google/gemma-4-31b-it",
        "arch": "Dense", "total": "30.7B", "active": "30.7B",
        "ctx": "256K", "vision": True, "audio": False,
        "desc": "Dense 31B β€” 졜고 ν’ˆμ§ˆ, AIME 89.2%, Codeforces 2150",
    },
    "Gemma-4-26B-A4B-it": {
        "id": "google/gemma-4-26B-A4B-it",
        "arch": "MoE", "total": "25.2B", "active": "3.8B",
        "ctx": "256K", "vision": True, "audio": False,
        "desc": "MoE 26B (3.8B active) β€” 31B의 95% μ„±λŠ₯, μΆ”λ‘  ~8λ°° 빠름",
    },
}

DEFAULT_MODEL = "Gemma-4-26B-A4B-it"  # MoEκ°€ ZeroGPUμ—μ„œ 더 적합

PRESETS = {
    "general":   "You are Gemma 4, a highly capable multimodal AI assistant by Google DeepMind. Think step by step for complex questions.",
    "code":      "You are an expert software engineer. Write clean, efficient, well-commented code. Explain your approach before writing. Use modern best practices.",
    "math":      "You are a world-class mathematician. Break problems step-by-step. Show full working. Use LaTeX where helpful.",
    "creative":  "You are a brilliant creative writer. Be imaginative, vivid, and engaging. Adapt tone and style to the request.",
    "translate": "You are a professional translator fluent in 140+ languages. Provide accurate, natural-sounding translations with cultural context.",
    "research":  "You are a rigorous research analyst. Provide structured, well-reasoned analysis. Identify assumptions and acknowledge uncertainty.",
}

IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
VIDEO_FILE_TYPES = (".mp4", ".mov", ".avi", ".webm")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10000"))

# Gemma 4 thinking delimiters
THINKING_START = "<|channel>"
THINKING_END = "<channel|>"


# ══════════════════════════════════════════════════════════════════════════════
# 2.  MODEL LOADING β€” Lazy load with switching
# ══════════════════════════════════════════════════════════════════════════════
_loaded_model_name = None
_model = None
_processor = None

def _load_model(model_name: str):
    """Load model at startup only. ZeroGPU packs tensors once β€” no runtime switching."""
    global _loaded_model_name, _model, _processor, _strip_tokens

    if _loaded_model_name == model_name and _model is not None:
        return

    model_cfg = MODELS[model_name]
    model_id = model_cfg["id"]
    print(f"[MODEL] Loading {model_name} ({model_id})...", flush=True)

    _processor = AutoProcessor.from_pretrained(model_id)
    _model = AutoModelForMultimodalLM.from_pretrained(
        model_id, device_map="auto", dtype=torch.bfloat16,
    )

    _keep = {THINKING_START, THINKING_END}
    _strip_tokens = sorted(
        (t for t in _processor.tokenizer.all_special_tokens if t not in _keep),
        key=len, reverse=True,
    )

    _loaded_model_name = model_name
    print(f"[MODEL] βœ“ {model_name} loaded ({model_cfg['arch']}, {model_cfg['active']} active)", flush=True)


# Load default model at startup (ZeroGPU will pack tensors β€” cannot switch later)
_load_model(DEFAULT_MODEL)


def _strip_special_tokens(text: str) -> str:
    for tok in _strip_tokens:
        text = text.replace(tok, "")
    return text


# ══════════════════════════════════════════════════════════════════════════════
# 3.  THINKING MODE HELPERS
# ══════════════════════════════════════════════════════════════════════════════
def parse_think_blocks(text: str) -> tuple[str, str]:
    m = re.search(r"<\|channel\>(.*?)<channel\|>\s*", text, re.DOTALL)
    if m:
        return (m.group(1).strip(), text[m.end():].strip())
    m = re.search(r"<think>(.*?)</think>\s*", text, re.DOTALL)
    return (m.group(1).strip(), text[m.end():].strip()) if m else ("", text)


def format_response(raw: str) -> str:
    chain, answer = parse_think_blocks(raw)
    if chain:
        return (
            "<details>\n"
            "<summary>🧠 Reasoning Chain β€” click to expand</summary>\n\n"
            f"{chain}\n\n"
            "</details>\n\n"
            f"{answer}"
        )
    if THINKING_START in raw and THINKING_END not in raw:
        think_len = len(raw) - raw.index(THINKING_START) - len(THINKING_START)
        return f"🧠 Reasoning... ({think_len} chars)"
    return raw


# ══════════════════════════════════════════════════════════════════════════════
# 4.  CLASSIFICATION & MESSAGE BUILDING
# ══════════════════════════════════════════════════════════════════════════════
def _classify_file(path: str) -> str | None:
    lower = path.lower()
    if lower.endswith(IMAGE_FILE_TYPES):
        return "image"
    if lower.endswith(VIDEO_FILE_TYPES):
        return "video"
    return None


def _has_media_type(messages: list[dict], media_type: str) -> bool:
    return any(
        c.get("type") == media_type
        for m in messages
        for c in (m["content"] if isinstance(m["content"], list) else [])
    )


# ══════════════════════════════════════════════════════════════════════════════
# 5.  GPU INFERENCE β€” ZeroGPU
# ══════════════════════════════════════════════════════════════════════════════
@spaces.GPU(duration=180)
@torch.inference_mode()
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
    inputs = inputs.to(device=_model.device, dtype=torch.bfloat16)

    streamer = TextIteratorStreamer(
        _processor,
        timeout=30.0,
        skip_prompt=True,
        skip_special_tokens=not thinking,
    )
    generate_kwargs = {
        **inputs,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "disable_compile": True,
    }

    exception_holder: list[Exception] = []

    def _generate() -> None:
        try:
            _model.generate(**generate_kwargs)
        except Exception as e:
            exception_holder.append(e)

    thread = Thread(target=_generate)
    thread.start()

    chunks: list[str] = []
    for text in streamer:
        chunks.append(text)
        accumulated = "".join(chunks)
        if thinking:
            yield _strip_special_tokens(accumulated)
        else:
            yield accumulated

    thread.join()
    if exception_holder:
        msg = f"Generation failed: {exception_holder[0]}"
        raise gr.Error(msg)


def generate_reply(
    message:        str,
    history:        list,
    thinking_mode:  str,
    image_input,
    system_prompt:  str,
    max_new_tokens: int,
    temperature:    float,
    top_p:          float,
    model_choice:   str = "",
) -> Generator[str, None, None]:
    """Main generation function."""

    # Model switching (may take 1-2 min on first switch)
    target = model_choice if model_choice in MODELS else DEFAULT_MODEL
    if target != _loaded_model_name:
        yield f"⏳ Loading **{target}**... (졜초 μ „ν™˜ μ‹œ 1-2λΆ„ μ†Œμš”)"
        _load_model(target)

    use_think = "Thinking" in thinking_mode
    max_new_tokens = min(int(max_new_tokens), 8192)

    # ── Build messages ──
    messages: list[dict] = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]})

    for turn in history:
        if isinstance(turn, dict):
            role = turn.get("role", "")
            raw = turn.get("content") or ""
            if isinstance(raw, list):
                text = " ".join(p.get("text", "") for p in raw if isinstance(p, dict) and p.get("type") == "text")
            else:
                text = str(raw)
            if role == "user":
                messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
            elif role == "assistant":
                _, clean = parse_think_blocks(text)
                messages.append({"role": "assistant", "content": [{"type": "text", "text": clean}]})

    # ── User message with optional image ──
    user_content: list[dict] = []

    # IMAGE: pass filepath directly as URL (Gemma 4 processor handles it)
    if image_input and isinstance(image_input, str) and os.path.isfile(image_input):
        user_content.append({"type": "image", "url": image_input})
        print(f"[VISION] Image attached: {image_input}", flush=True)

    user_content.append({"type": "text", "text": message})
    messages.append({"role": "user", "content": user_content})

    # ── Apply chat template ──
    try:
        template_kwargs = {
            "tokenize": True,
            "return_dict": True,
            "return_tensors": "pt",
            "add_generation_prompt": True,
            "processor_kwargs": {"images_kwargs": {"max_soft_tokens": 280}},
        }
        if _has_media_type(messages, "video"):
            template_kwargs["load_audio_from_video"] = False
        if use_think:
            template_kwargs["enable_thinking"] = True

        inputs = _processor.apply_chat_template(messages, **template_kwargs)

        n_tokens = inputs["input_ids"].shape[1]
        if n_tokens > MAX_INPUT_TOKENS:
            yield f"**❌ μž…λ ₯이 λ„ˆλ¬΄ κΉλ‹ˆλ‹€ ({n_tokens} tokens). μ΅œλŒ€ {MAX_INPUT_TOKENS} tokens.**"
            return

    except Exception as e:
        yield f"**❌ Template error:** `{e}`"
        return

    # ── Stream from GPU ──
    try:
        for text in _generate_on_gpu(inputs=inputs, max_new_tokens=max_new_tokens, thinking=use_think):
            yield format_response(text)
    except Exception as e:
        yield f"**❌ Generation error:** `{e}`"


# ══════════════════════════════════════════════════════════════════════════════
# 6.  GRADIO UI
# ══════════════════════════════════════════════════════════════════════════════

CSS = """
footer { display: none !important; }
.gradio-container { background: #faf8f5 !important; }
#send-btn { background: linear-gradient(135deg, #6d28d9, #7c3aed) !important; border: none !important; border-radius: 12px !important; color: white !important; font-size: 18px !important; min-width: 48px !important; }
#chatbot { border: 1.5px solid #e4dfd8 !important; border-radius: 14px !important; background: rgba(255,255,255,.65) !important; }
.model-box { padding: 10px 14px; border-radius: 10px; border: 1.5px solid rgba(109,40,217,.2); background: linear-gradient(135deg, rgba(109,40,217,.04), rgba(16,185,129,.03)); font-size: 12px; line-height: 1.6; }
.model-box b { color: #6d28d9; }
.model-box .st { font-size: 10px; color: #78716c; margin-top: 4px; }
"""

def _model_info_html(name):
    m = MODELS.get(name, MODELS[DEFAULT_MODEL])
    icon = "⚑" if m["arch"] == "MoE" else "πŸ†"
    return (
        f'<div class="model-box">'
        f'<b>{icon} {name}</b> '
        f'<span style="font-size:9px;padding:2px 6px;border-radius:6px;background:rgba(109,40,217,.08);color:#6d28d9;font-weight:700">{m["arch"]}</span><br>'
        f'<div class="st">{m["active"]} active / {m["total"]} total Β· πŸ‘οΈ Vision Β· {m["ctx"]} context</div>'
        f'<div class="st">{m["desc"]}</div>'
        f'<div class="st" style="margin-top:6px">'
        f'<a href="https://huggingface.co/{m["id"]}" target="_blank" style="color:#6d28d9;font-weight:700;text-decoration:none">πŸ€— Model Card β†—</a> Β· '
        f'<a href="https://deepmind.google/models/gemma/gemma-4/" target="_blank" style="color:#059669;font-weight:700;text-decoration:none">πŸ”¬ DeepMind β†—</a>'
        f'</div></div>'
    )

with gr.Blocks(title="Gemma 4 Playground") as demo:

    with gr.Row():
        gr.Markdown("## πŸ’Ž Gemma 4 Playground\nGoogle DeepMind Β· Apache 2.0 Β· Vision Β· Thinking")
        with gr.Column(scale=0, min_width=120):
            gr.LoginButton(size="sm")

    with gr.Row():
        # ── Sidebar ──
        with gr.Column(scale=0, min_width=280):
            model_dd = gr.Dropdown(
                choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Model",
                info="⚑MoE=Fast | πŸ†Dense=Best quality (μ „ν™˜ μ‹œ 1-2λΆ„)",
            )
            model_info = gr.HTML(value=_model_info_html(DEFAULT_MODEL))
            image_input = gr.Image(label="πŸ‘οΈ Image (Vision)", type="filepath", height=140)
            thinking_radio = gr.Radio(["⚑ Fast", "🧠 Thinking"], value="⚑ Fast", label="Mode")
            with gr.Accordion("βš™οΈ Settings", open=False):
                sys_prompt = gr.Textbox(value=PRESETS["general"], label="System Prompt", lines=2)
                preset_dd = gr.Dropdown(choices=list(PRESETS.keys()), value="general", label="Preset")
                max_tok = gr.Slider(64, 8192, value=4096, step=64, label="Max Tokens")
                temp = gr.Slider(0.0, 1.5, value=0.6, step=0.05, label="Temperature")
                topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
            clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")

        # ── Chat ──
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=600)
            with gr.Row():
                chat_input = gr.Textbox(
                    placeholder="Message Gemma 4…",
                    show_label=False, scale=7, autofocus=True, lines=1, max_lines=4,
                )
                send_btn = gr.Button("↑", variant="primary", scale=0, min_width=48, elem_id="send-btn")

    # ── Events ──
    model_dd.change(fn=_model_info_html, inputs=[model_dd], outputs=[model_info])
    preset_dd.change(fn=lambda k: PRESETS.get(k, PRESETS["general"]), inputs=[preset_dd], outputs=[sys_prompt])

    def user_msg(msg, hist):
        if not msg.strip(): return "", hist
        return "", hist + [{"role": "user", "content": msg}]

    def bot_reply(hist, think, img, sysp, maxt, tmp, tp, model):
        if not hist or hist[-1]["role"] != "user": return hist
        txt, past = hist[-1]["content"], hist[:-1]
        hist = hist + [{"role": "assistant", "content": ""}]
        for chunk in generate_reply(txt, past, think, img, sysp, maxt, tmp, tp, model):
            hist[-1]["content"] = chunk
            yield hist

    ins = [chatbot, thinking_radio, image_input, sys_prompt, max_tok, temp, topp, model_dd]
    send_btn.click(user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False).then(bot_reply, ins, chatbot)
    chat_input.submit(user_msg, [chat_input, chatbot], [chat_input, chatbot], queue=False).then(bot_reply, ins, chatbot)
    clear_btn.click(lambda: [], None, chatbot, queue=False)


# ══════════════════════════════════════════════════════════════════════════════
# 7.  LAUNCH
# ══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
    print(f"[BOOT] Gemma 4 Playground Β· Model: {DEFAULT_MODEL}", flush=True)
    demo.launch(server_name="0.0.0.0", server_port=7860, css=CSS, ssr_mode=False)