Spaces:
Sleeping
Sleeping
| """ | |
| Shared HF Inference Client + Cooldown | |
| ====================================== | |
| Lightweight wrapper around `huggingface_hub.InferenceClient` with: | |
| - Per-call cooldown to prevent credit burn on live HF Spaces | |
| - Async-friendly API | |
| - Auto-fallback to procedural/story-template engines when inference fails | |
| - Environment-driven config (works in HF Spaces and local) | |
| The cooldown model: | |
| - Each project has its own cooldown window (default 8s for cheap inference APIs) | |
| - Within a session, after a successful inference, no new call can run until cooldown expires | |
| - Failed inference does not start a cooldown (allow quick retry) | |
| - `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import time | |
| import logging | |
| import threading | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, Any, Callable, List | |
| log = logging.getLogger("inference") | |
| # ── Environment knobs ───────────────────────────────────────────────────────── | |
| # Override these in your Space's "Settings → Variables and secrets". | |
| # The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.) | |
| INFERENCE_MODEL = os.environ.get( | |
| "INFERENCE_MODEL", | |
| "Qwen/Qwen2.5-1.5B-Instruct", # 1.5B, fast, free-tier friendly | |
| ) | |
| # Provider: "hf-inference" (free serverless), "together", "fal-ai", "replicate" | |
| # Free HF inference works for many small models; otherwise use a paid provider. | |
| INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "hf-inference") | |
| # Token — read from HF Space secrets at runtime. | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| # Default cooldown between inferences, in seconds. | |
| COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8")) | |
| # Per-project override (keyed by app name) | |
| PROJECT_COOLDOWN_OVERRIDES = { | |
| "tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")), | |
| "focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")), | |
| "crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")), | |
| } | |
| # Max tokens to request (keeps costs bounded) | |
| MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220")) | |
| # ── Cooldown registry ──────────────────────────────────────────────────────── | |
| class _CooldownState: | |
| last_call: float = 0.0 | |
| lock: threading.Lock = field(default_factory=threading.Lock) | |
| _states: Dict[str, _CooldownState] = {} | |
| def _state(project: str) -> _CooldownState: | |
| if project not in _states: | |
| _states[project] = _CooldownState() | |
| return _states[project] | |
| def cooldown_seconds_for(project: str) -> float: | |
| return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS) | |
| def cooldown_active(project: str) -> bool: | |
| """Return True if the project is currently in cooldown (cannot run inference).""" | |
| state = _state(project) | |
| now = time.time() | |
| if now - state.last_call < cooldown_seconds_for(project): | |
| return True | |
| return False | |
| def cooldown_remaining(project: str) -> float: | |
| """Seconds left in the cooldown window (0 if not in cooldown).""" | |
| state = _state(project) | |
| elapsed = time.time() - state.last_call | |
| remaining = cooldown_seconds_for(project) - elapsed | |
| return max(0.0, remaining) | |
| def cooldown_status(project: str) -> dict: | |
| """Snapshot of cooldown state for the UI.""" | |
| return { | |
| "active": cooldown_active(project), | |
| "remaining_seconds": round(cooldown_remaining(project), 2), | |
| "window_seconds": cooldown_seconds_for(project), | |
| } | |
| def _mark_called(project: str) -> None: | |
| state = _state(project) | |
| with state.lock: | |
| state.last_call = time.time() | |
| # ── Inference client wrapper ───────────────────────────────────────────────── | |
| class InferenceResult: | |
| """A small wrapper so callers don't need to know which API returned text.""" | |
| def __init__(self, text: str, model: str, provider: str, latency_s: float): | |
| self.text = text | |
| self.model = model | |
| self.provider = provider | |
| self.latency_s = latency_s | |
| def __repr__(self) -> str: | |
| return f"InferenceResult(text={self.text[:50]!r}…, model={self.model!r}, latency={self.latency_s:.2f}s)" | |
| # We use direct HTTP requests via httpx to bypass huggingface_hub library routing bugs | |
| # and force the use of the free serverless Inference API. | |
| import httpx | |
| def generate( | |
| project: str, | |
| messages: List[Dict[str, str]], | |
| *, | |
| max_new_tokens: Optional[int] = None, | |
| temperature: float = 0.7, | |
| token: Optional[str] = None, | |
| model: Optional[str] = None, | |
| custom_endpoint: Optional[str] = None, | |
| ) -> InferenceResult: | |
| """Run a chat-style inference call, with cooldown enforcement. | |
| `messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}]. | |
| Returns InferenceResult with `.text` (string) on success, or raises on failure. | |
| Caller is responsible for fallback handling. | |
| """ | |
| if cooldown_active(project): | |
| remaining = cooldown_remaining(project) | |
| raise RuntimeError( | |
| f"cooldown active for {project!r}: {remaining:.1f}s remaining. " | |
| f"This protects your HF/Modal credit budget." | |
| ) | |
| max_new_tokens = max_new_tokens or MAX_NEW_TOKENS | |
| start = time.time() | |
| # Format messages list into a plain text dialogue prompt | |
| prompt = "" | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content_text = msg.get("content", "").strip() | |
| if role == "system": | |
| prompt += f"System Instructions:\n{content_text}\n\n" | |
| elif role == "user": | |
| prompt += f"User:\n{content_text}\n\n" | |
| elif role == "assistant": | |
| prompt += f"Assistant:\n{content_text}\n\n" | |
| prompt += "Assistant:\n" | |
| # Use overrides if provided | |
| use_model = model or INFERENCE_MODEL | |
| use_token = token or HF_TOKEN | |
| # Call direct HF serverless Inference API | |
| url = f"https://api.huggingface.co/models/{use_model}" | |
| headers = {} | |
| if use_token: | |
| headers["Authorization"] = f"Bearer {use_token}" | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "return_full_text": False, | |
| } | |
| } | |
| with httpx.Client(trust_env=True) as http_client: | |
| resp = http_client.post(url, json=payload, headers=headers, timeout=30.0) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f"HF Inference API Error {resp.status_code}: {resp.text}") | |
| data = resp.json() | |
| # Direct model endpoint returns a list of completions | |
| if isinstance(data, list) and len(data) > 0: | |
| text = data[0].get("generated_text", "") | |
| elif isinstance(data, dict): | |
| text = data.get("generated_text", "") | |
| else: | |
| text = str(data) | |
| latency = time.time() - start | |
| text = text.strip() | |
| _mark_called(project) | |
| return InferenceResult( | |
| text=text, | |
| model=use_model, | |
| provider=INFERENCE_PROVIDER, | |
| latency_s=latency, | |
| ) | |
| def force_clear_cooldown(project: str) -> None: | |
| """Manual escape hatch (e.g. for testing or admin overrides).""" | |
| _state(project).last_call = 0.0 | |
| # ── Convenience: build messages + format result ────────────────────────────── | |
| def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]: | |
| """Build an OpenAI-style message list with optional prior turns. | |
| `history` is in the same [{role, content}, ...] format. New turns are appended. | |
| """ | |
| msgs: List[Dict[str, str]] = [{"role": "system", "content": system}] | |
| if history: | |
| msgs.extend(history) | |
| msgs.append({"role": "user", "content": user}) | |
| return msgs | |
| __all__ = [ | |
| "InferenceResult", | |
| "cooldown_active", | |
| "cooldown_remaining", | |
| "cooldown_seconds_for", | |
| "cooldown_status", | |
| "force_clear_cooldown", | |
| "generate", | |
| "chat_messages", | |
| "INFERENCE_MODEL", | |
| "INFERENCE_PROVIDER", | |
| "MAX_NEW_TOKENS", | |
| ] | |
| if __name__ == "__main__": | |
| # Smoke test | |
| for p in ("tinybard", "focusfriend", "crittercalm"): | |
| print(p, "cooldown:", cooldown_status(p)) | |