File size: 3,223 Bytes
751ad26
 
 
 
 
 
 
 
 
 
 
 
 
b36db4f
 
 
 
 
e135040
 
 
 
 
 
 
 
 
 
 
 
 
 
b36db4f
 
 
 
 
 
e135040
b36db4f
 
 
 
 
 
 
 
 
 
751ad26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import json
import os
from pathlib import Path
import re
from typing import Any


REPO_ROOT = Path(__file__).resolve().parents[1]


def configure_model_cache_env() -> None:
    # ZeroGPU runtime downloads can fail through the Xet path if the default
    # cache layout resolves to an unwritable location inside the Space.
    # Force Hugging Face libraries onto a known writable cache tree and use
    # regular Hub downloads by default.
    default_hf_home = Path.home() / ".cache" / "huggingface"
    preferred_root = Path(os.getenv("DOTCACHE_MODEL_CACHE_DIR") or os.getenv("HF_HOME") or default_hf_home).resolve()
    fallback_root = (REPO_ROOT / ".hf-cache").resolve()
    cache_root = preferred_root
    for candidate in (preferred_root, fallback_root):
        try:
            candidate.mkdir(parents=True, exist_ok=True)
            test_file = candidate / ".write_test"
            test_file.write_text("ok", encoding="utf-8")
            test_file.unlink(missing_ok=True)
            cache_root = candidate
            break
        except OSError:
            continue

    hub_cache = cache_root / "hub"
    xet_cache = cache_root / "xet"
    assets_cache = cache_root / "assets"
    transformers_cache = cache_root / "transformers"
    modules_cache = cache_root / "modules"

    for path in (hub_cache, xet_cache, assets_cache, transformers_cache, modules_cache):
        path.mkdir(parents=True, exist_ok=True)

    os.environ.setdefault("HF_HOME", str(cache_root))
    os.environ.setdefault("HF_HUB_CACHE", str(hub_cache))
    os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hub_cache))
    os.environ.setdefault("HF_XET_CACHE", str(xet_cache))
    os.environ.setdefault("HF_ASSETS_CACHE", str(assets_cache))
    os.environ.setdefault("TRANSFORMERS_CACHE", str(transformers_cache))
    os.environ.setdefault("HF_MODULES_CACHE", str(modules_cache))
    os.environ.setdefault("HF_HUB_DISABLE_XET", "1")


def load_request_from_stdin() -> dict[str, Any]:
    payload = json.loads(input())
    if not isinstance(payload, dict):
        raise ValueError("Runner stdin payload must be a JSON object.")
    return payload


def clean_generated_text(text: str) -> str:
    cleaned = str(text)
    cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL | re.IGNORECASE)
    cleaned = re.sub(r"<\|im_start\|>|<\|im_end\|>", "", cleaned)
    cleaned = re.sub(r"(?m)^(system|user|assistant)\s*$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
    return cleaned.strip()


def decode_generated_text(tokenizer: Any, generated_ids: list[int], *, limit: int | None = None) -> str:
    ids = list(generated_ids if limit is None else generated_ids[:limit])
    if tokenizer is None or not ids:
        return ""
    raw_text = str(tokenizer.decode(ids, skip_special_tokens=True))
    cleaned = clean_generated_text(raw_text)
    return cleaned or raw_text.strip()


def tok_per_sec_from_latency(latency_ms_per_token: float) -> float:
    if latency_ms_per_token <= 0.0:
        return 0.0
    return float(1000.0 / latency_ms_per_token)


def print_json(payload: dict[str, Any]) -> None:
    print(json.dumps(payload, sort_keys=True))