nbiish commited on
Commit
c6edde0
Β·
verified Β·
1 Parent(s): 834887b

Upload shared/inference_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. shared/inference_client.py +209 -0
shared/inference_client.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared HF Inference Client + Cooldown
3
+ ======================================
4
+ Lightweight wrapper around `huggingface_hub.InferenceClient` with:
5
+
6
+ - Per-call cooldown to prevent credit burn on live HF Spaces
7
+ - Async-friendly API
8
+ - Auto-fallback to procedural/story-template engines when inference fails
9
+ - Environment-driven config (works in HF Spaces and local)
10
+
11
+ The cooldown model:
12
+ - Each project has its own cooldown window (default 8s for cheap inference APIs)
13
+ - Within a session, after a successful inference, no new call can run until cooldown expires
14
+ - Failed inference does not start a cooldown (allow quick retry)
15
+ - `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import time
21
+ import logging
22
+ import threading
23
+ from dataclasses import dataclass, field
24
+ from typing import Optional, Dict, Any, Callable, List
25
+
26
+ log = logging.getLogger("inference")
27
+
28
+ # ── Environment knobs ─────────────────────────────────────────────────────────
29
+ # Override these in your Space's "Settings β†’ Variables and secrets".
30
+
31
+ # The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.)
32
+ INFERENCE_MODEL = os.environ.get(
33
+ "INFERENCE_MODEL",
34
+ "meta-llama/Llama-3.2-1B-Instruct", # 1B, free-tier, great prose
35
+ )
36
+
37
+ # Provider: "featherless-ai" (supports small instruct models), "hf-inference" (free serverless), "together", "fal-ai", "replicate"
38
+ # Free HF inference works for many small models; otherwise use a paid provider.
39
+ INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "featherless-ai")
40
+
41
+ # Token β€” read from HF Space secrets at runtime.
42
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
43
+
44
+ # Default cooldown between inferences, in seconds.
45
+ COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8"))
46
+
47
+ # Per-project override (keyed by app name)
48
+ PROJECT_COOLDOWN_OVERRIDES = {
49
+ "tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")),
50
+ "focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")),
51
+ "crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")),
52
+ }
53
+
54
+ # Max tokens to request (keeps costs bounded)
55
+ MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220"))
56
+
57
+
58
+ # ── Cooldown registry ────────────────────────────────────────────────────────
59
+ @dataclass
60
+ class _CooldownState:
61
+ last_call: float = 0.0
62
+ lock: threading.Lock = field(default_factory=threading.Lock)
63
+
64
+
65
+ _states: Dict[str, _CooldownState] = {}
66
+
67
+
68
+ def _state(project: str) -> _CooldownState:
69
+ if project not in _states:
70
+ _states[project] = _CooldownState()
71
+ return _states[project]
72
+
73
+
74
+ def cooldown_seconds_for(project: str) -> float:
75
+ return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS)
76
+
77
+
78
+ def cooldown_active(project: str) -> bool:
79
+ """Return True if the project is currently in cooldown (cannot run inference)."""
80
+ state = _state(project)
81
+ now = time.time()
82
+ if now - state.last_call < cooldown_seconds_for(project):
83
+ return True
84
+ return False
85
+
86
+
87
+ def cooldown_remaining(project: str) -> float:
88
+ """Seconds left in the cooldown window (0 if not in cooldown)."""
89
+ state = _state(project)
90
+ elapsed = time.time() - state.last_call
91
+ remaining = cooldown_seconds_for(project) - elapsed
92
+ return max(0.0, remaining)
93
+
94
+
95
+ def cooldown_status(project: str) -> dict:
96
+ """Snapshot of cooldown state for the UI."""
97
+ return {
98
+ "active": cooldown_active(project),
99
+ "remaining_seconds": round(cooldown_remaining(project), 2),
100
+ "window_seconds": cooldown_seconds_for(project),
101
+ }
102
+
103
+
104
+ def _mark_called(project: str) -> None:
105
+ state = _state(project)
106
+ with state.lock:
107
+ state.last_call = time.time()
108
+
109
+
110
+ # ── Inference client wrapper ─────────────────────────────────────────────────
111
+ class InferenceResult:
112
+ """A small wrapper so callers don't need to know which API returned text."""
113
+ def __init__(self, text: str, model: str, provider: str, latency_s: float):
114
+ self.text = text
115
+ self.model = model
116
+ self.provider = provider
117
+ self.latency_s = latency_s
118
+
119
+ def __repr__(self) -> str:
120
+ return f"InferenceResult(text={self.text[:50]!r}…, model={self.model!r}, latency={self.latency_s:.2f}s)"
121
+
122
+
123
+ def _get_client():
124
+ """Lazy-load the InferenceClient to keep boot fast."""
125
+ from huggingface_hub import InferenceClient
126
+ kwargs = {"token": HF_TOKEN}
127
+ if INFERENCE_PROVIDER:
128
+ kwargs["provider"] = INFERENCE_PROVIDER
129
+ return InferenceClient(**kwargs)
130
+
131
+
132
+ def generate(
133
+ project: str,
134
+ messages: List[Dict[str, str]],
135
+ *,
136
+ max_new_tokens: Optional[int] = None,
137
+ temperature: float = 0.7,
138
+ ) -> InferenceResult:
139
+ """Run a chat-style inference call, with cooldown enforcement.
140
+
141
+ `messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}].
142
+ Returns InferenceResult with `.text` (string) on success, or raises on failure.
143
+ Caller is responsible for fallback handling.
144
+ """
145
+ if cooldown_active(project):
146
+ remaining = cooldown_remaining(project)
147
+ raise RuntimeError(
148
+ f"cooldown active for {project!r}: {remaining:.1f}s remaining. "
149
+ f"This protects your HF/Modal credit budget."
150
+ )
151
+
152
+ max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
153
+ client = _get_client()
154
+ start = time.time()
155
+ response = client.chat_completion(
156
+ model=INFERENCE_MODEL,
157
+ messages=messages,
158
+ max_tokens=max_new_tokens,
159
+ temperature=temperature,
160
+ )
161
+ latency = time.time() - start
162
+ text = response.choices[0].message.content or ""
163
+ text = text.strip()
164
+ _mark_called(project)
165
+ return InferenceResult(
166
+ text=text,
167
+ model=INFERENCE_MODEL,
168
+ provider=INFERENCE_PROVIDER,
169
+ latency_s=latency,
170
+ )
171
+
172
+
173
+ def force_clear_cooldown(project: str) -> None:
174
+ """Manual escape hatch (e.g. for testing or admin overrides)."""
175
+ _state(project).last_call = 0.0
176
+
177
+
178
+ # ── Convenience: build messages + format result ──────────────────────────────
179
+ def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]:
180
+ """Build an OpenAI-style message list with optional prior turns.
181
+
182
+ `history` is in the same [{role, content}, ...] format. New turns are appended.
183
+ """
184
+ msgs: List[Dict[str, str]] = [{"role": "system", "content": system}]
185
+ if history:
186
+ msgs.extend(history)
187
+ msgs.append({"role": "user", "content": user})
188
+ return msgs
189
+
190
+
191
+ __all__ = [
192
+ "InferenceResult",
193
+ "cooldown_active",
194
+ "cooldown_remaining",
195
+ "cooldown_seconds_for",
196
+ "cooldown_status",
197
+ "force_clear_cooldown",
198
+ "generate",
199
+ "chat_messages",
200
+ "INFERENCE_MODEL",
201
+ "INFERENCE_PROVIDER",
202
+ "MAX_NEW_TOKENS",
203
+ ]
204
+
205
+
206
+ if __name__ == "__main__":
207
+ # Smoke test
208
+ for p in ("tinybard", "focusfriend", "crittercalm"):
209
+ print(p, "cooldown:", cooldown_status(p))