Spaces:
Running on Zero
Running on Zero
| """Model load + single-round generation for demo-v3 step 2. | |
| Module-level load on startup (matches the ZeroGPU pattern: model in CPU RAM, | |
| GPU attached per @spaces.GPU call). One blocking generate function that | |
| returns the raw text output. No streaming, no multi-round, no tool dispatch | |
| yet — those land in step 3+. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| HERE = Path(__file__).resolve().parent | |
| # Monorepo: harness/ is at HERE.parent. Space deploy: harness/ is at HERE. | |
| REPO_ROOT = HERE if (HERE / "harness").is_dir() else HERE.parent | |
| MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3.5-9B") | |
| ADAPTER_ID = os.environ.get("ADAPTER_ID", "continker/Qwen3.5-9B-metro-v23") | |
| ADAPTER_SUBFOLDER = os.environ.get("ADAPTER_SUBFOLDER", "adapter") | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "2048")) | |
| def _localise_adapter(repo_id: str, subfolder: str) -> str: | |
| """Download adapter + remap key names to current Qwen3.5 architecture. | |
| Published adapters use `model.language_model.layers.*` (multimodal-shaped | |
| Qwen). The current text-only Qwen3.5 has flat `model.layers.*`, so we | |
| strip `.language_model.` from each safetensors key. Cached at | |
| `demo-v3/.adapter_cache/{repo}__{subfolder}/`. | |
| """ | |
| cache_root = HERE / ".adapter_cache" | |
| safe = repo_id.replace("/", "__") + "__" + subfolder | |
| dst = cache_root / safe | |
| flag = dst / ".localised" | |
| if flag.exists(): | |
| return str(dst) | |
| from huggingface_hub import snapshot_download | |
| from safetensors.torch import load_file, save_file | |
| src_root = snapshot_download(repo_id, allow_patterns=[f"{subfolder}/*"]) | |
| src = Path(src_root) / subfolder | |
| dst.mkdir(parents=True, exist_ok=True) | |
| for fname in os.listdir(src): | |
| if fname != "adapter_model.safetensors": | |
| shutil.copy(src / fname, dst) | |
| sd = load_file(str(src / "adapter_model.safetensors")) | |
| remapped = {k.replace(".language_model.layers.", ".layers."): v for k, v in sd.items()} | |
| save_file(remapped, str(dst / "adapter_model.safetensors")) | |
| flag.touch() | |
| print(f"[model] localised adapter ({len(remapped)} keys) → {dst}", flush=True) | |
| return str(dst) | |
| # --- Device + dtype -------------------------------------------------------- | |
| # ZeroGPU emulation accepts .to("cuda") even when no real GPU is present at | |
| # module load. On a developer Mac we fall back to MPS so the same code path | |
| # runs locally. | |
| if torch.cuda.is_available(): | |
| DEVICE, DTYPE = "cuda", torch.bfloat16 | |
| elif torch.backends.mps.is_available(): | |
| DEVICE, DTYPE = "mps", torch.float16 | |
| else: | |
| DEVICE, DTYPE = "cpu", torch.float32 | |
| print(f"[model] loading {MODEL_ID} on {DEVICE} ({DTYPE})…", flush=True) | |
| _adapter_path = _localise_adapter(ADAPTER_ID, ADAPTER_SUBFOLDER) if ADAPTER_ID else None | |
| tokenizer = AutoTokenizer.from_pretrained(_adapter_path or MODEL_ID) | |
| # FlashAttention-2 only on CUDA — flash-attn isn't built for MPS/CPU. When | |
| # present (ZeroGPU CUDA) it gives ~2–3× decode throughput; falls back | |
| # silently elsewhere so the same code runs on a Mac dev box. | |
| _attn_impl = None | |
| if DEVICE == "cuda": | |
| try: | |
| import flash_attn # noqa: F401 | |
| _attn_impl = "flash_attention_2" | |
| print("[model] flash-attn available → attn_implementation=flash_attention_2", | |
| flush=True) | |
| except ImportError: | |
| print("[model] flash-attn not installed; using default SDPA attention", | |
| flush=True) | |
| _base = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, dtype=DTYPE, low_cpu_mem_usage=True, | |
| **({"attn_implementation": _attn_impl} if _attn_impl else {}), | |
| ).to(DEVICE) | |
| if _adapter_path: | |
| from peft import PeftModel | |
| print(f"[model] applying LoRA adapter from {_adapter_path}…", flush=True) | |
| model = PeftModel.from_pretrained(_base, _adapter_path).merge_and_unload() | |
| else: | |
| model = _base | |
| model.eval() | |
| # Diagnostic: confirm the attention implementation that actually | |
| # attached to the loaded model (post-PEFT merge). If this shows "sdpa" | |
| # or "eager" while we requested "flash_attention_2", the kwarg got | |
| # dropped or the architecture doesn't support FA2 — explains a missing | |
| # speedup despite flash-attn being importable. | |
| try: | |
| actual_attn = getattr(model.config, "_attn_implementation", None) \ | |
| or getattr(model, "_attn_implementation", None) \ | |
| or "unknown" | |
| print(f"[model] runtime attn_implementation={actual_attn}", flush=True) | |
| except Exception as e: | |
| print(f"[model] attn_implementation probe failed: {e}", flush=True) | |
| print(f"[model] ready on {DEVICE}", flush=True) | |
| # --- Tool schema (imported from harness for chat-template parity) --------- | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| try: | |
| from harness.runner import TOOL_DEFINITIONS # noqa: F401 | |
| except ImportError as e: | |
| print(f"[model] WARN: TOOL_DEFINITIONS unavailable ({e})", flush=True) | |
| TOOL_DEFINITIONS = [] | |
| # --- System prompt builder (shared with harness.runner) ------------------- | |
| from harness.prompts import build_system_prompt # noqa: E402,F401 | |
| # --- Single-round generation ---------------------------------------------- | |
| def generate_one_round(messages: list[dict]) -> str: | |
| """Blocking single-round generate. The caller wraps this in @spaces.GPU | |
| so the GPU is held only for the generate itself; tokenization + chat | |
| template build is CPU-side and runs before the decorator fires. | |
| """ | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| return_dict=True, tools=TOOL_DEFINITIONS or None, | |
| enable_thinking=True, | |
| ).to(DEVICE) | |
| except TypeError: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| return_dict=True, tools=TOOL_DEFINITIONS or None, | |
| ).to(DEVICE) | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| stop_strings=["</tool_call>", "<|im_end|>"], | |
| tokenizer=tokenizer, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| new_tokens = output[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| def stream_one_round(messages: list[dict]): | |
| """Generator: yields (chunk_str, accumulated_full_text) per emitted token | |
| batch from `TextIteratorStreamer`. Generation runs in a worker thread so | |
| the streamer can deliver chunks while we yield to the UI.""" | |
| import threading | |
| from transformers import TextIteratorStreamer | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| return_dict=True, tools=TOOL_DEFINITIONS or None, | |
| enable_thinking=True, | |
| ).to(DEVICE) | |
| except TypeError: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| return_dict=True, tools=TOOL_DEFINITIONS or None, | |
| ).to(DEVICE) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| gen_kwargs = dict( | |
| inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| streamer=streamer, | |
| stop_strings=["</tool_call>", "<|im_end|>"], | |
| tokenizer=tokenizer, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) | |
| thread.start() | |
| full = "" | |
| for chunk in streamer: | |
| full += chunk | |
| yield chunk, full | |
| thread.join() | |
| # --- Tool-call parsers (Hermes/Qwen XML) ---------------------------------- | |
| _TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL) | |
| _FUNCTION_RE = re.compile(r"<function=(\w+)>(.*?)</function>", re.DOTALL) | |
| _PARAMETER_RE = re.compile(r"<parameter=(\w+)>(.*?)</parameter>", re.DOTALL) | |
| _ASSISTANT_MSG_RE = re.compile(r'"assistant_message"\s*:\s*"((?:[^"\\]|\\.)*)"', re.DOTALL) | |
| def parse_tool_calls(text: str) -> list[dict]: | |
| """Extract every <tool_call><function=NAME><parameter=K>V</parameter>... | |
| block from the model's text output. Returns [{"name", "arguments"}, ...].""" | |
| calls: list[dict] = [] | |
| for tc in _TOOL_CALL_RE.finditer(text): | |
| body = tc.group(1) | |
| fn = _FUNCTION_RE.search(body) | |
| if not fn: | |
| continue | |
| name = fn.group(1) | |
| args: dict = {} | |
| for p in _PARAMETER_RE.finditer(fn.group(2)): | |
| key = p.group(1) | |
| raw = p.group(2).strip() | |
| try: | |
| args[key] = json.loads(raw) | |
| except (json.JSONDecodeError, ValueError): | |
| args[key] = raw | |
| calls.append({"name": name, "arguments": args}) | |
| return calls | |
| def extract_assistant_message(text: str) -> str | None: | |
| """Best-effort recovery of `assistant_message` from a malformed | |
| submit_assistant_state body when the structured parse fails.""" | |
| m = _ASSISTANT_MSG_RE.search(text) | |
| if not m: | |
| return None | |
| raw = m.group(1) | |
| try: | |
| return raw.encode().decode("unicode_escape") | |
| except Exception: | |
| return raw | |