# Darwin-4B-David (Gemma4) - Transformers backend + MTI # Multimodal (Vision+Audio+Text) - Apache 2.0 # MTI: +9-11% reasoning accuracy (training-free), Transformers LogitsProcessor import sys, os, signal, time, uuid print(f"[BOOT] Python {sys.version}", flush=True) import base64, re, json from typing import Generator, Optional from threading import Thread from queue import Queue import torch import gradio as gr print(f"[BOOT] gradio {gr.__version__}, torch {torch.__version__}", flush=True) import requests, httpx, uvicorn from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from urllib.parse import urlencode import pathlib, secrets # ============================================================================== # 1. CONFIG # ============================================================================== MODEL_ID = "FINAL-Bench/Darwin-4B-David" MODEL_NAME = "Darwin-4B-David" MODEL_CAP = { "arch": "Gemma4", "active": "4B", "total": "4B", "ctx": "128K", "thinking": True, "vision": True, "audio": True, "max_tokens": 16384, "temp_max": 2.0, } PRESETS = { "general": "You are a highly capable multimodal AI assistant. Think deeply and provide thorough, insightful responses.", "code": "You are an expert software engineer. Write clean, efficient, well-commented code.", "math": "You are a world-class mathematician. Break problems step-by-step. Show full working.", "creative": "You are a brilliant creative writer. Be imaginative, vivid, and engaging.", "vision": "You are an expert at analyzing images. Describe what you see in detail, extract text, and answer questions about visual content.", } # ============================================================================== # 2. MTI -- Minimal Test-Time Intervention (arxiv 2510.13940) # Transformers LogitsProcessor API: __call__(input_ids, scores) -> scores # ============================================================================== from transformers import LogitsProcessor, LogitsProcessorList class MTILogitsProcessor(LogitsProcessor): """ High-entropy (uncertain) tokens only -> apply CFG-style sharpening. Training-free serving-time intervention, ~15% of tokens affected. """ def __init__(self, cfg_scale: float = 1.5, entropy_threshold: float = 2.0): self.cfg_scale = cfg_scale self.entropy_threshold = entropy_threshold self._interventions = 0 self._total = 0 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # scores: (batch_size, vocab_size) self._total += int(scores.shape[0]) probs = torch.softmax(scores, dim=-1) entropy = -(probs * torch.log(probs.clamp_min(1e-10))).sum(dim=-1) # (batch_size,) mask = entropy > self.entropy_threshold # (batch_size,) if bool(mask.any()): mean_logit = scores.mean(dim=-1, keepdim=True) guided = scores + self.cfg_scale * (scores - mean_logit) scores = torch.where(mask.unsqueeze(-1), guided, scores) self._interventions += int(mask.sum().item()) return scores @property def intervention_rate(self): return self._interventions / max(self._total, 1) print("[MTI] MTILogitsProcessor ready (cfg=1.5, threshold=2.0)", flush=True) # ============================================================================== # 3. TOKENIZER + MODEL LOAD (Transformers from source) # ============================================================================== from transformers import ( AutoTokenizer, Gemma4ForConditionalGeneration, TextIteratorStreamer, ) from huggingface_hub import hf_hub_download import tempfile, shutil # ---- Tokenizer with extra_special_tokens patch ---- # Transformers 5.5.x (git) has a regression where tokenizer_config.json with # extra_special_tokens stored as a list crashes during load (.keys() call on # a list). We pre-download, patch if needed, then load from the local copy. _tok_source = MODEL_ID _tok_dir = tempfile.mkdtemp(prefix="darwin_tok_") for _fname in ["tokenizer_config.json", "tokenizer.json", "tokenizer.model", "special_tokens_map.json", "chat_template.jinja"]: try: _p = hf_hub_download(_tok_source, _fname) shutil.copy(_p, os.path.join(_tok_dir, _fname)) except Exception: pass _tc_path = os.path.join(_tok_dir, "tokenizer_config.json") if os.path.exists(_tc_path): try: with open(_tc_path) as f: _tc = json.load(f) est = _tc.get("extra_special_tokens", None) if isinstance(est, list): _tc["extra_special_tokens"] = {tok: tok for tok in est} if est else {} with open(_tc_path, "w") as f: json.dump(_tc, f, indent=2) print(f"[Tokenizer] Patched extra_special_tokens: list({len(est)}) -> dict", flush=True) except Exception as e: print(f"[Tokenizer] Patch skipped: {e}", flush=True) tokenizer = AutoTokenizer.from_pretrained(_tok_dir) print(f"[Tokenizer] Loaded (vocab={len(tokenizer)}) from {_tok_source}", flush=True) # ---- Model ---- print(f"[Transformers] Loading {MODEL_ID} (this may take a while for a 16GB checkpoint)...", flush=True) _load_kwargs = dict( dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, ) try: model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs) except TypeError: # Older transformers signatures used torch_dtype instead of dtype. _load_kwargs["torch_dtype"] = _load_kwargs.pop("dtype") model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs) model.eval() _device = next(model.parameters()).device print(f"[Transformers] Model loaded on {_device}", flush=True) # Resolve max model length (text config for multimodal Gemma4). try: _text_cfg = model.config.get_text_config() except AttributeError: _text_cfg = getattr(model.config, "text_config", model.config) MAX_MODEL_LEN = int(getattr(_text_cfg, "max_position_embeddings", 16384)) # Clamp generation max_tokens to what the runtime can actually hold. MODEL_CAP["max_tokens"] = min(MODEL_CAP["max_tokens"], MAX_MODEL_LEN) print(f"[Transformers] max_position_embeddings={MAX_MODEL_LEN}, " f"max_tokens={MODEL_CAP['max_tokens']}", flush=True) BACKEND_NAME = "Transformers" # ============================================================================== # 4. THINKING MODE HELPERS # ============================================================================== def parse_think_blocks(text: str) -> tuple[str, str]: # Gemma 4 thinking format: <|channel|>thought\n...answer m = re.search(r"<\|channel\|>thought\s*\n(.*?)", text, re.DOTALL) if m: return m.group(1).strip(), text[m.end():].strip() # Fallback: ... m = re.search(r"(.*?)\s*", text, re.DOTALL) if m: return m.group(1).strip(), text[m.end():].strip() return "", text def format_response(raw: str) -> str: chain, answer = parse_think_blocks(raw) if chain: return ( "
\n🧠 Reasoning Chain -- click to expand\n\n" f"{chain}\n\n
\n\n{answer}" ) # Gemma 4 thinking in progress if "<|channel|>thought" in raw and "" not in raw: think_len = len(raw) - raw.index("<|channel|>thought") - 18 return f"🧠 Thinking... ({think_len} chars)" if "" in raw and "" not in raw: think_len = len(raw) - raw.index("") - 7 return f"🧠 Thinking... ({think_len} chars)" return raw # ============================================================================== # 5. GENERATION -- Transformers TextIteratorStreamer + MTI # ============================================================================== def _engine_generate(prompt_text: str, gen_kwargs: dict, mti: MTILogitsProcessor, queue: Queue): """Run model.generate in a background thread and stream tokens into queue.""" try: inputs = tokenizer(prompt_text, return_tensors="pt").to(_device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=120.0, ) full_kwargs = { **inputs, "streamer": streamer, "logits_processor": LogitsProcessorList([mti]), "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, **gen_kwargs, } gen_thread = Thread(target=model.generate, kwargs=full_kwargs) gen_thread.start() for chunk in streamer: if chunk: queue.put(chunk) gen_thread.join() queue.put(None) except Exception as e: queue.put(f"\n\n**❌ Engine error:** `{e}`") queue.put(None) def generate_reply( message, history, thinking_mode, image_input, system_prompt, max_new_tokens, temperature, top_p, ) -> Generator[str, None, None]: max_new_tokens = min(int(max_new_tokens), MODEL_CAP["max_tokens"]) temperature = min(float(temperature), MODEL_CAP["temp_max"]) messages: list[dict] = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt.strip()}) for turn in history: if isinstance(turn, dict): role = turn.get("role", "") raw = turn.get("content") or "" text = (" ".join(p.get("text","") for p in raw if isinstance(p,dict) and p.get("type")=="text") if isinstance(raw, list) else str(raw)) if role == "user": messages.append({"role":"user","content":text}) elif role == "assistant": _, clean = parse_think_blocks(text) messages.append({"role":"assistant","content":clean}) else: try: u, a = (turn[0] or None), (turn[1] if len(turn)>1 else None) except: continue def _txt(v): if v is None: return None if isinstance(v, list): return " ".join(p.get("text","") for p in v if isinstance(p,dict) and p.get("type")=="text") return str(v) ut, at = _txt(u), _txt(a) if ut: messages.append({"role":"user","content":ut}) if at: _, clean = parse_think_blocks(at) messages.append({"role":"assistant","content":clean}) messages.append({"role": "user", "content": message}) try: prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except Exception as e: yield f"**❌ Template error:** `{e}`" return input_len = len(tokenizer.encode(prompt_text)) print(f"[GEN] tokens={input_len}, max_new={max_new_tokens}, " f"temp={temperature}, MTI=on, Backend={BACKEND_NAME}", flush=True) mti = MTILogitsProcessor(cfg_scale=1.5, entropy_threshold=2.0) do_sample = float(temperature) > 0.01 gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=max(float(temperature), 0.01) if do_sample else 1.0, top_p=float(top_p), ) queue: Queue = Queue() thread = Thread(target=_engine_generate, args=(prompt_text, gen_kwargs, mti, queue)) thread.start() output = "" try: while True: chunk = queue.get(timeout=120) if chunk is None: break output += chunk yield format_response(output) except Exception as e: if not output: yield f"**❌ Streaming error:** `{e}`" thread.join(timeout=5) if output: mti_rate = f"{mti.intervention_rate*100:.1f}%" print(f"[GEN] Done -- {len(output)} chars, MTI={mti_rate} " f"({mti._interventions}/{mti._total})", flush=True) yield format_response(output) else: yield "**⚠️ λͺ¨λΈμ΄ 빈 응닡을 λ°˜ν™˜ν–ˆμŠ΅λ‹ˆλ‹€.** λ‹€μ‹œ μ‹œλ„ν•΄ μ£Όμ„Έμš”." # ============================================================================== # 6. GRADIO BLOCKS # ============================================================================== with gr.Blocks(title=MODEL_NAME) as gradio_demo: thinking_toggle = gr.Radio( choices=["⚑ Fast Mode", "🧠 Thinking Mode"], value="⚑ Fast Mode", visible=False, ) image_input = gr.Textbox(value="", visible=False) system_prompt = gr.Textbox(value=PRESETS["general"], visible=False) max_new_tokens = gr.Slider(minimum=64, maximum=MODEL_CAP["max_tokens"], value=4096, visible=False) temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, visible=False) top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, visible=False) gr.ChatInterface( fn=generate_reply, api_name="chat", additional_inputs=[ thinking_toggle, image_input, system_prompt, max_new_tokens, temperature, top_p, ], ) # ============================================================================== # 7. FASTAPI # ============================================================================== fapp = FastAPI() SESSIONS: dict[str, dict] = {} HTML = pathlib.Path(__file__).parent / "index.html" CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "") CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "") SPACE_HOST = os.getenv("SPACE_HOST", "localhost:7860") REDIRECT_URI = f"https://{SPACE_HOST}/login/callback" HF_AUTH_URL = "https://huggingface.co/oauth/authorize" HF_TOKEN_URL = "https://huggingface.co/oauth/token" HF_USER_URL = "https://huggingface.co/oauth/userinfo" SCOPES = os.getenv("OAUTH_SCOPES", "openid profile") print(f"[OAuth] CLIENT_ID={bool(CLIENT_ID)}, SPACE_HOST={SPACE_HOST}") def _sid(req): return req.cookies.get("mc_session") def _user(req): sid = _sid(req) return SESSIONS.get(sid) if sid else None @fapp.get("/") async def root(request: Request): html = HTML.read_text(encoding="utf-8") if HTML.exists() else "

index.html missing

" return HTMLResponse(html) @fapp.get("/oauth/user") async def oauth_user(request: Request): u = _user(request) return JSONResponse(u) if u else JSONResponse({"logged_in": False}, status_code=401) @fapp.get("/oauth/login") async def oauth_login(request: Request): if not CLIENT_ID: return RedirectResponse("/?oauth_error=not_configured") state = secrets.token_urlsafe(16) params = {"response_type":"code","client_id":CLIENT_ID,"redirect_uri":REDIRECT_URI,"scope":SCOPES,"state":state} return RedirectResponse(f"{HF_AUTH_URL}?{urlencode(params)}", status_code=302) @fapp.get("/login/callback") async def oauth_callback(code: str = "", error: str = "", state: str = ""): if error or not code: return RedirectResponse("/?auth_error=1") basic = base64.b64encode(f"{CLIENT_ID}:{CLIENT_SECRET}".encode()).decode() async with httpx.AsyncClient() as client: tok = await client.post(HF_TOKEN_URL, data={"grant_type":"authorization_code","code":code,"redirect_uri":REDIRECT_URI}, headers={"Accept":"application/json","Authorization":f"Basic {basic}"}) if tok.status_code != 200: return RedirectResponse("/?auth_error=1") access_token = tok.json().get("access_token", "") if not access_token: return RedirectResponse("/?auth_error=1") uinfo = await client.get(HF_USER_URL, headers={"Authorization":f"Bearer {access_token}"}) if uinfo.status_code != 200: return RedirectResponse("/?auth_error=1") user = uinfo.json() sid = secrets.token_urlsafe(32) SESSIONS[sid] = { "logged_in": True, "username": user.get("preferred_username", user.get("name", "User")), "name": user.get("name", ""), "avatar": user.get("picture", ""), "profile": f"https://huggingface.co/{user.get('preferred_username', '')}", } resp = RedirectResponse("/") resp.set_cookie("mc_session", sid, httponly=True, samesite="lax", secure=True, max_age=60*60*24*7) return resp @fapp.get("/oauth/logout") async def oauth_logout(request: Request): sid = _sid(request) if sid and sid in SESSIONS: del SESSIONS[sid] resp = RedirectResponse("/") resp.delete_cookie("mc_session") return resp @fapp.get("/health") async def health(): return { "status": "ok", "model": MODEL_ID, "backend": BACKEND_NAME, "mti": "enabled", "max_tokens": MODEL_CAP["max_tokens"], "max_model_len": MAX_MODEL_LEN, "multimodal": "vision+audio", } BRAVE_API_KEY = os.getenv("BRAVE_API_KEY", "") @fapp.post("/api/search") async def api_search(request: Request): body = await request.json() query = body.get("query", "").strip() if not query: return JSONResponse({"error": "empty"}, 400) if not BRAVE_API_KEY: return JSONResponse({"error": "no key"}, 500) try: r = requests.get("https://api.search.brave.com/res/v1/web/search", headers={"X-Subscription-Token": BRAVE_API_KEY, "Accept": "application/json"}, params={"q": query, "count": 5}, timeout=10) r.raise_for_status() results = r.json().get("web", {}).get("results", []) return JSONResponse({"results": [{"title":i.get("title",""),"desc":i.get("description",""),"url":i.get("url","")} for i in results[:5]]}) except Exception as e: return JSONResponse({"error": str(e)}, 500) @fapp.post("/api/extract-pdf") async def api_extract_pdf(request: Request): try: body = await request.json() b64 = body.get("data", "") if "," in b64: b64 = b64.split(",", 1)[1] pdf_bytes = base64.b64decode(b64) text = "" try: import fitz doc = fitz.open(stream=pdf_bytes, filetype="pdf") for page in doc: text += page.get_text() + "\n" except ImportError: text = pdf_bytes.decode("utf-8", errors="ignore") return JSONResponse({"text": text.strip()[:8000], "chars": len(text)}) except Exception as e: return JSONResponse({"error": str(e)}, 500) # ============================================================================== # 8. MOUNT & RUN # ============================================================================== app = gr.mount_gradio_app(fapp, gradio_demo, path="/gradio") def _shutdown(sig, frame): print("[BOOT] Shutdown", flush=True) sys.exit(0) signal.signal(signal.SIGTERM, _shutdown) signal.signal(signal.SIGINT, _shutdown) if __name__ == "__main__": print(f"[BOOT] {MODEL_NAME} - {BACKEND_NAME} - MTI - max_len={MAX_MODEL_LEN} - Ready", flush=True) uvicorn.run(app, host="0.0.0.0", port=7860)