Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import re | |
| from typing import Any | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| def configure_model_cache_env() -> None: | |
| # ZeroGPU runtime downloads can fail through the Xet path if the default | |
| # cache layout resolves to an unwritable location inside the Space. | |
| # Force Hugging Face libraries onto a known writable cache tree and use | |
| # regular Hub downloads by default. | |
| default_hf_home = Path.home() / ".cache" / "huggingface" | |
| preferred_root = Path(os.getenv("DOTCACHE_MODEL_CACHE_DIR") or os.getenv("HF_HOME") or default_hf_home).resolve() | |
| fallback_root = (REPO_ROOT / ".hf-cache").resolve() | |
| cache_root = preferred_root | |
| for candidate in (preferred_root, fallback_root): | |
| try: | |
| candidate.mkdir(parents=True, exist_ok=True) | |
| test_file = candidate / ".write_test" | |
| test_file.write_text("ok", encoding="utf-8") | |
| test_file.unlink(missing_ok=True) | |
| cache_root = candidate | |
| break | |
| except OSError: | |
| continue | |
| hub_cache = cache_root / "hub" | |
| xet_cache = cache_root / "xet" | |
| assets_cache = cache_root / "assets" | |
| transformers_cache = cache_root / "transformers" | |
| modules_cache = cache_root / "modules" | |
| for path in (hub_cache, xet_cache, assets_cache, transformers_cache, modules_cache): | |
| path.mkdir(parents=True, exist_ok=True) | |
| os.environ.setdefault("HF_HOME", str(cache_root)) | |
| os.environ.setdefault("HF_HUB_CACHE", str(hub_cache)) | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hub_cache)) | |
| os.environ.setdefault("HF_XET_CACHE", str(xet_cache)) | |
| os.environ.setdefault("HF_ASSETS_CACHE", str(assets_cache)) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", str(transformers_cache)) | |
| os.environ.setdefault("HF_MODULES_CACHE", str(modules_cache)) | |
| os.environ.setdefault("HF_HUB_DISABLE_XET", "1") | |
| def load_request_from_stdin() -> dict[str, Any]: | |
| payload = json.loads(input()) | |
| if not isinstance(payload, dict): | |
| raise ValueError("Runner stdin payload must be a JSON object.") | |
| return payload | |
| def clean_generated_text(text: str) -> str: | |
| cleaned = str(text) | |
| cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL | re.IGNORECASE) | |
| cleaned = re.sub(r"<\|im_start\|>|<\|im_end\|>", "", cleaned) | |
| cleaned = re.sub(r"(?m)^(system|user|assistant)\s*$", "", cleaned, flags=re.IGNORECASE) | |
| cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) | |
| return cleaned.strip() | |
| def decode_generated_text(tokenizer: Any, generated_ids: list[int], *, limit: int | None = None) -> str: | |
| ids = list(generated_ids if limit is None else generated_ids[:limit]) | |
| if tokenizer is None or not ids: | |
| return "" | |
| raw_text = str(tokenizer.decode(ids, skip_special_tokens=True)) | |
| cleaned = clean_generated_text(raw_text) | |
| return cleaned or raw_text.strip() | |
| def tok_per_sec_from_latency(latency_ms_per_token: float) -> float: | |
| if latency_ms_per_token <= 0.0: | |
| return 0.0 | |
| return float(1000.0 / latency_ms_per_token) | |
| def print_json(payload: dict[str, Any]) -> None: | |
| print(json.dumps(payload, sort_keys=True)) | |