"""Model load + single-round generation for demo-v3 step 2. Module-level load on startup (matches the ZeroGPU pattern: model in CPU RAM, GPU attached per @spaces.GPU call). One blocking generate function that returns the raw text output. No streaming, no multi-round, no tool dispatch yet — those land in step 3+. """ from __future__ import annotations import json import os import re import shutil import sys from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer HERE = Path(__file__).resolve().parent # Monorepo: harness/ is at HERE.parent. Space deploy: harness/ is at HERE. REPO_ROOT = HERE if (HERE / "harness").is_dir() else HERE.parent MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3.5-9B") ADAPTER_ID = os.environ.get("ADAPTER_ID", "continker/Qwen3.5-9B-metro-v23") ADAPTER_SUBFOLDER = os.environ.get("ADAPTER_SUBFOLDER", "adapter") MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "2048")) def _localise_adapter(repo_id: str, subfolder: str) -> str: """Download adapter + remap key names to current Qwen3.5 architecture. Published adapters use `model.language_model.layers.*` (multimodal-shaped Qwen). The current text-only Qwen3.5 has flat `model.layers.*`, so we strip `.language_model.` from each safetensors key. Cached at `demo-v3/.adapter_cache/{repo}__{subfolder}/`. """ cache_root = HERE / ".adapter_cache" safe = repo_id.replace("/", "__") + "__" + subfolder dst = cache_root / safe flag = dst / ".localised" if flag.exists(): return str(dst) from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file src_root = snapshot_download(repo_id, allow_patterns=[f"{subfolder}/*"]) src = Path(src_root) / subfolder dst.mkdir(parents=True, exist_ok=True) for fname in os.listdir(src): if fname != "adapter_model.safetensors": shutil.copy(src / fname, dst) sd = load_file(str(src / "adapter_model.safetensors")) remapped = {k.replace(".language_model.layers.", ".layers."): v for k, v in sd.items()} save_file(remapped, str(dst / "adapter_model.safetensors")) flag.touch() print(f"[model] localised adapter ({len(remapped)} keys) → {dst}", flush=True) return str(dst) # --- Device + dtype -------------------------------------------------------- # ZeroGPU emulation accepts .to("cuda") even when no real GPU is present at # module load. On a developer Mac we fall back to MPS so the same code path # runs locally. if torch.cuda.is_available(): DEVICE, DTYPE = "cuda", torch.bfloat16 elif torch.backends.mps.is_available(): DEVICE, DTYPE = "mps", torch.float16 else: DEVICE, DTYPE = "cpu", torch.float32 print(f"[model] loading {MODEL_ID} on {DEVICE} ({DTYPE})…", flush=True) _adapter_path = _localise_adapter(ADAPTER_ID, ADAPTER_SUBFOLDER) if ADAPTER_ID else None tokenizer = AutoTokenizer.from_pretrained(_adapter_path or MODEL_ID) # FlashAttention-2 only on CUDA — flash-attn isn't built for MPS/CPU. When # present (ZeroGPU CUDA) it gives ~2–3× decode throughput; falls back # silently elsewhere so the same code runs on a Mac dev box. _attn_impl = None if DEVICE == "cuda": try: import flash_attn # noqa: F401 _attn_impl = "flash_attention_2" print("[model] flash-attn available → attn_implementation=flash_attention_2", flush=True) except ImportError: print("[model] flash-attn not installed; using default SDPA attention", flush=True) _base = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=DTYPE, low_cpu_mem_usage=True, **({"attn_implementation": _attn_impl} if _attn_impl else {}), ).to(DEVICE) if _adapter_path: from peft import PeftModel print(f"[model] applying LoRA adapter from {_adapter_path}…", flush=True) model = PeftModel.from_pretrained(_base, _adapter_path).merge_and_unload() else: model = _base model.eval() # Diagnostic: confirm the attention implementation that actually # attached to the loaded model (post-PEFT merge). If this shows "sdpa" # or "eager" while we requested "flash_attention_2", the kwarg got # dropped or the architecture doesn't support FA2 — explains a missing # speedup despite flash-attn being importable. try: actual_attn = getattr(model.config, "_attn_implementation", None) \ or getattr(model, "_attn_implementation", None) \ or "unknown" print(f"[model] runtime attn_implementation={actual_attn}", flush=True) except Exception as e: print(f"[model] attn_implementation probe failed: {e}", flush=True) print(f"[model] ready on {DEVICE}", flush=True) # --- Tool schema (imported from harness for chat-template parity) --------- if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) try: from harness.runner import TOOL_DEFINITIONS # noqa: F401 except ImportError as e: print(f"[model] WARN: TOOL_DEFINITIONS unavailable ({e})", flush=True) TOOL_DEFINITIONS = [] # --- System prompt builder (shared with harness.runner) ------------------- from harness.prompts import build_system_prompt # noqa: E402,F401 # --- Single-round generation ---------------------------------------------- def generate_one_round(messages: list[dict]) -> str: """Blocking single-round generate. The caller wraps this in @spaces.GPU so the GPU is held only for the generate itself; tokenization + chat template build is CPU-side and runs before the decorator fires. """ try: inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, return_dict=True, tools=TOOL_DEFINITIONS or None, enable_thinking=True, ).to(DEVICE) except TypeError: inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, return_dict=True, tools=TOOL_DEFINITIONS or None, ).to(DEVICE) output = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, stop_strings=["", "<|im_end|>"], tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id, ) new_tokens = output[0][inputs["input_ids"].shape[1]:] return tokenizer.decode(new_tokens, skip_special_tokens=True) def stream_one_round(messages: list[dict]): """Generator: yields (chunk_str, accumulated_full_text) per emitted token batch from `TextIteratorStreamer`. Generation runs in a worker thread so the streamer can deliver chunks while we yield to the UI.""" import threading from transformers import TextIteratorStreamer try: inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, return_dict=True, tools=TOOL_DEFINITIONS or None, enable_thinking=True, ).to(DEVICE) except TypeError: inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, return_dict=True, tools=TOOL_DEFINITIONS or None, ).to(DEVICE) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = dict( inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, streamer=streamer, stop_strings=["", "<|im_end|>"], tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) thread.start() full = "" for chunk in streamer: full += chunk yield chunk, full thread.join() # --- Tool-call parsers (Hermes/Qwen XML) ---------------------------------- _TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL) _FUNCTION_RE = re.compile(r"(.*?)", re.DOTALL) _PARAMETER_RE = re.compile(r"(.*?)", re.DOTALL) _ASSISTANT_MSG_RE = re.compile(r'"assistant_message"\s*:\s*"((?:[^"\\]|\\.)*)"', re.DOTALL) def parse_tool_calls(text: str) -> list[dict]: """Extract every V... block from the model's text output. Returns [{"name", "arguments"}, ...].""" calls: list[dict] = [] for tc in _TOOL_CALL_RE.finditer(text): body = tc.group(1) fn = _FUNCTION_RE.search(body) if not fn: continue name = fn.group(1) args: dict = {} for p in _PARAMETER_RE.finditer(fn.group(2)): key = p.group(1) raw = p.group(2).strip() try: args[key] = json.loads(raw) except (json.JSONDecodeError, ValueError): args[key] = raw calls.append({"name": name, "arguments": args}) return calls def extract_assistant_message(text: str) -> str | None: """Best-effort recovery of `assistant_message` from a malformed submit_assistant_state body when the structured parse fails.""" m = _ASSISTANT_MSG_RE.search(text) if not m: return None raw = m.group(1) try: return raw.encode().decode("unicode_escape") except Exception: return raw