DotCache-Arena / scripts /space_runner_common.py
DeanoCalver's picture
Add live Llama lane and writable cache fallback
e135040 verified
Raw
History Blame Contribute Delete
3.22 kB
from __future__ import annotations
import json
import os
from pathlib import Path
import re
from typing import Any
REPO_ROOT = Path(__file__).resolve().parents[1]
def configure_model_cache_env() -> None:
# ZeroGPU runtime downloads can fail through the Xet path if the default
# cache layout resolves to an unwritable location inside the Space.
# Force Hugging Face libraries onto a known writable cache tree and use
# regular Hub downloads by default.
default_hf_home = Path.home() / ".cache" / "huggingface"
preferred_root = Path(os.getenv("DOTCACHE_MODEL_CACHE_DIR") or os.getenv("HF_HOME") or default_hf_home).resolve()
fallback_root = (REPO_ROOT / ".hf-cache").resolve()
cache_root = preferred_root
for candidate in (preferred_root, fallback_root):
try:
candidate.mkdir(parents=True, exist_ok=True)
test_file = candidate / ".write_test"
test_file.write_text("ok", encoding="utf-8")
test_file.unlink(missing_ok=True)
cache_root = candidate
break
except OSError:
continue
hub_cache = cache_root / "hub"
xet_cache = cache_root / "xet"
assets_cache = cache_root / "assets"
transformers_cache = cache_root / "transformers"
modules_cache = cache_root / "modules"
for path in (hub_cache, xet_cache, assets_cache, transformers_cache, modules_cache):
path.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("HF_HOME", str(cache_root))
os.environ.setdefault("HF_HUB_CACHE", str(hub_cache))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hub_cache))
os.environ.setdefault("HF_XET_CACHE", str(xet_cache))
os.environ.setdefault("HF_ASSETS_CACHE", str(assets_cache))
os.environ.setdefault("TRANSFORMERS_CACHE", str(transformers_cache))
os.environ.setdefault("HF_MODULES_CACHE", str(modules_cache))
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
def load_request_from_stdin() -> dict[str, Any]:
payload = json.loads(input())
if not isinstance(payload, dict):
raise ValueError("Runner stdin payload must be a JSON object.")
return payload
def clean_generated_text(text: str) -> str:
cleaned = str(text)
cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL | re.IGNORECASE)
cleaned = re.sub(r"<\|im_start\|>|<\|im_end\|>", "", cleaned)
cleaned = re.sub(r"(?m)^(system|user|assistant)\s*$", "", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
return cleaned.strip()
def decode_generated_text(tokenizer: Any, generated_ids: list[int], *, limit: int | None = None) -> str:
ids = list(generated_ids if limit is None else generated_ids[:limit])
if tokenizer is None or not ids:
return ""
raw_text = str(tokenizer.decode(ids, skip_special_tokens=True))
cleaned = clean_generated_text(raw_text)
return cleaned or raw_text.strip()
def tok_per_sec_from_latency(latency_ms_per_token: float) -> float:
if latency_ms_per_token <= 0.0:
return 0.0
return float(1000.0 / latency_ms_per_token)
def print_json(payload: dict[str, Any]) -> None:
print(json.dumps(payload, sort_keys=True))