File size: 7,400 Bytes
c6edde0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
Shared HF Inference Client + Cooldown
======================================
Lightweight wrapper around `huggingface_hub.InferenceClient` with:
- Per-call cooldown to prevent credit burn on live HF Spaces
- Async-friendly API
- Auto-fallback to procedural/story-template engines when inference fails
- Environment-driven config (works in HF Spaces and local)
The cooldown model:
- Each project has its own cooldown window (default 8s for cheap inference APIs)
- Within a session, after a successful inference, no new call can run until cooldown expires
- Failed inference does not start a cooldown (allow quick retry)
- `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown
"""
from __future__ import annotations
import os
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List
log = logging.getLogger("inference")
# ββ Environment knobs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Override these in your Space's "Settings β Variables and secrets".
# The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.)
INFERENCE_MODEL = os.environ.get(
"INFERENCE_MODEL",
"meta-llama/Llama-3.2-1B-Instruct", # 1B, free-tier, great prose
)
# Provider: "featherless-ai" (supports small instruct models), "hf-inference" (free serverless), "together", "fal-ai", "replicate"
# Free HF inference works for many small models; otherwise use a paid provider.
INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "featherless-ai")
# Token β read from HF Space secrets at runtime.
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
# Default cooldown between inferences, in seconds.
COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8"))
# Per-project override (keyed by app name)
PROJECT_COOLDOWN_OVERRIDES = {
"tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")),
"focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")),
"crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")),
}
# Max tokens to request (keeps costs bounded)
MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220"))
# ββ Cooldown registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class _CooldownState:
last_call: float = 0.0
lock: threading.Lock = field(default_factory=threading.Lock)
_states: Dict[str, _CooldownState] = {}
def _state(project: str) -> _CooldownState:
if project not in _states:
_states[project] = _CooldownState()
return _states[project]
def cooldown_seconds_for(project: str) -> float:
return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS)
def cooldown_active(project: str) -> bool:
"""Return True if the project is currently in cooldown (cannot run inference)."""
state = _state(project)
now = time.time()
if now - state.last_call < cooldown_seconds_for(project):
return True
return False
def cooldown_remaining(project: str) -> float:
"""Seconds left in the cooldown window (0 if not in cooldown)."""
state = _state(project)
elapsed = time.time() - state.last_call
remaining = cooldown_seconds_for(project) - elapsed
return max(0.0, remaining)
def cooldown_status(project: str) -> dict:
"""Snapshot of cooldown state for the UI."""
return {
"active": cooldown_active(project),
"remaining_seconds": round(cooldown_remaining(project), 2),
"window_seconds": cooldown_seconds_for(project),
}
def _mark_called(project: str) -> None:
state = _state(project)
with state.lock:
state.last_call = time.time()
# ββ Inference client wrapper βββββββββββββββββββββββββββββββββββββββββββββββββ
class InferenceResult:
"""A small wrapper so callers don't need to know which API returned text."""
def __init__(self, text: str, model: str, provider: str, latency_s: float):
self.text = text
self.model = model
self.provider = provider
self.latency_s = latency_s
def __repr__(self) -> str:
return f"InferenceResult(text={self.text[:50]!r}β¦, model={self.model!r}, latency={self.latency_s:.2f}s)"
def _get_client():
"""Lazy-load the InferenceClient to keep boot fast."""
from huggingface_hub import InferenceClient
kwargs = {"token": HF_TOKEN}
if INFERENCE_PROVIDER:
kwargs["provider"] = INFERENCE_PROVIDER
return InferenceClient(**kwargs)
def generate(
project: str,
messages: List[Dict[str, str]],
*,
max_new_tokens: Optional[int] = None,
temperature: float = 0.7,
) -> InferenceResult:
"""Run a chat-style inference call, with cooldown enforcement.
`messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}].
Returns InferenceResult with `.text` (string) on success, or raises on failure.
Caller is responsible for fallback handling.
"""
if cooldown_active(project):
remaining = cooldown_remaining(project)
raise RuntimeError(
f"cooldown active for {project!r}: {remaining:.1f}s remaining. "
f"This protects your HF/Modal credit budget."
)
max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
client = _get_client()
start = time.time()
response = client.chat_completion(
model=INFERENCE_MODEL,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
)
latency = time.time() - start
text = response.choices[0].message.content or ""
text = text.strip()
_mark_called(project)
return InferenceResult(
text=text,
model=INFERENCE_MODEL,
provider=INFERENCE_PROVIDER,
latency_s=latency,
)
def force_clear_cooldown(project: str) -> None:
"""Manual escape hatch (e.g. for testing or admin overrides)."""
_state(project).last_call = 0.0
# ββ Convenience: build messages + format result ββββββββββββββββββββββββββββββ
def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]:
"""Build an OpenAI-style message list with optional prior turns.
`history` is in the same [{role, content}, ...] format. New turns are appended.
"""
msgs: List[Dict[str, str]] = [{"role": "system", "content": system}]
if history:
msgs.extend(history)
msgs.append({"role": "user", "content": user})
return msgs
__all__ = [
"InferenceResult",
"cooldown_active",
"cooldown_remaining",
"cooldown_seconds_for",
"cooldown_status",
"force_clear_cooldown",
"generate",
"chat_messages",
"INFERENCE_MODEL",
"INFERENCE_PROVIDER",
"MAX_NEW_TOKENS",
]
if __name__ == "__main__":
# Smoke test
for p in ("tinybard", "focusfriend", "crittercalm"):
print(p, "cooldown:", cooldown_status(p))
|