Spaces:
Sleeping
Sleeping
| # Darwin-4B-David (Gemma4) - Transformers backend + MTI | |
| # Multimodal (Vision+Audio+Text) - Apache 2.0 | |
| # MTI: +9-11% reasoning accuracy (training-free), Transformers LogitsProcessor | |
| import sys, os, signal, time, uuid | |
| print(f"[BOOT] Python {sys.version}", flush=True) | |
| import base64, re, json | |
| from typing import Generator, Optional | |
| from threading import Thread | |
| from queue import Queue | |
| import torch | |
| import gradio as gr | |
| print(f"[BOOT] gradio {gr.__version__}, torch {torch.__version__}", flush=True) | |
| import requests, httpx, uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse | |
| from urllib.parse import urlencode | |
| import pathlib, secrets | |
| # ============================================================================== | |
| # 1. CONFIG | |
| # ============================================================================== | |
| MODEL_ID = "FINAL-Bench/Darwin-4B-David" | |
| MODEL_NAME = "Darwin-4B-David" | |
| MODEL_CAP = { | |
| "arch": "Gemma4", "active": "4B", "total": "4B", | |
| "ctx": "128K", "thinking": True, "vision": True, "audio": True, | |
| "max_tokens": 16384, "temp_max": 2.0, | |
| } | |
| PRESETS = { | |
| "general": "You are a highly capable multimodal AI assistant. Think deeply and provide thorough, insightful responses.", | |
| "code": "You are an expert software engineer. Write clean, efficient, well-commented code.", | |
| "math": "You are a world-class mathematician. Break problems step-by-step. Show full working.", | |
| "creative": "You are a brilliant creative writer. Be imaginative, vivid, and engaging.", | |
| "vision": "You are an expert at analyzing images. Describe what you see in detail, extract text, and answer questions about visual content.", | |
| } | |
| # ============================================================================== | |
| # 2. MTI -- Minimal Test-Time Intervention (arxiv 2510.13940) | |
| # Transformers LogitsProcessor API: __call__(input_ids, scores) -> scores | |
| # ============================================================================== | |
| from transformers import LogitsProcessor, LogitsProcessorList | |
| class MTILogitsProcessor(LogitsProcessor): | |
| """ | |
| High-entropy (uncertain) tokens only -> apply CFG-style sharpening. | |
| Training-free serving-time intervention, ~15% of tokens affected. | |
| """ | |
| def __init__(self, cfg_scale: float = 1.5, entropy_threshold: float = 2.0): | |
| self.cfg_scale = cfg_scale | |
| self.entropy_threshold = entropy_threshold | |
| self._interventions = 0 | |
| self._total = 0 | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # scores: (batch_size, vocab_size) | |
| self._total += int(scores.shape[0]) | |
| probs = torch.softmax(scores, dim=-1) | |
| entropy = -(probs * torch.log(probs.clamp_min(1e-10))).sum(dim=-1) # (batch_size,) | |
| mask = entropy > self.entropy_threshold # (batch_size,) | |
| if bool(mask.any()): | |
| mean_logit = scores.mean(dim=-1, keepdim=True) | |
| guided = scores + self.cfg_scale * (scores - mean_logit) | |
| scores = torch.where(mask.unsqueeze(-1), guided, scores) | |
| self._interventions += int(mask.sum().item()) | |
| return scores | |
| def intervention_rate(self): | |
| return self._interventions / max(self._total, 1) | |
| print("[MTI] MTILogitsProcessor ready (cfg=1.5, threshold=2.0)", flush=True) | |
| # ============================================================================== | |
| # 3. TOKENIZER + MODEL LOAD (Transformers from source) | |
| # ============================================================================== | |
| from transformers import ( | |
| AutoTokenizer, | |
| Gemma4ForConditionalGeneration, | |
| TextIteratorStreamer, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import tempfile, shutil | |
| # ---- Tokenizer with extra_special_tokens patch ---- | |
| # Transformers 5.5.x (git) has a regression where tokenizer_config.json with | |
| # extra_special_tokens stored as a list crashes during load (.keys() call on | |
| # a list). We pre-download, patch if needed, then load from the local copy. | |
| _tok_source = MODEL_ID | |
| _tok_dir = tempfile.mkdtemp(prefix="darwin_tok_") | |
| for _fname in ["tokenizer_config.json", "tokenizer.json", "tokenizer.model", | |
| "special_tokens_map.json", "chat_template.jinja"]: | |
| try: | |
| _p = hf_hub_download(_tok_source, _fname) | |
| shutil.copy(_p, os.path.join(_tok_dir, _fname)) | |
| except Exception: | |
| pass | |
| _tc_path = os.path.join(_tok_dir, "tokenizer_config.json") | |
| if os.path.exists(_tc_path): | |
| try: | |
| with open(_tc_path) as f: | |
| _tc = json.load(f) | |
| est = _tc.get("extra_special_tokens", None) | |
| if isinstance(est, list): | |
| _tc["extra_special_tokens"] = {tok: tok for tok in est} if est else {} | |
| with open(_tc_path, "w") as f: | |
| json.dump(_tc, f, indent=2) | |
| print(f"[Tokenizer] Patched extra_special_tokens: list({len(est)}) -> dict", flush=True) | |
| except Exception as e: | |
| print(f"[Tokenizer] Patch skipped: {e}", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(_tok_dir) | |
| print(f"[Tokenizer] Loaded (vocab={len(tokenizer)}) from {_tok_source}", flush=True) | |
| # ---- Model ---- | |
| print(f"[Transformers] Loading {MODEL_ID} (this may take a while for a 16GB checkpoint)...", flush=True) | |
| _load_kwargs = dict( | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| try: | |
| model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs) | |
| except TypeError: | |
| # Older transformers signatures used torch_dtype instead of dtype. | |
| _load_kwargs["torch_dtype"] = _load_kwargs.pop("dtype") | |
| model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs) | |
| model.eval() | |
| _device = next(model.parameters()).device | |
| print(f"[Transformers] Model loaded on {_device}", flush=True) | |
| # Resolve max model length (text config for multimodal Gemma4). | |
| try: | |
| _text_cfg = model.config.get_text_config() | |
| except AttributeError: | |
| _text_cfg = getattr(model.config, "text_config", model.config) | |
| MAX_MODEL_LEN = int(getattr(_text_cfg, "max_position_embeddings", 16384)) | |
| # Clamp generation max_tokens to what the runtime can actually hold. | |
| MODEL_CAP["max_tokens"] = min(MODEL_CAP["max_tokens"], MAX_MODEL_LEN) | |
| print(f"[Transformers] max_position_embeddings={MAX_MODEL_LEN}, " | |
| f"max_tokens={MODEL_CAP['max_tokens']}", flush=True) | |
| BACKEND_NAME = "Transformers" | |
| # ============================================================================== | |
| # 4. THINKING MODE HELPERS | |
| # ============================================================================== | |
| def parse_think_blocks(text: str) -> tuple[str, str]: | |
| # Gemma 4 thinking format: <|channel|>thought\n...<channel|>answer | |
| m = re.search(r"<\|channel\|>thought\s*\n(.*?)<channel\|>", text, re.DOTALL) | |
| if m: | |
| return m.group(1).strip(), text[m.end():].strip() | |
| # Fallback: <think>...</think> | |
| m = re.search(r"<think>(.*?)</think>\s*", text, re.DOTALL) | |
| if m: | |
| return m.group(1).strip(), text[m.end():].strip() | |
| return "", text | |
| def format_response(raw: str) -> str: | |
| chain, answer = parse_think_blocks(raw) | |
| if chain: | |
| return ( | |
| "<details>\n<summary>🧠 Reasoning Chain -- click to expand</summary>\n\n" | |
| f"{chain}\n\n</details>\n\n{answer}" | |
| ) | |
| # Gemma 4 thinking in progress | |
| if "<|channel|>thought" in raw and "<channel|>" not in raw: | |
| think_len = len(raw) - raw.index("<|channel|>thought") - 18 | |
| return f"🧠 Thinking... ({think_len} chars)" | |
| if "<think>" in raw and "</think>" not in raw: | |
| think_len = len(raw) - raw.index("<think>") - 7 | |
| return f"🧠 Thinking... ({think_len} chars)" | |
| return raw | |
| # ============================================================================== | |
| # 5. GENERATION -- Transformers TextIteratorStreamer + MTI | |
| # ============================================================================== | |
| def _engine_generate(prompt_text: str, gen_kwargs: dict, mti: MTILogitsProcessor, queue: Queue): | |
| """Run model.generate in a background thread and stream tokens into queue.""" | |
| try: | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to(_device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=120.0, | |
| ) | |
| full_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "logits_processor": LogitsProcessorList([mti]), | |
| "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| **gen_kwargs, | |
| } | |
| gen_thread = Thread(target=model.generate, kwargs=full_kwargs) | |
| gen_thread.start() | |
| for chunk in streamer: | |
| if chunk: | |
| queue.put(chunk) | |
| gen_thread.join() | |
| queue.put(None) | |
| except Exception as e: | |
| queue.put(f"\n\n**❌ Engine error:** `{e}`") | |
| queue.put(None) | |
| def generate_reply( | |
| message, history, thinking_mode, image_input, | |
| system_prompt, max_new_tokens, temperature, top_p, | |
| ) -> Generator[str, None, None]: | |
| max_new_tokens = min(int(max_new_tokens), MODEL_CAP["max_tokens"]) | |
| temperature = min(float(temperature), MODEL_CAP["temp_max"]) | |
| messages: list[dict] = [] | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": system_prompt.strip()}) | |
| for turn in history: | |
| if isinstance(turn, dict): | |
| role = turn.get("role", "") | |
| raw = turn.get("content") or "" | |
| text = (" ".join(p.get("text","") for p in raw | |
| if isinstance(p,dict) and p.get("type")=="text") | |
| if isinstance(raw, list) else str(raw)) | |
| if role == "user": | |
| messages.append({"role":"user","content":text}) | |
| elif role == "assistant": | |
| _, clean = parse_think_blocks(text) | |
| messages.append({"role":"assistant","content":clean}) | |
| else: | |
| try: u, a = (turn[0] or None), (turn[1] if len(turn)>1 else None) | |
| except: continue | |
| def _txt(v): | |
| if v is None: return None | |
| if isinstance(v, list): | |
| return " ".join(p.get("text","") for p in v if isinstance(p,dict) and p.get("type")=="text") | |
| return str(v) | |
| ut, at = _txt(u), _txt(a) | |
| if ut: messages.append({"role":"user","content":ut}) | |
| if at: | |
| _, clean = parse_think_blocks(at) | |
| messages.append({"role":"assistant","content":clean}) | |
| messages.append({"role": "user", "content": message}) | |
| try: | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| except Exception as e: | |
| yield f"**❌ Template error:** `{e}`" | |
| return | |
| input_len = len(tokenizer.encode(prompt_text)) | |
| print(f"[GEN] tokens={input_len}, max_new={max_new_tokens}, " | |
| f"temp={temperature}, MTI=on, Backend={BACKEND_NAME}", flush=True) | |
| mti = MTILogitsProcessor(cfg_scale=1.5, entropy_threshold=2.0) | |
| do_sample = float(temperature) > 0.01 | |
| gen_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=max(float(temperature), 0.01) if do_sample else 1.0, | |
| top_p=float(top_p), | |
| ) | |
| queue: Queue = Queue() | |
| thread = Thread(target=_engine_generate, args=(prompt_text, gen_kwargs, mti, queue)) | |
| thread.start() | |
| output = "" | |
| try: | |
| while True: | |
| chunk = queue.get(timeout=120) | |
| if chunk is None: break | |
| output += chunk | |
| yield format_response(output) | |
| except Exception as e: | |
| if not output: | |
| yield f"**❌ Streaming error:** `{e}`" | |
| thread.join(timeout=5) | |
| if output: | |
| mti_rate = f"{mti.intervention_rate*100:.1f}%" | |
| print(f"[GEN] Done -- {len(output)} chars, MTI={mti_rate} " | |
| f"({mti._interventions}/{mti._total})", flush=True) | |
| yield format_response(output) | |
| else: | |
| yield "**⚠️ 모델이 빈 응답을 반환했습니다.** 다시 시도해 주세요." | |
| # ============================================================================== | |
| # 6. GRADIO BLOCKS | |
| # ============================================================================== | |
| with gr.Blocks(title=MODEL_NAME) as gradio_demo: | |
| thinking_toggle = gr.Radio( | |
| choices=["⚡ Fast Mode", "🧠 Thinking Mode"], | |
| value="⚡ Fast Mode", visible=False, | |
| ) | |
| image_input = gr.Textbox(value="", visible=False) | |
| system_prompt = gr.Textbox(value=PRESETS["general"], visible=False) | |
| max_new_tokens = gr.Slider(minimum=64, maximum=MODEL_CAP["max_tokens"], value=4096, visible=False) | |
| temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, visible=False) | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, visible=False) | |
| gr.ChatInterface( | |
| fn=generate_reply, api_name="chat", | |
| additional_inputs=[ | |
| thinking_toggle, image_input, | |
| system_prompt, max_new_tokens, temperature, top_p, | |
| ], | |
| ) | |
| # ============================================================================== | |
| # 7. FASTAPI | |
| # ============================================================================== | |
| fapp = FastAPI() | |
| SESSIONS: dict[str, dict] = {} | |
| HTML = pathlib.Path(__file__).parent / "index.html" | |
| CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "") | |
| CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "") | |
| SPACE_HOST = os.getenv("SPACE_HOST", "localhost:7860") | |
| REDIRECT_URI = f"https://{SPACE_HOST}/login/callback" | |
| HF_AUTH_URL = "https://huggingface.co/oauth/authorize" | |
| HF_TOKEN_URL = "https://huggingface.co/oauth/token" | |
| HF_USER_URL = "https://huggingface.co/oauth/userinfo" | |
| SCOPES = os.getenv("OAUTH_SCOPES", "openid profile") | |
| print(f"[OAuth] CLIENT_ID={bool(CLIENT_ID)}, SPACE_HOST={SPACE_HOST}") | |
| def _sid(req): return req.cookies.get("mc_session") | |
| def _user(req): | |
| sid = _sid(req) | |
| return SESSIONS.get(sid) if sid else None | |
| async def root(request: Request): | |
| html = HTML.read_text(encoding="utf-8") if HTML.exists() else "<h2>index.html missing</h2>" | |
| return HTMLResponse(html) | |
| async def oauth_user(request: Request): | |
| u = _user(request) | |
| return JSONResponse(u) if u else JSONResponse({"logged_in": False}, status_code=401) | |
| async def oauth_login(request: Request): | |
| if not CLIENT_ID: return RedirectResponse("/?oauth_error=not_configured") | |
| state = secrets.token_urlsafe(16) | |
| params = {"response_type":"code","client_id":CLIENT_ID,"redirect_uri":REDIRECT_URI,"scope":SCOPES,"state":state} | |
| return RedirectResponse(f"{HF_AUTH_URL}?{urlencode(params)}", status_code=302) | |
| async def oauth_callback(code: str = "", error: str = "", state: str = ""): | |
| if error or not code: return RedirectResponse("/?auth_error=1") | |
| basic = base64.b64encode(f"{CLIENT_ID}:{CLIENT_SECRET}".encode()).decode() | |
| async with httpx.AsyncClient() as client: | |
| tok = await client.post(HF_TOKEN_URL, data={"grant_type":"authorization_code","code":code,"redirect_uri":REDIRECT_URI}, | |
| headers={"Accept":"application/json","Authorization":f"Basic {basic}"}) | |
| if tok.status_code != 200: return RedirectResponse("/?auth_error=1") | |
| access_token = tok.json().get("access_token", "") | |
| if not access_token: return RedirectResponse("/?auth_error=1") | |
| uinfo = await client.get(HF_USER_URL, headers={"Authorization":f"Bearer {access_token}"}) | |
| if uinfo.status_code != 200: return RedirectResponse("/?auth_error=1") | |
| user = uinfo.json() | |
| sid = secrets.token_urlsafe(32) | |
| SESSIONS[sid] = { | |
| "logged_in": True, | |
| "username": user.get("preferred_username", user.get("name", "User")), | |
| "name": user.get("name", ""), | |
| "avatar": user.get("picture", ""), | |
| "profile": f"https://huggingface.co/{user.get('preferred_username', '')}", | |
| } | |
| resp = RedirectResponse("/") | |
| resp.set_cookie("mc_session", sid, httponly=True, samesite="lax", secure=True, max_age=60*60*24*7) | |
| return resp | |
| async def oauth_logout(request: Request): | |
| sid = _sid(request) | |
| if sid and sid in SESSIONS: del SESSIONS[sid] | |
| resp = RedirectResponse("/") | |
| resp.delete_cookie("mc_session") | |
| return resp | |
| async def health(): | |
| return { | |
| "status": "ok", "model": MODEL_ID, | |
| "backend": BACKEND_NAME, | |
| "mti": "enabled", | |
| "max_tokens": MODEL_CAP["max_tokens"], | |
| "max_model_len": MAX_MODEL_LEN, | |
| "multimodal": "vision+audio", | |
| } | |
| BRAVE_API_KEY = os.getenv("BRAVE_API_KEY", "") | |
| async def api_search(request: Request): | |
| body = await request.json() | |
| query = body.get("query", "").strip() | |
| if not query: return JSONResponse({"error": "empty"}, 400) | |
| if not BRAVE_API_KEY: return JSONResponse({"error": "no key"}, 500) | |
| try: | |
| r = requests.get("https://api.search.brave.com/res/v1/web/search", | |
| headers={"X-Subscription-Token": BRAVE_API_KEY, "Accept": "application/json"}, | |
| params={"q": query, "count": 5}, timeout=10) | |
| r.raise_for_status() | |
| results = r.json().get("web", {}).get("results", []) | |
| return JSONResponse({"results": [{"title":i.get("title",""),"desc":i.get("description",""),"url":i.get("url","")} for i in results[:5]]}) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}, 500) | |
| async def api_extract_pdf(request: Request): | |
| try: | |
| body = await request.json() | |
| b64 = body.get("data", "") | |
| if "," in b64: b64 = b64.split(",", 1)[1] | |
| pdf_bytes = base64.b64decode(b64) | |
| text = "" | |
| try: | |
| import fitz | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| for page in doc: text += page.get_text() + "\n" | |
| except ImportError: | |
| text = pdf_bytes.decode("utf-8", errors="ignore") | |
| return JSONResponse({"text": text.strip()[:8000], "chars": len(text)}) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}, 500) | |
| # ============================================================================== | |
| # 8. MOUNT & RUN | |
| # ============================================================================== | |
| app = gr.mount_gradio_app(fapp, gradio_demo, path="/gradio") | |
| def _shutdown(sig, frame): | |
| print("[BOOT] Shutdown", flush=True) | |
| sys.exit(0) | |
| signal.signal(signal.SIGTERM, _shutdown) | |
| signal.signal(signal.SIGINT, _shutdown) | |
| if __name__ == "__main__": | |
| print(f"[BOOT] {MODEL_NAME} - {BACKEND_NAME} - MTI - max_len={MAX_MODEL_LEN} - Ready", flush=True) | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |