Spaces:
Running on Zero
Running on Zero
File size: 9,332 Bytes
cbfaae5 b08e50f cbfaae5 32606be b08e50f 32606be b08e50f cbfaae5 8003669 cbfaae5 bb7b69a cbfaae5 bb7b69a cbfaae5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | """Model load + single-round generation for demo-v3 step 2.
Module-level load on startup (matches the ZeroGPU pattern: model in CPU RAM,
GPU attached per @spaces.GPU call). One blocking generate function that
returns the raw text output. No streaming, no multi-round, no tool dispatch
yet β those land in step 3+.
"""
from __future__ import annotations
import json
import os
import re
import shutil
import sys
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HERE = Path(__file__).resolve().parent
# Monorepo: harness/ is at HERE.parent. Space deploy: harness/ is at HERE.
REPO_ROOT = HERE if (HERE / "harness").is_dir() else HERE.parent
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen3.5-9B")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "continker/Qwen3.5-9B-metro-v23")
ADAPTER_SUBFOLDER = os.environ.get("ADAPTER_SUBFOLDER", "adapter")
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "2048"))
def _localise_adapter(repo_id: str, subfolder: str) -> str:
"""Download adapter + remap key names to current Qwen3.5 architecture.
Published adapters use `model.language_model.layers.*` (multimodal-shaped
Qwen). The current text-only Qwen3.5 has flat `model.layers.*`, so we
strip `.language_model.` from each safetensors key. Cached at
`demo-v3/.adapter_cache/{repo}__{subfolder}/`.
"""
cache_root = HERE / ".adapter_cache"
safe = repo_id.replace("/", "__") + "__" + subfolder
dst = cache_root / safe
flag = dst / ".localised"
if flag.exists():
return str(dst)
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
src_root = snapshot_download(repo_id, allow_patterns=[f"{subfolder}/*"])
src = Path(src_root) / subfolder
dst.mkdir(parents=True, exist_ok=True)
for fname in os.listdir(src):
if fname != "adapter_model.safetensors":
shutil.copy(src / fname, dst)
sd = load_file(str(src / "adapter_model.safetensors"))
remapped = {k.replace(".language_model.layers.", ".layers."): v for k, v in sd.items()}
save_file(remapped, str(dst / "adapter_model.safetensors"))
flag.touch()
print(f"[model] localised adapter ({len(remapped)} keys) β {dst}", flush=True)
return str(dst)
# --- Device + dtype --------------------------------------------------------
# ZeroGPU emulation accepts .to("cuda") even when no real GPU is present at
# module load. On a developer Mac we fall back to MPS so the same code path
# runs locally.
if torch.cuda.is_available():
DEVICE, DTYPE = "cuda", torch.bfloat16
elif torch.backends.mps.is_available():
DEVICE, DTYPE = "mps", torch.float16
else:
DEVICE, DTYPE = "cpu", torch.float32
print(f"[model] loading {MODEL_ID} on {DEVICE} ({DTYPE})β¦", flush=True)
_adapter_path = _localise_adapter(ADAPTER_ID, ADAPTER_SUBFOLDER) if ADAPTER_ID else None
tokenizer = AutoTokenizer.from_pretrained(_adapter_path or MODEL_ID)
# FlashAttention-2 only on CUDA β flash-attn isn't built for MPS/CPU. When
# present (ZeroGPU CUDA) it gives ~2β3Γ decode throughput; falls back
# silently elsewhere so the same code runs on a Mac dev box.
_attn_impl = None
if DEVICE == "cuda":
try:
import flash_attn # noqa: F401
_attn_impl = "flash_attention_2"
print("[model] flash-attn available β attn_implementation=flash_attention_2",
flush=True)
except ImportError:
print("[model] flash-attn not installed; using default SDPA attention",
flush=True)
_base = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=DTYPE, low_cpu_mem_usage=True,
**({"attn_implementation": _attn_impl} if _attn_impl else {}),
).to(DEVICE)
if _adapter_path:
from peft import PeftModel
print(f"[model] applying LoRA adapter from {_adapter_path}β¦", flush=True)
model = PeftModel.from_pretrained(_base, _adapter_path).merge_and_unload()
else:
model = _base
model.eval()
# Diagnostic: confirm the attention implementation that actually
# attached to the loaded model (post-PEFT merge). If this shows "sdpa"
# or "eager" while we requested "flash_attention_2", the kwarg got
# dropped or the architecture doesn't support FA2 β explains a missing
# speedup despite flash-attn being importable.
try:
actual_attn = getattr(model.config, "_attn_implementation", None) \
or getattr(model, "_attn_implementation", None) \
or "unknown"
print(f"[model] runtime attn_implementation={actual_attn}", flush=True)
except Exception as e:
print(f"[model] attn_implementation probe failed: {e}", flush=True)
print(f"[model] ready on {DEVICE}", flush=True)
# --- Tool schema (imported from harness for chat-template parity) ---------
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
try:
from harness.runner import TOOL_DEFINITIONS # noqa: F401
except ImportError as e:
print(f"[model] WARN: TOOL_DEFINITIONS unavailable ({e})", flush=True)
TOOL_DEFINITIONS = []
# --- System prompt builder (shared with harness.runner) -------------------
from harness.prompts import build_system_prompt # noqa: E402,F401
# --- Single-round generation ----------------------------------------------
def generate_one_round(messages: list[dict]) -> str:
"""Blocking single-round generate. The caller wraps this in @spaces.GPU
so the GPU is held only for the generate itself; tokenization + chat
template build is CPU-side and runs before the decorator fires.
"""
try:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
enable_thinking=True,
).to(DEVICE)
except TypeError:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
).to(DEVICE)
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
stop_strings=["</tool_call>", "<|im_end|>"],
tokenizer=tokenizer,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
def stream_one_round(messages: list[dict]):
"""Generator: yields (chunk_str, accumulated_full_text) per emitted token
batch from `TextIteratorStreamer`. Generation runs in a worker thread so
the streamer can deliver chunks while we yield to the UI."""
import threading
from transformers import TextIteratorStreamer
try:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
enable_thinking=True,
).to(DEVICE)
except TypeError:
inputs = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
return_dict=True, tools=TOOL_DEFINITIONS or None,
).to(DEVICE)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
streamer=streamer,
stop_strings=["</tool_call>", "<|im_end|>"],
tokenizer=tokenizer,
pad_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
full = ""
for chunk in streamer:
full += chunk
yield chunk, full
thread.join()
# --- Tool-call parsers (Hermes/Qwen XML) ----------------------------------
_TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
_FUNCTION_RE = re.compile(r"<function=(\w+)>(.*?)</function>", re.DOTALL)
_PARAMETER_RE = re.compile(r"<parameter=(\w+)>(.*?)</parameter>", re.DOTALL)
_ASSISTANT_MSG_RE = re.compile(r'"assistant_message"\s*:\s*"((?:[^"\\]|\\.)*)"', re.DOTALL)
def parse_tool_calls(text: str) -> list[dict]:
"""Extract every <tool_call><function=NAME><parameter=K>V</parameter>...
block from the model's text output. Returns [{"name", "arguments"}, ...]."""
calls: list[dict] = []
for tc in _TOOL_CALL_RE.finditer(text):
body = tc.group(1)
fn = _FUNCTION_RE.search(body)
if not fn:
continue
name = fn.group(1)
args: dict = {}
for p in _PARAMETER_RE.finditer(fn.group(2)):
key = p.group(1)
raw = p.group(2).strip()
try:
args[key] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
args[key] = raw
calls.append({"name": name, "arguments": args})
return calls
def extract_assistant_message(text: str) -> str | None:
"""Best-effort recovery of `assistant_message` from a malformed
submit_assistant_state body when the structured parse fails."""
m = _ASSISTANT_MSG_RE.search(text)
if not m:
return None
raw = m.group(1)
try:
return raw.encode().decode("unicode_escape")
except Exception:
return raw
|