File size: 8,636 Bytes
140c4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780d3c3
140c4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a5d3d0
 
 
140c4d5
 
 
 
 
 
1a5d3d0
 
5941741
140c4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a059de
 
 
 
 
 
 
 
 
 
 
 
 
 
1a5d3d0
 
 
 
 
445e1fc
1a5d3d0
 
 
 
 
 
 
 
 
 
 
 
 
167678e
 
1a5d3d0
 
 
 
 
 
 
 
 
 
 
 
140c4d5
 
 
 
 
1a5d3d0
140c4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
Shared HF Inference Client + Cooldown
======================================
Lightweight wrapper around `huggingface_hub.InferenceClient` with:

- Per-call cooldown to prevent credit burn on live HF Spaces
- Async-friendly API
- Auto-fallback to procedural/story-template engines when inference fails
- Environment-driven config (works in HF Spaces and local)

The cooldown model:
- Each project has its own cooldown window (default 8s for cheap inference APIs)
- Within a session, after a successful inference, no new call can run until cooldown expires
- Failed inference does not start a cooldown (allow quick retry)
- `cooldown_active()` is the public check; FastAPI handlers short-circuit on active cooldown
"""
from __future__ import annotations

import os
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List

log = logging.getLogger("inference")

# ── Environment knobs ─────────────────────────────────────────────────────────
# Override these in your Space's "Settings → Variables and secrets".

# The HF model id used for text generation (VibeThinker 1.5B, Gemma 4 12B, etc.)
INFERENCE_MODEL = os.environ.get(
    "INFERENCE_MODEL",
    "Qwen/Qwen2.5-1.5B-Instruct",  # 1.5B, fast, free-tier friendly
)

# Provider: "hf-inference" (free serverless), "together", "fal-ai", "replicate"
# Free HF inference works for many small models; otherwise use a paid provider.
INFERENCE_PROVIDER = os.environ.get("INFERENCE_PROVIDER", "hf-inference")

# Token — read from HF Space secrets at runtime.
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")

# Default cooldown between inferences, in seconds.
COOLDOWN_SECONDS = float(os.environ.get("INFERENCE_COOLDOWN_SECONDS", "8"))

# Per-project override (keyed by app name)
PROJECT_COOLDOWN_OVERRIDES = {
    "tinybard": float(os.environ.get("TINYBARD_COOLDOWN_SECONDS", "6")),
    "focusfriend": float(os.environ.get("FOCUSFRIEND_COOLDOWN_SECONDS", "10")),
    "crittercalm": float(os.environ.get("CRITTERCALM_COOLDOWN_SECONDS", "12")),
}

# Max tokens to request (keeps costs bounded)
MAX_NEW_TOKENS = int(os.environ.get("INFERENCE_MAX_TOKENS", "220"))
# ── Cooldown registry ────────────────────────────────────────────────────────
@dataclass
class _CooldownState:
    last_call: float = 0.0
    lock: threading.Lock = field(default_factory=threading.Lock)
_states: Dict[str, _CooldownState] = {}
def _state(project: str) -> _CooldownState:
    if project not in _states:
        _states[project] = _CooldownState()
    return _states[project]
def cooldown_seconds_for(project: str) -> float:
    return PROJECT_COOLDOWN_OVERRIDES.get(project, COOLDOWN_SECONDS)
def cooldown_active(project: str) -> bool:
    """Return True if the project is currently in cooldown (cannot run inference)."""
    state = _state(project)
    now = time.time()
    if now - state.last_call < cooldown_seconds_for(project):
        return True
    return False
def cooldown_remaining(project: str) -> float:
    """Seconds left in the cooldown window (0 if not in cooldown)."""
    state = _state(project)
    elapsed = time.time() - state.last_call
    remaining = cooldown_seconds_for(project) - elapsed
    return max(0.0, remaining)
def cooldown_status(project: str) -> dict:
    """Snapshot of cooldown state for the UI."""
    return {
        "active": cooldown_active(project),
        "remaining_seconds": round(cooldown_remaining(project), 2),
        "window_seconds": cooldown_seconds_for(project),
    }
def _mark_called(project: str) -> None:
    state = _state(project)
    with state.lock:
        state.last_call = time.time()
# ── Inference client wrapper ─────────────────────────────────────────────────
class InferenceResult:
    """A small wrapper so callers don't need to know which API returned text."""
    def __init__(self, text: str, model: str, provider: str, latency_s: float):
        self.text = text
        self.model = model
        self.provider = provider
        self.latency_s = latency_s

    def __repr__(self) -> str:
        return f"InferenceResult(text={self.text[:50]!r}…, model={self.model!r}, latency={self.latency_s:.2f}s)"
# We use direct HTTP requests via httpx to bypass huggingface_hub library routing bugs
# and force the use of the free serverless Inference API.
import httpx
def generate(
    project: str,
    messages: List[Dict[str, str]],
    *,
    max_new_tokens: Optional[int] = None,
    temperature: float = 0.7,
    token: Optional[str] = None,
    model: Optional[str] = None,
    custom_endpoint: Optional[str] = None,
) -> InferenceResult:
    """Run a chat-style inference call, with cooldown enforcement.

    `messages` follows OpenAI chat format: [{"role": "user|assistant|system", "content": "..."}].
    Returns InferenceResult with `.text` (string) on success, or raises on failure.
    Caller is responsible for fallback handling.
    """
    if cooldown_active(project):
        remaining = cooldown_remaining(project)
        raise RuntimeError(
            f"cooldown active for {project!r}: {remaining:.1f}s remaining. "
            f"This protects your HF/Modal credit budget."
        )

    max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
    start = time.time()
    
    # Format messages list into a plain text dialogue prompt
    prompt = ""
    for msg in messages:
        role = msg.get("role", "user")
        content_text = msg.get("content", "").strip()
        if role == "system":
            prompt += f"System Instructions:\n{content_text}\n\n"
        elif role == "user":
            prompt += f"User:\n{content_text}\n\n"
        elif role == "assistant":
            prompt += f"Assistant:\n{content_text}\n\n"
    prompt += "Assistant:\n"

    # Use overrides if provided
    use_model = model or INFERENCE_MODEL
    use_token = token or HF_TOKEN

    # Call direct HF serverless Inference API
    url = f"https://api.huggingface.co/models/{use_model}"
    headers = {}
    if use_token:
        headers["Authorization"] = f"Bearer {use_token}"
    
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "return_full_text": False,
        }
    }
    
    with httpx.Client(trust_env=True) as http_client:
        resp = http_client.post(url, json=payload, headers=headers, timeout=30.0)
    if resp.status_code != 200:
        raise RuntimeError(f"HF Inference API Error {resp.status_code}: {resp.text}")
        
    data = resp.json()
    # Direct model endpoint returns a list of completions
    if isinstance(data, list) and len(data) > 0:
        text = data[0].get("generated_text", "")
    elif isinstance(data, dict):
        text = data.get("generated_text", "")
    else:
        text = str(data)
        
    latency = time.time() - start
    text = text.strip()
    _mark_called(project)
    return InferenceResult(
        text=text,
        model=use_model,
        provider=INFERENCE_PROVIDER,
        latency_s=latency,
    )
def force_clear_cooldown(project: str) -> None:
    """Manual escape hatch (e.g. for testing or admin overrides)."""
    _state(project).last_call = 0.0
# ── Convenience: build messages + format result ──────────────────────────────
def chat_messages(system: str, user: str, history: Optional[List[Dict[str, str]]] = None) -> List[Dict[str, str]]:
    """Build an OpenAI-style message list with optional prior turns.

    `history` is in the same [{role, content}, ...] format. New turns are appended.
    """
    msgs: List[Dict[str, str]] = [{"role": "system", "content": system}]
    if history:
        msgs.extend(history)
    msgs.append({"role": "user", "content": user})
    return msgs
__all__ = [
    "InferenceResult",
    "cooldown_active",
    "cooldown_remaining",
    "cooldown_seconds_for",
    "cooldown_status",
    "force_clear_cooldown",
    "generate",
    "chat_messages",
    "INFERENCE_MODEL",
    "INFERENCE_PROVIDER",
    "MAX_NEW_TOKENS",
]
if __name__ == "__main__":
    # Smoke test
    for p in ("tinybard", "focusfriend", "crittercalm"):
        print(p, "cooldown:", cooldown_status(p))