tinybard / shared /inference_client.py
Hermes Bot
feat: add support for custom OpenAI-compatible endpoints in config UI
5941741 unverified
Raw
History Blame Contribute Delete
8.64 kB
"""
Shared HF Inference Client + Cooldown
======================================
Lightweight wrapper around `huggingface_hub.InferenceClient` with:
- Per-call cooldown to prevent credit burn on live HF Spaces
- Async-friendly API
- Auto-fallback to procedural/story-template engines when inference fails
- Environment-driven config (works in HF Spaces and local)
The cooldown model:
- Each project has its own cooldown window (default 8s for cheap inference APIs)
- Within a session, after a successful inference, no new call can run until cooldown expires
- Failed inference does not start a cooldown (allow quick retry)
- `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown
"""
from __future__ import annotations
import os
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List
log = logging.getLogger("inference")
# ── Environment knobs ─────────────────────────────────────────────────────────
# Override these in your Space's "Settings → Variables and secrets".
# The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.)
INFERENCE_MODEL = os.environ.get(
"INFERENCE_MODEL",
"Qwen/Qwen2.5-1.5B-Instruct", # 1.5B, fast, free-tier friendly
)
# Provider: "hf-inference" (free serverless), "together", "fal-ai", "replicate"
# Free HF inference works for many small models; otherwise use a paid provider.
INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "hf-inference")
# Token — read from HF Space secrets at runtime.
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
# Default cooldown between inferences, in seconds.
COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8"))
# Per-project override (keyed by app name)
PROJECT_COOLDOWN_OVERRIDES = {
"tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")),
"focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")),
"crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")),
}
# Max tokens to request (keeps costs bounded)
MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220"))
# ── Cooldown registry ────────────────────────────────────────────────────────
@dataclass
class _CooldownState:
last_call: float = 0.0
lock: threading.Lock = field(default_factory=threading.Lock)
_states: Dict[str, _CooldownState] = {}
def _state(project: str) -> _CooldownState:
if project not in _states:
_states[project] = _CooldownState()
return _states[project]
def cooldown_seconds_for(project: str) -> float:
return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS)
def cooldown_active(project: str) -> bool:
"""Return True if the project is currently in cooldown (cannot run inference)."""
state = _state(project)
now = time.time()
if now - state.last_call < cooldown_seconds_for(project):
return True
return False
def cooldown_remaining(project: str) -> float:
"""Seconds left in the cooldown window (0 if not in cooldown)."""
state = _state(project)
elapsed = time.time() - state.last_call
remaining = cooldown_seconds_for(project) - elapsed
return max(0.0, remaining)
def cooldown_status(project: str) -> dict:
"""Snapshot of cooldown state for the UI."""
return {
"active": cooldown_active(project),
"remaining_seconds": round(cooldown_remaining(project), 2),
"window_seconds": cooldown_seconds_for(project),
}
def _mark_called(project: str) -> None:
state = _state(project)
with state.lock:
state.last_call = time.time()
# ── Inference client wrapper ─────────────────────────────────────────────────
class InferenceResult:
"""A small wrapper so callers don't need to know which API returned text."""
def __init__(self, text: str, model: str, provider: str, latency_s: float):
self.text = text
self.model = model
self.provider = provider
self.latency_s = latency_s
def __repr__(self) -> str:
return f"InferenceResult(text={self.text[:50]!r}…, model={self.model!r}, latency={self.latency_s:.2f}s)"
# We use direct HTTP requests via httpx to bypass huggingface_hub library routing bugs
# and force the use of the free serverless Inference API.
import httpx
def generate(
project: str,
messages: List[Dict[str, str]],
*,
max_new_tokens: Optional[int] = None,
temperature: float = 0.7,
token: Optional[str] = None,
model: Optional[str] = None,
custom_endpoint: Optional[str] = None,
) -> InferenceResult:
"""Run a chat-style inference call, with cooldown enforcement.
`messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}].
Returns InferenceResult with `.text` (string) on success, or raises on failure.
Caller is responsible for fallback handling.
"""
if cooldown_active(project):
remaining = cooldown_remaining(project)
raise RuntimeError(
f"cooldown active for {project!r}: {remaining:.1f}s remaining. "
f"This protects your HF/Modal credit budget."
)
max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
start = time.time()
# Format messages list into a plain text dialogue prompt
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content_text = msg.get("content", "").strip()
if role == "system":
prompt += f"System Instructions:\n{content_text}\n\n"
elif role == "user":
prompt += f"User:\n{content_text}\n\n"
elif role == "assistant":
prompt += f"Assistant:\n{content_text}\n\n"
prompt += "Assistant:\n"
# Use overrides if provided
use_model = model or INFERENCE_MODEL
use_token = token or HF_TOKEN
# Call direct HF serverless Inference API
url = f"https://api.huggingface.co/models/{use_model}"
headers = {}
if use_token:
headers["Authorization"] = f"Bearer {use_token}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"return_full_text": False,
}
}
with httpx.Client(trust_env=True) as http_client:
resp = http_client.post(url, json=payload, headers=headers, timeout=30.0)
if resp.status_code != 200:
raise RuntimeError(f"HF Inference API Error {resp.status_code}: {resp.text}")
data = resp.json()
# Direct model endpoint returns a list of completions
if isinstance(data, list) and len(data) > 0:
text = data[0].get("generated_text", "")
elif isinstance(data, dict):
text = data.get("generated_text", "")
else:
text = str(data)
latency = time.time() - start
text = text.strip()
_mark_called(project)
return InferenceResult(
text=text,
model=use_model,
provider=INFERENCE_PROVIDER,
latency_s=latency,
)
def force_clear_cooldown(project: str) -> None:
"""Manual escape hatch (e.g. for testing or admin overrides)."""
_state(project).last_call = 0.0
# ── Convenience: build messages + format result ──────────────────────────────
def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]:
"""Build an OpenAI-style message list with optional prior turns.
`history` is in the same [{role, content}, ...] format. New turns are appended.
"""
msgs: List[Dict[str, str]] = [{"role": "system", "content": system}]
if history:
msgs.extend(history)
msgs.append({"role": "user", "content": user})
return msgs
__all__ = [
"InferenceResult",
"cooldown_active",
"cooldown_remaining",
"cooldown_seconds_for",
"cooldown_status",
"force_clear_cooldown",
"generate",
"chat_messages",
"INFERENCE_MODEL",
"INFERENCE_PROVIDER",
"MAX_NEW_TOKENS",
]
if __name__ == "__main__":
# Smoke test
for p in ("tinybard", "focusfriend", "crittercalm"):
print(p, "cooldown:", cooldown_status(p))