from __future__ import annotations import json import os from pathlib import Path import re from typing import Any REPO_ROOT = Path(__file__).resolve().parents[1] def configure_model_cache_env() -> None: # ZeroGPU runtime downloads can fail through the Xet path if the default # cache layout resolves to an unwritable location inside the Space. # Force Hugging Face libraries onto a known writable cache tree and use # regular Hub downloads by default. default_hf_home = Path.home() / ".cache" / "huggingface" preferred_root = Path(os.getenv("DOTCACHE_MODEL_CACHE_DIR") or os.getenv("HF_HOME") or default_hf_home).resolve() fallback_root = (REPO_ROOT / ".hf-cache").resolve() cache_root = preferred_root for candidate in (preferred_root, fallback_root): try: candidate.mkdir(parents=True, exist_ok=True) test_file = candidate / ".write_test" test_file.write_text("ok", encoding="utf-8") test_file.unlink(missing_ok=True) cache_root = candidate break except OSError: continue hub_cache = cache_root / "hub" xet_cache = cache_root / "xet" assets_cache = cache_root / "assets" transformers_cache = cache_root / "transformers" modules_cache = cache_root / "modules" for path in (hub_cache, xet_cache, assets_cache, transformers_cache, modules_cache): path.mkdir(parents=True, exist_ok=True) os.environ.setdefault("HF_HOME", str(cache_root)) os.environ.setdefault("HF_HUB_CACHE", str(hub_cache)) os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hub_cache)) os.environ.setdefault("HF_XET_CACHE", str(xet_cache)) os.environ.setdefault("HF_ASSETS_CACHE", str(assets_cache)) os.environ.setdefault("TRANSFORMERS_CACHE", str(transformers_cache)) os.environ.setdefault("HF_MODULES_CACHE", str(modules_cache)) os.environ.setdefault("HF_HUB_DISABLE_XET", "1") def load_request_from_stdin() -> dict[str, Any]: payload = json.loads(input()) if not isinstance(payload, dict): raise ValueError("Runner stdin payload must be a JSON object.") return payload def clean_generated_text(text: str) -> str: cleaned = str(text) cleaned = re.sub(r".*?", "", cleaned, flags=re.DOTALL | re.IGNORECASE) cleaned = re.sub(r"<\|im_start\|>|<\|im_end\|>", "", cleaned) cleaned = re.sub(r"(?m)^(system|user|assistant)\s*$", "", cleaned, flags=re.IGNORECASE) cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) return cleaned.strip() def decode_generated_text(tokenizer: Any, generated_ids: list[int], *, limit: int | None = None) -> str: ids = list(generated_ids if limit is None else generated_ids[:limit]) if tokenizer is None or not ids: return "" raw_text = str(tokenizer.decode(ids, skip_special_tokens=True)) cleaned = clean_generated_text(raw_text) return cleaned or raw_text.strip() def tok_per_sec_from_latency(latency_ms_per_token: float) -> float: if latency_ms_per_token <= 0.0: return 0.0 return float(1000.0 / latency_ms_per_token) def print_json(payload: dict[str, Any]) -> None: print(json.dumps(payload, sort_keys=True))