Spaces:
Sleeping
Sleeping
senthil2421 commited on
Commit Β·
99e3f1b
1
Parent(s): d81f11d
arch: refactor cloud_backend into lean discovery server by removing execution logic
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- adapters/__init__.py +0 -0
- adapters/__pycache__/__init__.cpython-310.pyc +0 -0
- adapters/__pycache__/base.cpython-310.pyc +0 -0
- adapters/__pycache__/hf_adapter.cpython-310.pyc +0 -0
- adapters/__pycache__/onnx_adapter.cpython-310.pyc +0 -0
- adapters/__pycache__/roboflow_adapter.cpython-310.pyc +0 -0
- adapters/base.py +0 -28
- adapters/hf_adapter.py +0 -415
- adapters/onnx_adapter.py +0 -176
- adapters/roboflow_adapter.py +0 -353
- api/routes/benchmark.py +0 -238
- api/routes/inference.py +0 -168
- api/routes/jobs.py +0 -56
- api/routes/system.py +0 -97
- api/routes/training.py +0 -428
- benchmark/__init__.py +0 -1
- benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
- benchmark/__pycache__/compatibility.cpython-310.pyc +0 -0
- benchmark/__pycache__/execution.cpython-310.pyc +0 -0
- benchmark/__pycache__/metrics.cpython-310.pyc +0 -0
- benchmark/__pycache__/orchestrator.cpython-310.pyc +0 -0
- benchmark/__pycache__/registry.cpython-310.pyc +0 -0
- benchmark/__pycache__/telemetry.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/base.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/registry.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc +0 -0
- benchmark/adapters/base.py +0 -38
- benchmark/adapters/optimum_runner.py +0 -53
- benchmark/adapters/registry.py +0 -44
- benchmark/adapters/torch_runner.py +0 -45
- benchmark/compatibility.py +0 -360
- benchmark/execution.py +0 -366
- benchmark/metrics.py +0 -110
- benchmark/orchestrator.py +0 -374
- benchmark/registry.py +0 -302
- benchmark/telemetry.py +0 -182
- benchmark/torch_runner.py +0 -142
- config.py +4 -40
- download/__init__.py +0 -0
- download/__pycache__/__init__.cpython-310.pyc +0 -0
- download/__pycache__/manager.cpython-310.pyc +0 -0
- download/manager.py +0 -366
- inference/__init__.py +0 -1
- inference/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/__pycache__/engine.cpython-310.pyc +0 -0
- inference/__pycache__/session.cpython-310.pyc +0 -0
- inference/engine.py +0 -447
- inference/session.py +0 -80
- main.py +4 -6
- projects/__init__.py +0 -0
adapters/__init__.py
DELETED
|
File without changes
|
adapters/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (143 Bytes)
|
|
|
adapters/__pycache__/base.cpython-310.pyc
DELETED
|
Binary file (1.31 kB)
|
|
|
adapters/__pycache__/hf_adapter.cpython-310.pyc
DELETED
|
Binary file (13 kB)
|
|
|
adapters/__pycache__/onnx_adapter.cpython-310.pyc
DELETED
|
Binary file (5.27 kB)
|
|
|
adapters/__pycache__/roboflow_adapter.cpython-310.pyc
DELETED
|
Binary file (10.9 kB)
|
|
|
adapters/base.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
adapters/base.py β Abstract base class every source adapter must implement.
|
| 3 |
-
Enforces a stable contract so the registry never knows which adapter runs.
|
| 4 |
-
"""
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
from abc import ABC, abstractmethod
|
| 8 |
-
|
| 9 |
-
from models.model import Model
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class BaseAdapter(ABC):
|
| 13 |
-
"""Fetch models from an external source and normalize to the Model schema."""
|
| 14 |
-
|
| 15 |
-
source_name: str = "unknown"
|
| 16 |
-
|
| 17 |
-
@abstractmethod
|
| 18 |
-
async def fetch_models(self) -> list[Model]:
|
| 19 |
-
"""Return a list of normalized Model objects from the source."""
|
| 20 |
-
...
|
| 21 |
-
|
| 22 |
-
def _format_size(self, bytes_: int) -> str:
|
| 23 |
-
"""Human-readable file size."""
|
| 24 |
-
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 25 |
-
if bytes_ < 1024:
|
| 26 |
-
return f"{bytes_:.1f} {unit}"
|
| 27 |
-
bytes_ //= 1024
|
| 28 |
-
return f"{bytes_} PB"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adapters/hf_adapter.py
DELETED
|
@@ -1,415 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
adapters/hf_adapter.py β Hugging Face Hub adapter.
|
| 3 |
-
Fetches real models via the public HF API and normalises them to our schema.
|
| 4 |
-
|
| 5 |
-
Rate-limits respected via polite delays. Requires no authentication for
|
| 6 |
-
publicly accessible models; set HF_TOKEN env var for higher rate-limits.
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import asyncio
|
| 11 |
-
import re
|
| 12 |
-
from typing import Any
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def _is_shard_file(filename: str) -> bool:
|
| 16 |
-
"""Return True for sharded weight files like model-00001-of-00003.safetensors."""
|
| 17 |
-
return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
|
| 18 |
-
|
| 19 |
-
import httpx
|
| 20 |
-
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 21 |
-
|
| 22 |
-
from adapters.base import BaseAdapter
|
| 23 |
-
from config import settings
|
| 24 |
-
from models.model import Model, ModelMetrics, ModelVersion
|
| 25 |
-
from observability.logger import get_logger
|
| 26 |
-
|
| 27 |
-
log = get_logger("hf_adapter")
|
| 28 |
-
|
| 29 |
-
# ββ Task mapping: HF pipeline_tag β our internal task βββββββββββββββββββββββββ
|
| 30 |
-
HF_TASK_MAP: dict[str, str] = {
|
| 31 |
-
"object-detection": "detection",
|
| 32 |
-
"image-classification": "classification",
|
| 33 |
-
"image-segmentation": "segmentation",
|
| 34 |
-
"text-to-image": "generation",
|
| 35 |
-
"image-to-image": "generation",
|
| 36 |
-
"image-feature-extraction": "embedding",
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
# Tasks we actively fetch
|
| 40 |
-
FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys())
|
| 41 |
-
|
| 42 |
-
# ββ Framework detection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
-
def _detect_framework(tags: list[str], model_id: str) -> str:
|
| 44 |
-
tag_str = " ".join(tags + [model_id]).lower()
|
| 45 |
-
if "onnx" in tag_str: return "onnx"
|
| 46 |
-
if "tflite" in tag_str: return "tflite"
|
| 47 |
-
if "coreml" in tag_str: return "coreml"
|
| 48 |
-
if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow"
|
| 49 |
-
return "pytorch" # HF default
|
| 50 |
-
|
| 51 |
-
# ββ Hardware detection βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
-
def _detect_hardware(tags: list[str]) -> list[str]:
|
| 53 |
-
hw: list[str] = []
|
| 54 |
-
tag_str = " ".join(tags).lower()
|
| 55 |
-
if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu")
|
| 56 |
-
if "edge" in tag_str or "mobile" in tag_str: hw.append("edge")
|
| 57 |
-
if "cpu" in tag_str: hw.append("cpu")
|
| 58 |
-
if not hw: hw.append("gpu") # safe default
|
| 59 |
-
return hw
|
| 60 |
-
|
| 61 |
-
# ββ Internal tag normalisation βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
-
QUALITY_TAG_MAP = {
|
| 63 |
-
"state-of-the-art": "sota",
|
| 64 |
-
"lightweight": "lightweight",
|
| 65 |
-
"tiny": "tiny",
|
| 66 |
-
"fast": "fastest",
|
| 67 |
-
"real-time": "real-time",
|
| 68 |
-
"accuracy": "high-accuracy",
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]:
|
| 72 |
-
out: list[str] = []
|
| 73 |
-
for t in raw_tags:
|
| 74 |
-
t_lower = t.lower()
|
| 75 |
-
for keyword, mapped in QUALITY_TAG_MAP.items():
|
| 76 |
-
if keyword in t_lower:
|
| 77 |
-
out.append(mapped)
|
| 78 |
-
# keep relevant library / dataset tags
|
| 79 |
-
if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")):
|
| 80 |
-
continue
|
| 81 |
-
out.append(t_lower)
|
| 82 |
-
# add pipeline as tag
|
| 83 |
-
if pipeline:
|
| 84 |
-
out.append(pipeline.replace("-", "_"))
|
| 85 |
-
return list(dict.fromkeys(out)) # deduplicate, preserve order
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
class HFAdapter(BaseAdapter):
|
| 89 |
-
source_name = "hf"
|
| 90 |
-
|
| 91 |
-
def __init__(self) -> None:
|
| 92 |
-
headers = {"Accept": "application/json"}
|
| 93 |
-
if settings.hf_token:
|
| 94 |
-
headers["Authorization"] = f"Bearer {settings.hf_token}"
|
| 95 |
-
self._client = httpx.AsyncClient(
|
| 96 |
-
base_url=settings.hf_api_base,
|
| 97 |
-
headers=headers,
|
| 98 |
-
timeout=30,
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
@retry(
|
| 102 |
-
stop=stop_after_attempt(3),
|
| 103 |
-
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 104 |
-
reraise=True,
|
| 105 |
-
)
|
| 106 |
-
async def _fetch_task_page(
|
| 107 |
-
self, pipeline_tag: str, limit: int = 100
|
| 108 |
-
) -> list[dict[str, Any]]:
|
| 109 |
-
params = {
|
| 110 |
-
"pipeline_tag": pipeline_tag,
|
| 111 |
-
"sort": "downloads",
|
| 112 |
-
"direction": -1, # descending
|
| 113 |
-
"limit": limit,
|
| 114 |
-
"full": "True",
|
| 115 |
-
}
|
| 116 |
-
log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit)
|
| 117 |
-
resp = await self._client.get("/models", params=params)
|
| 118 |
-
resp.raise_for_status()
|
| 119 |
-
return resp.json()
|
| 120 |
-
|
| 121 |
-
@retry(
|
| 122 |
-
stop=stop_after_attempt(3),
|
| 123 |
-
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 124 |
-
reraise=True,
|
| 125 |
-
)
|
| 126 |
-
async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]:
|
| 127 |
-
resp = await self._client.get(f"/models/{model_id}", params={"full": "True"})
|
| 128 |
-
resp.raise_for_status()
|
| 129 |
-
raw = resp.json()
|
| 130 |
-
|
| 131 |
-
siblings: list[dict[str, Any]] = raw.get("siblings") or []
|
| 132 |
-
has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings)
|
| 133 |
-
if not has_any_size:
|
| 134 |
-
try:
|
| 135 |
-
tree = await self._fetch_model_tree(model_id, revision="main")
|
| 136 |
-
size_by_path: dict[str, int] = {
|
| 137 |
-
(t.get("path") or ""): int(t.get("size") or 0)
|
| 138 |
-
for t in (tree or [])
|
| 139 |
-
if isinstance(t, dict)
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
patched: list[dict[str, Any]] = []
|
| 143 |
-
for s in siblings:
|
| 144 |
-
if not isinstance(s, dict):
|
| 145 |
-
continue
|
| 146 |
-
fn = s.get("rfilename") or s.get("path") or ""
|
| 147 |
-
if fn and not s.get("size") and fn in size_by_path:
|
| 148 |
-
s = {**s, "size": size_by_path[fn]}
|
| 149 |
-
patched.append(s)
|
| 150 |
-
raw["siblings"] = patched
|
| 151 |
-
except Exception:
|
| 152 |
-
pass
|
| 153 |
-
|
| 154 |
-
return raw
|
| 155 |
-
|
| 156 |
-
@retry(
|
| 157 |
-
stop=stop_after_attempt(3),
|
| 158 |
-
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 159 |
-
reraise=True,
|
| 160 |
-
)
|
| 161 |
-
async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]:
|
| 162 |
-
resp = await self._client.get(f"/models/{model_id}/tree/{revision}")
|
| 163 |
-
resp.raise_for_status()
|
| 164 |
-
data = resp.json()
|
| 165 |
-
if isinstance(data, list):
|
| 166 |
-
return data
|
| 167 |
-
return []
|
| 168 |
-
|
| 169 |
-
def _parse_safe_tensors_size(self, siblings: list[dict]) -> int:
|
| 170 |
-
"""Estimate model size from sibling file list."""
|
| 171 |
-
total = 0
|
| 172 |
-
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
|
| 173 |
-
for s in siblings or []:
|
| 174 |
-
filename = s.get("rfilename", "").lower()
|
| 175 |
-
if filename.endswith(weight_exts):
|
| 176 |
-
total += s.get("size", 0)
|
| 177 |
-
|
| 178 |
-
if total > 0:
|
| 179 |
-
return total
|
| 180 |
-
|
| 181 |
-
# If no size found in siblings, check if it's in the root dict (sometimes HF API does this)
|
| 182 |
-
return 0 # Return 0 if not found, we'll handle fallback in _make_model
|
| 183 |
-
|
| 184 |
-
@retry(
|
| 185 |
-
stop=stop_after_attempt(3),
|
| 186 |
-
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 187 |
-
reraise=True,
|
| 188 |
-
)
|
| 189 |
-
async def _fetch_model_card(self, model_id: str) -> str:
|
| 190 |
-
"""Fetch model card (README.md) content for real-time description."""
|
| 191 |
-
url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md"
|
| 192 |
-
try:
|
| 193 |
-
resp = await self._client.get(url)
|
| 194 |
-
if resp.status_code == 200:
|
| 195 |
-
return resp.text
|
| 196 |
-
except Exception:
|
| 197 |
-
pass
|
| 198 |
-
return ""
|
| 199 |
-
|
| 200 |
-
def _extract_description(self, readme: str, raw: dict[str, Any]) -> str:
|
| 201 |
-
"""Extract a clean description from README or card data."""
|
| 202 |
-
if readme:
|
| 203 |
-
# Simple heuristic: take first paragraph that isn't frontmatter
|
| 204 |
-
lines = readme.split("\n")
|
| 205 |
-
in_frontmatter = False
|
| 206 |
-
for line in lines:
|
| 207 |
-
if line.strip() == "---":
|
| 208 |
-
in_frontmatter = not in_frontmatter
|
| 209 |
-
continue
|
| 210 |
-
if not in_frontmatter and line.strip() and not line.startswith("#"):
|
| 211 |
-
return line.strip()[:500]
|
| 212 |
-
|
| 213 |
-
card_data = raw.get("cardData") or {}
|
| 214 |
-
description: str = (
|
| 215 |
-
(card_data.get("summary") or "")
|
| 216 |
-
or (card_data.get("description") or "")
|
| 217 |
-
or (raw.get("description") or "")
|
| 218 |
-
).strip()
|
| 219 |
-
return description
|
| 220 |
-
|
| 221 |
-
def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics:
|
| 222 |
-
"""
|
| 223 |
-
Product-Grade Metrics Estimation.
|
| 224 |
-
Uses model name heuristics to provide realistic data for common architectures.
|
| 225 |
-
"""
|
| 226 |
-
metrics = ModelMetrics()
|
| 227 |
-
m_id = model_id.lower()
|
| 228 |
-
|
| 229 |
-
# Base latency/vram estimates by architecture
|
| 230 |
-
if "vit" in m_id or "dinov2" in m_id:
|
| 231 |
-
metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0
|
| 232 |
-
metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8
|
| 233 |
-
metrics.accuracy = 82.4 if "base" in m_id else 84.5
|
| 234 |
-
elif "segformer" in m_id:
|
| 235 |
-
# b0, b1, b2, b3, b4, b5
|
| 236 |
-
if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0
|
| 237 |
-
elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0
|
| 238 |
-
elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0
|
| 239 |
-
else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0
|
| 240 |
-
elif "convnext" in m_id:
|
| 241 |
-
metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0
|
| 242 |
-
metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2
|
| 243 |
-
metrics.accuracy = 81.0 if "tiny" in m_id else 83.5
|
| 244 |
-
elif "yolo" in m_id:
|
| 245 |
-
# n, s, m, l, x
|
| 246 |
-
if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3
|
| 247 |
-
elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9
|
| 248 |
-
elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2
|
| 249 |
-
else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0
|
| 250 |
-
|
| 251 |
-
# Generic task-based fallbacks if still empty
|
| 252 |
-
if metrics.latency_ms is None:
|
| 253 |
-
if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0
|
| 254 |
-
elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0
|
| 255 |
-
elif task == "embedding": metrics.latency_ms = 40.0
|
| 256 |
-
elif task == "generation": metrics.latency_ms = 1500.0
|
| 257 |
-
|
| 258 |
-
return metrics
|
| 259 |
-
|
| 260 |
-
def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None:
|
| 261 |
-
model_id: str = raw.get("id") or raw.get("modelId", "")
|
| 262 |
-
if not model_id:
|
| 263 |
-
return None
|
| 264 |
-
|
| 265 |
-
task = HF_TASK_MAP.get(pipeline_tag)
|
| 266 |
-
if not task:
|
| 267 |
-
return None
|
| 268 |
-
tags_raw: list[str] = raw.get("tags") or []
|
| 269 |
-
framework = _detect_framework(tags_raw, model_id)
|
| 270 |
-
hardware = _detect_hardware(tags_raw)
|
| 271 |
-
tags = _normalise_tags(tags_raw, pipeline_tag)
|
| 272 |
-
|
| 273 |
-
# Size
|
| 274 |
-
siblings: list[dict] = raw.get("siblings") or []
|
| 275 |
-
size = self._parse_safe_tensors_size(siblings)
|
| 276 |
-
if size == 0:
|
| 277 |
-
# Fallback based on model type if size not found
|
| 278 |
-
if "large" in model_id.lower(): size = 1_200_000_000
|
| 279 |
-
elif "base" in model_id.lower(): size = 500_000_000
|
| 280 |
-
elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000
|
| 281 |
-
else: size = 450_000_000 # More realistic general default than exactly 500MB
|
| 282 |
-
|
| 283 |
-
# Provider β author part of model_id
|
| 284 |
-
provider = model_id.split("/")[0] if "/" in model_id else "community"
|
| 285 |
-
|
| 286 |
-
# safe name
|
| 287 |
-
name = model_id.split("/")[-1] if "/" in model_id else model_id
|
| 288 |
-
# Clean ugly names
|
| 289 |
-
name = re.sub(r"[-_]+", "-", name).strip("-")
|
| 290 |
-
|
| 291 |
-
downloads = raw.get("downloads") or 0
|
| 292 |
-
likes = raw.get("likes") or 0
|
| 293 |
-
|
| 294 |
-
# Fabricate a sensible version from last modified
|
| 295 |
-
last_mod: str = raw.get("lastModified") or raw.get("createdAt") or ""
|
| 296 |
-
release_date = last_mod[:10] if last_mod else "2024-01-01"
|
| 297 |
-
sha8 = (raw.get("sha") or "main")[:8]
|
| 298 |
-
|
| 299 |
-
# Build versions from weight files in the repo (one per distinct weight file)
|
| 300 |
-
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
|
| 301 |
-
weight_files = [
|
| 302 |
-
s for s in siblings
|
| 303 |
-
if s.get("rfilename", "").lower().endswith(weight_exts)
|
| 304 |
-
and not _is_shard_file(s.get("rfilename", ""))
|
| 305 |
-
]
|
| 306 |
-
|
| 307 |
-
if len(weight_files) > 1:
|
| 308 |
-
versions = []
|
| 309 |
-
for s in weight_files[:15]:
|
| 310 |
-
filename = s["rfilename"]
|
| 311 |
-
# Detect variant from filename (n, s, m, l, x, or specific labels)
|
| 312 |
-
variant_label = "Stable"
|
| 313 |
-
fn_lower = filename.lower()
|
| 314 |
-
if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano"
|
| 315 |
-
elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small"
|
| 316 |
-
elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium"
|
| 317 |
-
elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large"
|
| 318 |
-
elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge"
|
| 319 |
-
|
| 320 |
-
versions.append(ModelVersion(
|
| 321 |
-
version=filename.replace(".", "_"),
|
| 322 |
-
label=variant_label,
|
| 323 |
-
description=f"Model variant: {filename}",
|
| 324 |
-
releaseDate=release_date,
|
| 325 |
-
changelog=None,
|
| 326 |
-
))
|
| 327 |
-
else:
|
| 328 |
-
versions = [
|
| 329 |
-
ModelVersion(
|
| 330 |
-
version=sha8,
|
| 331 |
-
label="Latest",
|
| 332 |
-
description="Primary model weight file.",
|
| 333 |
-
releaseDate=release_date,
|
| 334 |
-
changelog=None,
|
| 335 |
-
)
|
| 336 |
-
]
|
| 337 |
-
|
| 338 |
-
# Description from card data
|
| 339 |
-
description = self._extract_description("", raw)
|
| 340 |
-
if not description:
|
| 341 |
-
description = f"{task.capitalize()} model by {provider}."
|
| 342 |
-
|
| 343 |
-
# Metrics Estimation
|
| 344 |
-
metrics = self._estimate_metrics(model_id, task)
|
| 345 |
-
|
| 346 |
-
return Model(
|
| 347 |
-
id = model_id.replace("/", "_").lower(),
|
| 348 |
-
name = name,
|
| 349 |
-
task = task,
|
| 350 |
-
framework = framework,
|
| 351 |
-
source = "hf",
|
| 352 |
-
provider = provider,
|
| 353 |
-
description = description,
|
| 354 |
-
download_url = f"https://huggingface.co/{model_id}",
|
| 355 |
-
size = size,
|
| 356 |
-
size_label = self._format_size(size),
|
| 357 |
-
tags = tags,
|
| 358 |
-
hardware = hardware,
|
| 359 |
-
status = "available",
|
| 360 |
-
downloaded = False,
|
| 361 |
-
downloads = downloads,
|
| 362 |
-
rating = min(5.0, (likes / 200) + 3.5) if likes else None,
|
| 363 |
-
liked = False,
|
| 364 |
-
metrics = metrics,
|
| 365 |
-
versions = versions,
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
async def fetch_models(self) -> list[Model]:
|
| 369 |
-
models: list[Model] = []
|
| 370 |
-
seen_ids: set[str] = set()
|
| 371 |
-
|
| 372 |
-
for pipeline_tag in FETCH_TASKS:
|
| 373 |
-
try:
|
| 374 |
-
raw_list = await self._fetch_task_page(
|
| 375 |
-
pipeline_tag, limit=settings.hf_models_per_task
|
| 376 |
-
)
|
| 377 |
-
for idx, raw in enumerate(raw_list):
|
| 378 |
-
# Enrich top-N per task with full model detail so siblings include sizes.
|
| 379 |
-
if idx < 10:
|
| 380 |
-
original_id = raw.get("id") or raw.get("modelId")
|
| 381 |
-
if original_id:
|
| 382 |
-
try:
|
| 383 |
-
raw = await self._fetch_model_detail(original_id)
|
| 384 |
-
except Exception:
|
| 385 |
-
pass
|
| 386 |
-
|
| 387 |
-
m = self._make_model(raw, pipeline_tag)
|
| 388 |
-
if m and m.id not in seen_ids:
|
| 389 |
-
# Try to fetch real-time description for the first 5 models of each task
|
| 390 |
-
if len([mod for mod in models if mod.task == m.task]) < 5:
|
| 391 |
-
original_id = raw.get("id") or raw.get("modelId")
|
| 392 |
-
if original_id:
|
| 393 |
-
readme = await self._fetch_model_card(original_id)
|
| 394 |
-
if readme:
|
| 395 |
-
m.description = self._extract_description(readme, raw)
|
| 396 |
-
|
| 397 |
-
seen_ids.add(m.id)
|
| 398 |
-
models.append(m)
|
| 399 |
-
# Be polite to HF API
|
| 400 |
-
await asyncio.sleep(0.3)
|
| 401 |
-
except Exception as exc:
|
| 402 |
-
log.warning(
|
| 403 |
-
"hf_fetch_task_failed",
|
| 404 |
-
pipeline_tag=pipeline_tag,
|
| 405 |
-
error=str(exc),
|
| 406 |
-
)
|
| 407 |
-
|
| 408 |
-
log.info("hf_fetch_complete", total=len(models))
|
| 409 |
-
return models
|
| 410 |
-
|
| 411 |
-
async def __aenter__(self) -> "HFAdapter":
|
| 412 |
-
return self
|
| 413 |
-
|
| 414 |
-
async def __aexit__(self, *_: Any) -> None:
|
| 415 |
-
await self._client.aclose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adapters/onnx_adapter.py
DELETED
|
@@ -1,176 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
adapters/onnx_adapter.py β ONNX Model Zoo adapter.
|
| 3 |
-
Fetches the curated list of ONNX Zoo models from the GitHub API.
|
| 4 |
-
"""
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import httpx
|
| 10 |
-
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 11 |
-
|
| 12 |
-
from adapters.base import BaseAdapter
|
| 13 |
-
from models.model import Model, ModelMetrics, ModelVersion
|
| 14 |
-
from observability.logger import get_logger
|
| 15 |
-
|
| 16 |
-
log = get_logger("onnx_adapter")
|
| 17 |
-
|
| 18 |
-
# Curated ONNX Zoo models with metadata + download URLs (GitHub API is rate-limited without auth)
|
| 19 |
-
ONNX_CURATED: list[dict[str, Any]] = [
|
| 20 |
-
{
|
| 21 |
-
"id": "onnx_resnet50",
|
| 22 |
-
"name": "ResNet-50",
|
| 23 |
-
"task": "classification",
|
| 24 |
-
"provider": "ONNX Zoo",
|
| 25 |
-
"description": "ResNet-50 v1 image classification model in ONNX format.",
|
| 26 |
-
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx",
|
| 27 |
-
"size": 102_000_000,
|
| 28 |
-
"tags": ["resnet", "imagenet", "classification"],
|
| 29 |
-
"hardware": ["gpu", "cpu"],
|
| 30 |
-
"metrics": {"latency_ms": 14.2, "top1": 74.9},
|
| 31 |
-
"downloads": 250_000,
|
| 32 |
-
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-06-01"}],
|
| 33 |
-
},
|
| 34 |
-
{
|
| 35 |
-
"id": "onnx_yolov8n",
|
| 36 |
-
"name": "YOLOv8n",
|
| 37 |
-
"task": "detection",
|
| 38 |
-
"provider": "Ultralytics",
|
| 39 |
-
"description": "Ultralytics YOLOv8 Nano β real-time object detection, ONNX export.",
|
| 40 |
-
"download_url": "https://github.com/ultralytics/yolov8/releases/download/v8.0.0/yolov8n.onnx",
|
| 41 |
-
"size": 6_200_000,
|
| 42 |
-
"tags": ["yolo", "real-time", "fastest", "edge"],
|
| 43 |
-
"hardware": ["gpu", "cpu", "edge"],
|
| 44 |
-
"metrics": {"latency_ms": 3.1, "mAP": 37.3},
|
| 45 |
-
"downloads": 420_000,
|
| 46 |
-
"versions": [{"version": "8.0", "label": "Latest", "releaseDate": "2023-09-15"}],
|
| 47 |
-
},
|
| 48 |
-
{
|
| 49 |
-
"id": "onnx_mobilenet_v3",
|
| 50 |
-
"name": "MobileNetV3-Large",
|
| 51 |
-
"task": "classification",
|
| 52 |
-
"provider": "Google",
|
| 53 |
-
"description": "MobileNetV3-Large for efficient on-device image classification.",
|
| 54 |
-
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv3-large-1.11.onnx",
|
| 55 |
-
"size": 22_000_000,
|
| 56 |
-
"tags": ["mobilenet", "lightweight", "edge", "efficient"],
|
| 57 |
-
"hardware": ["cpu", "edge"],
|
| 58 |
-
"metrics": {"latency_ms": 5.8, "top1": 75.2, "fps": 180},
|
| 59 |
-
"downloads": 310_000,
|
| 60 |
-
"versions": [{"version": "3.0", "label": "Latest", "releaseDate": "2023-01-01"}],
|
| 61 |
-
},
|
| 62 |
-
{
|
| 63 |
-
"id": "onnx_bert_base_uncased",
|
| 64 |
-
"name": "BERT-Base-Uncased",
|
| 65 |
-
"task": "nlp",
|
| 66 |
-
"provider": "Google",
|
| 67 |
-
"description": "BERT base model fine-tuned for NLP inference in ONNX format.",
|
| 68 |
-
"download_url": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
|
| 69 |
-
"size": 438_000_000,
|
| 70 |
-
"tags": ["bert", "nlp", "transformer"],
|
| 71 |
-
"hardware": ["gpu", "cpu"],
|
| 72 |
-
"metrics": {"latency_ms": 42.0},
|
| 73 |
-
"downloads": 198_000,
|
| 74 |
-
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2022-11-01"}],
|
| 75 |
-
},
|
| 76 |
-
{
|
| 77 |
-
"id": "onnx_efficientnet_b0",
|
| 78 |
-
"name": "EfficientNet-B0",
|
| 79 |
-
"task": "classification",
|
| 80 |
-
"provider": "Google Brain",
|
| 81 |
-
"description": "EfficientNet-B0 for scalable image classification.",
|
| 82 |
-
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite/model/efficientnet-lite4-11.onnx",
|
| 83 |
-
"size": 20_000_000,
|
| 84 |
-
"tags": ["efficientnet", "efficient", "high-accuracy"],
|
| 85 |
-
"hardware": ["gpu", "cpu"],
|
| 86 |
-
"metrics": {"latency_ms": 10.4, "top1": 77.1},
|
| 87 |
-
"downloads": 145_000,
|
| 88 |
-
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-03-01"}],
|
| 89 |
-
},
|
| 90 |
-
{
|
| 91 |
-
"id": "onnx_sam_vit_b",
|
| 92 |
-
"name": "SAM ViT-B",
|
| 93 |
-
"task": "segmentation",
|
| 94 |
-
"provider": "Meta AI",
|
| 95 |
-
"description": "Segment Anything Model (ViT-B) for universal image segmentation.",
|
| 96 |
-
"download_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
| 97 |
-
"size": 375_000_000,
|
| 98 |
-
"tags": ["sam", "segmentation", "sota"],
|
| 99 |
-
"hardware": ["gpu"],
|
| 100 |
-
"metrics": {"latency_ms": 68.0},
|
| 101 |
-
"downloads": 88_000,
|
| 102 |
-
"versions": [{"version": "1.0", "label": "Latest", "releaseDate": "2023-04-05"}],
|
| 103 |
-
},
|
| 104 |
-
{
|
| 105 |
-
"id": "onnx_clip_vit_b32",
|
| 106 |
-
"name": "CLIP ViT-B/32",
|
| 107 |
-
"task": "embedding",
|
| 108 |
-
"provider": "OpenAI",
|
| 109 |
-
"description": "CLIP image + text embedding model for zero-shot classification.",
|
| 110 |
-
"download_url": "https://openaipublic.blob.core.windows.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba4f386/ViT-B-32.pt",
|
| 111 |
-
"size": 338_000_000,
|
| 112 |
-
"tags": ["clip", "embedding", "multimodal"],
|
| 113 |
-
"hardware": ["gpu", "cpu"],
|
| 114 |
-
"metrics": {"latency_ms": 25.0},
|
| 115 |
-
"downloads": 275_000,
|
| 116 |
-
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-01-01"}],
|
| 117 |
-
},
|
| 118 |
-
{
|
| 119 |
-
"id": "onnx_whisper_tiny",
|
| 120 |
-
"name": "Whisper Tiny",
|
| 121 |
-
"task": "nlp",
|
| 122 |
-
"provider": "OpenAI",
|
| 123 |
-
"description": "Whisper Tiny speech-to-text model in ONNX format.",
|
| 124 |
-
"download_url": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424930e36a852c0/tiny.pt",
|
| 125 |
-
"size": 39_000_000,
|
| 126 |
-
"tags": ["whisper", "speech", "lightweight"],
|
| 127 |
-
"hardware": ["cpu", "edge"],
|
| 128 |
-
"metrics": {"latency_ms": 100.0},
|
| 129 |
-
"downloads": 167_000,
|
| 130 |
-
"versions": [{"version": "20231117", "label": "Latest", "releaseDate": "2023-11-17"}],
|
| 131 |
-
},
|
| 132 |
-
]
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class ONNXAdapter(BaseAdapter):
|
| 136 |
-
source_name = "onnx"
|
| 137 |
-
|
| 138 |
-
async def fetch_models(self) -> list[Model]:
|
| 139 |
-
models: list[Model] = []
|
| 140 |
-
for raw in ONNX_CURATED:
|
| 141 |
-
try:
|
| 142 |
-
versions = [
|
| 143 |
-
ModelVersion(
|
| 144 |
-
version=v["version"],
|
| 145 |
-
label=v.get("label", "Stable"),
|
| 146 |
-
releaseDate=v.get("releaseDate", ""),
|
| 147 |
-
)
|
| 148 |
-
for v in raw.get("versions", [])
|
| 149 |
-
]
|
| 150 |
-
metrics_raw = raw.get("metrics", {})
|
| 151 |
-
m = Model(
|
| 152 |
-
id = raw["id"],
|
| 153 |
-
name = raw["name"],
|
| 154 |
-
task = raw["task"],
|
| 155 |
-
framework = "onnx",
|
| 156 |
-
source = "onnx",
|
| 157 |
-
provider = raw.get("provider", "ONNX Zoo"),
|
| 158 |
-
description = raw.get("description", ""),
|
| 159 |
-
download_url = raw.get("download_url"),
|
| 160 |
-
size = raw.get("size", 0),
|
| 161 |
-
size_label = self._format_size(raw.get("size", 0)),
|
| 162 |
-
tags = raw.get("tags", []),
|
| 163 |
-
hardware = raw.get("hardware", ["gpu"]),
|
| 164 |
-
status = "available",
|
| 165 |
-
downloaded = False,
|
| 166 |
-
downloads = raw.get("downloads"),
|
| 167 |
-
rating = 4.2,
|
| 168 |
-
metrics = ModelMetrics(**metrics_raw),
|
| 169 |
-
versions = versions,
|
| 170 |
-
)
|
| 171 |
-
models.append(m)
|
| 172 |
-
except Exception as exc:
|
| 173 |
-
log.warning("onnx_parse_failed", model_id=raw.get("id"), error=str(exc))
|
| 174 |
-
|
| 175 |
-
log.info("onnx_fetch_complete", total=len(models))
|
| 176 |
-
return models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adapters/roboflow_adapter.py
DELETED
|
@@ -1,353 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
adapters/roboflow_adapter.py β Roboflow Universe API client.
|
| 3 |
-
|
| 4 |
-
Responsibilities:
|
| 5 |
-
- Fetch dataset metadata (search, workspace listings, project details)
|
| 6 |
-
- Normalise responses β Dataset domain model
|
| 7 |
-
- Cache results in roboflow_cache table (TTL-aware)
|
| 8 |
-
- Handle pagination, rate limits, and errors robustly
|
| 9 |
-
|
| 10 |
-
Roboflow API reference: https://docs.roboflow.com/api-reference/
|
| 11 |
-
"""
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import hashlib
|
| 15 |
-
import json
|
| 16 |
-
import time
|
| 17 |
-
from typing import Any
|
| 18 |
-
|
| 19 |
-
import httpx
|
| 20 |
-
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 21 |
-
|
| 22 |
-
from database.connection import get_db
|
| 23 |
-
from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask
|
| 24 |
-
from observability.logger import audit, get_logger
|
| 25 |
-
|
| 26 |
-
log = get_logger("roboflow_adapter")
|
| 27 |
-
|
| 28 |
-
_ROBOFLOW_BASE = "https://api.roboflow.com"
|
| 29 |
-
_UNIVERSE_BASE = "https://universe.roboflow.com"
|
| 30 |
-
_DEFAULT_TTL = 3600 # 1 hour
|
| 31 |
-
|
| 32 |
-
# ββ Task mapping from Roboflow annotation_type βββββββββββββββββββββββββββββββ
|
| 33 |
-
|
| 34 |
-
_TASK_MAP: dict[str, DatasetTask] = {
|
| 35 |
-
"object-detection": DatasetTask.detection,
|
| 36 |
-
"instance-segmentation": DatasetTask.segmentation,
|
| 37 |
-
"semantic-segmentation": DatasetTask.segmentation,
|
| 38 |
-
"classification": DatasetTask.classification,
|
| 39 |
-
"keypoint-detection": DatasetTask.keypoints,
|
| 40 |
-
"multiclass-classification": DatasetTask.classification,
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
_FORMAT_MAP: dict[str, DatasetFormat] = {
|
| 44 |
-
"yolov5": DatasetFormat.yolo,
|
| 45 |
-
"yolov7": DatasetFormat.yolo,
|
| 46 |
-
"yolov8": DatasetFormat.yolo,
|
| 47 |
-
"yolov9": DatasetFormat.yolo,
|
| 48 |
-
"coco": DatasetFormat.coco,
|
| 49 |
-
"voc": DatasetFormat.voc,
|
| 50 |
-
"tfrecord": DatasetFormat.tfrecord,
|
| 51 |
-
"csv": DatasetFormat.csv,
|
| 52 |
-
"createml": DatasetFormat.json,
|
| 53 |
-
"multiclass": DatasetFormat.csv,
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def _cache_key(parts: list[str]) -> str:
|
| 58 |
-
raw = "|".join(parts)
|
| 59 |
-
return hashlib.sha256(raw.encode()).hexdigest()[:32]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _fmt_bytes(n: int) -> str:
|
| 63 |
-
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 64 |
-
if n < 1024:
|
| 65 |
-
return f"{n:.1f} {unit}"
|
| 66 |
-
n /= 1024
|
| 67 |
-
return f"{n:.1f} PB"
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# ββ Cache helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
-
|
| 72 |
-
async def _cache_get(key: str) -> dict[str, Any] | None:
|
| 73 |
-
db = await get_db()
|
| 74 |
-
async with db.execute(
|
| 75 |
-
"SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?",
|
| 76 |
-
(key,),
|
| 77 |
-
) as cur:
|
| 78 |
-
row = await cur.fetchone()
|
| 79 |
-
if row is None:
|
| 80 |
-
return None
|
| 81 |
-
fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S"))
|
| 82 |
-
if time.time() - fetched > row["ttl_secs"]:
|
| 83 |
-
return None # expired
|
| 84 |
-
return json.loads(row["payload"])
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None:
|
| 88 |
-
db = await get_db()
|
| 89 |
-
await db.execute(
|
| 90 |
-
"""INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs)
|
| 91 |
-
VALUES (?, ?, ?)""",
|
| 92 |
-
(key, json.dumps(payload), ttl),
|
| 93 |
-
)
|
| 94 |
-
await db.commit()
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# ββ HTTP client factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
-
|
| 99 |
-
def _make_client(api_key: str) -> httpx.AsyncClient:
|
| 100 |
-
return httpx.AsyncClient(
|
| 101 |
-
base_url=_ROBOFLOW_BASE,
|
| 102 |
-
params={"api_key": api_key},
|
| 103 |
-
timeout=30.0,
|
| 104 |
-
headers={"User-Agent": "MLForge/1.0"},
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ββ Roboflow Adapter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
-
|
| 110 |
-
class RoboflowAdapter:
|
| 111 |
-
"""
|
| 112 |
-
Stateless adapter for the Roboflow API.
|
| 113 |
-
All methods accept api_key explicitly to support per-user keys.
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
# ββ Search (Universe) βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
-
|
| 118 |
-
@staticmethod
|
| 119 |
-
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 120 |
-
async def search_datasets(
|
| 121 |
-
api_key: str,
|
| 122 |
-
query: str = "",
|
| 123 |
-
workspace: str | None = None,
|
| 124 |
-
page: int = 0,
|
| 125 |
-
page_size: int = 50,
|
| 126 |
-
) -> list[Dataset]:
|
| 127 |
-
"""
|
| 128 |
-
Search Roboflow Universe for datasets.
|
| 129 |
-
Returns normalised Dataset objects.
|
| 130 |
-
"""
|
| 131 |
-
ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)])
|
| 132 |
-
cached = await _cache_get(ck)
|
| 133 |
-
if cached:
|
| 134 |
-
log.debug("roboflow_cache_hit", key=ck, query=query)
|
| 135 |
-
return [Dataset(**d) for d in cached]
|
| 136 |
-
|
| 137 |
-
params: dict[str, Any] = {
|
| 138 |
-
"api_key": api_key,
|
| 139 |
-
"q": query or "*",
|
| 140 |
-
"from": page * page_size,
|
| 141 |
-
"size": page_size,
|
| 142 |
-
}
|
| 143 |
-
if workspace:
|
| 144 |
-
params["workspace"] = workspace
|
| 145 |
-
|
| 146 |
-
async with _make_client(api_key) as client:
|
| 147 |
-
try:
|
| 148 |
-
resp = await client.get("/", params=params)
|
| 149 |
-
resp.raise_for_status()
|
| 150 |
-
data = resp.json()
|
| 151 |
-
except httpx.HTTPStatusError as e:
|
| 152 |
-
log.error("roboflow_api_error", status=e.response.status_code, query=query)
|
| 153 |
-
await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error")
|
| 154 |
-
raise
|
| 155 |
-
|
| 156 |
-
datasets = []
|
| 157 |
-
for item in data.get("results", []):
|
| 158 |
-
try:
|
| 159 |
-
ds = RoboflowAdapter._normalise_search_result(item)
|
| 160 |
-
datasets.append(ds)
|
| 161 |
-
except Exception as exc:
|
| 162 |
-
log.warning("normalise_error", item_id=item.get("id"), error=str(exc))
|
| 163 |
-
|
| 164 |
-
await _cache_set(ck, [d.model_dump() for d in datasets])
|
| 165 |
-
await audit("roboflow_search", {"query": query, "count": len(datasets)})
|
| 166 |
-
return datasets
|
| 167 |
-
|
| 168 |
-
# ββ Workspace datasets listing ββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
-
|
| 170 |
-
@staticmethod
|
| 171 |
-
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 172 |
-
async def list_workspace_datasets(
|
| 173 |
-
api_key: str,
|
| 174 |
-
workspace: str,
|
| 175 |
-
) -> list[Dataset]:
|
| 176 |
-
"""List all datasets in a Roboflow workspace."""
|
| 177 |
-
ck = _cache_key(["workspace", workspace])
|
| 178 |
-
cached = await _cache_get(ck)
|
| 179 |
-
if cached:
|
| 180 |
-
return [Dataset(**d) for d in cached]
|
| 181 |
-
|
| 182 |
-
async with _make_client(api_key) as client:
|
| 183 |
-
try:
|
| 184 |
-
resp = await client.get(f"/{workspace}")
|
| 185 |
-
resp.raise_for_status()
|
| 186 |
-
data = resp.json()
|
| 187 |
-
except httpx.HTTPStatusError as e:
|
| 188 |
-
log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code)
|
| 189 |
-
raise
|
| 190 |
-
|
| 191 |
-
datasets = []
|
| 192 |
-
for proj in data.get("workspace", {}).get("projects", []):
|
| 193 |
-
try:
|
| 194 |
-
ds = RoboflowAdapter._normalise_project(proj, workspace)
|
| 195 |
-
datasets.append(ds)
|
| 196 |
-
except Exception as exc:
|
| 197 |
-
log.warning("normalise_project_error", project=proj.get("id"), error=str(exc))
|
| 198 |
-
|
| 199 |
-
await _cache_set(ck, [d.model_dump() for d in datasets])
|
| 200 |
-
return datasets
|
| 201 |
-
|
| 202 |
-
# ββ Single project detail βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
-
|
| 204 |
-
@staticmethod
|
| 205 |
-
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 206 |
-
async def get_project(
|
| 207 |
-
api_key: str,
|
| 208 |
-
workspace: str,
|
| 209 |
-
project_id: str,
|
| 210 |
-
) -> Dataset | None:
|
| 211 |
-
"""Fetch full metadata for a single Roboflow project."""
|
| 212 |
-
ck = _cache_key(["project", workspace, project_id])
|
| 213 |
-
cached = await _cache_get(ck)
|
| 214 |
-
if cached:
|
| 215 |
-
return Dataset(**cached)
|
| 216 |
-
|
| 217 |
-
async with _make_client(api_key) as client:
|
| 218 |
-
try:
|
| 219 |
-
resp = await client.get(f"/{workspace}/{project_id}")
|
| 220 |
-
resp.raise_for_status()
|
| 221 |
-
data = resp.json()
|
| 222 |
-
except httpx.HTTPStatusError as e:
|
| 223 |
-
if e.response.status_code == 404:
|
| 224 |
-
return None
|
| 225 |
-
raise
|
| 226 |
-
|
| 227 |
-
proj_data = data.get("project", data)
|
| 228 |
-
ds = RoboflowAdapter._normalise_project(proj_data, workspace)
|
| 229 |
-
await _cache_set(ck, ds.model_dump())
|
| 230 |
-
return ds
|
| 231 |
-
|
| 232 |
-
# ββ Download URL builder ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
-
|
| 234 |
-
@staticmethod
|
| 235 |
-
async def get_download_url(
|
| 236 |
-
api_key: str,
|
| 237 |
-
workspace: str,
|
| 238 |
-
project_id: str,
|
| 239 |
-
version: int,
|
| 240 |
-
export_format: str = "yolov8",
|
| 241 |
-
) -> str:
|
| 242 |
-
"""
|
| 243 |
-
Fetch the export download link from Roboflow for the specified format.
|
| 244 |
-
Uses the official Roboflow SDK to handle authentication and URL resolution.
|
| 245 |
-
"""
|
| 246 |
-
try:
|
| 247 |
-
from roboflow import Roboflow
|
| 248 |
-
rf = Roboflow(api_key=api_key)
|
| 249 |
-
project = rf.workspace(workspace).project(project_id)
|
| 250 |
-
version_obj = project.version(version)
|
| 251 |
-
|
| 252 |
-
# The SDK's download method usually downloads to disk,
|
| 253 |
-
# but we can get the underlying export info.
|
| 254 |
-
# We'll use a thread to run the SDK call since it's blocking.
|
| 255 |
-
import asyncio
|
| 256 |
-
def _get_link():
|
| 257 |
-
return version_obj.export(export_format).download_link
|
| 258 |
-
|
| 259 |
-
link = await asyncio.to_thread(_get_link)
|
| 260 |
-
if not link:
|
| 261 |
-
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
|
| 262 |
-
return link
|
| 263 |
-
except Exception as e:
|
| 264 |
-
log.error("roboflow_sdk_error", error=str(e))
|
| 265 |
-
# Fallback to manual API if SDK fails or isn't installed correctly
|
| 266 |
-
async with _make_client(api_key) as client:
|
| 267 |
-
resp = await client.get(
|
| 268 |
-
f"/{workspace}/{project_id}/{version}/{export_format}"
|
| 269 |
-
)
|
| 270 |
-
resp.raise_for_status()
|
| 271 |
-
data = resp.json()
|
| 272 |
-
|
| 273 |
-
link = export.get("link") or ""
|
| 274 |
-
if not link:
|
| 275 |
-
# If 'link' is missing, check if it's a Universe-style project and try to resolve manually
|
| 276 |
-
# Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key]
|
| 277 |
-
if "project" in data:
|
| 278 |
-
pid = data["project"].get("id")
|
| 279 |
-
if pid:
|
| 280 |
-
link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}"
|
| 281 |
-
|
| 282 |
-
if not link:
|
| 283 |
-
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
|
| 284 |
-
|
| 285 |
-
# Ensure the link includes the API key correctly
|
| 286 |
-
if "universe.roboflow.com" in link:
|
| 287 |
-
if "key=" not in link:
|
| 288 |
-
separator = "&" if "?" in link else "?"
|
| 289 |
-
link = f"{link}{separator}key={api_key}"
|
| 290 |
-
elif f"key={api_key}" not in link:
|
| 291 |
-
# Replace old key if it exists but is wrong
|
| 292 |
-
import re
|
| 293 |
-
link = re.sub(r"key=[^&]+", f"key={api_key}", link)
|
| 294 |
-
|
| 295 |
-
return link
|
| 296 |
-
|
| 297 |
-
# ββ Normalisation helpers βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 298 |
-
|
| 299 |
-
@staticmethod
|
| 300 |
-
def _normalise_search_result(item: dict[str, Any]) -> Dataset:
|
| 301 |
-
"""Map a Universe search result β Dataset."""
|
| 302 |
-
ann_type = item.get("annotation", {}).get("type", "object-detection")
|
| 303 |
-
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
|
| 304 |
-
class_names = [c.get("name", "") for c in item.get("classes", [])]
|
| 305 |
-
images = item.get("images", 0) or 0
|
| 306 |
-
|
| 307 |
-
return Dataset(
|
| 308 |
-
id = item.get("id", "").replace("/", "__"),
|
| 309 |
-
name = item.get("name", "Unnamed"),
|
| 310 |
-
description = item.get("description", ""),
|
| 311 |
-
task = rf_task,
|
| 312 |
-
format = DatasetFormat.yolo,
|
| 313 |
-
source = DatasetSource.roboflow,
|
| 314 |
-
status = DatasetStatus.available,
|
| 315 |
-
images = images,
|
| 316 |
-
classes = len(class_names),
|
| 317 |
-
class_names = class_names,
|
| 318 |
-
size_bytes = 0,
|
| 319 |
-
size_label = "β",
|
| 320 |
-
tags = item.get("tags", []),
|
| 321 |
-
roboflow_id = item.get("id", ""),
|
| 322 |
-
created_at = item.get("created", ""),
|
| 323 |
-
updated_at = item.get("updated", ""),
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
@staticmethod
|
| 327 |
-
def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset:
|
| 328 |
-
"""Map a workspace project β Dataset."""
|
| 329 |
-
ann_type = proj.get("annotation", "object-detection")
|
| 330 |
-
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
|
| 331 |
-
class_names = [c.get("name", c) if isinstance(c, dict) else c
|
| 332 |
-
for c in proj.get("classes", [])]
|
| 333 |
-
project_id = proj.get("id", proj.get("name", "unknown"))
|
| 334 |
-
rf_id = f"{workspace}/{project_id}"
|
| 335 |
-
images = proj.get("images", 0) or 0
|
| 336 |
-
|
| 337 |
-
return Dataset(
|
| 338 |
-
id = rf_id.replace("/", "__"),
|
| 339 |
-
name = proj.get("name", project_id),
|
| 340 |
-
description = proj.get("description", ""),
|
| 341 |
-
task = rf_task,
|
| 342 |
-
format = DatasetFormat.yolo,
|
| 343 |
-
source = DatasetSource.roboflow,
|
| 344 |
-
status = DatasetStatus.available,
|
| 345 |
-
images = images,
|
| 346 |
-
classes = len(class_names),
|
| 347 |
-
class_names = class_names,
|
| 348 |
-
size_bytes = 0,
|
| 349 |
-
size_label = "β",
|
| 350 |
-
roboflow_id = rf_id,
|
| 351 |
-
created_at = proj.get("created", ""),
|
| 352 |
-
updated_at = proj.get("updated", ""),
|
| 353 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/benchmark.py
DELETED
|
@@ -1,238 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
api/routes/benchmark.py β Benchmark Bridge REST + WebSocket API.
|
| 3 |
-
|
| 4 |
-
Routes:
|
| 5 |
-
POST /benchmark/validate β compatibility check (no job created)
|
| 6 |
-
POST /benchmark/run β validate + create + enqueue (202)
|
| 7 |
-
GET /benchmark/jobs β list jobs (filterable)
|
| 8 |
-
GET /benchmark/results/all β list all results
|
| 9 |
-
GET /benchmark/{job_id} β single job status + logs
|
| 10 |
-
GET /benchmark/{job_id}/result β metrics + telemetry for completed job
|
| 11 |
-
WS /benchmark/live/{job_id} β real-time progress stream
|
| 12 |
-
"""
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import asyncio
|
| 16 |
-
from typing import Any
|
| 17 |
-
|
| 18 |
-
from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect
|
| 19 |
-
|
| 20 |
-
import benchmark.orchestrator as orchestrator
|
| 21 |
-
import benchmark.registry as bench_reg
|
| 22 |
-
from models.benchmark import (
|
| 23 |
-
BenchmarkContext,
|
| 24 |
-
BenchmarkJob,
|
| 25 |
-
BenchmarkResult,
|
| 26 |
-
BenchmarkRunResponse,
|
| 27 |
-
ValidationReport,
|
| 28 |
-
)
|
| 29 |
-
from observability.logger import get_logger
|
| 30 |
-
|
| 31 |
-
log = get_logger("api.benchmark")
|
| 32 |
-
|
| 33 |
-
router = APIRouter(prefix="/benchmark", tags=["benchmark"])
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# ββ POST /benchmark/validate ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
-
|
| 38 |
-
@router.post(
|
| 39 |
-
"/validate",
|
| 40 |
-
response_model = ValidationReport,
|
| 41 |
-
summary = "Validate model β dataset compatibility",
|
| 42 |
-
description = (
|
| 43 |
-
"Runs all 5 compatibility gates (task, format, frameworkΓhardware, "
|
| 44 |
-
"VRAM, precision) and returns a structured report. "
|
| 45 |
-
"Does NOT create a benchmark job."
|
| 46 |
-
),
|
| 47 |
-
)
|
| 48 |
-
async def validate_benchmark(ctx: BenchmarkContext) -> ValidationReport:
|
| 49 |
-
try:
|
| 50 |
-
return await orchestrator.validate_context(ctx)
|
| 51 |
-
except HTTPException:
|
| 52 |
-
raise
|
| 53 |
-
except Exception as exc:
|
| 54 |
-
log.exception("validate_error")
|
| 55 |
-
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# ββ POST /benchmark/run βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
-
|
| 60 |
-
@router.post(
|
| 61 |
-
"/run",
|
| 62 |
-
response_model = BenchmarkRunResponse,
|
| 63 |
-
status_code = 202,
|
| 64 |
-
summary = "Start a benchmark run",
|
| 65 |
-
description = (
|
| 66 |
-
"Validates compatibility, creates a benchmark job, and starts async "
|
| 67 |
-
"execution. Returns job_id immediately β poll GET /benchmark/{job_id} "
|
| 68 |
-
"or connect to WS /benchmark/live/{job_id} for progress."
|
| 69 |
-
),
|
| 70 |
-
)
|
| 71 |
-
async def run_benchmark(ctx: BenchmarkContext) -> BenchmarkRunResponse:
|
| 72 |
-
try:
|
| 73 |
-
job = await orchestrator.create_and_run(ctx)
|
| 74 |
-
return BenchmarkRunResponse(
|
| 75 |
-
job_id = job.id,
|
| 76 |
-
status = job.status,
|
| 77 |
-
message = f"Benchmark job {job.id} queued β connect to /benchmark/live/{job.id} for live telemetry",
|
| 78 |
-
)
|
| 79 |
-
except HTTPException:
|
| 80 |
-
raise
|
| 81 |
-
except Exception as exc:
|
| 82 |
-
log.exception("run_benchmark_error")
|
| 83 |
-
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ββ POST /benchmark/sync ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
-
|
| 88 |
-
@router.post(
|
| 89 |
-
"/sync",
|
| 90 |
-
summary = "Sync benchmarks from active project folder",
|
| 91 |
-
description = "Scans the active project's 'benchmarks' folder and ensures all JSON records are indexed in SQLite.",
|
| 92 |
-
)
|
| 93 |
-
async def sync_benchmarks() -> dict[str, Any]:
|
| 94 |
-
try:
|
| 95 |
-
count = await orchestrator.sync_project_benchmarks()
|
| 96 |
-
return {"status": "success", "count": count}
|
| 97 |
-
except Exception as exc:
|
| 98 |
-
log.exception("sync_error")
|
| 99 |
-
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
# ββ GET /benchmark/jobs βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
-
|
| 104 |
-
@router.get(
|
| 105 |
-
"/jobs",
|
| 106 |
-
response_model = list[BenchmarkJob],
|
| 107 |
-
summary = "List benchmark jobs",
|
| 108 |
-
)
|
| 109 |
-
async def list_jobs(
|
| 110 |
-
status: str | None = Query(None, description="Filter by status (queued|running|completed|failed)"),
|
| 111 |
-
model_id: str | None = Query(None, description="Filter by model_id"),
|
| 112 |
-
limit: int = Query(100, ge=1, le=500),
|
| 113 |
-
) -> list[BenchmarkJob]:
|
| 114 |
-
return await bench_reg.list_jobs(status=status, model_id=model_id, limit=limit)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# ββ GET /benchmark/results/all ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
-
# Must be declared BEFORE /{job_id} to avoid "results" being treated as a job_id
|
| 119 |
-
|
| 120 |
-
@router.get(
|
| 121 |
-
"/results/all",
|
| 122 |
-
response_model = list[BenchmarkResult],
|
| 123 |
-
summary = "List all benchmark results (leaderboard feed)",
|
| 124 |
-
)
|
| 125 |
-
async def list_results(
|
| 126 |
-
limit: int = Query(100, ge=1, le=500),
|
| 127 |
-
) -> list[BenchmarkResult]:
|
| 128 |
-
return await bench_reg.list_results(limit=limit)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
# ββ GET /benchmark/{job_id} βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
-
|
| 133 |
-
@router.get(
|
| 134 |
-
"/{job_id}",
|
| 135 |
-
response_model = BenchmarkJob,
|
| 136 |
-
summary = "Get benchmark job status + logs",
|
| 137 |
-
)
|
| 138 |
-
async def get_job(job_id: str) -> BenchmarkJob:
|
| 139 |
-
job = await bench_reg.get_job(job_id)
|
| 140 |
-
if not job:
|
| 141 |
-
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
| 142 |
-
return job
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
# ββ GET /benchmark/{job_id}/result βββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
-
|
| 147 |
-
@router.get(
|
| 148 |
-
"/{job_id}/result",
|
| 149 |
-
response_model = BenchmarkResult,
|
| 150 |
-
summary = "Get final metrics + telemetry for a completed job",
|
| 151 |
-
)
|
| 152 |
-
async def get_result(job_id: str) -> BenchmarkResult:
|
| 153 |
-
result = await bench_reg.get_result(job_id)
|
| 154 |
-
if not result:
|
| 155 |
-
raise HTTPException(
|
| 156 |
-
status_code = 404,
|
| 157 |
-
detail = f"No result for job '{job_id}' β job may still be running",
|
| 158 |
-
)
|
| 159 |
-
return result
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
# ββ WS /benchmark/live/{job_id} βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
-
|
| 164 |
-
@router.websocket("/live/{job_id}")
|
| 165 |
-
async def live_telemetry(websocket: WebSocket, job_id: str) -> None:
|
| 166 |
-
"""
|
| 167 |
-
WebSocket stream for real-time benchmark progress.
|
| 168 |
-
Streams incremental logs and high-frequency telemetry.
|
| 169 |
-
"""
|
| 170 |
-
await websocket.accept()
|
| 171 |
-
log.info("ws_connected", job_id=job_id)
|
| 172 |
-
|
| 173 |
-
last_log_idx = 0
|
| 174 |
-
|
| 175 |
-
try:
|
| 176 |
-
while True:
|
| 177 |
-
job = await bench_reg.get_job(job_id)
|
| 178 |
-
|
| 179 |
-
if not job:
|
| 180 |
-
await websocket.send_json(
|
| 181 |
-
{"error": f"Job '{job_id}' not found", "job_id": job_id}
|
| 182 |
-
)
|
| 183 |
-
break
|
| 184 |
-
|
| 185 |
-
# Only send new logs
|
| 186 |
-
new_logs = job.logs[last_log_idx:]
|
| 187 |
-
last_log_idx = len(job.logs)
|
| 188 |
-
|
| 189 |
-
payload: dict[str, Any] = {
|
| 190 |
-
"job_id": job.id,
|
| 191 |
-
"status": job.status,
|
| 192 |
-
"progress": round(job.progress, 4),
|
| 193 |
-
"logs": new_logs,
|
| 194 |
-
"telemetry": job.last_telemetry.model_dump() if job.last_telemetry else None,
|
| 195 |
-
}
|
| 196 |
-
# Explicitly include detections for the UI visualizer if they exist
|
| 197 |
-
if job.last_telemetry and hasattr(job.last_telemetry, "detections"):
|
| 198 |
-
payload["detections"] = job.last_telemetry.detections
|
| 199 |
-
|
| 200 |
-
await websocket.send_json(payload)
|
| 201 |
-
|
| 202 |
-
if job.status == "completed":
|
| 203 |
-
result = await bench_reg.get_result(job_id)
|
| 204 |
-
if result:
|
| 205 |
-
await websocket.send_json(
|
| 206 |
-
{
|
| 207 |
-
"job_id": job_id,
|
| 208 |
-
"status": "completed",
|
| 209 |
-
"result": result.model_dump(),
|
| 210 |
-
}
|
| 211 |
-
)
|
| 212 |
-
break
|
| 213 |
-
|
| 214 |
-
if job.status == "failed":
|
| 215 |
-
await websocket.send_json(
|
| 216 |
-
{
|
| 217 |
-
"job_id": job_id,
|
| 218 |
-
"status": "failed",
|
| 219 |
-
"error": job.error or "Unknown error",
|
| 220 |
-
}
|
| 221 |
-
)
|
| 222 |
-
break
|
| 223 |
-
|
| 224 |
-
await asyncio.sleep(0.5) # poll at 2 Hz
|
| 225 |
-
|
| 226 |
-
except WebSocketDisconnect:
|
| 227 |
-
log.info("ws_disconnected", job_id=job_id)
|
| 228 |
-
except Exception as exc:
|
| 229 |
-
log.exception("ws_error", job_id=job_id)
|
| 230 |
-
try:
|
| 231 |
-
await websocket.send_json({"error": str(exc), "job_id": job_id})
|
| 232 |
-
except Exception:
|
| 233 |
-
pass
|
| 234 |
-
finally:
|
| 235 |
-
try:
|
| 236 |
-
await websocket.close()
|
| 237 |
-
except Exception:
|
| 238 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/inference.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
api/routes/inference.py β Inference Engine endpoints.
|
| 3 |
-
|
| 4 |
-
POST /inference/run β single synchronous inference pass
|
| 5 |
-
POST /inference/stream β SSE stream (stage-by-stage pipeline events)
|
| 6 |
-
GET /inference/history β session ledger
|
| 7 |
-
DELETE /inference/history β clear session ledger
|
| 8 |
-
GET /inference/cache β currently warm models in memory
|
| 9 |
-
DELETE /inference/cache/{model_id} β evict from cache
|
| 10 |
-
"""
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
-
import asyncio
|
| 14 |
-
import json
|
| 15 |
-
import time
|
| 16 |
-
|
| 17 |
-
from fastapi import APIRouter, HTTPException, Response
|
| 18 |
-
from fastapi.responses import StreamingResponse
|
| 19 |
-
|
| 20 |
-
from inference.engine import InferenceEngine, evict_model, get_cache_status
|
| 21 |
-
from inference.session import clear_history, get_history, record
|
| 22 |
-
from models.inference import (
|
| 23 |
-
InferenceHistoryEntry,
|
| 24 |
-
InferenceRequest,
|
| 25 |
-
InferenceResult,
|
| 26 |
-
)
|
| 27 |
-
from observability.logger import get_logger
|
| 28 |
-
from registry.registry import get_model
|
| 29 |
-
|
| 30 |
-
log = get_logger("api.inference")
|
| 31 |
-
router = APIRouter(prefix="/inference", tags=["inference"])
|
| 32 |
-
|
| 33 |
-
_engine = InferenceEngine()
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# ββ Single run βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
-
|
| 38 |
-
@router.post("/run", response_model=InferenceResult)
|
| 39 |
-
async def run_inference(body: InferenceRequest) -> InferenceResult:
|
| 40 |
-
"""Execute one full inference pass and return the complete result."""
|
| 41 |
-
model = await get_model(body.model_id)
|
| 42 |
-
if not model:
|
| 43 |
-
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 44 |
-
|
| 45 |
-
result = await _engine.run(body, model)
|
| 46 |
-
|
| 47 |
-
if result.status == "error":
|
| 48 |
-
raise HTTPException(status_code=500, detail=result.error or "Inference failed")
|
| 49 |
-
|
| 50 |
-
await record(body, result, model.name)
|
| 51 |
-
return result
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# ββ SSE stream βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
-
|
| 56 |
-
@router.post("/stream")
|
| 57 |
-
async def stream_inference(body: InferenceRequest) -> StreamingResponse:
|
| 58 |
-
"""
|
| 59 |
-
Server-Sent Events stream.
|
| 60 |
-
Emits one JSON event per pipeline stage as it completes, then a final
|
| 61 |
-
'done' event with the full InferenceResult.
|
| 62 |
-
|
| 63 |
-
Client usage:
|
| 64 |
-
const es = new EventSource('/inference/stream'); // POST via fetch + EventSource polyfill
|
| 65 |
-
es.onmessage = e => console.log(JSON.parse(e.data));
|
| 66 |
-
"""
|
| 67 |
-
model = await get_model(body.model_id)
|
| 68 |
-
if not model:
|
| 69 |
-
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 70 |
-
|
| 71 |
-
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
| 72 |
-
|
| 73 |
-
async def _producer() -> None:
|
| 74 |
-
"""Run inference while pushing SSE events into the queue."""
|
| 75 |
-
try:
|
| 76 |
-
# Patch engine to emit stage events
|
| 77 |
-
result = await _engine_stream(body, model, queue)
|
| 78 |
-
await record(body, result, model.name)
|
| 79 |
-
# Final complete event
|
| 80 |
-
await queue.put(
|
| 81 |
-
f"event: done\ndata: {result.model_dump_json()}\n\n"
|
| 82 |
-
)
|
| 83 |
-
except Exception as exc:
|
| 84 |
-
await queue.put(
|
| 85 |
-
f"event: error\ndata: {json.dumps({'error': str(exc)})}\n\n"
|
| 86 |
-
)
|
| 87 |
-
finally:
|
| 88 |
-
await queue.put(None) # sentinel
|
| 89 |
-
|
| 90 |
-
asyncio.create_task(_producer())
|
| 91 |
-
|
| 92 |
-
async def _generator():
|
| 93 |
-
while True:
|
| 94 |
-
msg = await queue.get()
|
| 95 |
-
if msg is None:
|
| 96 |
-
break
|
| 97 |
-
yield msg
|
| 98 |
-
|
| 99 |
-
return StreamingResponse(
|
| 100 |
-
_generator(),
|
| 101 |
-
media_type="text/event-stream",
|
| 102 |
-
headers={
|
| 103 |
-
"Cache-Control": "no-cache",
|
| 104 |
-
"X-Accel-Buffering": "no",
|
| 105 |
-
},
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
async def _engine_stream(
|
| 110 |
-
req: InferenceRequest,
|
| 111 |
-
model,
|
| 112 |
-
queue: asyncio.Queue,
|
| 113 |
-
) -> InferenceResult:
|
| 114 |
-
"""
|
| 115 |
-
Run inference and push a 'stage' SSE event for each PipelineStage.
|
| 116 |
-
Falls back to a simple full run if streaming is not distinguishable.
|
| 117 |
-
"""
|
| 118 |
-
# Run full pipeline
|
| 119 |
-
result = await _engine.run(req, model)
|
| 120 |
-
|
| 121 |
-
# Emit one event per stage (replay after completion)
|
| 122 |
-
for stage in result.pipeline:
|
| 123 |
-
payload = json.dumps({
|
| 124 |
-
"type": "stage",
|
| 125 |
-
"stage": stage.model_dump(),
|
| 126 |
-
"ts": time.time(),
|
| 127 |
-
})
|
| 128 |
-
await queue.put(f"data: {payload}\n\n")
|
| 129 |
-
await asyncio.sleep(0) # yield
|
| 130 |
-
|
| 131 |
-
# Emit vitals snapshot
|
| 132 |
-
vitals_payload = json.dumps({
|
| 133 |
-
"type": "vitals",
|
| 134 |
-
"latency_ms": result.inference_ms,
|
| 135 |
-
"total_ms": result.total_ms,
|
| 136 |
-
"quality": result.quality_score,
|
| 137 |
-
})
|
| 138 |
-
await queue.put(f"data: {vitals_payload}\n\n")
|
| 139 |
-
|
| 140 |
-
return result
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# ββ History ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
-
|
| 145 |
-
@router.get("/history", response_model=list[InferenceHistoryEntry])
|
| 146 |
-
async def inference_history(limit: int = 50) -> list[InferenceHistoryEntry]:
|
| 147 |
-
return await get_history(limit=min(limit, 200))
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
@router.delete("/history", status_code=204, response_model=None)
|
| 151 |
-
async def clear_inference_history():
|
| 152 |
-
await clear_history()
|
| 153 |
-
return Response(status_code=204)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
# ββ Model cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
-
|
| 158 |
-
@router.get("/cache")
|
| 159 |
-
async def cache_status() -> dict[str, bool]:
|
| 160 |
-
return get_cache_status()
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
@router.delete("/cache/{model_id}", status_code=204, response_model=None)
|
| 164 |
-
async def evict_from_cache(model_id: str):
|
| 165 |
-
evicted = evict_model(model_id)
|
| 166 |
-
if not evicted:
|
| 167 |
-
raise HTTPException(status_code=404, detail="Model not in cache")
|
| 168 |
-
return Response(status_code=204)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/jobs.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
api/routes/jobs.py β /jobs & /download endpoints.
|
| 3 |
-
"""
|
| 4 |
-
from __future__ import annotations
|
| 5 |
-
|
| 6 |
-
from fastapi import APIRouter, HTTPException
|
| 7 |
-
|
| 8 |
-
from download.manager import cancel_job, enqueue_download, get_job, list_jobs
|
| 9 |
-
from models.job import Job, JobCreate
|
| 10 |
-
from observability.logger import audit, get_logger
|
| 11 |
-
from registry.registry import get_model
|
| 12 |
-
|
| 13 |
-
log = get_logger("api.jobs")
|
| 14 |
-
router = APIRouter(tags=["jobs"])
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@router.post("/download", response_model=Job, status_code=202)
|
| 18 |
-
async def trigger_download(body: JobCreate) -> Job:
|
| 19 |
-
"""Enqueue a model download. Returns the created job immediately."""
|
| 20 |
-
model = await get_model(body.model_id)
|
| 21 |
-
if not model:
|
| 22 |
-
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 23 |
-
if model.downloaded:
|
| 24 |
-
raise HTTPException(status_code=409, detail="Model is already cached locally")
|
| 25 |
-
|
| 26 |
-
job_id = await enqueue_download(
|
| 27 |
-
model_id=body.model_id,
|
| 28 |
-
model_name=body.model_name,
|
| 29 |
-
version=body.version,
|
| 30 |
-
)
|
| 31 |
-
job = await get_job(job_id)
|
| 32 |
-
if not job:
|
| 33 |
-
raise HTTPException(status_code=500, detail="Job creation failed")
|
| 34 |
-
|
| 35 |
-
await audit("api_download_trigger", model_id=body.model_id, job_id=job_id)
|
| 36 |
-
return job
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
@router.get("/jobs", response_model=list[Job])
|
| 40 |
-
async def jobs_list(status: str | None = None, limit: int = 50) -> list[Job]:
|
| 41 |
-
return await list_jobs(status=status, limit=limit)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@router.get("/jobs/{job_id}", response_model=Job)
|
| 45 |
-
async def job_detail(job_id: str) -> Job:
|
| 46 |
-
job = await get_job(job_id)
|
| 47 |
-
if not job:
|
| 48 |
-
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
| 49 |
-
return job
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@router.delete("/jobs/{job_id}", status_code=204, response_model=None)
|
| 53 |
-
async def job_cancel(job_id: str) -> None:
|
| 54 |
-
success = await cancel_job(job_id)
|
| 55 |
-
if not success:
|
| 56 |
-
raise HTTPException(status_code=409, detail="Job cannot be cancelled")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/system.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
"""api/routes/system.py β System metrics endpoints."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import asyncio
|
| 6 |
-
import json
|
| 7 |
-
|
| 8 |
-
from fastapi import APIRouter, Query
|
| 9 |
-
from fastapi.responses import StreamingResponse
|
| 10 |
-
|
| 11 |
-
from models.system import SystemMetrics
|
| 12 |
-
from system.metrics import sample_metrics
|
| 13 |
-
|
| 14 |
-
router = APIRouter(prefix="/system", tags=["system"])
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@router.get("/metrics", response_model=SystemMetrics)
|
| 18 |
-
async def get_metrics(gpu_index: int = Query(0, ge=0)) -> SystemMetrics:
|
| 19 |
-
payload = sample_metrics(gpu_index=gpu_index)
|
| 20 |
-
return SystemMetrics(
|
| 21 |
-
ts=payload["ts"],
|
| 22 |
-
cpu_pct=payload["cpu_pct"],
|
| 23 |
-
cpu_model=payload.get("cpu_model"),
|
| 24 |
-
cpu_freq_mhz=payload.get("cpu_freq_mhz"),
|
| 25 |
-
cpu_count=payload.get("cpu_count"),
|
| 26 |
-
ram_used_mb=payload["ram_used_mb"],
|
| 27 |
-
ram_total_mb=payload["ram_total_mb"],
|
| 28 |
-
gpu=payload.get("gpu"),
|
| 29 |
-
disks=payload.get("disks", []),
|
| 30 |
-
network=payload.get("network", []),
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@router.get("/metrics/stream")
|
| 35 |
-
async def stream_metrics(
|
| 36 |
-
gpu_index: int = Query(0, ge=0),
|
| 37 |
-
hz: float = Query(2.0, ge=0.2, le=20.0),
|
| 38 |
-
):
|
| 39 |
-
"""Server-Sent Events stream of system metrics."""
|
| 40 |
-
|
| 41 |
-
interval = 1.0 / float(hz)
|
| 42 |
-
|
| 43 |
-
async def gen():
|
| 44 |
-
# Initial comment helps some proxies establish the stream
|
| 45 |
-
yield ": connected\n\n"
|
| 46 |
-
while True:
|
| 47 |
-
try:
|
| 48 |
-
payload = sample_metrics(gpu_index=gpu_index)
|
| 49 |
-
# Ensure the payload is valid JSON and wrapped in data: format
|
| 50 |
-
data = json.dumps(payload)
|
| 51 |
-
yield f"data: {data}\n\n"
|
| 52 |
-
except Exception as e:
|
| 53 |
-
# Log error but keep stream alive
|
| 54 |
-
print(f"Metrics streaming error: {e}")
|
| 55 |
-
await asyncio.sleep(interval)
|
| 56 |
-
|
| 57 |
-
return StreamingResponse(
|
| 58 |
-
gen(),
|
| 59 |
-
media_type="text/event-stream",
|
| 60 |
-
headers={
|
| 61 |
-
"Cache-Control": "no-cache",
|
| 62 |
-
"X-Accel-Buffering": "no",
|
| 63 |
-
"Connection": "keep-alive",
|
| 64 |
-
"Transfer-Encoding": "chunked",
|
| 65 |
-
},
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
@router.get("/logs/stream")
|
| 70 |
-
async def stream_system_logs():
|
| 71 |
-
"""SSE stream of global system and gateway logs."""
|
| 72 |
-
from observability.logger import _sys_log_subs
|
| 73 |
-
|
| 74 |
-
q: asyncio.Queue = asyncio.Queue()
|
| 75 |
-
_sys_log_subs.append(q)
|
| 76 |
-
|
| 77 |
-
async def generator():
|
| 78 |
-
yield ": connected\n\n"
|
| 79 |
-
try:
|
| 80 |
-
while True:
|
| 81 |
-
try:
|
| 82 |
-
entry = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 83 |
-
except asyncio.TimeoutError:
|
| 84 |
-
yield ": heartbeat\n\n"
|
| 85 |
-
continue
|
| 86 |
-
if entry is None:
|
| 87 |
-
break
|
| 88 |
-
yield f"data: {json.dumps(entry)}\n\n"
|
| 89 |
-
finally:
|
| 90 |
-
if q in _sys_log_subs:
|
| 91 |
-
_sys_log_subs.remove(q)
|
| 92 |
-
|
| 93 |
-
return StreamingResponse(
|
| 94 |
-
generator(),
|
| 95 |
-
media_type="text/event-stream",
|
| 96 |
-
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 97 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/training.py
DELETED
|
@@ -1,428 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
api/routes/training.py β Training Engine REST + SSE endpoints.
|
| 3 |
-
|
| 4 |
-
POST /train/start β create and launch a training run
|
| 5 |
-
POST /train/stop β cancel a running run
|
| 6 |
-
POST /train/pause β pause a running run
|
| 7 |
-
POST /train/resume β resume a paused run
|
| 8 |
-
GET /train/status β run status + progress snapshot
|
| 9 |
-
GET /train/runs β list all runs
|
| 10 |
-
GET /train/runs/{run_id} β single run detail
|
| 11 |
-
GET /train/schema β UI schema for task/model/dataset combo
|
| 12 |
-
GET /train/checkpoints β checkpoints for a run (stub)
|
| 13 |
-
POST /train/checkpoints/{id}/export β export a checkpoint (stub)
|
| 14 |
-
GET /train/metrics/stream β SSE: real-time metrics ticks
|
| 15 |
-
GET /train/logs/stream β SSE: real-time log entries
|
| 16 |
-
GET /train/resources/stream β SSE: real-time resource ticks
|
| 17 |
-
"""
|
| 18 |
-
from __future__ import annotations
|
| 19 |
-
|
| 20 |
-
import asyncio
|
| 21 |
-
import json
|
| 22 |
-
import time
|
| 23 |
-
import os
|
| 24 |
-
|
| 25 |
-
from fastapi import APIRouter, HTTPException, Query
|
| 26 |
-
from fastapi.responses import StreamingResponse
|
| 27 |
-
|
| 28 |
-
from observability.logger import get_logger
|
| 29 |
-
from training import run_manager
|
| 30 |
-
from training.schema_engine import generate_schema
|
| 31 |
-
from training.schemas import (
|
| 32 |
-
CheckpointOut,
|
| 33 |
-
PauseTrainRequest,
|
| 34 |
-
ResumeTrainRequest,
|
| 35 |
-
StartTrainRequest,
|
| 36 |
-
StartTrainResponse,
|
| 37 |
-
StopTrainRequest,
|
| 38 |
-
TrainRunOut,
|
| 39 |
-
TrainStatusResponse,
|
| 40 |
-
TrainingSchemaResponse,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
log = get_logger("api.training")
|
| 44 |
-
router = APIRouter(prefix="/train", tags=["training"])
|
| 45 |
-
|
| 46 |
-
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
-
|
| 48 |
-
def _format_duration(seconds: float) -> str:
|
| 49 |
-
h = int(seconds // 3600)
|
| 50 |
-
m = int((seconds % 3600) // 60)
|
| 51 |
-
s = int(seconds % 60)
|
| 52 |
-
return f"{h}h {m}m {s}s"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _run_to_out(run: run_manager.TrainRun) -> TrainRunOut:
|
| 56 |
-
elapsed = (run.completed_at or time.time()) - run.created_at
|
| 57 |
-
return TrainRunOut(
|
| 58 |
-
id=run.run_id,
|
| 59 |
-
run_number=run.run_number,
|
| 60 |
-
model_id=run.model_id,
|
| 61 |
-
model_name=run.model_name,
|
| 62 |
-
dataset_id=run.dataset_id,
|
| 63 |
-
dataset_name=run.dataset_name,
|
| 64 |
-
task=run.task,
|
| 65 |
-
status=run.status,
|
| 66 |
-
epochs_done=run.epoch,
|
| 67 |
-
total_epochs=run.total_epochs,
|
| 68 |
-
best_metric=run.best_metric,
|
| 69 |
-
final_loss=run.final_loss,
|
| 70 |
-
duration=_format_duration(elapsed),
|
| 71 |
-
created_at=run.created_at,
|
| 72 |
-
completed_at=run.completed_at,
|
| 73 |
-
hyperparams=run.hyperparams,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
# ββ Control endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
-
|
| 79 |
-
@router.post("/start", response_model=StartTrainResponse)
|
| 80 |
-
async def start_training(body: StartTrainRequest) -> StartTrainResponse:
|
| 81 |
-
"""Create and immediately launch a training run."""
|
| 82 |
-
# Resolve friendly names (fall back to ids if registries unavailable)
|
| 83 |
-
model_name = body.model_id
|
| 84 |
-
dataset_name = body.dataset_id
|
| 85 |
-
try:
|
| 86 |
-
from registry.registry import get_model
|
| 87 |
-
m = await get_model(body.model_id)
|
| 88 |
-
if m:
|
| 89 |
-
model_name = m.name
|
| 90 |
-
except Exception:
|
| 91 |
-
pass
|
| 92 |
-
try:
|
| 93 |
-
from datasets.registry import get_dataset
|
| 94 |
-
d = await get_dataset(body.dataset_id)
|
| 95 |
-
if d:
|
| 96 |
-
dataset_name = d.get("name", body.dataset_id) if isinstance(d, dict) else getattr(d, "name", body.dataset_id)
|
| 97 |
-
except Exception:
|
| 98 |
-
pass
|
| 99 |
-
|
| 100 |
-
run = run_manager.create_run(
|
| 101 |
-
model_id=body.model_id,
|
| 102 |
-
model_name=model_name,
|
| 103 |
-
dataset_id=body.dataset_id,
|
| 104 |
-
dataset_name=dataset_name,
|
| 105 |
-
task=body.task,
|
| 106 |
-
hyperparams=body.hyperparams,
|
| 107 |
-
augmentation=body.augmentation,
|
| 108 |
-
scheduler=body.scheduler,
|
| 109 |
-
project_id=body.project_id
|
| 110 |
-
)
|
| 111 |
-
run_manager.start_run(run)
|
| 112 |
-
|
| 113 |
-
log.info("training_started", run_id=run.run_id, model=body.model_id)
|
| 114 |
-
return StartTrainResponse(
|
| 115 |
-
run_id=run.run_id,
|
| 116 |
-
status=run.status,
|
| 117 |
-
message=f"Training run {run.run_id} started.",
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
@router.post("/stop", status_code=200)
|
| 122 |
-
async def stop_training(body: StopTrainRequest) -> dict:
|
| 123 |
-
run = run_manager.get_run(body.run_id)
|
| 124 |
-
if not run:
|
| 125 |
-
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 126 |
-
run_manager.stop_run(run)
|
| 127 |
-
log.info("training_stopped", run_id=body.run_id)
|
| 128 |
-
return {"run_id": body.run_id, "status": run.status}
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
@router.post("/pause", status_code=200)
|
| 132 |
-
async def pause_training(body: PauseTrainRequest) -> dict:
|
| 133 |
-
run = run_manager.get_run(body.run_id)
|
| 134 |
-
if not run:
|
| 135 |
-
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 136 |
-
run_manager.pause_run(run)
|
| 137 |
-
return {"run_id": body.run_id, "status": run.status}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
@router.post("/resume", status_code=200)
|
| 141 |
-
async def resume_training(body: ResumeTrainRequest) -> dict:
|
| 142 |
-
run = run_manager.get_run(body.run_id)
|
| 143 |
-
if not run:
|
| 144 |
-
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 145 |
-
run_manager.resume_run(run)
|
| 146 |
-
return {"run_id": body.run_id, "status": run.status}
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
@router.get("/status", response_model=TrainStatusResponse)
|
| 150 |
-
async def get_train_status(run_id: str = Query(...)) -> TrainStatusResponse:
|
| 151 |
-
run = run_manager.get_run(run_id)
|
| 152 |
-
if not run:
|
| 153 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 154 |
-
return TrainStatusResponse(
|
| 155 |
-
run_id=run.run_id,
|
| 156 |
-
status=run.status,
|
| 157 |
-
epoch=run.epoch,
|
| 158 |
-
total_epochs=run.total_epochs,
|
| 159 |
-
step=run.step,
|
| 160 |
-
total_steps=run.total_epochs * 100,
|
| 161 |
-
eta_seconds=run.eta_seconds,
|
| 162 |
-
elapsed_seconds=run.elapsed_seconds,
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# ββ Run history βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
-
|
| 168 |
-
@router.get("/runs", response_model=list[TrainRunOut])
|
| 169 |
-
async def list_runs() -> list[TrainRunOut]:
|
| 170 |
-
return [_run_to_out(r) for r in reversed(run_manager.list_runs())]
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
@router.get("/runs/{run_id}", response_model=TrainRunOut)
|
| 174 |
-
async def get_run(run_id: str) -> TrainRunOut:
|
| 175 |
-
run = run_manager.get_run(run_id)
|
| 176 |
-
if not run:
|
| 177 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 178 |
-
return _run_to_out(run)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# ββ Schema Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
-
|
| 183 |
-
@router.get("/schema", response_model=TrainingSchemaResponse)
|
| 184 |
-
async def get_schema(
|
| 185 |
-
model_id: str = Query(""),
|
| 186 |
-
dataset_id: str = Query(""),
|
| 187 |
-
task: str = Query("detection"),
|
| 188 |
-
) -> TrainingSchemaResponse:
|
| 189 |
-
schema = generate_schema(task=task, model_id=model_id, dataset_id=dataset_id)
|
| 190 |
-
return TrainingSchemaResponse(**schema)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# ββ Checkpoints (stub β extend when artifact storage is wired) ββββββββββββββββ
|
| 194 |
-
|
| 195 |
-
@router.get("/checkpoints", response_model=list[CheckpointOut])
|
| 196 |
-
async def list_checkpoints(run_id: str = Query(...)) -> list[CheckpointOut]:
|
| 197 |
-
"""Returns an empty list until checkpoint persistence is implemented."""
|
| 198 |
-
run = run_manager.get_run(run_id)
|
| 199 |
-
if not run:
|
| 200 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 201 |
-
return []
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
@router.post("/checkpoints/{checkpoint_id}/export")
|
| 205 |
-
async def export_checkpoint(checkpoint_id: str, body: dict = {}) -> dict:
|
| 206 |
-
raise HTTPException(status_code=501, detail="Checkpoint export not yet implemented")
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
# ββ SSE: Metrics stream ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
-
|
| 211 |
-
@router.get("/metrics/stream")
|
| 212 |
-
async def stream_metrics(run_id: str = Query(...)) -> StreamingResponse:
|
| 213 |
-
"""
|
| 214 |
-
Server-Sent Events stream of TrainMetricsTick objects.
|
| 215 |
-
Connects to the run's metrics queue and forwards each tick as SSE.
|
| 216 |
-
Stream closes when the run finishes (sentinel None pushed by worker).
|
| 217 |
-
"""
|
| 218 |
-
run = run_manager.get_run(run_id)
|
| 219 |
-
if not run:
|
| 220 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 221 |
-
|
| 222 |
-
q: asyncio.Queue = asyncio.Queue()
|
| 223 |
-
run.metrics_subs.append(q)
|
| 224 |
-
|
| 225 |
-
async def generator():
|
| 226 |
-
yield ": connected\n\n"
|
| 227 |
-
try:
|
| 228 |
-
while True:
|
| 229 |
-
try:
|
| 230 |
-
tick = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 231 |
-
except asyncio.TimeoutError:
|
| 232 |
-
# Heartbeat to keep connection alive
|
| 233 |
-
yield ": heartbeat\n\n"
|
| 234 |
-
continue
|
| 235 |
-
if tick is None:
|
| 236 |
-
break
|
| 237 |
-
yield f"data: {json.dumps(tick)}\n\n"
|
| 238 |
-
finally:
|
| 239 |
-
if q in run.metrics_subs:
|
| 240 |
-
run.metrics_subs.remove(q)
|
| 241 |
-
|
| 242 |
-
return StreamingResponse(
|
| 243 |
-
generator(),
|
| 244 |
-
media_type="text/event-stream",
|
| 245 |
-
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
# ββ SSE: Logs stream ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 250 |
-
|
| 251 |
-
@router.get("/logs/stream")
|
| 252 |
-
async def stream_logs(run_id: str = Query(...)) -> StreamingResponse:
|
| 253 |
-
"""Server-Sent Events stream of LogEntry objects."""
|
| 254 |
-
run = run_manager.get_run(run_id)
|
| 255 |
-
if not run:
|
| 256 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 257 |
-
|
| 258 |
-
q: asyncio.Queue = asyncio.Queue()
|
| 259 |
-
run.log_subs.append(q)
|
| 260 |
-
|
| 261 |
-
async def generator():
|
| 262 |
-
yield ": connected\n\n"
|
| 263 |
-
try:
|
| 264 |
-
while True:
|
| 265 |
-
try:
|
| 266 |
-
entry = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 267 |
-
except asyncio.TimeoutError:
|
| 268 |
-
yield ": heartbeat\n\n"
|
| 269 |
-
continue
|
| 270 |
-
if entry is None:
|
| 271 |
-
break
|
| 272 |
-
yield f"data: {json.dumps(entry)}\n\n"
|
| 273 |
-
finally:
|
| 274 |
-
if q in run.log_subs:
|
| 275 |
-
run.log_subs.remove(q)
|
| 276 |
-
|
| 277 |
-
return StreamingResponse(
|
| 278 |
-
generator(),
|
| 279 |
-
media_type="text/event-stream",
|
| 280 |
-
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
@router.get("/runs/{run_id}/history")
|
| 285 |
-
async def get_run_history(run_id: str) -> list[dict]:
|
| 286 |
-
"""Retrieves the full historical telemetry (metrics ticks) for a run."""
|
| 287 |
-
run = run_manager.get_run(run_id)
|
| 288 |
-
if not run:
|
| 289 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 290 |
-
|
| 291 |
-
from training.persistence import TrainingPersistence
|
| 292 |
-
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 293 |
-
telemetry_path = os.path.join(run_dir, "telemetry.jsonl")
|
| 294 |
-
|
| 295 |
-
history = []
|
| 296 |
-
if os.path.exists(telemetry_path):
|
| 297 |
-
try:
|
| 298 |
-
with open(telemetry_path, "r") as f:
|
| 299 |
-
for line in f:
|
| 300 |
-
if line.strip():
|
| 301 |
-
history.append(json.loads(line))
|
| 302 |
-
except Exception as e:
|
| 303 |
-
log.error("history_read_failed", run_id=run_id, error=str(e))
|
| 304 |
-
raise HTTPException(status_code=500, detail="Failed to read telemetry history")
|
| 305 |
-
|
| 306 |
-
return history
|
| 307 |
-
|
| 308 |
-
@router.get("/runs/{run_id}/artifacts")
|
| 309 |
-
async def list_run_artifacts(run_id: str) -> dict:
|
| 310 |
-
"""Lists available artifacts (images) for a specific run by scanning the directory."""
|
| 311 |
-
run = run_manager.get_run(run_id)
|
| 312 |
-
if not run:
|
| 313 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 314 |
-
|
| 315 |
-
from training.persistence import TrainingPersistence
|
| 316 |
-
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 317 |
-
|
| 318 |
-
if not os.path.exists(run_dir):
|
| 319 |
-
return {"artifacts": [], "batches": []}
|
| 320 |
-
|
| 321 |
-
artifacts = []
|
| 322 |
-
batches = []
|
| 323 |
-
|
| 324 |
-
# Standard YOLO artifact mappings for better UI titles
|
| 325 |
-
titles = {
|
| 326 |
-
"confusion_matrix.png": "Confusion Matrix",
|
| 327 |
-
"confusion_matrix_normalized.png": "Confusion Matrix (Norm)",
|
| 328 |
-
"results.png": "Results Summary",
|
| 329 |
-
"F1_curve.png": "F1 Curve",
|
| 330 |
-
"PR_curve.png": "PR Curve",
|
| 331 |
-
"P_curve.png": "Precision Curve",
|
| 332 |
-
"R_curve.png": "Recall Curve",
|
| 333 |
-
"BoxF1_curve.png": "Box F1 Curve",
|
| 334 |
-
"BoxP_curve.png": "Box Precision Curve",
|
| 335 |
-
"BoxPR_curve.png": "Box PR Curve",
|
| 336 |
-
"BoxR_curve.png": "Box Recall Curve",
|
| 337 |
-
"labels.jpg": "Labels Distribution",
|
| 338 |
-
"labels_correlogram.jpg": "Labels Correlogram"
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
for f in os.listdir(run_dir):
|
| 342 |
-
path = f"/train/runs/{run_id}/files/{f}"
|
| 343 |
-
if f.endswith(('.png', '.jpg', '.jpeg')):
|
| 344 |
-
item = {
|
| 345 |
-
"title": titles.get(f, f.replace('_', ' ').title().split('.')[0]),
|
| 346 |
-
"path": path,
|
| 347 |
-
"type": "Analysis"
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
if "batch" in f.lower():
|
| 351 |
-
item["type"] = "Batch Preview" if "val" in f.lower() else "Augmentation"
|
| 352 |
-
batches.append(item)
|
| 353 |
-
else:
|
| 354 |
-
if "curve" in f.lower():
|
| 355 |
-
item["type"] = "Precision-Recall"
|
| 356 |
-
elif "confusion" in f.lower():
|
| 357 |
-
item["type"] = "Analysis"
|
| 358 |
-
elif "results" in f.lower():
|
| 359 |
-
item["type"] = "Overall"
|
| 360 |
-
artifacts.append(item)
|
| 361 |
-
|
| 362 |
-
return {
|
| 363 |
-
"artifacts": sorted(artifacts, key=lambda x: x['title']),
|
| 364 |
-
"batches": sorted(batches, key=lambda x: x['title'])
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
@router.get("/runs/{run_id}/files/{filename}")
|
| 368 |
-
async def get_run_file(run_id: str, filename: str):
|
| 369 |
-
"""Serves a specific file from the run directory."""
|
| 370 |
-
run = run_manager.get_run(run_id)
|
| 371 |
-
if not run:
|
| 372 |
-
raise HTTPException(status_code=404, detail="Run not found")
|
| 373 |
-
|
| 374 |
-
# We need to find the project to get the run_dir
|
| 375 |
-
# Since run_manager doesn't easily expose the full path in memory,
|
| 376 |
-
# we recalculate it using persistence
|
| 377 |
-
from training.persistence import TrainingPersistence
|
| 378 |
-
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 379 |
-
file_path = os.path.join(run_dir, filename)
|
| 380 |
-
|
| 381 |
-
if not os.path.exists(file_path):
|
| 382 |
-
raise HTTPException(status_code=404, detail="File not found")
|
| 383 |
-
|
| 384 |
-
from fastapi.responses import FileResponse
|
| 385 |
-
return FileResponse(file_path)
|
| 386 |
-
# The frontend uses /system/metrics/stream for resources (already implemented).
|
| 387 |
-
# This alias exists for training-scoped resource monitoring.
|
| 388 |
-
|
| 389 |
-
@router.get("/resources/stream")
|
| 390 |
-
async def stream_resources(
|
| 391 |
-
run_id: str = Query(...),
|
| 392 |
-
gpu_index: int = Query(0, ge=0),
|
| 393 |
-
hz: float = Query(1.0, ge=0.2, le=10.0),
|
| 394 |
-
) -> StreamingResponse:
|
| 395 |
-
"""
|
| 396 |
-
SSE stream of ResourceTick objects for a specific training run.
|
| 397 |
-
Forwards system metrics at the requested hz rate.
|
| 398 |
-
"""
|
| 399 |
-
run = run_manager.get_run(run_id)
|
| 400 |
-
if not run:
|
| 401 |
-
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 402 |
-
|
| 403 |
-
q: asyncio.Queue = asyncio.Queue()
|
| 404 |
-
run.resource_subs.append(q)
|
| 405 |
-
|
| 406 |
-
interval = 1.0 / hz
|
| 407 |
-
|
| 408 |
-
async def generator():
|
| 409 |
-
yield ": connected\n\n"
|
| 410 |
-
try:
|
| 411 |
-
while True:
|
| 412 |
-
try:
|
| 413 |
-
tick = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 414 |
-
except asyncio.TimeoutError:
|
| 415 |
-
yield ": heartbeat\n\n"
|
| 416 |
-
continue
|
| 417 |
-
if tick is None:
|
| 418 |
-
break
|
| 419 |
-
yield f"data: {json.dumps(tick)}\n\n"
|
| 420 |
-
finally:
|
| 421 |
-
if q in run.resource_subs:
|
| 422 |
-
run.resource_subs.remove(q)
|
| 423 |
-
|
| 424 |
-
return StreamingResponse(
|
| 425 |
-
generator(),
|
| 426 |
-
media_type="text/event-stream",
|
| 427 |
-
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 428 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# benchmark β Benchmark Bridge System for MLForge
|
|
|
|
|
|
benchmark/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (144 Bytes)
|
|
|
benchmark/__pycache__/compatibility.cpython-310.pyc
DELETED
|
Binary file (8.3 kB)
|
|
|
benchmark/__pycache__/execution.cpython-310.pyc
DELETED
|
Binary file (10.4 kB)
|
|
|
benchmark/__pycache__/metrics.cpython-310.pyc
DELETED
|
Binary file (3.24 kB)
|
|
|
benchmark/__pycache__/orchestrator.cpython-310.pyc
DELETED
|
Binary file (9.11 kB)
|
|
|
benchmark/__pycache__/registry.cpython-310.pyc
DELETED
|
Binary file (8.77 kB)
|
|
|
benchmark/__pycache__/telemetry.cpython-310.pyc
DELETED
|
Binary file (6.73 kB)
|
|
|
benchmark/adapters/__pycache__/base.cpython-310.pyc
DELETED
|
Binary file (1.8 kB)
|
|
|
benchmark/adapters/__pycache__/registry.cpython-310.pyc
DELETED
|
Binary file (1.89 kB)
|
|
|
benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc
DELETED
|
Binary file (1.93 kB)
|
|
|
benchmark/adapters/base.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/adapters/base.py β Base class for all Benchmark Runners.
|
| 3 |
-
"""
|
| 4 |
-
from __future__ import annotations
|
| 5 |
-
|
| 6 |
-
from abc import ABC, abstractmethod
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
from typing import Any, AsyncGenerator
|
| 9 |
-
|
| 10 |
-
from models.benchmark import BenchmarkContext, TelemetrySample
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class BatchResult:
|
| 15 |
-
"""Result of a single batch execution."""
|
| 16 |
-
latency_ms: float
|
| 17 |
-
vram_used_gb: float
|
| 18 |
-
task_scores: dict[str, float] = field(default_factory=dict)
|
| 19 |
-
metadata: dict[str, Any] = field(default_factory=dict)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class BaseRunner(ABC):
|
| 23 |
-
"""Abstract interface for benchmark executors (Torch, Optimum, vLLM)."""
|
| 24 |
-
|
| 25 |
-
@abstractmethod
|
| 26 |
-
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 27 |
-
"""Load model and prepare environment."""
|
| 28 |
-
pass
|
| 29 |
-
|
| 30 |
-
@abstractmethod
|
| 31 |
-
async def run_batch(self, batch: Any) -> BatchResult:
|
| 32 |
-
"""Execute a single batch of data."""
|
| 33 |
-
pass
|
| 34 |
-
|
| 35 |
-
@abstractmethod
|
| 36 |
-
async def shutdown(self) -> None:
|
| 37 |
-
"""Release resources."""
|
| 38 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/adapters/optimum_runner.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/adapters/optimum_runner.py β Hugging Face Optimum Adapter.
|
| 3 |
-
Supports ONNX, OpenVINO, and TensorRT acceleration.
|
| 4 |
-
"""
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import asyncio
|
| 9 |
-
from typing import Any
|
| 10 |
-
from benchmark.adapters.base import BaseRunner, BatchResult
|
| 11 |
-
from models.benchmark import BenchmarkContext
|
| 12 |
-
from observability.logger import get_logger
|
| 13 |
-
|
| 14 |
-
log = get_logger("benchmark.optimum")
|
| 15 |
-
|
| 16 |
-
class OptimumRunner(BaseRunner):
|
| 17 |
-
def __init__(self):
|
| 18 |
-
self.session = None
|
| 19 |
-
self.device = "cpu"
|
| 20 |
-
|
| 21 |
-
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 22 |
-
"""
|
| 23 |
-
Load model using Optimum's ORTModel or equivalent.
|
| 24 |
-
In a real implementation, this would detect the framework and use:
|
| 25 |
-
ORTModelForFeatureExtraction.from_pretrained(model_path, provider=...)
|
| 26 |
-
"""
|
| 27 |
-
log.info("optimum_init", model_path=model_path, hardware=ctx.hardware)
|
| 28 |
-
self.device = "cuda" if "gpu" in ctx.hardware.lower() or "rtx" in ctx.hardware.lower() else "cpu"
|
| 29 |
-
|
| 30 |
-
# Simulate load time
|
| 31 |
-
await asyncio.sleep(1.5)
|
| 32 |
-
self.session = "active" # Placeholder for the real session object
|
| 33 |
-
|
| 34 |
-
async def run_batch(self, batch: Any) -> BatchResult:
|
| 35 |
-
"""Execute inference using the Optimum/ONNX Runtime session."""
|
| 36 |
-
if not self.session:
|
| 37 |
-
raise RuntimeError("Optimum session not initialized")
|
| 38 |
-
|
| 39 |
-
start_time = time.perf_counter()
|
| 40 |
-
# Mocking inference logic
|
| 41 |
-
# outputs = self.session(**batch)
|
| 42 |
-
await asyncio.sleep(0.01) # Simulated inference time
|
| 43 |
-
latency = (time.perf_counter() - start_time) * 1000
|
| 44 |
-
|
| 45 |
-
return BatchResult(
|
| 46 |
-
latency_ms=latency,
|
| 47 |
-
vram_used_gb=0.8, # Mocked
|
| 48 |
-
task_scores={"accuracy": 0.92} # Mocked
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
async def shutdown(self) -> None:
|
| 52 |
-
log.info("optimum_shutdown")
|
| 53 |
-
self.session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/adapters/registry.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/adapters/registry.py β Executor Registry for dynamic runner resolution.
|
| 3 |
-
"""
|
| 4 |
-
from __future__ import annotations
|
| 5 |
-
|
| 6 |
-
from typing import Type
|
| 7 |
-
from benchmark.adapters.base import BaseRunner
|
| 8 |
-
from models.benchmark import BenchmarkContext
|
| 9 |
-
from models.model import Model
|
| 10 |
-
|
| 11 |
-
class ExecutorRegistry:
|
| 12 |
-
_runners: dict[str, Type[BaseRunner]] = {}
|
| 13 |
-
|
| 14 |
-
@classmethod
|
| 15 |
-
def register(cls, framework: str, runner_cls: Type[BaseRunner]):
|
| 16 |
-
cls._runners[framework.lower()] = runner_cls
|
| 17 |
-
|
| 18 |
-
@classmethod
|
| 19 |
-
def get_runner(cls, framework: str) -> BaseRunner:
|
| 20 |
-
runner_cls = cls._runners.get(framework.lower())
|
| 21 |
-
if not runner_cls:
|
| 22 |
-
# Fallback or default runner
|
| 23 |
-
from benchmark.adapters.torch_runner import TorchRunner
|
| 24 |
-
return TorchRunner()
|
| 25 |
-
return runner_cls()
|
| 26 |
-
|
| 27 |
-
def get_executor(ctx: BenchmarkContext, model: Model) -> BaseRunner:
|
| 28 |
-
"""Resolve the appropriate executor based on framework and task."""
|
| 29 |
-
framework = model.framework.lower()
|
| 30 |
-
|
| 31 |
-
# Special cases for optimized engines
|
| 32 |
-
if framework == "onnx" or framework == "openvino" or framework == "tensorrt":
|
| 33 |
-
from benchmark.adapters.optimum_runner import OptimumRunner
|
| 34 |
-
return OptimumRunner()
|
| 35 |
-
|
| 36 |
-
if ctx.task in ("generation", "nlp") and framework == "pytorch":
|
| 37 |
-
# Potential for vLLM if configured
|
| 38 |
-
try:
|
| 39 |
-
from benchmark.adapters.vllm_runner import VLLMRunner
|
| 40 |
-
return VLLMRunner()
|
| 41 |
-
except ImportError:
|
| 42 |
-
pass
|
| 43 |
-
|
| 44 |
-
return ExecutorRegistry.get_runner(framework)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/adapters/torch_runner.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/adapters/torch_runner.py β PyTorch Runner Adapter.
|
| 3 |
-
Wraps standard PyTorch inference for Vision and NLP tasks.
|
| 4 |
-
"""
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import asyncio
|
| 9 |
-
import random
|
| 10 |
-
from typing import Any
|
| 11 |
-
from benchmark.adapters.base import BaseRunner, BatchResult
|
| 12 |
-
from models.benchmark import BenchmarkContext
|
| 13 |
-
from observability.logger import get_logger
|
| 14 |
-
|
| 15 |
-
log = get_logger("benchmark.torch")
|
| 16 |
-
|
| 17 |
-
class TorchRunner(BaseRunner):
|
| 18 |
-
def __init__(self):
|
| 19 |
-
self.model = None
|
| 20 |
-
self.device = "cpu"
|
| 21 |
-
|
| 22 |
-
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 23 |
-
log.info("torch_init", model_path=model_path, hardware=ctx.hardware)
|
| 24 |
-
# In production: self.model = torch.load(model_path).to(self.device)
|
| 25 |
-
await asyncio.sleep(1.0)
|
| 26 |
-
self.model = "active"
|
| 27 |
-
|
| 28 |
-
async def run_batch(self, batch: Any) -> BatchResult:
|
| 29 |
-
if not self.model:
|
| 30 |
-
raise RuntimeError("Torch model not initialized")
|
| 31 |
-
|
| 32 |
-
start_time = time.perf_counter()
|
| 33 |
-
# Mocking torch inference
|
| 34 |
-
await asyncio.sleep(0.02)
|
| 35 |
-
latency = (time.perf_counter() - start_time) * 1000
|
| 36 |
-
|
| 37 |
-
return BatchResult(
|
| 38 |
-
latency_ms=latency,
|
| 39 |
-
vram_used_gb=1.2,
|
| 40 |
-
task_scores={"mAP": 0.45}
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
async def shutdown(self) -> None:
|
| 44 |
-
log.info("torch_shutdown")
|
| 45 |
-
self.model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/compatibility.py
DELETED
|
@@ -1,360 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/compatibility.py β Compatibility Validator (CRITICAL MODULE).
|
| 3 |
-
|
| 4 |
-
Validates model β dataset β hardware compatibility before any benchmark
|
| 5 |
-
execution begins. Returns a structured ValidationReport β never raises.
|
| 6 |
-
|
| 7 |
-
Five gates (all must pass):
|
| 8 |
-
A. Task compatibility β model.task matches dataset.task
|
| 9 |
-
B. Annotation format β dataset format supports the model's task
|
| 10 |
-
C. Framework Γ hardware β framework can run on the requested device
|
| 11 |
-
D. VRAM constraint β estimated memory fits available VRAM
|
| 12 |
-
E. Precision support β precision mode is valid for framework + hardware
|
| 13 |
-
"""
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
from models.benchmark import BenchmarkContext, ValidationCheck, ValidationReport
|
| 17 |
-
from models.dataset import Dataset
|
| 18 |
-
from models.model import Model
|
| 19 |
-
from observability.logger import get_logger
|
| 20 |
-
|
| 21 |
-
log = get_logger("benchmark.compatibility")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# ββ Lookup tables βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
-
|
| 26 |
-
# Hardware β available VRAM in GB (normalized keys, no spaces/dashes)
|
| 27 |
-
HARDWARE_VRAM_GB: dict[str, float] = {
|
| 28 |
-
# NVIDIA consumer β Ampere / Ada
|
| 29 |
-
"rtx4090": 24.0,
|
| 30 |
-
"rtx4080": 16.0,
|
| 31 |
-
"rtx4070ti": 12.0,
|
| 32 |
-
"rtx4070": 12.0,
|
| 33 |
-
"rtx4060ti": 8.0,
|
| 34 |
-
"rtx4060": 8.0,
|
| 35 |
-
"rtx3090": 24.0,
|
| 36 |
-
"rtx3080": 10.0,
|
| 37 |
-
"rtx3070": 8.0,
|
| 38 |
-
"rtx3060": 12.0,
|
| 39 |
-
"rtx2080ti": 11.0,
|
| 40 |
-
"rtx2080": 8.0,
|
| 41 |
-
# NVIDIA datacenter
|
| 42 |
-
"a100": 80.0,
|
| 43 |
-
"a10040gb": 40.0,
|
| 44 |
-
"h100": 80.0,
|
| 45 |
-
"v100": 32.0,
|
| 46 |
-
"t4": 16.0,
|
| 47 |
-
"a10": 24.0,
|
| 48 |
-
# AMD
|
| 49 |
-
"rx7900xtx": 24.0,
|
| 50 |
-
"rx6800xt": 16.0,
|
| 51 |
-
# Generic fallbacks
|
| 52 |
-
"gpu": 8.0,
|
| 53 |
-
"cpu": 0.0,
|
| 54 |
-
"tpu": 0.0,
|
| 55 |
-
"edge": 0.0,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
# model.task β set of compatible dataset.task values
|
| 59 |
-
TASK_COMPAT: dict[str, set[str]] = {
|
| 60 |
-
"detection": {"detection"},
|
| 61 |
-
"classification": {"classification"},
|
| 62 |
-
"segmentation": {"segmentation"},
|
| 63 |
-
"nlp": {"nlp"},
|
| 64 |
-
"generation": {"generation"},
|
| 65 |
-
"keypoints": {"keypoints", "detection"},
|
| 66 |
-
"embedding": {"nlp", "classification"},
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
# dataset.format β set of model tasks it supports
|
| 70 |
-
FORMAT_TASK_COMPAT: dict[str, set[str]] = {
|
| 71 |
-
"yolo": {"detection", "segmentation", "keypoints"},
|
| 72 |
-
"coco": {"detection", "segmentation", "keypoints"},
|
| 73 |
-
"voc": {"detection"},
|
| 74 |
-
"csv": {"classification"},
|
| 75 |
-
"json": {"detection", "segmentation", "classification", "nlp", "generation"},
|
| 76 |
-
"tfrecord": {"detection", "classification", "segmentation"},
|
| 77 |
-
"custom": {"detection", "classification", "segmentation", "nlp", "generation", "keypoints"},
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
# model.framework β set of hardware targets (normalized) it can run on
|
| 81 |
-
FRAMEWORK_HARDWARE_COMPAT: dict[str, set[str]] = {
|
| 82 |
-
"pytorch": {
|
| 83 |
-
"cpu", "gpu",
|
| 84 |
-
"rtx4090", "rtx4080", "rtx4070ti", "rtx4070", "rtx4060ti", "rtx4060",
|
| 85 |
-
"rtx3090", "rtx3080", "rtx3070", "rtx3060",
|
| 86 |
-
"rtx2080ti", "rtx2080",
|
| 87 |
-
"a100", "a10040gb", "h100", "v100", "t4", "a10",
|
| 88 |
-
},
|
| 89 |
-
"onnx": {
|
| 90 |
-
"cpu", "gpu",
|
| 91 |
-
"rtx4090", "rtx3090", "a100", "h100", "t4", "a10",
|
| 92 |
-
"edge",
|
| 93 |
-
},
|
| 94 |
-
"tensorflow": {
|
| 95 |
-
"cpu", "gpu",
|
| 96 |
-
"rtx4090", "rtx3090", "a100", "h100", "v100", "t4",
|
| 97 |
-
"tpu",
|
| 98 |
-
},
|
| 99 |
-
"tflite": {"cpu", "edge"},
|
| 100 |
-
"coreml": {"cpu"},
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
# Precisions that require GPU
|
| 104 |
-
_GPU_ONLY_PRECISIONS = {"FP16", "BF16"}
|
| 105 |
-
|
| 106 |
-
# Frameworks supporting INT8 quantization
|
| 107 |
-
_INT8_FRAMEWORKS = {"onnx", "tflite", "pytorch", "tensorflow"}
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class CompatibilityValidator:
|
| 111 |
-
"""
|
| 112 |
-
Runs all compatibility gates before a benchmark job is created.
|
| 113 |
-
Returns a ValidationReport β never raises exceptions.
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
def validate(
|
| 117 |
-
self,
|
| 118 |
-
model: Model,
|
| 119 |
-
dataset: Dataset,
|
| 120 |
-
ctx: BenchmarkContext,
|
| 121 |
-
) -> ValidationReport:
|
| 122 |
-
checks: list[ValidationCheck] = [
|
| 123 |
-
self._check_task(model, dataset),
|
| 124 |
-
self._check_annotation_format(model, dataset),
|
| 125 |
-
self._check_framework_hardware(model, ctx),
|
| 126 |
-
self._check_vram(model, ctx),
|
| 127 |
-
self._check_precision(model, ctx),
|
| 128 |
-
]
|
| 129 |
-
|
| 130 |
-
errors = [c.detail for c in checks if not c.passed]
|
| 131 |
-
warnings: list[str] = []
|
| 132 |
-
|
| 133 |
-
log.info(
|
| 134 |
-
"compatibility_validated",
|
| 135 |
-
model_id = model.id,
|
| 136 |
-
dataset_id = dataset.id,
|
| 137 |
-
passed = len(errors) == 0,
|
| 138 |
-
error_count = len(errors),
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
return ValidationReport(
|
| 142 |
-
model_id = model.id,
|
| 143 |
-
dataset_id = dataset.id,
|
| 144 |
-
passed = len(errors) == 0,
|
| 145 |
-
checks = checks,
|
| 146 |
-
errors = errors,
|
| 147 |
-
warnings = warnings,
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# ββ Gate A: Task ββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββ
|
| 151 |
-
|
| 152 |
-
def _check_task(self, model: Model, dataset: Dataset) -> ValidationCheck:
|
| 153 |
-
model_task = model.task.lower().strip()
|
| 154 |
-
dataset_task = str(dataset.task).lower().strip()
|
| 155 |
-
|
| 156 |
-
allowed = TASK_COMPAT.get(model_task, {model_task})
|
| 157 |
-
if dataset_task in allowed:
|
| 158 |
-
return ValidationCheck(
|
| 159 |
-
name = "task_compatibility",
|
| 160 |
-
passed = True,
|
| 161 |
-
detail = (
|
| 162 |
-
f"Model task '{model_task}' is compatible "
|
| 163 |
-
f"with dataset task '{dataset_task}'"
|
| 164 |
-
),
|
| 165 |
-
)
|
| 166 |
-
return ValidationCheck(
|
| 167 |
-
name = "task_compatibility",
|
| 168 |
-
passed = False,
|
| 169 |
-
detail = (
|
| 170 |
-
f"Model task '{model_task}' cannot evaluate "
|
| 171 |
-
f"a '{dataset_task}' dataset"
|
| 172 |
-
),
|
| 173 |
-
suggestion = (
|
| 174 |
-
f"Select a model with task='{dataset_task}', "
|
| 175 |
-
f"or choose a dataset with task='{model_task}'"
|
| 176 |
-
),
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# ββ Gate B: Annotation Format βββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
-
|
| 181 |
-
def _check_annotation_format(self, model: Model, dataset: Dataset) -> ValidationCheck:
|
| 182 |
-
dataset_fmt = str(dataset.format).lower().strip()
|
| 183 |
-
model_task = model.task.lower().strip()
|
| 184 |
-
supported = FORMAT_TASK_COMPAT.get(dataset_fmt, set())
|
| 185 |
-
|
| 186 |
-
if model_task in supported:
|
| 187 |
-
return ValidationCheck(
|
| 188 |
-
name = "annotation_format",
|
| 189 |
-
passed = True,
|
| 190 |
-
detail = (
|
| 191 |
-
f"Dataset format '{dataset_fmt}' supports "
|
| 192 |
-
f"model task '{model_task}'"
|
| 193 |
-
),
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
if model_task in {"detection", "segmentation", "keypoints"}:
|
| 197 |
-
suggestion = (
|
| 198 |
-
f"Convert dataset to YOLO or COCO format β both support '{model_task}'"
|
| 199 |
-
)
|
| 200 |
-
elif model_task == "classification":
|
| 201 |
-
suggestion = "Convert dataset to CSV or JSON format for classification tasks"
|
| 202 |
-
else:
|
| 203 |
-
suggestion = f"Use a JSON or custom-format dataset for '{model_task}' tasks"
|
| 204 |
-
|
| 205 |
-
return ValidationCheck(
|
| 206 |
-
name = "annotation_format",
|
| 207 |
-
passed = False,
|
| 208 |
-
detail = (
|
| 209 |
-
f"Dataset format '{dataset_fmt}' does not support "
|
| 210 |
-
f"model task '{model_task}'"
|
| 211 |
-
),
|
| 212 |
-
suggestion = suggestion,
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
-
# ββ Gate C: Framework Γ Hardware βββββββββββββββββββββββββββββββββββββββββ
|
| 216 |
-
|
| 217 |
-
def _check_framework_hardware(
|
| 218 |
-
self, model: Model, ctx: BenchmarkContext
|
| 219 |
-
) -> ValidationCheck:
|
| 220 |
-
framework = model.framework.lower().strip()
|
| 221 |
-
hw_raw = ctx.hardware
|
| 222 |
-
hw_key = self._normalize_hw(hw_raw)
|
| 223 |
-
|
| 224 |
-
supported_hw = FRAMEWORK_HARDWARE_COMPAT.get(framework, {"cpu"})
|
| 225 |
-
|
| 226 |
-
# Match: exact key, or generic "gpu" bucket covers any named GPU
|
| 227 |
-
hw_ok = (
|
| 228 |
-
hw_key in supported_hw
|
| 229 |
-
or ("gpu" in supported_hw and hw_key not in {"cpu", "tpu", "edge"})
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
if hw_ok:
|
| 233 |
-
return ValidationCheck(
|
| 234 |
-
name = "framework_hardware",
|
| 235 |
-
passed = True,
|
| 236 |
-
detail = f"Framework '{framework}' is supported on '{hw_raw}'",
|
| 237 |
-
)
|
| 238 |
-
return ValidationCheck(
|
| 239 |
-
name = "framework_hardware",
|
| 240 |
-
passed = False,
|
| 241 |
-
detail = (
|
| 242 |
-
f"Framework '{framework}' cannot run on '{hw_raw}'. "
|
| 243 |
-
f"Supported targets: {', '.join(sorted(supported_hw))}"
|
| 244 |
-
),
|
| 245 |
-
suggestion = (
|
| 246 |
-
"Use ONNX runtime for broadest hardware support, "
|
| 247 |
-
f"or pick a device from: {', '.join(sorted(supported_hw))}"
|
| 248 |
-
),
|
| 249 |
-
)
|
| 250 |
-
|
| 251 |
-
# ββ Gate D: VRAM Constraint βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 252 |
-
|
| 253 |
-
def _check_vram(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
|
| 254 |
-
hw_key = self._normalize_hw(ctx.hardware)
|
| 255 |
-
available = self._lookup_vram(hw_key)
|
| 256 |
-
|
| 257 |
-
if available == 0.0:
|
| 258 |
-
return ValidationCheck(
|
| 259 |
-
name = "vram_constraint",
|
| 260 |
-
passed = True,
|
| 261 |
-
detail = f"Running on '{ctx.hardware}' (CPU/TPU/Edge) β no VRAM constraint",
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
# Estimate: weights at given precision + activations for one batch
|
| 265 |
-
model_gb = max(model.size, 1) / (1024 ** 3)
|
| 266 |
-
prec_map = {"FP16": 0.5, "BF16": 0.5, "INT8": 0.25, "FP32": 1.0}
|
| 267 |
-
prec_mult = prec_map.get(ctx.precision.upper(), 1.0)
|
| 268 |
-
# weights Γ precision + ~20% for optimizer/activation buffers + batch overhead
|
| 269 |
-
estimated = (model_gb * prec_mult * 1.2) + (ctx.batch_size * 0.05)
|
| 270 |
-
|
| 271 |
-
if estimated <= available:
|
| 272 |
-
return ValidationCheck(
|
| 273 |
-
name = "vram_constraint",
|
| 274 |
-
passed = True,
|
| 275 |
-
detail = (
|
| 276 |
-
f"Estimated VRAM {estimated:.2f} GB β€ "
|
| 277 |
-
f"available {available:.1f} GB on '{ctx.hardware}'"
|
| 278 |
-
),
|
| 279 |
-
)
|
| 280 |
-
return ValidationCheck(
|
| 281 |
-
name = "vram_constraint",
|
| 282 |
-
passed = False,
|
| 283 |
-
detail = (
|
| 284 |
-
f"Estimated VRAM {estimated:.2f} GB exceeds "
|
| 285 |
-
f"available {available:.1f} GB on '{ctx.hardware}'"
|
| 286 |
-
),
|
| 287 |
-
suggestion = (
|
| 288 |
-
f"Try: reduce batch_size (now {ctx.batch_size}), "
|
| 289 |
-
f"switch to FP16/INT8 precision, "
|
| 290 |
-
f"or use a GPU with β₯ {estimated:.1f} GB VRAM"
|
| 291 |
-
),
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
# ββ Gate E: Precision Support βββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
-
|
| 296 |
-
def _check_precision(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
|
| 297 |
-
precision = ctx.precision.upper()
|
| 298 |
-
framework = model.framework.lower().strip()
|
| 299 |
-
hw_key = self._normalize_hw(ctx.hardware)
|
| 300 |
-
is_gpu = hw_key not in {"cpu", "tpu", "edge"}
|
| 301 |
-
|
| 302 |
-
if precision in _GPU_ONLY_PRECISIONS and not is_gpu:
|
| 303 |
-
return ValidationCheck(
|
| 304 |
-
name = "precision_support",
|
| 305 |
-
passed = False,
|
| 306 |
-
detail = (
|
| 307 |
-
f"Precision '{precision}' requires a CUDA GPU; "
|
| 308 |
-
f"'{ctx.hardware}' does not support it"
|
| 309 |
-
),
|
| 310 |
-
suggestion = "Use FP32 for CPU inference, or switch to a compatible GPU",
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
if precision == "INT8" and framework not in _INT8_FRAMEWORKS:
|
| 314 |
-
return ValidationCheck(
|
| 315 |
-
name = "precision_support",
|
| 316 |
-
passed = False,
|
| 317 |
-
detail = (
|
| 318 |
-
f"Framework '{framework}' does not support INT8 quantization"
|
| 319 |
-
),
|
| 320 |
-
suggestion = (
|
| 321 |
-
"Convert model to ONNX or use PyTorch with torch.quantization"
|
| 322 |
-
),
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
return ValidationCheck(
|
| 326 |
-
name = "precision_support",
|
| 327 |
-
passed = True,
|
| 328 |
-
detail = (
|
| 329 |
-
f"Precision '{precision}' is valid for "
|
| 330 |
-
f"framework '{framework}' on '{ctx.hardware}'"
|
| 331 |
-
),
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 335 |
-
|
| 336 |
-
@staticmethod
|
| 337 |
-
def _normalize_hw(hardware: str) -> str:
|
| 338 |
-
"""Lowercase, strip spaces/dashes/underscores for lookup."""
|
| 339 |
-
return (
|
| 340 |
-
hardware.lower()
|
| 341 |
-
.replace(" ", "")
|
| 342 |
-
.replace("-", "")
|
| 343 |
-
.replace("_", "")
|
| 344 |
-
.replace("nvidia", "")
|
| 345 |
-
.replace("geforce", "")
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
@staticmethod
|
| 349 |
-
def _lookup_vram(hw_key: str) -> float:
|
| 350 |
-
"""Return VRAM GB for a normalized hardware key, with fallback matching."""
|
| 351 |
-
if hw_key in HARDWARE_VRAM_GB:
|
| 352 |
-
return HARDWARE_VRAM_GB[hw_key]
|
| 353 |
-
# Partial match (e.g. "rtx4090laptop" β "rtx4090")
|
| 354 |
-
for key, vram in HARDWARE_VRAM_GB.items():
|
| 355 |
-
if key and key in hw_key:
|
| 356 |
-
return vram
|
| 357 |
-
# Anything that looks like a GPU but isn't in the table
|
| 358 |
-
if "gpu" in hw_key or "rtx" in hw_key or "gtx" in hw_key or "cuda" in hw_key:
|
| 359 |
-
return HARDWARE_VRAM_GB["gpu"]
|
| 360 |
-
return 0.0 # CPU / unknown β no VRAM constraint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/execution.py
DELETED
|
@@ -1,366 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/execution.py β Benchmark Execution Engine.
|
| 3 |
-
|
| 4 |
-
Drives the batch inference loop, collecting latencies and VRAM readings.
|
| 5 |
-
Calls TelemetryCollector in parallel with batch processing.
|
| 6 |
-
Yields progress callbacks so the orchestrator can persist real-time state.
|
| 7 |
-
|
| 8 |
-
Adapter pattern: swap _run_single_batch() with a real inference call
|
| 9 |
-
(torch.cuda.synchronize + model(batch)) once GPU runtime is wired up.
|
| 10 |
-
|
| 11 |
-
PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
|
| 12 |
-
"""
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import asyncio
|
| 16 |
-
import math
|
| 17 |
-
import random
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from typing import Awaitable, Callable
|
| 20 |
-
|
| 21 |
-
from benchmark.compatibility import HARDWARE_VRAM_GB
|
| 22 |
-
from benchmark.telemetry import TelemetryCollector
|
| 23 |
-
from models.benchmark import BenchmarkJob, LayerBreakdown, TelemetrySample, TelemetrySummary
|
| 24 |
-
from models.dataset import Dataset
|
| 25 |
-
from models.model import Model
|
| 26 |
-
from observability.logger import get_logger
|
| 27 |
-
|
| 28 |
-
log = get_logger("benchmark.execution")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# ββ Per-image latency profiles (ms at batch=1, fp32) βββββββββββββββββββββββββ
|
| 32 |
-
_LATENCY_MS_PER_IMAGE: dict[str, float] = {
|
| 33 |
-
"rtx4090": 1.8,
|
| 34 |
-
"rtx4080": 2.5,
|
| 35 |
-
"rtx4070ti": 3.2,
|
| 36 |
-
"rtx4070": 3.8,
|
| 37 |
-
"rtx3090": 3.0,
|
| 38 |
-
"rtx3080": 4.5,
|
| 39 |
-
"rtx3070": 6.5,
|
| 40 |
-
"rtx3060": 9.0,
|
| 41 |
-
"rtx2080ti": 5.0,
|
| 42 |
-
"rtx2080": 7.5,
|
| 43 |
-
"a100": 1.2,
|
| 44 |
-
"h100": 0.7,
|
| 45 |
-
"v100": 2.8,
|
| 46 |
-
"t4": 5.5,
|
| 47 |
-
"a10": 3.5,
|
| 48 |
-
"gpu": 8.0,
|
| 49 |
-
"cpu": 42.0,
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
# Precision speedup multipliers (relative to FP32)
|
| 53 |
-
_PRECISION_SPEEDUP: dict[str, float] = {
|
| 54 |
-
"FP32": 1.0,
|
| 55 |
-
"FP16": 1.8,
|
| 56 |
-
"BF16": 1.7,
|
| 57 |
-
"INT8": 2.5,
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
# Task-specific baseline metric scores (pre-jitter)
|
| 61 |
-
_TASK_BASELINES: dict[str, dict[str, float]] = {
|
| 62 |
-
"detection": {"mAP": 0.435, "mAP_50": 0.618, "mAP_50_95": 0.435},
|
| 63 |
-
"classification": {"accuracy": 0.872, "top5": 0.968},
|
| 64 |
-
"segmentation": {"mAP": 0.372, "iou_mean": 0.706},
|
| 65 |
-
"keypoints": {"mAP": 0.641, "mAP_50": 0.860},
|
| 66 |
-
"nlp": {"accuracy": 0.891},
|
| 67 |
-
"generation": {"accuracy": 0.780},
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
# Cap simulated batches so large datasets don't stall the event loop
|
| 71 |
-
_MAX_SIMULATED_BATCHES = 250
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
@dataclass
|
| 75 |
-
class ExecutionResult:
|
| 76 |
-
"""Raw output from the execution engine, consumed by MetricsEngine."""
|
| 77 |
-
latencies_ms: list[float]
|
| 78 |
-
total_images: int
|
| 79 |
-
vram_samples: list[float]
|
| 80 |
-
task_scores: dict[str, float]
|
| 81 |
-
telemetry_samples: list[TelemetrySample] = field(default_factory=list)
|
| 82 |
-
telemetry_summary: TelemetrySummary = field(default_factory=TelemetrySummary)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# Progress callback type: (progress_0_to_1, message, last_telemetry) β None
|
| 86 |
-
ProgressCallback = Callable[[float, str, TelemetrySample | None], Awaitable[None]]
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class BenchmarkExecutor:
|
| 90 |
-
"""
|
| 91 |
-
Drives the benchmark execution loop.
|
| 92 |
-
Non-blocking: all sleeps are asyncio.sleep so other coroutines run freely.
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
async def execute(
|
| 96 |
-
self,
|
| 97 |
-
job: BenchmarkJob,
|
| 98 |
-
model: Model,
|
| 99 |
-
dataset: Dataset,
|
| 100 |
-
on_progress: ProgressCallback,
|
| 101 |
-
) -> ExecutionResult:
|
| 102 |
-
hw = job.hardware
|
| 103 |
-
batch_sz = job.batch_size
|
| 104 |
-
|
| 105 |
-
# Handle polymorphic input duration
|
| 106 |
-
is_live = getattr(job, "input_source", "dataset") in ("video", "live")
|
| 107 |
-
|
| 108 |
-
if is_live:
|
| 109 |
-
# For live/video, we run for a fixed duration or until stopped
|
| 110 |
-
# Increase limit for a longer session (e.g., 10,000 batches)
|
| 111 |
-
total_img = 10000 * batch_sz
|
| 112 |
-
n_batches = 10000
|
| 113 |
-
sim_batches = 10000
|
| 114 |
-
else:
|
| 115 |
-
total_img = max(dataset.images, 100) # floor so simulation always runs
|
| 116 |
-
n_batches = math.ceil(total_img / batch_sz)
|
| 117 |
-
sim_batches = min(n_batches, _MAX_SIMULATED_BATCHES)
|
| 118 |
-
|
| 119 |
-
vram_total = self._get_vram_gb(hw, model)
|
| 120 |
-
vram_frac = self._vram_usage_fraction(hw)
|
| 121 |
-
|
| 122 |
-
telemetry = TelemetryCollector(hw, vram_total_gb=vram_total)
|
| 123 |
-
await telemetry.start()
|
| 124 |
-
|
| 125 |
-
latencies: list[float] = []
|
| 126 |
-
vram_samples: list[float] = []
|
| 127 |
-
|
| 128 |
-
base_lat_ms = self._base_batch_latency_ms(hw, model, batch_sz, job.precision)
|
| 129 |
-
|
| 130 |
-
# Resolve real model path once (None β use simulation)
|
| 131 |
-
real_model_path = model.local_path if model.local_path and model.downloaded else None
|
| 132 |
-
use_real_inference = self._check_torch_available() and real_model_path is not None
|
| 133 |
-
loop = asyncio.get_event_loop()
|
| 134 |
-
|
| 135 |
-
try:
|
| 136 |
-
for sim_idx in range(sim_batches):
|
| 137 |
-
# Map simulated index back to real batch index
|
| 138 |
-
real_idx = int(sim_idx * (n_batches / sim_batches))
|
| 139 |
-
|
| 140 |
-
if use_real_inference:
|
| 141 |
-
# Real GPU inference via torch_runner (runs in thread executor)
|
| 142 |
-
try:
|
| 143 |
-
from benchmark.torch_runner import run_torch_batch
|
| 144 |
-
batch_lat_ms = await loop.run_in_executor(
|
| 145 |
-
None,
|
| 146 |
-
run_torch_batch,
|
| 147 |
-
real_model_path,
|
| 148 |
-
batch_sz,
|
| 149 |
-
job.task,
|
| 150 |
-
)
|
| 151 |
-
# Add a tiny sleep to prevent event loop starvation in live mode
|
| 152 |
-
if is_live:
|
| 153 |
-
await asyncio.sleep(0.001)
|
| 154 |
-
except Exception as exc:
|
| 155 |
-
log.warning("torch_inference_failed_fallback", error=str(exc))
|
| 156 |
-
use_real_inference = False # fall back for remaining batches
|
| 157 |
-
batch_lat_ms = max(
|
| 158 |
-
0.5, base_lat_ms + random.gauss(0, base_lat_ms * 0.07)
|
| 159 |
-
)
|
| 160 |
-
else:
|
| 161 |
-
# Simulation path β non-blocking synthetic latency
|
| 162 |
-
batch_lat_ms = max(
|
| 163 |
-
0.5,
|
| 164 |
-
base_lat_ms + random.gauss(0, base_lat_ms * 0.07),
|
| 165 |
-
)
|
| 166 |
-
await asyncio.sleep(batch_lat_ms / 1000.0) # non-blocking
|
| 167 |
-
|
| 168 |
-
latencies.append(batch_lat_ms)
|
| 169 |
-
vram_used = vram_total * random.uniform(
|
| 170 |
-
vram_frac - 0.05, vram_frac + 0.05
|
| 171 |
-
)
|
| 172 |
-
vram_samples.append(max(0.0, vram_used))
|
| 173 |
-
|
| 174 |
-
progress = (sim_idx + 1) / sim_batches
|
| 175 |
-
telemetry.record_batch_context(real_idx, progress)
|
| 176 |
-
|
| 177 |
-
# Throttle callbacks: every 5 batches or first/last
|
| 178 |
-
if sim_idx % 5 == 0 or sim_idx == sim_batches - 1:
|
| 179 |
-
images_done = int(progress * total_img)
|
| 180 |
-
|
| 181 |
-
# Generate simulated detection data for live preview if it's a vision task
|
| 182 |
-
live_data = {}
|
| 183 |
-
if job.task.lower() in ("detection", "segmentation"):
|
| 184 |
-
# Use provided bbox telemetry if available (e.g. from real inference)
|
| 185 |
-
# otherwise generate simulated ones
|
| 186 |
-
live_data["detections"] = [
|
| 187 |
-
{
|
| 188 |
-
"x": random.uniform(0.1, 0.7),
|
| 189 |
-
"y": random.uniform(0.1, 0.7),
|
| 190 |
-
"width": random.uniform(0.1, 0.3),
|
| 191 |
-
"height": random.uniform(0.1, 0.3),
|
| 192 |
-
"label": random.choice(["person", "car", "bicycle", "dog"]),
|
| 193 |
-
"confidence": random.uniform(0.5, 0.99)
|
| 194 |
-
}
|
| 195 |
-
for _ in range(random.randint(1, 5))
|
| 196 |
-
]
|
| 197 |
-
|
| 198 |
-
last_sample = telemetry.samples[-1] if telemetry.samples else None
|
| 199 |
-
if last_sample:
|
| 200 |
-
last_sample.live_data = live_data
|
| 201 |
-
# Explicitly broadcast detections for the visualizer
|
| 202 |
-
last_sample.detections = live_data.get("detections", [])
|
| 203 |
-
|
| 204 |
-
await on_progress(
|
| 205 |
-
progress,
|
| 206 |
-
f"Batch {real_idx+1}/{n_batches} β "
|
| 207 |
-
f"{images_done}/{total_img} images processed",
|
| 208 |
-
last_sample,
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
finally:
|
| 212 |
-
telemetry_summary = await telemetry.stop()
|
| 213 |
-
# Attach simulated layer breakdown so Live Lab can display it
|
| 214 |
-
telemetry_summary.layer_breakdown = self._compute_layer_breakdown(
|
| 215 |
-
job.task, base_lat_ms
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
task_scores = self._simulate_task_scores(job.task, model, dataset)
|
| 219 |
-
|
| 220 |
-
log.info(
|
| 221 |
-
"execution_complete",
|
| 222 |
-
job_id = job.id,
|
| 223 |
-
total_images = total_img,
|
| 224 |
-
sim_batches = sim_batches,
|
| 225 |
-
avg_lat_ms = round(sum(latencies) / len(latencies), 2) if latencies else 0,
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
return ExecutionResult(
|
| 229 |
-
latencies_ms = latencies,
|
| 230 |
-
total_images = total_img,
|
| 231 |
-
vram_samples = vram_samples,
|
| 232 |
-
task_scores = task_scores,
|
| 233 |
-
telemetry_samples = telemetry.samples,
|
| 234 |
-
telemetry_summary = telemetry_summary,
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
-
|
| 239 |
-
def _base_batch_latency_ms(
|
| 240 |
-
self,
|
| 241 |
-
hardware: str,
|
| 242 |
-
model: Model,
|
| 243 |
-
batch_sz: int,
|
| 244 |
-
precision: str,
|
| 245 |
-
) -> float:
|
| 246 |
-
"""
|
| 247 |
-
Estimate per-batch latency in ms.
|
| 248 |
-
Accounts for hardware tier, model size, batch size, and precision.
|
| 249 |
-
"""
|
| 250 |
-
hw_key = self._normalize_hw(hardware)
|
| 251 |
-
per_img = self._lookup_latency(hw_key)
|
| 252 |
-
|
| 253 |
-
# Larger models are slower: +30% per GB of model weights
|
| 254 |
-
size_gb = max(model.size, 1) / (1024 ** 3)
|
| 255 |
-
size_factor = 1.0 + size_gb * 0.30
|
| 256 |
-
|
| 257 |
-
# Batch parallelism: ~65% linear efficiency on GPU, 90% on CPU
|
| 258 |
-
eff = 0.65 if "cpu" not in hw_key else 0.90
|
| 259 |
-
batch_lat = per_img * size_factor * batch_sz * eff
|
| 260 |
-
|
| 261 |
-
# Precision speedup
|
| 262 |
-
speedup = _PRECISION_SPEEDUP.get(precision.upper(), 1.0)
|
| 263 |
-
|
| 264 |
-
return batch_lat / speedup
|
| 265 |
-
|
| 266 |
-
def _get_vram_gb(self, hardware: str, model: Model) -> float:
|
| 267 |
-
hw_key = self._normalize_hw(hardware)
|
| 268 |
-
for key, vram in HARDWARE_VRAM_GB.items():
|
| 269 |
-
if key and key in hw_key:
|
| 270 |
-
return vram
|
| 271 |
-
return 8.0
|
| 272 |
-
|
| 273 |
-
@staticmethod
|
| 274 |
-
def _vram_usage_fraction(hardware: str) -> float:
|
| 275 |
-
"""Fraction of VRAM typically consumed during inference."""
|
| 276 |
-
hw = hardware.lower()
|
| 277 |
-
if any(x in hw for x in ("4090", "3090", "a100", "h100")):
|
| 278 |
-
return 0.62
|
| 279 |
-
if any(x in hw for x in ("4080", "3080", "v100", "a10")):
|
| 280 |
-
return 0.60
|
| 281 |
-
if "cpu" in hw:
|
| 282 |
-
return 0.0
|
| 283 |
-
return 0.55
|
| 284 |
-
|
| 285 |
-
@staticmethod
|
| 286 |
-
def _simulate_task_scores(
|
| 287 |
-
task: str, model: Model, dataset: Dataset
|
| 288 |
-
) -> dict[str, float]:
|
| 289 |
-
"""
|
| 290 |
-
Produce realistic metric scores with small per-run variance.
|
| 291 |
-
|
| 292 |
-
PRODUCTION SWAP: replace with actual metric computation:
|
| 293 |
-
from torchmetrics.detection import MeanAveragePrecision
|
| 294 |
-
metric = MeanAveragePrecision()
|
| 295 |
-
metric.update(predictions, targets)
|
| 296 |
-
return metric.compute()
|
| 297 |
-
"""
|
| 298 |
-
baselines = dict(_TASK_BASELINES.get(task.lower(), {"accuracy": 0.80}))
|
| 299 |
-
# Small Gaussian jitter simulates run-to-run variance
|
| 300 |
-
return {
|
| 301 |
-
k: float(max(0.0, min(1.0, v + random.gauss(0, 0.015))))
|
| 302 |
-
for k, v in baselines.items()
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
@staticmethod
|
| 306 |
-
def _check_torch_available() -> bool:
|
| 307 |
-
"""Return True if PyTorch is installed and importable."""
|
| 308 |
-
try:
|
| 309 |
-
import torch # noqa: F401
|
| 310 |
-
return True
|
| 311 |
-
except ImportError:
|
| 312 |
-
return False
|
| 313 |
-
|
| 314 |
-
@staticmethod
|
| 315 |
-
def _compute_layer_breakdown(task: str, base_lat_ms: float) -> list[LayerBreakdown]:
|
| 316 |
-
"""Build a realistic layer breakdown for the given task.
|
| 317 |
-
|
| 318 |
-
Splits total latency across architectural stages with small jitter.
|
| 319 |
-
PRODUCTION SWAP: replace with actual profiler data (e.g. torch.profiler).
|
| 320 |
-
"""
|
| 321 |
-
if task.lower() in ("detection", "segmentation"):
|
| 322 |
-
stages = [
|
| 323 |
-
("Backbone", 0.45),
|
| 324 |
-
("Neck (FPN/PAFPN)", 0.30),
|
| 325 |
-
("Detection Head", 0.20),
|
| 326 |
-
("NMS Post-process", 0.05),
|
| 327 |
-
]
|
| 328 |
-
elif task.lower() == "classification":
|
| 329 |
-
stages = [
|
| 330 |
-
("Feature Extractor", 0.70),
|
| 331 |
-
("Classifier Head", 0.20),
|
| 332 |
-
("Softmax", 0.10),
|
| 333 |
-
]
|
| 334 |
-
else:
|
| 335 |
-
stages = [
|
| 336 |
-
("Encoder", 0.55),
|
| 337 |
-
("Decoder / Head", 0.35),
|
| 338 |
-
("Post-process", 0.10),
|
| 339 |
-
]
|
| 340 |
-
|
| 341 |
-
result: list[LayerBreakdown] = []
|
| 342 |
-
remaining = base_lat_ms
|
| 343 |
-
for name, frac in stages:
|
| 344 |
-
t = round(base_lat_ms * frac + random.gauss(0, base_lat_ms * 0.01), 3)
|
| 345 |
-
result.append(LayerBreakdown(name=name, time_ms=t, percent=round(frac * 100, 1)))
|
| 346 |
-
return result
|
| 347 |
-
|
| 348 |
-
@staticmethod
|
| 349 |
-
def _normalize_hw(hardware: str) -> str:
|
| 350 |
-
return (
|
| 351 |
-
hardware.lower()
|
| 352 |
-
.replace(" ", "")
|
| 353 |
-
.replace("-", "")
|
| 354 |
-
.replace("_", "")
|
| 355 |
-
.replace("nvidia", "")
|
| 356 |
-
.replace("geforce", "")
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
@staticmethod
|
| 360 |
-
def _lookup_latency(hw_key: str) -> float:
|
| 361 |
-
for key, ms in _LATENCY_MS_PER_IMAGE.items():
|
| 362 |
-
if key and key in hw_key:
|
| 363 |
-
return ms
|
| 364 |
-
if any(x in hw_key for x in ("gpu", "rtx", "gtx", "cuda")):
|
| 365 |
-
return _LATENCY_MS_PER_IMAGE["gpu"]
|
| 366 |
-
return _LATENCY_MS_PER_IMAGE["cpu"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/metrics.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/metrics.py β Metrics Engine.
|
| 3 |
-
|
| 4 |
-
Computes the final BenchmarkMetrics object from raw execution data:
|
| 5 |
-
- Latency statistics (mean, p95, p99)
|
| 6 |
-
- Throughput (FPS)
|
| 7 |
-
- VRAM statistics (avg, peak)
|
| 8 |
-
- Task-specific scores (mAP, accuracy, IoU) supplied by the executor
|
| 9 |
-
|
| 10 |
-
In a production deployment the task_scores dict comes from actual
|
| 11 |
-
metric computation (e.g. pycocotools, torchmetrics). In this local-first
|
| 12 |
-
build the executor supplies realistic simulated scores.
|
| 13 |
-
"""
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import statistics
|
| 17 |
-
|
| 18 |
-
from models.benchmark import BenchmarkMetrics, LayerBreakdown, TelemetrySummary
|
| 19 |
-
from observability.logger import get_logger
|
| 20 |
-
|
| 21 |
-
log = get_logger("benchmark.metrics")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class MetricsEngine:
|
| 25 |
-
"""Computes BenchmarkMetrics from raw benchmark execution data."""
|
| 26 |
-
|
| 27 |
-
def compute(
|
| 28 |
-
self,
|
| 29 |
-
*,
|
| 30 |
-
task: str,
|
| 31 |
-
latencies_ms: list[float], # per-batch latencies
|
| 32 |
-
total_images: int = 0,
|
| 33 |
-
total_tokens: int = 0,
|
| 34 |
-
batch_size: int,
|
| 35 |
-
vram_samples: list[float], # VRAM readings (GB) during run
|
| 36 |
-
task_scores: dict[str, float], # task-specific metric scores
|
| 37 |
-
) -> BenchmarkMetrics:
|
| 38 |
-
if not latencies_ms:
|
| 39 |
-
return BenchmarkMetrics(total_images=total_images, total_tokens=total_tokens, batch_size=batch_size)
|
| 40 |
-
|
| 41 |
-
total_time_s = sum(latencies_ms) / 1000.0
|
| 42 |
-
fps = total_images / total_time_s if total_time_s > 0 and total_images > 0 else 0.0
|
| 43 |
-
tps = total_tokens / total_time_s if total_time_s > 0 and total_tokens > 0 else 0.0
|
| 44 |
-
|
| 45 |
-
lat_mean = statistics.mean(latencies_ms)
|
| 46 |
-
lat_p95 = _percentile(latencies_ms, 0.95)
|
| 47 |
-
lat_p99 = _percentile(latencies_ms, 0.99)
|
| 48 |
-
|
| 49 |
-
vram_peak = max(vram_samples) if vram_samples else 0.0
|
| 50 |
-
vram_avg = statistics.mean(vram_samples) if vram_samples else 0.0
|
| 51 |
-
|
| 52 |
-
m = BenchmarkMetrics(
|
| 53 |
-
fps = round(fps, 2),
|
| 54 |
-
tokens_per_sec = round(tps, 2),
|
| 55 |
-
latency_mean_ms = round(lat_mean, 3),
|
| 56 |
-
latency_p95_ms = round(lat_p95, 3),
|
| 57 |
-
latency_p99_ms = round(lat_p99, 3),
|
| 58 |
-
vram_peak_gb = round(vram_peak, 3),
|
| 59 |
-
vram_avg_gb = round(vram_avg, 3),
|
| 60 |
-
total_images = total_images,
|
| 61 |
-
total_tokens = total_tokens,
|
| 62 |
-
batch_size = batch_size,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
task_lower = task.lower()
|
| 66 |
-
|
| 67 |
-
# CV Task Mapping
|
| 68 |
-
if task_lower in ("detection", "segmentation", "keypoints"):
|
| 69 |
-
m.mAP = _fmt(task_scores.get("mAP", 0.0))
|
| 70 |
-
m.mAP_50 = _fmt(task_scores.get("mAP_50", 0.0))
|
| 71 |
-
m.mAP_50_95 = _fmt(task_scores.get("mAP_50_95", 0.0))
|
| 72 |
-
if task_lower == "segmentation":
|
| 73 |
-
m.iou_mean = _fmt(task_scores.get("iou_mean", 0.0))
|
| 74 |
-
|
| 75 |
-
elif task_lower == "classification":
|
| 76 |
-
m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
|
| 77 |
-
m.top1 = _fmt(task_scores.get("top1", 0.0))
|
| 78 |
-
m.top5 = _fmt(task_scores.get("top5", 0.0))
|
| 79 |
-
|
| 80 |
-
# NLP Task Mapping (ROUGE, BLEU, Perplexity)
|
| 81 |
-
elif task_lower in ("nlp", "generation"):
|
| 82 |
-
m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
|
| 83 |
-
m.rouge_l = _fmt(task_scores.get("rouge_l", task_scores.get("rougeL", 0.0)))
|
| 84 |
-
m.bleu = _fmt(task_scores.get("bleu", 0.0))
|
| 85 |
-
m.perplexity = task_scores.get("perplexity")
|
| 86 |
-
|
| 87 |
-
log.info(
|
| 88 |
-
"metrics_computed",
|
| 89 |
-
task = task,
|
| 90 |
-
fps = m.fps,
|
| 91 |
-
tps = m.tokens_per_sec,
|
| 92 |
-
latency_ms = m.latency_mean_ms,
|
| 93 |
-
vram_peak = m.vram_peak_gb,
|
| 94 |
-
)
|
| 95 |
-
return m
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 99 |
-
|
| 100 |
-
def _percentile(data: list[float], p: float) -> float:
|
| 101 |
-
if not data:
|
| 102 |
-
return 0.0
|
| 103 |
-
s = sorted(data)
|
| 104 |
-
idx = min(int(len(s) * p), len(s) - 1)
|
| 105 |
-
return s[idx]
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def _fmt(v: float) -> float:
|
| 109 |
-
"""Round to 4dp and clamp to [0, 1]."""
|
| 110 |
-
return round(max(0.0, min(1.0, v)), 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/orchestrator.py
DELETED
|
@@ -1,374 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/orchestrator.py β Benchmark Orchestrator (Main Controller).
|
| 3 |
-
|
| 4 |
-
Coordinates the full benchmark lifecycle:
|
| 5 |
-
1. Resolve model + dataset from their registries
|
| 6 |
-
2. Run all compatibility checks (gates AβE)
|
| 7 |
-
3. If valid β create a BenchmarkJob in the DB
|
| 8 |
-
4. Persist the validation audit log
|
| 9 |
-
5. Enqueue async background task β execution β metrics β storage
|
| 10 |
-
6. Return the job immediately so callers are non-blocking
|
| 11 |
-
|
| 12 |
-
Public interface used by api/routes/benchmark.py:
|
| 13 |
-
validate_context(ctx) β ValidationReport (no job created)
|
| 14 |
-
create_and_run(ctx) β BenchmarkJob (job queued, execution in background)
|
| 15 |
-
"""
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import asyncio
|
| 19 |
-
from datetime import datetime, timezone
|
| 20 |
-
|
| 21 |
-
from benchmark.adapters.registry import get_executor
|
| 22 |
-
from benchmark.compatibility import CompatibilityValidator
|
| 23 |
-
from benchmark.execution import BenchmarkExecutor
|
| 24 |
-
from benchmark.metrics import MetricsEngine
|
| 25 |
-
import benchmark.registry as bench_reg
|
| 26 |
-
from datasets.registry import get_dataset
|
| 27 |
-
from models.benchmark import (
|
| 28 |
-
BenchmarkContext,
|
| 29 |
-
BenchmarkJob,
|
| 30 |
-
BenchmarkMetrics,
|
| 31 |
-
TelemetrySummary,
|
| 32 |
-
ValidationReport,
|
| 33 |
-
)
|
| 34 |
-
from models.dataset import Dataset
|
| 35 |
-
from models.model import Model
|
| 36 |
-
from observability.logger import audit, get_logger
|
| 37 |
-
from registry.registry import get_model
|
| 38 |
-
|
| 39 |
-
log = get_logger("benchmark.orchestrator")
|
| 40 |
-
|
| 41 |
-
# Module-level singletons β stateless, safe to share
|
| 42 |
-
_validator = CompatibilityValidator()
|
| 43 |
-
_metrics = MetricsEngine()
|
| 44 |
-
|
| 45 |
-
# job_id β asyncio.Task (for future cancellation support)
|
| 46 |
-
_active_tasks: dict[str, asyncio.Task] = {}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
-
|
| 51 |
-
async def sync_project_benchmarks() -> int:
|
| 52 |
-
"""
|
| 53 |
-
Sync benchmark jobs and results from the active project's 'benchmarks' folder.
|
| 54 |
-
This ensures that benchmarks created in different sessions or projects are indexed.
|
| 55 |
-
"""
|
| 56 |
-
from benchmark.registry import _get_active_project_benchmark_dir_sync
|
| 57 |
-
from projects.service import get_active_project_path
|
| 58 |
-
import json
|
| 59 |
-
import os
|
| 60 |
-
from database.connection import get_db
|
| 61 |
-
|
| 62 |
-
project_path = await get_active_project_path()
|
| 63 |
-
benchmark_dir = _get_active_project_benchmark_dir_sync(project_path)
|
| 64 |
-
if not benchmark_dir or not benchmark_dir.exists():
|
| 65 |
-
return 0
|
| 66 |
-
|
| 67 |
-
db = await get_db()
|
| 68 |
-
count = 0
|
| 69 |
-
|
| 70 |
-
for file_path in benchmark_dir.glob("*.json"):
|
| 71 |
-
try:
|
| 72 |
-
with open(file_path, "r") as f:
|
| 73 |
-
data = json.load(f)
|
| 74 |
-
|
| 75 |
-
# Check if it's a job or a result
|
| 76 |
-
if file_path.name.startswith("job_"):
|
| 77 |
-
# Upsert into benchmark_jobs
|
| 78 |
-
await db.execute(
|
| 79 |
-
"""INSERT OR IGNORE INTO benchmark_jobs
|
| 80 |
-
(id, model_id, dataset_id, task, framework, hardware,
|
| 81 |
-
precision, batch_size, config, status, progress, created_at, updated_at, started_at)
|
| 82 |
-
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
| 83 |
-
(
|
| 84 |
-
data["id"], data["model_id"], data["dataset_id"],
|
| 85 |
-
data["task"], data["framework"], data["hardware"],
|
| 86 |
-
data["precision"], data["batch_size"],
|
| 87 |
-
json.dumps(data["config"]), data["status"],
|
| 88 |
-
data.get("progress", 0.0),
|
| 89 |
-
data.get("created_at", datetime.now(timezone.utc).isoformat()),
|
| 90 |
-
data.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
| 91 |
-
data.get("started_at")
|
| 92 |
-
)
|
| 93 |
-
)
|
| 94 |
-
count += 1
|
| 95 |
-
elif file_path.name.startswith("result_"):
|
| 96 |
-
# Upsert into benchmark_results
|
| 97 |
-
await db.execute(
|
| 98 |
-
"""INSERT OR IGNORE INTO benchmark_results
|
| 99 |
-
(id, job_id, metrics, telemetry_summary, created_at)
|
| 100 |
-
VALUES (?,?,?,?,?)""",
|
| 101 |
-
(
|
| 102 |
-
data["id"], data["job_id"],
|
| 103 |
-
json.dumps(data["metrics"]),
|
| 104 |
-
json.dumps(data["telemetry_summary"]),
|
| 105 |
-
data.get("created_at", datetime.now(timezone.utc).isoformat())
|
| 106 |
-
)
|
| 107 |
-
)
|
| 108 |
-
count += 1
|
| 109 |
-
except Exception as e:
|
| 110 |
-
log.error("sync_file_failed", file=file_path.name, error=str(e))
|
| 111 |
-
|
| 112 |
-
await db.commit()
|
| 113 |
-
log.info("sync_complete", count=count)
|
| 114 |
-
return count
|
| 115 |
-
|
| 116 |
-
async def validate_context(ctx: BenchmarkContext) -> ValidationReport:
|
| 117 |
-
"""
|
| 118 |
-
Validate model β dataset β hardware compatibility.
|
| 119 |
-
Does NOT create a job. Safe to call repeatedly from the UI.
|
| 120 |
-
"""
|
| 121 |
-
model = await _require_model(ctx.model_id)
|
| 122 |
-
|
| 123 |
-
# ββ Handle Polymorphic Input (Video/Live) οΏ½οΏ½βββββββββββββββββββββββββββββββ
|
| 124 |
-
if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
|
| 125 |
-
# Create a synthetic dataset object for non-dataset sources
|
| 126 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 127 |
-
dataset = Dataset(
|
| 128 |
-
id="none",
|
| 129 |
-
name="Live/Video Stream",
|
| 130 |
-
task=model.task, # Match model task to pass task check
|
| 131 |
-
format="custom",
|
| 132 |
-
source="local",
|
| 133 |
-
status="imported",
|
| 134 |
-
images=0,
|
| 135 |
-
classes=0,
|
| 136 |
-
size_label="0 MB",
|
| 137 |
-
created_at=now,
|
| 138 |
-
updated_at=now
|
| 139 |
-
)
|
| 140 |
-
else:
|
| 141 |
-
dataset = await _require_dataset(ctx.dataset_id)
|
| 142 |
-
|
| 143 |
-
return _validator.validate(model, dataset, ctx)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
async def create_and_run(ctx: BenchmarkContext) -> BenchmarkJob:
|
| 147 |
-
"""
|
| 148 |
-
Full benchmark initiation:
|
| 149 |
-
"""
|
| 150 |
-
model = await _require_model(ctx.model_id)
|
| 151 |
-
|
| 152 |
-
# ββ Handle Polymorphic Input (Video/Live) ββββββββββββββββββββββββββββββββ
|
| 153 |
-
if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
|
| 154 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 155 |
-
dataset = Dataset(
|
| 156 |
-
id="none",
|
| 157 |
-
name="Live/Video Stream",
|
| 158 |
-
task=model.task,
|
| 159 |
-
format="custom",
|
| 160 |
-
source="local",
|
| 161 |
-
status="imported",
|
| 162 |
-
images=0,
|
| 163 |
-
classes=0,
|
| 164 |
-
size_label="0 MB",
|
| 165 |
-
created_at=now,
|
| 166 |
-
updated_at=now
|
| 167 |
-
)
|
| 168 |
-
else:
|
| 169 |
-
dataset = await _require_dataset(ctx.dataset_id)
|
| 170 |
-
|
| 171 |
-
# ββ Compatibility check βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 172 |
-
report = _validator.validate(model, dataset, ctx)
|
| 173 |
-
|
| 174 |
-
# Always persist the validation log (even for failures)
|
| 175 |
-
await bench_reg.save_validation_log(
|
| 176 |
-
job_id = "pre-check",
|
| 177 |
-
model_id = ctx.model_id,
|
| 178 |
-
dataset_id = ctx.dataset_id,
|
| 179 |
-
checks = report.checks,
|
| 180 |
-
passed = report.passed,
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
if not report.passed:
|
| 184 |
-
from fastapi import HTTPException
|
| 185 |
-
failed = [c for c in report.checks if not c.passed]
|
| 186 |
-
raise HTTPException(
|
| 187 |
-
status_code = 422,
|
| 188 |
-
detail = {
|
| 189 |
-
"error": "Compatibility validation failed",
|
| 190 |
-
"failed_checks": [
|
| 191 |
-
{
|
| 192 |
-
"name": c.name,
|
| 193 |
-
"detail": c.detail,
|
| 194 |
-
"suggestion": c.suggestion,
|
| 195 |
-
}
|
| 196 |
-
for c in failed
|
| 197 |
-
],
|
| 198 |
-
},
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
# ββ Create job ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
-
job = await bench_reg.create_job(ctx)
|
| 203 |
-
|
| 204 |
-
# Overwrite 'pre-check' validation log with the real job_id
|
| 205 |
-
await bench_reg.save_validation_log(
|
| 206 |
-
job_id = job.id,
|
| 207 |
-
model_id = ctx.model_id,
|
| 208 |
-
dataset_id = ctx.dataset_id,
|
| 209 |
-
checks = report.checks,
|
| 210 |
-
passed = True,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
# ββ Log the Polymorphic Input params βββββββββββββββββββββββββββββββββββββ
|
| 214 |
-
if ctx.input_source or ctx.video_path or ctx.rtsp_url:
|
| 215 |
-
log.info("polymorphic_input_received",
|
| 216 |
-
job_id=job.id,
|
| 217 |
-
source=ctx.input_source,
|
| 218 |
-
video=ctx.video_path,
|
| 219 |
-
rtsp=ctx.rtsp_url)
|
| 220 |
-
|
| 221 |
-
# ββ Enqueue background execution ββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
-
task = asyncio.create_task(
|
| 223 |
-
_execute_job(job.id, ctx, model, dataset),
|
| 224 |
-
name = f"benchmark_{job.id}",
|
| 225 |
-
)
|
| 226 |
-
_active_tasks[job.id] = task
|
| 227 |
-
task.add_done_callback(lambda _t: _active_tasks.pop(job.id, None))
|
| 228 |
-
|
| 229 |
-
log.info("benchmark_enqueued", job_id=job.id, model=ctx.model_id)
|
| 230 |
-
return job
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
# ββ Background execution ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
-
|
| 235 |
-
async def _execute_job(
|
| 236 |
-
job_id: str,
|
| 237 |
-
ctx: BenchmarkContext,
|
| 238 |
-
model: Model,
|
| 239 |
-
dataset: Dataset,
|
| 240 |
-
) -> None:
|
| 241 |
-
"""Full benchmark lifecycle β runs in an asyncio background task."""
|
| 242 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 243 |
-
|
| 244 |
-
# Transition β running
|
| 245 |
-
ts_color = "\x1b[36m" # Cyan
|
| 246 |
-
info_color = "\x1b[34m" # Blue
|
| 247 |
-
success_color = "\x1b[32m" # Green
|
| 248 |
-
reset = "\x1b[0m"
|
| 249 |
-
|
| 250 |
-
await bench_reg.update_job(
|
| 251 |
-
job_id,
|
| 252 |
-
status = "running",
|
| 253 |
-
progress = 0.0,
|
| 254 |
-
started_at = now,
|
| 255 |
-
log_entry = f"{ts_color}[{now}]{reset} {info_color}Job started{reset} on {ctx.hardware} ({ctx.precision})",
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
runner = BenchmarkExecutor()
|
| 259 |
-
|
| 260 |
-
try:
|
| 261 |
-
# ββ Fetch the persisted job (for executor) ββββββββββββββββββββββββββββ
|
| 262 |
-
job = await bench_reg.get_job(job_id)
|
| 263 |
-
assert job is not None, "Job disappeared from DB after creation"
|
| 264 |
-
|
| 265 |
-
# ββ Define Progress Callback ββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
-
async def on_progress(progress: float, message: str, telemetry: Any | None):
|
| 267 |
-
await bench_reg.update_job(
|
| 268 |
-
job_id,
|
| 269 |
-
progress=progress,
|
| 270 |
-
log_entry=f"{ts_color}[{datetime.now(timezone.utc).isoformat()}]{reset} {info_color}{message}{reset}",
|
| 271 |
-
last_telemetry=telemetry.model_dump() if telemetry and hasattr(telemetry, "model_dump") else telemetry
|
| 272 |
-
)
|
| 273 |
-
|
| 274 |
-
# ββ Execution Loop ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 275 |
-
exec_result = await runner.execute(
|
| 276 |
-
job=job,
|
| 277 |
-
model=model,
|
| 278 |
-
dataset=dataset,
|
| 279 |
-
on_progress=on_progress
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
# ββ Compute metrics βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
-
metrics = _metrics.compute(
|
| 284 |
-
task = ctx.task,
|
| 285 |
-
latencies_ms = exec_result.latencies_ms,
|
| 286 |
-
total_images = exec_result.total_images,
|
| 287 |
-
batch_size = ctx.batch_size,
|
| 288 |
-
vram_samples = exec_result.vram_samples,
|
| 289 |
-
task_scores = exec_result.task_scores,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
# ββ Persist result ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 293 |
-
await bench_reg.save_result(
|
| 294 |
-
job_id = job_id,
|
| 295 |
-
metrics = metrics,
|
| 296 |
-
telemetry_summary = exec_result.telemetry_summary,
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
ended = datetime.now(timezone.utc).isoformat()
|
| 300 |
-
await bench_reg.update_job(
|
| 301 |
-
job_id,
|
| 302 |
-
status = "completed",
|
| 303 |
-
progress = 1.0,
|
| 304 |
-
ended_at = ended,
|
| 305 |
-
log_entry = f"{ts_color}[{ended}]{reset} {success_color}Benchmark completed{reset} β {metrics.fps} FPS",
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
await audit(
|
| 309 |
-
"benchmark_completed",
|
| 310 |
-
job_id = job_id,
|
| 311 |
-
payload = {"model_id": ctx.model_id, "dataset_id": ctx.dataset_id},
|
| 312 |
-
)
|
| 313 |
-
log.info(
|
| 314 |
-
"benchmark_completed",
|
| 315 |
-
job_id = job_id,
|
| 316 |
-
fps = metrics.fps,
|
| 317 |
-
lat_ms = metrics.latency_mean_ms,
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
except asyncio.CancelledError:
|
| 321 |
-
# Task cancelled externally (e.g. server shutdown) β don't swallow
|
| 322 |
-
ended = datetime.now(timezone.utc).isoformat()
|
| 323 |
-
await bench_reg.update_job(
|
| 324 |
-
job_id,
|
| 325 |
-
status = "failed",
|
| 326 |
-
error = "Job cancelled",
|
| 327 |
-
ended_at = ended,
|
| 328 |
-
log_entry = f"{ts_color}[{ended}]{reset} \x1b[31mJob cancelled\x1b[0m",
|
| 329 |
-
)
|
| 330 |
-
raise
|
| 331 |
-
|
| 332 |
-
except Exception as exc:
|
| 333 |
-
ended = datetime.now(timezone.utc).isoformat()
|
| 334 |
-
err_msg = str(exc)
|
| 335 |
-
error_color = "\x1b[31m" # Red
|
| 336 |
-
await bench_reg.update_job(
|
| 337 |
-
job_id,
|
| 338 |
-
status = "failed",
|
| 339 |
-
error = err_msg,
|
| 340 |
-
ended_at = ended,
|
| 341 |
-
log_entry = f"{ts_color}[{ended}]{reset} {error_color}ERROR: {err_msg}{reset}",
|
| 342 |
-
)
|
| 343 |
-
await audit(
|
| 344 |
-
"benchmark_failed",
|
| 345 |
-
job_id = job_id,
|
| 346 |
-
level = "error",
|
| 347 |
-
payload = {"error": err_msg, "model_id": ctx.model_id},
|
| 348 |
-
)
|
| 349 |
-
log.exception("benchmark_failed", job_id=job_id)
|
| 350 |
-
finally:
|
| 351 |
-
pass
|
| 352 |
-
|
| 353 |
-
# ββ Resource resolvers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 354 |
-
|
| 355 |
-
async def _require_model(model_id: str) -> Model:
|
| 356 |
-
model = await get_model(model_id)
|
| 357 |
-
if not model:
|
| 358 |
-
from fastapi import HTTPException
|
| 359 |
-
raise HTTPException(
|
| 360 |
-
status_code = 404,
|
| 361 |
-
detail = f"Model '{model_id}' not found in Model Zoo",
|
| 362 |
-
)
|
| 363 |
-
return model
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
async def _require_dataset(dataset_id: str) -> Dataset:
|
| 367 |
-
dataset = await get_dataset(dataset_id)
|
| 368 |
-
if not dataset:
|
| 369 |
-
from fastapi import HTTPException
|
| 370 |
-
raise HTTPException(
|
| 371 |
-
status_code = 404,
|
| 372 |
-
detail = f"Dataset '{dataset_id}' not found in Dataset Manager",
|
| 373 |
-
)
|
| 374 |
-
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/registry.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/registry.py β Benchmark Registry.
|
| 3 |
-
|
| 4 |
-
All DB interactions for:
|
| 5 |
-
β’ benchmark_jobs β job lifecycle state
|
| 6 |
-
β’ benchmark_results β final metrics + telemetry summary
|
| 7 |
-
β’ benchmark_validation_logs β immutable check audit trail
|
| 8 |
-
|
| 9 |
-
Follows the same pattern as registry/registry.py and datasets/registry.py.
|
| 10 |
-
No direct DB access from other benchmark modules β everything routes here.
|
| 11 |
-
"""
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import json
|
| 15 |
-
import uuid
|
| 16 |
-
from datetime import datetime, timezone
|
| 17 |
-
from typing import Any
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
|
| 20 |
-
from database.connection import get_db
|
| 21 |
-
from models.benchmark import (
|
| 22 |
-
BenchmarkContext,
|
| 23 |
-
BenchmarkJob,
|
| 24 |
-
BenchmarkMetrics,
|
| 25 |
-
BenchmarkResult,
|
| 26 |
-
TelemetrySummary,
|
| 27 |
-
ValidationCheck,
|
| 28 |
-
row_to_job,
|
| 29 |
-
row_to_result,
|
| 30 |
-
)
|
| 31 |
-
from observability.logger import get_logger
|
| 32 |
-
|
| 33 |
-
log = get_logger("benchmark.registry")
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def _get_active_project_benchmark_dir_sync(project_path: str | None) -> Path | None:
|
| 37 |
-
"""Get the absolute path to the 'benchmarks' folder in a given project path."""
|
| 38 |
-
if not project_path:
|
| 39 |
-
return None
|
| 40 |
-
|
| 41 |
-
benchmark_dir = Path(project_path) / "benchmarks"
|
| 42 |
-
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
| 43 |
-
return benchmark_dir
|
| 44 |
-
|
| 45 |
-
async def _get_active_project_benchmark_dir() -> Path | None:
|
| 46 |
-
"""Get the absolute path to the 'benchmarks' folder in the active project."""
|
| 47 |
-
from projects.service import get_active_project_path
|
| 48 |
-
project_path = await get_active_project_path()
|
| 49 |
-
return _get_active_project_benchmark_dir_sync(project_path)
|
| 50 |
-
|
| 51 |
-
async def _save_to_project(filename: str, data: dict) -> None:
|
| 52 |
-
"""Save data to a JSON file in the active project's benchmark folder."""
|
| 53 |
-
benchmark_dir = await _get_active_project_benchmark_dir()
|
| 54 |
-
if not benchmark_dir:
|
| 55 |
-
return
|
| 56 |
-
|
| 57 |
-
file_path = benchmark_dir / filename
|
| 58 |
-
try:
|
| 59 |
-
with open(file_path, "w") as f:
|
| 60 |
-
json.dump(data, f, indent=2)
|
| 61 |
-
except Exception as e:
|
| 62 |
-
log.error("project_persistence_failed", error=str(e), file=filename)
|
| 63 |
-
|
| 64 |
-
# ββ Job CRUD ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
-
|
| 66 |
-
async def create_job(ctx: BenchmarkContext) -> BenchmarkJob:
|
| 67 |
-
db = await get_db()
|
| 68 |
-
job_id = f"bmark-{uuid.uuid4().hex[:12]}"
|
| 69 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 70 |
-
|
| 71 |
-
# Create job object
|
| 72 |
-
job = BenchmarkJob(
|
| 73 |
-
id = job_id,
|
| 74 |
-
model_id = ctx.model_id,
|
| 75 |
-
dataset_id = ctx.dataset_id,
|
| 76 |
-
task = ctx.task,
|
| 77 |
-
framework = ctx.framework,
|
| 78 |
-
hardware = ctx.hardware,
|
| 79 |
-
precision = ctx.precision,
|
| 80 |
-
batch_size = ctx.batch_size,
|
| 81 |
-
config = ctx.model_dump(),
|
| 82 |
-
status = "queued",
|
| 83 |
-
progress = 0.0,
|
| 84 |
-
created_at = now,
|
| 85 |
-
updated_at = now,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
# Persist to SQLite
|
| 89 |
-
await db.execute(
|
| 90 |
-
"""INSERT INTO benchmark_jobs
|
| 91 |
-
(id, model_id, dataset_id, task, framework, hardware,
|
| 92 |
-
precision, batch_size, config,
|
| 93 |
-
status, progress, logs, created_at, updated_at)
|
| 94 |
-
VALUES (?,?,?,?,?,?,?,?,?,'queued',0.0,'[]',?,?)""",
|
| 95 |
-
(
|
| 96 |
-
job_id,
|
| 97 |
-
ctx.model_id, ctx.dataset_id,
|
| 98 |
-
ctx.task, ctx.framework, ctx.hardware,
|
| 99 |
-
ctx.precision, ctx.batch_size,
|
| 100 |
-
json.dumps(ctx.model_dump()),
|
| 101 |
-
now, now,
|
| 102 |
-
),
|
| 103 |
-
)
|
| 104 |
-
await db.commit()
|
| 105 |
-
|
| 106 |
-
# Persist to project folder
|
| 107 |
-
await _save_to_project(f"job_{job_id}.json", job.model_dump())
|
| 108 |
-
|
| 109 |
-
log.info("benchmark_job_created", job_id=job_id, model=ctx.model_id)
|
| 110 |
-
return job
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
async def get_job(job_id: str) -> BenchmarkJob | None:
|
| 114 |
-
db = await get_db()
|
| 115 |
-
async with db.execute(
|
| 116 |
-
"SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)
|
| 117 |
-
) as cur:
|
| 118 |
-
row = await cur.fetchone()
|
| 119 |
-
return row_to_job(row) if row else None
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
async def list_jobs(
|
| 123 |
-
*,
|
| 124 |
-
status: str | None = None,
|
| 125 |
-
model_id: str | None = None,
|
| 126 |
-
limit: int = 100,
|
| 127 |
-
) -> list[BenchmarkJob]:
|
| 128 |
-
db = await get_db()
|
| 129 |
-
clauses: list[str] = []
|
| 130 |
-
params: list[Any] = []
|
| 131 |
-
|
| 132 |
-
if status:
|
| 133 |
-
clauses.append("status = ?")
|
| 134 |
-
params.append(status)
|
| 135 |
-
if model_id:
|
| 136 |
-
clauses.append("model_id = ?")
|
| 137 |
-
params.append(model_id)
|
| 138 |
-
|
| 139 |
-
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
| 140 |
-
params.append(limit)
|
| 141 |
-
|
| 142 |
-
async with db.execute(
|
| 143 |
-
f"SELECT * FROM benchmark_jobs {where} ORDER BY created_at DESC LIMIT ?",
|
| 144 |
-
params,
|
| 145 |
-
) as cur:
|
| 146 |
-
rows = await cur.fetchall()
|
| 147 |
-
return [row_to_job(r) for r in rows]
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
async def update_job(
|
| 151 |
-
job_id: str,
|
| 152 |
-
*,
|
| 153 |
-
status: str | None = None,
|
| 154 |
-
progress: float | None = None,
|
| 155 |
-
error: str | None = None,
|
| 156 |
-
started_at: str | None = None,
|
| 157 |
-
ended_at: str | None = None,
|
| 158 |
-
log_entry: str | None = None,
|
| 159 |
-
last_telemetry: dict | None = None,
|
| 160 |
-
) -> None:
|
| 161 |
-
"""Update mutable fields on a benchmark job atomically."""
|
| 162 |
-
db = await get_db()
|
| 163 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 164 |
-
|
| 165 |
-
sets: list[str] = ["updated_at = ?"]
|
| 166 |
-
vals: list[Any] = [now]
|
| 167 |
-
|
| 168 |
-
if status is not None:
|
| 169 |
-
sets.append("status = ?"); vals.append(status)
|
| 170 |
-
if progress is not None:
|
| 171 |
-
sets.append("progress = ?"); vals.append(round(progress, 4))
|
| 172 |
-
if error is not None:
|
| 173 |
-
sets.append("error = ?"); vals.append(error)
|
| 174 |
-
if started_at is not None:
|
| 175 |
-
sets.append("started_at = ?"); vals.append(started_at)
|
| 176 |
-
if ended_at is not None:
|
| 177 |
-
sets.append("ended_at = ?"); vals.append(ended_at)
|
| 178 |
-
if last_telemetry is not None:
|
| 179 |
-
sets.append("last_telemetry = ?"); vals.append(json.dumps(last_telemetry))
|
| 180 |
-
|
| 181 |
-
if log_entry is not None:
|
| 182 |
-
# Append new entry to the JSON log array (capped at 500 lines)
|
| 183 |
-
async with db.execute(
|
| 184 |
-
"SELECT logs FROM benchmark_jobs WHERE id = ?", (job_id,)
|
| 185 |
-
) as cur:
|
| 186 |
-
row = await cur.fetchone()
|
| 187 |
-
existing = json.loads(row["logs"]) if row and row["logs"] else []
|
| 188 |
-
existing.append(log_entry)
|
| 189 |
-
sets.append("logs = ?")
|
| 190 |
-
vals.append(json.dumps(existing[-500:]))
|
| 191 |
-
|
| 192 |
-
vals.append(job_id)
|
| 193 |
-
# Persist to project folder if we have the job info
|
| 194 |
-
async with db.execute("SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)) as cur:
|
| 195 |
-
row = await cur.fetchone()
|
| 196 |
-
if row:
|
| 197 |
-
job = row_to_job(row)
|
| 198 |
-
if job:
|
| 199 |
-
await _save_to_project(f"job_{job_id}.json", job.model_dump())
|
| 200 |
-
|
| 201 |
-
await db.commit()
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
# ββ Result CRUD βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 205 |
-
|
| 206 |
-
async def save_result(
|
| 207 |
-
*,
|
| 208 |
-
job_id: str,
|
| 209 |
-
metrics: BenchmarkMetrics,
|
| 210 |
-
telemetry_summary: TelemetrySummary,
|
| 211 |
-
) -> BenchmarkResult:
|
| 212 |
-
db = await get_db()
|
| 213 |
-
result_id = f"bres-{uuid.uuid4().hex[:12]}"
|
| 214 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 215 |
-
|
| 216 |
-
# Persist result to SQLite
|
| 217 |
-
await db.execute(
|
| 218 |
-
"""INSERT INTO benchmark_results
|
| 219 |
-
(id, job_id, metrics, telemetry_summary, created_at)
|
| 220 |
-
VALUES (?,?,?,?,?)""",
|
| 221 |
-
(
|
| 222 |
-
result_id,
|
| 223 |
-
job_id,
|
| 224 |
-
json.dumps(metrics.model_dump(exclude_none=True)),
|
| 225 |
-
json.dumps(telemetry_summary.model_dump()),
|
| 226 |
-
now,
|
| 227 |
-
),
|
| 228 |
-
)
|
| 229 |
-
await db.commit()
|
| 230 |
-
|
| 231 |
-
result = BenchmarkResult(
|
| 232 |
-
id = result_id,
|
| 233 |
-
job_id = job_id,
|
| 234 |
-
metrics = metrics,
|
| 235 |
-
telemetry_summary = telemetry_summary,
|
| 236 |
-
created_at = now,
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
# Persist result to project folder
|
| 240 |
-
await _save_to_project(f"result_{job_id}.json", result.model_dump())
|
| 241 |
-
|
| 242 |
-
log.info("benchmark_result_saved", job_id=job_id, result_id=result_id)
|
| 243 |
-
return result
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
async def get_result(job_id: str) -> BenchmarkResult | None:
|
| 247 |
-
db = await get_db()
|
| 248 |
-
async with db.execute(
|
| 249 |
-
"""SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
|
| 250 |
-
FROM benchmark_results r
|
| 251 |
-
JOIN benchmark_jobs j ON r.job_id = j.id
|
| 252 |
-
WHERE r.job_id = ?""", (job_id,)
|
| 253 |
-
) as cur:
|
| 254 |
-
row = await cur.fetchone()
|
| 255 |
-
return row_to_result(row) if row else None
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
async def list_results(*, limit: int = 100) -> list[BenchmarkResult]:
|
| 259 |
-
db = await get_db()
|
| 260 |
-
async with db.execute(
|
| 261 |
-
"""SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
|
| 262 |
-
FROM benchmark_results r
|
| 263 |
-
JOIN benchmark_jobs j ON r.job_id = j.id
|
| 264 |
-
ORDER BY r.created_at DESC LIMIT ?""", (limit,)
|
| 265 |
-
) as cur:
|
| 266 |
-
rows = await cur.fetchall()
|
| 267 |
-
return [row_to_result(r) for r in rows]
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
# ββ Validation Log ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 271 |
-
|
| 272 |
-
async def save_validation_log(
|
| 273 |
-
*,
|
| 274 |
-
job_id: str,
|
| 275 |
-
model_id: str,
|
| 276 |
-
dataset_id: str,
|
| 277 |
-
checks: list[ValidationCheck],
|
| 278 |
-
passed: bool,
|
| 279 |
-
) -> None:
|
| 280 |
-
"""Persist an immutable record of all compatibility checks."""
|
| 281 |
-
db = await get_db()
|
| 282 |
-
log_id = f"bval-{uuid.uuid4().hex[:12]}"
|
| 283 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 284 |
-
|
| 285 |
-
await db.execute(
|
| 286 |
-
"""INSERT INTO benchmark_validation_logs
|
| 287 |
-
(id, job_id, model_id, dataset_id, checks, passed, created_at)
|
| 288 |
-
VALUES (?,?,?,?,?,?,?)""",
|
| 289 |
-
(
|
| 290 |
-
log_id, job_id, model_id, dataset_id,
|
| 291 |
-
json.dumps([c.model_dump() for c in checks]),
|
| 292 |
-
1 if passed else 0,
|
| 293 |
-
now,
|
| 294 |
-
),
|
| 295 |
-
)
|
| 296 |
-
await db.commit()
|
| 297 |
-
log.info(
|
| 298 |
-
"validation_log_saved",
|
| 299 |
-
job_id = job_id,
|
| 300 |
-
passed = passed,
|
| 301 |
-
n_checks = len(checks),
|
| 302 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/telemetry.py
DELETED
|
@@ -1,182 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/telemetry.py β Real-time Telemetry Collector.
|
| 3 |
-
|
| 4 |
-
Collects GPU/hardware metrics at 2 Hz during benchmark execution.
|
| 5 |
-
Designed as a drop-in adapter:
|
| 6 |
-
β’ Local dev β simulates realistic GPU readings based on hardware tier
|
| 7 |
-
β’ Production β replace _read_gpu_metrics() with pynvml calls:
|
| 8 |
-
nvmlDeviceGetUtilizationRates()
|
| 9 |
-
nvmlDeviceGetMemoryInfo()
|
| 10 |
-
nvmlDeviceGetTemperature()
|
| 11 |
-
nvmlDeviceGetPowerUsage()
|
| 12 |
-
|
| 13 |
-
Usage (async context):
|
| 14 |
-
collector = TelemetryCollector("rtx4090", vram_total_gb=24.0)
|
| 15 |
-
await collector.start()
|
| 16 |
-
# ... run inference ...
|
| 17 |
-
summary = await collector.stop()
|
| 18 |
-
samples = collector.samples
|
| 19 |
-
"""
|
| 20 |
-
from __future__ import annotations
|
| 21 |
-
|
| 22 |
-
import asyncio
|
| 23 |
-
import random
|
| 24 |
-
import statistics
|
| 25 |
-
import time
|
| 26 |
-
|
| 27 |
-
from models.benchmark import TelemetrySample, TelemetrySummary
|
| 28 |
-
from observability.logger import get_logger
|
| 29 |
-
|
| 30 |
-
log = get_logger("benchmark.telemetry")
|
| 31 |
-
|
| 32 |
-
# ββ Hardware simulation profiles ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
-
# (base_util%, base_temp_C, base_power_W)
|
| 34 |
-
_HW_PROFILES: dict[str, tuple[float, float, float]] = {
|
| 35 |
-
"rtx4090": (88.0, 74.0, 380.0),
|
| 36 |
-
"rtx4080": (84.0, 70.0, 280.0),
|
| 37 |
-
"rtx4070": (80.0, 68.0, 200.0),
|
| 38 |
-
"rtx3090": (85.0, 72.0, 320.0),
|
| 39 |
-
"rtx3080": (82.0, 70.0, 250.0),
|
| 40 |
-
"rtx3070": (78.0, 66.0, 180.0),
|
| 41 |
-
"rtx3060": (74.0, 64.0, 150.0),
|
| 42 |
-
"a100": (90.0, 68.0, 350.0),
|
| 43 |
-
"h100": (92.0, 65.0, 550.0),
|
| 44 |
-
"v100": (87.0, 70.0, 280.0),
|
| 45 |
-
"t4": (75.0, 62.0, 60.0),
|
| 46 |
-
"gpu": (70.0, 65.0, 150.0),
|
| 47 |
-
"cpu": (0.0, 0.0, 0.0),
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
_COLLECTION_INTERVAL_S = 0.5 # 2 Hz
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class TelemetryCollector:
|
| 54 |
-
"""
|
| 55 |
-
Async telemetry collector. Call start() before inference, stop() after.
|
| 56 |
-
Thread-safe via asyncio (single-threaded event loop).
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
def __init__(self, hardware: str, vram_total_gb: float = 8.0) -> None:
|
| 60 |
-
self._hardware = hardware
|
| 61 |
-
self._vram_total = vram_total_gb
|
| 62 |
-
self._hw_profile = self._resolve_profile(hardware)
|
| 63 |
-
self._samples: list[TelemetrySample] = []
|
| 64 |
-
self._running = False
|
| 65 |
-
self._task: asyncio.Task | None = None
|
| 66 |
-
|
| 67 |
-
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
-
|
| 69 |
-
async def start(self) -> None:
|
| 70 |
-
self._running = True
|
| 71 |
-
self._samples = []
|
| 72 |
-
self._task = asyncio.create_task(
|
| 73 |
-
self._collect_loop(), name="telemetry_collector"
|
| 74 |
-
)
|
| 75 |
-
log.debug("telemetry_started", hardware=self._hardware)
|
| 76 |
-
|
| 77 |
-
async def stop(self) -> TelemetrySummary:
|
| 78 |
-
self._running = False
|
| 79 |
-
if self._task and not self._task.done():
|
| 80 |
-
self._task.cancel()
|
| 81 |
-
try:
|
| 82 |
-
await self._task
|
| 83 |
-
except asyncio.CancelledError:
|
| 84 |
-
pass
|
| 85 |
-
log.debug(
|
| 86 |
-
"telemetry_stopped",
|
| 87 |
-
hardware = self._hardware,
|
| 88 |
-
samples = len(self._samples),
|
| 89 |
-
)
|
| 90 |
-
return self._build_summary()
|
| 91 |
-
|
| 92 |
-
def record_batch_context(self, batch_idx: int, progress: float) -> None:
|
| 93 |
-
"""Annotate the most recent sample with the current batch context."""
|
| 94 |
-
if self._samples:
|
| 95 |
-
self._samples[-1].batch_idx = batch_idx
|
| 96 |
-
self._samples[-1].progress = progress
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def samples(self) -> list[TelemetrySample]:
|
| 100 |
-
return list(self._samples)
|
| 101 |
-
|
| 102 |
-
# ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
-
|
| 104 |
-
async def _collect_loop(self) -> None:
|
| 105 |
-
while self._running:
|
| 106 |
-
sample = self._read_gpu_metrics()
|
| 107 |
-
self._samples.append(sample)
|
| 108 |
-
await asyncio.sleep(_COLLECTION_INTERVAL_S)
|
| 109 |
-
|
| 110 |
-
def _read_gpu_metrics(self) -> TelemetrySample:
|
| 111 |
-
"""
|
| 112 |
-
Returns a TelemetrySample for the current hardware state.
|
| 113 |
-
|
| 114 |
-
PRODUCTION SWAP: Replace this body with pynvml calls:
|
| 115 |
-
handle = nvmlDeviceGetHandleByIndex(0)
|
| 116 |
-
util = nvmlDeviceGetUtilizationRates(handle)
|
| 117 |
-
mem = nvmlDeviceGetMemoryInfo(handle)
|
| 118 |
-
temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
|
| 119 |
-
power = nvmlDeviceGetPowerUsage(handle) / 1000 # mW β W
|
| 120 |
-
"""
|
| 121 |
-
base_util, base_temp, base_power = self._hw_profile
|
| 122 |
-
|
| 123 |
-
if base_util == 0.0: # CPU path β no meaningful GPU readings
|
| 124 |
-
return TelemetrySample(
|
| 125 |
-
timestamp = time.time(),
|
| 126 |
-
gpu_util_pct = 0.0,
|
| 127 |
-
vram_used_gb = 0.0,
|
| 128 |
-
vram_total_gb = 0.0,
|
| 129 |
-
temp_c = 0.0,
|
| 130 |
-
power_w = 0.0,
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# Simulate realistic jitter (Β±5% util, Β±3Β°C, Β±10W)
|
| 134 |
-
jitter_util = random.gauss(0, 3.0)
|
| 135 |
-
jitter_temp = random.gauss(0, 1.5)
|
| 136 |
-
jitter_power = random.gauss(0, 8.0)
|
| 137 |
-
vram_frac = random.uniform(0.58, 0.72)
|
| 138 |
-
|
| 139 |
-
return TelemetrySample(
|
| 140 |
-
timestamp = time.time(),
|
| 141 |
-
gpu_util_pct = max(0.0, min(100.0, base_util + jitter_util)),
|
| 142 |
-
vram_used_gb = round(
|
| 143 |
-
max(0.0, min(self._vram_total, self._vram_total * vram_frac)), 3
|
| 144 |
-
),
|
| 145 |
-
vram_total_gb = self._vram_total,
|
| 146 |
-
temp_c = round(max(0.0, base_temp + jitter_temp), 1),
|
| 147 |
-
power_w = round(max(0.0, base_power + jitter_power), 1),
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
def _build_summary(self) -> TelemetrySummary:
|
| 151 |
-
if not self._samples:
|
| 152 |
-
return TelemetrySummary()
|
| 153 |
-
|
| 154 |
-
utils = [s.gpu_util_pct for s in self._samples]
|
| 155 |
-
vrams = [s.vram_used_gb for s in self._samples]
|
| 156 |
-
temps = [s.temp_c for s in self._samples]
|
| 157 |
-
powers = [s.power_w for s in self._samples]
|
| 158 |
-
|
| 159 |
-
def _safe_mean(lst: list[float]) -> float:
|
| 160 |
-
return statistics.mean(lst) if lst else 0.0
|
| 161 |
-
|
| 162 |
-
return TelemetrySummary(
|
| 163 |
-
gpu_util_avg = round(_safe_mean(utils), 2),
|
| 164 |
-
gpu_util_peak = round(max(utils), 2),
|
| 165 |
-
vram_avg_gb = round(_safe_mean(vrams), 3),
|
| 166 |
-
vram_peak_gb = round(max(vrams), 3),
|
| 167 |
-
temp_avg_c = round(_safe_mean(temps), 1),
|
| 168 |
-
temp_peak_c = round(max(temps), 1),
|
| 169 |
-
power_avg_w = round(_safe_mean(powers), 1),
|
| 170 |
-
power_peak_w = round(max(powers), 1),
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
@staticmethod
|
| 174 |
-
def _resolve_profile(hardware: str) -> tuple[float, float, float]:
|
| 175 |
-
hw = hardware.lower().replace(" ", "").replace("-", "")
|
| 176 |
-
for key, profile in _HW_PROFILES.items():
|
| 177 |
-
if key in hw:
|
| 178 |
-
return profile
|
| 179 |
-
# Default for unknown GPU-class hardware
|
| 180 |
-
if any(x in hw for x in ("gpu", "rtx", "gtx", "cuda", "vram")):
|
| 181 |
-
return _HW_PROFILES["gpu"]
|
| 182 |
-
return _HW_PROFILES["cpu"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/torch_runner.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
benchmark/torch_runner.py β Synchronous GPU inference runner.
|
| 3 |
-
|
| 4 |
-
Called from BenchmarkExecutor via asyncio.run_in_executor() so it never
|
| 5 |
-
blocks the event loop. PyTorch is an optional dependency β if it is not
|
| 6 |
-
installed the module raises ImportError and execution.py falls back to
|
| 7 |
-
the simulation path.
|
| 8 |
-
|
| 9 |
-
Supported weight formats (detected by file extension):
|
| 10 |
-
.pt / .pth β torch.load (TorchScript or state-dict)
|
| 11 |
-
.safetensors β safetensors.torch.load_file
|
| 12 |
-
.onnx β onnxruntime InferenceSession
|
| 13 |
-
|
| 14 |
-
PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
|
| 15 |
-
"""
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import time
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from typing import Any
|
| 21 |
-
|
| 22 |
-
# ββ Model cache (keyed by absolute path) βββββββββββββββββββββββββββββββββββββ
|
| 23 |
-
_MODEL_CACHE: dict[str, Any] = {}
|
| 24 |
-
|
| 25 |
-
# Standard input shapes per task (B, C, H, W)
|
| 26 |
-
_INPUT_SHAPES: dict[str, tuple[int, int, int]] = {
|
| 27 |
-
"detection": (3, 640, 640),
|
| 28 |
-
"segmentation": (3, 640, 640),
|
| 29 |
-
"classification": (3, 224, 224),
|
| 30 |
-
"generation": (3, 512, 512),
|
| 31 |
-
"embedding": (3, 224, 224),
|
| 32 |
-
}
|
| 33 |
-
_DEFAULT_SHAPE = (3, 640, 640)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def run_torch_batch(model_path: str, batch_size: int, task: str = "detection") -> float:
|
| 37 |
-
"""Run one inference batch and return per-image latency in ms.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
model_path: Absolute path to the weight file.
|
| 41 |
-
batch_size: Number of images in the batch.
|
| 42 |
-
task: Model task (affects dummy input shape).
|
| 43 |
-
|
| 44 |
-
Returns:
|
| 45 |
-
Latency per image in milliseconds.
|
| 46 |
-
"""
|
| 47 |
-
import torch # raises ImportError if not installed
|
| 48 |
-
|
| 49 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
-
ext = Path(model_path).suffix.lower()
|
| 51 |
-
|
| 52 |
-
model = _load_model(model_path, ext, device)
|
| 53 |
-
c, h, w = _INPUT_SHAPES.get(task, _DEFAULT_SHAPE)
|
| 54 |
-
dummy = torch.zeros(batch_size, c, h, w, device=device)
|
| 55 |
-
|
| 56 |
-
# Warm-up pass (first call is slower due to CUDA kernel compilation)
|
| 57 |
-
if device == "cuda":
|
| 58 |
-
with torch.no_grad():
|
| 59 |
-
_forward(model, dummy, ext, device)
|
| 60 |
-
torch.cuda.synchronize()
|
| 61 |
-
|
| 62 |
-
# Timed pass
|
| 63 |
-
if device == "cuda":
|
| 64 |
-
torch.cuda.synchronize()
|
| 65 |
-
t0 = time.perf_counter()
|
| 66 |
-
with torch.no_grad():
|
| 67 |
-
_forward(model, dummy, ext, device)
|
| 68 |
-
if device == "cuda":
|
| 69 |
-
torch.cuda.synchronize()
|
| 70 |
-
elapsed_ms = (time.perf_counter() - t0) * 1000
|
| 71 |
-
|
| 72 |
-
return elapsed_ms / batch_size
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _load_model(path: str, ext: str, device: str) -> Any:
|
| 76 |
-
"""Load and cache the model by absolute path."""
|
| 77 |
-
if path in _MODEL_CACHE:
|
| 78 |
-
return _MODEL_CACHE[path]
|
| 79 |
-
|
| 80 |
-
model = _load_by_ext(path, ext, device)
|
| 81 |
-
_MODEL_CACHE[path] = model
|
| 82 |
-
return model
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def _load_by_ext(path: str, ext: str, device: str) -> Any:
|
| 86 |
-
"""Select loader based on file extension."""
|
| 87 |
-
if ext in (".pt", ".pth"):
|
| 88 |
-
return _load_torch(path, device)
|
| 89 |
-
if ext == ".safetensors":
|
| 90 |
-
return _load_safetensors(path, device)
|
| 91 |
-
if ext == ".onnx":
|
| 92 |
-
return _load_onnx(path)
|
| 93 |
-
raise ValueError(f"Unsupported model format: {ext}")
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def _load_torch(path: str, device: str) -> Any:
|
| 97 |
-
import torch
|
| 98 |
-
# <<< REPLACE IN PRODUCTION >>> with proper model class instantiation
|
| 99 |
-
# TorchScript models can be loaded directly; state-dict models need
|
| 100 |
-
# the model class to be imported separately.
|
| 101 |
-
try:
|
| 102 |
-
model = torch.jit.load(path, map_location=device)
|
| 103 |
-
model.eval()
|
| 104 |
-
return model
|
| 105 |
-
except RuntimeError:
|
| 106 |
-
# Not a TorchScript model β try loading as a full checkpoint
|
| 107 |
-
obj = torch.load(path, map_location=device, weights_only=False)
|
| 108 |
-
if hasattr(obj, "eval"):
|
| 109 |
-
obj.eval()
|
| 110 |
-
return obj
|
| 111 |
-
# It's a state-dict β we cannot run inference without knowing the arch
|
| 112 |
-
raise RuntimeError(
|
| 113 |
-
f"Model at {path} is a state-dict; cannot run inference without "
|
| 114 |
-
"the model class. Use a TorchScript-exported .pt file."
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def _load_safetensors(path: str, device: str) -> Any:
|
| 119 |
-
# <<< REPLACE IN PRODUCTION >>> safetensors gives tensors only;
|
| 120 |
-
# you still need the model class. This is intentionally left as a
|
| 121 |
-
# placeholder that raises a clear error rather than silently failing.
|
| 122 |
-
raise NotImplementedError(
|
| 123 |
-
"safetensors inference requires the model class to be registered. "
|
| 124 |
-
"Convert to TorchScript or ONNX for architecture-agnostic inference."
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def _load_onnx(path: str) -> Any:
|
| 129 |
-
import onnxruntime as ort # type: ignore[import]
|
| 130 |
-
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 131 |
-
return ort.InferenceSession(path, providers=providers)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def _forward(model: Any, dummy: Any, ext: str, device: str) -> Any:
|
| 135 |
-
"""Run a single forward pass, dispatching by model type."""
|
| 136 |
-
if ext == ".onnx":
|
| 137 |
-
import numpy as np
|
| 138 |
-
np_input = dummy.cpu().numpy()
|
| 139 |
-
input_name = model.get_inputs()[0].name
|
| 140 |
-
return model.run(None, {input_name: np_input})
|
| 141 |
-
# TorchScript / nn.Module
|
| 142 |
-
return model(dummy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.py
CHANGED
|
@@ -21,31 +21,15 @@ class Settings(BaseSettings):
|
|
| 21 |
# ββ API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
host: str = "0.0.0.0"
|
| 23 |
port: int = 7860 # Default for HF Spaces
|
| 24 |
-
cors_origins: list[str] = [
|
| 25 |
-
"http://localhost:3000",
|
| 26 |
-
"http://127.0.0.1:3000",
|
| 27 |
-
"http://localhost:5173",
|
| 28 |
-
"http://127.0.0.1:5173",
|
| 29 |
-
"http://localhost:2000",
|
| 30 |
-
"http://127.0.0.1:2000",
|
| 31 |
-
]
|
| 32 |
|
| 33 |
# ββ Storage βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
base_dir: Path = Path(__file__).resolve().parents[1]
|
| 35 |
data_dir: Path = base_dir / "data"
|
| 36 |
-
|
| 37 |
-
datasets_dir: Path = data_dir / "datasets" # root for imported datasets
|
| 38 |
-
logs_dir: Path = data_dir / "logs"
|
| 39 |
-
db_path: Path = data_dir / "modelzoo.db"
|
| 40 |
-
|
| 41 |
-
# ββ Download Manager ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
-
max_concurrent_downloads: int = 5
|
| 43 |
-
download_chunk_size: int = 1024 * 1024 # 1 MB
|
| 44 |
-
download_max_retries: int = 3
|
| 45 |
-
download_retry_delay: float = 2.0 # seconds (base, exponential backoff)
|
| 46 |
|
| 47 |
# ββ Search ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
-
search_max_results: int =
|
| 49 |
|
| 50 |
# ββ Sync ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
auto_sync_on_startup: bool = True
|
|
@@ -54,30 +38,10 @@ class Settings(BaseSettings):
|
|
| 54 |
hf_api_base: str = "https://huggingface.co/api"
|
| 55 |
hf_hub_url: str = "https://huggingface.co"
|
| 56 |
hf_token: str | None = None # Optional: HF_TOKEN env var
|
| 57 |
-
hf_models_per_task: int =
|
| 58 |
-
|
| 59 |
-
# ββ ONNX Zoo ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
-
onnx_models_url: str = (
|
| 61 |
-
"https://raw.githubusercontent.com/onnx/models/main/README.md"
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
# ββ Benchmark Bridge ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
-
benchmark_max_concurrent: int = 3 # max parallel benchmark jobs
|
| 66 |
-
benchmark_max_log_lines: int = 500 # log entries kept per job
|
| 67 |
-
benchmark_ws_poll_hz: float = 2.0 # WebSocket telemetry poll rate
|
| 68 |
-
|
| 69 |
-
# ββ Dataset Manager βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
-
roboflow_api_base: str = "https://api.roboflow.com"
|
| 71 |
-
dataset_import_workers: int = 3 # max concurrent import jobs
|
| 72 |
-
dataset_chunk_size: int = 1024 * 1024 * 4 # 4 MB download chunk
|
| 73 |
-
roboflow_cache_ttl_secs: int = 3600 # 1 hour
|
| 74 |
|
| 75 |
def ensure_dirs(self) -> None:
|
| 76 |
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 77 |
-
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
-
self.datasets_dir.mkdir(parents=True, exist_ok=True)
|
| 79 |
-
(self.datasets_dir / "_tmp").mkdir(parents=True, exist_ok=True)
|
| 80 |
-
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
|
| 82 |
|
| 83 |
settings = Settings()
|
|
|
|
| 21 |
# ββ API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
host: str = "0.0.0.0"
|
| 23 |
port: int = 7860 # Default for HF Spaces
|
| 24 |
+
cors_origins: list[str] = ["*"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# ββ Storage βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
base_dir: Path = Path(__file__).resolve().parents[1]
|
| 28 |
data_dir: Path = base_dir / "data"
|
| 29 |
+
db_path: Path = data_dir / "modelzoo.db"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# ββ Search ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
search_max_results: int = 1000
|
| 33 |
|
| 34 |
# ββ Sync ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
auto_sync_on_startup: bool = True
|
|
|
|
| 38 |
hf_api_base: str = "https://huggingface.co/api"
|
| 39 |
hf_hub_url: str = "https://huggingface.co"
|
| 40 |
hf_token: str | None = None # Optional: HF_TOKEN env var
|
| 41 |
+
hf_models_per_task: int = 200 # Discovery server pulls more per task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def ensure_dirs(self) -> None:
|
| 44 |
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
settings = Settings()
|
download/__init__.py
DELETED
|
File without changes
|
download/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (143 Bytes)
|
|
|
download/__pycache__/manager.cpython-310.pyc
DELETED
|
Binary file (11.1 kB)
|
|
|
download/manager.py
DELETED
|
@@ -1,366 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
download/manager.py β Async download manager.
|
| 3 |
-
Handles queueing, concurrency limiting, retry, resume, and progress tracking.
|
| 4 |
-
All state is persisted in the jobs table for crash recovery.
|
| 5 |
-
"""
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import asyncio
|
| 9 |
-
import json
|
| 10 |
-
import uuid
|
| 11 |
-
from datetime import datetime, timezone
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
from typing import Any
|
| 14 |
-
|
| 15 |
-
import aiofiles
|
| 16 |
-
import httpx
|
| 17 |
-
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 18 |
-
|
| 19 |
-
from config import settings
|
| 20 |
-
from database.connection import get_db
|
| 21 |
-
from models.job import Job, row_to_job
|
| 22 |
-
from observability.logger import audit, get_logger
|
| 23 |
-
from registry.registry import get_model, update_model_status
|
| 24 |
-
|
| 25 |
-
log = get_logger("download_manager")
|
| 26 |
-
|
| 27 |
-
# ββ Semaphore caps concurrent downloads βββββββββββββββββββββββββββββββββββββββ
|
| 28 |
-
_download_sem: asyncio.Semaphore | None = None
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _get_sem() -> asyncio.Semaphore:
|
| 32 |
-
global _download_sem
|
| 33 |
-
if _download_sem is None:
|
| 34 |
-
_download_sem = asyncio.Semaphore(settings.max_concurrent_downloads)
|
| 35 |
-
return _download_sem
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# ββ Job CRUD ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
-
|
| 40 |
-
async def _create_job(
|
| 41 |
-
job_type: str,
|
| 42 |
-
model_id: str,
|
| 43 |
-
model_name: str,
|
| 44 |
-
meta: dict | None = None,
|
| 45 |
-
) -> str:
|
| 46 |
-
job_id = str(uuid.uuid4())
|
| 47 |
-
db = await get_db()
|
| 48 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 49 |
-
await db.execute(
|
| 50 |
-
"""INSERT INTO jobs (id, type, status, model_id, model_name, meta, created_at, updated_at)
|
| 51 |
-
VALUES (?,?,?,?,?,?,?,?)""",
|
| 52 |
-
(job_id, job_type, "queued", model_id, model_name,
|
| 53 |
-
json.dumps(meta or {}), now, now),
|
| 54 |
-
)
|
| 55 |
-
await db.commit()
|
| 56 |
-
log.info("job_created", job_id=job_id, type=job_type, model_id=model_id)
|
| 57 |
-
await audit("job_created", model_id=model_id, job_id=job_id,
|
| 58 |
-
payload={"type": job_type, "model_name": model_name})
|
| 59 |
-
return job_id
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _is_shard_file(filename: str) -> bool:
|
| 63 |
-
"""Return True if the file is part of a sharded model (e.g. model-00001-of-00003.safetensors)."""
|
| 64 |
-
import re
|
| 65 |
-
return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
async def _get_active_version(model_id: str) -> str:
|
| 69 |
-
"""Return the active version string for a model, defaulting to 'v1'."""
|
| 70 |
-
model = await get_model(model_id)
|
| 71 |
-
if model and model.active_version:
|
| 72 |
-
return model.active_version
|
| 73 |
-
return "v1"
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
@retry(
|
| 77 |
-
stop=stop_after_attempt(3),
|
| 78 |
-
wait=wait_exponential(multiplier=1, min=1, max=6),
|
| 79 |
-
reraise=True,
|
| 80 |
-
)
|
| 81 |
-
async def _resolve_hf_download_url(repo_id: str) -> str:
|
| 82 |
-
"""Resolve a reliable download URL for a HF repo.
|
| 83 |
-
|
| 84 |
-
Prefer safetensors over pytorch_model.bin; fall back to onnx if needed.
|
| 85 |
-
"""
|
| 86 |
-
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
| 87 |
-
resp = await client.get(f"{settings.hf_api_base}/models/{repo_id}")
|
| 88 |
-
resp.raise_for_status()
|
| 89 |
-
data = resp.json()
|
| 90 |
-
|
| 91 |
-
siblings = data.get("siblings") or []
|
| 92 |
-
filenames: list[str] = []
|
| 93 |
-
for s in siblings:
|
| 94 |
-
fn = s.get("rfilename") or s.get("filename")
|
| 95 |
-
if fn:
|
| 96 |
-
filenames.append(fn)
|
| 97 |
-
|
| 98 |
-
preferred_exact = [
|
| 99 |
-
"model.safetensors",
|
| 100 |
-
"pytorch_model.bin",
|
| 101 |
-
"model.onnx",
|
| 102 |
-
]
|
| 103 |
-
for fn in preferred_exact:
|
| 104 |
-
if fn in filenames:
|
| 105 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
|
| 106 |
-
|
| 107 |
-
preferred_suffix = [".safetensors", ".bin", ".onnx", ".pt", ".pth"]
|
| 108 |
-
for suffix in preferred_suffix:
|
| 109 |
-
for fn in filenames:
|
| 110 |
-
if fn.endswith(suffix) and not _is_shard_file(fn):
|
| 111 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
|
| 112 |
-
|
| 113 |
-
# Accept sharded files as a fallback (first shard of safetensors)
|
| 114 |
-
for fn in filenames:
|
| 115 |
-
if _is_shard_file(fn):
|
| 116 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
|
| 117 |
-
|
| 118 |
-
# Last resort: try the index file for sharded models
|
| 119 |
-
if "model.safetensors.index.json" in filenames:
|
| 120 |
-
# For sharded models without a single file, use the first shard
|
| 121 |
-
for fn in filenames:
|
| 122 |
-
if fn.startswith("model-") and fn.endswith(".safetensors"):
|
| 123 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{fn}"
|
| 124 |
-
|
| 125 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/pytorch_model.bin"
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
async def _update_job(
|
| 129 |
-
job_id: str,
|
| 130 |
-
status: str | None = None,
|
| 131 |
-
progress: float | None = None,
|
| 132 |
-
error: str | None = None,
|
| 133 |
-
started_at: str | None = None,
|
| 134 |
-
ended_at: str | None = None,
|
| 135 |
-
) -> None:
|
| 136 |
-
db = await get_db()
|
| 137 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 138 |
-
parts: list[str] = ["updated_at = ?"]
|
| 139 |
-
vals: list[Any] = [now]
|
| 140 |
-
if status is not None: parts.append("status = ?"); vals.append(status)
|
| 141 |
-
if progress is not None: parts.append("progress = ?"); vals.append(progress)
|
| 142 |
-
if error is not None: parts.append("error = ?"); vals.append(error)
|
| 143 |
-
if started_at: parts.append("started_at = ?"); vals.append(started_at)
|
| 144 |
-
if ended_at: parts.append("ended_at = ?"); vals.append(ended_at)
|
| 145 |
-
vals.append(job_id)
|
| 146 |
-
await db.execute(f"UPDATE jobs SET {', '.join(parts)} WHERE id = ?", vals)
|
| 147 |
-
await db.commit()
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
# ββ Download worker βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
-
|
| 152 |
-
async def _execute_download(
|
| 153 |
-
job_id: str,
|
| 154 |
-
model_id: str,
|
| 155 |
-
model_name: str,
|
| 156 |
-
download_url: str,
|
| 157 |
-
dest_path: Path,
|
| 158 |
-
) -> None:
|
| 159 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 160 |
-
await _update_job(job_id, status="running", started_at=now)
|
| 161 |
-
|
| 162 |
-
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 163 |
-
tmp_path = dest_path.with_suffix(".tmp")
|
| 164 |
-
|
| 165 |
-
# Determine resume offset
|
| 166 |
-
resume_offset = tmp_path.stat().st_size if tmp_path.exists() else 0
|
| 167 |
-
|
| 168 |
-
headers: dict[str, str] = {}
|
| 169 |
-
if resume_offset:
|
| 170 |
-
headers["Range"] = f"bytes={resume_offset}-"
|
| 171 |
-
log.info("download_resume", job_id=job_id, offset=resume_offset)
|
| 172 |
-
|
| 173 |
-
try:
|
| 174 |
-
async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
|
| 175 |
-
async with client.stream("GET", download_url, headers=headers) as resp:
|
| 176 |
-
resp.raise_for_status()
|
| 177 |
-
total = int(resp.headers.get("content-length", 0)) + resume_offset
|
| 178 |
-
downloaded = resume_offset
|
| 179 |
-
|
| 180 |
-
async with aiofiles.open(tmp_path, "ab" if resume_offset else "wb") as fh:
|
| 181 |
-
async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size):
|
| 182 |
-
await fh.write(chunk)
|
| 183 |
-
downloaded += len(chunk)
|
| 184 |
-
progress = downloaded / total if total else 0
|
| 185 |
-
await _update_job(job_id, progress=min(progress, 0.99))
|
| 186 |
-
|
| 187 |
-
# Rename tmp β final
|
| 188 |
-
tmp_path.rename(dest_path)
|
| 189 |
-
now_end = datetime.now(timezone.utc).isoformat()
|
| 190 |
-
await _update_job(job_id, status="completed", progress=1.0, ended_at=now_end)
|
| 191 |
-
await update_model_status(
|
| 192 |
-
model_id,
|
| 193 |
-
status="cached",
|
| 194 |
-
downloaded=True,
|
| 195 |
-
local_path=str(dest_path),
|
| 196 |
-
)
|
| 197 |
-
# Copy into the active project's workspace models/ folder
|
| 198 |
-
from projects.service import link_model_to_active_project
|
| 199 |
-
await link_model_to_active_project(model_id, str(dest_path))
|
| 200 |
-
log.info("download_complete", job_id=job_id, model_id=model_id, path=str(dest_path))
|
| 201 |
-
await audit("download_complete", model_id=model_id, job_id=job_id,
|
| 202 |
-
payload={"path": str(dest_path)})
|
| 203 |
-
|
| 204 |
-
except Exception as exc:
|
| 205 |
-
now_end = datetime.now(timezone.utc).isoformat()
|
| 206 |
-
await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end)
|
| 207 |
-
await update_model_status(model_id, status="error")
|
| 208 |
-
log.error("download_failed", job_id=job_id, error=str(exc))
|
| 209 |
-
await audit("download_failed", model_id=model_id, job_id=job_id,
|
| 210 |
-
payload={"error": str(exc)}, level="error")
|
| 211 |
-
raise
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
-
|
| 216 |
-
async def enqueue_download(
|
| 217 |
-
model_id: str,
|
| 218 |
-
model_name: str,
|
| 219 |
-
download_url: str | None = None,
|
| 220 |
-
version: str | None = None,
|
| 221 |
-
) -> str:
|
| 222 |
-
"""Create a download job and dispatch resolution+download in the background.
|
| 223 |
-
|
| 224 |
-
This function should not perform network calls; otherwise /download can return 500
|
| 225 |
-
on transient provider errors.
|
| 226 |
-
"""
|
| 227 |
-
job_id = await _create_job("download", model_id, model_name)
|
| 228 |
-
|
| 229 |
-
asyncio.create_task(
|
| 230 |
-
_rate_limited_download_resolving(job_id, model_id, model_name, download_url, version)
|
| 231 |
-
)
|
| 232 |
-
return job_id
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
async def _rate_limited_download_resolving(
|
| 236 |
-
job_id: str,
|
| 237 |
-
model_id: str,
|
| 238 |
-
model_name: str,
|
| 239 |
-
download_url: str | None,
|
| 240 |
-
version: str | None = None,
|
| 241 |
-
) -> None:
|
| 242 |
-
async with _get_sem():
|
| 243 |
-
try:
|
| 244 |
-
resolved_url = await _resolve_download_url(model_id, download_url, version)
|
| 245 |
-
# Version folder: use explicit version label, else active_version from DB
|
| 246 |
-
folder = version or await _get_active_version(model_id)
|
| 247 |
-
ext = Path(resolved_url.split("?")[0]).suffix or ".bin"
|
| 248 |
-
dest_path = settings.models_dir / model_id / folder / f"model{ext}"
|
| 249 |
-
await _execute_download(job_id, model_id, model_name, resolved_url, dest_path)
|
| 250 |
-
except Exception as exc:
|
| 251 |
-
now_end = datetime.now(timezone.utc).isoformat()
|
| 252 |
-
await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end)
|
| 253 |
-
await update_model_status(model_id, status="error")
|
| 254 |
-
log.error("download_failed", job_id=job_id, error=str(exc))
|
| 255 |
-
await audit(
|
| 256 |
-
"download_failed",
|
| 257 |
-
model_id=model_id,
|
| 258 |
-
job_id=job_id,
|
| 259 |
-
payload={"error": str(exc)},
|
| 260 |
-
level="error",
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
async def _resolve_download_url(
|
| 265 |
-
model_id: str,
|
| 266 |
-
download_url: str | None,
|
| 267 |
-
version: str | None = None,
|
| 268 |
-
) -> str:
|
| 269 |
-
"""Resolve the final download URL for a model.
|
| 270 |
-
|
| 271 |
-
If `version` is provided and looks like a filename (e.g. 'yolov8n_pt'),
|
| 272 |
-
it was generated by hf_adapter from a sibling rfilename. Restore the
|
| 273 |
-
original filename (replace trailing _ext with .ext) and build a direct URL.
|
| 274 |
-
"""
|
| 275 |
-
repo_id: str | None = None
|
| 276 |
-
|
| 277 |
-
if download_url and "huggingface.co" in download_url:
|
| 278 |
-
repo_id = download_url.replace("https://huggingface.co/", "").rstrip("/")
|
| 279 |
-
elif not download_url:
|
| 280 |
-
model = await get_model(model_id)
|
| 281 |
-
if model and model.download_url:
|
| 282 |
-
url = model.download_url
|
| 283 |
-
if "huggingface.co" in url:
|
| 284 |
-
repo_id = url.replace("https://huggingface.co/", "").rstrip("/")
|
| 285 |
-
else:
|
| 286 |
-
return url
|
| 287 |
-
else:
|
| 288 |
-
repo_id = model_id.replace("_", "/", 1)
|
| 289 |
-
else:
|
| 290 |
-
return download_url
|
| 291 |
-
|
| 292 |
-
# If the caller specified a version that is a converted rfilename
|
| 293 |
-
# (dots replaced with underscores by hf_adapter), reconstruct the filename.
|
| 294 |
-
if version and repo_id:
|
| 295 |
-
filename = _version_to_filename(version)
|
| 296 |
-
if filename:
|
| 297 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
|
| 298 |
-
|
| 299 |
-
return await _resolve_hf_download_url(repo_id)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
def _version_to_filename(version: str) -> str | None:
|
| 303 |
-
"""Convert an hf_adapter version string back to a real filename.
|
| 304 |
-
|
| 305 |
-
hf_adapter stores version as rfilename.replace('.', '_'), e.g.:
|
| 306 |
-
'yolov8n_pt' β 'yolov8n.pt'
|
| 307 |
-
'model_safetensors' β 'model.safetensors'
|
| 308 |
-
Only converts if the result ends with a known weight extension.
|
| 309 |
-
"""
|
| 310 |
-
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx")
|
| 311 |
-
# Try replacing the last underscore with a dot
|
| 312 |
-
idx = version.rfind("_")
|
| 313 |
-
if idx == -1:
|
| 314 |
-
return None
|
| 315 |
-
candidate = version[:idx] + "." + version[idx + 1:]
|
| 316 |
-
if any(candidate.endswith(ext) for ext in weight_exts):
|
| 317 |
-
return candidate
|
| 318 |
-
return None
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
async def _rate_limited_download(
|
| 322 |
-
job_id: str,
|
| 323 |
-
model_id: str,
|
| 324 |
-
model_name: str,
|
| 325 |
-
download_url: str,
|
| 326 |
-
dest_path: Path,
|
| 327 |
-
) -> None:
|
| 328 |
-
async with _get_sem():
|
| 329 |
-
try:
|
| 330 |
-
await _execute_download(job_id, model_id, model_name, download_url, dest_path)
|
| 331 |
-
except Exception:
|
| 332 |
-
pass # Already logged & stored in DB
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
async def get_job(job_id: str) -> Job | None:
|
| 336 |
-
db = await get_db()
|
| 337 |
-
async with db.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)) as cur:
|
| 338 |
-
row = await cur.fetchone()
|
| 339 |
-
return row_to_job(row) if row else None
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
async def list_jobs(
|
| 343 |
-
status: str | None = None,
|
| 344 |
-
limit: int = 50,
|
| 345 |
-
) -> list[Job]:
|
| 346 |
-
db = await get_db()
|
| 347 |
-
if status:
|
| 348 |
-
sql = "SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC LIMIT ?"
|
| 349 |
-
params: tuple = (status, limit)
|
| 350 |
-
else:
|
| 351 |
-
sql = "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?"
|
| 352 |
-
params = (limit,)
|
| 353 |
-
async with db.execute(sql, params) as cur:
|
| 354 |
-
rows = await cur.fetchall()
|
| 355 |
-
return [row_to_job(r) for r in rows]
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
async def cancel_job(job_id: str) -> bool:
|
| 359 |
-
"""Cancel a queued or running job (best-effort)."""
|
| 360 |
-
job = await get_job(job_id)
|
| 361 |
-
if not job or job.status not in ("queued", "running"):
|
| 362 |
-
return False
|
| 363 |
-
now = datetime.now(timezone.utc).isoformat()
|
| 364 |
-
await _update_job(job_id, status="cancelled", ended_at=now)
|
| 365 |
-
log.info("job_cancelled", job_id=job_id)
|
| 366 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# inference package
|
|
|
|
|
|
inference/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (144 Bytes)
|
|
|
inference/__pycache__/engine.cpython-310.pyc
DELETED
|
Binary file (12 kB)
|
|
|
inference/__pycache__/session.cpython-310.pyc
DELETED
|
Binary file (2.87 kB)
|
|
|
inference/engine.py
DELETED
|
@@ -1,447 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
inference/engine.py β MLForge Inference Engine.
|
| 3 |
-
|
| 4 |
-
Dispatcher that routes each InferenceRequest to the correct adapter pipeline:
|
| 5 |
-
YOLO β YOLOInferencePipeline
|
| 6 |
-
TRANSFORMERS β TransformersPipeline
|
| 7 |
-
ONNX β ONNXPipeline
|
| 8 |
-
CUSTOM β CustomPipeline
|
| 9 |
-
|
| 10 |
-
Each pipeline implements preprocess β inference_step β postprocess.
|
| 11 |
-
Simulation paths are used when real model weights are not loaded;
|
| 12 |
-
every # <<< REPLACE IN PRODUCTION >>> comment marks the exact swap point.
|
| 13 |
-
|
| 14 |
-
Architecture follows the spec in infra_arch.md Β§4 (Adapter Protocol).
|
| 15 |
-
"""
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import asyncio
|
| 19 |
-
import base64
|
| 20 |
-
import io
|
| 21 |
-
import random
|
| 22 |
-
import time
|
| 23 |
-
import uuid
|
| 24 |
-
from typing import Any
|
| 25 |
-
|
| 26 |
-
from models.inference import (
|
| 27 |
-
AdapterType,
|
| 28 |
-
Detection,
|
| 29 |
-
InferenceRequest,
|
| 30 |
-
InferenceResult,
|
| 31 |
-
PipelineStage,
|
| 32 |
-
)
|
| 33 |
-
from models.model import Model
|
| 34 |
-
from observability.logger import get_logger
|
| 35 |
-
|
| 36 |
-
log = get_logger("inference.engine")
|
| 37 |
-
|
| 38 |
-
# ββ Model cache: model_id β loaded model object ββββββββββββββββββββββββββββββ
|
| 39 |
-
_MODEL_CACHE: dict[str, Any] = {}
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _now_ms() -> float:
|
| 43 |
-
return time.perf_counter() * 1000
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# ββ YOLO Pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
-
|
| 48 |
-
class YOLOPipeline:
|
| 49 |
-
"""
|
| 50 |
-
YOLO inference pipeline.
|
| 51 |
-
Preprocess: letterbox resize β BGRβRGB β 1/255 normalise.
|
| 52 |
-
Postprocess: NMS β [{x1,y1,x2,y2,confidence,class_id,class_name}].
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
async def run(
|
| 56 |
-
self, req: InferenceRequest, model: Model
|
| 57 |
-
) -> tuple[list[PipelineStage], dict[str, Any]]:
|
| 58 |
-
cfg = req.yolo_config
|
| 59 |
-
conf = cfg.confidence if cfg else 0.25
|
| 60 |
-
iou = cfg.iou_threshold if cfg else 0.45
|
| 61 |
-
|
| 62 |
-
stages: list[PipelineStage] = []
|
| 63 |
-
|
| 64 |
-
# β Stage 1: Preprocess ββββββββββββββββββββββββββββ
|
| 65 |
-
t0 = _now_ms()
|
| 66 |
-
await asyncio.sleep(0) # yield control
|
| 67 |
-
if req.image_base64:
|
| 68 |
-
try:
|
| 69 |
-
raw_bytes = base64.b64decode(req.image_base64)
|
| 70 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 71 |
-
# img = cv2.imdecode(np.frombuffer(raw_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 72 |
-
# tensor = letterbox(img, 640) / 255.0
|
| 73 |
-
_ = len(raw_bytes) # validate decode worked
|
| 74 |
-
except Exception as e:
|
| 75 |
-
return [PipelineStage(name="Preprocess", status="error", detail=str(e))], {}
|
| 76 |
-
pre_ms = _now_ms() - t0 + random.uniform(0.8, 2.5)
|
| 77 |
-
stages.append(PipelineStage(name="Preprocess", status="done",
|
| 78 |
-
latency_ms=round(pre_ms, 2), detail="Letterbox 640Γ640"))
|
| 79 |
-
|
| 80 |
-
# β Stage 2: Engine Load βββββββββββββββββββββββββββ
|
| 81 |
-
t1 = _now_ms()
|
| 82 |
-
loaded = model.id in _MODEL_CACHE
|
| 83 |
-
load_ms = 0.0 if loaded else random.uniform(80, 220)
|
| 84 |
-
await asyncio.sleep(load_ms / 1000.0)
|
| 85 |
-
if not loaded:
|
| 86 |
-
_MODEL_CACHE[model.id] = object() # <<< REPLACE: load actual weights
|
| 87 |
-
stages.append(PipelineStage(name="Engine Load", status="done",
|
| 88 |
-
latency_ms=round(_now_ms() - t1, 2),
|
| 89 |
-
detail="Cache hit" if loaded else "Weights loaded"))
|
| 90 |
-
|
| 91 |
-
# β Stage 3: Inference ββββββββββββββββββββββββββββ
|
| 92 |
-
t2 = _now_ms()
|
| 93 |
-
size_gb = max(model.size, 1) / (1024 ** 3)
|
| 94 |
-
base_lat = 2.5 + size_gb * 1.5
|
| 95 |
-
infer_ms = base_lat + random.gauss(0, base_lat * 0.07)
|
| 96 |
-
await asyncio.sleep(infer_ms / 1000.0)
|
| 97 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 98 |
-
# results = model_obj(tensor, conf=conf, iou=iou)
|
| 99 |
-
stages.append(PipelineStage(name="Inference", status="done",
|
| 100 |
-
latency_ms=round(infer_ms, 2),
|
| 101 |
-
detail=f"conf={conf} iou={iou}"))
|
| 102 |
-
|
| 103 |
-
# β Stage 4: Post-process (NMS) ββββββββββββββββββ
|
| 104 |
-
t3 = _now_ms()
|
| 105 |
-
detections = self._simulate_detections(conf, cfg.class_filter if cfg else [])
|
| 106 |
-
post_ms = random.uniform(0.3, 1.2)
|
| 107 |
-
await asyncio.sleep(post_ms / 1000.0)
|
| 108 |
-
stages.append(PipelineStage(name="NMS Post-process", status="done",
|
| 109 |
-
latency_ms=round(post_ms, 2),
|
| 110 |
-
detail=f"{len(detections)} detections"))
|
| 111 |
-
|
| 112 |
-
output: dict[str, Any] = {
|
| 113 |
-
"detections": [d.model_dump() for d in detections],
|
| 114 |
-
"pre_ms": round(pre_ms, 2),
|
| 115 |
-
"infer_ms": round(infer_ms, 2),
|
| 116 |
-
"post_ms": round(post_ms, 2),
|
| 117 |
-
}
|
| 118 |
-
return stages, output
|
| 119 |
-
|
| 120 |
-
@staticmethod
|
| 121 |
-
def _simulate_detections(conf_thresh: float, class_filter: list[str]) -> list[Detection]:
|
| 122 |
-
"""Simulate bounding-box detections. <<< REPLACE with real NMS output."""
|
| 123 |
-
CLASSES = ["person", "car", "truck", "bicycle", "dog", "cat",
|
| 124 |
-
"traffic light", "stop sign", "bench", "bird"]
|
| 125 |
-
n = random.randint(0, 8)
|
| 126 |
-
dets: list[Detection] = []
|
| 127 |
-
for _ in range(n):
|
| 128 |
-
c = random.uniform(conf_thresh, 1.0)
|
| 129 |
-
cid = random.randint(0, len(CLASSES) - 1)
|
| 130 |
-
cname = CLASSES[cid]
|
| 131 |
-
if class_filter and cname not in class_filter:
|
| 132 |
-
continue
|
| 133 |
-
x1 = random.uniform(0, 0.7)
|
| 134 |
-
y1 = random.uniform(0, 0.7)
|
| 135 |
-
dets.append(Detection(
|
| 136 |
-
x1=round(x1 * 640, 1), y1=round(y1 * 640, 1),
|
| 137 |
-
x2=round((x1 + random.uniform(0.05, 0.3)) * 640, 1),
|
| 138 |
-
y2=round((y1 + random.uniform(0.05, 0.3)) * 640, 1),
|
| 139 |
-
confidence=round(c, 4),
|
| 140 |
-
class_id=cid, class_name=cname,
|
| 141 |
-
))
|
| 142 |
-
return dets
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
# ββ Transformers Pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
-
|
| 147 |
-
class TransformersPipeline:
|
| 148 |
-
"""
|
| 149 |
-
HuggingFace Transformers pipeline.
|
| 150 |
-
Preprocess: AutoTokenizer.encode.
|
| 151 |
-
Inference: model.generate with KV-cache.
|
| 152 |
-
Postprocess: decode + strip special tokens.
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
async def run(
|
| 156 |
-
self, req: InferenceRequest, model: Model
|
| 157 |
-
) -> tuple[list[PipelineStage], dict[str, Any]]:
|
| 158 |
-
cfg = req.transformers_config
|
| 159 |
-
stages: list[PipelineStage] = []
|
| 160 |
-
|
| 161 |
-
# β Tokenize ββββββββββββββββββββββββββββββββββββββ
|
| 162 |
-
t0 = _now_ms()
|
| 163 |
-
txt = req.text_input or "Hello, world!"
|
| 164 |
-
tok_count = len(txt.split()) * 2 # rough BPE estimate
|
| 165 |
-
await asyncio.sleep(0.002)
|
| 166 |
-
pre_ms = _now_ms() - t0 + random.uniform(1, 4)
|
| 167 |
-
stages.append(PipelineStage(name="Tokenise", status="done",
|
| 168 |
-
latency_ms=round(pre_ms, 2),
|
| 169 |
-
detail=f"{tok_count} tokens"))
|
| 170 |
-
|
| 171 |
-
# β Engine Load βββββββββββββββββββββββββββββββββ
|
| 172 |
-
t1 = _now_ms()
|
| 173 |
-
loaded = model.id in _MODEL_CACHE
|
| 174 |
-
load_ms = 0.0 if loaded else random.uniform(150, 400)
|
| 175 |
-
await asyncio.sleep(load_ms / 1000.0)
|
| 176 |
-
if not loaded:
|
| 177 |
-
_MODEL_CACHE[model.id] = object()
|
| 178 |
-
stages.append(PipelineStage(name="Engine Load", status="done",
|
| 179 |
-
latency_ms=round(_now_ms() - t1, 2),
|
| 180 |
-
detail="Cache hit" if loaded else "Model loaded"))
|
| 181 |
-
|
| 182 |
-
# β Generate ββββββββββββββββββββββββββββββββββββββ
|
| 183 |
-
t2 = _now_ms()
|
| 184 |
-
max_tok = cfg.max_new_tokens if cfg else 256
|
| 185 |
-
# Simulate token-by-token generation at ~20 tok/s
|
| 186 |
-
infer_ms = (max_tok / 20.0) * 1000 + random.gauss(0, 50)
|
| 187 |
-
await asyncio.sleep(min(infer_ms / 1000.0, 0.5)) # cap sim delay
|
| 188 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 189 |
-
# outputs = model_obj.generate(input_ids, max_new_tokens=max_tok,
|
| 190 |
-
# temperature=cfg.temperature, top_p=cfg.top_p, do_sample=cfg.do_sample)
|
| 191 |
-
stages.append(PipelineStage(name="Generate", status="done",
|
| 192 |
-
latency_ms=round(infer_ms, 2),
|
| 193 |
-
detail=f"~{max_tok} tokens @ fp16"))
|
| 194 |
-
|
| 195 |
-
# β Decode ββββββββββββββββββββββββββββββββββββββββ
|
| 196 |
-
t3 = _now_ms()
|
| 197 |
-
text_output = self._simulate_text(txt, max_tok)
|
| 198 |
-
post_ms = random.uniform(0.5, 2.0)
|
| 199 |
-
stages.append(PipelineStage(name="Decode", status="done",
|
| 200 |
-
latency_ms=round(post_ms, 2),
|
| 201 |
-
detail="Special tokens stripped"))
|
| 202 |
-
|
| 203 |
-
output: dict[str, Any] = {
|
| 204 |
-
"text_output": text_output,
|
| 205 |
-
"tokens_generated": max_tok,
|
| 206 |
-
"pre_ms": round(pre_ms, 2),
|
| 207 |
-
"infer_ms": round(infer_ms, 2),
|
| 208 |
-
"post_ms": round(post_ms, 2),
|
| 209 |
-
}
|
| 210 |
-
return stages, output
|
| 211 |
-
|
| 212 |
-
@staticmethod
|
| 213 |
-
def _simulate_text(prompt: str, n_tokens: int) -> str:
|
| 214 |
-
"""Placeholder generation. <<< REPLACE with model.generate."""
|
| 215 |
-
lorem = (
|
| 216 |
-
"The model processed your input and generated a response based on the "
|
| 217 |
-
"learned distribution of the training corpus. This output is a simulation "
|
| 218 |
-
"placeholder β replace with actual model.generate() in production. "
|
| 219 |
-
)
|
| 220 |
-
# Repeat to roughly match token count
|
| 221 |
-
words = (lorem * (n_tokens // 20 + 1)).split()[:n_tokens]
|
| 222 |
-
return " ".join(words)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
# ββ ONNX Pipeline βββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββ
|
| 226 |
-
|
| 227 |
-
class ONNXPipeline:
|
| 228 |
-
"""
|
| 229 |
-
ONNX Runtime pipeline.
|
| 230 |
-
Acts as universal wrapper for TF / sklearn / PyTorch exported models.
|
| 231 |
-
Dynamically maps input tensor names from model metadata.
|
| 232 |
-
"""
|
| 233 |
-
|
| 234 |
-
async def run(
|
| 235 |
-
self, req: InferenceRequest, model: Model
|
| 236 |
-
) -> tuple[list[PipelineStage], dict[str, Any]]:
|
| 237 |
-
cfg = req.onnx_config
|
| 238 |
-
stages: list[PipelineStage] = []
|
| 239 |
-
provider = cfg.execution_provider if cfg else "CUDAExecutionProvider"
|
| 240 |
-
|
| 241 |
-
# β Preprocess ββββββββββββββββββββββββββββββββββββ
|
| 242 |
-
t0 = _now_ms()
|
| 243 |
-
pre_ms = random.uniform(1.0, 3.5)
|
| 244 |
-
await asyncio.sleep(pre_ms / 1000.0)
|
| 245 |
-
stages.append(PipelineStage(name="Preprocess", status="done",
|
| 246 |
-
latency_ms=round(pre_ms, 2),
|
| 247 |
-
detail="Normalise + reshape tensor"))
|
| 248 |
-
|
| 249 |
-
# β ONNX Runtime ββββββββββββββββββββββββββββββββββ
|
| 250 |
-
t1 = _now_ms()
|
| 251 |
-
loaded = model.id in _MODEL_CACHE
|
| 252 |
-
load_ms = 0.0 if loaded else random.uniform(50, 150)
|
| 253 |
-
await asyncio.sleep(load_ms / 1000.0)
|
| 254 |
-
if not loaded:
|
| 255 |
-
_MODEL_CACHE[model.id] = object()
|
| 256 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 257 |
-
# import onnxruntime as ort
|
| 258 |
-
# sess_opts = ort.SessionOptions()
|
| 259 |
-
# _MODEL_CACHE[model.id] = ort.InferenceSession(
|
| 260 |
-
# model.local_path, sess_options=sess_opts,
|
| 261 |
-
# providers=[provider])
|
| 262 |
-
stages.append(PipelineStage(name="ONNX Runtime", status="done",
|
| 263 |
-
latency_ms=round(_now_ms() - t1, 2),
|
| 264 |
-
detail=provider.replace("ExecutionProvider", "")))
|
| 265 |
-
|
| 266 |
-
# β Inference ββββββββββββββββββββββββββββββββββββ
|
| 267 |
-
t2 = _now_ms()
|
| 268 |
-
infer_ms = random.uniform(3.0, 12.0)
|
| 269 |
-
await asyncio.sleep(infer_ms / 1000.0)
|
| 270 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 271 |
-
# ort_inputs = {sess.get_inputs()[0].name: tensor.numpy()}
|
| 272 |
-
# raw = sess.run(None, ort_inputs)
|
| 273 |
-
stages.append(PipelineStage(name="Inference", status="done",
|
| 274 |
-
latency_ms=round(infer_ms, 2),
|
| 275 |
-
detail="session.run()"))
|
| 276 |
-
|
| 277 |
-
# β Format Output ββββββββββββββββββββββββββββββββ
|
| 278 |
-
t3 = _now_ms()
|
| 279 |
-
post_ms = random.uniform(0.2, 0.8)
|
| 280 |
-
raw_out = {"output_0": [round(random.random(), 4) for _ in range(10)]}
|
| 281 |
-
stages.append(PipelineStage(name="Format Output", status="done",
|
| 282 |
-
latency_ms=round(post_ms, 2),
|
| 283 |
-
detail="Tensor β JSON"))
|
| 284 |
-
|
| 285 |
-
output: dict[str, Any] = {
|
| 286 |
-
"raw_output": raw_out,
|
| 287 |
-
"pre_ms": round(pre_ms, 2),
|
| 288 |
-
"infer_ms": round(infer_ms, 2),
|
| 289 |
-
"post_ms": round(post_ms, 2),
|
| 290 |
-
}
|
| 291 |
-
return stages, output
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
# ββ Custom Python Pipeline ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
-
|
| 296 |
-
class CustomPipeline:
|
| 297 |
-
"""
|
| 298 |
-
Sandboxed custom Python pipeline.
|
| 299 |
-
Executes user-supplied pre/postprocess scripts in a restricted namespace.
|
| 300 |
-
Only numpy, the input tensor, and the model's raw output are accessible.
|
| 301 |
-
"""
|
| 302 |
-
|
| 303 |
-
FORBIDDEN = ("import os", "import sys", "subprocess", "open(", "__import__",
|
| 304 |
-
"eval(", "exec(", "globals(", "locals(")
|
| 305 |
-
|
| 306 |
-
def _validate_script(self, script: str) -> str | None:
|
| 307 |
-
for tok in self.FORBIDDEN:
|
| 308 |
-
if tok in script:
|
| 309 |
-
return f"Forbidden token in script: {tok!r}"
|
| 310 |
-
return None
|
| 311 |
-
|
| 312 |
-
async def run(
|
| 313 |
-
self, req: InferenceRequest, model: Model
|
| 314 |
-
) -> tuple[list[PipelineStage], dict[str, Any]]:
|
| 315 |
-
cfg = req.custom_config
|
| 316 |
-
stages: list[PipelineStage] = []
|
| 317 |
-
|
| 318 |
-
# β Validate scripts ββββββββββββββββββββββββββββββ
|
| 319 |
-
if cfg:
|
| 320 |
-
for label, script in [("preprocess", cfg.preprocess_script),
|
| 321 |
-
("postprocess", cfg.postprocess_script)]:
|
| 322 |
-
if script:
|
| 323 |
-
err = self._validate_script(script)
|
| 324 |
-
if err:
|
| 325 |
-
return [PipelineStage(name=label.capitalize(),
|
| 326 |
-
status="error", detail=err)], {}
|
| 327 |
-
|
| 328 |
-
# β Transform Input βββββββββββββββββββββββββββββββ
|
| 329 |
-
pre_ms = random.uniform(1.0, 5.0)
|
| 330 |
-
await asyncio.sleep(pre_ms / 1000.0)
|
| 331 |
-
stages.append(PipelineStage(name="Transform Input", status="done",
|
| 332 |
-
latency_ms=round(pre_ms, 2),
|
| 333 |
-
detail="Custom preprocess script"))
|
| 334 |
-
|
| 335 |
-
# β Run Inference ββββββββββββββββββββββββββββββββ
|
| 336 |
-
infer_ms = random.uniform(5.0, 30.0)
|
| 337 |
-
await asyncio.sleep(infer_ms / 1000.0)
|
| 338 |
-
# <<< REPLACE IN PRODUCTION >>>
|
| 339 |
-
# namespace = {"input": tensor, "model": raw_model}
|
| 340 |
-
# exec(compile(cfg.preprocess_script, "<pre>", "exec"), namespace)
|
| 341 |
-
# tensor = namespace.get("output", tensor)
|
| 342 |
-
stages.append(PipelineStage(name="Run Inference", status="done",
|
| 343 |
-
latency_ms=round(infer_ms, 2),
|
| 344 |
-
detail="Custom runtime"))
|
| 345 |
-
|
| 346 |
-
# β Format Result ββββββββββββββββββββββββββββββββ
|
| 347 |
-
post_ms = random.uniform(0.5, 3.0)
|
| 348 |
-
stages.append(PipelineStage(name="Format Result", status="done",
|
| 349 |
-
latency_ms=round(post_ms, 2),
|
| 350 |
-
detail="Custom postprocess script"))
|
| 351 |
-
|
| 352 |
-
output: dict[str, Any] = {
|
| 353 |
-
"raw_output": {"custom_result": round(random.random(), 4)},
|
| 354 |
-
"pre_ms": round(pre_ms, 2),
|
| 355 |
-
"infer_ms": round(infer_ms, 2),
|
| 356 |
-
"post_ms": round(post_ms, 2),
|
| 357 |
-
}
|
| 358 |
-
return stages, output
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
# ββ Master Dispatcher βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
-
|
| 363 |
-
_PIPELINE_MAP = {
|
| 364 |
-
AdapterType.YOLO: YOLOPipeline,
|
| 365 |
-
AdapterType.TRANSFORMERS: TransformersPipeline,
|
| 366 |
-
AdapterType.ONNX: ONNXPipeline,
|
| 367 |
-
AdapterType.CUSTOM: CustomPipeline,
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
class InferenceEngine:
|
| 372 |
-
"""
|
| 373 |
-
Central inference dispatcher.
|
| 374 |
-
Resolves the correct pipeline, executes it, and wraps the result
|
| 375 |
-
into a fully-populated InferenceResult.
|
| 376 |
-
"""
|
| 377 |
-
|
| 378 |
-
async def run(self, req: InferenceRequest, model: Model) -> InferenceResult:
|
| 379 |
-
t_start = _now_ms()
|
| 380 |
-
pipeline_cls = _PIPELINE_MAP.get(req.adapter_type)
|
| 381 |
-
if pipeline_cls is None:
|
| 382 |
-
return InferenceResult(
|
| 383 |
-
request_id=str(uuid.uuid4()),
|
| 384 |
-
model_id=req.model_id,
|
| 385 |
-
adapter_type=req.adapter_type,
|
| 386 |
-
status="error",
|
| 387 |
-
error=f"Unknown adapter type: {req.adapter_type}",
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
-
try:
|
| 391 |
-
stages, output = await pipeline_cls().run(req, model)
|
| 392 |
-
|
| 393 |
-
total_ms = _now_ms() - t_start
|
| 394 |
-
pre_ms = output.get("pre_ms", 0.0)
|
| 395 |
-
infer_ms = output.get("infer_ms", 0.0)
|
| 396 |
-
post_ms = output.get("post_ms", 0.0)
|
| 397 |
-
|
| 398 |
-
# Quality score: mean confidence of detections (0β5 scale)
|
| 399 |
-
detections = [Detection(**d) for d in output.get("detections", [])]
|
| 400 |
-
if detections:
|
| 401 |
-
mean_conf = sum(d.confidence for d in detections) / len(detections)
|
| 402 |
-
quality = round(mean_conf * 5.0, 2)
|
| 403 |
-
else:
|
| 404 |
-
quality = round(random.uniform(3.2, 4.8), 2)
|
| 405 |
-
|
| 406 |
-
result = InferenceResult(
|
| 407 |
-
model_id = req.model_id,
|
| 408 |
-
adapter_type = req.adapter_type,
|
| 409 |
-
preprocess_ms = pre_ms,
|
| 410 |
-
inference_ms = infer_ms,
|
| 411 |
-
postprocess_ms= post_ms,
|
| 412 |
-
total_ms = round(total_ms, 2),
|
| 413 |
-
pipeline = stages,
|
| 414 |
-
detections = detections,
|
| 415 |
-
text_output = output.get("text_output"),
|
| 416 |
-
raw_output = output.get("raw_output"),
|
| 417 |
-
quality_score = quality,
|
| 418 |
-
status = "ok",
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
log.info("inference_complete",
|
| 422 |
-
model_id=req.model_id,
|
| 423 |
-
adapter=req.adapter_type,
|
| 424 |
-
total_ms=round(total_ms, 2))
|
| 425 |
-
return result
|
| 426 |
-
|
| 427 |
-
except Exception as exc:
|
| 428 |
-
log.error("inference_error", model_id=req.model_id, error=str(exc))
|
| 429 |
-
return InferenceResult(
|
| 430 |
-
model_id=req.model_id,
|
| 431 |
-
adapter_type=req.adapter_type,
|
| 432 |
-
status="error",
|
| 433 |
-
error=str(exc),
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
def get_cache_status() -> dict[str, bool]:
|
| 438 |
-
"""Return which model IDs are currently warm in cache."""
|
| 439 |
-
return {k: True for k in _MODEL_CACHE}
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
def evict_model(model_id: str) -> bool:
|
| 443 |
-
"""Evict a model from the in-process cache (free VRAM sim)."""
|
| 444 |
-
if model_id in _MODEL_CACHE:
|
| 445 |
-
del _MODEL_CACHE[model_id]
|
| 446 |
-
return True
|
| 447 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/session.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
inference/session.py β In-memory inference session ledger.
|
| 3 |
-
|
| 4 |
-
Keeps the last MAX_HISTORY inference results per process lifetime.
|
| 5 |
-
Persisted to the SQLite `inference_history` table on each write
|
| 6 |
-
(non-blocking via aiosqlite).
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import asyncio
|
| 11 |
-
import json
|
| 12 |
-
import uuid
|
| 13 |
-
from collections import deque
|
| 14 |
-
from typing import Deque
|
| 15 |
-
|
| 16 |
-
from models.inference import InferenceHistoryEntry, InferenceRequest, InferenceResult
|
| 17 |
-
from observability.logger import get_logger
|
| 18 |
-
|
| 19 |
-
log = get_logger("inference.session")
|
| 20 |
-
|
| 21 |
-
MAX_HISTORY = 200
|
| 22 |
-
|
| 23 |
-
_history: Deque[InferenceHistoryEntry] = deque(maxlen=MAX_HISTORY)
|
| 24 |
-
_lock = asyncio.Lock()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
async def record(req: InferenceRequest, result: InferenceResult, model_name: str) -> None:
|
| 28 |
-
"""Append a completed inference run to the ledger."""
|
| 29 |
-
entry = InferenceHistoryEntry(
|
| 30 |
-
model_id = req.model_id,
|
| 31 |
-
model_name = model_name,
|
| 32 |
-
adapter_type = req.adapter_type,
|
| 33 |
-
total_ms = result.total_ms,
|
| 34 |
-
quality_score = result.quality_score,
|
| 35 |
-
status = result.status,
|
| 36 |
-
request_snapshot = req.model_dump(exclude={"image_base64"}),
|
| 37 |
-
)
|
| 38 |
-
async with _lock:
|
| 39 |
-
_history.appendleft(entry)
|
| 40 |
-
|
| 41 |
-
# Persist to DB (fire-and-forget)
|
| 42 |
-
asyncio.create_task(_persist(entry))
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
async def _persist(entry: InferenceHistoryEntry) -> None:
|
| 46 |
-
try:
|
| 47 |
-
from database.connection import get_db
|
| 48 |
-
async with get_db() as db:
|
| 49 |
-
await db.execute(
|
| 50 |
-
"""
|
| 51 |
-
INSERT OR REPLACE INTO inference_history
|
| 52 |
-
(id, model_id, model_name, adapter_type, timestamp,
|
| 53 |
-
total_ms, quality_score, status, request_snapshot)
|
| 54 |
-
VALUES (?,?,?,?,?,?,?,?,?)
|
| 55 |
-
""",
|
| 56 |
-
(
|
| 57 |
-
entry.id,
|
| 58 |
-
entry.model_id,
|
| 59 |
-
entry.model_name,
|
| 60 |
-
entry.adapter_type.value,
|
| 61 |
-
entry.timestamp,
|
| 62 |
-
entry.total_ms,
|
| 63 |
-
entry.quality_score,
|
| 64 |
-
entry.status,
|
| 65 |
-
json.dumps(entry.request_snapshot),
|
| 66 |
-
),
|
| 67 |
-
)
|
| 68 |
-
await db.commit()
|
| 69 |
-
except Exception as exc:
|
| 70 |
-
log.warning("inference_persist_failed", error=str(exc))
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
async def get_history(limit: int = 50) -> list[InferenceHistoryEntry]:
|
| 74 |
-
async with _lock:
|
| 75 |
-
return list(_history)[:limit]
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
async def clear_history() -> None:
|
| 79 |
-
async with _lock:
|
| 80 |
-
_history.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
|
@@ -26,7 +26,6 @@ from fastapi.responses import JSONResponse
|
|
| 26 |
from api.routes import models as models_router
|
| 27 |
from api.routes import sync as sync_router
|
| 28 |
from api.routes import datasets as datasets_router
|
| 29 |
-
from api.routes import projects as projects_router
|
| 30 |
from config import settings
|
| 31 |
from database.connection import close_db, get_db
|
| 32 |
from middleware.logging_middleware import RequestLoggingMiddleware
|
|
@@ -65,9 +64,9 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
| 65 |
|
| 66 |
# ββ Application βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
app = FastAPI(
|
| 68 |
-
title=
|
| 69 |
version=settings.version,
|
| 70 |
-
description="
|
| 71 |
docs_url="/docs",
|
| 72 |
redoc_url="/redoc",
|
| 73 |
lifespan=lifespan,
|
|
@@ -91,8 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|
| 91 |
# ββ Middleware βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
app.add_middleware(
|
| 93 |
CORSMiddleware,
|
| 94 |
-
allow_origins=
|
| 95 |
-
allow_origin_regex=r"^https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?$",
|
| 96 |
allow_credentials=True,
|
| 97 |
allow_methods=["*"],
|
| 98 |
allow_headers=["*"],
|
|
@@ -103,7 +101,6 @@ app.add_middleware(RequestLoggingMiddleware)
|
|
| 103 |
app.include_router(models_router.router)
|
| 104 |
app.include_router(sync_router.router)
|
| 105 |
app.include_router(datasets_router.router)
|
| 106 |
-
app.include_router(projects_router.router)
|
| 107 |
|
| 108 |
|
| 109 |
@app.get("/health", tags=["system"])
|
|
@@ -114,6 +111,7 @@ async def health() -> dict:
|
|
| 114 |
n_datasets = await count_datasets()
|
| 115 |
return {
|
| 116 |
"status": "ok",
|
|
|
|
| 117 |
"version": settings.version,
|
| 118 |
"model_count": n_models,
|
| 119 |
"dataset_count": n_datasets,
|
|
|
|
| 26 |
from api.routes import models as models_router
|
| 27 |
from api.routes import sync as sync_router
|
| 28 |
from api.routes import datasets as datasets_router
|
|
|
|
| 29 |
from config import settings
|
| 30 |
from database.connection import close_db, get_db
|
| 31 |
from middleware.logging_middleware import RequestLoggingMiddleware
|
|
|
|
| 64 |
|
| 65 |
# ββ Application βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
app = FastAPI(
|
| 67 |
+
title="MLForge Cloud Registry",
|
| 68 |
version=settings.version,
|
| 69 |
+
description="Global Model and Dataset Discovery Service β The Brain of MLForge.",
|
| 70 |
docs_url="/docs",
|
| 71 |
redoc_url="/redoc",
|
| 72 |
lifespan=lifespan,
|
|
|
|
| 90 |
# ββ Middleware βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
app.add_middleware(
|
| 92 |
CORSMiddleware,
|
| 93 |
+
allow_origins=["*"], # Allow all origins for the cloud registry to support SDK/CLI/UI
|
|
|
|
| 94 |
allow_credentials=True,
|
| 95 |
allow_methods=["*"],
|
| 96 |
allow_headers=["*"],
|
|
|
|
| 101 |
app.include_router(models_router.router)
|
| 102 |
app.include_router(sync_router.router)
|
| 103 |
app.include_router(datasets_router.router)
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
@app.get("/health", tags=["system"])
|
|
|
|
| 111 |
n_datasets = await count_datasets()
|
| 112 |
return {
|
| 113 |
"status": "ok",
|
| 114 |
+
"service": "cloud_registry",
|
| 115 |
"version": settings.version,
|
| 116 |
"model_count": n_models,
|
| 117 |
"dataset_count": n_datasets,
|
projects/__init__.py
DELETED
|
File without changes
|