senthil2421 commited on
Commit
99e3f1b
Β·
1 Parent(s): d81f11d

arch: refactor cloud_backend into lean discovery server by removing execution logic

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. adapters/__init__.py +0 -0
  2. adapters/__pycache__/__init__.cpython-310.pyc +0 -0
  3. adapters/__pycache__/base.cpython-310.pyc +0 -0
  4. adapters/__pycache__/hf_adapter.cpython-310.pyc +0 -0
  5. adapters/__pycache__/onnx_adapter.cpython-310.pyc +0 -0
  6. adapters/__pycache__/roboflow_adapter.cpython-310.pyc +0 -0
  7. adapters/base.py +0 -28
  8. adapters/hf_adapter.py +0 -415
  9. adapters/onnx_adapter.py +0 -176
  10. adapters/roboflow_adapter.py +0 -353
  11. api/routes/benchmark.py +0 -238
  12. api/routes/inference.py +0 -168
  13. api/routes/jobs.py +0 -56
  14. api/routes/system.py +0 -97
  15. api/routes/training.py +0 -428
  16. benchmark/__init__.py +0 -1
  17. benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
  18. benchmark/__pycache__/compatibility.cpython-310.pyc +0 -0
  19. benchmark/__pycache__/execution.cpython-310.pyc +0 -0
  20. benchmark/__pycache__/metrics.cpython-310.pyc +0 -0
  21. benchmark/__pycache__/orchestrator.cpython-310.pyc +0 -0
  22. benchmark/__pycache__/registry.cpython-310.pyc +0 -0
  23. benchmark/__pycache__/telemetry.cpython-310.pyc +0 -0
  24. benchmark/adapters/__pycache__/base.cpython-310.pyc +0 -0
  25. benchmark/adapters/__pycache__/registry.cpython-310.pyc +0 -0
  26. benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc +0 -0
  27. benchmark/adapters/base.py +0 -38
  28. benchmark/adapters/optimum_runner.py +0 -53
  29. benchmark/adapters/registry.py +0 -44
  30. benchmark/adapters/torch_runner.py +0 -45
  31. benchmark/compatibility.py +0 -360
  32. benchmark/execution.py +0 -366
  33. benchmark/metrics.py +0 -110
  34. benchmark/orchestrator.py +0 -374
  35. benchmark/registry.py +0 -302
  36. benchmark/telemetry.py +0 -182
  37. benchmark/torch_runner.py +0 -142
  38. config.py +4 -40
  39. download/__init__.py +0 -0
  40. download/__pycache__/__init__.cpython-310.pyc +0 -0
  41. download/__pycache__/manager.cpython-310.pyc +0 -0
  42. download/manager.py +0 -366
  43. inference/__init__.py +0 -1
  44. inference/__pycache__/__init__.cpython-310.pyc +0 -0
  45. inference/__pycache__/engine.cpython-310.pyc +0 -0
  46. inference/__pycache__/session.cpython-310.pyc +0 -0
  47. inference/engine.py +0 -447
  48. inference/session.py +0 -80
  49. main.py +4 -6
  50. projects/__init__.py +0 -0
adapters/__init__.py DELETED
File without changes
adapters/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
adapters/__pycache__/base.cpython-310.pyc DELETED
Binary file (1.31 kB)
 
adapters/__pycache__/hf_adapter.cpython-310.pyc DELETED
Binary file (13 kB)
 
adapters/__pycache__/onnx_adapter.cpython-310.pyc DELETED
Binary file (5.27 kB)
 
adapters/__pycache__/roboflow_adapter.cpython-310.pyc DELETED
Binary file (10.9 kB)
 
adapters/base.py DELETED
@@ -1,28 +0,0 @@
1
- """
2
- adapters/base.py β€” Abstract base class every source adapter must implement.
3
- Enforces a stable contract so the registry never knows which adapter runs.
4
- """
5
- from __future__ import annotations
6
-
7
- from abc import ABC, abstractmethod
8
-
9
- from models.model import Model
10
-
11
-
12
- class BaseAdapter(ABC):
13
- """Fetch models from an external source and normalize to the Model schema."""
14
-
15
- source_name: str = "unknown"
16
-
17
- @abstractmethod
18
- async def fetch_models(self) -> list[Model]:
19
- """Return a list of normalized Model objects from the source."""
20
- ...
21
-
22
- def _format_size(self, bytes_: int) -> str:
23
- """Human-readable file size."""
24
- for unit in ("B", "KB", "MB", "GB", "TB"):
25
- if bytes_ < 1024:
26
- return f"{bytes_:.1f} {unit}"
27
- bytes_ //= 1024
28
- return f"{bytes_} PB"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adapters/hf_adapter.py DELETED
@@ -1,415 +0,0 @@
1
- """
2
- adapters/hf_adapter.py β€” Hugging Face Hub adapter.
3
- Fetches real models via the public HF API and normalises them to our schema.
4
-
5
- Rate-limits respected via polite delays. Requires no authentication for
6
- publicly accessible models; set HF_TOKEN env var for higher rate-limits.
7
- """
8
- from __future__ import annotations
9
-
10
- import asyncio
11
- import re
12
- from typing import Any
13
-
14
-
15
- def _is_shard_file(filename: str) -> bool:
16
- """Return True for sharded weight files like model-00001-of-00003.safetensors."""
17
- return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
18
-
19
- import httpx
20
- from tenacity import retry, stop_after_attempt, wait_exponential
21
-
22
- from adapters.base import BaseAdapter
23
- from config import settings
24
- from models.model import Model, ModelMetrics, ModelVersion
25
- from observability.logger import get_logger
26
-
27
- log = get_logger("hf_adapter")
28
-
29
- # ── Task mapping: HF pipeline_tag β†’ our internal task ─────────────────────────
30
- HF_TASK_MAP: dict[str, str] = {
31
- "object-detection": "detection",
32
- "image-classification": "classification",
33
- "image-segmentation": "segmentation",
34
- "text-to-image": "generation",
35
- "image-to-image": "generation",
36
- "image-feature-extraction": "embedding",
37
- }
38
-
39
- # Tasks we actively fetch
40
- FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys())
41
-
42
- # ── Framework detection ────────────────────────────────────────────────────────
43
- def _detect_framework(tags: list[str], model_id: str) -> str:
44
- tag_str = " ".join(tags + [model_id]).lower()
45
- if "onnx" in tag_str: return "onnx"
46
- if "tflite" in tag_str: return "tflite"
47
- if "coreml" in tag_str: return "coreml"
48
- if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow"
49
- return "pytorch" # HF default
50
-
51
- # ── Hardware detection ─────────────────────────────────────────────────────────
52
- def _detect_hardware(tags: list[str]) -> list[str]:
53
- hw: list[str] = []
54
- tag_str = " ".join(tags).lower()
55
- if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu")
56
- if "edge" in tag_str or "mobile" in tag_str: hw.append("edge")
57
- if "cpu" in tag_str: hw.append("cpu")
58
- if not hw: hw.append("gpu") # safe default
59
- return hw
60
-
61
- # ── Internal tag normalisation ─────────────────────────────────────────────────
62
- QUALITY_TAG_MAP = {
63
- "state-of-the-art": "sota",
64
- "lightweight": "lightweight",
65
- "tiny": "tiny",
66
- "fast": "fastest",
67
- "real-time": "real-time",
68
- "accuracy": "high-accuracy",
69
- }
70
-
71
- def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]:
72
- out: list[str] = []
73
- for t in raw_tags:
74
- t_lower = t.lower()
75
- for keyword, mapped in QUALITY_TAG_MAP.items():
76
- if keyword in t_lower:
77
- out.append(mapped)
78
- # keep relevant library / dataset tags
79
- if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")):
80
- continue
81
- out.append(t_lower)
82
- # add pipeline as tag
83
- if pipeline:
84
- out.append(pipeline.replace("-", "_"))
85
- return list(dict.fromkeys(out)) # deduplicate, preserve order
86
-
87
-
88
- class HFAdapter(BaseAdapter):
89
- source_name = "hf"
90
-
91
- def __init__(self) -> None:
92
- headers = {"Accept": "application/json"}
93
- if settings.hf_token:
94
- headers["Authorization"] = f"Bearer {settings.hf_token}"
95
- self._client = httpx.AsyncClient(
96
- base_url=settings.hf_api_base,
97
- headers=headers,
98
- timeout=30,
99
- )
100
-
101
- @retry(
102
- stop=stop_after_attempt(3),
103
- wait=wait_exponential(multiplier=1, min=2, max=10),
104
- reraise=True,
105
- )
106
- async def _fetch_task_page(
107
- self, pipeline_tag: str, limit: int = 100
108
- ) -> list[dict[str, Any]]:
109
- params = {
110
- "pipeline_tag": pipeline_tag,
111
- "sort": "downloads",
112
- "direction": -1, # descending
113
- "limit": limit,
114
- "full": "True",
115
- }
116
- log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit)
117
- resp = await self._client.get("/models", params=params)
118
- resp.raise_for_status()
119
- return resp.json()
120
-
121
- @retry(
122
- stop=stop_after_attempt(3),
123
- wait=wait_exponential(multiplier=1, min=2, max=10),
124
- reraise=True,
125
- )
126
- async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]:
127
- resp = await self._client.get(f"/models/{model_id}", params={"full": "True"})
128
- resp.raise_for_status()
129
- raw = resp.json()
130
-
131
- siblings: list[dict[str, Any]] = raw.get("siblings") or []
132
- has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings)
133
- if not has_any_size:
134
- try:
135
- tree = await self._fetch_model_tree(model_id, revision="main")
136
- size_by_path: dict[str, int] = {
137
- (t.get("path") or ""): int(t.get("size") or 0)
138
- for t in (tree or [])
139
- if isinstance(t, dict)
140
- }
141
-
142
- patched: list[dict[str, Any]] = []
143
- for s in siblings:
144
- if not isinstance(s, dict):
145
- continue
146
- fn = s.get("rfilename") or s.get("path") or ""
147
- if fn and not s.get("size") and fn in size_by_path:
148
- s = {**s, "size": size_by_path[fn]}
149
- patched.append(s)
150
- raw["siblings"] = patched
151
- except Exception:
152
- pass
153
-
154
- return raw
155
-
156
- @retry(
157
- stop=stop_after_attempt(3),
158
- wait=wait_exponential(multiplier=1, min=2, max=10),
159
- reraise=True,
160
- )
161
- async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]:
162
- resp = await self._client.get(f"/models/{model_id}/tree/{revision}")
163
- resp.raise_for_status()
164
- data = resp.json()
165
- if isinstance(data, list):
166
- return data
167
- return []
168
-
169
- def _parse_safe_tensors_size(self, siblings: list[dict]) -> int:
170
- """Estimate model size from sibling file list."""
171
- total = 0
172
- weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
173
- for s in siblings or []:
174
- filename = s.get("rfilename", "").lower()
175
- if filename.endswith(weight_exts):
176
- total += s.get("size", 0)
177
-
178
- if total > 0:
179
- return total
180
-
181
- # If no size found in siblings, check if it's in the root dict (sometimes HF API does this)
182
- return 0 # Return 0 if not found, we'll handle fallback in _make_model
183
-
184
- @retry(
185
- stop=stop_after_attempt(3),
186
- wait=wait_exponential(multiplier=1, min=2, max=10),
187
- reraise=True,
188
- )
189
- async def _fetch_model_card(self, model_id: str) -> str:
190
- """Fetch model card (README.md) content for real-time description."""
191
- url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md"
192
- try:
193
- resp = await self._client.get(url)
194
- if resp.status_code == 200:
195
- return resp.text
196
- except Exception:
197
- pass
198
- return ""
199
-
200
- def _extract_description(self, readme: str, raw: dict[str, Any]) -> str:
201
- """Extract a clean description from README or card data."""
202
- if readme:
203
- # Simple heuristic: take first paragraph that isn't frontmatter
204
- lines = readme.split("\n")
205
- in_frontmatter = False
206
- for line in lines:
207
- if line.strip() == "---":
208
- in_frontmatter = not in_frontmatter
209
- continue
210
- if not in_frontmatter and line.strip() and not line.startswith("#"):
211
- return line.strip()[:500]
212
-
213
- card_data = raw.get("cardData") or {}
214
- description: str = (
215
- (card_data.get("summary") or "")
216
- or (card_data.get("description") or "")
217
- or (raw.get("description") or "")
218
- ).strip()
219
- return description
220
-
221
- def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics:
222
- """
223
- Product-Grade Metrics Estimation.
224
- Uses model name heuristics to provide realistic data for common architectures.
225
- """
226
- metrics = ModelMetrics()
227
- m_id = model_id.lower()
228
-
229
- # Base latency/vram estimates by architecture
230
- if "vit" in m_id or "dinov2" in m_id:
231
- metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0
232
- metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8
233
- metrics.accuracy = 82.4 if "base" in m_id else 84.5
234
- elif "segformer" in m_id:
235
- # b0, b1, b2, b3, b4, b5
236
- if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0
237
- elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0
238
- elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0
239
- else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0
240
- elif "convnext" in m_id:
241
- metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0
242
- metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2
243
- metrics.accuracy = 81.0 if "tiny" in m_id else 83.5
244
- elif "yolo" in m_id:
245
- # n, s, m, l, x
246
- if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3
247
- elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9
248
- elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2
249
- else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0
250
-
251
- # Generic task-based fallbacks if still empty
252
- if metrics.latency_ms is None:
253
- if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0
254
- elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0
255
- elif task == "embedding": metrics.latency_ms = 40.0
256
- elif task == "generation": metrics.latency_ms = 1500.0
257
-
258
- return metrics
259
-
260
- def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None:
261
- model_id: str = raw.get("id") or raw.get("modelId", "")
262
- if not model_id:
263
- return None
264
-
265
- task = HF_TASK_MAP.get(pipeline_tag)
266
- if not task:
267
- return None
268
- tags_raw: list[str] = raw.get("tags") or []
269
- framework = _detect_framework(tags_raw, model_id)
270
- hardware = _detect_hardware(tags_raw)
271
- tags = _normalise_tags(tags_raw, pipeline_tag)
272
-
273
- # Size
274
- siblings: list[dict] = raw.get("siblings") or []
275
- size = self._parse_safe_tensors_size(siblings)
276
- if size == 0:
277
- # Fallback based on model type if size not found
278
- if "large" in model_id.lower(): size = 1_200_000_000
279
- elif "base" in model_id.lower(): size = 500_000_000
280
- elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000
281
- else: size = 450_000_000 # More realistic general default than exactly 500MB
282
-
283
- # Provider β€” author part of model_id
284
- provider = model_id.split("/")[0] if "/" in model_id else "community"
285
-
286
- # safe name
287
- name = model_id.split("/")[-1] if "/" in model_id else model_id
288
- # Clean ugly names
289
- name = re.sub(r"[-_]+", "-", name).strip("-")
290
-
291
- downloads = raw.get("downloads") or 0
292
- likes = raw.get("likes") or 0
293
-
294
- # Fabricate a sensible version from last modified
295
- last_mod: str = raw.get("lastModified") or raw.get("createdAt") or ""
296
- release_date = last_mod[:10] if last_mod else "2024-01-01"
297
- sha8 = (raw.get("sha") or "main")[:8]
298
-
299
- # Build versions from weight files in the repo (one per distinct weight file)
300
- weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
301
- weight_files = [
302
- s for s in siblings
303
- if s.get("rfilename", "").lower().endswith(weight_exts)
304
- and not _is_shard_file(s.get("rfilename", ""))
305
- ]
306
-
307
- if len(weight_files) > 1:
308
- versions = []
309
- for s in weight_files[:15]:
310
- filename = s["rfilename"]
311
- # Detect variant from filename (n, s, m, l, x, or specific labels)
312
- variant_label = "Stable"
313
- fn_lower = filename.lower()
314
- if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano"
315
- elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small"
316
- elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium"
317
- elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large"
318
- elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge"
319
-
320
- versions.append(ModelVersion(
321
- version=filename.replace(".", "_"),
322
- label=variant_label,
323
- description=f"Model variant: {filename}",
324
- releaseDate=release_date,
325
- changelog=None,
326
- ))
327
- else:
328
- versions = [
329
- ModelVersion(
330
- version=sha8,
331
- label="Latest",
332
- description="Primary model weight file.",
333
- releaseDate=release_date,
334
- changelog=None,
335
- )
336
- ]
337
-
338
- # Description from card data
339
- description = self._extract_description("", raw)
340
- if not description:
341
- description = f"{task.capitalize()} model by {provider}."
342
-
343
- # Metrics Estimation
344
- metrics = self._estimate_metrics(model_id, task)
345
-
346
- return Model(
347
- id = model_id.replace("/", "_").lower(),
348
- name = name,
349
- task = task,
350
- framework = framework,
351
- source = "hf",
352
- provider = provider,
353
- description = description,
354
- download_url = f"https://huggingface.co/{model_id}",
355
- size = size,
356
- size_label = self._format_size(size),
357
- tags = tags,
358
- hardware = hardware,
359
- status = "available",
360
- downloaded = False,
361
- downloads = downloads,
362
- rating = min(5.0, (likes / 200) + 3.5) if likes else None,
363
- liked = False,
364
- metrics = metrics,
365
- versions = versions,
366
- )
367
-
368
- async def fetch_models(self) -> list[Model]:
369
- models: list[Model] = []
370
- seen_ids: set[str] = set()
371
-
372
- for pipeline_tag in FETCH_TASKS:
373
- try:
374
- raw_list = await self._fetch_task_page(
375
- pipeline_tag, limit=settings.hf_models_per_task
376
- )
377
- for idx, raw in enumerate(raw_list):
378
- # Enrich top-N per task with full model detail so siblings include sizes.
379
- if idx < 10:
380
- original_id = raw.get("id") or raw.get("modelId")
381
- if original_id:
382
- try:
383
- raw = await self._fetch_model_detail(original_id)
384
- except Exception:
385
- pass
386
-
387
- m = self._make_model(raw, pipeline_tag)
388
- if m and m.id not in seen_ids:
389
- # Try to fetch real-time description for the first 5 models of each task
390
- if len([mod for mod in models if mod.task == m.task]) < 5:
391
- original_id = raw.get("id") or raw.get("modelId")
392
- if original_id:
393
- readme = await self._fetch_model_card(original_id)
394
- if readme:
395
- m.description = self._extract_description(readme, raw)
396
-
397
- seen_ids.add(m.id)
398
- models.append(m)
399
- # Be polite to HF API
400
- await asyncio.sleep(0.3)
401
- except Exception as exc:
402
- log.warning(
403
- "hf_fetch_task_failed",
404
- pipeline_tag=pipeline_tag,
405
- error=str(exc),
406
- )
407
-
408
- log.info("hf_fetch_complete", total=len(models))
409
- return models
410
-
411
- async def __aenter__(self) -> "HFAdapter":
412
- return self
413
-
414
- async def __aexit__(self, *_: Any) -> None:
415
- await self._client.aclose()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adapters/onnx_adapter.py DELETED
@@ -1,176 +0,0 @@
1
- """
2
- adapters/onnx_adapter.py β€” ONNX Model Zoo adapter.
3
- Fetches the curated list of ONNX Zoo models from the GitHub API.
4
- """
5
- from __future__ import annotations
6
-
7
- from typing import Any
8
-
9
- import httpx
10
- from tenacity import retry, stop_after_attempt, wait_exponential
11
-
12
- from adapters.base import BaseAdapter
13
- from models.model import Model, ModelMetrics, ModelVersion
14
- from observability.logger import get_logger
15
-
16
- log = get_logger("onnx_adapter")
17
-
18
- # Curated ONNX Zoo models with metadata + download URLs (GitHub API is rate-limited without auth)
19
- ONNX_CURATED: list[dict[str, Any]] = [
20
- {
21
- "id": "onnx_resnet50",
22
- "name": "ResNet-50",
23
- "task": "classification",
24
- "provider": "ONNX Zoo",
25
- "description": "ResNet-50 v1 image classification model in ONNX format.",
26
- "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx",
27
- "size": 102_000_000,
28
- "tags": ["resnet", "imagenet", "classification"],
29
- "hardware": ["gpu", "cpu"],
30
- "metrics": {"latency_ms": 14.2, "top1": 74.9},
31
- "downloads": 250_000,
32
- "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-06-01"}],
33
- },
34
- {
35
- "id": "onnx_yolov8n",
36
- "name": "YOLOv8n",
37
- "task": "detection",
38
- "provider": "Ultralytics",
39
- "description": "Ultralytics YOLOv8 Nano β€” real-time object detection, ONNX export.",
40
- "download_url": "https://github.com/ultralytics/yolov8/releases/download/v8.0.0/yolov8n.onnx",
41
- "size": 6_200_000,
42
- "tags": ["yolo", "real-time", "fastest", "edge"],
43
- "hardware": ["gpu", "cpu", "edge"],
44
- "metrics": {"latency_ms": 3.1, "mAP": 37.3},
45
- "downloads": 420_000,
46
- "versions": [{"version": "8.0", "label": "Latest", "releaseDate": "2023-09-15"}],
47
- },
48
- {
49
- "id": "onnx_mobilenet_v3",
50
- "name": "MobileNetV3-Large",
51
- "task": "classification",
52
- "provider": "Google",
53
- "description": "MobileNetV3-Large for efficient on-device image classification.",
54
- "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv3-large-1.11.onnx",
55
- "size": 22_000_000,
56
- "tags": ["mobilenet", "lightweight", "edge", "efficient"],
57
- "hardware": ["cpu", "edge"],
58
- "metrics": {"latency_ms": 5.8, "top1": 75.2, "fps": 180},
59
- "downloads": 310_000,
60
- "versions": [{"version": "3.0", "label": "Latest", "releaseDate": "2023-01-01"}],
61
- },
62
- {
63
- "id": "onnx_bert_base_uncased",
64
- "name": "BERT-Base-Uncased",
65
- "task": "nlp",
66
- "provider": "Google",
67
- "description": "BERT base model fine-tuned for NLP inference in ONNX format.",
68
- "download_url": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
69
- "size": 438_000_000,
70
- "tags": ["bert", "nlp", "transformer"],
71
- "hardware": ["gpu", "cpu"],
72
- "metrics": {"latency_ms": 42.0},
73
- "downloads": 198_000,
74
- "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2022-11-01"}],
75
- },
76
- {
77
- "id": "onnx_efficientnet_b0",
78
- "name": "EfficientNet-B0",
79
- "task": "classification",
80
- "provider": "Google Brain",
81
- "description": "EfficientNet-B0 for scalable image classification.",
82
- "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite/model/efficientnet-lite4-11.onnx",
83
- "size": 20_000_000,
84
- "tags": ["efficientnet", "efficient", "high-accuracy"],
85
- "hardware": ["gpu", "cpu"],
86
- "metrics": {"latency_ms": 10.4, "top1": 77.1},
87
- "downloads": 145_000,
88
- "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-03-01"}],
89
- },
90
- {
91
- "id": "onnx_sam_vit_b",
92
- "name": "SAM ViT-B",
93
- "task": "segmentation",
94
- "provider": "Meta AI",
95
- "description": "Segment Anything Model (ViT-B) for universal image segmentation.",
96
- "download_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
97
- "size": 375_000_000,
98
- "tags": ["sam", "segmentation", "sota"],
99
- "hardware": ["gpu"],
100
- "metrics": {"latency_ms": 68.0},
101
- "downloads": 88_000,
102
- "versions": [{"version": "1.0", "label": "Latest", "releaseDate": "2023-04-05"}],
103
- },
104
- {
105
- "id": "onnx_clip_vit_b32",
106
- "name": "CLIP ViT-B/32",
107
- "task": "embedding",
108
- "provider": "OpenAI",
109
- "description": "CLIP image + text embedding model for zero-shot classification.",
110
- "download_url": "https://openaipublic.blob.core.windows.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba4f386/ViT-B-32.pt",
111
- "size": 338_000_000,
112
- "tags": ["clip", "embedding", "multimodal"],
113
- "hardware": ["gpu", "cpu"],
114
- "metrics": {"latency_ms": 25.0},
115
- "downloads": 275_000,
116
- "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-01-01"}],
117
- },
118
- {
119
- "id": "onnx_whisper_tiny",
120
- "name": "Whisper Tiny",
121
- "task": "nlp",
122
- "provider": "OpenAI",
123
- "description": "Whisper Tiny speech-to-text model in ONNX format.",
124
- "download_url": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424930e36a852c0/tiny.pt",
125
- "size": 39_000_000,
126
- "tags": ["whisper", "speech", "lightweight"],
127
- "hardware": ["cpu", "edge"],
128
- "metrics": {"latency_ms": 100.0},
129
- "downloads": 167_000,
130
- "versions": [{"version": "20231117", "label": "Latest", "releaseDate": "2023-11-17"}],
131
- },
132
- ]
133
-
134
-
135
- class ONNXAdapter(BaseAdapter):
136
- source_name = "onnx"
137
-
138
- async def fetch_models(self) -> list[Model]:
139
- models: list[Model] = []
140
- for raw in ONNX_CURATED:
141
- try:
142
- versions = [
143
- ModelVersion(
144
- version=v["version"],
145
- label=v.get("label", "Stable"),
146
- releaseDate=v.get("releaseDate", ""),
147
- )
148
- for v in raw.get("versions", [])
149
- ]
150
- metrics_raw = raw.get("metrics", {})
151
- m = Model(
152
- id = raw["id"],
153
- name = raw["name"],
154
- task = raw["task"],
155
- framework = "onnx",
156
- source = "onnx",
157
- provider = raw.get("provider", "ONNX Zoo"),
158
- description = raw.get("description", ""),
159
- download_url = raw.get("download_url"),
160
- size = raw.get("size", 0),
161
- size_label = self._format_size(raw.get("size", 0)),
162
- tags = raw.get("tags", []),
163
- hardware = raw.get("hardware", ["gpu"]),
164
- status = "available",
165
- downloaded = False,
166
- downloads = raw.get("downloads"),
167
- rating = 4.2,
168
- metrics = ModelMetrics(**metrics_raw),
169
- versions = versions,
170
- )
171
- models.append(m)
172
- except Exception as exc:
173
- log.warning("onnx_parse_failed", model_id=raw.get("id"), error=str(exc))
174
-
175
- log.info("onnx_fetch_complete", total=len(models))
176
- return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adapters/roboflow_adapter.py DELETED
@@ -1,353 +0,0 @@
1
- """
2
- adapters/roboflow_adapter.py β€” Roboflow Universe API client.
3
-
4
- Responsibilities:
5
- - Fetch dataset metadata (search, workspace listings, project details)
6
- - Normalise responses β†’ Dataset domain model
7
- - Cache results in roboflow_cache table (TTL-aware)
8
- - Handle pagination, rate limits, and errors robustly
9
-
10
- Roboflow API reference: https://docs.roboflow.com/api-reference/
11
- """
12
- from __future__ import annotations
13
-
14
- import hashlib
15
- import json
16
- import time
17
- from typing import Any
18
-
19
- import httpx
20
- from tenacity import retry, stop_after_attempt, wait_exponential
21
-
22
- from database.connection import get_db
23
- from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask
24
- from observability.logger import audit, get_logger
25
-
26
- log = get_logger("roboflow_adapter")
27
-
28
- _ROBOFLOW_BASE = "https://api.roboflow.com"
29
- _UNIVERSE_BASE = "https://universe.roboflow.com"
30
- _DEFAULT_TTL = 3600 # 1 hour
31
-
32
- # ── Task mapping from Roboflow annotation_type ───────────────────────────────
33
-
34
- _TASK_MAP: dict[str, DatasetTask] = {
35
- "object-detection": DatasetTask.detection,
36
- "instance-segmentation": DatasetTask.segmentation,
37
- "semantic-segmentation": DatasetTask.segmentation,
38
- "classification": DatasetTask.classification,
39
- "keypoint-detection": DatasetTask.keypoints,
40
- "multiclass-classification": DatasetTask.classification,
41
- }
42
-
43
- _FORMAT_MAP: dict[str, DatasetFormat] = {
44
- "yolov5": DatasetFormat.yolo,
45
- "yolov7": DatasetFormat.yolo,
46
- "yolov8": DatasetFormat.yolo,
47
- "yolov9": DatasetFormat.yolo,
48
- "coco": DatasetFormat.coco,
49
- "voc": DatasetFormat.voc,
50
- "tfrecord": DatasetFormat.tfrecord,
51
- "csv": DatasetFormat.csv,
52
- "createml": DatasetFormat.json,
53
- "multiclass": DatasetFormat.csv,
54
- }
55
-
56
-
57
- def _cache_key(parts: list[str]) -> str:
58
- raw = "|".join(parts)
59
- return hashlib.sha256(raw.encode()).hexdigest()[:32]
60
-
61
-
62
- def _fmt_bytes(n: int) -> str:
63
- for unit in ("B", "KB", "MB", "GB", "TB"):
64
- if n < 1024:
65
- return f"{n:.1f} {unit}"
66
- n /= 1024
67
- return f"{n:.1f} PB"
68
-
69
-
70
- # ── Cache helpers ─────────────────────────────────────────────────────────────
71
-
72
- async def _cache_get(key: str) -> dict[str, Any] | None:
73
- db = await get_db()
74
- async with db.execute(
75
- "SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?",
76
- (key,),
77
- ) as cur:
78
- row = await cur.fetchone()
79
- if row is None:
80
- return None
81
- fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S"))
82
- if time.time() - fetched > row["ttl_secs"]:
83
- return None # expired
84
- return json.loads(row["payload"])
85
-
86
-
87
- async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None:
88
- db = await get_db()
89
- await db.execute(
90
- """INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs)
91
- VALUES (?, ?, ?)""",
92
- (key, json.dumps(payload), ttl),
93
- )
94
- await db.commit()
95
-
96
-
97
- # ── HTTP client factory ───────────────────────────────────────────────────────
98
-
99
- def _make_client(api_key: str) -> httpx.AsyncClient:
100
- return httpx.AsyncClient(
101
- base_url=_ROBOFLOW_BASE,
102
- params={"api_key": api_key},
103
- timeout=30.0,
104
- headers={"User-Agent": "MLForge/1.0"},
105
- )
106
-
107
-
108
- # ── Roboflow Adapter ──────────────────────────────────────────────────────────
109
-
110
- class RoboflowAdapter:
111
- """
112
- Stateless adapter for the Roboflow API.
113
- All methods accept api_key explicitly to support per-user keys.
114
- """
115
-
116
- # ── Search (Universe) ─────────────────────────────────────────────────────
117
-
118
- @staticmethod
119
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
120
- async def search_datasets(
121
- api_key: str,
122
- query: str = "",
123
- workspace: str | None = None,
124
- page: int = 0,
125
- page_size: int = 50,
126
- ) -> list[Dataset]:
127
- """
128
- Search Roboflow Universe for datasets.
129
- Returns normalised Dataset objects.
130
- """
131
- ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)])
132
- cached = await _cache_get(ck)
133
- if cached:
134
- log.debug("roboflow_cache_hit", key=ck, query=query)
135
- return [Dataset(**d) for d in cached]
136
-
137
- params: dict[str, Any] = {
138
- "api_key": api_key,
139
- "q": query or "*",
140
- "from": page * page_size,
141
- "size": page_size,
142
- }
143
- if workspace:
144
- params["workspace"] = workspace
145
-
146
- async with _make_client(api_key) as client:
147
- try:
148
- resp = await client.get("/", params=params)
149
- resp.raise_for_status()
150
- data = resp.json()
151
- except httpx.HTTPStatusError as e:
152
- log.error("roboflow_api_error", status=e.response.status_code, query=query)
153
- await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error")
154
- raise
155
-
156
- datasets = []
157
- for item in data.get("results", []):
158
- try:
159
- ds = RoboflowAdapter._normalise_search_result(item)
160
- datasets.append(ds)
161
- except Exception as exc:
162
- log.warning("normalise_error", item_id=item.get("id"), error=str(exc))
163
-
164
- await _cache_set(ck, [d.model_dump() for d in datasets])
165
- await audit("roboflow_search", {"query": query, "count": len(datasets)})
166
- return datasets
167
-
168
- # ── Workspace datasets listing ────────────────────────────────────────────
169
-
170
- @staticmethod
171
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
172
- async def list_workspace_datasets(
173
- api_key: str,
174
- workspace: str,
175
- ) -> list[Dataset]:
176
- """List all datasets in a Roboflow workspace."""
177
- ck = _cache_key(["workspace", workspace])
178
- cached = await _cache_get(ck)
179
- if cached:
180
- return [Dataset(**d) for d in cached]
181
-
182
- async with _make_client(api_key) as client:
183
- try:
184
- resp = await client.get(f"/{workspace}")
185
- resp.raise_for_status()
186
- data = resp.json()
187
- except httpx.HTTPStatusError as e:
188
- log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code)
189
- raise
190
-
191
- datasets = []
192
- for proj in data.get("workspace", {}).get("projects", []):
193
- try:
194
- ds = RoboflowAdapter._normalise_project(proj, workspace)
195
- datasets.append(ds)
196
- except Exception as exc:
197
- log.warning("normalise_project_error", project=proj.get("id"), error=str(exc))
198
-
199
- await _cache_set(ck, [d.model_dump() for d in datasets])
200
- return datasets
201
-
202
- # ── Single project detail ─────────────────────────────────────────────────
203
-
204
- @staticmethod
205
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
206
- async def get_project(
207
- api_key: str,
208
- workspace: str,
209
- project_id: str,
210
- ) -> Dataset | None:
211
- """Fetch full metadata for a single Roboflow project."""
212
- ck = _cache_key(["project", workspace, project_id])
213
- cached = await _cache_get(ck)
214
- if cached:
215
- return Dataset(**cached)
216
-
217
- async with _make_client(api_key) as client:
218
- try:
219
- resp = await client.get(f"/{workspace}/{project_id}")
220
- resp.raise_for_status()
221
- data = resp.json()
222
- except httpx.HTTPStatusError as e:
223
- if e.response.status_code == 404:
224
- return None
225
- raise
226
-
227
- proj_data = data.get("project", data)
228
- ds = RoboflowAdapter._normalise_project(proj_data, workspace)
229
- await _cache_set(ck, ds.model_dump())
230
- return ds
231
-
232
- # ── Download URL builder ──────────────────────────────────────────────────
233
-
234
- @staticmethod
235
- async def get_download_url(
236
- api_key: str,
237
- workspace: str,
238
- project_id: str,
239
- version: int,
240
- export_format: str = "yolov8",
241
- ) -> str:
242
- """
243
- Fetch the export download link from Roboflow for the specified format.
244
- Uses the official Roboflow SDK to handle authentication and URL resolution.
245
- """
246
- try:
247
- from roboflow import Roboflow
248
- rf = Roboflow(api_key=api_key)
249
- project = rf.workspace(workspace).project(project_id)
250
- version_obj = project.version(version)
251
-
252
- # The SDK's download method usually downloads to disk,
253
- # but we can get the underlying export info.
254
- # We'll use a thread to run the SDK call since it's blocking.
255
- import asyncio
256
- def _get_link():
257
- return version_obj.export(export_format).download_link
258
-
259
- link = await asyncio.to_thread(_get_link)
260
- if not link:
261
- raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
262
- return link
263
- except Exception as e:
264
- log.error("roboflow_sdk_error", error=str(e))
265
- # Fallback to manual API if SDK fails or isn't installed correctly
266
- async with _make_client(api_key) as client:
267
- resp = await client.get(
268
- f"/{workspace}/{project_id}/{version}/{export_format}"
269
- )
270
- resp.raise_for_status()
271
- data = resp.json()
272
-
273
- link = export.get("link") or ""
274
- if not link:
275
- # If 'link' is missing, check if it's a Universe-style project and try to resolve manually
276
- # Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key]
277
- if "project" in data:
278
- pid = data["project"].get("id")
279
- if pid:
280
- link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}"
281
-
282
- if not link:
283
- raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
284
-
285
- # Ensure the link includes the API key correctly
286
- if "universe.roboflow.com" in link:
287
- if "key=" not in link:
288
- separator = "&" if "?" in link else "?"
289
- link = f"{link}{separator}key={api_key}"
290
- elif f"key={api_key}" not in link:
291
- # Replace old key if it exists but is wrong
292
- import re
293
- link = re.sub(r"key=[^&]+", f"key={api_key}", link)
294
-
295
- return link
296
-
297
- # ── Normalisation helpers ─────────────────────────────────────────────────
298
-
299
- @staticmethod
300
- def _normalise_search_result(item: dict[str, Any]) -> Dataset:
301
- """Map a Universe search result β†’ Dataset."""
302
- ann_type = item.get("annotation", {}).get("type", "object-detection")
303
- rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
304
- class_names = [c.get("name", "") for c in item.get("classes", [])]
305
- images = item.get("images", 0) or 0
306
-
307
- return Dataset(
308
- id = item.get("id", "").replace("/", "__"),
309
- name = item.get("name", "Unnamed"),
310
- description = item.get("description", ""),
311
- task = rf_task,
312
- format = DatasetFormat.yolo,
313
- source = DatasetSource.roboflow,
314
- status = DatasetStatus.available,
315
- images = images,
316
- classes = len(class_names),
317
- class_names = class_names,
318
- size_bytes = 0,
319
- size_label = "β€”",
320
- tags = item.get("tags", []),
321
- roboflow_id = item.get("id", ""),
322
- created_at = item.get("created", ""),
323
- updated_at = item.get("updated", ""),
324
- )
325
-
326
- @staticmethod
327
- def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset:
328
- """Map a workspace project β†’ Dataset."""
329
- ann_type = proj.get("annotation", "object-detection")
330
- rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
331
- class_names = [c.get("name", c) if isinstance(c, dict) else c
332
- for c in proj.get("classes", [])]
333
- project_id = proj.get("id", proj.get("name", "unknown"))
334
- rf_id = f"{workspace}/{project_id}"
335
- images = proj.get("images", 0) or 0
336
-
337
- return Dataset(
338
- id = rf_id.replace("/", "__"),
339
- name = proj.get("name", project_id),
340
- description = proj.get("description", ""),
341
- task = rf_task,
342
- format = DatasetFormat.yolo,
343
- source = DatasetSource.roboflow,
344
- status = DatasetStatus.available,
345
- images = images,
346
- classes = len(class_names),
347
- class_names = class_names,
348
- size_bytes = 0,
349
- size_label = "β€”",
350
- roboflow_id = rf_id,
351
- created_at = proj.get("created", ""),
352
- updated_at = proj.get("updated", ""),
353
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/benchmark.py DELETED
@@ -1,238 +0,0 @@
1
- """
2
- api/routes/benchmark.py β€” Benchmark Bridge REST + WebSocket API.
3
-
4
- Routes:
5
- POST /benchmark/validate β€” compatibility check (no job created)
6
- POST /benchmark/run β€” validate + create + enqueue (202)
7
- GET /benchmark/jobs β€” list jobs (filterable)
8
- GET /benchmark/results/all β€” list all results
9
- GET /benchmark/{job_id} β€” single job status + logs
10
- GET /benchmark/{job_id}/result β€” metrics + telemetry for completed job
11
- WS /benchmark/live/{job_id} β€” real-time progress stream
12
- """
13
- from __future__ import annotations
14
-
15
- import asyncio
16
- from typing import Any
17
-
18
- from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect
19
-
20
- import benchmark.orchestrator as orchestrator
21
- import benchmark.registry as bench_reg
22
- from models.benchmark import (
23
- BenchmarkContext,
24
- BenchmarkJob,
25
- BenchmarkResult,
26
- BenchmarkRunResponse,
27
- ValidationReport,
28
- )
29
- from observability.logger import get_logger
30
-
31
- log = get_logger("api.benchmark")
32
-
33
- router = APIRouter(prefix="/benchmark", tags=["benchmark"])
34
-
35
-
36
- # ── POST /benchmark/validate ──────────────────────────────────────────────────
37
-
38
- @router.post(
39
- "/validate",
40
- response_model = ValidationReport,
41
- summary = "Validate model ↔ dataset compatibility",
42
- description = (
43
- "Runs all 5 compatibility gates (task, format, frameworkΓ—hardware, "
44
- "VRAM, precision) and returns a structured report. "
45
- "Does NOT create a benchmark job."
46
- ),
47
- )
48
- async def validate_benchmark(ctx: BenchmarkContext) -> ValidationReport:
49
- try:
50
- return await orchestrator.validate_context(ctx)
51
- except HTTPException:
52
- raise
53
- except Exception as exc:
54
- log.exception("validate_error")
55
- raise HTTPException(status_code=500, detail=str(exc)) from exc
56
-
57
-
58
- # ── POST /benchmark/run ───────────────────────────────────────────────────────
59
-
60
- @router.post(
61
- "/run",
62
- response_model = BenchmarkRunResponse,
63
- status_code = 202,
64
- summary = "Start a benchmark run",
65
- description = (
66
- "Validates compatibility, creates a benchmark job, and starts async "
67
- "execution. Returns job_id immediately β€” poll GET /benchmark/{job_id} "
68
- "or connect to WS /benchmark/live/{job_id} for progress."
69
- ),
70
- )
71
- async def run_benchmark(ctx: BenchmarkContext) -> BenchmarkRunResponse:
72
- try:
73
- job = await orchestrator.create_and_run(ctx)
74
- return BenchmarkRunResponse(
75
- job_id = job.id,
76
- status = job.status,
77
- message = f"Benchmark job {job.id} queued β€” connect to /benchmark/live/{job.id} for live telemetry",
78
- )
79
- except HTTPException:
80
- raise
81
- except Exception as exc:
82
- log.exception("run_benchmark_error")
83
- raise HTTPException(status_code=500, detail=str(exc)) from exc
84
-
85
-
86
- # ── POST /benchmark/sync ──────────────────────────────────────────────────────────
87
-
88
- @router.post(
89
- "/sync",
90
- summary = "Sync benchmarks from active project folder",
91
- description = "Scans the active project's 'benchmarks' folder and ensures all JSON records are indexed in SQLite.",
92
- )
93
- async def sync_benchmarks() -> dict[str, Any]:
94
- try:
95
- count = await orchestrator.sync_project_benchmarks()
96
- return {"status": "success", "count": count}
97
- except Exception as exc:
98
- log.exception("sync_error")
99
- raise HTTPException(status_code=500, detail=str(exc)) from exc
100
-
101
-
102
- # ── GET /benchmark/jobs ───────────────────────────────────────────────────────
103
-
104
- @router.get(
105
- "/jobs",
106
- response_model = list[BenchmarkJob],
107
- summary = "List benchmark jobs",
108
- )
109
- async def list_jobs(
110
- status: str | None = Query(None, description="Filter by status (queued|running|completed|failed)"),
111
- model_id: str | None = Query(None, description="Filter by model_id"),
112
- limit: int = Query(100, ge=1, le=500),
113
- ) -> list[BenchmarkJob]:
114
- return await bench_reg.list_jobs(status=status, model_id=model_id, limit=limit)
115
-
116
-
117
- # ── GET /benchmark/results/all ────────────────────────────────────────────────
118
- # Must be declared BEFORE /{job_id} to avoid "results" being treated as a job_id
119
-
120
- @router.get(
121
- "/results/all",
122
- response_model = list[BenchmarkResult],
123
- summary = "List all benchmark results (leaderboard feed)",
124
- )
125
- async def list_results(
126
- limit: int = Query(100, ge=1, le=500),
127
- ) -> list[BenchmarkResult]:
128
- return await bench_reg.list_results(limit=limit)
129
-
130
-
131
- # ── GET /benchmark/{job_id} ───────────────────────────────────────────────────
132
-
133
- @router.get(
134
- "/{job_id}",
135
- response_model = BenchmarkJob,
136
- summary = "Get benchmark job status + logs",
137
- )
138
- async def get_job(job_id: str) -> BenchmarkJob:
139
- job = await bench_reg.get_job(job_id)
140
- if not job:
141
- raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
142
- return job
143
-
144
-
145
- # ── GET /benchmark/{job_id}/result ───────────────────────────────────────────
146
-
147
- @router.get(
148
- "/{job_id}/result",
149
- response_model = BenchmarkResult,
150
- summary = "Get final metrics + telemetry for a completed job",
151
- )
152
- async def get_result(job_id: str) -> BenchmarkResult:
153
- result = await bench_reg.get_result(job_id)
154
- if not result:
155
- raise HTTPException(
156
- status_code = 404,
157
- detail = f"No result for job '{job_id}' β€” job may still be running",
158
- )
159
- return result
160
-
161
-
162
- # ── WS /benchmark/live/{job_id} ───────────────────────────────────────────────
163
-
164
- @router.websocket("/live/{job_id}")
165
- async def live_telemetry(websocket: WebSocket, job_id: str) -> None:
166
- """
167
- WebSocket stream for real-time benchmark progress.
168
- Streams incremental logs and high-frequency telemetry.
169
- """
170
- await websocket.accept()
171
- log.info("ws_connected", job_id=job_id)
172
-
173
- last_log_idx = 0
174
-
175
- try:
176
- while True:
177
- job = await bench_reg.get_job(job_id)
178
-
179
- if not job:
180
- await websocket.send_json(
181
- {"error": f"Job '{job_id}' not found", "job_id": job_id}
182
- )
183
- break
184
-
185
- # Only send new logs
186
- new_logs = job.logs[last_log_idx:]
187
- last_log_idx = len(job.logs)
188
-
189
- payload: dict[str, Any] = {
190
- "job_id": job.id,
191
- "status": job.status,
192
- "progress": round(job.progress, 4),
193
- "logs": new_logs,
194
- "telemetry": job.last_telemetry.model_dump() if job.last_telemetry else None,
195
- }
196
- # Explicitly include detections for the UI visualizer if they exist
197
- if job.last_telemetry and hasattr(job.last_telemetry, "detections"):
198
- payload["detections"] = job.last_telemetry.detections
199
-
200
- await websocket.send_json(payload)
201
-
202
- if job.status == "completed":
203
- result = await bench_reg.get_result(job_id)
204
- if result:
205
- await websocket.send_json(
206
- {
207
- "job_id": job_id,
208
- "status": "completed",
209
- "result": result.model_dump(),
210
- }
211
- )
212
- break
213
-
214
- if job.status == "failed":
215
- await websocket.send_json(
216
- {
217
- "job_id": job_id,
218
- "status": "failed",
219
- "error": job.error or "Unknown error",
220
- }
221
- )
222
- break
223
-
224
- await asyncio.sleep(0.5) # poll at 2 Hz
225
-
226
- except WebSocketDisconnect:
227
- log.info("ws_disconnected", job_id=job_id)
228
- except Exception as exc:
229
- log.exception("ws_error", job_id=job_id)
230
- try:
231
- await websocket.send_json({"error": str(exc), "job_id": job_id})
232
- except Exception:
233
- pass
234
- finally:
235
- try:
236
- await websocket.close()
237
- except Exception:
238
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/inference.py DELETED
@@ -1,168 +0,0 @@
1
- """
2
- api/routes/inference.py β€” Inference Engine endpoints.
3
-
4
- POST /inference/run β€” single synchronous inference pass
5
- POST /inference/stream β€” SSE stream (stage-by-stage pipeline events)
6
- GET /inference/history β€” session ledger
7
- DELETE /inference/history β€” clear session ledger
8
- GET /inference/cache β€” currently warm models in memory
9
- DELETE /inference/cache/{model_id} β€” evict from cache
10
- """
11
- from __future__ import annotations
12
-
13
- import asyncio
14
- import json
15
- import time
16
-
17
- from fastapi import APIRouter, HTTPException, Response
18
- from fastapi.responses import StreamingResponse
19
-
20
- from inference.engine import InferenceEngine, evict_model, get_cache_status
21
- from inference.session import clear_history, get_history, record
22
- from models.inference import (
23
- InferenceHistoryEntry,
24
- InferenceRequest,
25
- InferenceResult,
26
- )
27
- from observability.logger import get_logger
28
- from registry.registry import get_model
29
-
30
- log = get_logger("api.inference")
31
- router = APIRouter(prefix="/inference", tags=["inference"])
32
-
33
- _engine = InferenceEngine()
34
-
35
-
36
- # ── Single run ───────────────────────────────────────────────────────────────
37
-
38
- @router.post("/run", response_model=InferenceResult)
39
- async def run_inference(body: InferenceRequest) -> InferenceResult:
40
- """Execute one full inference pass and return the complete result."""
41
- model = await get_model(body.model_id)
42
- if not model:
43
- raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
44
-
45
- result = await _engine.run(body, model)
46
-
47
- if result.status == "error":
48
- raise HTTPException(status_code=500, detail=result.error or "Inference failed")
49
-
50
- await record(body, result, model.name)
51
- return result
52
-
53
-
54
- # ── SSE stream ───────────────────────────────────────────────────────────────
55
-
56
- @router.post("/stream")
57
- async def stream_inference(body: InferenceRequest) -> StreamingResponse:
58
- """
59
- Server-Sent Events stream.
60
- Emits one JSON event per pipeline stage as it completes, then a final
61
- 'done' event with the full InferenceResult.
62
-
63
- Client usage:
64
- const es = new EventSource('/inference/stream'); // POST via fetch + EventSource polyfill
65
- es.onmessage = e => console.log(JSON.parse(e.data));
66
- """
67
- model = await get_model(body.model_id)
68
- if not model:
69
- raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
70
-
71
- queue: asyncio.Queue[str | None] = asyncio.Queue()
72
-
73
- async def _producer() -> None:
74
- """Run inference while pushing SSE events into the queue."""
75
- try:
76
- # Patch engine to emit stage events
77
- result = await _engine_stream(body, model, queue)
78
- await record(body, result, model.name)
79
- # Final complete event
80
- await queue.put(
81
- f"event: done\ndata: {result.model_dump_json()}\n\n"
82
- )
83
- except Exception as exc:
84
- await queue.put(
85
- f"event: error\ndata: {json.dumps({'error': str(exc)})}\n\n"
86
- )
87
- finally:
88
- await queue.put(None) # sentinel
89
-
90
- asyncio.create_task(_producer())
91
-
92
- async def _generator():
93
- while True:
94
- msg = await queue.get()
95
- if msg is None:
96
- break
97
- yield msg
98
-
99
- return StreamingResponse(
100
- _generator(),
101
- media_type="text/event-stream",
102
- headers={
103
- "Cache-Control": "no-cache",
104
- "X-Accel-Buffering": "no",
105
- },
106
- )
107
-
108
-
109
- async def _engine_stream(
110
- req: InferenceRequest,
111
- model,
112
- queue: asyncio.Queue,
113
- ) -> InferenceResult:
114
- """
115
- Run inference and push a 'stage' SSE event for each PipelineStage.
116
- Falls back to a simple full run if streaming is not distinguishable.
117
- """
118
- # Run full pipeline
119
- result = await _engine.run(req, model)
120
-
121
- # Emit one event per stage (replay after completion)
122
- for stage in result.pipeline:
123
- payload = json.dumps({
124
- "type": "stage",
125
- "stage": stage.model_dump(),
126
- "ts": time.time(),
127
- })
128
- await queue.put(f"data: {payload}\n\n")
129
- await asyncio.sleep(0) # yield
130
-
131
- # Emit vitals snapshot
132
- vitals_payload = json.dumps({
133
- "type": "vitals",
134
- "latency_ms": result.inference_ms,
135
- "total_ms": result.total_ms,
136
- "quality": result.quality_score,
137
- })
138
- await queue.put(f"data: {vitals_payload}\n\n")
139
-
140
- return result
141
-
142
-
143
- # ── History ──────────────────────────────────────────────────────────────────
144
-
145
- @router.get("/history", response_model=list[InferenceHistoryEntry])
146
- async def inference_history(limit: int = 50) -> list[InferenceHistoryEntry]:
147
- return await get_history(limit=min(limit, 200))
148
-
149
-
150
- @router.delete("/history", status_code=204, response_model=None)
151
- async def clear_inference_history():
152
- await clear_history()
153
- return Response(status_code=204)
154
-
155
-
156
- # ── Model cache ──────────────────────────────────────────────────────────────
157
-
158
- @router.get("/cache")
159
- async def cache_status() -> dict[str, bool]:
160
- return get_cache_status()
161
-
162
-
163
- @router.delete("/cache/{model_id}", status_code=204, response_model=None)
164
- async def evict_from_cache(model_id: str):
165
- evicted = evict_model(model_id)
166
- if not evicted:
167
- raise HTTPException(status_code=404, detail="Model not in cache")
168
- return Response(status_code=204)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/jobs.py DELETED
@@ -1,56 +0,0 @@
1
- """
2
- api/routes/jobs.py β€” /jobs & /download endpoints.
3
- """
4
- from __future__ import annotations
5
-
6
- from fastapi import APIRouter, HTTPException
7
-
8
- from download.manager import cancel_job, enqueue_download, get_job, list_jobs
9
- from models.job import Job, JobCreate
10
- from observability.logger import audit, get_logger
11
- from registry.registry import get_model
12
-
13
- log = get_logger("api.jobs")
14
- router = APIRouter(tags=["jobs"])
15
-
16
-
17
- @router.post("/download", response_model=Job, status_code=202)
18
- async def trigger_download(body: JobCreate) -> Job:
19
- """Enqueue a model download. Returns the created job immediately."""
20
- model = await get_model(body.model_id)
21
- if not model:
22
- raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
23
- if model.downloaded:
24
- raise HTTPException(status_code=409, detail="Model is already cached locally")
25
-
26
- job_id = await enqueue_download(
27
- model_id=body.model_id,
28
- model_name=body.model_name,
29
- version=body.version,
30
- )
31
- job = await get_job(job_id)
32
- if not job:
33
- raise HTTPException(status_code=500, detail="Job creation failed")
34
-
35
- await audit("api_download_trigger", model_id=body.model_id, job_id=job_id)
36
- return job
37
-
38
-
39
- @router.get("/jobs", response_model=list[Job])
40
- async def jobs_list(status: str | None = None, limit: int = 50) -> list[Job]:
41
- return await list_jobs(status=status, limit=limit)
42
-
43
-
44
- @router.get("/jobs/{job_id}", response_model=Job)
45
- async def job_detail(job_id: str) -> Job:
46
- job = await get_job(job_id)
47
- if not job:
48
- raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
49
- return job
50
-
51
-
52
- @router.delete("/jobs/{job_id}", status_code=204, response_model=None)
53
- async def job_cancel(job_id: str) -> None:
54
- success = await cancel_job(job_id)
55
- if not success:
56
- raise HTTPException(status_code=409, detail="Job cannot be cancelled")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/system.py DELETED
@@ -1,97 +0,0 @@
1
- """api/routes/system.py β€” System metrics endpoints."""
2
-
3
- from __future__ import annotations
4
-
5
- import asyncio
6
- import json
7
-
8
- from fastapi import APIRouter, Query
9
- from fastapi.responses import StreamingResponse
10
-
11
- from models.system import SystemMetrics
12
- from system.metrics import sample_metrics
13
-
14
- router = APIRouter(prefix="/system", tags=["system"])
15
-
16
-
17
- @router.get("/metrics", response_model=SystemMetrics)
18
- async def get_metrics(gpu_index: int = Query(0, ge=0)) -> SystemMetrics:
19
- payload = sample_metrics(gpu_index=gpu_index)
20
- return SystemMetrics(
21
- ts=payload["ts"],
22
- cpu_pct=payload["cpu_pct"],
23
- cpu_model=payload.get("cpu_model"),
24
- cpu_freq_mhz=payload.get("cpu_freq_mhz"),
25
- cpu_count=payload.get("cpu_count"),
26
- ram_used_mb=payload["ram_used_mb"],
27
- ram_total_mb=payload["ram_total_mb"],
28
- gpu=payload.get("gpu"),
29
- disks=payload.get("disks", []),
30
- network=payload.get("network", []),
31
- )
32
-
33
-
34
- @router.get("/metrics/stream")
35
- async def stream_metrics(
36
- gpu_index: int = Query(0, ge=0),
37
- hz: float = Query(2.0, ge=0.2, le=20.0),
38
- ):
39
- """Server-Sent Events stream of system metrics."""
40
-
41
- interval = 1.0 / float(hz)
42
-
43
- async def gen():
44
- # Initial comment helps some proxies establish the stream
45
- yield ": connected\n\n"
46
- while True:
47
- try:
48
- payload = sample_metrics(gpu_index=gpu_index)
49
- # Ensure the payload is valid JSON and wrapped in data: format
50
- data = json.dumps(payload)
51
- yield f"data: {data}\n\n"
52
- except Exception as e:
53
- # Log error but keep stream alive
54
- print(f"Metrics streaming error: {e}")
55
- await asyncio.sleep(interval)
56
-
57
- return StreamingResponse(
58
- gen(),
59
- media_type="text/event-stream",
60
- headers={
61
- "Cache-Control": "no-cache",
62
- "X-Accel-Buffering": "no",
63
- "Connection": "keep-alive",
64
- "Transfer-Encoding": "chunked",
65
- },
66
- )
67
-
68
-
69
- @router.get("/logs/stream")
70
- async def stream_system_logs():
71
- """SSE stream of global system and gateway logs."""
72
- from observability.logger import _sys_log_subs
73
-
74
- q: asyncio.Queue = asyncio.Queue()
75
- _sys_log_subs.append(q)
76
-
77
- async def generator():
78
- yield ": connected\n\n"
79
- try:
80
- while True:
81
- try:
82
- entry = await asyncio.wait_for(q.get(), timeout=30.0)
83
- except asyncio.TimeoutError:
84
- yield ": heartbeat\n\n"
85
- continue
86
- if entry is None:
87
- break
88
- yield f"data: {json.dumps(entry)}\n\n"
89
- finally:
90
- if q in _sys_log_subs:
91
- _sys_log_subs.remove(q)
92
-
93
- return StreamingResponse(
94
- generator(),
95
- media_type="text/event-stream",
96
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/routes/training.py DELETED
@@ -1,428 +0,0 @@
1
- """
2
- api/routes/training.py β€” Training Engine REST + SSE endpoints.
3
-
4
- POST /train/start β€” create and launch a training run
5
- POST /train/stop β€” cancel a running run
6
- POST /train/pause β€” pause a running run
7
- POST /train/resume β€” resume a paused run
8
- GET /train/status β€” run status + progress snapshot
9
- GET /train/runs β€” list all runs
10
- GET /train/runs/{run_id} β€” single run detail
11
- GET /train/schema β€” UI schema for task/model/dataset combo
12
- GET /train/checkpoints β€” checkpoints for a run (stub)
13
- POST /train/checkpoints/{id}/export β€” export a checkpoint (stub)
14
- GET /train/metrics/stream β€” SSE: real-time metrics ticks
15
- GET /train/logs/stream β€” SSE: real-time log entries
16
- GET /train/resources/stream β€” SSE: real-time resource ticks
17
- """
18
- from __future__ import annotations
19
-
20
- import asyncio
21
- import json
22
- import time
23
- import os
24
-
25
- from fastapi import APIRouter, HTTPException, Query
26
- from fastapi.responses import StreamingResponse
27
-
28
- from observability.logger import get_logger
29
- from training import run_manager
30
- from training.schema_engine import generate_schema
31
- from training.schemas import (
32
- CheckpointOut,
33
- PauseTrainRequest,
34
- ResumeTrainRequest,
35
- StartTrainRequest,
36
- StartTrainResponse,
37
- StopTrainRequest,
38
- TrainRunOut,
39
- TrainStatusResponse,
40
- TrainingSchemaResponse,
41
- )
42
-
43
- log = get_logger("api.training")
44
- router = APIRouter(prefix="/train", tags=["training"])
45
-
46
- # ── Helpers ────────────────────────────────────────────────────────────────────
47
-
48
- def _format_duration(seconds: float) -> str:
49
- h = int(seconds // 3600)
50
- m = int((seconds % 3600) // 60)
51
- s = int(seconds % 60)
52
- return f"{h}h {m}m {s}s"
53
-
54
-
55
- def _run_to_out(run: run_manager.TrainRun) -> TrainRunOut:
56
- elapsed = (run.completed_at or time.time()) - run.created_at
57
- return TrainRunOut(
58
- id=run.run_id,
59
- run_number=run.run_number,
60
- model_id=run.model_id,
61
- model_name=run.model_name,
62
- dataset_id=run.dataset_id,
63
- dataset_name=run.dataset_name,
64
- task=run.task,
65
- status=run.status,
66
- epochs_done=run.epoch,
67
- total_epochs=run.total_epochs,
68
- best_metric=run.best_metric,
69
- final_loss=run.final_loss,
70
- duration=_format_duration(elapsed),
71
- created_at=run.created_at,
72
- completed_at=run.completed_at,
73
- hyperparams=run.hyperparams,
74
- )
75
-
76
-
77
- # ── Control endpoints ─────────────────────────────────────────────────────────
78
-
79
- @router.post("/start", response_model=StartTrainResponse)
80
- async def start_training(body: StartTrainRequest) -> StartTrainResponse:
81
- """Create and immediately launch a training run."""
82
- # Resolve friendly names (fall back to ids if registries unavailable)
83
- model_name = body.model_id
84
- dataset_name = body.dataset_id
85
- try:
86
- from registry.registry import get_model
87
- m = await get_model(body.model_id)
88
- if m:
89
- model_name = m.name
90
- except Exception:
91
- pass
92
- try:
93
- from datasets.registry import get_dataset
94
- d = await get_dataset(body.dataset_id)
95
- if d:
96
- dataset_name = d.get("name", body.dataset_id) if isinstance(d, dict) else getattr(d, "name", body.dataset_id)
97
- except Exception:
98
- pass
99
-
100
- run = run_manager.create_run(
101
- model_id=body.model_id,
102
- model_name=model_name,
103
- dataset_id=body.dataset_id,
104
- dataset_name=dataset_name,
105
- task=body.task,
106
- hyperparams=body.hyperparams,
107
- augmentation=body.augmentation,
108
- scheduler=body.scheduler,
109
- project_id=body.project_id
110
- )
111
- run_manager.start_run(run)
112
-
113
- log.info("training_started", run_id=run.run_id, model=body.model_id)
114
- return StartTrainResponse(
115
- run_id=run.run_id,
116
- status=run.status,
117
- message=f"Training run {run.run_id} started.",
118
- )
119
-
120
-
121
- @router.post("/stop", status_code=200)
122
- async def stop_training(body: StopTrainRequest) -> dict:
123
- run = run_manager.get_run(body.run_id)
124
- if not run:
125
- raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
126
- run_manager.stop_run(run)
127
- log.info("training_stopped", run_id=body.run_id)
128
- return {"run_id": body.run_id, "status": run.status}
129
-
130
-
131
- @router.post("/pause", status_code=200)
132
- async def pause_training(body: PauseTrainRequest) -> dict:
133
- run = run_manager.get_run(body.run_id)
134
- if not run:
135
- raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
136
- run_manager.pause_run(run)
137
- return {"run_id": body.run_id, "status": run.status}
138
-
139
-
140
- @router.post("/resume", status_code=200)
141
- async def resume_training(body: ResumeTrainRequest) -> dict:
142
- run = run_manager.get_run(body.run_id)
143
- if not run:
144
- raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
145
- run_manager.resume_run(run)
146
- return {"run_id": body.run_id, "status": run.status}
147
-
148
-
149
- @router.get("/status", response_model=TrainStatusResponse)
150
- async def get_train_status(run_id: str = Query(...)) -> TrainStatusResponse:
151
- run = run_manager.get_run(run_id)
152
- if not run:
153
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
154
- return TrainStatusResponse(
155
- run_id=run.run_id,
156
- status=run.status,
157
- epoch=run.epoch,
158
- total_epochs=run.total_epochs,
159
- step=run.step,
160
- total_steps=run.total_epochs * 100,
161
- eta_seconds=run.eta_seconds,
162
- elapsed_seconds=run.elapsed_seconds,
163
- )
164
-
165
-
166
- # ── Run history ───────────────────────────────────────────────────────────────
167
-
168
- @router.get("/runs", response_model=list[TrainRunOut])
169
- async def list_runs() -> list[TrainRunOut]:
170
- return [_run_to_out(r) for r in reversed(run_manager.list_runs())]
171
-
172
-
173
- @router.get("/runs/{run_id}", response_model=TrainRunOut)
174
- async def get_run(run_id: str) -> TrainRunOut:
175
- run = run_manager.get_run(run_id)
176
- if not run:
177
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
178
- return _run_to_out(run)
179
-
180
-
181
- # ── Schema Engine ─────────────────────────────────────────────────────────────
182
-
183
- @router.get("/schema", response_model=TrainingSchemaResponse)
184
- async def get_schema(
185
- model_id: str = Query(""),
186
- dataset_id: str = Query(""),
187
- task: str = Query("detection"),
188
- ) -> TrainingSchemaResponse:
189
- schema = generate_schema(task=task, model_id=model_id, dataset_id=dataset_id)
190
- return TrainingSchemaResponse(**schema)
191
-
192
-
193
- # ── Checkpoints (stub β€” extend when artifact storage is wired) ────────────────
194
-
195
- @router.get("/checkpoints", response_model=list[CheckpointOut])
196
- async def list_checkpoints(run_id: str = Query(...)) -> list[CheckpointOut]:
197
- """Returns an empty list until checkpoint persistence is implemented."""
198
- run = run_manager.get_run(run_id)
199
- if not run:
200
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
201
- return []
202
-
203
-
204
- @router.post("/checkpoints/{checkpoint_id}/export")
205
- async def export_checkpoint(checkpoint_id: str, body: dict = {}) -> dict:
206
- raise HTTPException(status_code=501, detail="Checkpoint export not yet implemented")
207
-
208
-
209
- # ── SSE: Metrics stream ────────────────────────────────────────────────────────
210
-
211
- @router.get("/metrics/stream")
212
- async def stream_metrics(run_id: str = Query(...)) -> StreamingResponse:
213
- """
214
- Server-Sent Events stream of TrainMetricsTick objects.
215
- Connects to the run's metrics queue and forwards each tick as SSE.
216
- Stream closes when the run finishes (sentinel None pushed by worker).
217
- """
218
- run = run_manager.get_run(run_id)
219
- if not run:
220
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
221
-
222
- q: asyncio.Queue = asyncio.Queue()
223
- run.metrics_subs.append(q)
224
-
225
- async def generator():
226
- yield ": connected\n\n"
227
- try:
228
- while True:
229
- try:
230
- tick = await asyncio.wait_for(q.get(), timeout=30.0)
231
- except asyncio.TimeoutError:
232
- # Heartbeat to keep connection alive
233
- yield ": heartbeat\n\n"
234
- continue
235
- if tick is None:
236
- break
237
- yield f"data: {json.dumps(tick)}\n\n"
238
- finally:
239
- if q in run.metrics_subs:
240
- run.metrics_subs.remove(q)
241
-
242
- return StreamingResponse(
243
- generator(),
244
- media_type="text/event-stream",
245
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
246
- )
247
-
248
-
249
- # ── SSE: Logs stream ──────────────────────────────────────────────────────────
250
-
251
- @router.get("/logs/stream")
252
- async def stream_logs(run_id: str = Query(...)) -> StreamingResponse:
253
- """Server-Sent Events stream of LogEntry objects."""
254
- run = run_manager.get_run(run_id)
255
- if not run:
256
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
257
-
258
- q: asyncio.Queue = asyncio.Queue()
259
- run.log_subs.append(q)
260
-
261
- async def generator():
262
- yield ": connected\n\n"
263
- try:
264
- while True:
265
- try:
266
- entry = await asyncio.wait_for(q.get(), timeout=30.0)
267
- except asyncio.TimeoutError:
268
- yield ": heartbeat\n\n"
269
- continue
270
- if entry is None:
271
- break
272
- yield f"data: {json.dumps(entry)}\n\n"
273
- finally:
274
- if q in run.log_subs:
275
- run.log_subs.remove(q)
276
-
277
- return StreamingResponse(
278
- generator(),
279
- media_type="text/event-stream",
280
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
281
- )
282
-
283
-
284
- @router.get("/runs/{run_id}/history")
285
- async def get_run_history(run_id: str) -> list[dict]:
286
- """Retrieves the full historical telemetry (metrics ticks) for a run."""
287
- run = run_manager.get_run(run_id)
288
- if not run:
289
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
290
-
291
- from training.persistence import TrainingPersistence
292
- run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
293
- telemetry_path = os.path.join(run_dir, "telemetry.jsonl")
294
-
295
- history = []
296
- if os.path.exists(telemetry_path):
297
- try:
298
- with open(telemetry_path, "r") as f:
299
- for line in f:
300
- if line.strip():
301
- history.append(json.loads(line))
302
- except Exception as e:
303
- log.error("history_read_failed", run_id=run_id, error=str(e))
304
- raise HTTPException(status_code=500, detail="Failed to read telemetry history")
305
-
306
- return history
307
-
308
- @router.get("/runs/{run_id}/artifacts")
309
- async def list_run_artifacts(run_id: str) -> dict:
310
- """Lists available artifacts (images) for a specific run by scanning the directory."""
311
- run = run_manager.get_run(run_id)
312
- if not run:
313
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
314
-
315
- from training.persistence import TrainingPersistence
316
- run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
317
-
318
- if not os.path.exists(run_dir):
319
- return {"artifacts": [], "batches": []}
320
-
321
- artifacts = []
322
- batches = []
323
-
324
- # Standard YOLO artifact mappings for better UI titles
325
- titles = {
326
- "confusion_matrix.png": "Confusion Matrix",
327
- "confusion_matrix_normalized.png": "Confusion Matrix (Norm)",
328
- "results.png": "Results Summary",
329
- "F1_curve.png": "F1 Curve",
330
- "PR_curve.png": "PR Curve",
331
- "P_curve.png": "Precision Curve",
332
- "R_curve.png": "Recall Curve",
333
- "BoxF1_curve.png": "Box F1 Curve",
334
- "BoxP_curve.png": "Box Precision Curve",
335
- "BoxPR_curve.png": "Box PR Curve",
336
- "BoxR_curve.png": "Box Recall Curve",
337
- "labels.jpg": "Labels Distribution",
338
- "labels_correlogram.jpg": "Labels Correlogram"
339
- }
340
-
341
- for f in os.listdir(run_dir):
342
- path = f"/train/runs/{run_id}/files/{f}"
343
- if f.endswith(('.png', '.jpg', '.jpeg')):
344
- item = {
345
- "title": titles.get(f, f.replace('_', ' ').title().split('.')[0]),
346
- "path": path,
347
- "type": "Analysis"
348
- }
349
-
350
- if "batch" in f.lower():
351
- item["type"] = "Batch Preview" if "val" in f.lower() else "Augmentation"
352
- batches.append(item)
353
- else:
354
- if "curve" in f.lower():
355
- item["type"] = "Precision-Recall"
356
- elif "confusion" in f.lower():
357
- item["type"] = "Analysis"
358
- elif "results" in f.lower():
359
- item["type"] = "Overall"
360
- artifacts.append(item)
361
-
362
- return {
363
- "artifacts": sorted(artifacts, key=lambda x: x['title']),
364
- "batches": sorted(batches, key=lambda x: x['title'])
365
- }
366
-
367
- @router.get("/runs/{run_id}/files/{filename}")
368
- async def get_run_file(run_id: str, filename: str):
369
- """Serves a specific file from the run directory."""
370
- run = run_manager.get_run(run_id)
371
- if not run:
372
- raise HTTPException(status_code=404, detail="Run not found")
373
-
374
- # We need to find the project to get the run_dir
375
- # Since run_manager doesn't easily expose the full path in memory,
376
- # we recalculate it using persistence
377
- from training.persistence import TrainingPersistence
378
- run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
379
- file_path = os.path.join(run_dir, filename)
380
-
381
- if not os.path.exists(file_path):
382
- raise HTTPException(status_code=404, detail="File not found")
383
-
384
- from fastapi.responses import FileResponse
385
- return FileResponse(file_path)
386
- # The frontend uses /system/metrics/stream for resources (already implemented).
387
- # This alias exists for training-scoped resource monitoring.
388
-
389
- @router.get("/resources/stream")
390
- async def stream_resources(
391
- run_id: str = Query(...),
392
- gpu_index: int = Query(0, ge=0),
393
- hz: float = Query(1.0, ge=0.2, le=10.0),
394
- ) -> StreamingResponse:
395
- """
396
- SSE stream of ResourceTick objects for a specific training run.
397
- Forwards system metrics at the requested hz rate.
398
- """
399
- run = run_manager.get_run(run_id)
400
- if not run:
401
- raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
402
-
403
- q: asyncio.Queue = asyncio.Queue()
404
- run.resource_subs.append(q)
405
-
406
- interval = 1.0 / hz
407
-
408
- async def generator():
409
- yield ": connected\n\n"
410
- try:
411
- while True:
412
- try:
413
- tick = await asyncio.wait_for(q.get(), timeout=30.0)
414
- except asyncio.TimeoutError:
415
- yield ": heartbeat\n\n"
416
- continue
417
- if tick is None:
418
- break
419
- yield f"data: {json.dumps(tick)}\n\n"
420
- finally:
421
- if q in run.resource_subs:
422
- run.resource_subs.remove(q)
423
-
424
- return StreamingResponse(
425
- generator(),
426
- media_type="text/event-stream",
427
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
428
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/__init__.py DELETED
@@ -1 +0,0 @@
1
- # benchmark β€” Benchmark Bridge System for MLForge
 
 
benchmark/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (144 Bytes)
 
benchmark/__pycache__/compatibility.cpython-310.pyc DELETED
Binary file (8.3 kB)
 
benchmark/__pycache__/execution.cpython-310.pyc DELETED
Binary file (10.4 kB)
 
benchmark/__pycache__/metrics.cpython-310.pyc DELETED
Binary file (3.24 kB)
 
benchmark/__pycache__/orchestrator.cpython-310.pyc DELETED
Binary file (9.11 kB)
 
benchmark/__pycache__/registry.cpython-310.pyc DELETED
Binary file (8.77 kB)
 
benchmark/__pycache__/telemetry.cpython-310.pyc DELETED
Binary file (6.73 kB)
 
benchmark/adapters/__pycache__/base.cpython-310.pyc DELETED
Binary file (1.8 kB)
 
benchmark/adapters/__pycache__/registry.cpython-310.pyc DELETED
Binary file (1.89 kB)
 
benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc DELETED
Binary file (1.93 kB)
 
benchmark/adapters/base.py DELETED
@@ -1,38 +0,0 @@
1
- """
2
- benchmark/adapters/base.py β€” Base class for all Benchmark Runners.
3
- """
4
- from __future__ import annotations
5
-
6
- from abc import ABC, abstractmethod
7
- from dataclasses import dataclass, field
8
- from typing import Any, AsyncGenerator
9
-
10
- from models.benchmark import BenchmarkContext, TelemetrySample
11
-
12
-
13
- @dataclass
14
- class BatchResult:
15
- """Result of a single batch execution."""
16
- latency_ms: float
17
- vram_used_gb: float
18
- task_scores: dict[str, float] = field(default_factory=dict)
19
- metadata: dict[str, Any] = field(default_factory=dict)
20
-
21
-
22
- class BaseRunner(ABC):
23
- """Abstract interface for benchmark executors (Torch, Optimum, vLLM)."""
24
-
25
- @abstractmethod
26
- async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
27
- """Load model and prepare environment."""
28
- pass
29
-
30
- @abstractmethod
31
- async def run_batch(self, batch: Any) -> BatchResult:
32
- """Execute a single batch of data."""
33
- pass
34
-
35
- @abstractmethod
36
- async def shutdown(self) -> None:
37
- """Release resources."""
38
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/adapters/optimum_runner.py DELETED
@@ -1,53 +0,0 @@
1
- """
2
- benchmark/adapters/optimum_runner.py β€” Hugging Face Optimum Adapter.
3
- Supports ONNX, OpenVINO, and TensorRT acceleration.
4
- """
5
- from __future__ import annotations
6
-
7
- import time
8
- import asyncio
9
- from typing import Any
10
- from benchmark.adapters.base import BaseRunner, BatchResult
11
- from models.benchmark import BenchmarkContext
12
- from observability.logger import get_logger
13
-
14
- log = get_logger("benchmark.optimum")
15
-
16
- class OptimumRunner(BaseRunner):
17
- def __init__(self):
18
- self.session = None
19
- self.device = "cpu"
20
-
21
- async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
22
- """
23
- Load model using Optimum's ORTModel or equivalent.
24
- In a real implementation, this would detect the framework and use:
25
- ORTModelForFeatureExtraction.from_pretrained(model_path, provider=...)
26
- """
27
- log.info("optimum_init", model_path=model_path, hardware=ctx.hardware)
28
- self.device = "cuda" if "gpu" in ctx.hardware.lower() or "rtx" in ctx.hardware.lower() else "cpu"
29
-
30
- # Simulate load time
31
- await asyncio.sleep(1.5)
32
- self.session = "active" # Placeholder for the real session object
33
-
34
- async def run_batch(self, batch: Any) -> BatchResult:
35
- """Execute inference using the Optimum/ONNX Runtime session."""
36
- if not self.session:
37
- raise RuntimeError("Optimum session not initialized")
38
-
39
- start_time = time.perf_counter()
40
- # Mocking inference logic
41
- # outputs = self.session(**batch)
42
- await asyncio.sleep(0.01) # Simulated inference time
43
- latency = (time.perf_counter() - start_time) * 1000
44
-
45
- return BatchResult(
46
- latency_ms=latency,
47
- vram_used_gb=0.8, # Mocked
48
- task_scores={"accuracy": 0.92} # Mocked
49
- )
50
-
51
- async def shutdown(self) -> None:
52
- log.info("optimum_shutdown")
53
- self.session = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/adapters/registry.py DELETED
@@ -1,44 +0,0 @@
1
- """
2
- benchmark/adapters/registry.py β€” Executor Registry for dynamic runner resolution.
3
- """
4
- from __future__ import annotations
5
-
6
- from typing import Type
7
- from benchmark.adapters.base import BaseRunner
8
- from models.benchmark import BenchmarkContext
9
- from models.model import Model
10
-
11
- class ExecutorRegistry:
12
- _runners: dict[str, Type[BaseRunner]] = {}
13
-
14
- @classmethod
15
- def register(cls, framework: str, runner_cls: Type[BaseRunner]):
16
- cls._runners[framework.lower()] = runner_cls
17
-
18
- @classmethod
19
- def get_runner(cls, framework: str) -> BaseRunner:
20
- runner_cls = cls._runners.get(framework.lower())
21
- if not runner_cls:
22
- # Fallback or default runner
23
- from benchmark.adapters.torch_runner import TorchRunner
24
- return TorchRunner()
25
- return runner_cls()
26
-
27
- def get_executor(ctx: BenchmarkContext, model: Model) -> BaseRunner:
28
- """Resolve the appropriate executor based on framework and task."""
29
- framework = model.framework.lower()
30
-
31
- # Special cases for optimized engines
32
- if framework == "onnx" or framework == "openvino" or framework == "tensorrt":
33
- from benchmark.adapters.optimum_runner import OptimumRunner
34
- return OptimumRunner()
35
-
36
- if ctx.task in ("generation", "nlp") and framework == "pytorch":
37
- # Potential for vLLM if configured
38
- try:
39
- from benchmark.adapters.vllm_runner import VLLMRunner
40
- return VLLMRunner()
41
- except ImportError:
42
- pass
43
-
44
- return ExecutorRegistry.get_runner(framework)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/adapters/torch_runner.py DELETED
@@ -1,45 +0,0 @@
1
- """
2
- benchmark/adapters/torch_runner.py β€” PyTorch Runner Adapter.
3
- Wraps standard PyTorch inference for Vision and NLP tasks.
4
- """
5
- from __future__ import annotations
6
-
7
- import time
8
- import asyncio
9
- import random
10
- from typing import Any
11
- from benchmark.adapters.base import BaseRunner, BatchResult
12
- from models.benchmark import BenchmarkContext
13
- from observability.logger import get_logger
14
-
15
- log = get_logger("benchmark.torch")
16
-
17
- class TorchRunner(BaseRunner):
18
- def __init__(self):
19
- self.model = None
20
- self.device = "cpu"
21
-
22
- async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
23
- log.info("torch_init", model_path=model_path, hardware=ctx.hardware)
24
- # In production: self.model = torch.load(model_path).to(self.device)
25
- await asyncio.sleep(1.0)
26
- self.model = "active"
27
-
28
- async def run_batch(self, batch: Any) -> BatchResult:
29
- if not self.model:
30
- raise RuntimeError("Torch model not initialized")
31
-
32
- start_time = time.perf_counter()
33
- # Mocking torch inference
34
- await asyncio.sleep(0.02)
35
- latency = (time.perf_counter() - start_time) * 1000
36
-
37
- return BatchResult(
38
- latency_ms=latency,
39
- vram_used_gb=1.2,
40
- task_scores={"mAP": 0.45}
41
- )
42
-
43
- async def shutdown(self) -> None:
44
- log.info("torch_shutdown")
45
- self.model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/compatibility.py DELETED
@@ -1,360 +0,0 @@
1
- """
2
- benchmark/compatibility.py β€” Compatibility Validator (CRITICAL MODULE).
3
-
4
- Validates model ↔ dataset ↔ hardware compatibility before any benchmark
5
- execution begins. Returns a structured ValidationReport β€” never raises.
6
-
7
- Five gates (all must pass):
8
- A. Task compatibility β€” model.task matches dataset.task
9
- B. Annotation format β€” dataset format supports the model's task
10
- C. Framework Γ— hardware β€” framework can run on the requested device
11
- D. VRAM constraint β€” estimated memory fits available VRAM
12
- E. Precision support β€” precision mode is valid for framework + hardware
13
- """
14
- from __future__ import annotations
15
-
16
- from models.benchmark import BenchmarkContext, ValidationCheck, ValidationReport
17
- from models.dataset import Dataset
18
- from models.model import Model
19
- from observability.logger import get_logger
20
-
21
- log = get_logger("benchmark.compatibility")
22
-
23
-
24
- # ── Lookup tables ─────────────────────────────────────────────────────────────
25
-
26
- # Hardware β†’ available VRAM in GB (normalized keys, no spaces/dashes)
27
- HARDWARE_VRAM_GB: dict[str, float] = {
28
- # NVIDIA consumer β€” Ampere / Ada
29
- "rtx4090": 24.0,
30
- "rtx4080": 16.0,
31
- "rtx4070ti": 12.0,
32
- "rtx4070": 12.0,
33
- "rtx4060ti": 8.0,
34
- "rtx4060": 8.0,
35
- "rtx3090": 24.0,
36
- "rtx3080": 10.0,
37
- "rtx3070": 8.0,
38
- "rtx3060": 12.0,
39
- "rtx2080ti": 11.0,
40
- "rtx2080": 8.0,
41
- # NVIDIA datacenter
42
- "a100": 80.0,
43
- "a10040gb": 40.0,
44
- "h100": 80.0,
45
- "v100": 32.0,
46
- "t4": 16.0,
47
- "a10": 24.0,
48
- # AMD
49
- "rx7900xtx": 24.0,
50
- "rx6800xt": 16.0,
51
- # Generic fallbacks
52
- "gpu": 8.0,
53
- "cpu": 0.0,
54
- "tpu": 0.0,
55
- "edge": 0.0,
56
- }
57
-
58
- # model.task β†’ set of compatible dataset.task values
59
- TASK_COMPAT: dict[str, set[str]] = {
60
- "detection": {"detection"},
61
- "classification": {"classification"},
62
- "segmentation": {"segmentation"},
63
- "nlp": {"nlp"},
64
- "generation": {"generation"},
65
- "keypoints": {"keypoints", "detection"},
66
- "embedding": {"nlp", "classification"},
67
- }
68
-
69
- # dataset.format β†’ set of model tasks it supports
70
- FORMAT_TASK_COMPAT: dict[str, set[str]] = {
71
- "yolo": {"detection", "segmentation", "keypoints"},
72
- "coco": {"detection", "segmentation", "keypoints"},
73
- "voc": {"detection"},
74
- "csv": {"classification"},
75
- "json": {"detection", "segmentation", "classification", "nlp", "generation"},
76
- "tfrecord": {"detection", "classification", "segmentation"},
77
- "custom": {"detection", "classification", "segmentation", "nlp", "generation", "keypoints"},
78
- }
79
-
80
- # model.framework β†’ set of hardware targets (normalized) it can run on
81
- FRAMEWORK_HARDWARE_COMPAT: dict[str, set[str]] = {
82
- "pytorch": {
83
- "cpu", "gpu",
84
- "rtx4090", "rtx4080", "rtx4070ti", "rtx4070", "rtx4060ti", "rtx4060",
85
- "rtx3090", "rtx3080", "rtx3070", "rtx3060",
86
- "rtx2080ti", "rtx2080",
87
- "a100", "a10040gb", "h100", "v100", "t4", "a10",
88
- },
89
- "onnx": {
90
- "cpu", "gpu",
91
- "rtx4090", "rtx3090", "a100", "h100", "t4", "a10",
92
- "edge",
93
- },
94
- "tensorflow": {
95
- "cpu", "gpu",
96
- "rtx4090", "rtx3090", "a100", "h100", "v100", "t4",
97
- "tpu",
98
- },
99
- "tflite": {"cpu", "edge"},
100
- "coreml": {"cpu"},
101
- }
102
-
103
- # Precisions that require GPU
104
- _GPU_ONLY_PRECISIONS = {"FP16", "BF16"}
105
-
106
- # Frameworks supporting INT8 quantization
107
- _INT8_FRAMEWORKS = {"onnx", "tflite", "pytorch", "tensorflow"}
108
-
109
-
110
- class CompatibilityValidator:
111
- """
112
- Runs all compatibility gates before a benchmark job is created.
113
- Returns a ValidationReport β€” never raises exceptions.
114
- """
115
-
116
- def validate(
117
- self,
118
- model: Model,
119
- dataset: Dataset,
120
- ctx: BenchmarkContext,
121
- ) -> ValidationReport:
122
- checks: list[ValidationCheck] = [
123
- self._check_task(model, dataset),
124
- self._check_annotation_format(model, dataset),
125
- self._check_framework_hardware(model, ctx),
126
- self._check_vram(model, ctx),
127
- self._check_precision(model, ctx),
128
- ]
129
-
130
- errors = [c.detail for c in checks if not c.passed]
131
- warnings: list[str] = []
132
-
133
- log.info(
134
- "compatibility_validated",
135
- model_id = model.id,
136
- dataset_id = dataset.id,
137
- passed = len(errors) == 0,
138
- error_count = len(errors),
139
- )
140
-
141
- return ValidationReport(
142
- model_id = model.id,
143
- dataset_id = dataset.id,
144
- passed = len(errors) == 0,
145
- checks = checks,
146
- errors = errors,
147
- warnings = warnings,
148
- )
149
-
150
- # ── Gate A: Task ────────────────────��─────────────────────────────────────
151
-
152
- def _check_task(self, model: Model, dataset: Dataset) -> ValidationCheck:
153
- model_task = model.task.lower().strip()
154
- dataset_task = str(dataset.task).lower().strip()
155
-
156
- allowed = TASK_COMPAT.get(model_task, {model_task})
157
- if dataset_task in allowed:
158
- return ValidationCheck(
159
- name = "task_compatibility",
160
- passed = True,
161
- detail = (
162
- f"Model task '{model_task}' is compatible "
163
- f"with dataset task '{dataset_task}'"
164
- ),
165
- )
166
- return ValidationCheck(
167
- name = "task_compatibility",
168
- passed = False,
169
- detail = (
170
- f"Model task '{model_task}' cannot evaluate "
171
- f"a '{dataset_task}' dataset"
172
- ),
173
- suggestion = (
174
- f"Select a model with task='{dataset_task}', "
175
- f"or choose a dataset with task='{model_task}'"
176
- ),
177
- )
178
-
179
- # ── Gate B: Annotation Format ─────────────────────────────────────────────
180
-
181
- def _check_annotation_format(self, model: Model, dataset: Dataset) -> ValidationCheck:
182
- dataset_fmt = str(dataset.format).lower().strip()
183
- model_task = model.task.lower().strip()
184
- supported = FORMAT_TASK_COMPAT.get(dataset_fmt, set())
185
-
186
- if model_task in supported:
187
- return ValidationCheck(
188
- name = "annotation_format",
189
- passed = True,
190
- detail = (
191
- f"Dataset format '{dataset_fmt}' supports "
192
- f"model task '{model_task}'"
193
- ),
194
- )
195
-
196
- if model_task in {"detection", "segmentation", "keypoints"}:
197
- suggestion = (
198
- f"Convert dataset to YOLO or COCO format β€” both support '{model_task}'"
199
- )
200
- elif model_task == "classification":
201
- suggestion = "Convert dataset to CSV or JSON format for classification tasks"
202
- else:
203
- suggestion = f"Use a JSON or custom-format dataset for '{model_task}' tasks"
204
-
205
- return ValidationCheck(
206
- name = "annotation_format",
207
- passed = False,
208
- detail = (
209
- f"Dataset format '{dataset_fmt}' does not support "
210
- f"model task '{model_task}'"
211
- ),
212
- suggestion = suggestion,
213
- )
214
-
215
- # ── Gate C: Framework Γ— Hardware ─────────────────────────────────────────
216
-
217
- def _check_framework_hardware(
218
- self, model: Model, ctx: BenchmarkContext
219
- ) -> ValidationCheck:
220
- framework = model.framework.lower().strip()
221
- hw_raw = ctx.hardware
222
- hw_key = self._normalize_hw(hw_raw)
223
-
224
- supported_hw = FRAMEWORK_HARDWARE_COMPAT.get(framework, {"cpu"})
225
-
226
- # Match: exact key, or generic "gpu" bucket covers any named GPU
227
- hw_ok = (
228
- hw_key in supported_hw
229
- or ("gpu" in supported_hw and hw_key not in {"cpu", "tpu", "edge"})
230
- )
231
-
232
- if hw_ok:
233
- return ValidationCheck(
234
- name = "framework_hardware",
235
- passed = True,
236
- detail = f"Framework '{framework}' is supported on '{hw_raw}'",
237
- )
238
- return ValidationCheck(
239
- name = "framework_hardware",
240
- passed = False,
241
- detail = (
242
- f"Framework '{framework}' cannot run on '{hw_raw}'. "
243
- f"Supported targets: {', '.join(sorted(supported_hw))}"
244
- ),
245
- suggestion = (
246
- "Use ONNX runtime for broadest hardware support, "
247
- f"or pick a device from: {', '.join(sorted(supported_hw))}"
248
- ),
249
- )
250
-
251
- # ── Gate D: VRAM Constraint ───────────────────────────────────────────────
252
-
253
- def _check_vram(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
254
- hw_key = self._normalize_hw(ctx.hardware)
255
- available = self._lookup_vram(hw_key)
256
-
257
- if available == 0.0:
258
- return ValidationCheck(
259
- name = "vram_constraint",
260
- passed = True,
261
- detail = f"Running on '{ctx.hardware}' (CPU/TPU/Edge) β€” no VRAM constraint",
262
- )
263
-
264
- # Estimate: weights at given precision + activations for one batch
265
- model_gb = max(model.size, 1) / (1024 ** 3)
266
- prec_map = {"FP16": 0.5, "BF16": 0.5, "INT8": 0.25, "FP32": 1.0}
267
- prec_mult = prec_map.get(ctx.precision.upper(), 1.0)
268
- # weights Γ— precision + ~20% for optimizer/activation buffers + batch overhead
269
- estimated = (model_gb * prec_mult * 1.2) + (ctx.batch_size * 0.05)
270
-
271
- if estimated <= available:
272
- return ValidationCheck(
273
- name = "vram_constraint",
274
- passed = True,
275
- detail = (
276
- f"Estimated VRAM {estimated:.2f} GB ≀ "
277
- f"available {available:.1f} GB on '{ctx.hardware}'"
278
- ),
279
- )
280
- return ValidationCheck(
281
- name = "vram_constraint",
282
- passed = False,
283
- detail = (
284
- f"Estimated VRAM {estimated:.2f} GB exceeds "
285
- f"available {available:.1f} GB on '{ctx.hardware}'"
286
- ),
287
- suggestion = (
288
- f"Try: reduce batch_size (now {ctx.batch_size}), "
289
- f"switch to FP16/INT8 precision, "
290
- f"or use a GPU with β‰₯ {estimated:.1f} GB VRAM"
291
- ),
292
- )
293
-
294
- # ── Gate E: Precision Support ─────────────────────────────────────────────
295
-
296
- def _check_precision(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
297
- precision = ctx.precision.upper()
298
- framework = model.framework.lower().strip()
299
- hw_key = self._normalize_hw(ctx.hardware)
300
- is_gpu = hw_key not in {"cpu", "tpu", "edge"}
301
-
302
- if precision in _GPU_ONLY_PRECISIONS and not is_gpu:
303
- return ValidationCheck(
304
- name = "precision_support",
305
- passed = False,
306
- detail = (
307
- f"Precision '{precision}' requires a CUDA GPU; "
308
- f"'{ctx.hardware}' does not support it"
309
- ),
310
- suggestion = "Use FP32 for CPU inference, or switch to a compatible GPU",
311
- )
312
-
313
- if precision == "INT8" and framework not in _INT8_FRAMEWORKS:
314
- return ValidationCheck(
315
- name = "precision_support",
316
- passed = False,
317
- detail = (
318
- f"Framework '{framework}' does not support INT8 quantization"
319
- ),
320
- suggestion = (
321
- "Convert model to ONNX or use PyTorch with torch.quantization"
322
- ),
323
- )
324
-
325
- return ValidationCheck(
326
- name = "precision_support",
327
- passed = True,
328
- detail = (
329
- f"Precision '{precision}' is valid for "
330
- f"framework '{framework}' on '{ctx.hardware}'"
331
- ),
332
- )
333
-
334
- # ── Helpers ───────────────────────────────────────────────────────────────
335
-
336
- @staticmethod
337
- def _normalize_hw(hardware: str) -> str:
338
- """Lowercase, strip spaces/dashes/underscores for lookup."""
339
- return (
340
- hardware.lower()
341
- .replace(" ", "")
342
- .replace("-", "")
343
- .replace("_", "")
344
- .replace("nvidia", "")
345
- .replace("geforce", "")
346
- )
347
-
348
- @staticmethod
349
- def _lookup_vram(hw_key: str) -> float:
350
- """Return VRAM GB for a normalized hardware key, with fallback matching."""
351
- if hw_key in HARDWARE_VRAM_GB:
352
- return HARDWARE_VRAM_GB[hw_key]
353
- # Partial match (e.g. "rtx4090laptop" β†’ "rtx4090")
354
- for key, vram in HARDWARE_VRAM_GB.items():
355
- if key and key in hw_key:
356
- return vram
357
- # Anything that looks like a GPU but isn't in the table
358
- if "gpu" in hw_key or "rtx" in hw_key or "gtx" in hw_key or "cuda" in hw_key:
359
- return HARDWARE_VRAM_GB["gpu"]
360
- return 0.0 # CPU / unknown β†’ no VRAM constraint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/execution.py DELETED
@@ -1,366 +0,0 @@
1
- """
2
- benchmark/execution.py β€” Benchmark Execution Engine.
3
-
4
- Drives the batch inference loop, collecting latencies and VRAM readings.
5
- Calls TelemetryCollector in parallel with batch processing.
6
- Yields progress callbacks so the orchestrator can persist real-time state.
7
-
8
- Adapter pattern: swap _run_single_batch() with a real inference call
9
- (torch.cuda.synchronize + model(batch)) once GPU runtime is wired up.
10
-
11
- PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
12
- """
13
- from __future__ import annotations
14
-
15
- import asyncio
16
- import math
17
- import random
18
- from dataclasses import dataclass, field
19
- from typing import Awaitable, Callable
20
-
21
- from benchmark.compatibility import HARDWARE_VRAM_GB
22
- from benchmark.telemetry import TelemetryCollector
23
- from models.benchmark import BenchmarkJob, LayerBreakdown, TelemetrySample, TelemetrySummary
24
- from models.dataset import Dataset
25
- from models.model import Model
26
- from observability.logger import get_logger
27
-
28
- log = get_logger("benchmark.execution")
29
-
30
-
31
- # ── Per-image latency profiles (ms at batch=1, fp32) ─────────────────────────
32
- _LATENCY_MS_PER_IMAGE: dict[str, float] = {
33
- "rtx4090": 1.8,
34
- "rtx4080": 2.5,
35
- "rtx4070ti": 3.2,
36
- "rtx4070": 3.8,
37
- "rtx3090": 3.0,
38
- "rtx3080": 4.5,
39
- "rtx3070": 6.5,
40
- "rtx3060": 9.0,
41
- "rtx2080ti": 5.0,
42
- "rtx2080": 7.5,
43
- "a100": 1.2,
44
- "h100": 0.7,
45
- "v100": 2.8,
46
- "t4": 5.5,
47
- "a10": 3.5,
48
- "gpu": 8.0,
49
- "cpu": 42.0,
50
- }
51
-
52
- # Precision speedup multipliers (relative to FP32)
53
- _PRECISION_SPEEDUP: dict[str, float] = {
54
- "FP32": 1.0,
55
- "FP16": 1.8,
56
- "BF16": 1.7,
57
- "INT8": 2.5,
58
- }
59
-
60
- # Task-specific baseline metric scores (pre-jitter)
61
- _TASK_BASELINES: dict[str, dict[str, float]] = {
62
- "detection": {"mAP": 0.435, "mAP_50": 0.618, "mAP_50_95": 0.435},
63
- "classification": {"accuracy": 0.872, "top5": 0.968},
64
- "segmentation": {"mAP": 0.372, "iou_mean": 0.706},
65
- "keypoints": {"mAP": 0.641, "mAP_50": 0.860},
66
- "nlp": {"accuracy": 0.891},
67
- "generation": {"accuracy": 0.780},
68
- }
69
-
70
- # Cap simulated batches so large datasets don't stall the event loop
71
- _MAX_SIMULATED_BATCHES = 250
72
-
73
-
74
- @dataclass
75
- class ExecutionResult:
76
- """Raw output from the execution engine, consumed by MetricsEngine."""
77
- latencies_ms: list[float]
78
- total_images: int
79
- vram_samples: list[float]
80
- task_scores: dict[str, float]
81
- telemetry_samples: list[TelemetrySample] = field(default_factory=list)
82
- telemetry_summary: TelemetrySummary = field(default_factory=TelemetrySummary)
83
-
84
-
85
- # Progress callback type: (progress_0_to_1, message, last_telemetry) β†’ None
86
- ProgressCallback = Callable[[float, str, TelemetrySample | None], Awaitable[None]]
87
-
88
-
89
- class BenchmarkExecutor:
90
- """
91
- Drives the benchmark execution loop.
92
- Non-blocking: all sleeps are asyncio.sleep so other coroutines run freely.
93
- """
94
-
95
- async def execute(
96
- self,
97
- job: BenchmarkJob,
98
- model: Model,
99
- dataset: Dataset,
100
- on_progress: ProgressCallback,
101
- ) -> ExecutionResult:
102
- hw = job.hardware
103
- batch_sz = job.batch_size
104
-
105
- # Handle polymorphic input duration
106
- is_live = getattr(job, "input_source", "dataset") in ("video", "live")
107
-
108
- if is_live:
109
- # For live/video, we run for a fixed duration or until stopped
110
- # Increase limit for a longer session (e.g., 10,000 batches)
111
- total_img = 10000 * batch_sz
112
- n_batches = 10000
113
- sim_batches = 10000
114
- else:
115
- total_img = max(dataset.images, 100) # floor so simulation always runs
116
- n_batches = math.ceil(total_img / batch_sz)
117
- sim_batches = min(n_batches, _MAX_SIMULATED_BATCHES)
118
-
119
- vram_total = self._get_vram_gb(hw, model)
120
- vram_frac = self._vram_usage_fraction(hw)
121
-
122
- telemetry = TelemetryCollector(hw, vram_total_gb=vram_total)
123
- await telemetry.start()
124
-
125
- latencies: list[float] = []
126
- vram_samples: list[float] = []
127
-
128
- base_lat_ms = self._base_batch_latency_ms(hw, model, batch_sz, job.precision)
129
-
130
- # Resolve real model path once (None β†’ use simulation)
131
- real_model_path = model.local_path if model.local_path and model.downloaded else None
132
- use_real_inference = self._check_torch_available() and real_model_path is not None
133
- loop = asyncio.get_event_loop()
134
-
135
- try:
136
- for sim_idx in range(sim_batches):
137
- # Map simulated index back to real batch index
138
- real_idx = int(sim_idx * (n_batches / sim_batches))
139
-
140
- if use_real_inference:
141
- # Real GPU inference via torch_runner (runs in thread executor)
142
- try:
143
- from benchmark.torch_runner import run_torch_batch
144
- batch_lat_ms = await loop.run_in_executor(
145
- None,
146
- run_torch_batch,
147
- real_model_path,
148
- batch_sz,
149
- job.task,
150
- )
151
- # Add a tiny sleep to prevent event loop starvation in live mode
152
- if is_live:
153
- await asyncio.sleep(0.001)
154
- except Exception as exc:
155
- log.warning("torch_inference_failed_fallback", error=str(exc))
156
- use_real_inference = False # fall back for remaining batches
157
- batch_lat_ms = max(
158
- 0.5, base_lat_ms + random.gauss(0, base_lat_ms * 0.07)
159
- )
160
- else:
161
- # Simulation path β€” non-blocking synthetic latency
162
- batch_lat_ms = max(
163
- 0.5,
164
- base_lat_ms + random.gauss(0, base_lat_ms * 0.07),
165
- )
166
- await asyncio.sleep(batch_lat_ms / 1000.0) # non-blocking
167
-
168
- latencies.append(batch_lat_ms)
169
- vram_used = vram_total * random.uniform(
170
- vram_frac - 0.05, vram_frac + 0.05
171
- )
172
- vram_samples.append(max(0.0, vram_used))
173
-
174
- progress = (sim_idx + 1) / sim_batches
175
- telemetry.record_batch_context(real_idx, progress)
176
-
177
- # Throttle callbacks: every 5 batches or first/last
178
- if sim_idx % 5 == 0 or sim_idx == sim_batches - 1:
179
- images_done = int(progress * total_img)
180
-
181
- # Generate simulated detection data for live preview if it's a vision task
182
- live_data = {}
183
- if job.task.lower() in ("detection", "segmentation"):
184
- # Use provided bbox telemetry if available (e.g. from real inference)
185
- # otherwise generate simulated ones
186
- live_data["detections"] = [
187
- {
188
- "x": random.uniform(0.1, 0.7),
189
- "y": random.uniform(0.1, 0.7),
190
- "width": random.uniform(0.1, 0.3),
191
- "height": random.uniform(0.1, 0.3),
192
- "label": random.choice(["person", "car", "bicycle", "dog"]),
193
- "confidence": random.uniform(0.5, 0.99)
194
- }
195
- for _ in range(random.randint(1, 5))
196
- ]
197
-
198
- last_sample = telemetry.samples[-1] if telemetry.samples else None
199
- if last_sample:
200
- last_sample.live_data = live_data
201
- # Explicitly broadcast detections for the visualizer
202
- last_sample.detections = live_data.get("detections", [])
203
-
204
- await on_progress(
205
- progress,
206
- f"Batch {real_idx+1}/{n_batches} β€” "
207
- f"{images_done}/{total_img} images processed",
208
- last_sample,
209
- )
210
-
211
- finally:
212
- telemetry_summary = await telemetry.stop()
213
- # Attach simulated layer breakdown so Live Lab can display it
214
- telemetry_summary.layer_breakdown = self._compute_layer_breakdown(
215
- job.task, base_lat_ms
216
- )
217
-
218
- task_scores = self._simulate_task_scores(job.task, model, dataset)
219
-
220
- log.info(
221
- "execution_complete",
222
- job_id = job.id,
223
- total_images = total_img,
224
- sim_batches = sim_batches,
225
- avg_lat_ms = round(sum(latencies) / len(latencies), 2) if latencies else 0,
226
- )
227
-
228
- return ExecutionResult(
229
- latencies_ms = latencies,
230
- total_images = total_img,
231
- vram_samples = vram_samples,
232
- task_scores = task_scores,
233
- telemetry_samples = telemetry.samples,
234
- telemetry_summary = telemetry_summary,
235
- )
236
-
237
- # ── Helpers ───────────────────────────────────────────────────────────────
238
-
239
- def _base_batch_latency_ms(
240
- self,
241
- hardware: str,
242
- model: Model,
243
- batch_sz: int,
244
- precision: str,
245
- ) -> float:
246
- """
247
- Estimate per-batch latency in ms.
248
- Accounts for hardware tier, model size, batch size, and precision.
249
- """
250
- hw_key = self._normalize_hw(hardware)
251
- per_img = self._lookup_latency(hw_key)
252
-
253
- # Larger models are slower: +30% per GB of model weights
254
- size_gb = max(model.size, 1) / (1024 ** 3)
255
- size_factor = 1.0 + size_gb * 0.30
256
-
257
- # Batch parallelism: ~65% linear efficiency on GPU, 90% on CPU
258
- eff = 0.65 if "cpu" not in hw_key else 0.90
259
- batch_lat = per_img * size_factor * batch_sz * eff
260
-
261
- # Precision speedup
262
- speedup = _PRECISION_SPEEDUP.get(precision.upper(), 1.0)
263
-
264
- return batch_lat / speedup
265
-
266
- def _get_vram_gb(self, hardware: str, model: Model) -> float:
267
- hw_key = self._normalize_hw(hardware)
268
- for key, vram in HARDWARE_VRAM_GB.items():
269
- if key and key in hw_key:
270
- return vram
271
- return 8.0
272
-
273
- @staticmethod
274
- def _vram_usage_fraction(hardware: str) -> float:
275
- """Fraction of VRAM typically consumed during inference."""
276
- hw = hardware.lower()
277
- if any(x in hw for x in ("4090", "3090", "a100", "h100")):
278
- return 0.62
279
- if any(x in hw for x in ("4080", "3080", "v100", "a10")):
280
- return 0.60
281
- if "cpu" in hw:
282
- return 0.0
283
- return 0.55
284
-
285
- @staticmethod
286
- def _simulate_task_scores(
287
- task: str, model: Model, dataset: Dataset
288
- ) -> dict[str, float]:
289
- """
290
- Produce realistic metric scores with small per-run variance.
291
-
292
- PRODUCTION SWAP: replace with actual metric computation:
293
- from torchmetrics.detection import MeanAveragePrecision
294
- metric = MeanAveragePrecision()
295
- metric.update(predictions, targets)
296
- return metric.compute()
297
- """
298
- baselines = dict(_TASK_BASELINES.get(task.lower(), {"accuracy": 0.80}))
299
- # Small Gaussian jitter simulates run-to-run variance
300
- return {
301
- k: float(max(0.0, min(1.0, v + random.gauss(0, 0.015))))
302
- for k, v in baselines.items()
303
- }
304
-
305
- @staticmethod
306
- def _check_torch_available() -> bool:
307
- """Return True if PyTorch is installed and importable."""
308
- try:
309
- import torch # noqa: F401
310
- return True
311
- except ImportError:
312
- return False
313
-
314
- @staticmethod
315
- def _compute_layer_breakdown(task: str, base_lat_ms: float) -> list[LayerBreakdown]:
316
- """Build a realistic layer breakdown for the given task.
317
-
318
- Splits total latency across architectural stages with small jitter.
319
- PRODUCTION SWAP: replace with actual profiler data (e.g. torch.profiler).
320
- """
321
- if task.lower() in ("detection", "segmentation"):
322
- stages = [
323
- ("Backbone", 0.45),
324
- ("Neck (FPN/PAFPN)", 0.30),
325
- ("Detection Head", 0.20),
326
- ("NMS Post-process", 0.05),
327
- ]
328
- elif task.lower() == "classification":
329
- stages = [
330
- ("Feature Extractor", 0.70),
331
- ("Classifier Head", 0.20),
332
- ("Softmax", 0.10),
333
- ]
334
- else:
335
- stages = [
336
- ("Encoder", 0.55),
337
- ("Decoder / Head", 0.35),
338
- ("Post-process", 0.10),
339
- ]
340
-
341
- result: list[LayerBreakdown] = []
342
- remaining = base_lat_ms
343
- for name, frac in stages:
344
- t = round(base_lat_ms * frac + random.gauss(0, base_lat_ms * 0.01), 3)
345
- result.append(LayerBreakdown(name=name, time_ms=t, percent=round(frac * 100, 1)))
346
- return result
347
-
348
- @staticmethod
349
- def _normalize_hw(hardware: str) -> str:
350
- return (
351
- hardware.lower()
352
- .replace(" ", "")
353
- .replace("-", "")
354
- .replace("_", "")
355
- .replace("nvidia", "")
356
- .replace("geforce", "")
357
- )
358
-
359
- @staticmethod
360
- def _lookup_latency(hw_key: str) -> float:
361
- for key, ms in _LATENCY_MS_PER_IMAGE.items():
362
- if key and key in hw_key:
363
- return ms
364
- if any(x in hw_key for x in ("gpu", "rtx", "gtx", "cuda")):
365
- return _LATENCY_MS_PER_IMAGE["gpu"]
366
- return _LATENCY_MS_PER_IMAGE["cpu"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/metrics.py DELETED
@@ -1,110 +0,0 @@
1
- """
2
- benchmark/metrics.py β€” Metrics Engine.
3
-
4
- Computes the final BenchmarkMetrics object from raw execution data:
5
- - Latency statistics (mean, p95, p99)
6
- - Throughput (FPS)
7
- - VRAM statistics (avg, peak)
8
- - Task-specific scores (mAP, accuracy, IoU) supplied by the executor
9
-
10
- In a production deployment the task_scores dict comes from actual
11
- metric computation (e.g. pycocotools, torchmetrics). In this local-first
12
- build the executor supplies realistic simulated scores.
13
- """
14
- from __future__ import annotations
15
-
16
- import statistics
17
-
18
- from models.benchmark import BenchmarkMetrics, LayerBreakdown, TelemetrySummary
19
- from observability.logger import get_logger
20
-
21
- log = get_logger("benchmark.metrics")
22
-
23
-
24
- class MetricsEngine:
25
- """Computes BenchmarkMetrics from raw benchmark execution data."""
26
-
27
- def compute(
28
- self,
29
- *,
30
- task: str,
31
- latencies_ms: list[float], # per-batch latencies
32
- total_images: int = 0,
33
- total_tokens: int = 0,
34
- batch_size: int,
35
- vram_samples: list[float], # VRAM readings (GB) during run
36
- task_scores: dict[str, float], # task-specific metric scores
37
- ) -> BenchmarkMetrics:
38
- if not latencies_ms:
39
- return BenchmarkMetrics(total_images=total_images, total_tokens=total_tokens, batch_size=batch_size)
40
-
41
- total_time_s = sum(latencies_ms) / 1000.0
42
- fps = total_images / total_time_s if total_time_s > 0 and total_images > 0 else 0.0
43
- tps = total_tokens / total_time_s if total_time_s > 0 and total_tokens > 0 else 0.0
44
-
45
- lat_mean = statistics.mean(latencies_ms)
46
- lat_p95 = _percentile(latencies_ms, 0.95)
47
- lat_p99 = _percentile(latencies_ms, 0.99)
48
-
49
- vram_peak = max(vram_samples) if vram_samples else 0.0
50
- vram_avg = statistics.mean(vram_samples) if vram_samples else 0.0
51
-
52
- m = BenchmarkMetrics(
53
- fps = round(fps, 2),
54
- tokens_per_sec = round(tps, 2),
55
- latency_mean_ms = round(lat_mean, 3),
56
- latency_p95_ms = round(lat_p95, 3),
57
- latency_p99_ms = round(lat_p99, 3),
58
- vram_peak_gb = round(vram_peak, 3),
59
- vram_avg_gb = round(vram_avg, 3),
60
- total_images = total_images,
61
- total_tokens = total_tokens,
62
- batch_size = batch_size,
63
- )
64
-
65
- task_lower = task.lower()
66
-
67
- # CV Task Mapping
68
- if task_lower in ("detection", "segmentation", "keypoints"):
69
- m.mAP = _fmt(task_scores.get("mAP", 0.0))
70
- m.mAP_50 = _fmt(task_scores.get("mAP_50", 0.0))
71
- m.mAP_50_95 = _fmt(task_scores.get("mAP_50_95", 0.0))
72
- if task_lower == "segmentation":
73
- m.iou_mean = _fmt(task_scores.get("iou_mean", 0.0))
74
-
75
- elif task_lower == "classification":
76
- m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
77
- m.top1 = _fmt(task_scores.get("top1", 0.0))
78
- m.top5 = _fmt(task_scores.get("top5", 0.0))
79
-
80
- # NLP Task Mapping (ROUGE, BLEU, Perplexity)
81
- elif task_lower in ("nlp", "generation"):
82
- m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
83
- m.rouge_l = _fmt(task_scores.get("rouge_l", task_scores.get("rougeL", 0.0)))
84
- m.bleu = _fmt(task_scores.get("bleu", 0.0))
85
- m.perplexity = task_scores.get("perplexity")
86
-
87
- log.info(
88
- "metrics_computed",
89
- task = task,
90
- fps = m.fps,
91
- tps = m.tokens_per_sec,
92
- latency_ms = m.latency_mean_ms,
93
- vram_peak = m.vram_peak_gb,
94
- )
95
- return m
96
-
97
-
98
- # ── Helpers ───────────────────────────────────────────────────────────────────
99
-
100
- def _percentile(data: list[float], p: float) -> float:
101
- if not data:
102
- return 0.0
103
- s = sorted(data)
104
- idx = min(int(len(s) * p), len(s) - 1)
105
- return s[idx]
106
-
107
-
108
- def _fmt(v: float) -> float:
109
- """Round to 4dp and clamp to [0, 1]."""
110
- return round(max(0.0, min(1.0, v)), 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/orchestrator.py DELETED
@@ -1,374 +0,0 @@
1
- """
2
- benchmark/orchestrator.py β€” Benchmark Orchestrator (Main Controller).
3
-
4
- Coordinates the full benchmark lifecycle:
5
- 1. Resolve model + dataset from their registries
6
- 2. Run all compatibility checks (gates A–E)
7
- 3. If valid β†’ create a BenchmarkJob in the DB
8
- 4. Persist the validation audit log
9
- 5. Enqueue async background task β†’ execution β†’ metrics β†’ storage
10
- 6. Return the job immediately so callers are non-blocking
11
-
12
- Public interface used by api/routes/benchmark.py:
13
- validate_context(ctx) β†’ ValidationReport (no job created)
14
- create_and_run(ctx) β†’ BenchmarkJob (job queued, execution in background)
15
- """
16
- from __future__ import annotations
17
-
18
- import asyncio
19
- from datetime import datetime, timezone
20
-
21
- from benchmark.adapters.registry import get_executor
22
- from benchmark.compatibility import CompatibilityValidator
23
- from benchmark.execution import BenchmarkExecutor
24
- from benchmark.metrics import MetricsEngine
25
- import benchmark.registry as bench_reg
26
- from datasets.registry import get_dataset
27
- from models.benchmark import (
28
- BenchmarkContext,
29
- BenchmarkJob,
30
- BenchmarkMetrics,
31
- TelemetrySummary,
32
- ValidationReport,
33
- )
34
- from models.dataset import Dataset
35
- from models.model import Model
36
- from observability.logger import audit, get_logger
37
- from registry.registry import get_model
38
-
39
- log = get_logger("benchmark.orchestrator")
40
-
41
- # Module-level singletons β€” stateless, safe to share
42
- _validator = CompatibilityValidator()
43
- _metrics = MetricsEngine()
44
-
45
- # job_id β†’ asyncio.Task (for future cancellation support)
46
- _active_tasks: dict[str, asyncio.Task] = {}
47
-
48
-
49
- # ── Public API ────────────────────────────────────────────────────────────────
50
-
51
- async def sync_project_benchmarks() -> int:
52
- """
53
- Sync benchmark jobs and results from the active project's 'benchmarks' folder.
54
- This ensures that benchmarks created in different sessions or projects are indexed.
55
- """
56
- from benchmark.registry import _get_active_project_benchmark_dir_sync
57
- from projects.service import get_active_project_path
58
- import json
59
- import os
60
- from database.connection import get_db
61
-
62
- project_path = await get_active_project_path()
63
- benchmark_dir = _get_active_project_benchmark_dir_sync(project_path)
64
- if not benchmark_dir or not benchmark_dir.exists():
65
- return 0
66
-
67
- db = await get_db()
68
- count = 0
69
-
70
- for file_path in benchmark_dir.glob("*.json"):
71
- try:
72
- with open(file_path, "r") as f:
73
- data = json.load(f)
74
-
75
- # Check if it's a job or a result
76
- if file_path.name.startswith("job_"):
77
- # Upsert into benchmark_jobs
78
- await db.execute(
79
- """INSERT OR IGNORE INTO benchmark_jobs
80
- (id, model_id, dataset_id, task, framework, hardware,
81
- precision, batch_size, config, status, progress, created_at, updated_at, started_at)
82
- VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
83
- (
84
- data["id"], data["model_id"], data["dataset_id"],
85
- data["task"], data["framework"], data["hardware"],
86
- data["precision"], data["batch_size"],
87
- json.dumps(data["config"]), data["status"],
88
- data.get("progress", 0.0),
89
- data.get("created_at", datetime.now(timezone.utc).isoformat()),
90
- data.get("updated_at", datetime.now(timezone.utc).isoformat()),
91
- data.get("started_at")
92
- )
93
- )
94
- count += 1
95
- elif file_path.name.startswith("result_"):
96
- # Upsert into benchmark_results
97
- await db.execute(
98
- """INSERT OR IGNORE INTO benchmark_results
99
- (id, job_id, metrics, telemetry_summary, created_at)
100
- VALUES (?,?,?,?,?)""",
101
- (
102
- data["id"], data["job_id"],
103
- json.dumps(data["metrics"]),
104
- json.dumps(data["telemetry_summary"]),
105
- data.get("created_at", datetime.now(timezone.utc).isoformat())
106
- )
107
- )
108
- count += 1
109
- except Exception as e:
110
- log.error("sync_file_failed", file=file_path.name, error=str(e))
111
-
112
- await db.commit()
113
- log.info("sync_complete", count=count)
114
- return count
115
-
116
- async def validate_context(ctx: BenchmarkContext) -> ValidationReport:
117
- """
118
- Validate model ↔ dataset ↔ hardware compatibility.
119
- Does NOT create a job. Safe to call repeatedly from the UI.
120
- """
121
- model = await _require_model(ctx.model_id)
122
-
123
- # ── Handle Polymorphic Input (Video/Live) ��───────────────────────────────
124
- if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
125
- # Create a synthetic dataset object for non-dataset sources
126
- now = datetime.now(timezone.utc).isoformat()
127
- dataset = Dataset(
128
- id="none",
129
- name="Live/Video Stream",
130
- task=model.task, # Match model task to pass task check
131
- format="custom",
132
- source="local",
133
- status="imported",
134
- images=0,
135
- classes=0,
136
- size_label="0 MB",
137
- created_at=now,
138
- updated_at=now
139
- )
140
- else:
141
- dataset = await _require_dataset(ctx.dataset_id)
142
-
143
- return _validator.validate(model, dataset, ctx)
144
-
145
-
146
- async def create_and_run(ctx: BenchmarkContext) -> BenchmarkJob:
147
- """
148
- Full benchmark initiation:
149
- """
150
- model = await _require_model(ctx.model_id)
151
-
152
- # ── Handle Polymorphic Input (Video/Live) ────────────────────────────────
153
- if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
154
- now = datetime.now(timezone.utc).isoformat()
155
- dataset = Dataset(
156
- id="none",
157
- name="Live/Video Stream",
158
- task=model.task,
159
- format="custom",
160
- source="local",
161
- status="imported",
162
- images=0,
163
- classes=0,
164
- size_label="0 MB",
165
- created_at=now,
166
- updated_at=now
167
- )
168
- else:
169
- dataset = await _require_dataset(ctx.dataset_id)
170
-
171
- # ── Compatibility check ───────────────────────────────────────────────────
172
- report = _validator.validate(model, dataset, ctx)
173
-
174
- # Always persist the validation log (even for failures)
175
- await bench_reg.save_validation_log(
176
- job_id = "pre-check",
177
- model_id = ctx.model_id,
178
- dataset_id = ctx.dataset_id,
179
- checks = report.checks,
180
- passed = report.passed,
181
- )
182
-
183
- if not report.passed:
184
- from fastapi import HTTPException
185
- failed = [c for c in report.checks if not c.passed]
186
- raise HTTPException(
187
- status_code = 422,
188
- detail = {
189
- "error": "Compatibility validation failed",
190
- "failed_checks": [
191
- {
192
- "name": c.name,
193
- "detail": c.detail,
194
- "suggestion": c.suggestion,
195
- }
196
- for c in failed
197
- ],
198
- },
199
- )
200
-
201
- # ── Create job ────────────────────────────────────────────────────────────
202
- job = await bench_reg.create_job(ctx)
203
-
204
- # Overwrite 'pre-check' validation log with the real job_id
205
- await bench_reg.save_validation_log(
206
- job_id = job.id,
207
- model_id = ctx.model_id,
208
- dataset_id = ctx.dataset_id,
209
- checks = report.checks,
210
- passed = True,
211
- )
212
-
213
- # ── Log the Polymorphic Input params ─────────────────────────────────────
214
- if ctx.input_source or ctx.video_path or ctx.rtsp_url:
215
- log.info("polymorphic_input_received",
216
- job_id=job.id,
217
- source=ctx.input_source,
218
- video=ctx.video_path,
219
- rtsp=ctx.rtsp_url)
220
-
221
- # ── Enqueue background execution ──────────────────────────────────────────
222
- task = asyncio.create_task(
223
- _execute_job(job.id, ctx, model, dataset),
224
- name = f"benchmark_{job.id}",
225
- )
226
- _active_tasks[job.id] = task
227
- task.add_done_callback(lambda _t: _active_tasks.pop(job.id, None))
228
-
229
- log.info("benchmark_enqueued", job_id=job.id, model=ctx.model_id)
230
- return job
231
-
232
-
233
- # ── Background execution ──────────────────────────────────────────────────────
234
-
235
- async def _execute_job(
236
- job_id: str,
237
- ctx: BenchmarkContext,
238
- model: Model,
239
- dataset: Dataset,
240
- ) -> None:
241
- """Full benchmark lifecycle β€” runs in an asyncio background task."""
242
- now = datetime.now(timezone.utc).isoformat()
243
-
244
- # Transition β†’ running
245
- ts_color = "\x1b[36m" # Cyan
246
- info_color = "\x1b[34m" # Blue
247
- success_color = "\x1b[32m" # Green
248
- reset = "\x1b[0m"
249
-
250
- await bench_reg.update_job(
251
- job_id,
252
- status = "running",
253
- progress = 0.0,
254
- started_at = now,
255
- log_entry = f"{ts_color}[{now}]{reset} {info_color}Job started{reset} on {ctx.hardware} ({ctx.precision})",
256
- )
257
-
258
- runner = BenchmarkExecutor()
259
-
260
- try:
261
- # ── Fetch the persisted job (for executor) ────────────────────────────
262
- job = await bench_reg.get_job(job_id)
263
- assert job is not None, "Job disappeared from DB after creation"
264
-
265
- # ── Define Progress Callback ──────────────────────────────────────────
266
- async def on_progress(progress: float, message: str, telemetry: Any | None):
267
- await bench_reg.update_job(
268
- job_id,
269
- progress=progress,
270
- log_entry=f"{ts_color}[{datetime.now(timezone.utc).isoformat()}]{reset} {info_color}{message}{reset}",
271
- last_telemetry=telemetry.model_dump() if telemetry and hasattr(telemetry, "model_dump") else telemetry
272
- )
273
-
274
- # ── Execution Loop ────────────────────────────────────────────────────
275
- exec_result = await runner.execute(
276
- job=job,
277
- model=model,
278
- dataset=dataset,
279
- on_progress=on_progress
280
- )
281
-
282
- # ── Compute metrics ───────────────────────────────────────────────────
283
- metrics = _metrics.compute(
284
- task = ctx.task,
285
- latencies_ms = exec_result.latencies_ms,
286
- total_images = exec_result.total_images,
287
- batch_size = ctx.batch_size,
288
- vram_samples = exec_result.vram_samples,
289
- task_scores = exec_result.task_scores,
290
- )
291
-
292
- # ── Persist result ────────────────────────────────────────────────────
293
- await bench_reg.save_result(
294
- job_id = job_id,
295
- metrics = metrics,
296
- telemetry_summary = exec_result.telemetry_summary,
297
- )
298
-
299
- ended = datetime.now(timezone.utc).isoformat()
300
- await bench_reg.update_job(
301
- job_id,
302
- status = "completed",
303
- progress = 1.0,
304
- ended_at = ended,
305
- log_entry = f"{ts_color}[{ended}]{reset} {success_color}Benchmark completed{reset} β€” {metrics.fps} FPS",
306
- )
307
-
308
- await audit(
309
- "benchmark_completed",
310
- job_id = job_id,
311
- payload = {"model_id": ctx.model_id, "dataset_id": ctx.dataset_id},
312
- )
313
- log.info(
314
- "benchmark_completed",
315
- job_id = job_id,
316
- fps = metrics.fps,
317
- lat_ms = metrics.latency_mean_ms,
318
- )
319
-
320
- except asyncio.CancelledError:
321
- # Task cancelled externally (e.g. server shutdown) β€” don't swallow
322
- ended = datetime.now(timezone.utc).isoformat()
323
- await bench_reg.update_job(
324
- job_id,
325
- status = "failed",
326
- error = "Job cancelled",
327
- ended_at = ended,
328
- log_entry = f"{ts_color}[{ended}]{reset} \x1b[31mJob cancelled\x1b[0m",
329
- )
330
- raise
331
-
332
- except Exception as exc:
333
- ended = datetime.now(timezone.utc).isoformat()
334
- err_msg = str(exc)
335
- error_color = "\x1b[31m" # Red
336
- await bench_reg.update_job(
337
- job_id,
338
- status = "failed",
339
- error = err_msg,
340
- ended_at = ended,
341
- log_entry = f"{ts_color}[{ended}]{reset} {error_color}ERROR: {err_msg}{reset}",
342
- )
343
- await audit(
344
- "benchmark_failed",
345
- job_id = job_id,
346
- level = "error",
347
- payload = {"error": err_msg, "model_id": ctx.model_id},
348
- )
349
- log.exception("benchmark_failed", job_id=job_id)
350
- finally:
351
- pass
352
-
353
- # ── Resource resolvers ────────────────────────────────────────────────────────
354
-
355
- async def _require_model(model_id: str) -> Model:
356
- model = await get_model(model_id)
357
- if not model:
358
- from fastapi import HTTPException
359
- raise HTTPException(
360
- status_code = 404,
361
- detail = f"Model '{model_id}' not found in Model Zoo",
362
- )
363
- return model
364
-
365
-
366
- async def _require_dataset(dataset_id: str) -> Dataset:
367
- dataset = await get_dataset(dataset_id)
368
- if not dataset:
369
- from fastapi import HTTPException
370
- raise HTTPException(
371
- status_code = 404,
372
- detail = f"Dataset '{dataset_id}' not found in Dataset Manager",
373
- )
374
- return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/registry.py DELETED
@@ -1,302 +0,0 @@
1
- """
2
- benchmark/registry.py β€” Benchmark Registry.
3
-
4
- All DB interactions for:
5
- β€’ benchmark_jobs β€” job lifecycle state
6
- β€’ benchmark_results β€” final metrics + telemetry summary
7
- β€’ benchmark_validation_logs β€” immutable check audit trail
8
-
9
- Follows the same pattern as registry/registry.py and datasets/registry.py.
10
- No direct DB access from other benchmark modules β€” everything routes here.
11
- """
12
- from __future__ import annotations
13
-
14
- import json
15
- import uuid
16
- from datetime import datetime, timezone
17
- from typing import Any
18
- from pathlib import Path
19
-
20
- from database.connection import get_db
21
- from models.benchmark import (
22
- BenchmarkContext,
23
- BenchmarkJob,
24
- BenchmarkMetrics,
25
- BenchmarkResult,
26
- TelemetrySummary,
27
- ValidationCheck,
28
- row_to_job,
29
- row_to_result,
30
- )
31
- from observability.logger import get_logger
32
-
33
- log = get_logger("benchmark.registry")
34
-
35
-
36
- def _get_active_project_benchmark_dir_sync(project_path: str | None) -> Path | None:
37
- """Get the absolute path to the 'benchmarks' folder in a given project path."""
38
- if not project_path:
39
- return None
40
-
41
- benchmark_dir = Path(project_path) / "benchmarks"
42
- benchmark_dir.mkdir(parents=True, exist_ok=True)
43
- return benchmark_dir
44
-
45
- async def _get_active_project_benchmark_dir() -> Path | None:
46
- """Get the absolute path to the 'benchmarks' folder in the active project."""
47
- from projects.service import get_active_project_path
48
- project_path = await get_active_project_path()
49
- return _get_active_project_benchmark_dir_sync(project_path)
50
-
51
- async def _save_to_project(filename: str, data: dict) -> None:
52
- """Save data to a JSON file in the active project's benchmark folder."""
53
- benchmark_dir = await _get_active_project_benchmark_dir()
54
- if not benchmark_dir:
55
- return
56
-
57
- file_path = benchmark_dir / filename
58
- try:
59
- with open(file_path, "w") as f:
60
- json.dump(data, f, indent=2)
61
- except Exception as e:
62
- log.error("project_persistence_failed", error=str(e), file=filename)
63
-
64
- # ── Job CRUD ──────────────────────────────────────────────────────────────────
65
-
66
- async def create_job(ctx: BenchmarkContext) -> BenchmarkJob:
67
- db = await get_db()
68
- job_id = f"bmark-{uuid.uuid4().hex[:12]}"
69
- now = datetime.now(timezone.utc).isoformat()
70
-
71
- # Create job object
72
- job = BenchmarkJob(
73
- id = job_id,
74
- model_id = ctx.model_id,
75
- dataset_id = ctx.dataset_id,
76
- task = ctx.task,
77
- framework = ctx.framework,
78
- hardware = ctx.hardware,
79
- precision = ctx.precision,
80
- batch_size = ctx.batch_size,
81
- config = ctx.model_dump(),
82
- status = "queued",
83
- progress = 0.0,
84
- created_at = now,
85
- updated_at = now,
86
- )
87
-
88
- # Persist to SQLite
89
- await db.execute(
90
- """INSERT INTO benchmark_jobs
91
- (id, model_id, dataset_id, task, framework, hardware,
92
- precision, batch_size, config,
93
- status, progress, logs, created_at, updated_at)
94
- VALUES (?,?,?,?,?,?,?,?,?,'queued',0.0,'[]',?,?)""",
95
- (
96
- job_id,
97
- ctx.model_id, ctx.dataset_id,
98
- ctx.task, ctx.framework, ctx.hardware,
99
- ctx.precision, ctx.batch_size,
100
- json.dumps(ctx.model_dump()),
101
- now, now,
102
- ),
103
- )
104
- await db.commit()
105
-
106
- # Persist to project folder
107
- await _save_to_project(f"job_{job_id}.json", job.model_dump())
108
-
109
- log.info("benchmark_job_created", job_id=job_id, model=ctx.model_id)
110
- return job
111
-
112
-
113
- async def get_job(job_id: str) -> BenchmarkJob | None:
114
- db = await get_db()
115
- async with db.execute(
116
- "SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)
117
- ) as cur:
118
- row = await cur.fetchone()
119
- return row_to_job(row) if row else None
120
-
121
-
122
- async def list_jobs(
123
- *,
124
- status: str | None = None,
125
- model_id: str | None = None,
126
- limit: int = 100,
127
- ) -> list[BenchmarkJob]:
128
- db = await get_db()
129
- clauses: list[str] = []
130
- params: list[Any] = []
131
-
132
- if status:
133
- clauses.append("status = ?")
134
- params.append(status)
135
- if model_id:
136
- clauses.append("model_id = ?")
137
- params.append(model_id)
138
-
139
- where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
140
- params.append(limit)
141
-
142
- async with db.execute(
143
- f"SELECT * FROM benchmark_jobs {where} ORDER BY created_at DESC LIMIT ?",
144
- params,
145
- ) as cur:
146
- rows = await cur.fetchall()
147
- return [row_to_job(r) for r in rows]
148
-
149
-
150
- async def update_job(
151
- job_id: str,
152
- *,
153
- status: str | None = None,
154
- progress: float | None = None,
155
- error: str | None = None,
156
- started_at: str | None = None,
157
- ended_at: str | None = None,
158
- log_entry: str | None = None,
159
- last_telemetry: dict | None = None,
160
- ) -> None:
161
- """Update mutable fields on a benchmark job atomically."""
162
- db = await get_db()
163
- now = datetime.now(timezone.utc).isoformat()
164
-
165
- sets: list[str] = ["updated_at = ?"]
166
- vals: list[Any] = [now]
167
-
168
- if status is not None:
169
- sets.append("status = ?"); vals.append(status)
170
- if progress is not None:
171
- sets.append("progress = ?"); vals.append(round(progress, 4))
172
- if error is not None:
173
- sets.append("error = ?"); vals.append(error)
174
- if started_at is not None:
175
- sets.append("started_at = ?"); vals.append(started_at)
176
- if ended_at is not None:
177
- sets.append("ended_at = ?"); vals.append(ended_at)
178
- if last_telemetry is not None:
179
- sets.append("last_telemetry = ?"); vals.append(json.dumps(last_telemetry))
180
-
181
- if log_entry is not None:
182
- # Append new entry to the JSON log array (capped at 500 lines)
183
- async with db.execute(
184
- "SELECT logs FROM benchmark_jobs WHERE id = ?", (job_id,)
185
- ) as cur:
186
- row = await cur.fetchone()
187
- existing = json.loads(row["logs"]) if row and row["logs"] else []
188
- existing.append(log_entry)
189
- sets.append("logs = ?")
190
- vals.append(json.dumps(existing[-500:]))
191
-
192
- vals.append(job_id)
193
- # Persist to project folder if we have the job info
194
- async with db.execute("SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)) as cur:
195
- row = await cur.fetchone()
196
- if row:
197
- job = row_to_job(row)
198
- if job:
199
- await _save_to_project(f"job_{job_id}.json", job.model_dump())
200
-
201
- await db.commit()
202
-
203
-
204
- # ── Result CRUD ───────────────────────────────────────────────────────────────
205
-
206
- async def save_result(
207
- *,
208
- job_id: str,
209
- metrics: BenchmarkMetrics,
210
- telemetry_summary: TelemetrySummary,
211
- ) -> BenchmarkResult:
212
- db = await get_db()
213
- result_id = f"bres-{uuid.uuid4().hex[:12]}"
214
- now = datetime.now(timezone.utc).isoformat()
215
-
216
- # Persist result to SQLite
217
- await db.execute(
218
- """INSERT INTO benchmark_results
219
- (id, job_id, metrics, telemetry_summary, created_at)
220
- VALUES (?,?,?,?,?)""",
221
- (
222
- result_id,
223
- job_id,
224
- json.dumps(metrics.model_dump(exclude_none=True)),
225
- json.dumps(telemetry_summary.model_dump()),
226
- now,
227
- ),
228
- )
229
- await db.commit()
230
-
231
- result = BenchmarkResult(
232
- id = result_id,
233
- job_id = job_id,
234
- metrics = metrics,
235
- telemetry_summary = telemetry_summary,
236
- created_at = now,
237
- )
238
-
239
- # Persist result to project folder
240
- await _save_to_project(f"result_{job_id}.json", result.model_dump())
241
-
242
- log.info("benchmark_result_saved", job_id=job_id, result_id=result_id)
243
- return result
244
-
245
-
246
- async def get_result(job_id: str) -> BenchmarkResult | None:
247
- db = await get_db()
248
- async with db.execute(
249
- """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
250
- FROM benchmark_results r
251
- JOIN benchmark_jobs j ON r.job_id = j.id
252
- WHERE r.job_id = ?""", (job_id,)
253
- ) as cur:
254
- row = await cur.fetchone()
255
- return row_to_result(row) if row else None
256
-
257
-
258
- async def list_results(*, limit: int = 100) -> list[BenchmarkResult]:
259
- db = await get_db()
260
- async with db.execute(
261
- """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
262
- FROM benchmark_results r
263
- JOIN benchmark_jobs j ON r.job_id = j.id
264
- ORDER BY r.created_at DESC LIMIT ?""", (limit,)
265
- ) as cur:
266
- rows = await cur.fetchall()
267
- return [row_to_result(r) for r in rows]
268
-
269
-
270
- # ── Validation Log ────────────────────────────────────────────────────────────
271
-
272
- async def save_validation_log(
273
- *,
274
- job_id: str,
275
- model_id: str,
276
- dataset_id: str,
277
- checks: list[ValidationCheck],
278
- passed: bool,
279
- ) -> None:
280
- """Persist an immutable record of all compatibility checks."""
281
- db = await get_db()
282
- log_id = f"bval-{uuid.uuid4().hex[:12]}"
283
- now = datetime.now(timezone.utc).isoformat()
284
-
285
- await db.execute(
286
- """INSERT INTO benchmark_validation_logs
287
- (id, job_id, model_id, dataset_id, checks, passed, created_at)
288
- VALUES (?,?,?,?,?,?,?)""",
289
- (
290
- log_id, job_id, model_id, dataset_id,
291
- json.dumps([c.model_dump() for c in checks]),
292
- 1 if passed else 0,
293
- now,
294
- ),
295
- )
296
- await db.commit()
297
- log.info(
298
- "validation_log_saved",
299
- job_id = job_id,
300
- passed = passed,
301
- n_checks = len(checks),
302
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/telemetry.py DELETED
@@ -1,182 +0,0 @@
1
- """
2
- benchmark/telemetry.py β€” Real-time Telemetry Collector.
3
-
4
- Collects GPU/hardware metrics at 2 Hz during benchmark execution.
5
- Designed as a drop-in adapter:
6
- β€’ Local dev β†’ simulates realistic GPU readings based on hardware tier
7
- β€’ Production β†’ replace _read_gpu_metrics() with pynvml calls:
8
- nvmlDeviceGetUtilizationRates()
9
- nvmlDeviceGetMemoryInfo()
10
- nvmlDeviceGetTemperature()
11
- nvmlDeviceGetPowerUsage()
12
-
13
- Usage (async context):
14
- collector = TelemetryCollector("rtx4090", vram_total_gb=24.0)
15
- await collector.start()
16
- # ... run inference ...
17
- summary = await collector.stop()
18
- samples = collector.samples
19
- """
20
- from __future__ import annotations
21
-
22
- import asyncio
23
- import random
24
- import statistics
25
- import time
26
-
27
- from models.benchmark import TelemetrySample, TelemetrySummary
28
- from observability.logger import get_logger
29
-
30
- log = get_logger("benchmark.telemetry")
31
-
32
- # ── Hardware simulation profiles ──────────────────────────────────────────────
33
- # (base_util%, base_temp_C, base_power_W)
34
- _HW_PROFILES: dict[str, tuple[float, float, float]] = {
35
- "rtx4090": (88.0, 74.0, 380.0),
36
- "rtx4080": (84.0, 70.0, 280.0),
37
- "rtx4070": (80.0, 68.0, 200.0),
38
- "rtx3090": (85.0, 72.0, 320.0),
39
- "rtx3080": (82.0, 70.0, 250.0),
40
- "rtx3070": (78.0, 66.0, 180.0),
41
- "rtx3060": (74.0, 64.0, 150.0),
42
- "a100": (90.0, 68.0, 350.0),
43
- "h100": (92.0, 65.0, 550.0),
44
- "v100": (87.0, 70.0, 280.0),
45
- "t4": (75.0, 62.0, 60.0),
46
- "gpu": (70.0, 65.0, 150.0),
47
- "cpu": (0.0, 0.0, 0.0),
48
- }
49
-
50
- _COLLECTION_INTERVAL_S = 0.5 # 2 Hz
51
-
52
-
53
- class TelemetryCollector:
54
- """
55
- Async telemetry collector. Call start() before inference, stop() after.
56
- Thread-safe via asyncio (single-threaded event loop).
57
- """
58
-
59
- def __init__(self, hardware: str, vram_total_gb: float = 8.0) -> None:
60
- self._hardware = hardware
61
- self._vram_total = vram_total_gb
62
- self._hw_profile = self._resolve_profile(hardware)
63
- self._samples: list[TelemetrySample] = []
64
- self._running = False
65
- self._task: asyncio.Task | None = None
66
-
67
- # ── Public API ────────────────────────────────────────────────────────────
68
-
69
- async def start(self) -> None:
70
- self._running = True
71
- self._samples = []
72
- self._task = asyncio.create_task(
73
- self._collect_loop(), name="telemetry_collector"
74
- )
75
- log.debug("telemetry_started", hardware=self._hardware)
76
-
77
- async def stop(self) -> TelemetrySummary:
78
- self._running = False
79
- if self._task and not self._task.done():
80
- self._task.cancel()
81
- try:
82
- await self._task
83
- except asyncio.CancelledError:
84
- pass
85
- log.debug(
86
- "telemetry_stopped",
87
- hardware = self._hardware,
88
- samples = len(self._samples),
89
- )
90
- return self._build_summary()
91
-
92
- def record_batch_context(self, batch_idx: int, progress: float) -> None:
93
- """Annotate the most recent sample with the current batch context."""
94
- if self._samples:
95
- self._samples[-1].batch_idx = batch_idx
96
- self._samples[-1].progress = progress
97
-
98
- @property
99
- def samples(self) -> list[TelemetrySample]:
100
- return list(self._samples)
101
-
102
- # ── Internal ──────────────────────────────────────────────────────────────
103
-
104
- async def _collect_loop(self) -> None:
105
- while self._running:
106
- sample = self._read_gpu_metrics()
107
- self._samples.append(sample)
108
- await asyncio.sleep(_COLLECTION_INTERVAL_S)
109
-
110
- def _read_gpu_metrics(self) -> TelemetrySample:
111
- """
112
- Returns a TelemetrySample for the current hardware state.
113
-
114
- PRODUCTION SWAP: Replace this body with pynvml calls:
115
- handle = nvmlDeviceGetHandleByIndex(0)
116
- util = nvmlDeviceGetUtilizationRates(handle)
117
- mem = nvmlDeviceGetMemoryInfo(handle)
118
- temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
119
- power = nvmlDeviceGetPowerUsage(handle) / 1000 # mW β†’ W
120
- """
121
- base_util, base_temp, base_power = self._hw_profile
122
-
123
- if base_util == 0.0: # CPU path β€” no meaningful GPU readings
124
- return TelemetrySample(
125
- timestamp = time.time(),
126
- gpu_util_pct = 0.0,
127
- vram_used_gb = 0.0,
128
- vram_total_gb = 0.0,
129
- temp_c = 0.0,
130
- power_w = 0.0,
131
- )
132
-
133
- # Simulate realistic jitter (Β±5% util, Β±3Β°C, Β±10W)
134
- jitter_util = random.gauss(0, 3.0)
135
- jitter_temp = random.gauss(0, 1.5)
136
- jitter_power = random.gauss(0, 8.0)
137
- vram_frac = random.uniform(0.58, 0.72)
138
-
139
- return TelemetrySample(
140
- timestamp = time.time(),
141
- gpu_util_pct = max(0.0, min(100.0, base_util + jitter_util)),
142
- vram_used_gb = round(
143
- max(0.0, min(self._vram_total, self._vram_total * vram_frac)), 3
144
- ),
145
- vram_total_gb = self._vram_total,
146
- temp_c = round(max(0.0, base_temp + jitter_temp), 1),
147
- power_w = round(max(0.0, base_power + jitter_power), 1),
148
- )
149
-
150
- def _build_summary(self) -> TelemetrySummary:
151
- if not self._samples:
152
- return TelemetrySummary()
153
-
154
- utils = [s.gpu_util_pct for s in self._samples]
155
- vrams = [s.vram_used_gb for s in self._samples]
156
- temps = [s.temp_c for s in self._samples]
157
- powers = [s.power_w for s in self._samples]
158
-
159
- def _safe_mean(lst: list[float]) -> float:
160
- return statistics.mean(lst) if lst else 0.0
161
-
162
- return TelemetrySummary(
163
- gpu_util_avg = round(_safe_mean(utils), 2),
164
- gpu_util_peak = round(max(utils), 2),
165
- vram_avg_gb = round(_safe_mean(vrams), 3),
166
- vram_peak_gb = round(max(vrams), 3),
167
- temp_avg_c = round(_safe_mean(temps), 1),
168
- temp_peak_c = round(max(temps), 1),
169
- power_avg_w = round(_safe_mean(powers), 1),
170
- power_peak_w = round(max(powers), 1),
171
- )
172
-
173
- @staticmethod
174
- def _resolve_profile(hardware: str) -> tuple[float, float, float]:
175
- hw = hardware.lower().replace(" ", "").replace("-", "")
176
- for key, profile in _HW_PROFILES.items():
177
- if key in hw:
178
- return profile
179
- # Default for unknown GPU-class hardware
180
- if any(x in hw for x in ("gpu", "rtx", "gtx", "cuda", "vram")):
181
- return _HW_PROFILES["gpu"]
182
- return _HW_PROFILES["cpu"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark/torch_runner.py DELETED
@@ -1,142 +0,0 @@
1
- """
2
- benchmark/torch_runner.py β€” Synchronous GPU inference runner.
3
-
4
- Called from BenchmarkExecutor via asyncio.run_in_executor() so it never
5
- blocks the event loop. PyTorch is an optional dependency β€” if it is not
6
- installed the module raises ImportError and execution.py falls back to
7
- the simulation path.
8
-
9
- Supported weight formats (detected by file extension):
10
- .pt / .pth β€” torch.load (TorchScript or state-dict)
11
- .safetensors β€” safetensors.torch.load_file
12
- .onnx β€” onnxruntime InferenceSession
13
-
14
- PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
15
- """
16
- from __future__ import annotations
17
-
18
- import time
19
- from pathlib import Path
20
- from typing import Any
21
-
22
- # ── Model cache (keyed by absolute path) ─────────────────────────────────────
23
- _MODEL_CACHE: dict[str, Any] = {}
24
-
25
- # Standard input shapes per task (B, C, H, W)
26
- _INPUT_SHAPES: dict[str, tuple[int, int, int]] = {
27
- "detection": (3, 640, 640),
28
- "segmentation": (3, 640, 640),
29
- "classification": (3, 224, 224),
30
- "generation": (3, 512, 512),
31
- "embedding": (3, 224, 224),
32
- }
33
- _DEFAULT_SHAPE = (3, 640, 640)
34
-
35
-
36
- def run_torch_batch(model_path: str, batch_size: int, task: str = "detection") -> float:
37
- """Run one inference batch and return per-image latency in ms.
38
-
39
- Args:
40
- model_path: Absolute path to the weight file.
41
- batch_size: Number of images in the batch.
42
- task: Model task (affects dummy input shape).
43
-
44
- Returns:
45
- Latency per image in milliseconds.
46
- """
47
- import torch # raises ImportError if not installed
48
-
49
- device = "cuda" if torch.cuda.is_available() else "cpu"
50
- ext = Path(model_path).suffix.lower()
51
-
52
- model = _load_model(model_path, ext, device)
53
- c, h, w = _INPUT_SHAPES.get(task, _DEFAULT_SHAPE)
54
- dummy = torch.zeros(batch_size, c, h, w, device=device)
55
-
56
- # Warm-up pass (first call is slower due to CUDA kernel compilation)
57
- if device == "cuda":
58
- with torch.no_grad():
59
- _forward(model, dummy, ext, device)
60
- torch.cuda.synchronize()
61
-
62
- # Timed pass
63
- if device == "cuda":
64
- torch.cuda.synchronize()
65
- t0 = time.perf_counter()
66
- with torch.no_grad():
67
- _forward(model, dummy, ext, device)
68
- if device == "cuda":
69
- torch.cuda.synchronize()
70
- elapsed_ms = (time.perf_counter() - t0) * 1000
71
-
72
- return elapsed_ms / batch_size
73
-
74
-
75
- def _load_model(path: str, ext: str, device: str) -> Any:
76
- """Load and cache the model by absolute path."""
77
- if path in _MODEL_CACHE:
78
- return _MODEL_CACHE[path]
79
-
80
- model = _load_by_ext(path, ext, device)
81
- _MODEL_CACHE[path] = model
82
- return model
83
-
84
-
85
- def _load_by_ext(path: str, ext: str, device: str) -> Any:
86
- """Select loader based on file extension."""
87
- if ext in (".pt", ".pth"):
88
- return _load_torch(path, device)
89
- if ext == ".safetensors":
90
- return _load_safetensors(path, device)
91
- if ext == ".onnx":
92
- return _load_onnx(path)
93
- raise ValueError(f"Unsupported model format: {ext}")
94
-
95
-
96
- def _load_torch(path: str, device: str) -> Any:
97
- import torch
98
- # <<< REPLACE IN PRODUCTION >>> with proper model class instantiation
99
- # TorchScript models can be loaded directly; state-dict models need
100
- # the model class to be imported separately.
101
- try:
102
- model = torch.jit.load(path, map_location=device)
103
- model.eval()
104
- return model
105
- except RuntimeError:
106
- # Not a TorchScript model β€” try loading as a full checkpoint
107
- obj = torch.load(path, map_location=device, weights_only=False)
108
- if hasattr(obj, "eval"):
109
- obj.eval()
110
- return obj
111
- # It's a state-dict β€” we cannot run inference without knowing the arch
112
- raise RuntimeError(
113
- f"Model at {path} is a state-dict; cannot run inference without "
114
- "the model class. Use a TorchScript-exported .pt file."
115
- )
116
-
117
-
118
- def _load_safetensors(path: str, device: str) -> Any:
119
- # <<< REPLACE IN PRODUCTION >>> safetensors gives tensors only;
120
- # you still need the model class. This is intentionally left as a
121
- # placeholder that raises a clear error rather than silently failing.
122
- raise NotImplementedError(
123
- "safetensors inference requires the model class to be registered. "
124
- "Convert to TorchScript or ONNX for architecture-agnostic inference."
125
- )
126
-
127
-
128
- def _load_onnx(path: str) -> Any:
129
- import onnxruntime as ort # type: ignore[import]
130
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
131
- return ort.InferenceSession(path, providers=providers)
132
-
133
-
134
- def _forward(model: Any, dummy: Any, ext: str, device: str) -> Any:
135
- """Run a single forward pass, dispatching by model type."""
136
- if ext == ".onnx":
137
- import numpy as np
138
- np_input = dummy.cpu().numpy()
139
- input_name = model.get_inputs()[0].name
140
- return model.run(None, {input_name: np_input})
141
- # TorchScript / nn.Module
142
- return model(dummy)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py CHANGED
@@ -21,31 +21,15 @@ class Settings(BaseSettings):
21
  # ── API ───────────────────────────────────────────────────────────
22
  host: str = "0.0.0.0"
23
  port: int = 7860 # Default for HF Spaces
24
- cors_origins: list[str] = [
25
- "http://localhost:3000",
26
- "http://127.0.0.1:3000",
27
- "http://localhost:5173",
28
- "http://127.0.0.1:5173",
29
- "http://localhost:2000",
30
- "http://127.0.0.1:2000",
31
- ]
32
 
33
  # ── Storage ───────────────────────────────────────────────────────
34
  base_dir: Path = Path(__file__).resolve().parents[1]
35
  data_dir: Path = base_dir / "data"
36
- models_dir: Path = data_dir / "models"
37
- datasets_dir: Path = data_dir / "datasets" # root for imported datasets
38
- logs_dir: Path = data_dir / "logs"
39
- db_path: Path = data_dir / "modelzoo.db"
40
-
41
- # ── Download Manager ──────────────────────────────────────────────
42
- max_concurrent_downloads: int = 5
43
- download_chunk_size: int = 1024 * 1024 # 1 MB
44
- download_max_retries: int = 3
45
- download_retry_delay: float = 2.0 # seconds (base, exponential backoff)
46
 
47
  # ── Search ────────────────────────────────────────────────────────
48
- search_max_results: int = 500
49
 
50
  # ── Sync ──────────────────────────────────────────────────────────
51
  auto_sync_on_startup: bool = True
@@ -54,30 +38,10 @@ class Settings(BaseSettings):
54
  hf_api_base: str = "https://huggingface.co/api"
55
  hf_hub_url: str = "https://huggingface.co"
56
  hf_token: str | None = None # Optional: HF_TOKEN env var
57
- hf_models_per_task: int = 100 # How many to pull per task
58
-
59
- # ── ONNX Zoo ──────────────────────────────────────────────────────
60
- onnx_models_url: str = (
61
- "https://raw.githubusercontent.com/onnx/models/main/README.md"
62
- )
63
-
64
- # ── Benchmark Bridge ──────────────────────────────────────────────
65
- benchmark_max_concurrent: int = 3 # max parallel benchmark jobs
66
- benchmark_max_log_lines: int = 500 # log entries kept per job
67
- benchmark_ws_poll_hz: float = 2.0 # WebSocket telemetry poll rate
68
-
69
- # ── Dataset Manager ───────────────────────────────────────────────
70
- roboflow_api_base: str = "https://api.roboflow.com"
71
- dataset_import_workers: int = 3 # max concurrent import jobs
72
- dataset_chunk_size: int = 1024 * 1024 * 4 # 4 MB download chunk
73
- roboflow_cache_ttl_secs: int = 3600 # 1 hour
74
 
75
  def ensure_dirs(self) -> None:
76
  self.data_dir.mkdir(parents=True, exist_ok=True)
77
- self.models_dir.mkdir(parents=True, exist_ok=True)
78
- self.datasets_dir.mkdir(parents=True, exist_ok=True)
79
- (self.datasets_dir / "_tmp").mkdir(parents=True, exist_ok=True)
80
- self.logs_dir.mkdir(parents=True, exist_ok=True)
81
 
82
 
83
  settings = Settings()
 
21
  # ── API ───────────────────────────────────────────────────────────
22
  host: str = "0.0.0.0"
23
  port: int = 7860 # Default for HF Spaces
24
+ cors_origins: list[str] = ["*"]
 
 
 
 
 
 
 
25
 
26
  # ── Storage ───────────────────────────────────────────────────────
27
  base_dir: Path = Path(__file__).resolve().parents[1]
28
  data_dir: Path = base_dir / "data"
29
+ db_path: Path = data_dir / "modelzoo.db"
 
 
 
 
 
 
 
 
 
30
 
31
  # ── Search ────────────────────────────────────────────────────────
32
+ search_max_results: int = 1000
33
 
34
  # ── Sync ──────────────────────────────────────────────────────────
35
  auto_sync_on_startup: bool = True
 
38
  hf_api_base: str = "https://huggingface.co/api"
39
  hf_hub_url: str = "https://huggingface.co"
40
  hf_token: str | None = None # Optional: HF_TOKEN env var
41
+ hf_models_per_task: int = 200 # Discovery server pulls more per task
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def ensure_dirs(self) -> None:
44
  self.data_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
45
 
46
 
47
  settings = Settings()
download/__init__.py DELETED
File without changes
download/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
download/__pycache__/manager.cpython-310.pyc DELETED
Binary file (11.1 kB)
 
download/manager.py DELETED
@@ -1,366 +0,0 @@
1
- """
2
- download/manager.py β€” Async download manager.
3
- Handles queueing, concurrency limiting, retry, resume, and progress tracking.
4
- All state is persisted in the jobs table for crash recovery.
5
- """
6
- from __future__ import annotations
7
-
8
- import asyncio
9
- import json
10
- import uuid
11
- from datetime import datetime, timezone
12
- from pathlib import Path
13
- from typing import Any
14
-
15
- import aiofiles
16
- import httpx
17
- from tenacity import retry, stop_after_attempt, wait_exponential
18
-
19
- from config import settings
20
- from database.connection import get_db
21
- from models.job import Job, row_to_job
22
- from observability.logger import audit, get_logger
23
- from registry.registry import get_model, update_model_status
24
-
25
- log = get_logger("download_manager")
26
-
27
- # ── Semaphore caps concurrent downloads ───────────────────────────────────────
28
- _download_sem: asyncio.Semaphore | None = None
29
-
30
-
31
- def _get_sem() -> asyncio.Semaphore:
32
- global _download_sem
33
- if _download_sem is None:
34
- _download_sem = asyncio.Semaphore(settings.max_concurrent_downloads)
35
- return _download_sem
36
-
37
-
38
- # ── Job CRUD ──────────────────────────────────────────────────────────────────
39
-
40
- async def _create_job(
41
- job_type: str,
42
- model_id: str,
43
- model_name: str,
44
- meta: dict | None = None,
45
- ) -> str:
46
- job_id = str(uuid.uuid4())
47
- db = await get_db()
48
- now = datetime.now(timezone.utc).isoformat()
49
- await db.execute(
50
- """INSERT INTO jobs (id, type, status, model_id, model_name, meta, created_at, updated_at)
51
- VALUES (?,?,?,?,?,?,?,?)""",
52
- (job_id, job_type, "queued", model_id, model_name,
53
- json.dumps(meta or {}), now, now),
54
- )
55
- await db.commit()
56
- log.info("job_created", job_id=job_id, type=job_type, model_id=model_id)
57
- await audit("job_created", model_id=model_id, job_id=job_id,
58
- payload={"type": job_type, "model_name": model_name})
59
- return job_id
60
-
61
-
62
- def _is_shard_file(filename: str) -> bool:
63
- """Return True if the file is part of a sharded model (e.g. model-00001-of-00003.safetensors)."""
64
- import re
65
- return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
66
-
67
-
68
- async def _get_active_version(model_id: str) -> str:
69
- """Return the active version string for a model, defaulting to 'v1'."""
70
- model = await get_model(model_id)
71
- if model and model.active_version:
72
- return model.active_version
73
- return "v1"
74
-
75
-
76
- @retry(
77
- stop=stop_after_attempt(3),
78
- wait=wait_exponential(multiplier=1, min=1, max=6),
79
- reraise=True,
80
- )
81
- async def _resolve_hf_download_url(repo_id: str) -> str:
82
- """Resolve a reliable download URL for a HF repo.
83
-
84
- Prefer safetensors over pytorch_model.bin; fall back to onnx if needed.
85
- """
86
- async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
87
- resp = await client.get(f"{settings.hf_api_base}/models/{repo_id}")
88
- resp.raise_for_status()
89
- data = resp.json()
90
-
91
- siblings = data.get("siblings") or []
92
- filenames: list[str] = []
93
- for s in siblings:
94
- fn = s.get("rfilename") or s.get("filename")
95
- if fn:
96
- filenames.append(fn)
97
-
98
- preferred_exact = [
99
- "model.safetensors",
100
- "pytorch_model.bin",
101
- "model.onnx",
102
- ]
103
- for fn in preferred_exact:
104
- if fn in filenames:
105
- return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
106
-
107
- preferred_suffix = [".safetensors", ".bin", ".onnx", ".pt", ".pth"]
108
- for suffix in preferred_suffix:
109
- for fn in filenames:
110
- if fn.endswith(suffix) and not _is_shard_file(fn):
111
- return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
112
-
113
- # Accept sharded files as a fallback (first shard of safetensors)
114
- for fn in filenames:
115
- if _is_shard_file(fn):
116
- return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
117
-
118
- # Last resort: try the index file for sharded models
119
- if "model.safetensors.index.json" in filenames:
120
- # For sharded models without a single file, use the first shard
121
- for fn in filenames:
122
- if fn.startswith("model-") and fn.endswith(".safetensors"):
123
- return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
124
-
125
- return f"https://huggingface.co/{repo_id}/resolve/main/pytorch_model.bin"
126
-
127
-
128
- async def _update_job(
129
- job_id: str,
130
- status: str | None = None,
131
- progress: float | None = None,
132
- error: str | None = None,
133
- started_at: str | None = None,
134
- ended_at: str | None = None,
135
- ) -> None:
136
- db = await get_db()
137
- now = datetime.now(timezone.utc).isoformat()
138
- parts: list[str] = ["updated_at = ?"]
139
- vals: list[Any] = [now]
140
- if status is not None: parts.append("status = ?"); vals.append(status)
141
- if progress is not None: parts.append("progress = ?"); vals.append(progress)
142
- if error is not None: parts.append("error = ?"); vals.append(error)
143
- if started_at: parts.append("started_at = ?"); vals.append(started_at)
144
- if ended_at: parts.append("ended_at = ?"); vals.append(ended_at)
145
- vals.append(job_id)
146
- await db.execute(f"UPDATE jobs SET {', '.join(parts)} WHERE id = ?", vals)
147
- await db.commit()
148
-
149
-
150
- # ── Download worker ───────────────────────────────────────────────────────────
151
-
152
- async def _execute_download(
153
- job_id: str,
154
- model_id: str,
155
- model_name: str,
156
- download_url: str,
157
- dest_path: Path,
158
- ) -> None:
159
- now = datetime.now(timezone.utc).isoformat()
160
- await _update_job(job_id, status="running", started_at=now)
161
-
162
- dest_path.parent.mkdir(parents=True, exist_ok=True)
163
- tmp_path = dest_path.with_suffix(".tmp")
164
-
165
- # Determine resume offset
166
- resume_offset = tmp_path.stat().st_size if tmp_path.exists() else 0
167
-
168
- headers: dict[str, str] = {}
169
- if resume_offset:
170
- headers["Range"] = f"bytes={resume_offset}-"
171
- log.info("download_resume", job_id=job_id, offset=resume_offset)
172
-
173
- try:
174
- async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
175
- async with client.stream("GET", download_url, headers=headers) as resp:
176
- resp.raise_for_status()
177
- total = int(resp.headers.get("content-length", 0)) + resume_offset
178
- downloaded = resume_offset
179
-
180
- async with aiofiles.open(tmp_path, "ab" if resume_offset else "wb") as fh:
181
- async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size):
182
- await fh.write(chunk)
183
- downloaded += len(chunk)
184
- progress = downloaded / total if total else 0
185
- await _update_job(job_id, progress=min(progress, 0.99))
186
-
187
- # Rename tmp β†’ final
188
- tmp_path.rename(dest_path)
189
- now_end = datetime.now(timezone.utc).isoformat()
190
- await _update_job(job_id, status="completed", progress=1.0, ended_at=now_end)
191
- await update_model_status(
192
- model_id,
193
- status="cached",
194
- downloaded=True,
195
- local_path=str(dest_path),
196
- )
197
- # Copy into the active project's workspace models/ folder
198
- from projects.service import link_model_to_active_project
199
- await link_model_to_active_project(model_id, str(dest_path))
200
- log.info("download_complete", job_id=job_id, model_id=model_id, path=str(dest_path))
201
- await audit("download_complete", model_id=model_id, job_id=job_id,
202
- payload={"path": str(dest_path)})
203
-
204
- except Exception as exc:
205
- now_end = datetime.now(timezone.utc).isoformat()
206
- await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end)
207
- await update_model_status(model_id, status="error")
208
- log.error("download_failed", job_id=job_id, error=str(exc))
209
- await audit("download_failed", model_id=model_id, job_id=job_id,
210
- payload={"error": str(exc)}, level="error")
211
- raise
212
-
213
-
214
- # ── Public API ────────────────────────────────────────────────────────────────
215
-
216
- async def enqueue_download(
217
- model_id: str,
218
- model_name: str,
219
- download_url: str | None = None,
220
- version: str | None = None,
221
- ) -> str:
222
- """Create a download job and dispatch resolution+download in the background.
223
-
224
- This function should not perform network calls; otherwise /download can return 500
225
- on transient provider errors.
226
- """
227
- job_id = await _create_job("download", model_id, model_name)
228
-
229
- asyncio.create_task(
230
- _rate_limited_download_resolving(job_id, model_id, model_name, download_url, version)
231
- )
232
- return job_id
233
-
234
-
235
- async def _rate_limited_download_resolving(
236
- job_id: str,
237
- model_id: str,
238
- model_name: str,
239
- download_url: str | None,
240
- version: str | None = None,
241
- ) -> None:
242
- async with _get_sem():
243
- try:
244
- resolved_url = await _resolve_download_url(model_id, download_url, version)
245
- # Version folder: use explicit version label, else active_version from DB
246
- folder = version or await _get_active_version(model_id)
247
- ext = Path(resolved_url.split("?")[0]).suffix or ".bin"
248
- dest_path = settings.models_dir / model_id / folder / f"model{ext}"
249
- await _execute_download(job_id, model_id, model_name, resolved_url, dest_path)
250
- except Exception as exc:
251
- now_end = datetime.now(timezone.utc).isoformat()
252
- await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end)
253
- await update_model_status(model_id, status="error")
254
- log.error("download_failed", job_id=job_id, error=str(exc))
255
- await audit(
256
- "download_failed",
257
- model_id=model_id,
258
- job_id=job_id,
259
- payload={"error": str(exc)},
260
- level="error",
261
- )
262
-
263
-
264
- async def _resolve_download_url(
265
- model_id: str,
266
- download_url: str | None,
267
- version: str | None = None,
268
- ) -> str:
269
- """Resolve the final download URL for a model.
270
-
271
- If `version` is provided and looks like a filename (e.g. 'yolov8n_pt'),
272
- it was generated by hf_adapter from a sibling rfilename. Restore the
273
- original filename (replace trailing _ext with .ext) and build a direct URL.
274
- """
275
- repo_id: str | None = None
276
-
277
- if download_url and "huggingface.co" in download_url:
278
- repo_id = download_url.replace("https://huggingface.co/", "").rstrip("/")
279
- elif not download_url:
280
- model = await get_model(model_id)
281
- if model and model.download_url:
282
- url = model.download_url
283
- if "huggingface.co" in url:
284
- repo_id = url.replace("https://huggingface.co/", "").rstrip("/")
285
- else:
286
- return url
287
- else:
288
- repo_id = model_id.replace("_", "/", 1)
289
- else:
290
- return download_url
291
-
292
- # If the caller specified a version that is a converted rfilename
293
- # (dots replaced with underscores by hf_adapter), reconstruct the filename.
294
- if version and repo_id:
295
- filename = _version_to_filename(version)
296
- if filename:
297
- return f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
298
-
299
- return await _resolve_hf_download_url(repo_id)
300
-
301
-
302
- def _version_to_filename(version: str) -> str | None:
303
- """Convert an hf_adapter version string back to a real filename.
304
-
305
- hf_adapter stores version as rfilename.replace('.', '_'), e.g.:
306
- 'yolov8n_pt' β†’ 'yolov8n.pt'
307
- 'model_safetensors' β†’ 'model.safetensors'
308
- Only converts if the result ends with a known weight extension.
309
- """
310
- weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx")
311
- # Try replacing the last underscore with a dot
312
- idx = version.rfind("_")
313
- if idx == -1:
314
- return None
315
- candidate = version[:idx] + "." + version[idx + 1:]
316
- if any(candidate.endswith(ext) for ext in weight_exts):
317
- return candidate
318
- return None
319
-
320
-
321
- async def _rate_limited_download(
322
- job_id: str,
323
- model_id: str,
324
- model_name: str,
325
- download_url: str,
326
- dest_path: Path,
327
- ) -> None:
328
- async with _get_sem():
329
- try:
330
- await _execute_download(job_id, model_id, model_name, download_url, dest_path)
331
- except Exception:
332
- pass # Already logged & stored in DB
333
-
334
-
335
- async def get_job(job_id: str) -> Job | None:
336
- db = await get_db()
337
- async with db.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)) as cur:
338
- row = await cur.fetchone()
339
- return row_to_job(row) if row else None
340
-
341
-
342
- async def list_jobs(
343
- status: str | None = None,
344
- limit: int = 50,
345
- ) -> list[Job]:
346
- db = await get_db()
347
- if status:
348
- sql = "SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC LIMIT ?"
349
- params: tuple = (status, limit)
350
- else:
351
- sql = "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?"
352
- params = (limit,)
353
- async with db.execute(sql, params) as cur:
354
- rows = await cur.fetchall()
355
- return [row_to_job(r) for r in rows]
356
-
357
-
358
- async def cancel_job(job_id: str) -> bool:
359
- """Cancel a queued or running job (best-effort)."""
360
- job = await get_job(job_id)
361
- if not job or job.status not in ("queued", "running"):
362
- return False
363
- now = datetime.now(timezone.utc).isoformat()
364
- await _update_job(job_id, status="cancelled", ended_at=now)
365
- log.info("job_cancelled", job_id=job_id)
366
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/__init__.py DELETED
@@ -1 +0,0 @@
1
- # inference package
 
 
inference/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (144 Bytes)
 
inference/__pycache__/engine.cpython-310.pyc DELETED
Binary file (12 kB)
 
inference/__pycache__/session.cpython-310.pyc DELETED
Binary file (2.87 kB)
 
inference/engine.py DELETED
@@ -1,447 +0,0 @@
1
- """
2
- inference/engine.py β€” MLForge Inference Engine.
3
-
4
- Dispatcher that routes each InferenceRequest to the correct adapter pipeline:
5
- YOLO β†’ YOLOInferencePipeline
6
- TRANSFORMERS β†’ TransformersPipeline
7
- ONNX β†’ ONNXPipeline
8
- CUSTOM β†’ CustomPipeline
9
-
10
- Each pipeline implements preprocess β†’ inference_step β†’ postprocess.
11
- Simulation paths are used when real model weights are not loaded;
12
- every # <<< REPLACE IN PRODUCTION >>> comment marks the exact swap point.
13
-
14
- Architecture follows the spec in infra_arch.md Β§4 (Adapter Protocol).
15
- """
16
- from __future__ import annotations
17
-
18
- import asyncio
19
- import base64
20
- import io
21
- import random
22
- import time
23
- import uuid
24
- from typing import Any
25
-
26
- from models.inference import (
27
- AdapterType,
28
- Detection,
29
- InferenceRequest,
30
- InferenceResult,
31
- PipelineStage,
32
- )
33
- from models.model import Model
34
- from observability.logger import get_logger
35
-
36
- log = get_logger("inference.engine")
37
-
38
- # ── Model cache: model_id β†’ loaded model object ──────────────────────────────
39
- _MODEL_CACHE: dict[str, Any] = {}
40
-
41
-
42
- def _now_ms() -> float:
43
- return time.perf_counter() * 1000
44
-
45
-
46
- # ── YOLO Pipeline ─────────────────────────────────────────────────────────────
47
-
48
- class YOLOPipeline:
49
- """
50
- YOLO inference pipeline.
51
- Preprocess: letterbox resize → BGR→RGB → 1/255 normalise.
52
- Postprocess: NMS β†’ [{x1,y1,x2,y2,confidence,class_id,class_name}].
53
- """
54
-
55
- async def run(
56
- self, req: InferenceRequest, model: Model
57
- ) -> tuple[list[PipelineStage], dict[str, Any]]:
58
- cfg = req.yolo_config
59
- conf = cfg.confidence if cfg else 0.25
60
- iou = cfg.iou_threshold if cfg else 0.45
61
-
62
- stages: list[PipelineStage] = []
63
-
64
- # β€” Stage 1: Preprocess β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
65
- t0 = _now_ms()
66
- await asyncio.sleep(0) # yield control
67
- if req.image_base64:
68
- try:
69
- raw_bytes = base64.b64decode(req.image_base64)
70
- # <<< REPLACE IN PRODUCTION >>>
71
- # img = cv2.imdecode(np.frombuffer(raw_bytes, np.uint8), cv2.IMREAD_COLOR)
72
- # tensor = letterbox(img, 640) / 255.0
73
- _ = len(raw_bytes) # validate decode worked
74
- except Exception as e:
75
- return [PipelineStage(name="Preprocess", status="error", detail=str(e))], {}
76
- pre_ms = _now_ms() - t0 + random.uniform(0.8, 2.5)
77
- stages.append(PipelineStage(name="Preprocess", status="done",
78
- latency_ms=round(pre_ms, 2), detail="Letterbox 640Γ—640"))
79
-
80
- # β€” Stage 2: Engine Load β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
81
- t1 = _now_ms()
82
- loaded = model.id in _MODEL_CACHE
83
- load_ms = 0.0 if loaded else random.uniform(80, 220)
84
- await asyncio.sleep(load_ms / 1000.0)
85
- if not loaded:
86
- _MODEL_CACHE[model.id] = object() # <<< REPLACE: load actual weights
87
- stages.append(PipelineStage(name="Engine Load", status="done",
88
- latency_ms=round(_now_ms() - t1, 2),
89
- detail="Cache hit" if loaded else "Weights loaded"))
90
-
91
- # β€” Stage 3: Inference β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
92
- t2 = _now_ms()
93
- size_gb = max(model.size, 1) / (1024 ** 3)
94
- base_lat = 2.5 + size_gb * 1.5
95
- infer_ms = base_lat + random.gauss(0, base_lat * 0.07)
96
- await asyncio.sleep(infer_ms / 1000.0)
97
- # <<< REPLACE IN PRODUCTION >>>
98
- # results = model_obj(tensor, conf=conf, iou=iou)
99
- stages.append(PipelineStage(name="Inference", status="done",
100
- latency_ms=round(infer_ms, 2),
101
- detail=f"conf={conf} iou={iou}"))
102
-
103
- # β€” Stage 4: Post-process (NMS) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
104
- t3 = _now_ms()
105
- detections = self._simulate_detections(conf, cfg.class_filter if cfg else [])
106
- post_ms = random.uniform(0.3, 1.2)
107
- await asyncio.sleep(post_ms / 1000.0)
108
- stages.append(PipelineStage(name="NMS Post-process", status="done",
109
- latency_ms=round(post_ms, 2),
110
- detail=f"{len(detections)} detections"))
111
-
112
- output: dict[str, Any] = {
113
- "detections": [d.model_dump() for d in detections],
114
- "pre_ms": round(pre_ms, 2),
115
- "infer_ms": round(infer_ms, 2),
116
- "post_ms": round(post_ms, 2),
117
- }
118
- return stages, output
119
-
120
- @staticmethod
121
- def _simulate_detections(conf_thresh: float, class_filter: list[str]) -> list[Detection]:
122
- """Simulate bounding-box detections. <<< REPLACE with real NMS output."""
123
- CLASSES = ["person", "car", "truck", "bicycle", "dog", "cat",
124
- "traffic light", "stop sign", "bench", "bird"]
125
- n = random.randint(0, 8)
126
- dets: list[Detection] = []
127
- for _ in range(n):
128
- c = random.uniform(conf_thresh, 1.0)
129
- cid = random.randint(0, len(CLASSES) - 1)
130
- cname = CLASSES[cid]
131
- if class_filter and cname not in class_filter:
132
- continue
133
- x1 = random.uniform(0, 0.7)
134
- y1 = random.uniform(0, 0.7)
135
- dets.append(Detection(
136
- x1=round(x1 * 640, 1), y1=round(y1 * 640, 1),
137
- x2=round((x1 + random.uniform(0.05, 0.3)) * 640, 1),
138
- y2=round((y1 + random.uniform(0.05, 0.3)) * 640, 1),
139
- confidence=round(c, 4),
140
- class_id=cid, class_name=cname,
141
- ))
142
- return dets
143
-
144
-
145
- # ── Transformers Pipeline ─────────────────────────────────────────────────────
146
-
147
- class TransformersPipeline:
148
- """
149
- HuggingFace Transformers pipeline.
150
- Preprocess: AutoTokenizer.encode.
151
- Inference: model.generate with KV-cache.
152
- Postprocess: decode + strip special tokens.
153
- """
154
-
155
- async def run(
156
- self, req: InferenceRequest, model: Model
157
- ) -> tuple[list[PipelineStage], dict[str, Any]]:
158
- cfg = req.transformers_config
159
- stages: list[PipelineStage] = []
160
-
161
- # β€” Tokenize β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
162
- t0 = _now_ms()
163
- txt = req.text_input or "Hello, world!"
164
- tok_count = len(txt.split()) * 2 # rough BPE estimate
165
- await asyncio.sleep(0.002)
166
- pre_ms = _now_ms() - t0 + random.uniform(1, 4)
167
- stages.append(PipelineStage(name="Tokenise", status="done",
168
- latency_ms=round(pre_ms, 2),
169
- detail=f"{tok_count} tokens"))
170
-
171
- # β€” Engine Load β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
172
- t1 = _now_ms()
173
- loaded = model.id in _MODEL_CACHE
174
- load_ms = 0.0 if loaded else random.uniform(150, 400)
175
- await asyncio.sleep(load_ms / 1000.0)
176
- if not loaded:
177
- _MODEL_CACHE[model.id] = object()
178
- stages.append(PipelineStage(name="Engine Load", status="done",
179
- latency_ms=round(_now_ms() - t1, 2),
180
- detail="Cache hit" if loaded else "Model loaded"))
181
-
182
- # β€” Generate β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
183
- t2 = _now_ms()
184
- max_tok = cfg.max_new_tokens if cfg else 256
185
- # Simulate token-by-token generation at ~20 tok/s
186
- infer_ms = (max_tok / 20.0) * 1000 + random.gauss(0, 50)
187
- await asyncio.sleep(min(infer_ms / 1000.0, 0.5)) # cap sim delay
188
- # <<< REPLACE IN PRODUCTION >>>
189
- # outputs = model_obj.generate(input_ids, max_new_tokens=max_tok,
190
- # temperature=cfg.temperature, top_p=cfg.top_p, do_sample=cfg.do_sample)
191
- stages.append(PipelineStage(name="Generate", status="done",
192
- latency_ms=round(infer_ms, 2),
193
- detail=f"~{max_tok} tokens @ fp16"))
194
-
195
- # β€” Decode β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
196
- t3 = _now_ms()
197
- text_output = self._simulate_text(txt, max_tok)
198
- post_ms = random.uniform(0.5, 2.0)
199
- stages.append(PipelineStage(name="Decode", status="done",
200
- latency_ms=round(post_ms, 2),
201
- detail="Special tokens stripped"))
202
-
203
- output: dict[str, Any] = {
204
- "text_output": text_output,
205
- "tokens_generated": max_tok,
206
- "pre_ms": round(pre_ms, 2),
207
- "infer_ms": round(infer_ms, 2),
208
- "post_ms": round(post_ms, 2),
209
- }
210
- return stages, output
211
-
212
- @staticmethod
213
- def _simulate_text(prompt: str, n_tokens: int) -> str:
214
- """Placeholder generation. <<< REPLACE with model.generate."""
215
- lorem = (
216
- "The model processed your input and generated a response based on the "
217
- "learned distribution of the training corpus. This output is a simulation "
218
- "placeholder β€” replace with actual model.generate() in production. "
219
- )
220
- # Repeat to roughly match token count
221
- words = (lorem * (n_tokens // 20 + 1)).split()[:n_tokens]
222
- return " ".join(words)
223
-
224
-
225
- # ── ONNX Pipeline ─────────────────────────────────��───────────────────────────
226
-
227
- class ONNXPipeline:
228
- """
229
- ONNX Runtime pipeline.
230
- Acts as universal wrapper for TF / sklearn / PyTorch exported models.
231
- Dynamically maps input tensor names from model metadata.
232
- """
233
-
234
- async def run(
235
- self, req: InferenceRequest, model: Model
236
- ) -> tuple[list[PipelineStage], dict[str, Any]]:
237
- cfg = req.onnx_config
238
- stages: list[PipelineStage] = []
239
- provider = cfg.execution_provider if cfg else "CUDAExecutionProvider"
240
-
241
- # β€” Preprocess β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
242
- t0 = _now_ms()
243
- pre_ms = random.uniform(1.0, 3.5)
244
- await asyncio.sleep(pre_ms / 1000.0)
245
- stages.append(PipelineStage(name="Preprocess", status="done",
246
- latency_ms=round(pre_ms, 2),
247
- detail="Normalise + reshape tensor"))
248
-
249
- # β€” ONNX Runtime β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
250
- t1 = _now_ms()
251
- loaded = model.id in _MODEL_CACHE
252
- load_ms = 0.0 if loaded else random.uniform(50, 150)
253
- await asyncio.sleep(load_ms / 1000.0)
254
- if not loaded:
255
- _MODEL_CACHE[model.id] = object()
256
- # <<< REPLACE IN PRODUCTION >>>
257
- # import onnxruntime as ort
258
- # sess_opts = ort.SessionOptions()
259
- # _MODEL_CACHE[model.id] = ort.InferenceSession(
260
- # model.local_path, sess_options=sess_opts,
261
- # providers=[provider])
262
- stages.append(PipelineStage(name="ONNX Runtime", status="done",
263
- latency_ms=round(_now_ms() - t1, 2),
264
- detail=provider.replace("ExecutionProvider", "")))
265
-
266
- # β€” Inference β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
267
- t2 = _now_ms()
268
- infer_ms = random.uniform(3.0, 12.0)
269
- await asyncio.sleep(infer_ms / 1000.0)
270
- # <<< REPLACE IN PRODUCTION >>>
271
- # ort_inputs = {sess.get_inputs()[0].name: tensor.numpy()}
272
- # raw = sess.run(None, ort_inputs)
273
- stages.append(PipelineStage(name="Inference", status="done",
274
- latency_ms=round(infer_ms, 2),
275
- detail="session.run()"))
276
-
277
- # β€” Format Output β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
278
- t3 = _now_ms()
279
- post_ms = random.uniform(0.2, 0.8)
280
- raw_out = {"output_0": [round(random.random(), 4) for _ in range(10)]}
281
- stages.append(PipelineStage(name="Format Output", status="done",
282
- latency_ms=round(post_ms, 2),
283
- detail="Tensor β†’ JSON"))
284
-
285
- output: dict[str, Any] = {
286
- "raw_output": raw_out,
287
- "pre_ms": round(pre_ms, 2),
288
- "infer_ms": round(infer_ms, 2),
289
- "post_ms": round(post_ms, 2),
290
- }
291
- return stages, output
292
-
293
-
294
- # ── Custom Python Pipeline ────────────────────────────────────────────────────
295
-
296
- class CustomPipeline:
297
- """
298
- Sandboxed custom Python pipeline.
299
- Executes user-supplied pre/postprocess scripts in a restricted namespace.
300
- Only numpy, the input tensor, and the model's raw output are accessible.
301
- """
302
-
303
- FORBIDDEN = ("import os", "import sys", "subprocess", "open(", "__import__",
304
- "eval(", "exec(", "globals(", "locals(")
305
-
306
- def _validate_script(self, script: str) -> str | None:
307
- for tok in self.FORBIDDEN:
308
- if tok in script:
309
- return f"Forbidden token in script: {tok!r}"
310
- return None
311
-
312
- async def run(
313
- self, req: InferenceRequest, model: Model
314
- ) -> tuple[list[PipelineStage], dict[str, Any]]:
315
- cfg = req.custom_config
316
- stages: list[PipelineStage] = []
317
-
318
- # β€” Validate scripts β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
319
- if cfg:
320
- for label, script in [("preprocess", cfg.preprocess_script),
321
- ("postprocess", cfg.postprocess_script)]:
322
- if script:
323
- err = self._validate_script(script)
324
- if err:
325
- return [PipelineStage(name=label.capitalize(),
326
- status="error", detail=err)], {}
327
-
328
- # β€” Transform Input β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
329
- pre_ms = random.uniform(1.0, 5.0)
330
- await asyncio.sleep(pre_ms / 1000.0)
331
- stages.append(PipelineStage(name="Transform Input", status="done",
332
- latency_ms=round(pre_ms, 2),
333
- detail="Custom preprocess script"))
334
-
335
- # β€” Run Inference β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
336
- infer_ms = random.uniform(5.0, 30.0)
337
- await asyncio.sleep(infer_ms / 1000.0)
338
- # <<< REPLACE IN PRODUCTION >>>
339
- # namespace = {"input": tensor, "model": raw_model}
340
- # exec(compile(cfg.preprocess_script, "<pre>", "exec"), namespace)
341
- # tensor = namespace.get("output", tensor)
342
- stages.append(PipelineStage(name="Run Inference", status="done",
343
- latency_ms=round(infer_ms, 2),
344
- detail="Custom runtime"))
345
-
346
- # β€” Format Result β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
347
- post_ms = random.uniform(0.5, 3.0)
348
- stages.append(PipelineStage(name="Format Result", status="done",
349
- latency_ms=round(post_ms, 2),
350
- detail="Custom postprocess script"))
351
-
352
- output: dict[str, Any] = {
353
- "raw_output": {"custom_result": round(random.random(), 4)},
354
- "pre_ms": round(pre_ms, 2),
355
- "infer_ms": round(infer_ms, 2),
356
- "post_ms": round(post_ms, 2),
357
- }
358
- return stages, output
359
-
360
-
361
- # ── Master Dispatcher ─────────────────────────────────────────────────────────
362
-
363
- _PIPELINE_MAP = {
364
- AdapterType.YOLO: YOLOPipeline,
365
- AdapterType.TRANSFORMERS: TransformersPipeline,
366
- AdapterType.ONNX: ONNXPipeline,
367
- AdapterType.CUSTOM: CustomPipeline,
368
- }
369
-
370
-
371
- class InferenceEngine:
372
- """
373
- Central inference dispatcher.
374
- Resolves the correct pipeline, executes it, and wraps the result
375
- into a fully-populated InferenceResult.
376
- """
377
-
378
- async def run(self, req: InferenceRequest, model: Model) -> InferenceResult:
379
- t_start = _now_ms()
380
- pipeline_cls = _PIPELINE_MAP.get(req.adapter_type)
381
- if pipeline_cls is None:
382
- return InferenceResult(
383
- request_id=str(uuid.uuid4()),
384
- model_id=req.model_id,
385
- adapter_type=req.adapter_type,
386
- status="error",
387
- error=f"Unknown adapter type: {req.adapter_type}",
388
- )
389
-
390
- try:
391
- stages, output = await pipeline_cls().run(req, model)
392
-
393
- total_ms = _now_ms() - t_start
394
- pre_ms = output.get("pre_ms", 0.0)
395
- infer_ms = output.get("infer_ms", 0.0)
396
- post_ms = output.get("post_ms", 0.0)
397
-
398
- # Quality score: mean confidence of detections (0–5 scale)
399
- detections = [Detection(**d) for d in output.get("detections", [])]
400
- if detections:
401
- mean_conf = sum(d.confidence for d in detections) / len(detections)
402
- quality = round(mean_conf * 5.0, 2)
403
- else:
404
- quality = round(random.uniform(3.2, 4.8), 2)
405
-
406
- result = InferenceResult(
407
- model_id = req.model_id,
408
- adapter_type = req.adapter_type,
409
- preprocess_ms = pre_ms,
410
- inference_ms = infer_ms,
411
- postprocess_ms= post_ms,
412
- total_ms = round(total_ms, 2),
413
- pipeline = stages,
414
- detections = detections,
415
- text_output = output.get("text_output"),
416
- raw_output = output.get("raw_output"),
417
- quality_score = quality,
418
- status = "ok",
419
- )
420
-
421
- log.info("inference_complete",
422
- model_id=req.model_id,
423
- adapter=req.adapter_type,
424
- total_ms=round(total_ms, 2))
425
- return result
426
-
427
- except Exception as exc:
428
- log.error("inference_error", model_id=req.model_id, error=str(exc))
429
- return InferenceResult(
430
- model_id=req.model_id,
431
- adapter_type=req.adapter_type,
432
- status="error",
433
- error=str(exc),
434
- )
435
-
436
-
437
- def get_cache_status() -> dict[str, bool]:
438
- """Return which model IDs are currently warm in cache."""
439
- return {k: True for k in _MODEL_CACHE}
440
-
441
-
442
- def evict_model(model_id: str) -> bool:
443
- """Evict a model from the in-process cache (free VRAM sim)."""
444
- if model_id in _MODEL_CACHE:
445
- del _MODEL_CACHE[model_id]
446
- return True
447
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/session.py DELETED
@@ -1,80 +0,0 @@
1
- """
2
- inference/session.py β€” In-memory inference session ledger.
3
-
4
- Keeps the last MAX_HISTORY inference results per process lifetime.
5
- Persisted to the SQLite `inference_history` table on each write
6
- (non-blocking via aiosqlite).
7
- """
8
- from __future__ import annotations
9
-
10
- import asyncio
11
- import json
12
- import uuid
13
- from collections import deque
14
- from typing import Deque
15
-
16
- from models.inference import InferenceHistoryEntry, InferenceRequest, InferenceResult
17
- from observability.logger import get_logger
18
-
19
- log = get_logger("inference.session")
20
-
21
- MAX_HISTORY = 200
22
-
23
- _history: Deque[InferenceHistoryEntry] = deque(maxlen=MAX_HISTORY)
24
- _lock = asyncio.Lock()
25
-
26
-
27
- async def record(req: InferenceRequest, result: InferenceResult, model_name: str) -> None:
28
- """Append a completed inference run to the ledger."""
29
- entry = InferenceHistoryEntry(
30
- model_id = req.model_id,
31
- model_name = model_name,
32
- adapter_type = req.adapter_type,
33
- total_ms = result.total_ms,
34
- quality_score = result.quality_score,
35
- status = result.status,
36
- request_snapshot = req.model_dump(exclude={"image_base64"}),
37
- )
38
- async with _lock:
39
- _history.appendleft(entry)
40
-
41
- # Persist to DB (fire-and-forget)
42
- asyncio.create_task(_persist(entry))
43
-
44
-
45
- async def _persist(entry: InferenceHistoryEntry) -> None:
46
- try:
47
- from database.connection import get_db
48
- async with get_db() as db:
49
- await db.execute(
50
- """
51
- INSERT OR REPLACE INTO inference_history
52
- (id, model_id, model_name, adapter_type, timestamp,
53
- total_ms, quality_score, status, request_snapshot)
54
- VALUES (?,?,?,?,?,?,?,?,?)
55
- """,
56
- (
57
- entry.id,
58
- entry.model_id,
59
- entry.model_name,
60
- entry.adapter_type.value,
61
- entry.timestamp,
62
- entry.total_ms,
63
- entry.quality_score,
64
- entry.status,
65
- json.dumps(entry.request_snapshot),
66
- ),
67
- )
68
- await db.commit()
69
- except Exception as exc:
70
- log.warning("inference_persist_failed", error=str(exc))
71
-
72
-
73
- async def get_history(limit: int = 50) -> list[InferenceHistoryEntry]:
74
- async with _lock:
75
- return list(_history)[:limit]
76
-
77
-
78
- async def clear_history() -> None:
79
- async with _lock:
80
- _history.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -26,7 +26,6 @@ from fastapi.responses import JSONResponse
26
  from api.routes import models as models_router
27
  from api.routes import sync as sync_router
28
  from api.routes import datasets as datasets_router
29
- from api.routes import projects as projects_router
30
  from config import settings
31
  from database.connection import close_db, get_db
32
  from middleware.logging_middleware import RequestLoggingMiddleware
@@ -65,9 +64,9 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
65
 
66
  # ── Application ───────────────────────────────────────────────────────────────
67
  app = FastAPI(
68
- title=settings.app_name,
69
  version=settings.version,
70
- description="Production ML Model Zoo backend β€” local-first, traceable, extensible.",
71
  docs_url="/docs",
72
  redoc_url="/redoc",
73
  lifespan=lifespan,
@@ -91,8 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
91
  # ── Middleware ─────────────────────────────────────────────────────────────────
92
  app.add_middleware(
93
  CORSMiddleware,
94
- allow_origins=settings.cors_origins,
95
- allow_origin_regex=r"^https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?$",
96
  allow_credentials=True,
97
  allow_methods=["*"],
98
  allow_headers=["*"],
@@ -103,7 +101,6 @@ app.add_middleware(RequestLoggingMiddleware)
103
  app.include_router(models_router.router)
104
  app.include_router(sync_router.router)
105
  app.include_router(datasets_router.router)
106
- app.include_router(projects_router.router)
107
 
108
 
109
  @app.get("/health", tags=["system"])
@@ -114,6 +111,7 @@ async def health() -> dict:
114
  n_datasets = await count_datasets()
115
  return {
116
  "status": "ok",
 
117
  "version": settings.version,
118
  "model_count": n_models,
119
  "dataset_count": n_datasets,
 
26
  from api.routes import models as models_router
27
  from api.routes import sync as sync_router
28
  from api.routes import datasets as datasets_router
 
29
  from config import settings
30
  from database.connection import close_db, get_db
31
  from middleware.logging_middleware import RequestLoggingMiddleware
 
64
 
65
  # ── Application ───────────────────────────────────────────────────────────────
66
  app = FastAPI(
67
+ title="MLForge Cloud Registry",
68
  version=settings.version,
69
+ description="Global Model and Dataset Discovery Service β€” The Brain of MLForge.",
70
  docs_url="/docs",
71
  redoc_url="/redoc",
72
  lifespan=lifespan,
 
90
  # ── Middleware ─────────────────────────────────────────────────────────────────
91
  app.add_middleware(
92
  CORSMiddleware,
93
+ allow_origins=["*"], # Allow all origins for the cloud registry to support SDK/CLI/UI
 
94
  allow_credentials=True,
95
  allow_methods=["*"],
96
  allow_headers=["*"],
 
101
  app.include_router(models_router.router)
102
  app.include_router(sync_router.router)
103
  app.include_router(datasets_router.router)
 
104
 
105
 
106
  @app.get("/health", tags=["system"])
 
111
  n_datasets = await count_datasets()
112
  return {
113
  "status": "ok",
114
+ "service": "cloud_registry",
115
  "version": settings.version,
116
  "model_count": n_models,
117
  "dataset_count": n_datasets,
projects/__init__.py DELETED
File without changes