"""Model load + single-round generation for demo-v3 step 2.
Module-level load on startup (matches the ZeroGPU pattern: model in CPU RAM,
GPU attached per @spaces.GPU call). One blocking generate function that
returns the raw text output. No streaming, no multi-round, no tool dispatch
yet — those land in step 3+.
"""
from __future__ import annotations
import json
import os
import re
import shutil
import sys
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HERE = Path(__file__).resolve().parent
# Monorepo: harness/ is at HERE.parent. Space deploy: harness/ is at HERE.
REPO_ROOT = HERE if (HERE / "harness").is_dir() else HERE.parent
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3.5-9B")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "continker/Qwen3.5-9B-metro-v23")
ADAPTER_SUBFOLDER = os.environ.get("ADAPTER_SUBFOLDER", "adapter")
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "2048"))
def _localise_adapter(repo_id: str, subfolder: str) -> str:
"""Download adapter + remap key names to current Qwen3.5 architecture.
Published adapters use `model.language_model.layers.*` (multimodal-shaped
Qwen). The current text-only Qwen3.5 has flat `model.layers.*`, so we
strip `.language_model.` from each safetensors key. Cached at
`demo-v3/.adapter_cache/{repo}__{subfolder}/`.
"""
cache_root = HERE / ".adapter_cache"
safe = repo_id.replace("/", "__") + "__" + subfolder
dst = cache_root / safe
flag = dst / ".localised"
if flag.exists():
return str(dst)
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
src_root = snapshot_download(repo_id, allow_patterns=[f"{subfolder}/*"])
src = Path(src_root) / subfolder
dst.mkdir(parents=True, exist_ok=True)
for fname in os.listdir(src):
if fname != "adapter_model.safetensors":
shutil.copy(src / fname, dst)
sd = load_file(str(src / "adapter_model.safetensors"))
remapped = {k.replace(".language_model.layers.", ".layers."): v for k, v in sd.items()}
save_file(remapped, str(dst / "adapter_model.safetensors"))
flag.touch()
print(f"[model] localised adapter ({len(remapped)} keys) → {dst}", flush=True)
return str(dst)
# --- Device + dtype --------------------------------------------------------
# ZeroGPU emulation accepts .to("cuda") even when no real GPU is present at
# module load. On a developer Mac we fall back to MPS so the same code path
# runs locally.
if torch.cuda.is_available():
DEVICE, DTYPE = "cuda", torch.bfloat16
elif torch.backends.mps.is_available():
DEVICE, DTYPE = "mps", torch.float16
else:
DEVICE, DTYPE = "cpu", torch.float32
print(f"[model] loading {MODEL_ID} on {DEVICE} ({DTYPE})…", flush=True)
_adapter_path = _localise_adapter(ADAPTER_ID, ADAPTER_SUBFOLDER) if ADAPTER_ID else None
tokenizer = AutoTokenizer.from_pretrained(_adapter_path or MODEL_ID)
# FlashAttention-2 only on CUDA — flash-attn isn't built for MPS/CPU. When
# present (ZeroGPU CUDA) it gives ~2–3× decode throughput; falls back
# silently elsewhere so the same code runs on a Mac dev box.
_attn_impl = None
if DEVICE == "cuda":
try:
import flash_attn # noqa: F401
_attn_impl = "flash_attention_2"
print("[model] flash-attn available → attn_implementation=flash_attention_2",
flush=True)
except ImportError:
print("[model] flash-attn not installed; using default SDPA attention",
flush=True)
_base = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=DTYPE, low_cpu_mem_usage=True,
**({"attn_implementation": _attn_impl} if _attn_impl else {}),
).to(DEVICE)
if _adapter_path:
from peft import PeftModel
print(f"[model] applying LoRA adapter from {_adapter_path}…", flush=True)
model = PeftModel.from_pretrained(_base, _adapter_path).merge_and_unload()
else:
model = _base
model.eval()
# Diagnostic: confirm the attention implementation that actually
# attached to the loaded model (post-PEFT merge). If this shows "sdpa"
# or "eager" while we requested "flash_attention_2", the kwarg got
# dropped or the architecture doesn't support FA2 — explains a missing
# speedup despite flash-attn being importable.
try:
actual_attn = getattr(model.config, "_attn_implementation", None) \
or getattr(model, "_attn_implementation", None) \
or "unknown"
print(f"[model] runtime attn_implementation={actual_attn}", flush=True)
except Exception as e:
print(f"[model] attn_implementation probe failed: {e}", flush=True)
print(f"[model] ready on {DEVICE}", flush=True)
# --- Tool schema (imported from harness for chat-template parity) ---------
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
try:
from harness.runner import TOOL_DEFINITIONS # noqa: F401
except ImportError as e:
print(f"[model] WARN: TOOL_DEFINITIONS unavailable ({e})", flush=True)
TOOL_DEFINITIONS = []
# --- System prompt builder (shared with harness.runner) -------------------
from harness.prompts import build_system_prompt # noqa: E402,F401
# --- Single-round generation ----------------------------------------------
def generate_one_round(messages: list[dict]) -> str:
"""Blocking single-round generate. The caller wraps this in @spaces.GPU
so the GPU is held only for the generate itself; tokenization + chat
template build is CPU-side and runs before the decorator fires.
"""
try:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
enable_thinking=True,
).to(DEVICE)
except TypeError:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
).to(DEVICE)
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
stop_strings=["", "<|im_end|>"],
tokenizer=tokenizer,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
def stream_one_round(messages: list[dict]):
"""Generator: yields (chunk_str, accumulated_full_text) per emitted token
batch from `TextIteratorStreamer`. Generation runs in a worker thread so
the streamer can deliver chunks while we yield to the UI."""
import threading
from transformers import TextIteratorStreamer
try:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
enable_thinking=True,
).to(DEVICE)
except TypeError:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
).to(DEVICE)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
streamer=streamer,
stop_strings=["", "<|im_end|>"],
tokenizer=tokenizer,
pad_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
full = ""
for chunk in streamer:
full += chunk
yield chunk, full
thread.join()
# --- Tool-call parsers (Hermes/Qwen XML) ----------------------------------
_TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL)
_FUNCTION_RE = re.compile(r"(.*?)", re.DOTALL)
_PARAMETER_RE = re.compile(r"(.*?)", re.DOTALL)
_ASSISTANT_MSG_RE = re.compile(r'"assistant_message"\s*:\s*"((?:[^"\\]|\\.)*)"', re.DOTALL)
def parse_tool_calls(text: str) -> list[dict]:
"""Extract every V...
block from the model's text output. Returns [{"name", "arguments"}, ...]."""
calls: list[dict] = []
for tc in _TOOL_CALL_RE.finditer(text):
body = tc.group(1)
fn = _FUNCTION_RE.search(body)
if not fn:
continue
name = fn.group(1)
args: dict = {}
for p in _PARAMETER_RE.finditer(fn.group(2)):
key = p.group(1)
raw = p.group(2).strip()
try:
args[key] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
args[key] = raw
calls.append({"name": name, "arguments": args})
return calls
def extract_assistant_message(text: str) -> str | None:
"""Best-effort recovery of `assistant_message` from a malformed
submit_assistant_state body when the structured parse fails."""
m = _ASSISTANT_MSG_RE.search(text)
if not m:
return None
raw = m.group(1)
try:
return raw.encode().decode("unicode_escape")
except Exception:
return raw