Spaces:
Sleeping
Sleeping
File size: 3,357 Bytes
0b89610 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | """
Shared OpenAI-compatible runtime helpers for LLM-backed benchmark features.
"""
from __future__ import annotations
import asyncio
import json
import os
import re
from dataclasses import dataclass
from typing import Any
from openai import OpenAI
DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
@dataclass(frozen=True, slots=True)
class JsonCallResult:
data: dict[str, Any]
prompt_tokens: int | None
completion_tokens: int | None
def model_name() -> str:
return os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
def resolve_llm_credentials() -> tuple[str | None, str | None, str | None]:
api_base_url = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
api_key = os.getenv("API_KEY")
legacy_token = os.getenv("HF_TOKEN")
if api_key:
return api_base_url, api_key, "proxy"
if legacy_token:
return api_base_url, legacy_token, "legacy"
return None, None, None
def llm_configured() -> bool:
_base_url, api_key, _auth_mode = resolve_llm_credentials()
return bool(api_key)
def _extract_json_object(text: str) -> dict[str, Any]:
payload = text.strip()
try:
return json.loads(payload)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", payload, re.DOTALL)
if not match:
raise
return json.loads(match.group(0))
async def call_json(
*,
system_prompt: str,
user_payload: dict[str, Any] | list[Any] | str,
temperature: float = 0.0,
max_output_tokens: int = 400,
) -> JsonCallResult:
api_base_url, client_api_key, _auth_mode = resolve_llm_credentials()
if not api_base_url or not client_api_key:
raise RuntimeError("llm_credentials_missing")
client = OpenAI(base_url=api_base_url, api_key=client_api_key)
user_content = user_payload if isinstance(user_payload, str) else json.dumps(user_payload, ensure_ascii=True)
def _call():
return client.chat.completions.create(
model=model_name(),
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
response_format={"type": "json_object"},
temperature=temperature,
max_tokens=max_output_tokens,
)
response = await asyncio.to_thread(_call)
content = response.choices[0].message.content or "{}"
usage = getattr(response, "usage", None)
return JsonCallResult(
data=_extract_json_object(content),
prompt_tokens=getattr(usage, "prompt_tokens", None),
completion_tokens=getattr(usage, "completion_tokens", None),
)
async def estimate_tokens(text: str) -> int:
cleaned = text.strip()
if not cleaned:
return 0
if not llm_configured():
return max(1, len(cleaned) // 4)
result = await call_json(
system_prompt=(
"You are TOKEN_ESTIMATOR. Estimate how many model tokens the provided text would use "
"for the current chat model. Return JSON with exactly one integer field: "
'{"token_count": 123}'
),
user_payload={"text": cleaned},
temperature=0.0,
max_output_tokens=32,
)
token_count = int(result.data.get("token_count", 0))
return max(1, token_count)
|