Spaces:
Sleeping
Sleeping
File size: 8,636 Bytes
140c4d5 780d3c3 140c4d5 1a5d3d0 140c4d5 1a5d3d0 5941741 140c4d5 2a059de 1a5d3d0 445e1fc 1a5d3d0 167678e 1a5d3d0 140c4d5 1a5d3d0 140c4d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | """
Shared HF Inference Client + Cooldown
======================================
Lightweight wrapper around `huggingface_hub.InferenceClient` with:
- Per-call cooldown to prevent credit burn on live HF Spaces
- Async-friendly API
- Auto-fallback to procedural/story-template engines when inference fails
- Environment-driven config (works in HF Spaces and local)
The cooldown model:
- Each project has its own cooldown window (default 8s for cheap inference APIs)
- Within a session, after a successful inference, no new call can run until cooldown expires
- Failed inference does not start a cooldown (allow quick retry)
- `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown
"""
from __future__ import annotations
import os
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List
log = logging.getLogger("inference")
# ── Environment knobs ─────────────────────────────────────────────────────────
# Override these in your Space's "Settings → Variables and secrets".
# The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.)
INFERENCE_MODEL = os.environ.get(
"INFERENCE_MODEL",
"Qwen/Qwen2.5-1.5B-Instruct", # 1.5B, fast, free-tier friendly
)
# Provider: "hf-inference" (free serverless), "together", "fal-ai", "replicate"
# Free HF inference works for many small models; otherwise use a paid provider.
INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "hf-inference")
# Token — read from HF Space secrets at runtime.
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
# Default cooldown between inferences, in seconds.
COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8"))
# Per-project override (keyed by app name)
PROJECT_COOLDOWN_OVERRIDES = {
"tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")),
"focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")),
"crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")),
}
# Max tokens to request (keeps costs bounded)
MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220"))
# ── Cooldown registry ────────────────────────────────────────────────────────
@dataclass
class _CooldownState:
last_call: float = 0.0
lock: threading.Lock = field(default_factory=threading.Lock)
_states: Dict[str, _CooldownState] = {}
def _state(project: str) -> _CooldownState:
if project not in _states:
_states[project] = _CooldownState()
return _states[project]
def cooldown_seconds_for(project: str) -> float:
return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS)
def cooldown_active(project: str) -> bool:
"""Return True if the project is currently in cooldown (cannot run inference)."""
state = _state(project)
now = time.time()
if now - state.last_call < cooldown_seconds_for(project):
return True
return False
def cooldown_remaining(project: str) -> float:
"""Seconds left in the cooldown window (0 if not in cooldown)."""
state = _state(project)
elapsed = time.time() - state.last_call
remaining = cooldown_seconds_for(project) - elapsed
return max(0.0, remaining)
def cooldown_status(project: str) -> dict:
"""Snapshot of cooldown state for the UI."""
return {
"active": cooldown_active(project),
"remaining_seconds": round(cooldown_remaining(project), 2),
"window_seconds": cooldown_seconds_for(project),
}
def _mark_called(project: str) -> None:
state = _state(project)
with state.lock:
state.last_call = time.time()
# ── Inference client wrapper ─────────────────────────────────────────────────
class InferenceResult:
"""A small wrapper so callers don't need to know which API returned text."""
def __init__(self, text: str, model: str, provider: str, latency_s: float):
self.text = text
self.model = model
self.provider = provider
self.latency_s = latency_s
def __repr__(self) -> str:
return f"InferenceResult(text={self.text[:50]!r}…, model={self.model!r}, latency={self.latency_s:.2f}s)"
# We use direct HTTP requests via httpx to bypass huggingface_hub library routing bugs
# and force the use of the free serverless Inference API.
import httpx
def generate(
project: str,
messages: List[Dict[str, str]],
*,
max_new_tokens: Optional[int] = None,
temperature: float = 0.7,
token: Optional[str] = None,
model: Optional[str] = None,
custom_endpoint: Optional[str] = None,
) -> InferenceResult:
"""Run a chat-style inference call, with cooldown enforcement.
`messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}].
Returns InferenceResult with `.text` (string) on success, or raises on failure.
Caller is responsible for fallback handling.
"""
if cooldown_active(project):
remaining = cooldown_remaining(project)
raise RuntimeError(
f"cooldown active for {project!r}: {remaining:.1f}s remaining. "
f"This protects your HF/Modal credit budget."
)
max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
start = time.time()
# Format messages list into a plain text dialogue prompt
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content_text = msg.get("content", "").strip()
if role == "system":
prompt += f"System Instructions:\n{content_text}\n\n"
elif role == "user":
prompt += f"User:\n{content_text}\n\n"
elif role == "assistant":
prompt += f"Assistant:\n{content_text}\n\n"
prompt += "Assistant:\n"
# Use overrides if provided
use_model = model or INFERENCE_MODEL
use_token = token or HF_TOKEN
# Call direct HF serverless Inference API
url = f"https://api.huggingface.co/models/{use_model}"
headers = {}
if use_token:
headers["Authorization"] = f"Bearer {use_token}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"return_full_text": False,
}
}
with httpx.Client(trust_env=True) as http_client:
resp = http_client.post(url, json=payload, headers=headers, timeout=30.0)
if resp.status_code != 200:
raise RuntimeError(f"HF Inference API Error {resp.status_code}: {resp.text}")
data = resp.json()
# Direct model endpoint returns a list of completions
if isinstance(data, list) and len(data) > 0:
text = data[0].get("generated_text", "")
elif isinstance(data, dict):
text = data.get("generated_text", "")
else:
text = str(data)
latency = time.time() - start
text = text.strip()
_mark_called(project)
return InferenceResult(
text=text,
model=use_model,
provider=INFERENCE_PROVIDER,
latency_s=latency,
)
def force_clear_cooldown(project: str) -> None:
"""Manual escape hatch (e.g. for testing or admin overrides)."""
_state(project).last_call = 0.0
# ── Convenience: build messages + format result ──────────────────────────────
def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]:
"""Build an OpenAI-style message list with optional prior turns.
`history` is in the same [{role, content}, ...] format. New turns are appended.
"""
msgs: List[Dict[str, str]] = [{"role": "system", "content": system}]
if history:
msgs.extend(history)
msgs.append({"role": "user", "content": user})
return msgs
__all__ = [
"InferenceResult",
"cooldown_active",
"cooldown_remaining",
"cooldown_seconds_for",
"cooldown_status",
"force_clear_cooldown",
"generate",
"chat_messages",
"INFERENCE_MODEL",
"INFERENCE_PROVIDER",
"MAX_NEW_TOKENS",
]
if __name__ == "__main__":
# Smoke test
for p in ("tinybard", "focusfriend", "crittercalm"):
print(p, "cooldown:", cooldown_status(p))
|