johnbridges commited on
Commit
9f87c0c
·
1 Parent(s): d757694
Files changed (1) hide show
  1. app.py +256 -87
app.py CHANGED
@@ -1,90 +1,259 @@
1
- # app.py
2
- import asyncio, logging
3
- import gradio as gr
 
 
4
 
 
 
 
5
  from config import settings
6
- from rabbit_base import RabbitBase
7
- from listener import RabbitListenerBase
8
- from rabbit_repo import RabbitRepo
9
- from oa_server import OpenAIServers
10
- #from vllm_backend import VLLMChatBackend, StubImagesBackend
11
- #from transformers_backend import TransformersChatBackend, StubImagesBackend
12
- #from hf_backend import HFChatBackend, StubImagesBackend
13
- from hf_backend import StubImagesBackend
14
- from timesfm_backend import TimesFMBackend
15
-
16
-
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
20
- )
21
- log = logging.getLogger("app")
22
-
23
- # ----------------- Hugging Face Spaces helpers -----------------
24
  try:
25
- import spaces
26
-
27
- @spaces.GPU(duration=60)
28
- def gpu_entrypoint() -> str:
29
- return "gpu: ready"
30
-
31
- except Exception:
32
- def gpu_entrypoint() -> str:
33
- return "gpu: not available (CPU only)"
34
-
35
- # ----------------- RabbitMQ wiring -----------------
36
- publisher = RabbitRepo(external_source="openai.mq.server")
37
- resolver = (lambda name: "direct" if name.startswith("oa.") else settings.RABBIT_EXCHANGE_TYPE)
38
- base = RabbitBase(exchange_type_resolver=resolver)
39
-
40
- servers = OpenAIServers(
41
- publisher,
42
- chat_backend=TimesFMBackend(),
43
- images_backend=StubImagesBackend()
44
- )
45
-
46
- handlers = {
47
- "oaChatCreate": servers.handle_chat_create,
48
- "oaImagesGenerate": servers.handle_images_generate,
49
- }
50
-
51
- DECLS = [
52
- {"ExchangeName": "oa.chat.create", "FuncName": "oaChatCreate",
53
- "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
54
- {"ExchangeName": "oa.images.generate", "FuncName": "oaImagesGenerate",
55
- "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
56
- ]
57
-
58
- listener = RabbitListenerBase(base, instance_name=settings.RABBIT_INSTANCE_NAME, handlers=handlers)
59
-
60
- # ----------------- Startup init -----------------
61
- async def _startup_init():
62
- try:
63
- await base.connect() # connect to RabbitMQ
64
- await listener.start(DECLS) # start queue listeners
65
- return "OpenAI MQ + vLLM: ready"
66
- except Exception as e:
67
- log.exception("Startup init failed")
68
- return f"ERROR: {e}"
69
-
70
- async def ping():
71
- return "ok"
72
-
73
- # ----------------- Gradio UI -----------------
74
- with gr.Blocks(title="OpenAI over RabbitMQ (local vLLM)", theme=gr.themes.Soft()) as demo:
75
- gr.Markdown("## OpenAI-compatible over RabbitMQ using vLLM locally inside Space")
76
- with gr.Tabs():
77
- with gr.Tab("Service"):
78
- btn = gr.Button("Ping")
79
- out = gr.Textbox(label="Ping result")
80
- btn.click(ping, inputs=None, outputs=out)
81
- init_status = gr.Textbox(label="Startup status", interactive=False)
82
- demo.load(fn=_startup_init, inputs=None, outputs=init_status)
83
-
84
- with gr.Tab("@spaces.GPU Probe"):
85
- gpu_btn = gr.Button("GPU Ready Probe", variant="primary")
86
- gpu_out = gr.Textbox(label="GPU Probe Result", interactive=False)
87
- gpu_btn.click(gpu_entrypoint, inputs=None, outputs=gpu_out)
88
-
89
- if __name__ == "__main__":
90
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, debug=True, mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # timesfm_backend.py
2
+ import time
3
+ import json
4
+ import logging
5
+ from typing import Any, Dict, List, Optional, Tuple
6
 
7
+ import numpy as np
8
+
9
+ from backends_base import ChatBackend, ImagesBackend # ChatBackend for OA server
10
  from config import settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Try to import TimesFM. If not present, we fall back to a naive forecaster.
15
+ _TIMESFM_AVAILABLE = False
16
+ _TFM = None
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ # google timesfm 2.5 requires `pip install timesfm`
19
+ # model class name can be TimesFm (library-dependent)
20
+ from timesfm import TimesFm # type: ignore
21
+ _TIMESFM_AVAILABLE = True
22
+ except Exception as e:
23
+ logger.warning("timesfm not available (%s) — will use naive fallback.", e)
24
+
25
+
26
+ def _parse_series(series: Any) -> np.ndarray:
27
+ """
28
+ Accepts list[float], list[int], list[dict{value:..}], or dict with 'values'.
29
+ Returns a 1D float numpy array. Raises ValueError on empty/invalid.
30
+ """
31
+ if series is None:
32
+ raise ValueError("series is required")
33
+
34
+ if isinstance(series, dict):
35
+ if "values" in series:
36
+ series = series["values"]
37
+ elif "y" in series:
38
+ series = series["y"]
39
+
40
+ vals: List[float] = []
41
+ if isinstance(series, (list, tuple)):
42
+ if series and isinstance(series[0], dict):
43
+ # e.g. [{"t": "...", "y": 1.2}, ...] or {"value": ...}
44
+ for item in series:
45
+ if "y" in item:
46
+ vals.append(float(item["y"]))
47
+ elif "value" in item:
48
+ vals.append(float(item["value"]))
49
+ else:
50
+ # numeric list
51
+ vals = [float(x) for x in series]
52
+ else:
53
+ raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
54
+
55
+ if not vals:
56
+ raise ValueError("series is empty")
57
+ return np.asarray(vals, dtype=np.float32)
58
+
59
+
60
+ def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
61
+ """
62
+ Very small, dependency-free fallback:
63
+ - if length >= 4: mean of last 4 points
64
+ - else: mean of all points
65
+ """
66
+ if horizon <= 0:
67
+ return np.zeros((0,), dtype=np.float32)
68
+ k = 4 if y.shape[0] >= 4 else y.shape[0]
69
+ base = float(np.mean(y[-k:]))
70
+ return np.full((horizon,), base, dtype=np.float32)
71
+
72
+
73
+ class TimesFMBackend(ChatBackend):
74
+ """
75
+ Chat-compatible backend (for oa_server) wrapping TimesFM (if installed).
76
+ If TimesFM is missing, uses a naive statistical fallback.
77
+ """
78
+
79
+ def __init__(self,
80
+ model_id: Optional[str] = None,
81
+ device: Optional[str] = None):
82
+ """
83
+ model_id: optional identifier for logs/metadata
84
+ device: 'cpu' or 'cuda' (passed to TimesFm if supported by installed lib)
85
+ """
86
+ self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
87
+ self.device = device or "cpu"
88
+ self._model = None # lazy init
89
+
90
+ # ---------- internal ----------
91
+ def _ensure_model(self):
92
+ if self._model is not None or not _TIMESFM_AVAILABLE:
93
+ return
94
+ try:
95
+ # minimal init; adjust kwargs if your installed version needs different args
96
+ self._model = TimesFm() # type: ignore
97
+ logger.info("TimesFM model initialized.")
98
+ except Exception as e:
99
+ logger.exception("Failed to initialize TimesFM; will use fallback. %s", e)
100
+ self._model = None
101
+
102
+ # ---------- public helpers ----------
103
+ async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
104
+ """
105
+ Unified forecast entrypoint.
106
+ Expected keys (directly in payload OR nested under 'data' OR 'timeseries'):
107
+ - series: list of numbers (or list of dicts holding 'y'/'value')
108
+ - horizon: int (>0)
109
+ - freq: optional string for metadata only
110
+ Returns:
111
+ {
112
+ "model": "...",
113
+ "horizon": int,
114
+ "freq": str|None,
115
+ "forecast": [floats],
116
+ "note": str|None
117
+ }
118
+ """
119
+ # unwrap if nested
120
+ if "data" in payload and isinstance(payload["data"], dict):
121
+ payload = {**payload, **payload["data"]}
122
+ if "timeseries" in payload and isinstance(payload["timeseries"], dict):
123
+ payload = {**payload, **payload["timeseries"]}
124
+
125
+ series = payload.get("series")
126
+ horizon = int(payload.get("horizon", 0))
127
+ freq = payload.get("freq")
128
+
129
+ y = _parse_series(series)
130
+ if horizon <= 0:
131
+ raise ValueError("horizon must be a positive integer")
132
+
133
+ self._ensure_model()
134
+
135
+ if _TIMESFM_AVAILABLE and self._model is not None:
136
+ # Use real TimesFM
137
+ try:
138
+ # Most TimesFM APIs are batch-oriented; we add a batch dim and remove it later
139
+ # If your installed version differs (e.g., .predict with signature),
140
+ # change these two lines accordingly:
141
+ y_batch = y[None, :]
142
+ preds = self._model.predict(y_batch, horizon=horizon) # type: ignore
143
+ # preds shape => (1, horizon)
144
+ fc = np.asarray(preds).reshape(-1).tolist()
145
+ note = None
146
+ except Exception as e:
147
+ logger.exception("TimesFM predict failed; falling back. %s", e)
148
+ fc = _fallback_forecast(y, horizon).tolist()
149
+ note = "fallback_used_due_to_predict_error"
150
+ else:
151
+ # Fallback path
152
+ fc = _fallback_forecast(y, horizon).tolist()
153
+ note = "fallback_used_timesfm_missing"
154
+
155
+ return {
156
+ "model": self.model_id,
157
+ "horizon": horizon,
158
+ "freq": freq,
159
+ "forecast": fc,
160
+ "note": note,
161
+ }
162
+
163
+ # ---------- ChatBackend interface (for oa_server) ----------
164
+ async def stream(self, request: Dict[str, Any]):
165
+ """
166
+ OA-compatible streaming shim:
167
+ - Extracts forecast inputs from request (or from last user message JSON).
168
+ - Runs forecast() and yields ONE OpenAI-style chat chunk whose content
169
+ is a compact JSON string with the forecast result.
170
+ """
171
+ rid = f"chatcmpl-timesfm-{int(time.time())}"
172
+ now = int(time.time())
173
+
174
+ # try to gather payload
175
+ payload: Dict[str, Any] = {}
176
+
177
+ # 1) allow direct shape: {series, horizon, ...} / or under 'data'/'timeseries'
178
+ if isinstance(request, dict):
179
+ payload = dict(request) # shallow copy
180
+
181
+ # 2) optionally parse last user message if it's JSON
182
+ try:
183
+ msgs = request.get("messages") if isinstance(request, dict) else None
184
+ if isinstance(msgs, list) and msgs:
185
+ for m in reversed(msgs):
186
+ if isinstance(m, dict) and m.get("role") == "user":
187
+ c = m.get("content")
188
+ if isinstance(c, str):
189
+ c_str = c.strip()
190
+ if (c_str.startswith("{") and c_str.endswith("}")) or (
191
+ c_str.startswith("[") and c_str.endswith("]")
192
+ ):
193
+ # try parse JSON content
194
+ parsed = json.loads(c_str)
195
+ if isinstance(parsed, dict):
196
+ payload.update(parsed)
197
+ break
198
+ except Exception:
199
+ # non-fatal: keep whatever we had
200
+ pass
201
+
202
+ # run forecast
203
+ try:
204
+ result = await self.forecast(payload)
205
+ except Exception as e:
206
+ # return an error chunk in OpenAI shape
207
+ err = {"error": str(e)}
208
+ content = json.dumps(err, separators=(",", ":"), ensure_ascii=False)
209
+ yield {
210
+ "id": rid,
211
+ "object": "chat.completion.chunk",
212
+ "created": now,
213
+ "model": self.model_id,
214
+ "choices": [
215
+ {
216
+ "index": 0,
217
+ "delta": {"role": "assistant", "content": content},
218
+ "finish_reason": "stop",
219
+ }
220
+ ],
221
+ }
222
+ return
223
+
224
+ # success: compact JSON content so your .NET can parse
225
+ content = json.dumps(
226
+ {
227
+ "model": result.get("model"),
228
+ "horizon": result.get("horizon"),
229
+ "freq": result.get("freq"),
230
+ "forecast": result.get("forecast"),
231
+ "note": result.get("note"),
232
+ "backend": "timesfm",
233
+ },
234
+ separators=(",", ":"),
235
+ ensure_ascii=False,
236
+ )
237
+
238
+ yield {
239
+ "id": rid,
240
+ "object": "chat.completion.chunk",
241
+ "created": now,
242
+ "model": self.model_id,
243
+ "choices": [
244
+ {
245
+ "index": 0,
246
+ "delta": {"role": "assistant", "content": content},
247
+ "finish_reason": "stop",
248
+ }
249
+ ],
250
+ }
251
+
252
+
253
+ # Optional: keep an images stub to satisfy oa_server wiring if needed elsewhere
254
+ class StubImagesBackend(ImagesBackend):
255
+ async def generate_b64(self, request: Dict[str, Any]) -> str:
256
+ logger.warning("Image generation not supported in TimesFM backend.")
257
+ return (
258
+ "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
259
+ )