Darwin-4B-david / app.py
SeaWolf-AI's picture
Fix gemma4 runtime error: switch to Transformers backend + Darwin-4B-David
c8a5e69
# Darwin-4B-David (Gemma4) - Transformers backend + MTI
# Multimodal (Vision+Audio+Text) - Apache 2.0
# MTI: +9-11% reasoning accuracy (training-free), Transformers LogitsProcessor
import sys, os, signal, time, uuid
print(f"[BOOT] Python {sys.version}", flush=True)
import base64, re, json
from typing import Generator, Optional
from threading import Thread
from queue import Queue
import torch
import gradio as gr
print(f"[BOOT] gradio {gr.__version__}, torch {torch.__version__}", flush=True)
import requests, httpx, uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from urllib.parse import urlencode
import pathlib, secrets
# ==============================================================================
# 1. CONFIG
# ==============================================================================
MODEL_ID = "FINAL-Bench/Darwin-4B-David"
MODEL_NAME = "Darwin-4B-David"
MODEL_CAP = {
"arch": "Gemma4", "active": "4B", "total": "4B",
"ctx": "128K", "thinking": True, "vision": True, "audio": True,
"max_tokens": 16384, "temp_max": 2.0,
}
PRESETS = {
"general": "You are a highly capable multimodal AI assistant. Think deeply and provide thorough, insightful responses.",
"code": "You are an expert software engineer. Write clean, efficient, well-commented code.",
"math": "You are a world-class mathematician. Break problems step-by-step. Show full working.",
"creative": "You are a brilliant creative writer. Be imaginative, vivid, and engaging.",
"vision": "You are an expert at analyzing images. Describe what you see in detail, extract text, and answer questions about visual content.",
}
# ==============================================================================
# 2. MTI -- Minimal Test-Time Intervention (arxiv 2510.13940)
# Transformers LogitsProcessor API: __call__(input_ids, scores) -> scores
# ==============================================================================
from transformers import LogitsProcessor, LogitsProcessorList
class MTILogitsProcessor(LogitsProcessor):
"""
High-entropy (uncertain) tokens only -> apply CFG-style sharpening.
Training-free serving-time intervention, ~15% of tokens affected.
"""
def __init__(self, cfg_scale: float = 1.5, entropy_threshold: float = 2.0):
self.cfg_scale = cfg_scale
self.entropy_threshold = entropy_threshold
self._interventions = 0
self._total = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# scores: (batch_size, vocab_size)
self._total += int(scores.shape[0])
probs = torch.softmax(scores, dim=-1)
entropy = -(probs * torch.log(probs.clamp_min(1e-10))).sum(dim=-1) # (batch_size,)
mask = entropy > self.entropy_threshold # (batch_size,)
if bool(mask.any()):
mean_logit = scores.mean(dim=-1, keepdim=True)
guided = scores + self.cfg_scale * (scores - mean_logit)
scores = torch.where(mask.unsqueeze(-1), guided, scores)
self._interventions += int(mask.sum().item())
return scores
@property
def intervention_rate(self):
return self._interventions / max(self._total, 1)
print("[MTI] MTILogitsProcessor ready (cfg=1.5, threshold=2.0)", flush=True)
# ==============================================================================
# 3. TOKENIZER + MODEL LOAD (Transformers from source)
# ==============================================================================
from transformers import (
AutoTokenizer,
Gemma4ForConditionalGeneration,
TextIteratorStreamer,
)
from huggingface_hub import hf_hub_download
import tempfile, shutil
# ---- Tokenizer with extra_special_tokens patch ----
# Transformers 5.5.x (git) has a regression where tokenizer_config.json with
# extra_special_tokens stored as a list crashes during load (.keys() call on
# a list). We pre-download, patch if needed, then load from the local copy.
_tok_source = MODEL_ID
_tok_dir = tempfile.mkdtemp(prefix="darwin_tok_")
for _fname in ["tokenizer_config.json", "tokenizer.json", "tokenizer.model",
"special_tokens_map.json", "chat_template.jinja"]:
try:
_p = hf_hub_download(_tok_source, _fname)
shutil.copy(_p, os.path.join(_tok_dir, _fname))
except Exception:
pass
_tc_path = os.path.join(_tok_dir, "tokenizer_config.json")
if os.path.exists(_tc_path):
try:
with open(_tc_path) as f:
_tc = json.load(f)
est = _tc.get("extra_special_tokens", None)
if isinstance(est, list):
_tc["extra_special_tokens"] = {tok: tok for tok in est} if est else {}
with open(_tc_path, "w") as f:
json.dump(_tc, f, indent=2)
print(f"[Tokenizer] Patched extra_special_tokens: list({len(est)}) -> dict", flush=True)
except Exception as e:
print(f"[Tokenizer] Patch skipped: {e}", flush=True)
tokenizer = AutoTokenizer.from_pretrained(_tok_dir)
print(f"[Tokenizer] Loaded (vocab={len(tokenizer)}) from {_tok_source}", flush=True)
# ---- Model ----
print(f"[Transformers] Loading {MODEL_ID} (this may take a while for a 16GB checkpoint)...", flush=True)
_load_kwargs = dict(
dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
)
try:
model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs)
except TypeError:
# Older transformers signatures used torch_dtype instead of dtype.
_load_kwargs["torch_dtype"] = _load_kwargs.pop("dtype")
model = Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, **_load_kwargs)
model.eval()
_device = next(model.parameters()).device
print(f"[Transformers] Model loaded on {_device}", flush=True)
# Resolve max model length (text config for multimodal Gemma4).
try:
_text_cfg = model.config.get_text_config()
except AttributeError:
_text_cfg = getattr(model.config, "text_config", model.config)
MAX_MODEL_LEN = int(getattr(_text_cfg, "max_position_embeddings", 16384))
# Clamp generation max_tokens to what the runtime can actually hold.
MODEL_CAP["max_tokens"] = min(MODEL_CAP["max_tokens"], MAX_MODEL_LEN)
print(f"[Transformers] max_position_embeddings={MAX_MODEL_LEN}, "
f"max_tokens={MODEL_CAP['max_tokens']}", flush=True)
BACKEND_NAME = "Transformers"
# ==============================================================================
# 4. THINKING MODE HELPERS
# ==============================================================================
def parse_think_blocks(text: str) -> tuple[str, str]:
# Gemma 4 thinking format: <|channel|>thought\n...<channel|>answer
m = re.search(r"<\|channel\|>thought\s*\n(.*?)<channel\|>", text, re.DOTALL)
if m:
return m.group(1).strip(), text[m.end():].strip()
# Fallback: <think>...</think>
m = re.search(r"<think>(.*?)</think>\s*", text, re.DOTALL)
if m:
return m.group(1).strip(), text[m.end():].strip()
return "", text
def format_response(raw: str) -> str:
chain, answer = parse_think_blocks(raw)
if chain:
return (
"<details>\n<summary>🧠 Reasoning Chain -- click to expand</summary>\n\n"
f"{chain}\n\n</details>\n\n{answer}"
)
# Gemma 4 thinking in progress
if "<|channel|>thought" in raw and "<channel|>" not in raw:
think_len = len(raw) - raw.index("<|channel|>thought") - 18
return f"🧠 Thinking... ({think_len} chars)"
if "<think>" in raw and "</think>" not in raw:
think_len = len(raw) - raw.index("<think>") - 7
return f"🧠 Thinking... ({think_len} chars)"
return raw
# ==============================================================================
# 5. GENERATION -- Transformers TextIteratorStreamer + MTI
# ==============================================================================
def _engine_generate(prompt_text: str, gen_kwargs: dict, mti: MTILogitsProcessor, queue: Queue):
"""Run model.generate in a background thread and stream tokens into queue."""
try:
inputs = tokenizer(prompt_text, return_tensors="pt").to(_device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=120.0,
)
full_kwargs = {
**inputs,
"streamer": streamer,
"logits_processor": LogitsProcessorList([mti]),
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
**gen_kwargs,
}
gen_thread = Thread(target=model.generate, kwargs=full_kwargs)
gen_thread.start()
for chunk in streamer:
if chunk:
queue.put(chunk)
gen_thread.join()
queue.put(None)
except Exception as e:
queue.put(f"\n\n**❌ Engine error:** `{e}`")
queue.put(None)
def generate_reply(
message, history, thinking_mode, image_input,
system_prompt, max_new_tokens, temperature, top_p,
) -> Generator[str, None, None]:
max_new_tokens = min(int(max_new_tokens), MODEL_CAP["max_tokens"])
temperature = min(float(temperature), MODEL_CAP["temp_max"])
messages: list[dict] = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
for turn in history:
if isinstance(turn, dict):
role = turn.get("role", "")
raw = turn.get("content") or ""
text = (" ".join(p.get("text","") for p in raw
if isinstance(p,dict) and p.get("type")=="text")
if isinstance(raw, list) else str(raw))
if role == "user":
messages.append({"role":"user","content":text})
elif role == "assistant":
_, clean = parse_think_blocks(text)
messages.append({"role":"assistant","content":clean})
else:
try: u, a = (turn[0] or None), (turn[1] if len(turn)>1 else None)
except: continue
def _txt(v):
if v is None: return None
if isinstance(v, list):
return " ".join(p.get("text","") for p in v if isinstance(p,dict) and p.get("type")=="text")
return str(v)
ut, at = _txt(u), _txt(a)
if ut: messages.append({"role":"user","content":ut})
if at:
_, clean = parse_think_blocks(at)
messages.append({"role":"assistant","content":clean})
messages.append({"role": "user", "content": message})
try:
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
except Exception as e:
yield f"**❌ Template error:** `{e}`"
return
input_len = len(tokenizer.encode(prompt_text))
print(f"[GEN] tokens={input_len}, max_new={max_new_tokens}, "
f"temp={temperature}, MTI=on, Backend={BACKEND_NAME}", flush=True)
mti = MTILogitsProcessor(cfg_scale=1.5, entropy_threshold=2.0)
do_sample = float(temperature) > 0.01
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=max(float(temperature), 0.01) if do_sample else 1.0,
top_p=float(top_p),
)
queue: Queue = Queue()
thread = Thread(target=_engine_generate, args=(prompt_text, gen_kwargs, mti, queue))
thread.start()
output = ""
try:
while True:
chunk = queue.get(timeout=120)
if chunk is None: break
output += chunk
yield format_response(output)
except Exception as e:
if not output:
yield f"**❌ Streaming error:** `{e}`"
thread.join(timeout=5)
if output:
mti_rate = f"{mti.intervention_rate*100:.1f}%"
print(f"[GEN] Done -- {len(output)} chars, MTI={mti_rate} "
f"({mti._interventions}/{mti._total})", flush=True)
yield format_response(output)
else:
yield "**⚠️ 모델이 빈 응답을 반환했습니다.** 다시 시도해 주세요."
# ==============================================================================
# 6. GRADIO BLOCKS
# ==============================================================================
with gr.Blocks(title=MODEL_NAME) as gradio_demo:
thinking_toggle = gr.Radio(
choices=["⚡ Fast Mode", "🧠 Thinking Mode"],
value="⚡ Fast Mode", visible=False,
)
image_input = gr.Textbox(value="", visible=False)
system_prompt = gr.Textbox(value=PRESETS["general"], visible=False)
max_new_tokens = gr.Slider(minimum=64, maximum=MODEL_CAP["max_tokens"], value=4096, visible=False)
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, visible=False)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, visible=False)
gr.ChatInterface(
fn=generate_reply, api_name="chat",
additional_inputs=[
thinking_toggle, image_input,
system_prompt, max_new_tokens, temperature, top_p,
],
)
# ==============================================================================
# 7. FASTAPI
# ==============================================================================
fapp = FastAPI()
SESSIONS: dict[str, dict] = {}
HTML = pathlib.Path(__file__).parent / "index.html"
CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "")
CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "")
SPACE_HOST = os.getenv("SPACE_HOST", "localhost:7860")
REDIRECT_URI = f"https://{SPACE_HOST}/login/callback"
HF_AUTH_URL = "https://huggingface.co/oauth/authorize"
HF_TOKEN_URL = "https://huggingface.co/oauth/token"
HF_USER_URL = "https://huggingface.co/oauth/userinfo"
SCOPES = os.getenv("OAUTH_SCOPES", "openid profile")
print(f"[OAuth] CLIENT_ID={bool(CLIENT_ID)}, SPACE_HOST={SPACE_HOST}")
def _sid(req): return req.cookies.get("mc_session")
def _user(req):
sid = _sid(req)
return SESSIONS.get(sid) if sid else None
@fapp.get("/")
async def root(request: Request):
html = HTML.read_text(encoding="utf-8") if HTML.exists() else "<h2>index.html missing</h2>"
return HTMLResponse(html)
@fapp.get("/oauth/user")
async def oauth_user(request: Request):
u = _user(request)
return JSONResponse(u) if u else JSONResponse({"logged_in": False}, status_code=401)
@fapp.get("/oauth/login")
async def oauth_login(request: Request):
if not CLIENT_ID: return RedirectResponse("/?oauth_error=not_configured")
state = secrets.token_urlsafe(16)
params = {"response_type":"code","client_id":CLIENT_ID,"redirect_uri":REDIRECT_URI,"scope":SCOPES,"state":state}
return RedirectResponse(f"{HF_AUTH_URL}?{urlencode(params)}", status_code=302)
@fapp.get("/login/callback")
async def oauth_callback(code: str = "", error: str = "", state: str = ""):
if error or not code: return RedirectResponse("/?auth_error=1")
basic = base64.b64encode(f"{CLIENT_ID}:{CLIENT_SECRET}".encode()).decode()
async with httpx.AsyncClient() as client:
tok = await client.post(HF_TOKEN_URL, data={"grant_type":"authorization_code","code":code,"redirect_uri":REDIRECT_URI},
headers={"Accept":"application/json","Authorization":f"Basic {basic}"})
if tok.status_code != 200: return RedirectResponse("/?auth_error=1")
access_token = tok.json().get("access_token", "")
if not access_token: return RedirectResponse("/?auth_error=1")
uinfo = await client.get(HF_USER_URL, headers={"Authorization":f"Bearer {access_token}"})
if uinfo.status_code != 200: return RedirectResponse("/?auth_error=1")
user = uinfo.json()
sid = secrets.token_urlsafe(32)
SESSIONS[sid] = {
"logged_in": True,
"username": user.get("preferred_username", user.get("name", "User")),
"name": user.get("name", ""),
"avatar": user.get("picture", ""),
"profile": f"https://huggingface.co/{user.get('preferred_username', '')}",
}
resp = RedirectResponse("/")
resp.set_cookie("mc_session", sid, httponly=True, samesite="lax", secure=True, max_age=60*60*24*7)
return resp
@fapp.get("/oauth/logout")
async def oauth_logout(request: Request):
sid = _sid(request)
if sid and sid in SESSIONS: del SESSIONS[sid]
resp = RedirectResponse("/")
resp.delete_cookie("mc_session")
return resp
@fapp.get("/health")
async def health():
return {
"status": "ok", "model": MODEL_ID,
"backend": BACKEND_NAME,
"mti": "enabled",
"max_tokens": MODEL_CAP["max_tokens"],
"max_model_len": MAX_MODEL_LEN,
"multimodal": "vision+audio",
}
BRAVE_API_KEY = os.getenv("BRAVE_API_KEY", "")
@fapp.post("/api/search")
async def api_search(request: Request):
body = await request.json()
query = body.get("query", "").strip()
if not query: return JSONResponse({"error": "empty"}, 400)
if not BRAVE_API_KEY: return JSONResponse({"error": "no key"}, 500)
try:
r = requests.get("https://api.search.brave.com/res/v1/web/search",
headers={"X-Subscription-Token": BRAVE_API_KEY, "Accept": "application/json"},
params={"q": query, "count": 5}, timeout=10)
r.raise_for_status()
results = r.json().get("web", {}).get("results", [])
return JSONResponse({"results": [{"title":i.get("title",""),"desc":i.get("description",""),"url":i.get("url","")} for i in results[:5]]})
except Exception as e:
return JSONResponse({"error": str(e)}, 500)
@fapp.post("/api/extract-pdf")
async def api_extract_pdf(request: Request):
try:
body = await request.json()
b64 = body.get("data", "")
if "," in b64: b64 = b64.split(",", 1)[1]
pdf_bytes = base64.b64decode(b64)
text = ""
try:
import fitz
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page in doc: text += page.get_text() + "\n"
except ImportError:
text = pdf_bytes.decode("utf-8", errors="ignore")
return JSONResponse({"text": text.strip()[:8000], "chars": len(text)})
except Exception as e:
return JSONResponse({"error": str(e)}, 500)
# ==============================================================================
# 8. MOUNT & RUN
# ==============================================================================
app = gr.mount_gradio_app(fapp, gradio_demo, path="/gradio")
def _shutdown(sig, frame):
print("[BOOT] Shutdown", flush=True)
sys.exit(0)
signal.signal(signal.SIGTERM, _shutdown)
signal.signal(signal.SIGINT, _shutdown)
if __name__ == "__main__":
print(f"[BOOT] {MODEL_NAME} - {BACKEND_NAME} - MTI - max_len={MAX_MODEL_LEN} - Ready", flush=True)
uvicorn.run(app, host="0.0.0.0", port=7860)