Spaces:
Sleeping
Sleeping
senthil2421 commited on
Commit ·
e10cda2
1
Parent(s): ee35993
Refactor cloud_backend: remove local execution routes and fix missing modules
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- __pycache__/config.cpython-310.pyc +0 -0
- __pycache__/main.cpython-310.pyc +0 -0
- adapters/__init__.py +0 -0
- adapters/__pycache__/__init__.cpython-310.pyc +0 -0
- adapters/__pycache__/base.cpython-310.pyc +0 -0
- adapters/__pycache__/hf_adapter.cpython-310.pyc +0 -0
- adapters/__pycache__/onnx_adapter.cpython-310.pyc +0 -0
- adapters/__pycache__/roboflow_adapter.cpython-310.pyc +0 -0
- adapters/base.py +28 -0
- adapters/hf_adapter.py +415 -0
- adapters/onnx_adapter.py +176 -0
- adapters/roboflow_adapter.py +353 -0
- benchmark/__init__.py +1 -0
- benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
- benchmark/__pycache__/compatibility.cpython-310.pyc +0 -0
- benchmark/__pycache__/execution.cpython-310.pyc +0 -0
- benchmark/__pycache__/metrics.cpython-310.pyc +0 -0
- benchmark/__pycache__/orchestrator.cpython-310.pyc +0 -0
- benchmark/__pycache__/registry.cpython-310.pyc +0 -0
- benchmark/__pycache__/telemetry.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/base.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/registry.cpython-310.pyc +0 -0
- benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc +0 -0
- benchmark/adapters/base.py +38 -0
- benchmark/adapters/optimum_runner.py +53 -0
- benchmark/adapters/registry.py +44 -0
- benchmark/adapters/torch_runner.py +45 -0
- benchmark/compatibility.py +360 -0
- benchmark/execution.py +366 -0
- benchmark/metrics.py +110 -0
- benchmark/orchestrator.py +374 -0
- benchmark/registry.py +302 -0
- benchmark/telemetry.py +182 -0
- benchmark/torch_runner.py +142 -0
- datasets/__init__.py +1 -0
- datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- datasets/__pycache__/annotation_parser.cpython-310.pyc +0 -0
- datasets/__pycache__/base_adapter.cpython-310.pyc +0 -0
- datasets/__pycache__/format_adapters.cpython-310.pyc +0 -0
- datasets/__pycache__/import_service.cpython-310.pyc +0 -0
- datasets/__pycache__/registry.cpython-310.pyc +0 -0
- datasets/__pycache__/viewer_service.cpython-310.pyc +0 -0
- datasets/annotation_parser.py +576 -0
- datasets/base_adapter.py +37 -0
- datasets/format_adapters.py +235 -0
- datasets/import_service.py +589 -0
- datasets/registry.py +452 -0
- datasets/viewer_service.py +320 -0
- download/__init__.py +0 -0
- download/__pycache__/__init__.cpython-310.pyc +0 -0
__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (3.43 kB). View file
|
|
|
adapters/__init__.py
ADDED
|
File without changes
|
adapters/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
adapters/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
adapters/__pycache__/hf_adapter.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
adapters/__pycache__/onnx_adapter.cpython-310.pyc
ADDED
|
Binary file (5.27 kB). View file
|
|
|
adapters/__pycache__/roboflow_adapter.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
adapters/base.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
adapters/base.py — Abstract base class every source adapter must implement.
|
| 3 |
+
Enforces a stable contract so the registry never knows which adapter runs.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
from models.model import Model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseAdapter(ABC):
|
| 13 |
+
"""Fetch models from an external source and normalize to the Model schema."""
|
| 14 |
+
|
| 15 |
+
source_name: str = "unknown"
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
async def fetch_models(self) -> list[Model]:
|
| 19 |
+
"""Return a list of normalized Model objects from the source."""
|
| 20 |
+
...
|
| 21 |
+
|
| 22 |
+
def _format_size(self, bytes_: int) -> str:
|
| 23 |
+
"""Human-readable file size."""
|
| 24 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 25 |
+
if bytes_ < 1024:
|
| 26 |
+
return f"{bytes_:.1f} {unit}"
|
| 27 |
+
bytes_ //= 1024
|
| 28 |
+
return f"{bytes_} PB"
|
adapters/hf_adapter.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
adapters/hf_adapter.py — Hugging Face Hub adapter.
|
| 3 |
+
Fetches real models via the public HF API and normalises them to our schema.
|
| 4 |
+
|
| 5 |
+
Rate-limits respected via polite delays. Requires no authentication for
|
| 6 |
+
publicly accessible models; set HF_TOKEN env var for higher rate-limits.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
import re
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _is_shard_file(filename: str) -> bool:
|
| 16 |
+
"""Return True for sharded weight files like model-00001-of-00003.safetensors."""
|
| 17 |
+
return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 21 |
+
|
| 22 |
+
from adapters.base import BaseAdapter
|
| 23 |
+
from config import settings
|
| 24 |
+
from models.model import Model, ModelMetrics, ModelVersion
|
| 25 |
+
from observability.logger import get_logger
|
| 26 |
+
|
| 27 |
+
log = get_logger("hf_adapter")
|
| 28 |
+
|
| 29 |
+
# ── Task mapping: HF pipeline_tag → our internal task ─────────────────────────
|
| 30 |
+
HF_TASK_MAP: dict[str, str] = {
|
| 31 |
+
"object-detection": "detection",
|
| 32 |
+
"image-classification": "classification",
|
| 33 |
+
"image-segmentation": "segmentation",
|
| 34 |
+
"text-to-image": "generation",
|
| 35 |
+
"image-to-image": "generation",
|
| 36 |
+
"image-feature-extraction": "embedding",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Tasks we actively fetch
|
| 40 |
+
FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys())
|
| 41 |
+
|
| 42 |
+
# ── Framework detection ────────────────────────────────────────────────────────
|
| 43 |
+
def _detect_framework(tags: list[str], model_id: str) -> str:
|
| 44 |
+
tag_str = " ".join(tags + [model_id]).lower()
|
| 45 |
+
if "onnx" in tag_str: return "onnx"
|
| 46 |
+
if "tflite" in tag_str: return "tflite"
|
| 47 |
+
if "coreml" in tag_str: return "coreml"
|
| 48 |
+
if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow"
|
| 49 |
+
return "pytorch" # HF default
|
| 50 |
+
|
| 51 |
+
# ── Hardware detection ─────────────────────────────────────────────────────────
|
| 52 |
+
def _detect_hardware(tags: list[str]) -> list[str]:
|
| 53 |
+
hw: list[str] = []
|
| 54 |
+
tag_str = " ".join(tags).lower()
|
| 55 |
+
if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu")
|
| 56 |
+
if "edge" in tag_str or "mobile" in tag_str: hw.append("edge")
|
| 57 |
+
if "cpu" in tag_str: hw.append("cpu")
|
| 58 |
+
if not hw: hw.append("gpu") # safe default
|
| 59 |
+
return hw
|
| 60 |
+
|
| 61 |
+
# ── Internal tag normalisation ─────────────────────────────────────────────────
|
| 62 |
+
QUALITY_TAG_MAP = {
|
| 63 |
+
"state-of-the-art": "sota",
|
| 64 |
+
"lightweight": "lightweight",
|
| 65 |
+
"tiny": "tiny",
|
| 66 |
+
"fast": "fastest",
|
| 67 |
+
"real-time": "real-time",
|
| 68 |
+
"accuracy": "high-accuracy",
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]:
|
| 72 |
+
out: list[str] = []
|
| 73 |
+
for t in raw_tags:
|
| 74 |
+
t_lower = t.lower()
|
| 75 |
+
for keyword, mapped in QUALITY_TAG_MAP.items():
|
| 76 |
+
if keyword in t_lower:
|
| 77 |
+
out.append(mapped)
|
| 78 |
+
# keep relevant library / dataset tags
|
| 79 |
+
if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")):
|
| 80 |
+
continue
|
| 81 |
+
out.append(t_lower)
|
| 82 |
+
# add pipeline as tag
|
| 83 |
+
if pipeline:
|
| 84 |
+
out.append(pipeline.replace("-", "_"))
|
| 85 |
+
return list(dict.fromkeys(out)) # deduplicate, preserve order
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class HFAdapter(BaseAdapter):
|
| 89 |
+
source_name = "hf"
|
| 90 |
+
|
| 91 |
+
def __init__(self) -> None:
|
| 92 |
+
headers = {"Accept": "application/json"}
|
| 93 |
+
if settings.hf_token:
|
| 94 |
+
headers["Authorization"] = f"Bearer {settings.hf_token}"
|
| 95 |
+
self._client = httpx.AsyncClient(
|
| 96 |
+
base_url=settings.hf_api_base,
|
| 97 |
+
headers=headers,
|
| 98 |
+
timeout=30,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@retry(
|
| 102 |
+
stop=stop_after_attempt(3),
|
| 103 |
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 104 |
+
reraise=True,
|
| 105 |
+
)
|
| 106 |
+
async def _fetch_task_page(
|
| 107 |
+
self, pipeline_tag: str, limit: int = 100
|
| 108 |
+
) -> list[dict[str, Any]]:
|
| 109 |
+
params = {
|
| 110 |
+
"pipeline_tag": pipeline_tag,
|
| 111 |
+
"sort": "downloads",
|
| 112 |
+
"direction": -1, # descending
|
| 113 |
+
"limit": limit,
|
| 114 |
+
"full": "True",
|
| 115 |
+
}
|
| 116 |
+
log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit)
|
| 117 |
+
resp = await self._client.get("/models", params=params)
|
| 118 |
+
resp.raise_for_status()
|
| 119 |
+
return resp.json()
|
| 120 |
+
|
| 121 |
+
@retry(
|
| 122 |
+
stop=stop_after_attempt(3),
|
| 123 |
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 124 |
+
reraise=True,
|
| 125 |
+
)
|
| 126 |
+
async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]:
|
| 127 |
+
resp = await self._client.get(f"/models/{model_id}", params={"full": "True"})
|
| 128 |
+
resp.raise_for_status()
|
| 129 |
+
raw = resp.json()
|
| 130 |
+
|
| 131 |
+
siblings: list[dict[str, Any]] = raw.get("siblings") or []
|
| 132 |
+
has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings)
|
| 133 |
+
if not has_any_size:
|
| 134 |
+
try:
|
| 135 |
+
tree = await self._fetch_model_tree(model_id, revision="main")
|
| 136 |
+
size_by_path: dict[str, int] = {
|
| 137 |
+
(t.get("path") or ""): int(t.get("size") or 0)
|
| 138 |
+
for t in (tree or [])
|
| 139 |
+
if isinstance(t, dict)
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
patched: list[dict[str, Any]] = []
|
| 143 |
+
for s in siblings:
|
| 144 |
+
if not isinstance(s, dict):
|
| 145 |
+
continue
|
| 146 |
+
fn = s.get("rfilename") or s.get("path") or ""
|
| 147 |
+
if fn and not s.get("size") and fn in size_by_path:
|
| 148 |
+
s = {**s, "size": size_by_path[fn]}
|
| 149 |
+
patched.append(s)
|
| 150 |
+
raw["siblings"] = patched
|
| 151 |
+
except Exception:
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
return raw
|
| 155 |
+
|
| 156 |
+
@retry(
|
| 157 |
+
stop=stop_after_attempt(3),
|
| 158 |
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 159 |
+
reraise=True,
|
| 160 |
+
)
|
| 161 |
+
async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]:
|
| 162 |
+
resp = await self._client.get(f"/models/{model_id}/tree/{revision}")
|
| 163 |
+
resp.raise_for_status()
|
| 164 |
+
data = resp.json()
|
| 165 |
+
if isinstance(data, list):
|
| 166 |
+
return data
|
| 167 |
+
return []
|
| 168 |
+
|
| 169 |
+
def _parse_safe_tensors_size(self, siblings: list[dict]) -> int:
|
| 170 |
+
"""Estimate model size from sibling file list."""
|
| 171 |
+
total = 0
|
| 172 |
+
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
|
| 173 |
+
for s in siblings or []:
|
| 174 |
+
filename = s.get("rfilename", "").lower()
|
| 175 |
+
if filename.endswith(weight_exts):
|
| 176 |
+
total += s.get("size", 0)
|
| 177 |
+
|
| 178 |
+
if total > 0:
|
| 179 |
+
return total
|
| 180 |
+
|
| 181 |
+
# If no size found in siblings, check if it's in the root dict (sometimes HF API does this)
|
| 182 |
+
return 0 # Return 0 if not found, we'll handle fallback in _make_model
|
| 183 |
+
|
| 184 |
+
@retry(
|
| 185 |
+
stop=stop_after_attempt(3),
|
| 186 |
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 187 |
+
reraise=True,
|
| 188 |
+
)
|
| 189 |
+
async def _fetch_model_card(self, model_id: str) -> str:
|
| 190 |
+
"""Fetch model card (README.md) content for real-time description."""
|
| 191 |
+
url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md"
|
| 192 |
+
try:
|
| 193 |
+
resp = await self._client.get(url)
|
| 194 |
+
if resp.status_code == 200:
|
| 195 |
+
return resp.text
|
| 196 |
+
except Exception:
|
| 197 |
+
pass
|
| 198 |
+
return ""
|
| 199 |
+
|
| 200 |
+
def _extract_description(self, readme: str, raw: dict[str, Any]) -> str:
|
| 201 |
+
"""Extract a clean description from README or card data."""
|
| 202 |
+
if readme:
|
| 203 |
+
# Simple heuristic: take first paragraph that isn't frontmatter
|
| 204 |
+
lines = readme.split("\n")
|
| 205 |
+
in_frontmatter = False
|
| 206 |
+
for line in lines:
|
| 207 |
+
if line.strip() == "---":
|
| 208 |
+
in_frontmatter = not in_frontmatter
|
| 209 |
+
continue
|
| 210 |
+
if not in_frontmatter and line.strip() and not line.startswith("#"):
|
| 211 |
+
return line.strip()[:500]
|
| 212 |
+
|
| 213 |
+
card_data = raw.get("cardData") or {}
|
| 214 |
+
description: str = (
|
| 215 |
+
(card_data.get("summary") or "")
|
| 216 |
+
or (card_data.get("description") or "")
|
| 217 |
+
or (raw.get("description") or "")
|
| 218 |
+
).strip()
|
| 219 |
+
return description
|
| 220 |
+
|
| 221 |
+
def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics:
|
| 222 |
+
"""
|
| 223 |
+
Product-Grade Metrics Estimation.
|
| 224 |
+
Uses model name heuristics to provide realistic data for common architectures.
|
| 225 |
+
"""
|
| 226 |
+
metrics = ModelMetrics()
|
| 227 |
+
m_id = model_id.lower()
|
| 228 |
+
|
| 229 |
+
# Base latency/vram estimates by architecture
|
| 230 |
+
if "vit" in m_id or "dinov2" in m_id:
|
| 231 |
+
metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0
|
| 232 |
+
metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8
|
| 233 |
+
metrics.accuracy = 82.4 if "base" in m_id else 84.5
|
| 234 |
+
elif "segformer" in m_id:
|
| 235 |
+
# b0, b1, b2, b3, b4, b5
|
| 236 |
+
if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0
|
| 237 |
+
elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0
|
| 238 |
+
elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0
|
| 239 |
+
else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0
|
| 240 |
+
elif "convnext" in m_id:
|
| 241 |
+
metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0
|
| 242 |
+
metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2
|
| 243 |
+
metrics.accuracy = 81.0 if "tiny" in m_id else 83.5
|
| 244 |
+
elif "yolo" in m_id:
|
| 245 |
+
# n, s, m, l, x
|
| 246 |
+
if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3
|
| 247 |
+
elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9
|
| 248 |
+
elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2
|
| 249 |
+
else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0
|
| 250 |
+
|
| 251 |
+
# Generic task-based fallbacks if still empty
|
| 252 |
+
if metrics.latency_ms is None:
|
| 253 |
+
if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0
|
| 254 |
+
elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0
|
| 255 |
+
elif task == "embedding": metrics.latency_ms = 40.0
|
| 256 |
+
elif task == "generation": metrics.latency_ms = 1500.0
|
| 257 |
+
|
| 258 |
+
return metrics
|
| 259 |
+
|
| 260 |
+
def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None:
|
| 261 |
+
model_id: str = raw.get("id") or raw.get("modelId", "")
|
| 262 |
+
if not model_id:
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
task = HF_TASK_MAP.get(pipeline_tag)
|
| 266 |
+
if not task:
|
| 267 |
+
return None
|
| 268 |
+
tags_raw: list[str] = raw.get("tags") or []
|
| 269 |
+
framework = _detect_framework(tags_raw, model_id)
|
| 270 |
+
hardware = _detect_hardware(tags_raw)
|
| 271 |
+
tags = _normalise_tags(tags_raw, pipeline_tag)
|
| 272 |
+
|
| 273 |
+
# Size
|
| 274 |
+
siblings: list[dict] = raw.get("siblings") or []
|
| 275 |
+
size = self._parse_safe_tensors_size(siblings)
|
| 276 |
+
if size == 0:
|
| 277 |
+
# Fallback based on model type if size not found
|
| 278 |
+
if "large" in model_id.lower(): size = 1_200_000_000
|
| 279 |
+
elif "base" in model_id.lower(): size = 500_000_000
|
| 280 |
+
elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000
|
| 281 |
+
else: size = 450_000_000 # More realistic general default than exactly 500MB
|
| 282 |
+
|
| 283 |
+
# Provider — author part of model_id
|
| 284 |
+
provider = model_id.split("/")[0] if "/" in model_id else "community"
|
| 285 |
+
|
| 286 |
+
# safe name
|
| 287 |
+
name = model_id.split("/")[-1] if "/" in model_id else model_id
|
| 288 |
+
# Clean ugly names
|
| 289 |
+
name = re.sub(r"[-_]+", "-", name).strip("-")
|
| 290 |
+
|
| 291 |
+
downloads = raw.get("downloads") or 0
|
| 292 |
+
likes = raw.get("likes") or 0
|
| 293 |
+
|
| 294 |
+
# Fabricate a sensible version from last modified
|
| 295 |
+
last_mod: str = raw.get("lastModified") or raw.get("createdAt") or ""
|
| 296 |
+
release_date = last_mod[:10] if last_mod else "2024-01-01"
|
| 297 |
+
sha8 = (raw.get("sha") or "main")[:8]
|
| 298 |
+
|
| 299 |
+
# Build versions from weight files in the repo (one per distinct weight file)
|
| 300 |
+
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
|
| 301 |
+
weight_files = [
|
| 302 |
+
s for s in siblings
|
| 303 |
+
if s.get("rfilename", "").lower().endswith(weight_exts)
|
| 304 |
+
and not _is_shard_file(s.get("rfilename", ""))
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
if len(weight_files) > 1:
|
| 308 |
+
versions = []
|
| 309 |
+
for s in weight_files[:15]:
|
| 310 |
+
filename = s["rfilename"]
|
| 311 |
+
# Detect variant from filename (n, s, m, l, x, or specific labels)
|
| 312 |
+
variant_label = "Stable"
|
| 313 |
+
fn_lower = filename.lower()
|
| 314 |
+
if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano"
|
| 315 |
+
elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small"
|
| 316 |
+
elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium"
|
| 317 |
+
elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large"
|
| 318 |
+
elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge"
|
| 319 |
+
|
| 320 |
+
versions.append(ModelVersion(
|
| 321 |
+
version=filename.replace(".", "_"),
|
| 322 |
+
label=variant_label,
|
| 323 |
+
description=f"Model variant: {filename}",
|
| 324 |
+
releaseDate=release_date,
|
| 325 |
+
changelog=None,
|
| 326 |
+
))
|
| 327 |
+
else:
|
| 328 |
+
versions = [
|
| 329 |
+
ModelVersion(
|
| 330 |
+
version=sha8,
|
| 331 |
+
label="Latest",
|
| 332 |
+
description="Primary model weight file.",
|
| 333 |
+
releaseDate=release_date,
|
| 334 |
+
changelog=None,
|
| 335 |
+
)
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
# Description from card data
|
| 339 |
+
description = self._extract_description("", raw)
|
| 340 |
+
if not description:
|
| 341 |
+
description = f"{task.capitalize()} model by {provider}."
|
| 342 |
+
|
| 343 |
+
# Metrics Estimation
|
| 344 |
+
metrics = self._estimate_metrics(model_id, task)
|
| 345 |
+
|
| 346 |
+
return Model(
|
| 347 |
+
id = model_id.replace("/", "_").lower(),
|
| 348 |
+
name = name,
|
| 349 |
+
task = task,
|
| 350 |
+
framework = framework,
|
| 351 |
+
source = "hf",
|
| 352 |
+
provider = provider,
|
| 353 |
+
description = description,
|
| 354 |
+
download_url = f"https://huggingface.co/{model_id}",
|
| 355 |
+
size = size,
|
| 356 |
+
size_label = self._format_size(size),
|
| 357 |
+
tags = tags,
|
| 358 |
+
hardware = hardware,
|
| 359 |
+
status = "available",
|
| 360 |
+
downloaded = False,
|
| 361 |
+
downloads = downloads,
|
| 362 |
+
rating = min(5.0, (likes / 200) + 3.5) if likes else None,
|
| 363 |
+
liked = False,
|
| 364 |
+
metrics = metrics,
|
| 365 |
+
versions = versions,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
async def fetch_models(self) -> list[Model]:
|
| 369 |
+
models: list[Model] = []
|
| 370 |
+
seen_ids: set[str] = set()
|
| 371 |
+
|
| 372 |
+
for pipeline_tag in FETCH_TASKS:
|
| 373 |
+
try:
|
| 374 |
+
raw_list = await self._fetch_task_page(
|
| 375 |
+
pipeline_tag, limit=settings.hf_models_per_task
|
| 376 |
+
)
|
| 377 |
+
for idx, raw in enumerate(raw_list):
|
| 378 |
+
# Enrich top-N per task with full model detail so siblings include sizes.
|
| 379 |
+
if idx < 10:
|
| 380 |
+
original_id = raw.get("id") or raw.get("modelId")
|
| 381 |
+
if original_id:
|
| 382 |
+
try:
|
| 383 |
+
raw = await self._fetch_model_detail(original_id)
|
| 384 |
+
except Exception:
|
| 385 |
+
pass
|
| 386 |
+
|
| 387 |
+
m = self._make_model(raw, pipeline_tag)
|
| 388 |
+
if m and m.id not in seen_ids:
|
| 389 |
+
# Try to fetch real-time description for the first 5 models of each task
|
| 390 |
+
if len([mod for mod in models if mod.task == m.task]) < 5:
|
| 391 |
+
original_id = raw.get("id") or raw.get("modelId")
|
| 392 |
+
if original_id:
|
| 393 |
+
readme = await self._fetch_model_card(original_id)
|
| 394 |
+
if readme:
|
| 395 |
+
m.description = self._extract_description(readme, raw)
|
| 396 |
+
|
| 397 |
+
seen_ids.add(m.id)
|
| 398 |
+
models.append(m)
|
| 399 |
+
# Be polite to HF API
|
| 400 |
+
await asyncio.sleep(0.3)
|
| 401 |
+
except Exception as exc:
|
| 402 |
+
log.warning(
|
| 403 |
+
"hf_fetch_task_failed",
|
| 404 |
+
pipeline_tag=pipeline_tag,
|
| 405 |
+
error=str(exc),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
log.info("hf_fetch_complete", total=len(models))
|
| 409 |
+
return models
|
| 410 |
+
|
| 411 |
+
async def __aenter__(self) -> "HFAdapter":
|
| 412 |
+
return self
|
| 413 |
+
|
| 414 |
+
async def __aexit__(self, *_: Any) -> None:
|
| 415 |
+
await self._client.aclose()
|
adapters/onnx_adapter.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
adapters/onnx_adapter.py — ONNX Model Zoo adapter.
|
| 3 |
+
Fetches the curated list of ONNX Zoo models from the GitHub API.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 11 |
+
|
| 12 |
+
from adapters.base import BaseAdapter
|
| 13 |
+
from models.model import Model, ModelMetrics, ModelVersion
|
| 14 |
+
from observability.logger import get_logger
|
| 15 |
+
|
| 16 |
+
log = get_logger("onnx_adapter")
|
| 17 |
+
|
| 18 |
+
# Curated ONNX Zoo models with metadata + download URLs (GitHub API is rate-limited without auth)
|
| 19 |
+
ONNX_CURATED: list[dict[str, Any]] = [
|
| 20 |
+
{
|
| 21 |
+
"id": "onnx_resnet50",
|
| 22 |
+
"name": "ResNet-50",
|
| 23 |
+
"task": "classification",
|
| 24 |
+
"provider": "ONNX Zoo",
|
| 25 |
+
"description": "ResNet-50 v1 image classification model in ONNX format.",
|
| 26 |
+
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx",
|
| 27 |
+
"size": 102_000_000,
|
| 28 |
+
"tags": ["resnet", "imagenet", "classification"],
|
| 29 |
+
"hardware": ["gpu", "cpu"],
|
| 30 |
+
"metrics": {"latency_ms": 14.2, "top1": 74.9},
|
| 31 |
+
"downloads": 250_000,
|
| 32 |
+
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-06-01"}],
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"id": "onnx_yolov8n",
|
| 36 |
+
"name": "YOLOv8n",
|
| 37 |
+
"task": "detection",
|
| 38 |
+
"provider": "Ultralytics",
|
| 39 |
+
"description": "Ultralytics YOLOv8 Nano — real-time object detection, ONNX export.",
|
| 40 |
+
"download_url": "https://github.com/ultralytics/yolov8/releases/download/v8.0.0/yolov8n.onnx",
|
| 41 |
+
"size": 6_200_000,
|
| 42 |
+
"tags": ["yolo", "real-time", "fastest", "edge"],
|
| 43 |
+
"hardware": ["gpu", "cpu", "edge"],
|
| 44 |
+
"metrics": {"latency_ms": 3.1, "mAP": 37.3},
|
| 45 |
+
"downloads": 420_000,
|
| 46 |
+
"versions": [{"version": "8.0", "label": "Latest", "releaseDate": "2023-09-15"}],
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"id": "onnx_mobilenet_v3",
|
| 50 |
+
"name": "MobileNetV3-Large",
|
| 51 |
+
"task": "classification",
|
| 52 |
+
"provider": "Google",
|
| 53 |
+
"description": "MobileNetV3-Large for efficient on-device image classification.",
|
| 54 |
+
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv3-large-1.11.onnx",
|
| 55 |
+
"size": 22_000_000,
|
| 56 |
+
"tags": ["mobilenet", "lightweight", "edge", "efficient"],
|
| 57 |
+
"hardware": ["cpu", "edge"],
|
| 58 |
+
"metrics": {"latency_ms": 5.8, "top1": 75.2, "fps": 180},
|
| 59 |
+
"downloads": 310_000,
|
| 60 |
+
"versions": [{"version": "3.0", "label": "Latest", "releaseDate": "2023-01-01"}],
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"id": "onnx_bert_base_uncased",
|
| 64 |
+
"name": "BERT-Base-Uncased",
|
| 65 |
+
"task": "nlp",
|
| 66 |
+
"provider": "Google",
|
| 67 |
+
"description": "BERT base model fine-tuned for NLP inference in ONNX format.",
|
| 68 |
+
"download_url": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
|
| 69 |
+
"size": 438_000_000,
|
| 70 |
+
"tags": ["bert", "nlp", "transformer"],
|
| 71 |
+
"hardware": ["gpu", "cpu"],
|
| 72 |
+
"metrics": {"latency_ms": 42.0},
|
| 73 |
+
"downloads": 198_000,
|
| 74 |
+
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2022-11-01"}],
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"id": "onnx_efficientnet_b0",
|
| 78 |
+
"name": "EfficientNet-B0",
|
| 79 |
+
"task": "classification",
|
| 80 |
+
"provider": "Google Brain",
|
| 81 |
+
"description": "EfficientNet-B0 for scalable image classification.",
|
| 82 |
+
"download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite/model/efficientnet-lite4-11.onnx",
|
| 83 |
+
"size": 20_000_000,
|
| 84 |
+
"tags": ["efficientnet", "efficient", "high-accuracy"],
|
| 85 |
+
"hardware": ["gpu", "cpu"],
|
| 86 |
+
"metrics": {"latency_ms": 10.4, "top1": 77.1},
|
| 87 |
+
"downloads": 145_000,
|
| 88 |
+
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-03-01"}],
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"id": "onnx_sam_vit_b",
|
| 92 |
+
"name": "SAM ViT-B",
|
| 93 |
+
"task": "segmentation",
|
| 94 |
+
"provider": "Meta AI",
|
| 95 |
+
"description": "Segment Anything Model (ViT-B) for universal image segmentation.",
|
| 96 |
+
"download_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
| 97 |
+
"size": 375_000_000,
|
| 98 |
+
"tags": ["sam", "segmentation", "sota"],
|
| 99 |
+
"hardware": ["gpu"],
|
| 100 |
+
"metrics": {"latency_ms": 68.0},
|
| 101 |
+
"downloads": 88_000,
|
| 102 |
+
"versions": [{"version": "1.0", "label": "Latest", "releaseDate": "2023-04-05"}],
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"id": "onnx_clip_vit_b32",
|
| 106 |
+
"name": "CLIP ViT-B/32",
|
| 107 |
+
"task": "embedding",
|
| 108 |
+
"provider": "OpenAI",
|
| 109 |
+
"description": "CLIP image + text embedding model for zero-shot classification.",
|
| 110 |
+
"download_url": "https://openaipublic.blob.core.windows.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba4f386/ViT-B-32.pt",
|
| 111 |
+
"size": 338_000_000,
|
| 112 |
+
"tags": ["clip", "embedding", "multimodal"],
|
| 113 |
+
"hardware": ["gpu", "cpu"],
|
| 114 |
+
"metrics": {"latency_ms": 25.0},
|
| 115 |
+
"downloads": 275_000,
|
| 116 |
+
"versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-01-01"}],
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"id": "onnx_whisper_tiny",
|
| 120 |
+
"name": "Whisper Tiny",
|
| 121 |
+
"task": "nlp",
|
| 122 |
+
"provider": "OpenAI",
|
| 123 |
+
"description": "Whisper Tiny speech-to-text model in ONNX format.",
|
| 124 |
+
"download_url": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424930e36a852c0/tiny.pt",
|
| 125 |
+
"size": 39_000_000,
|
| 126 |
+
"tags": ["whisper", "speech", "lightweight"],
|
| 127 |
+
"hardware": ["cpu", "edge"],
|
| 128 |
+
"metrics": {"latency_ms": 100.0},
|
| 129 |
+
"downloads": 167_000,
|
| 130 |
+
"versions": [{"version": "20231117", "label": "Latest", "releaseDate": "2023-11-17"}],
|
| 131 |
+
},
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ONNXAdapter(BaseAdapter):
|
| 136 |
+
source_name = "onnx"
|
| 137 |
+
|
| 138 |
+
async def fetch_models(self) -> list[Model]:
|
| 139 |
+
models: list[Model] = []
|
| 140 |
+
for raw in ONNX_CURATED:
|
| 141 |
+
try:
|
| 142 |
+
versions = [
|
| 143 |
+
ModelVersion(
|
| 144 |
+
version=v["version"],
|
| 145 |
+
label=v.get("label", "Stable"),
|
| 146 |
+
releaseDate=v.get("releaseDate", ""),
|
| 147 |
+
)
|
| 148 |
+
for v in raw.get("versions", [])
|
| 149 |
+
]
|
| 150 |
+
metrics_raw = raw.get("metrics", {})
|
| 151 |
+
m = Model(
|
| 152 |
+
id = raw["id"],
|
| 153 |
+
name = raw["name"],
|
| 154 |
+
task = raw["task"],
|
| 155 |
+
framework = "onnx",
|
| 156 |
+
source = "onnx",
|
| 157 |
+
provider = raw.get("provider", "ONNX Zoo"),
|
| 158 |
+
description = raw.get("description", ""),
|
| 159 |
+
download_url = raw.get("download_url"),
|
| 160 |
+
size = raw.get("size", 0),
|
| 161 |
+
size_label = self._format_size(raw.get("size", 0)),
|
| 162 |
+
tags = raw.get("tags", []),
|
| 163 |
+
hardware = raw.get("hardware", ["gpu"]),
|
| 164 |
+
status = "available",
|
| 165 |
+
downloaded = False,
|
| 166 |
+
downloads = raw.get("downloads"),
|
| 167 |
+
rating = 4.2,
|
| 168 |
+
metrics = ModelMetrics(**metrics_raw),
|
| 169 |
+
versions = versions,
|
| 170 |
+
)
|
| 171 |
+
models.append(m)
|
| 172 |
+
except Exception as exc:
|
| 173 |
+
log.warning("onnx_parse_failed", model_id=raw.get("id"), error=str(exc))
|
| 174 |
+
|
| 175 |
+
log.info("onnx_fetch_complete", total=len(models))
|
| 176 |
+
return models
|
adapters/roboflow_adapter.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
adapters/roboflow_adapter.py — Roboflow Universe API client.
|
| 3 |
+
|
| 4 |
+
Responsibilities:
|
| 5 |
+
- Fetch dataset metadata (search, workspace listings, project details)
|
| 6 |
+
- Normalise responses → Dataset domain model
|
| 7 |
+
- Cache results in roboflow_cache table (TTL-aware)
|
| 8 |
+
- Handle pagination, rate limits, and errors robustly
|
| 9 |
+
|
| 10 |
+
Roboflow API reference: https://docs.roboflow.com/api-reference/
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import hashlib
|
| 15 |
+
import json
|
| 16 |
+
import time
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 21 |
+
|
| 22 |
+
from database.connection import get_db
|
| 23 |
+
from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask
|
| 24 |
+
from observability.logger import audit, get_logger
|
| 25 |
+
|
| 26 |
+
log = get_logger("roboflow_adapter")
|
| 27 |
+
|
| 28 |
+
_ROBOFLOW_BASE = "https://api.roboflow.com"
|
| 29 |
+
_UNIVERSE_BASE = "https://universe.roboflow.com"
|
| 30 |
+
_DEFAULT_TTL = 3600 # 1 hour
|
| 31 |
+
|
| 32 |
+
# ── Task mapping from Roboflow annotation_type ───────────────────────────────
|
| 33 |
+
|
| 34 |
+
_TASK_MAP: dict[str, DatasetTask] = {
|
| 35 |
+
"object-detection": DatasetTask.detection,
|
| 36 |
+
"instance-segmentation": DatasetTask.segmentation,
|
| 37 |
+
"semantic-segmentation": DatasetTask.segmentation,
|
| 38 |
+
"classification": DatasetTask.classification,
|
| 39 |
+
"keypoint-detection": DatasetTask.keypoints,
|
| 40 |
+
"multiclass-classification": DatasetTask.classification,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
_FORMAT_MAP: dict[str, DatasetFormat] = {
|
| 44 |
+
"yolov5": DatasetFormat.yolo,
|
| 45 |
+
"yolov7": DatasetFormat.yolo,
|
| 46 |
+
"yolov8": DatasetFormat.yolo,
|
| 47 |
+
"yolov9": DatasetFormat.yolo,
|
| 48 |
+
"coco": DatasetFormat.coco,
|
| 49 |
+
"voc": DatasetFormat.voc,
|
| 50 |
+
"tfrecord": DatasetFormat.tfrecord,
|
| 51 |
+
"csv": DatasetFormat.csv,
|
| 52 |
+
"createml": DatasetFormat.json,
|
| 53 |
+
"multiclass": DatasetFormat.csv,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _cache_key(parts: list[str]) -> str:
|
| 58 |
+
raw = "|".join(parts)
|
| 59 |
+
return hashlib.sha256(raw.encode()).hexdigest()[:32]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _fmt_bytes(n: int) -> str:
|
| 63 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 64 |
+
if n < 1024:
|
| 65 |
+
return f"{n:.1f} {unit}"
|
| 66 |
+
n /= 1024
|
| 67 |
+
return f"{n:.1f} PB"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ── Cache helpers ─────────────────────────────────────────────────────────────
|
| 71 |
+
|
| 72 |
+
async def _cache_get(key: str) -> dict[str, Any] | None:
|
| 73 |
+
db = await get_db()
|
| 74 |
+
async with db.execute(
|
| 75 |
+
"SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?",
|
| 76 |
+
(key,),
|
| 77 |
+
) as cur:
|
| 78 |
+
row = await cur.fetchone()
|
| 79 |
+
if row is None:
|
| 80 |
+
return None
|
| 81 |
+
fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S"))
|
| 82 |
+
if time.time() - fetched > row["ttl_secs"]:
|
| 83 |
+
return None # expired
|
| 84 |
+
return json.loads(row["payload"])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None:
|
| 88 |
+
db = await get_db()
|
| 89 |
+
await db.execute(
|
| 90 |
+
"""INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs)
|
| 91 |
+
VALUES (?, ?, ?)""",
|
| 92 |
+
(key, json.dumps(payload), ttl),
|
| 93 |
+
)
|
| 94 |
+
await db.commit()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ── HTTP client factory ───────────────────────────────────────────────────────
|
| 98 |
+
|
| 99 |
+
def _make_client(api_key: str) -> httpx.AsyncClient:
|
| 100 |
+
return httpx.AsyncClient(
|
| 101 |
+
base_url=_ROBOFLOW_BASE,
|
| 102 |
+
params={"api_key": api_key},
|
| 103 |
+
timeout=30.0,
|
| 104 |
+
headers={"User-Agent": "MLForge/1.0"},
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ── Roboflow Adapter ──────────────────────────────────────────────────────────
|
| 109 |
+
|
| 110 |
+
class RoboflowAdapter:
|
| 111 |
+
"""
|
| 112 |
+
Stateless adapter for the Roboflow API.
|
| 113 |
+
All methods accept api_key explicitly to support per-user keys.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
# ── Search (Universe) ─────────────────────────────────────────────────────
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 120 |
+
async def search_datasets(
|
| 121 |
+
api_key: str,
|
| 122 |
+
query: str = "",
|
| 123 |
+
workspace: str | None = None,
|
| 124 |
+
page: int = 0,
|
| 125 |
+
page_size: int = 50,
|
| 126 |
+
) -> list[Dataset]:
|
| 127 |
+
"""
|
| 128 |
+
Search Roboflow Universe for datasets.
|
| 129 |
+
Returns normalised Dataset objects.
|
| 130 |
+
"""
|
| 131 |
+
ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)])
|
| 132 |
+
cached = await _cache_get(ck)
|
| 133 |
+
if cached:
|
| 134 |
+
log.debug("roboflow_cache_hit", key=ck, query=query)
|
| 135 |
+
return [Dataset(**d) for d in cached]
|
| 136 |
+
|
| 137 |
+
params: dict[str, Any] = {
|
| 138 |
+
"api_key": api_key,
|
| 139 |
+
"q": query or "*",
|
| 140 |
+
"from": page * page_size,
|
| 141 |
+
"size": page_size,
|
| 142 |
+
}
|
| 143 |
+
if workspace:
|
| 144 |
+
params["workspace"] = workspace
|
| 145 |
+
|
| 146 |
+
async with _make_client(api_key) as client:
|
| 147 |
+
try:
|
| 148 |
+
resp = await client.get("/", params=params)
|
| 149 |
+
resp.raise_for_status()
|
| 150 |
+
data = resp.json()
|
| 151 |
+
except httpx.HTTPStatusError as e:
|
| 152 |
+
log.error("roboflow_api_error", status=e.response.status_code, query=query)
|
| 153 |
+
await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error")
|
| 154 |
+
raise
|
| 155 |
+
|
| 156 |
+
datasets = []
|
| 157 |
+
for item in data.get("results", []):
|
| 158 |
+
try:
|
| 159 |
+
ds = RoboflowAdapter._normalise_search_result(item)
|
| 160 |
+
datasets.append(ds)
|
| 161 |
+
except Exception as exc:
|
| 162 |
+
log.warning("normalise_error", item_id=item.get("id"), error=str(exc))
|
| 163 |
+
|
| 164 |
+
await _cache_set(ck, [d.model_dump() for d in datasets])
|
| 165 |
+
await audit("roboflow_search", {"query": query, "count": len(datasets)})
|
| 166 |
+
return datasets
|
| 167 |
+
|
| 168 |
+
# ── Workspace datasets listing ────────────────────────────────────────────
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 172 |
+
async def list_workspace_datasets(
|
| 173 |
+
api_key: str,
|
| 174 |
+
workspace: str,
|
| 175 |
+
) -> list[Dataset]:
|
| 176 |
+
"""List all datasets in a Roboflow workspace."""
|
| 177 |
+
ck = _cache_key(["workspace", workspace])
|
| 178 |
+
cached = await _cache_get(ck)
|
| 179 |
+
if cached:
|
| 180 |
+
return [Dataset(**d) for d in cached]
|
| 181 |
+
|
| 182 |
+
async with _make_client(api_key) as client:
|
| 183 |
+
try:
|
| 184 |
+
resp = await client.get(f"/{workspace}")
|
| 185 |
+
resp.raise_for_status()
|
| 186 |
+
data = resp.json()
|
| 187 |
+
except httpx.HTTPStatusError as e:
|
| 188 |
+
log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code)
|
| 189 |
+
raise
|
| 190 |
+
|
| 191 |
+
datasets = []
|
| 192 |
+
for proj in data.get("workspace", {}).get("projects", []):
|
| 193 |
+
try:
|
| 194 |
+
ds = RoboflowAdapter._normalise_project(proj, workspace)
|
| 195 |
+
datasets.append(ds)
|
| 196 |
+
except Exception as exc:
|
| 197 |
+
log.warning("normalise_project_error", project=proj.get("id"), error=str(exc))
|
| 198 |
+
|
| 199 |
+
await _cache_set(ck, [d.model_dump() for d in datasets])
|
| 200 |
+
return datasets
|
| 201 |
+
|
| 202 |
+
# ── Single project detail ─────────────────────────────────────────────────
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
|
| 206 |
+
async def get_project(
|
| 207 |
+
api_key: str,
|
| 208 |
+
workspace: str,
|
| 209 |
+
project_id: str,
|
| 210 |
+
) -> Dataset | None:
|
| 211 |
+
"""Fetch full metadata for a single Roboflow project."""
|
| 212 |
+
ck = _cache_key(["project", workspace, project_id])
|
| 213 |
+
cached = await _cache_get(ck)
|
| 214 |
+
if cached:
|
| 215 |
+
return Dataset(**cached)
|
| 216 |
+
|
| 217 |
+
async with _make_client(api_key) as client:
|
| 218 |
+
try:
|
| 219 |
+
resp = await client.get(f"/{workspace}/{project_id}")
|
| 220 |
+
resp.raise_for_status()
|
| 221 |
+
data = resp.json()
|
| 222 |
+
except httpx.HTTPStatusError as e:
|
| 223 |
+
if e.response.status_code == 404:
|
| 224 |
+
return None
|
| 225 |
+
raise
|
| 226 |
+
|
| 227 |
+
proj_data = data.get("project", data)
|
| 228 |
+
ds = RoboflowAdapter._normalise_project(proj_data, workspace)
|
| 229 |
+
await _cache_set(ck, ds.model_dump())
|
| 230 |
+
return ds
|
| 231 |
+
|
| 232 |
+
# ── Download URL builder ──────────────────────────────────────────────────
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
async def get_download_url(
|
| 236 |
+
api_key: str,
|
| 237 |
+
workspace: str,
|
| 238 |
+
project_id: str,
|
| 239 |
+
version: int,
|
| 240 |
+
export_format: str = "yolov8",
|
| 241 |
+
) -> str:
|
| 242 |
+
"""
|
| 243 |
+
Fetch the export download link from Roboflow for the specified format.
|
| 244 |
+
Uses the official Roboflow SDK to handle authentication and URL resolution.
|
| 245 |
+
"""
|
| 246 |
+
try:
|
| 247 |
+
from roboflow import Roboflow
|
| 248 |
+
rf = Roboflow(api_key=api_key)
|
| 249 |
+
project = rf.workspace(workspace).project(project_id)
|
| 250 |
+
version_obj = project.version(version)
|
| 251 |
+
|
| 252 |
+
# The SDK's download method usually downloads to disk,
|
| 253 |
+
# but we can get the underlying export info.
|
| 254 |
+
# We'll use a thread to run the SDK call since it's blocking.
|
| 255 |
+
import asyncio
|
| 256 |
+
def _get_link():
|
| 257 |
+
return version_obj.export(export_format).download_link
|
| 258 |
+
|
| 259 |
+
link = await asyncio.to_thread(_get_link)
|
| 260 |
+
if not link:
|
| 261 |
+
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
|
| 262 |
+
return link
|
| 263 |
+
except Exception as e:
|
| 264 |
+
log.error("roboflow_sdk_error", error=str(e))
|
| 265 |
+
# Fallback to manual API if SDK fails or isn't installed correctly
|
| 266 |
+
async with _make_client(api_key) as client:
|
| 267 |
+
resp = await client.get(
|
| 268 |
+
f"/{workspace}/{project_id}/{version}/{export_format}"
|
| 269 |
+
)
|
| 270 |
+
resp.raise_for_status()
|
| 271 |
+
data = resp.json()
|
| 272 |
+
|
| 273 |
+
link = export.get("link") or ""
|
| 274 |
+
if not link:
|
| 275 |
+
# If 'link' is missing, check if it's a Universe-style project and try to resolve manually
|
| 276 |
+
# Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key]
|
| 277 |
+
if "project" in data:
|
| 278 |
+
pid = data["project"].get("id")
|
| 279 |
+
if pid:
|
| 280 |
+
link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}"
|
| 281 |
+
|
| 282 |
+
if not link:
|
| 283 |
+
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
|
| 284 |
+
|
| 285 |
+
# Ensure the link includes the API key correctly
|
| 286 |
+
if "universe.roboflow.com" in link:
|
| 287 |
+
if "key=" not in link:
|
| 288 |
+
separator = "&" if "?" in link else "?"
|
| 289 |
+
link = f"{link}{separator}key={api_key}"
|
| 290 |
+
elif f"key={api_key}" not in link:
|
| 291 |
+
# Replace old key if it exists but is wrong
|
| 292 |
+
import re
|
| 293 |
+
link = re.sub(r"key=[^&]+", f"key={api_key}", link)
|
| 294 |
+
|
| 295 |
+
return link
|
| 296 |
+
|
| 297 |
+
# ── Normalisation helpers ─────────────────────────────────────────────────
|
| 298 |
+
|
| 299 |
+
@staticmethod
|
| 300 |
+
def _normalise_search_result(item: dict[str, Any]) -> Dataset:
|
| 301 |
+
"""Map a Universe search result → Dataset."""
|
| 302 |
+
ann_type = item.get("annotation", {}).get("type", "object-detection")
|
| 303 |
+
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
|
| 304 |
+
class_names = [c.get("name", "") for c in item.get("classes", [])]
|
| 305 |
+
images = item.get("images", 0) or 0
|
| 306 |
+
|
| 307 |
+
return Dataset(
|
| 308 |
+
id = item.get("id", "").replace("/", "__"),
|
| 309 |
+
name = item.get("name", "Unnamed"),
|
| 310 |
+
description = item.get("description", ""),
|
| 311 |
+
task = rf_task,
|
| 312 |
+
format = DatasetFormat.yolo,
|
| 313 |
+
source = DatasetSource.roboflow,
|
| 314 |
+
status = DatasetStatus.available,
|
| 315 |
+
images = images,
|
| 316 |
+
classes = len(class_names),
|
| 317 |
+
class_names = class_names,
|
| 318 |
+
size_bytes = 0,
|
| 319 |
+
size_label = "—",
|
| 320 |
+
tags = item.get("tags", []),
|
| 321 |
+
roboflow_id = item.get("id", ""),
|
| 322 |
+
created_at = item.get("created", ""),
|
| 323 |
+
updated_at = item.get("updated", ""),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset:
|
| 328 |
+
"""Map a workspace project → Dataset."""
|
| 329 |
+
ann_type = proj.get("annotation", "object-detection")
|
| 330 |
+
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
|
| 331 |
+
class_names = [c.get("name", c) if isinstance(c, dict) else c
|
| 332 |
+
for c in proj.get("classes", [])]
|
| 333 |
+
project_id = proj.get("id", proj.get("name", "unknown"))
|
| 334 |
+
rf_id = f"{workspace}/{project_id}"
|
| 335 |
+
images = proj.get("images", 0) or 0
|
| 336 |
+
|
| 337 |
+
return Dataset(
|
| 338 |
+
id = rf_id.replace("/", "__"),
|
| 339 |
+
name = proj.get("name", project_id),
|
| 340 |
+
description = proj.get("description", ""),
|
| 341 |
+
task = rf_task,
|
| 342 |
+
format = DatasetFormat.yolo,
|
| 343 |
+
source = DatasetSource.roboflow,
|
| 344 |
+
status = DatasetStatus.available,
|
| 345 |
+
images = images,
|
| 346 |
+
classes = len(class_names),
|
| 347 |
+
class_names = class_names,
|
| 348 |
+
size_bytes = 0,
|
| 349 |
+
size_label = "—",
|
| 350 |
+
roboflow_id = rf_id,
|
| 351 |
+
created_at = proj.get("created", ""),
|
| 352 |
+
updated_at = proj.get("updated", ""),
|
| 353 |
+
)
|
benchmark/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# benchmark — Benchmark Bridge System for MLForge
|
benchmark/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (144 Bytes). View file
|
|
|
benchmark/__pycache__/compatibility.cpython-310.pyc
ADDED
|
Binary file (8.3 kB). View file
|
|
|
benchmark/__pycache__/execution.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
benchmark/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (3.24 kB). View file
|
|
|
benchmark/__pycache__/orchestrator.cpython-310.pyc
ADDED
|
Binary file (9.11 kB). View file
|
|
|
benchmark/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (8.77 kB). View file
|
|
|
benchmark/__pycache__/telemetry.cpython-310.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
benchmark/adapters/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
benchmark/adapters/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (1.89 kB). View file
|
|
|
benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
benchmark/adapters/base.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/adapters/base.py — Base class for all Benchmark Runners.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Any, AsyncGenerator
|
| 9 |
+
|
| 10 |
+
from models.benchmark import BenchmarkContext, TelemetrySample
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BatchResult:
|
| 15 |
+
"""Result of a single batch execution."""
|
| 16 |
+
latency_ms: float
|
| 17 |
+
vram_used_gb: float
|
| 18 |
+
task_scores: dict[str, float] = field(default_factory=dict)
|
| 19 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseRunner(ABC):
|
| 23 |
+
"""Abstract interface for benchmark executors (Torch, Optimum, vLLM)."""
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 27 |
+
"""Load model and prepare environment."""
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
async def run_batch(self, batch: Any) -> BatchResult:
|
| 32 |
+
"""Execute a single batch of data."""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
async def shutdown(self) -> None:
|
| 37 |
+
"""Release resources."""
|
| 38 |
+
pass
|
benchmark/adapters/optimum_runner.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/adapters/optimum_runner.py — Hugging Face Optimum Adapter.
|
| 3 |
+
Supports ONNX, OpenVINO, and TensorRT acceleration.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Any
|
| 10 |
+
from benchmark.adapters.base import BaseRunner, BatchResult
|
| 11 |
+
from models.benchmark import BenchmarkContext
|
| 12 |
+
from observability.logger import get_logger
|
| 13 |
+
|
| 14 |
+
log = get_logger("benchmark.optimum")
|
| 15 |
+
|
| 16 |
+
class OptimumRunner(BaseRunner):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.session = None
|
| 19 |
+
self.device = "cpu"
|
| 20 |
+
|
| 21 |
+
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Load model using Optimum's ORTModel or equivalent.
|
| 24 |
+
In a real implementation, this would detect the framework and use:
|
| 25 |
+
ORTModelForFeatureExtraction.from_pretrained(model_path, provider=...)
|
| 26 |
+
"""
|
| 27 |
+
log.info("optimum_init", model_path=model_path, hardware=ctx.hardware)
|
| 28 |
+
self.device = "cuda" if "gpu" in ctx.hardware.lower() or "rtx" in ctx.hardware.lower() else "cpu"
|
| 29 |
+
|
| 30 |
+
# Simulate load time
|
| 31 |
+
await asyncio.sleep(1.5)
|
| 32 |
+
self.session = "active" # Placeholder for the real session object
|
| 33 |
+
|
| 34 |
+
async def run_batch(self, batch: Any) -> BatchResult:
|
| 35 |
+
"""Execute inference using the Optimum/ONNX Runtime session."""
|
| 36 |
+
if not self.session:
|
| 37 |
+
raise RuntimeError("Optimum session not initialized")
|
| 38 |
+
|
| 39 |
+
start_time = time.perf_counter()
|
| 40 |
+
# Mocking inference logic
|
| 41 |
+
# outputs = self.session(**batch)
|
| 42 |
+
await asyncio.sleep(0.01) # Simulated inference time
|
| 43 |
+
latency = (time.perf_counter() - start_time) * 1000
|
| 44 |
+
|
| 45 |
+
return BatchResult(
|
| 46 |
+
latency_ms=latency,
|
| 47 |
+
vram_used_gb=0.8, # Mocked
|
| 48 |
+
task_scores={"accuracy": 0.92} # Mocked
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
async def shutdown(self) -> None:
|
| 52 |
+
log.info("optimum_shutdown")
|
| 53 |
+
self.session = None
|
benchmark/adapters/registry.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/adapters/registry.py — Executor Registry for dynamic runner resolution.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Type
|
| 7 |
+
from benchmark.adapters.base import BaseRunner
|
| 8 |
+
from models.benchmark import BenchmarkContext
|
| 9 |
+
from models.model import Model
|
| 10 |
+
|
| 11 |
+
class ExecutorRegistry:
|
| 12 |
+
_runners: dict[str, Type[BaseRunner]] = {}
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def register(cls, framework: str, runner_cls: Type[BaseRunner]):
|
| 16 |
+
cls._runners[framework.lower()] = runner_cls
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def get_runner(cls, framework: str) -> BaseRunner:
|
| 20 |
+
runner_cls = cls._runners.get(framework.lower())
|
| 21 |
+
if not runner_cls:
|
| 22 |
+
# Fallback or default runner
|
| 23 |
+
from benchmark.adapters.torch_runner import TorchRunner
|
| 24 |
+
return TorchRunner()
|
| 25 |
+
return runner_cls()
|
| 26 |
+
|
| 27 |
+
def get_executor(ctx: BenchmarkContext, model: Model) -> BaseRunner:
|
| 28 |
+
"""Resolve the appropriate executor based on framework and task."""
|
| 29 |
+
framework = model.framework.lower()
|
| 30 |
+
|
| 31 |
+
# Special cases for optimized engines
|
| 32 |
+
if framework == "onnx" or framework == "openvino" or framework == "tensorrt":
|
| 33 |
+
from benchmark.adapters.optimum_runner import OptimumRunner
|
| 34 |
+
return OptimumRunner()
|
| 35 |
+
|
| 36 |
+
if ctx.task in ("generation", "nlp") and framework == "pytorch":
|
| 37 |
+
# Potential for vLLM if configured
|
| 38 |
+
try:
|
| 39 |
+
from benchmark.adapters.vllm_runner import VLLMRunner
|
| 40 |
+
return VLLMRunner()
|
| 41 |
+
except ImportError:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
return ExecutorRegistry.get_runner(framework)
|
benchmark/adapters/torch_runner.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/adapters/torch_runner.py — PyTorch Runner Adapter.
|
| 3 |
+
Wraps standard PyTorch inference for Vision and NLP tasks.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import asyncio
|
| 9 |
+
import random
|
| 10 |
+
from typing import Any
|
| 11 |
+
from benchmark.adapters.base import BaseRunner, BatchResult
|
| 12 |
+
from models.benchmark import BenchmarkContext
|
| 13 |
+
from observability.logger import get_logger
|
| 14 |
+
|
| 15 |
+
log = get_logger("benchmark.torch")
|
| 16 |
+
|
| 17 |
+
class TorchRunner(BaseRunner):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.model = None
|
| 20 |
+
self.device = "cpu"
|
| 21 |
+
|
| 22 |
+
async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
|
| 23 |
+
log.info("torch_init", model_path=model_path, hardware=ctx.hardware)
|
| 24 |
+
# In production: self.model = torch.load(model_path).to(self.device)
|
| 25 |
+
await asyncio.sleep(1.0)
|
| 26 |
+
self.model = "active"
|
| 27 |
+
|
| 28 |
+
async def run_batch(self, batch: Any) -> BatchResult:
|
| 29 |
+
if not self.model:
|
| 30 |
+
raise RuntimeError("Torch model not initialized")
|
| 31 |
+
|
| 32 |
+
start_time = time.perf_counter()
|
| 33 |
+
# Mocking torch inference
|
| 34 |
+
await asyncio.sleep(0.02)
|
| 35 |
+
latency = (time.perf_counter() - start_time) * 1000
|
| 36 |
+
|
| 37 |
+
return BatchResult(
|
| 38 |
+
latency_ms=latency,
|
| 39 |
+
vram_used_gb=1.2,
|
| 40 |
+
task_scores={"mAP": 0.45}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
async def shutdown(self) -> None:
|
| 44 |
+
log.info("torch_shutdown")
|
| 45 |
+
self.model = None
|
benchmark/compatibility.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/compatibility.py — Compatibility Validator (CRITICAL MODULE).
|
| 3 |
+
|
| 4 |
+
Validates model ↔ dataset ↔ hardware compatibility before any benchmark
|
| 5 |
+
execution begins. Returns a structured ValidationReport — never raises.
|
| 6 |
+
|
| 7 |
+
Five gates (all must pass):
|
| 8 |
+
A. Task compatibility — model.task matches dataset.task
|
| 9 |
+
B. Annotation format — dataset format supports the model's task
|
| 10 |
+
C. Framework × hardware — framework can run on the requested device
|
| 11 |
+
D. VRAM constraint — estimated memory fits available VRAM
|
| 12 |
+
E. Precision support — precision mode is valid for framework + hardware
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from models.benchmark import BenchmarkContext, ValidationCheck, ValidationReport
|
| 17 |
+
from models.dataset import Dataset
|
| 18 |
+
from models.model import Model
|
| 19 |
+
from observability.logger import get_logger
|
| 20 |
+
|
| 21 |
+
log = get_logger("benchmark.compatibility")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ── Lookup tables ─────────────────────────────────────────────────────────────
|
| 25 |
+
|
| 26 |
+
# Hardware → available VRAM in GB (normalized keys, no spaces/dashes)
|
| 27 |
+
HARDWARE_VRAM_GB: dict[str, float] = {
|
| 28 |
+
# NVIDIA consumer — Ampere / Ada
|
| 29 |
+
"rtx4090": 24.0,
|
| 30 |
+
"rtx4080": 16.0,
|
| 31 |
+
"rtx4070ti": 12.0,
|
| 32 |
+
"rtx4070": 12.0,
|
| 33 |
+
"rtx4060ti": 8.0,
|
| 34 |
+
"rtx4060": 8.0,
|
| 35 |
+
"rtx3090": 24.0,
|
| 36 |
+
"rtx3080": 10.0,
|
| 37 |
+
"rtx3070": 8.0,
|
| 38 |
+
"rtx3060": 12.0,
|
| 39 |
+
"rtx2080ti": 11.0,
|
| 40 |
+
"rtx2080": 8.0,
|
| 41 |
+
# NVIDIA datacenter
|
| 42 |
+
"a100": 80.0,
|
| 43 |
+
"a10040gb": 40.0,
|
| 44 |
+
"h100": 80.0,
|
| 45 |
+
"v100": 32.0,
|
| 46 |
+
"t4": 16.0,
|
| 47 |
+
"a10": 24.0,
|
| 48 |
+
# AMD
|
| 49 |
+
"rx7900xtx": 24.0,
|
| 50 |
+
"rx6800xt": 16.0,
|
| 51 |
+
# Generic fallbacks
|
| 52 |
+
"gpu": 8.0,
|
| 53 |
+
"cpu": 0.0,
|
| 54 |
+
"tpu": 0.0,
|
| 55 |
+
"edge": 0.0,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# model.task → set of compatible dataset.task values
|
| 59 |
+
TASK_COMPAT: dict[str, set[str]] = {
|
| 60 |
+
"detection": {"detection"},
|
| 61 |
+
"classification": {"classification"},
|
| 62 |
+
"segmentation": {"segmentation"},
|
| 63 |
+
"nlp": {"nlp"},
|
| 64 |
+
"generation": {"generation"},
|
| 65 |
+
"keypoints": {"keypoints", "detection"},
|
| 66 |
+
"embedding": {"nlp", "classification"},
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# dataset.format → set of model tasks it supports
|
| 70 |
+
FORMAT_TASK_COMPAT: dict[str, set[str]] = {
|
| 71 |
+
"yolo": {"detection", "segmentation", "keypoints"},
|
| 72 |
+
"coco": {"detection", "segmentation", "keypoints"},
|
| 73 |
+
"voc": {"detection"},
|
| 74 |
+
"csv": {"classification"},
|
| 75 |
+
"json": {"detection", "segmentation", "classification", "nlp", "generation"},
|
| 76 |
+
"tfrecord": {"detection", "classification", "segmentation"},
|
| 77 |
+
"custom": {"detection", "classification", "segmentation", "nlp", "generation", "keypoints"},
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# model.framework → set of hardware targets (normalized) it can run on
|
| 81 |
+
FRAMEWORK_HARDWARE_COMPAT: dict[str, set[str]] = {
|
| 82 |
+
"pytorch": {
|
| 83 |
+
"cpu", "gpu",
|
| 84 |
+
"rtx4090", "rtx4080", "rtx4070ti", "rtx4070", "rtx4060ti", "rtx4060",
|
| 85 |
+
"rtx3090", "rtx3080", "rtx3070", "rtx3060",
|
| 86 |
+
"rtx2080ti", "rtx2080",
|
| 87 |
+
"a100", "a10040gb", "h100", "v100", "t4", "a10",
|
| 88 |
+
},
|
| 89 |
+
"onnx": {
|
| 90 |
+
"cpu", "gpu",
|
| 91 |
+
"rtx4090", "rtx3090", "a100", "h100", "t4", "a10",
|
| 92 |
+
"edge",
|
| 93 |
+
},
|
| 94 |
+
"tensorflow": {
|
| 95 |
+
"cpu", "gpu",
|
| 96 |
+
"rtx4090", "rtx3090", "a100", "h100", "v100", "t4",
|
| 97 |
+
"tpu",
|
| 98 |
+
},
|
| 99 |
+
"tflite": {"cpu", "edge"},
|
| 100 |
+
"coreml": {"cpu"},
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Precisions that require GPU
|
| 104 |
+
_GPU_ONLY_PRECISIONS = {"FP16", "BF16"}
|
| 105 |
+
|
| 106 |
+
# Frameworks supporting INT8 quantization
|
| 107 |
+
_INT8_FRAMEWORKS = {"onnx", "tflite", "pytorch", "tensorflow"}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class CompatibilityValidator:
|
| 111 |
+
"""
|
| 112 |
+
Runs all compatibility gates before a benchmark job is created.
|
| 113 |
+
Returns a ValidationReport — never raises exceptions.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def validate(
|
| 117 |
+
self,
|
| 118 |
+
model: Model,
|
| 119 |
+
dataset: Dataset,
|
| 120 |
+
ctx: BenchmarkContext,
|
| 121 |
+
) -> ValidationReport:
|
| 122 |
+
checks: list[ValidationCheck] = [
|
| 123 |
+
self._check_task(model, dataset),
|
| 124 |
+
self._check_annotation_format(model, dataset),
|
| 125 |
+
self._check_framework_hardware(model, ctx),
|
| 126 |
+
self._check_vram(model, ctx),
|
| 127 |
+
self._check_precision(model, ctx),
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
errors = [c.detail for c in checks if not c.passed]
|
| 131 |
+
warnings: list[str] = []
|
| 132 |
+
|
| 133 |
+
log.info(
|
| 134 |
+
"compatibility_validated",
|
| 135 |
+
model_id = model.id,
|
| 136 |
+
dataset_id = dataset.id,
|
| 137 |
+
passed = len(errors) == 0,
|
| 138 |
+
error_count = len(errors),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return ValidationReport(
|
| 142 |
+
model_id = model.id,
|
| 143 |
+
dataset_id = dataset.id,
|
| 144 |
+
passed = len(errors) == 0,
|
| 145 |
+
checks = checks,
|
| 146 |
+
errors = errors,
|
| 147 |
+
warnings = warnings,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# ── Gate A: Task ────────────────────��─────────────────────────────────────
|
| 151 |
+
|
| 152 |
+
def _check_task(self, model: Model, dataset: Dataset) -> ValidationCheck:
|
| 153 |
+
model_task = model.task.lower().strip()
|
| 154 |
+
dataset_task = str(dataset.task).lower().strip()
|
| 155 |
+
|
| 156 |
+
allowed = TASK_COMPAT.get(model_task, {model_task})
|
| 157 |
+
if dataset_task in allowed:
|
| 158 |
+
return ValidationCheck(
|
| 159 |
+
name = "task_compatibility",
|
| 160 |
+
passed = True,
|
| 161 |
+
detail = (
|
| 162 |
+
f"Model task '{model_task}' is compatible "
|
| 163 |
+
f"with dataset task '{dataset_task}'"
|
| 164 |
+
),
|
| 165 |
+
)
|
| 166 |
+
return ValidationCheck(
|
| 167 |
+
name = "task_compatibility",
|
| 168 |
+
passed = False,
|
| 169 |
+
detail = (
|
| 170 |
+
f"Model task '{model_task}' cannot evaluate "
|
| 171 |
+
f"a '{dataset_task}' dataset"
|
| 172 |
+
),
|
| 173 |
+
suggestion = (
|
| 174 |
+
f"Select a model with task='{dataset_task}', "
|
| 175 |
+
f"or choose a dataset with task='{model_task}'"
|
| 176 |
+
),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# ── Gate B: Annotation Format ─────────────────────────────────────────────
|
| 180 |
+
|
| 181 |
+
def _check_annotation_format(self, model: Model, dataset: Dataset) -> ValidationCheck:
|
| 182 |
+
dataset_fmt = str(dataset.format).lower().strip()
|
| 183 |
+
model_task = model.task.lower().strip()
|
| 184 |
+
supported = FORMAT_TASK_COMPAT.get(dataset_fmt, set())
|
| 185 |
+
|
| 186 |
+
if model_task in supported:
|
| 187 |
+
return ValidationCheck(
|
| 188 |
+
name = "annotation_format",
|
| 189 |
+
passed = True,
|
| 190 |
+
detail = (
|
| 191 |
+
f"Dataset format '{dataset_fmt}' supports "
|
| 192 |
+
f"model task '{model_task}'"
|
| 193 |
+
),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if model_task in {"detection", "segmentation", "keypoints"}:
|
| 197 |
+
suggestion = (
|
| 198 |
+
f"Convert dataset to YOLO or COCO format — both support '{model_task}'"
|
| 199 |
+
)
|
| 200 |
+
elif model_task == "classification":
|
| 201 |
+
suggestion = "Convert dataset to CSV or JSON format for classification tasks"
|
| 202 |
+
else:
|
| 203 |
+
suggestion = f"Use a JSON or custom-format dataset for '{model_task}' tasks"
|
| 204 |
+
|
| 205 |
+
return ValidationCheck(
|
| 206 |
+
name = "annotation_format",
|
| 207 |
+
passed = False,
|
| 208 |
+
detail = (
|
| 209 |
+
f"Dataset format '{dataset_fmt}' does not support "
|
| 210 |
+
f"model task '{model_task}'"
|
| 211 |
+
),
|
| 212 |
+
suggestion = suggestion,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# ── Gate C: Framework × Hardware ─────────────────────────────────────────
|
| 216 |
+
|
| 217 |
+
def _check_framework_hardware(
|
| 218 |
+
self, model: Model, ctx: BenchmarkContext
|
| 219 |
+
) -> ValidationCheck:
|
| 220 |
+
framework = model.framework.lower().strip()
|
| 221 |
+
hw_raw = ctx.hardware
|
| 222 |
+
hw_key = self._normalize_hw(hw_raw)
|
| 223 |
+
|
| 224 |
+
supported_hw = FRAMEWORK_HARDWARE_COMPAT.get(framework, {"cpu"})
|
| 225 |
+
|
| 226 |
+
# Match: exact key, or generic "gpu" bucket covers any named GPU
|
| 227 |
+
hw_ok = (
|
| 228 |
+
hw_key in supported_hw
|
| 229 |
+
or ("gpu" in supported_hw and hw_key not in {"cpu", "tpu", "edge"})
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if hw_ok:
|
| 233 |
+
return ValidationCheck(
|
| 234 |
+
name = "framework_hardware",
|
| 235 |
+
passed = True,
|
| 236 |
+
detail = f"Framework '{framework}' is supported on '{hw_raw}'",
|
| 237 |
+
)
|
| 238 |
+
return ValidationCheck(
|
| 239 |
+
name = "framework_hardware",
|
| 240 |
+
passed = False,
|
| 241 |
+
detail = (
|
| 242 |
+
f"Framework '{framework}' cannot run on '{hw_raw}'. "
|
| 243 |
+
f"Supported targets: {', '.join(sorted(supported_hw))}"
|
| 244 |
+
),
|
| 245 |
+
suggestion = (
|
| 246 |
+
"Use ONNX runtime for broadest hardware support, "
|
| 247 |
+
f"or pick a device from: {', '.join(sorted(supported_hw))}"
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# ── Gate D: VRAM Constraint ───────────────────────────────────────────────
|
| 252 |
+
|
| 253 |
+
def _check_vram(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
|
| 254 |
+
hw_key = self._normalize_hw(ctx.hardware)
|
| 255 |
+
available = self._lookup_vram(hw_key)
|
| 256 |
+
|
| 257 |
+
if available == 0.0:
|
| 258 |
+
return ValidationCheck(
|
| 259 |
+
name = "vram_constraint",
|
| 260 |
+
passed = True,
|
| 261 |
+
detail = f"Running on '{ctx.hardware}' (CPU/TPU/Edge) — no VRAM constraint",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Estimate: weights at given precision + activations for one batch
|
| 265 |
+
model_gb = max(model.size, 1) / (1024 ** 3)
|
| 266 |
+
prec_map = {"FP16": 0.5, "BF16": 0.5, "INT8": 0.25, "FP32": 1.0}
|
| 267 |
+
prec_mult = prec_map.get(ctx.precision.upper(), 1.0)
|
| 268 |
+
# weights × precision + ~20% for optimizer/activation buffers + batch overhead
|
| 269 |
+
estimated = (model_gb * prec_mult * 1.2) + (ctx.batch_size * 0.05)
|
| 270 |
+
|
| 271 |
+
if estimated <= available:
|
| 272 |
+
return ValidationCheck(
|
| 273 |
+
name = "vram_constraint",
|
| 274 |
+
passed = True,
|
| 275 |
+
detail = (
|
| 276 |
+
f"Estimated VRAM {estimated:.2f} GB ≤ "
|
| 277 |
+
f"available {available:.1f} GB on '{ctx.hardware}'"
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
return ValidationCheck(
|
| 281 |
+
name = "vram_constraint",
|
| 282 |
+
passed = False,
|
| 283 |
+
detail = (
|
| 284 |
+
f"Estimated VRAM {estimated:.2f} GB exceeds "
|
| 285 |
+
f"available {available:.1f} GB on '{ctx.hardware}'"
|
| 286 |
+
),
|
| 287 |
+
suggestion = (
|
| 288 |
+
f"Try: reduce batch_size (now {ctx.batch_size}), "
|
| 289 |
+
f"switch to FP16/INT8 precision, "
|
| 290 |
+
f"or use a GPU with ≥ {estimated:.1f} GB VRAM"
|
| 291 |
+
),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# ── Gate E: Precision Support ─────────────────────────────────────────────
|
| 295 |
+
|
| 296 |
+
def _check_precision(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
|
| 297 |
+
precision = ctx.precision.upper()
|
| 298 |
+
framework = model.framework.lower().strip()
|
| 299 |
+
hw_key = self._normalize_hw(ctx.hardware)
|
| 300 |
+
is_gpu = hw_key not in {"cpu", "tpu", "edge"}
|
| 301 |
+
|
| 302 |
+
if precision in _GPU_ONLY_PRECISIONS and not is_gpu:
|
| 303 |
+
return ValidationCheck(
|
| 304 |
+
name = "precision_support",
|
| 305 |
+
passed = False,
|
| 306 |
+
detail = (
|
| 307 |
+
f"Precision '{precision}' requires a CUDA GPU; "
|
| 308 |
+
f"'{ctx.hardware}' does not support it"
|
| 309 |
+
),
|
| 310 |
+
suggestion = "Use FP32 for CPU inference, or switch to a compatible GPU",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if precision == "INT8" and framework not in _INT8_FRAMEWORKS:
|
| 314 |
+
return ValidationCheck(
|
| 315 |
+
name = "precision_support",
|
| 316 |
+
passed = False,
|
| 317 |
+
detail = (
|
| 318 |
+
f"Framework '{framework}' does not support INT8 quantization"
|
| 319 |
+
),
|
| 320 |
+
suggestion = (
|
| 321 |
+
"Convert model to ONNX or use PyTorch with torch.quantization"
|
| 322 |
+
),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return ValidationCheck(
|
| 326 |
+
name = "precision_support",
|
| 327 |
+
passed = True,
|
| 328 |
+
detail = (
|
| 329 |
+
f"Precision '{precision}' is valid for "
|
| 330 |
+
f"framework '{framework}' on '{ctx.hardware}'"
|
| 331 |
+
),
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# ── Helpers ───────────────────────────────────────────────────────────────
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def _normalize_hw(hardware: str) -> str:
|
| 338 |
+
"""Lowercase, strip spaces/dashes/underscores for lookup."""
|
| 339 |
+
return (
|
| 340 |
+
hardware.lower()
|
| 341 |
+
.replace(" ", "")
|
| 342 |
+
.replace("-", "")
|
| 343 |
+
.replace("_", "")
|
| 344 |
+
.replace("nvidia", "")
|
| 345 |
+
.replace("geforce", "")
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
@staticmethod
|
| 349 |
+
def _lookup_vram(hw_key: str) -> float:
|
| 350 |
+
"""Return VRAM GB for a normalized hardware key, with fallback matching."""
|
| 351 |
+
if hw_key in HARDWARE_VRAM_GB:
|
| 352 |
+
return HARDWARE_VRAM_GB[hw_key]
|
| 353 |
+
# Partial match (e.g. "rtx4090laptop" → "rtx4090")
|
| 354 |
+
for key, vram in HARDWARE_VRAM_GB.items():
|
| 355 |
+
if key and key in hw_key:
|
| 356 |
+
return vram
|
| 357 |
+
# Anything that looks like a GPU but isn't in the table
|
| 358 |
+
if "gpu" in hw_key or "rtx" in hw_key or "gtx" in hw_key or "cuda" in hw_key:
|
| 359 |
+
return HARDWARE_VRAM_GB["gpu"]
|
| 360 |
+
return 0.0 # CPU / unknown → no VRAM constraint
|
benchmark/execution.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/execution.py — Benchmark Execution Engine.
|
| 3 |
+
|
| 4 |
+
Drives the batch inference loop, collecting latencies and VRAM readings.
|
| 5 |
+
Calls TelemetryCollector in parallel with batch processing.
|
| 6 |
+
Yields progress callbacks so the orchestrator can persist real-time state.
|
| 7 |
+
|
| 8 |
+
Adapter pattern: swap _run_single_batch() with a real inference call
|
| 9 |
+
(torch.cuda.synchronize + model(batch)) once GPU runtime is wired up.
|
| 10 |
+
|
| 11 |
+
PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import math
|
| 17 |
+
import random
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Awaitable, Callable
|
| 20 |
+
|
| 21 |
+
from benchmark.compatibility import HARDWARE_VRAM_GB
|
| 22 |
+
from benchmark.telemetry import TelemetryCollector
|
| 23 |
+
from models.benchmark import BenchmarkJob, LayerBreakdown, TelemetrySample, TelemetrySummary
|
| 24 |
+
from models.dataset import Dataset
|
| 25 |
+
from models.model import Model
|
| 26 |
+
from observability.logger import get_logger
|
| 27 |
+
|
| 28 |
+
log = get_logger("benchmark.execution")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Per-image latency profiles (ms at batch=1, fp32) ─────────────────────────
|
| 32 |
+
_LATENCY_MS_PER_IMAGE: dict[str, float] = {
|
| 33 |
+
"rtx4090": 1.8,
|
| 34 |
+
"rtx4080": 2.5,
|
| 35 |
+
"rtx4070ti": 3.2,
|
| 36 |
+
"rtx4070": 3.8,
|
| 37 |
+
"rtx3090": 3.0,
|
| 38 |
+
"rtx3080": 4.5,
|
| 39 |
+
"rtx3070": 6.5,
|
| 40 |
+
"rtx3060": 9.0,
|
| 41 |
+
"rtx2080ti": 5.0,
|
| 42 |
+
"rtx2080": 7.5,
|
| 43 |
+
"a100": 1.2,
|
| 44 |
+
"h100": 0.7,
|
| 45 |
+
"v100": 2.8,
|
| 46 |
+
"t4": 5.5,
|
| 47 |
+
"a10": 3.5,
|
| 48 |
+
"gpu": 8.0,
|
| 49 |
+
"cpu": 42.0,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# Precision speedup multipliers (relative to FP32)
|
| 53 |
+
_PRECISION_SPEEDUP: dict[str, float] = {
|
| 54 |
+
"FP32": 1.0,
|
| 55 |
+
"FP16": 1.8,
|
| 56 |
+
"BF16": 1.7,
|
| 57 |
+
"INT8": 2.5,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Task-specific baseline metric scores (pre-jitter)
|
| 61 |
+
_TASK_BASELINES: dict[str, dict[str, float]] = {
|
| 62 |
+
"detection": {"mAP": 0.435, "mAP_50": 0.618, "mAP_50_95": 0.435},
|
| 63 |
+
"classification": {"accuracy": 0.872, "top5": 0.968},
|
| 64 |
+
"segmentation": {"mAP": 0.372, "iou_mean": 0.706},
|
| 65 |
+
"keypoints": {"mAP": 0.641, "mAP_50": 0.860},
|
| 66 |
+
"nlp": {"accuracy": 0.891},
|
| 67 |
+
"generation": {"accuracy": 0.780},
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# Cap simulated batches so large datasets don't stall the event loop
|
| 71 |
+
_MAX_SIMULATED_BATCHES = 250
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class ExecutionResult:
|
| 76 |
+
"""Raw output from the execution engine, consumed by MetricsEngine."""
|
| 77 |
+
latencies_ms: list[float]
|
| 78 |
+
total_images: int
|
| 79 |
+
vram_samples: list[float]
|
| 80 |
+
task_scores: dict[str, float]
|
| 81 |
+
telemetry_samples: list[TelemetrySample] = field(default_factory=list)
|
| 82 |
+
telemetry_summary: TelemetrySummary = field(default_factory=TelemetrySummary)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Progress callback type: (progress_0_to_1, message, last_telemetry) → None
|
| 86 |
+
ProgressCallback = Callable[[float, str, TelemetrySample | None], Awaitable[None]]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BenchmarkExecutor:
|
| 90 |
+
"""
|
| 91 |
+
Drives the benchmark execution loop.
|
| 92 |
+
Non-blocking: all sleeps are asyncio.sleep so other coroutines run freely.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
async def execute(
|
| 96 |
+
self,
|
| 97 |
+
job: BenchmarkJob,
|
| 98 |
+
model: Model,
|
| 99 |
+
dataset: Dataset,
|
| 100 |
+
on_progress: ProgressCallback,
|
| 101 |
+
) -> ExecutionResult:
|
| 102 |
+
hw = job.hardware
|
| 103 |
+
batch_sz = job.batch_size
|
| 104 |
+
|
| 105 |
+
# Handle polymorphic input duration
|
| 106 |
+
is_live = getattr(job, "input_source", "dataset") in ("video", "live")
|
| 107 |
+
|
| 108 |
+
if is_live:
|
| 109 |
+
# For live/video, we run for a fixed duration or until stopped
|
| 110 |
+
# Increase limit for a longer session (e.g., 10,000 batches)
|
| 111 |
+
total_img = 10000 * batch_sz
|
| 112 |
+
n_batches = 10000
|
| 113 |
+
sim_batches = 10000
|
| 114 |
+
else:
|
| 115 |
+
total_img = max(dataset.images, 100) # floor so simulation always runs
|
| 116 |
+
n_batches = math.ceil(total_img / batch_sz)
|
| 117 |
+
sim_batches = min(n_batches, _MAX_SIMULATED_BATCHES)
|
| 118 |
+
|
| 119 |
+
vram_total = self._get_vram_gb(hw, model)
|
| 120 |
+
vram_frac = self._vram_usage_fraction(hw)
|
| 121 |
+
|
| 122 |
+
telemetry = TelemetryCollector(hw, vram_total_gb=vram_total)
|
| 123 |
+
await telemetry.start()
|
| 124 |
+
|
| 125 |
+
latencies: list[float] = []
|
| 126 |
+
vram_samples: list[float] = []
|
| 127 |
+
|
| 128 |
+
base_lat_ms = self._base_batch_latency_ms(hw, model, batch_sz, job.precision)
|
| 129 |
+
|
| 130 |
+
# Resolve real model path once (None → use simulation)
|
| 131 |
+
real_model_path = model.local_path if model.local_path and model.downloaded else None
|
| 132 |
+
use_real_inference = self._check_torch_available() and real_model_path is not None
|
| 133 |
+
loop = asyncio.get_event_loop()
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
for sim_idx in range(sim_batches):
|
| 137 |
+
# Map simulated index back to real batch index
|
| 138 |
+
real_idx = int(sim_idx * (n_batches / sim_batches))
|
| 139 |
+
|
| 140 |
+
if use_real_inference:
|
| 141 |
+
# Real GPU inference via torch_runner (runs in thread executor)
|
| 142 |
+
try:
|
| 143 |
+
from benchmark.torch_runner import run_torch_batch
|
| 144 |
+
batch_lat_ms = await loop.run_in_executor(
|
| 145 |
+
None,
|
| 146 |
+
run_torch_batch,
|
| 147 |
+
real_model_path,
|
| 148 |
+
batch_sz,
|
| 149 |
+
job.task,
|
| 150 |
+
)
|
| 151 |
+
# Add a tiny sleep to prevent event loop starvation in live mode
|
| 152 |
+
if is_live:
|
| 153 |
+
await asyncio.sleep(0.001)
|
| 154 |
+
except Exception as exc:
|
| 155 |
+
log.warning("torch_inference_failed_fallback", error=str(exc))
|
| 156 |
+
use_real_inference = False # fall back for remaining batches
|
| 157 |
+
batch_lat_ms = max(
|
| 158 |
+
0.5, base_lat_ms + random.gauss(0, base_lat_ms * 0.07)
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
# Simulation path — non-blocking synthetic latency
|
| 162 |
+
batch_lat_ms = max(
|
| 163 |
+
0.5,
|
| 164 |
+
base_lat_ms + random.gauss(0, base_lat_ms * 0.07),
|
| 165 |
+
)
|
| 166 |
+
await asyncio.sleep(batch_lat_ms / 1000.0) # non-blocking
|
| 167 |
+
|
| 168 |
+
latencies.append(batch_lat_ms)
|
| 169 |
+
vram_used = vram_total * random.uniform(
|
| 170 |
+
vram_frac - 0.05, vram_frac + 0.05
|
| 171 |
+
)
|
| 172 |
+
vram_samples.append(max(0.0, vram_used))
|
| 173 |
+
|
| 174 |
+
progress = (sim_idx + 1) / sim_batches
|
| 175 |
+
telemetry.record_batch_context(real_idx, progress)
|
| 176 |
+
|
| 177 |
+
# Throttle callbacks: every 5 batches or first/last
|
| 178 |
+
if sim_idx % 5 == 0 or sim_idx == sim_batches - 1:
|
| 179 |
+
images_done = int(progress * total_img)
|
| 180 |
+
|
| 181 |
+
# Generate simulated detection data for live preview if it's a vision task
|
| 182 |
+
live_data = {}
|
| 183 |
+
if job.task.lower() in ("detection", "segmentation"):
|
| 184 |
+
# Use provided bbox telemetry if available (e.g. from real inference)
|
| 185 |
+
# otherwise generate simulated ones
|
| 186 |
+
live_data["detections"] = [
|
| 187 |
+
{
|
| 188 |
+
"x": random.uniform(0.1, 0.7),
|
| 189 |
+
"y": random.uniform(0.1, 0.7),
|
| 190 |
+
"width": random.uniform(0.1, 0.3),
|
| 191 |
+
"height": random.uniform(0.1, 0.3),
|
| 192 |
+
"label": random.choice(["person", "car", "bicycle", "dog"]),
|
| 193 |
+
"confidence": random.uniform(0.5, 0.99)
|
| 194 |
+
}
|
| 195 |
+
for _ in range(random.randint(1, 5))
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
last_sample = telemetry.samples[-1] if telemetry.samples else None
|
| 199 |
+
if last_sample:
|
| 200 |
+
last_sample.live_data = live_data
|
| 201 |
+
# Explicitly broadcast detections for the visualizer
|
| 202 |
+
last_sample.detections = live_data.get("detections", [])
|
| 203 |
+
|
| 204 |
+
await on_progress(
|
| 205 |
+
progress,
|
| 206 |
+
f"Batch {real_idx+1}/{n_batches} — "
|
| 207 |
+
f"{images_done}/{total_img} images processed",
|
| 208 |
+
last_sample,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
finally:
|
| 212 |
+
telemetry_summary = await telemetry.stop()
|
| 213 |
+
# Attach simulated layer breakdown so Live Lab can display it
|
| 214 |
+
telemetry_summary.layer_breakdown = self._compute_layer_breakdown(
|
| 215 |
+
job.task, base_lat_ms
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
task_scores = self._simulate_task_scores(job.task, model, dataset)
|
| 219 |
+
|
| 220 |
+
log.info(
|
| 221 |
+
"execution_complete",
|
| 222 |
+
job_id = job.id,
|
| 223 |
+
total_images = total_img,
|
| 224 |
+
sim_batches = sim_batches,
|
| 225 |
+
avg_lat_ms = round(sum(latencies) / len(latencies), 2) if latencies else 0,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return ExecutionResult(
|
| 229 |
+
latencies_ms = latencies,
|
| 230 |
+
total_images = total_img,
|
| 231 |
+
vram_samples = vram_samples,
|
| 232 |
+
task_scores = task_scores,
|
| 233 |
+
telemetry_samples = telemetry.samples,
|
| 234 |
+
telemetry_summary = telemetry_summary,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ── Helpers ───────────────────────────────────────────────────────────────
|
| 238 |
+
|
| 239 |
+
def _base_batch_latency_ms(
|
| 240 |
+
self,
|
| 241 |
+
hardware: str,
|
| 242 |
+
model: Model,
|
| 243 |
+
batch_sz: int,
|
| 244 |
+
precision: str,
|
| 245 |
+
) -> float:
|
| 246 |
+
"""
|
| 247 |
+
Estimate per-batch latency in ms.
|
| 248 |
+
Accounts for hardware tier, model size, batch size, and precision.
|
| 249 |
+
"""
|
| 250 |
+
hw_key = self._normalize_hw(hardware)
|
| 251 |
+
per_img = self._lookup_latency(hw_key)
|
| 252 |
+
|
| 253 |
+
# Larger models are slower: +30% per GB of model weights
|
| 254 |
+
size_gb = max(model.size, 1) / (1024 ** 3)
|
| 255 |
+
size_factor = 1.0 + size_gb * 0.30
|
| 256 |
+
|
| 257 |
+
# Batch parallelism: ~65% linear efficiency on GPU, 90% on CPU
|
| 258 |
+
eff = 0.65 if "cpu" not in hw_key else 0.90
|
| 259 |
+
batch_lat = per_img * size_factor * batch_sz * eff
|
| 260 |
+
|
| 261 |
+
# Precision speedup
|
| 262 |
+
speedup = _PRECISION_SPEEDUP.get(precision.upper(), 1.0)
|
| 263 |
+
|
| 264 |
+
return batch_lat / speedup
|
| 265 |
+
|
| 266 |
+
def _get_vram_gb(self, hardware: str, model: Model) -> float:
|
| 267 |
+
hw_key = self._normalize_hw(hardware)
|
| 268 |
+
for key, vram in HARDWARE_VRAM_GB.items():
|
| 269 |
+
if key and key in hw_key:
|
| 270 |
+
return vram
|
| 271 |
+
return 8.0
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def _vram_usage_fraction(hardware: str) -> float:
|
| 275 |
+
"""Fraction of VRAM typically consumed during inference."""
|
| 276 |
+
hw = hardware.lower()
|
| 277 |
+
if any(x in hw for x in ("4090", "3090", "a100", "h100")):
|
| 278 |
+
return 0.62
|
| 279 |
+
if any(x in hw for x in ("4080", "3080", "v100", "a10")):
|
| 280 |
+
return 0.60
|
| 281 |
+
if "cpu" in hw:
|
| 282 |
+
return 0.0
|
| 283 |
+
return 0.55
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def _simulate_task_scores(
|
| 287 |
+
task: str, model: Model, dataset: Dataset
|
| 288 |
+
) -> dict[str, float]:
|
| 289 |
+
"""
|
| 290 |
+
Produce realistic metric scores with small per-run variance.
|
| 291 |
+
|
| 292 |
+
PRODUCTION SWAP: replace with actual metric computation:
|
| 293 |
+
from torchmetrics.detection import MeanAveragePrecision
|
| 294 |
+
metric = MeanAveragePrecision()
|
| 295 |
+
metric.update(predictions, targets)
|
| 296 |
+
return metric.compute()
|
| 297 |
+
"""
|
| 298 |
+
baselines = dict(_TASK_BASELINES.get(task.lower(), {"accuracy": 0.80}))
|
| 299 |
+
# Small Gaussian jitter simulates run-to-run variance
|
| 300 |
+
return {
|
| 301 |
+
k: float(max(0.0, min(1.0, v + random.gauss(0, 0.015))))
|
| 302 |
+
for k, v in baselines.items()
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
@staticmethod
|
| 306 |
+
def _check_torch_available() -> bool:
|
| 307 |
+
"""Return True if PyTorch is installed and importable."""
|
| 308 |
+
try:
|
| 309 |
+
import torch # noqa: F401
|
| 310 |
+
return True
|
| 311 |
+
except ImportError:
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
@staticmethod
|
| 315 |
+
def _compute_layer_breakdown(task: str, base_lat_ms: float) -> list[LayerBreakdown]:
|
| 316 |
+
"""Build a realistic layer breakdown for the given task.
|
| 317 |
+
|
| 318 |
+
Splits total latency across architectural stages with small jitter.
|
| 319 |
+
PRODUCTION SWAP: replace with actual profiler data (e.g. torch.profiler).
|
| 320 |
+
"""
|
| 321 |
+
if task.lower() in ("detection", "segmentation"):
|
| 322 |
+
stages = [
|
| 323 |
+
("Backbone", 0.45),
|
| 324 |
+
("Neck (FPN/PAFPN)", 0.30),
|
| 325 |
+
("Detection Head", 0.20),
|
| 326 |
+
("NMS Post-process", 0.05),
|
| 327 |
+
]
|
| 328 |
+
elif task.lower() == "classification":
|
| 329 |
+
stages = [
|
| 330 |
+
("Feature Extractor", 0.70),
|
| 331 |
+
("Classifier Head", 0.20),
|
| 332 |
+
("Softmax", 0.10),
|
| 333 |
+
]
|
| 334 |
+
else:
|
| 335 |
+
stages = [
|
| 336 |
+
("Encoder", 0.55),
|
| 337 |
+
("Decoder / Head", 0.35),
|
| 338 |
+
("Post-process", 0.10),
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
result: list[LayerBreakdown] = []
|
| 342 |
+
remaining = base_lat_ms
|
| 343 |
+
for name, frac in stages:
|
| 344 |
+
t = round(base_lat_ms * frac + random.gauss(0, base_lat_ms * 0.01), 3)
|
| 345 |
+
result.append(LayerBreakdown(name=name, time_ms=t, percent=round(frac * 100, 1)))
|
| 346 |
+
return result
|
| 347 |
+
|
| 348 |
+
@staticmethod
|
| 349 |
+
def _normalize_hw(hardware: str) -> str:
|
| 350 |
+
return (
|
| 351 |
+
hardware.lower()
|
| 352 |
+
.replace(" ", "")
|
| 353 |
+
.replace("-", "")
|
| 354 |
+
.replace("_", "")
|
| 355 |
+
.replace("nvidia", "")
|
| 356 |
+
.replace("geforce", "")
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def _lookup_latency(hw_key: str) -> float:
|
| 361 |
+
for key, ms in _LATENCY_MS_PER_IMAGE.items():
|
| 362 |
+
if key and key in hw_key:
|
| 363 |
+
return ms
|
| 364 |
+
if any(x in hw_key for x in ("gpu", "rtx", "gtx", "cuda")):
|
| 365 |
+
return _LATENCY_MS_PER_IMAGE["gpu"]
|
| 366 |
+
return _LATENCY_MS_PER_IMAGE["cpu"]
|
benchmark/metrics.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/metrics.py — Metrics Engine.
|
| 3 |
+
|
| 4 |
+
Computes the final BenchmarkMetrics object from raw execution data:
|
| 5 |
+
- Latency statistics (mean, p95, p99)
|
| 6 |
+
- Throughput (FPS)
|
| 7 |
+
- VRAM statistics (avg, peak)
|
| 8 |
+
- Task-specific scores (mAP, accuracy, IoU) supplied by the executor
|
| 9 |
+
|
| 10 |
+
In a production deployment the task_scores dict comes from actual
|
| 11 |
+
metric computation (e.g. pycocotools, torchmetrics). In this local-first
|
| 12 |
+
build the executor supplies realistic simulated scores.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import statistics
|
| 17 |
+
|
| 18 |
+
from models.benchmark import BenchmarkMetrics, LayerBreakdown, TelemetrySummary
|
| 19 |
+
from observability.logger import get_logger
|
| 20 |
+
|
| 21 |
+
log = get_logger("benchmark.metrics")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MetricsEngine:
|
| 25 |
+
"""Computes BenchmarkMetrics from raw benchmark execution data."""
|
| 26 |
+
|
| 27 |
+
def compute(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
task: str,
|
| 31 |
+
latencies_ms: list[float], # per-batch latencies
|
| 32 |
+
total_images: int = 0,
|
| 33 |
+
total_tokens: int = 0,
|
| 34 |
+
batch_size: int,
|
| 35 |
+
vram_samples: list[float], # VRAM readings (GB) during run
|
| 36 |
+
task_scores: dict[str, float], # task-specific metric scores
|
| 37 |
+
) -> BenchmarkMetrics:
|
| 38 |
+
if not latencies_ms:
|
| 39 |
+
return BenchmarkMetrics(total_images=total_images, total_tokens=total_tokens, batch_size=batch_size)
|
| 40 |
+
|
| 41 |
+
total_time_s = sum(latencies_ms) / 1000.0
|
| 42 |
+
fps = total_images / total_time_s if total_time_s > 0 and total_images > 0 else 0.0
|
| 43 |
+
tps = total_tokens / total_time_s if total_time_s > 0 and total_tokens > 0 else 0.0
|
| 44 |
+
|
| 45 |
+
lat_mean = statistics.mean(latencies_ms)
|
| 46 |
+
lat_p95 = _percentile(latencies_ms, 0.95)
|
| 47 |
+
lat_p99 = _percentile(latencies_ms, 0.99)
|
| 48 |
+
|
| 49 |
+
vram_peak = max(vram_samples) if vram_samples else 0.0
|
| 50 |
+
vram_avg = statistics.mean(vram_samples) if vram_samples else 0.0
|
| 51 |
+
|
| 52 |
+
m = BenchmarkMetrics(
|
| 53 |
+
fps = round(fps, 2),
|
| 54 |
+
tokens_per_sec = round(tps, 2),
|
| 55 |
+
latency_mean_ms = round(lat_mean, 3),
|
| 56 |
+
latency_p95_ms = round(lat_p95, 3),
|
| 57 |
+
latency_p99_ms = round(lat_p99, 3),
|
| 58 |
+
vram_peak_gb = round(vram_peak, 3),
|
| 59 |
+
vram_avg_gb = round(vram_avg, 3),
|
| 60 |
+
total_images = total_images,
|
| 61 |
+
total_tokens = total_tokens,
|
| 62 |
+
batch_size = batch_size,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
task_lower = task.lower()
|
| 66 |
+
|
| 67 |
+
# CV Task Mapping
|
| 68 |
+
if task_lower in ("detection", "segmentation", "keypoints"):
|
| 69 |
+
m.mAP = _fmt(task_scores.get("mAP", 0.0))
|
| 70 |
+
m.mAP_50 = _fmt(task_scores.get("mAP_50", 0.0))
|
| 71 |
+
m.mAP_50_95 = _fmt(task_scores.get("mAP_50_95", 0.0))
|
| 72 |
+
if task_lower == "segmentation":
|
| 73 |
+
m.iou_mean = _fmt(task_scores.get("iou_mean", 0.0))
|
| 74 |
+
|
| 75 |
+
elif task_lower == "classification":
|
| 76 |
+
m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
|
| 77 |
+
m.top1 = _fmt(task_scores.get("top1", 0.0))
|
| 78 |
+
m.top5 = _fmt(task_scores.get("top5", 0.0))
|
| 79 |
+
|
| 80 |
+
# NLP Task Mapping (ROUGE, BLEU, Perplexity)
|
| 81 |
+
elif task_lower in ("nlp", "generation"):
|
| 82 |
+
m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
|
| 83 |
+
m.rouge_l = _fmt(task_scores.get("rouge_l", task_scores.get("rougeL", 0.0)))
|
| 84 |
+
m.bleu = _fmt(task_scores.get("bleu", 0.0))
|
| 85 |
+
m.perplexity = task_scores.get("perplexity")
|
| 86 |
+
|
| 87 |
+
log.info(
|
| 88 |
+
"metrics_computed",
|
| 89 |
+
task = task,
|
| 90 |
+
fps = m.fps,
|
| 91 |
+
tps = m.tokens_per_sec,
|
| 92 |
+
latency_ms = m.latency_mean_ms,
|
| 93 |
+
vram_peak = m.vram_peak_gb,
|
| 94 |
+
)
|
| 95 |
+
return m
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 99 |
+
|
| 100 |
+
def _percentile(data: list[float], p: float) -> float:
|
| 101 |
+
if not data:
|
| 102 |
+
return 0.0
|
| 103 |
+
s = sorted(data)
|
| 104 |
+
idx = min(int(len(s) * p), len(s) - 1)
|
| 105 |
+
return s[idx]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _fmt(v: float) -> float:
|
| 109 |
+
"""Round to 4dp and clamp to [0, 1]."""
|
| 110 |
+
return round(max(0.0, min(1.0, v)), 4)
|
benchmark/orchestrator.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/orchestrator.py — Benchmark Orchestrator (Main Controller).
|
| 3 |
+
|
| 4 |
+
Coordinates the full benchmark lifecycle:
|
| 5 |
+
1. Resolve model + dataset from their registries
|
| 6 |
+
2. Run all compatibility checks (gates A–E)
|
| 7 |
+
3. If valid → create a BenchmarkJob in the DB
|
| 8 |
+
4. Persist the validation audit log
|
| 9 |
+
5. Enqueue async background task → execution → metrics → storage
|
| 10 |
+
6. Return the job immediately so callers are non-blocking
|
| 11 |
+
|
| 12 |
+
Public interface used by api/routes/benchmark.py:
|
| 13 |
+
validate_context(ctx) → ValidationReport (no job created)
|
| 14 |
+
create_and_run(ctx) → BenchmarkJob (job queued, execution in background)
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
from datetime import datetime, timezone
|
| 20 |
+
|
| 21 |
+
from benchmark.adapters.registry import get_executor
|
| 22 |
+
from benchmark.compatibility import CompatibilityValidator
|
| 23 |
+
from benchmark.execution import BenchmarkExecutor
|
| 24 |
+
from benchmark.metrics import MetricsEngine
|
| 25 |
+
import benchmark.registry as bench_reg
|
| 26 |
+
from datasets.registry import get_dataset
|
| 27 |
+
from models.benchmark import (
|
| 28 |
+
BenchmarkContext,
|
| 29 |
+
BenchmarkJob,
|
| 30 |
+
BenchmarkMetrics,
|
| 31 |
+
TelemetrySummary,
|
| 32 |
+
ValidationReport,
|
| 33 |
+
)
|
| 34 |
+
from models.dataset import Dataset
|
| 35 |
+
from models.model import Model
|
| 36 |
+
from observability.logger import audit, get_logger
|
| 37 |
+
from registry.registry import get_model
|
| 38 |
+
|
| 39 |
+
log = get_logger("benchmark.orchestrator")
|
| 40 |
+
|
| 41 |
+
# Module-level singletons — stateless, safe to share
|
| 42 |
+
_validator = CompatibilityValidator()
|
| 43 |
+
_metrics = MetricsEngine()
|
| 44 |
+
|
| 45 |
+
# job_id → asyncio.Task (for future cancellation support)
|
| 46 |
+
_active_tasks: dict[str, asyncio.Task] = {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ── Public API ────────────────────────────────────────────────────────────────
|
| 50 |
+
|
| 51 |
+
async def sync_project_benchmarks() -> int:
|
| 52 |
+
"""
|
| 53 |
+
Sync benchmark jobs and results from the active project's 'benchmarks' folder.
|
| 54 |
+
This ensures that benchmarks created in different sessions or projects are indexed.
|
| 55 |
+
"""
|
| 56 |
+
from benchmark.registry import _get_active_project_benchmark_dir_sync
|
| 57 |
+
from projects.service import get_active_project_path
|
| 58 |
+
import json
|
| 59 |
+
import os
|
| 60 |
+
from database.connection import get_db
|
| 61 |
+
|
| 62 |
+
project_path = await get_active_project_path()
|
| 63 |
+
benchmark_dir = _get_active_project_benchmark_dir_sync(project_path)
|
| 64 |
+
if not benchmark_dir or not benchmark_dir.exists():
|
| 65 |
+
return 0
|
| 66 |
+
|
| 67 |
+
db = await get_db()
|
| 68 |
+
count = 0
|
| 69 |
+
|
| 70 |
+
for file_path in benchmark_dir.glob("*.json"):
|
| 71 |
+
try:
|
| 72 |
+
with open(file_path, "r") as f:
|
| 73 |
+
data = json.load(f)
|
| 74 |
+
|
| 75 |
+
# Check if it's a job or a result
|
| 76 |
+
if file_path.name.startswith("job_"):
|
| 77 |
+
# Upsert into benchmark_jobs
|
| 78 |
+
await db.execute(
|
| 79 |
+
"""INSERT OR IGNORE INTO benchmark_jobs
|
| 80 |
+
(id, model_id, dataset_id, task, framework, hardware,
|
| 81 |
+
precision, batch_size, config, status, progress, created_at, updated_at, started_at)
|
| 82 |
+
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
| 83 |
+
(
|
| 84 |
+
data["id"], data["model_id"], data["dataset_id"],
|
| 85 |
+
data["task"], data["framework"], data["hardware"],
|
| 86 |
+
data["precision"], data["batch_size"],
|
| 87 |
+
json.dumps(data["config"]), data["status"],
|
| 88 |
+
data.get("progress", 0.0),
|
| 89 |
+
data.get("created_at", datetime.now(timezone.utc).isoformat()),
|
| 90 |
+
data.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
| 91 |
+
data.get("started_at")
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
count += 1
|
| 95 |
+
elif file_path.name.startswith("result_"):
|
| 96 |
+
# Upsert into benchmark_results
|
| 97 |
+
await db.execute(
|
| 98 |
+
"""INSERT OR IGNORE INTO benchmark_results
|
| 99 |
+
(id, job_id, metrics, telemetry_summary, created_at)
|
| 100 |
+
VALUES (?,?,?,?,?)""",
|
| 101 |
+
(
|
| 102 |
+
data["id"], data["job_id"],
|
| 103 |
+
json.dumps(data["metrics"]),
|
| 104 |
+
json.dumps(data["telemetry_summary"]),
|
| 105 |
+
data.get("created_at", datetime.now(timezone.utc).isoformat())
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
count += 1
|
| 109 |
+
except Exception as e:
|
| 110 |
+
log.error("sync_file_failed", file=file_path.name, error=str(e))
|
| 111 |
+
|
| 112 |
+
await db.commit()
|
| 113 |
+
log.info("sync_complete", count=count)
|
| 114 |
+
return count
|
| 115 |
+
|
| 116 |
+
async def validate_context(ctx: BenchmarkContext) -> ValidationReport:
|
| 117 |
+
"""
|
| 118 |
+
Validate model ↔ dataset ↔ hardware compatibility.
|
| 119 |
+
Does NOT create a job. Safe to call repeatedly from the UI.
|
| 120 |
+
"""
|
| 121 |
+
model = await _require_model(ctx.model_id)
|
| 122 |
+
|
| 123 |
+
# ── Handle Polymorphic Input (Video/Live) ��───────────────────────────────
|
| 124 |
+
if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
|
| 125 |
+
# Create a synthetic dataset object for non-dataset sources
|
| 126 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 127 |
+
dataset = Dataset(
|
| 128 |
+
id="none",
|
| 129 |
+
name="Live/Video Stream",
|
| 130 |
+
task=model.task, # Match model task to pass task check
|
| 131 |
+
format="custom",
|
| 132 |
+
source="local",
|
| 133 |
+
status="imported",
|
| 134 |
+
images=0,
|
| 135 |
+
classes=0,
|
| 136 |
+
size_label="0 MB",
|
| 137 |
+
created_at=now,
|
| 138 |
+
updated_at=now
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
dataset = await _require_dataset(ctx.dataset_id)
|
| 142 |
+
|
| 143 |
+
return _validator.validate(model, dataset, ctx)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def create_and_run(ctx: BenchmarkContext) -> BenchmarkJob:
|
| 147 |
+
"""
|
| 148 |
+
Full benchmark initiation:
|
| 149 |
+
"""
|
| 150 |
+
model = await _require_model(ctx.model_id)
|
| 151 |
+
|
| 152 |
+
# ── Handle Polymorphic Input (Video/Live) ────────────────────────────────
|
| 153 |
+
if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
|
| 154 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 155 |
+
dataset = Dataset(
|
| 156 |
+
id="none",
|
| 157 |
+
name="Live/Video Stream",
|
| 158 |
+
task=model.task,
|
| 159 |
+
format="custom",
|
| 160 |
+
source="local",
|
| 161 |
+
status="imported",
|
| 162 |
+
images=0,
|
| 163 |
+
classes=0,
|
| 164 |
+
size_label="0 MB",
|
| 165 |
+
created_at=now,
|
| 166 |
+
updated_at=now
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
dataset = await _require_dataset(ctx.dataset_id)
|
| 170 |
+
|
| 171 |
+
# ── Compatibility check ───────────────────────────────────────────────────
|
| 172 |
+
report = _validator.validate(model, dataset, ctx)
|
| 173 |
+
|
| 174 |
+
# Always persist the validation log (even for failures)
|
| 175 |
+
await bench_reg.save_validation_log(
|
| 176 |
+
job_id = "pre-check",
|
| 177 |
+
model_id = ctx.model_id,
|
| 178 |
+
dataset_id = ctx.dataset_id,
|
| 179 |
+
checks = report.checks,
|
| 180 |
+
passed = report.passed,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if not report.passed:
|
| 184 |
+
from fastapi import HTTPException
|
| 185 |
+
failed = [c for c in report.checks if not c.passed]
|
| 186 |
+
raise HTTPException(
|
| 187 |
+
status_code = 422,
|
| 188 |
+
detail = {
|
| 189 |
+
"error": "Compatibility validation failed",
|
| 190 |
+
"failed_checks": [
|
| 191 |
+
{
|
| 192 |
+
"name": c.name,
|
| 193 |
+
"detail": c.detail,
|
| 194 |
+
"suggestion": c.suggestion,
|
| 195 |
+
}
|
| 196 |
+
for c in failed
|
| 197 |
+
],
|
| 198 |
+
},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ── Create job ────────────────────────────────────────────────────────────
|
| 202 |
+
job = await bench_reg.create_job(ctx)
|
| 203 |
+
|
| 204 |
+
# Overwrite 'pre-check' validation log with the real job_id
|
| 205 |
+
await bench_reg.save_validation_log(
|
| 206 |
+
job_id = job.id,
|
| 207 |
+
model_id = ctx.model_id,
|
| 208 |
+
dataset_id = ctx.dataset_id,
|
| 209 |
+
checks = report.checks,
|
| 210 |
+
passed = True,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# ── Log the Polymorphic Input params ─────────────────────────────────────
|
| 214 |
+
if ctx.input_source or ctx.video_path or ctx.rtsp_url:
|
| 215 |
+
log.info("polymorphic_input_received",
|
| 216 |
+
job_id=job.id,
|
| 217 |
+
source=ctx.input_source,
|
| 218 |
+
video=ctx.video_path,
|
| 219 |
+
rtsp=ctx.rtsp_url)
|
| 220 |
+
|
| 221 |
+
# ── Enqueue background execution ──────────────────────────────────────────
|
| 222 |
+
task = asyncio.create_task(
|
| 223 |
+
_execute_job(job.id, ctx, model, dataset),
|
| 224 |
+
name = f"benchmark_{job.id}",
|
| 225 |
+
)
|
| 226 |
+
_active_tasks[job.id] = task
|
| 227 |
+
task.add_done_callback(lambda _t: _active_tasks.pop(job.id, None))
|
| 228 |
+
|
| 229 |
+
log.info("benchmark_enqueued", job_id=job.id, model=ctx.model_id)
|
| 230 |
+
return job
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ── Background execution ──────────────────────────────────────────────────────
|
| 234 |
+
|
| 235 |
+
async def _execute_job(
|
| 236 |
+
job_id: str,
|
| 237 |
+
ctx: BenchmarkContext,
|
| 238 |
+
model: Model,
|
| 239 |
+
dataset: Dataset,
|
| 240 |
+
) -> None:
|
| 241 |
+
"""Full benchmark lifecycle — runs in an asyncio background task."""
|
| 242 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 243 |
+
|
| 244 |
+
# Transition → running
|
| 245 |
+
ts_color = "\x1b[36m" # Cyan
|
| 246 |
+
info_color = "\x1b[34m" # Blue
|
| 247 |
+
success_color = "\x1b[32m" # Green
|
| 248 |
+
reset = "\x1b[0m"
|
| 249 |
+
|
| 250 |
+
await bench_reg.update_job(
|
| 251 |
+
job_id,
|
| 252 |
+
status = "running",
|
| 253 |
+
progress = 0.0,
|
| 254 |
+
started_at = now,
|
| 255 |
+
log_entry = f"{ts_color}[{now}]{reset} {info_color}Job started{reset} on {ctx.hardware} ({ctx.precision})",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
runner = BenchmarkExecutor()
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
# ── Fetch the persisted job (for executor) ────────────────────────────
|
| 262 |
+
job = await bench_reg.get_job(job_id)
|
| 263 |
+
assert job is not None, "Job disappeared from DB after creation"
|
| 264 |
+
|
| 265 |
+
# ── Define Progress Callback ──────────────────────────────────────────
|
| 266 |
+
async def on_progress(progress: float, message: str, telemetry: Any | None):
|
| 267 |
+
await bench_reg.update_job(
|
| 268 |
+
job_id,
|
| 269 |
+
progress=progress,
|
| 270 |
+
log_entry=f"{ts_color}[{datetime.now(timezone.utc).isoformat()}]{reset} {info_color}{message}{reset}",
|
| 271 |
+
last_telemetry=telemetry.model_dump() if telemetry and hasattr(telemetry, "model_dump") else telemetry
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# ── Execution Loop ────────────────────────────────────────────────────
|
| 275 |
+
exec_result = await runner.execute(
|
| 276 |
+
job=job,
|
| 277 |
+
model=model,
|
| 278 |
+
dataset=dataset,
|
| 279 |
+
on_progress=on_progress
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# ── Compute metrics ───────────────────────────────────────────────────
|
| 283 |
+
metrics = _metrics.compute(
|
| 284 |
+
task = ctx.task,
|
| 285 |
+
latencies_ms = exec_result.latencies_ms,
|
| 286 |
+
total_images = exec_result.total_images,
|
| 287 |
+
batch_size = ctx.batch_size,
|
| 288 |
+
vram_samples = exec_result.vram_samples,
|
| 289 |
+
task_scores = exec_result.task_scores,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ── Persist result ────────────────────────────────────────────────────
|
| 293 |
+
await bench_reg.save_result(
|
| 294 |
+
job_id = job_id,
|
| 295 |
+
metrics = metrics,
|
| 296 |
+
telemetry_summary = exec_result.telemetry_summary,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
ended = datetime.now(timezone.utc).isoformat()
|
| 300 |
+
await bench_reg.update_job(
|
| 301 |
+
job_id,
|
| 302 |
+
status = "completed",
|
| 303 |
+
progress = 1.0,
|
| 304 |
+
ended_at = ended,
|
| 305 |
+
log_entry = f"{ts_color}[{ended}]{reset} {success_color}Benchmark completed{reset} — {metrics.fps} FPS",
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
await audit(
|
| 309 |
+
"benchmark_completed",
|
| 310 |
+
job_id = job_id,
|
| 311 |
+
payload = {"model_id": ctx.model_id, "dataset_id": ctx.dataset_id},
|
| 312 |
+
)
|
| 313 |
+
log.info(
|
| 314 |
+
"benchmark_completed",
|
| 315 |
+
job_id = job_id,
|
| 316 |
+
fps = metrics.fps,
|
| 317 |
+
lat_ms = metrics.latency_mean_ms,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
except asyncio.CancelledError:
|
| 321 |
+
# Task cancelled externally (e.g. server shutdown) — don't swallow
|
| 322 |
+
ended = datetime.now(timezone.utc).isoformat()
|
| 323 |
+
await bench_reg.update_job(
|
| 324 |
+
job_id,
|
| 325 |
+
status = "failed",
|
| 326 |
+
error = "Job cancelled",
|
| 327 |
+
ended_at = ended,
|
| 328 |
+
log_entry = f"{ts_color}[{ended}]{reset} \x1b[31mJob cancelled\x1b[0m",
|
| 329 |
+
)
|
| 330 |
+
raise
|
| 331 |
+
|
| 332 |
+
except Exception as exc:
|
| 333 |
+
ended = datetime.now(timezone.utc).isoformat()
|
| 334 |
+
err_msg = str(exc)
|
| 335 |
+
error_color = "\x1b[31m" # Red
|
| 336 |
+
await bench_reg.update_job(
|
| 337 |
+
job_id,
|
| 338 |
+
status = "failed",
|
| 339 |
+
error = err_msg,
|
| 340 |
+
ended_at = ended,
|
| 341 |
+
log_entry = f"{ts_color}[{ended}]{reset} {error_color}ERROR: {err_msg}{reset}",
|
| 342 |
+
)
|
| 343 |
+
await audit(
|
| 344 |
+
"benchmark_failed",
|
| 345 |
+
job_id = job_id,
|
| 346 |
+
level = "error",
|
| 347 |
+
payload = {"error": err_msg, "model_id": ctx.model_id},
|
| 348 |
+
)
|
| 349 |
+
log.exception("benchmark_failed", job_id=job_id)
|
| 350 |
+
finally:
|
| 351 |
+
pass
|
| 352 |
+
|
| 353 |
+
# ── Resource resolvers ────────────────────────────────────────────────────────
|
| 354 |
+
|
| 355 |
+
async def _require_model(model_id: str) -> Model:
|
| 356 |
+
model = await get_model(model_id)
|
| 357 |
+
if not model:
|
| 358 |
+
from fastapi import HTTPException
|
| 359 |
+
raise HTTPException(
|
| 360 |
+
status_code = 404,
|
| 361 |
+
detail = f"Model '{model_id}' not found in Model Zoo",
|
| 362 |
+
)
|
| 363 |
+
return model
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
async def _require_dataset(dataset_id: str) -> Dataset:
|
| 367 |
+
dataset = await get_dataset(dataset_id)
|
| 368 |
+
if not dataset:
|
| 369 |
+
from fastapi import HTTPException
|
| 370 |
+
raise HTTPException(
|
| 371 |
+
status_code = 404,
|
| 372 |
+
detail = f"Dataset '{dataset_id}' not found in Dataset Manager",
|
| 373 |
+
)
|
| 374 |
+
return dataset
|
benchmark/registry.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/registry.py — Benchmark Registry.
|
| 3 |
+
|
| 4 |
+
All DB interactions for:
|
| 5 |
+
• benchmark_jobs — job lifecycle state
|
| 6 |
+
• benchmark_results — final metrics + telemetry summary
|
| 7 |
+
• benchmark_validation_logs — immutable check audit trail
|
| 8 |
+
|
| 9 |
+
Follows the same pattern as registry/registry.py and datasets/registry.py.
|
| 10 |
+
No direct DB access from other benchmark modules — everything routes here.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import uuid
|
| 16 |
+
from datetime import datetime, timezone
|
| 17 |
+
from typing import Any
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
from database.connection import get_db
|
| 21 |
+
from models.benchmark import (
|
| 22 |
+
BenchmarkContext,
|
| 23 |
+
BenchmarkJob,
|
| 24 |
+
BenchmarkMetrics,
|
| 25 |
+
BenchmarkResult,
|
| 26 |
+
TelemetrySummary,
|
| 27 |
+
ValidationCheck,
|
| 28 |
+
row_to_job,
|
| 29 |
+
row_to_result,
|
| 30 |
+
)
|
| 31 |
+
from observability.logger import get_logger
|
| 32 |
+
|
| 33 |
+
log = get_logger("benchmark.registry")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _get_active_project_benchmark_dir_sync(project_path: str | None) -> Path | None:
|
| 37 |
+
"""Get the absolute path to the 'benchmarks' folder in a given project path."""
|
| 38 |
+
if not project_path:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
benchmark_dir = Path(project_path) / "benchmarks"
|
| 42 |
+
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
return benchmark_dir
|
| 44 |
+
|
| 45 |
+
async def _get_active_project_benchmark_dir() -> Path | None:
|
| 46 |
+
"""Get the absolute path to the 'benchmarks' folder in the active project."""
|
| 47 |
+
from projects.service import get_active_project_path
|
| 48 |
+
project_path = await get_active_project_path()
|
| 49 |
+
return _get_active_project_benchmark_dir_sync(project_path)
|
| 50 |
+
|
| 51 |
+
async def _save_to_project(filename: str, data: dict) -> None:
|
| 52 |
+
"""Save data to a JSON file in the active project's benchmark folder."""
|
| 53 |
+
benchmark_dir = await _get_active_project_benchmark_dir()
|
| 54 |
+
if not benchmark_dir:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
file_path = benchmark_dir / filename
|
| 58 |
+
try:
|
| 59 |
+
with open(file_path, "w") as f:
|
| 60 |
+
json.dump(data, f, indent=2)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
log.error("project_persistence_failed", error=str(e), file=filename)
|
| 63 |
+
|
| 64 |
+
# ── Job CRUD ──────────────────────────────────────────────────────────────────
|
| 65 |
+
|
| 66 |
+
async def create_job(ctx: BenchmarkContext) -> BenchmarkJob:
|
| 67 |
+
db = await get_db()
|
| 68 |
+
job_id = f"bmark-{uuid.uuid4().hex[:12]}"
|
| 69 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 70 |
+
|
| 71 |
+
# Create job object
|
| 72 |
+
job = BenchmarkJob(
|
| 73 |
+
id = job_id,
|
| 74 |
+
model_id = ctx.model_id,
|
| 75 |
+
dataset_id = ctx.dataset_id,
|
| 76 |
+
task = ctx.task,
|
| 77 |
+
framework = ctx.framework,
|
| 78 |
+
hardware = ctx.hardware,
|
| 79 |
+
precision = ctx.precision,
|
| 80 |
+
batch_size = ctx.batch_size,
|
| 81 |
+
config = ctx.model_dump(),
|
| 82 |
+
status = "queued",
|
| 83 |
+
progress = 0.0,
|
| 84 |
+
created_at = now,
|
| 85 |
+
updated_at = now,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Persist to SQLite
|
| 89 |
+
await db.execute(
|
| 90 |
+
"""INSERT INTO benchmark_jobs
|
| 91 |
+
(id, model_id, dataset_id, task, framework, hardware,
|
| 92 |
+
precision, batch_size, config,
|
| 93 |
+
status, progress, logs, created_at, updated_at)
|
| 94 |
+
VALUES (?,?,?,?,?,?,?,?,?,'queued',0.0,'[]',?,?)""",
|
| 95 |
+
(
|
| 96 |
+
job_id,
|
| 97 |
+
ctx.model_id, ctx.dataset_id,
|
| 98 |
+
ctx.task, ctx.framework, ctx.hardware,
|
| 99 |
+
ctx.precision, ctx.batch_size,
|
| 100 |
+
json.dumps(ctx.model_dump()),
|
| 101 |
+
now, now,
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
await db.commit()
|
| 105 |
+
|
| 106 |
+
# Persist to project folder
|
| 107 |
+
await _save_to_project(f"job_{job_id}.json", job.model_dump())
|
| 108 |
+
|
| 109 |
+
log.info("benchmark_job_created", job_id=job_id, model=ctx.model_id)
|
| 110 |
+
return job
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
async def get_job(job_id: str) -> BenchmarkJob | None:
|
| 114 |
+
db = await get_db()
|
| 115 |
+
async with db.execute(
|
| 116 |
+
"SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)
|
| 117 |
+
) as cur:
|
| 118 |
+
row = await cur.fetchone()
|
| 119 |
+
return row_to_job(row) if row else None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
async def list_jobs(
|
| 123 |
+
*,
|
| 124 |
+
status: str | None = None,
|
| 125 |
+
model_id: str | None = None,
|
| 126 |
+
limit: int = 100,
|
| 127 |
+
) -> list[BenchmarkJob]:
|
| 128 |
+
db = await get_db()
|
| 129 |
+
clauses: list[str] = []
|
| 130 |
+
params: list[Any] = []
|
| 131 |
+
|
| 132 |
+
if status:
|
| 133 |
+
clauses.append("status = ?")
|
| 134 |
+
params.append(status)
|
| 135 |
+
if model_id:
|
| 136 |
+
clauses.append("model_id = ?")
|
| 137 |
+
params.append(model_id)
|
| 138 |
+
|
| 139 |
+
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
| 140 |
+
params.append(limit)
|
| 141 |
+
|
| 142 |
+
async with db.execute(
|
| 143 |
+
f"SELECT * FROM benchmark_jobs {where} ORDER BY created_at DESC LIMIT ?",
|
| 144 |
+
params,
|
| 145 |
+
) as cur:
|
| 146 |
+
rows = await cur.fetchall()
|
| 147 |
+
return [row_to_job(r) for r in rows]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
async def update_job(
|
| 151 |
+
job_id: str,
|
| 152 |
+
*,
|
| 153 |
+
status: str | None = None,
|
| 154 |
+
progress: float | None = None,
|
| 155 |
+
error: str | None = None,
|
| 156 |
+
started_at: str | None = None,
|
| 157 |
+
ended_at: str | None = None,
|
| 158 |
+
log_entry: str | None = None,
|
| 159 |
+
last_telemetry: dict | None = None,
|
| 160 |
+
) -> None:
|
| 161 |
+
"""Update mutable fields on a benchmark job atomically."""
|
| 162 |
+
db = await get_db()
|
| 163 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 164 |
+
|
| 165 |
+
sets: list[str] = ["updated_at = ?"]
|
| 166 |
+
vals: list[Any] = [now]
|
| 167 |
+
|
| 168 |
+
if status is not None:
|
| 169 |
+
sets.append("status = ?"); vals.append(status)
|
| 170 |
+
if progress is not None:
|
| 171 |
+
sets.append("progress = ?"); vals.append(round(progress, 4))
|
| 172 |
+
if error is not None:
|
| 173 |
+
sets.append("error = ?"); vals.append(error)
|
| 174 |
+
if started_at is not None:
|
| 175 |
+
sets.append("started_at = ?"); vals.append(started_at)
|
| 176 |
+
if ended_at is not None:
|
| 177 |
+
sets.append("ended_at = ?"); vals.append(ended_at)
|
| 178 |
+
if last_telemetry is not None:
|
| 179 |
+
sets.append("last_telemetry = ?"); vals.append(json.dumps(last_telemetry))
|
| 180 |
+
|
| 181 |
+
if log_entry is not None:
|
| 182 |
+
# Append new entry to the JSON log array (capped at 500 lines)
|
| 183 |
+
async with db.execute(
|
| 184 |
+
"SELECT logs FROM benchmark_jobs WHERE id = ?", (job_id,)
|
| 185 |
+
) as cur:
|
| 186 |
+
row = await cur.fetchone()
|
| 187 |
+
existing = json.loads(row["logs"]) if row and row["logs"] else []
|
| 188 |
+
existing.append(log_entry)
|
| 189 |
+
sets.append("logs = ?")
|
| 190 |
+
vals.append(json.dumps(existing[-500:]))
|
| 191 |
+
|
| 192 |
+
vals.append(job_id)
|
| 193 |
+
# Persist to project folder if we have the job info
|
| 194 |
+
async with db.execute("SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)) as cur:
|
| 195 |
+
row = await cur.fetchone()
|
| 196 |
+
if row:
|
| 197 |
+
job = row_to_job(row)
|
| 198 |
+
if job:
|
| 199 |
+
await _save_to_project(f"job_{job_id}.json", job.model_dump())
|
| 200 |
+
|
| 201 |
+
await db.commit()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ── Result CRUD ───────────────────────────────────────────────────────────────
|
| 205 |
+
|
| 206 |
+
async def save_result(
|
| 207 |
+
*,
|
| 208 |
+
job_id: str,
|
| 209 |
+
metrics: BenchmarkMetrics,
|
| 210 |
+
telemetry_summary: TelemetrySummary,
|
| 211 |
+
) -> BenchmarkResult:
|
| 212 |
+
db = await get_db()
|
| 213 |
+
result_id = f"bres-{uuid.uuid4().hex[:12]}"
|
| 214 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 215 |
+
|
| 216 |
+
# Persist result to SQLite
|
| 217 |
+
await db.execute(
|
| 218 |
+
"""INSERT INTO benchmark_results
|
| 219 |
+
(id, job_id, metrics, telemetry_summary, created_at)
|
| 220 |
+
VALUES (?,?,?,?,?)""",
|
| 221 |
+
(
|
| 222 |
+
result_id,
|
| 223 |
+
job_id,
|
| 224 |
+
json.dumps(metrics.model_dump(exclude_none=True)),
|
| 225 |
+
json.dumps(telemetry_summary.model_dump()),
|
| 226 |
+
now,
|
| 227 |
+
),
|
| 228 |
+
)
|
| 229 |
+
await db.commit()
|
| 230 |
+
|
| 231 |
+
result = BenchmarkResult(
|
| 232 |
+
id = result_id,
|
| 233 |
+
job_id = job_id,
|
| 234 |
+
metrics = metrics,
|
| 235 |
+
telemetry_summary = telemetry_summary,
|
| 236 |
+
created_at = now,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Persist result to project folder
|
| 240 |
+
await _save_to_project(f"result_{job_id}.json", result.model_dump())
|
| 241 |
+
|
| 242 |
+
log.info("benchmark_result_saved", job_id=job_id, result_id=result_id)
|
| 243 |
+
return result
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
async def get_result(job_id: str) -> BenchmarkResult | None:
|
| 247 |
+
db = await get_db()
|
| 248 |
+
async with db.execute(
|
| 249 |
+
"""SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
|
| 250 |
+
FROM benchmark_results r
|
| 251 |
+
JOIN benchmark_jobs j ON r.job_id = j.id
|
| 252 |
+
WHERE r.job_id = ?""", (job_id,)
|
| 253 |
+
) as cur:
|
| 254 |
+
row = await cur.fetchone()
|
| 255 |
+
return row_to_result(row) if row else None
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
async def list_results(*, limit: int = 100) -> list[BenchmarkResult]:
|
| 259 |
+
db = await get_db()
|
| 260 |
+
async with db.execute(
|
| 261 |
+
"""SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
|
| 262 |
+
FROM benchmark_results r
|
| 263 |
+
JOIN benchmark_jobs j ON r.job_id = j.id
|
| 264 |
+
ORDER BY r.created_at DESC LIMIT ?""", (limit,)
|
| 265 |
+
) as cur:
|
| 266 |
+
rows = await cur.fetchall()
|
| 267 |
+
return [row_to_result(r) for r in rows]
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ── Validation Log ────────────────────────────────────────────────────────────
|
| 271 |
+
|
| 272 |
+
async def save_validation_log(
|
| 273 |
+
*,
|
| 274 |
+
job_id: str,
|
| 275 |
+
model_id: str,
|
| 276 |
+
dataset_id: str,
|
| 277 |
+
checks: list[ValidationCheck],
|
| 278 |
+
passed: bool,
|
| 279 |
+
) -> None:
|
| 280 |
+
"""Persist an immutable record of all compatibility checks."""
|
| 281 |
+
db = await get_db()
|
| 282 |
+
log_id = f"bval-{uuid.uuid4().hex[:12]}"
|
| 283 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 284 |
+
|
| 285 |
+
await db.execute(
|
| 286 |
+
"""INSERT INTO benchmark_validation_logs
|
| 287 |
+
(id, job_id, model_id, dataset_id, checks, passed, created_at)
|
| 288 |
+
VALUES (?,?,?,?,?,?,?)""",
|
| 289 |
+
(
|
| 290 |
+
log_id, job_id, model_id, dataset_id,
|
| 291 |
+
json.dumps([c.model_dump() for c in checks]),
|
| 292 |
+
1 if passed else 0,
|
| 293 |
+
now,
|
| 294 |
+
),
|
| 295 |
+
)
|
| 296 |
+
await db.commit()
|
| 297 |
+
log.info(
|
| 298 |
+
"validation_log_saved",
|
| 299 |
+
job_id = job_id,
|
| 300 |
+
passed = passed,
|
| 301 |
+
n_checks = len(checks),
|
| 302 |
+
)
|
benchmark/telemetry.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/telemetry.py — Real-time Telemetry Collector.
|
| 3 |
+
|
| 4 |
+
Collects GPU/hardware metrics at 2 Hz during benchmark execution.
|
| 5 |
+
Designed as a drop-in adapter:
|
| 6 |
+
• Local dev → simulates realistic GPU readings based on hardware tier
|
| 7 |
+
• Production → replace _read_gpu_metrics() with pynvml calls:
|
| 8 |
+
nvmlDeviceGetUtilizationRates()
|
| 9 |
+
nvmlDeviceGetMemoryInfo()
|
| 10 |
+
nvmlDeviceGetTemperature()
|
| 11 |
+
nvmlDeviceGetPowerUsage()
|
| 12 |
+
|
| 13 |
+
Usage (async context):
|
| 14 |
+
collector = TelemetryCollector("rtx4090", vram_total_gb=24.0)
|
| 15 |
+
await collector.start()
|
| 16 |
+
# ... run inference ...
|
| 17 |
+
summary = await collector.stop()
|
| 18 |
+
samples = collector.samples
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import asyncio
|
| 23 |
+
import random
|
| 24 |
+
import statistics
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
from models.benchmark import TelemetrySample, TelemetrySummary
|
| 28 |
+
from observability.logger import get_logger
|
| 29 |
+
|
| 30 |
+
log = get_logger("benchmark.telemetry")
|
| 31 |
+
|
| 32 |
+
# ── Hardware simulation profiles ──────────────────────────────────────────────
|
| 33 |
+
# (base_util%, base_temp_C, base_power_W)
|
| 34 |
+
_HW_PROFILES: dict[str, tuple[float, float, float]] = {
|
| 35 |
+
"rtx4090": (88.0, 74.0, 380.0),
|
| 36 |
+
"rtx4080": (84.0, 70.0, 280.0),
|
| 37 |
+
"rtx4070": (80.0, 68.0, 200.0),
|
| 38 |
+
"rtx3090": (85.0, 72.0, 320.0),
|
| 39 |
+
"rtx3080": (82.0, 70.0, 250.0),
|
| 40 |
+
"rtx3070": (78.0, 66.0, 180.0),
|
| 41 |
+
"rtx3060": (74.0, 64.0, 150.0),
|
| 42 |
+
"a100": (90.0, 68.0, 350.0),
|
| 43 |
+
"h100": (92.0, 65.0, 550.0),
|
| 44 |
+
"v100": (87.0, 70.0, 280.0),
|
| 45 |
+
"t4": (75.0, 62.0, 60.0),
|
| 46 |
+
"gpu": (70.0, 65.0, 150.0),
|
| 47 |
+
"cpu": (0.0, 0.0, 0.0),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
_COLLECTION_INTERVAL_S = 0.5 # 2 Hz
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TelemetryCollector:
|
| 54 |
+
"""
|
| 55 |
+
Async telemetry collector. Call start() before inference, stop() after.
|
| 56 |
+
Thread-safe via asyncio (single-threaded event loop).
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, hardware: str, vram_total_gb: float = 8.0) -> None:
|
| 60 |
+
self._hardware = hardware
|
| 61 |
+
self._vram_total = vram_total_gb
|
| 62 |
+
self._hw_profile = self._resolve_profile(hardware)
|
| 63 |
+
self._samples: list[TelemetrySample] = []
|
| 64 |
+
self._running = False
|
| 65 |
+
self._task: asyncio.Task | None = None
|
| 66 |
+
|
| 67 |
+
# ── Public API ────────────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
async def start(self) -> None:
|
| 70 |
+
self._running = True
|
| 71 |
+
self._samples = []
|
| 72 |
+
self._task = asyncio.create_task(
|
| 73 |
+
self._collect_loop(), name="telemetry_collector"
|
| 74 |
+
)
|
| 75 |
+
log.debug("telemetry_started", hardware=self._hardware)
|
| 76 |
+
|
| 77 |
+
async def stop(self) -> TelemetrySummary:
|
| 78 |
+
self._running = False
|
| 79 |
+
if self._task and not self._task.done():
|
| 80 |
+
self._task.cancel()
|
| 81 |
+
try:
|
| 82 |
+
await self._task
|
| 83 |
+
except asyncio.CancelledError:
|
| 84 |
+
pass
|
| 85 |
+
log.debug(
|
| 86 |
+
"telemetry_stopped",
|
| 87 |
+
hardware = self._hardware,
|
| 88 |
+
samples = len(self._samples),
|
| 89 |
+
)
|
| 90 |
+
return self._build_summary()
|
| 91 |
+
|
| 92 |
+
def record_batch_context(self, batch_idx: int, progress: float) -> None:
|
| 93 |
+
"""Annotate the most recent sample with the current batch context."""
|
| 94 |
+
if self._samples:
|
| 95 |
+
self._samples[-1].batch_idx = batch_idx
|
| 96 |
+
self._samples[-1].progress = progress
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def samples(self) -> list[TelemetrySample]:
|
| 100 |
+
return list(self._samples)
|
| 101 |
+
|
| 102 |
+
# ── Internal ──────────────────────────────────────────────────────────────
|
| 103 |
+
|
| 104 |
+
async def _collect_loop(self) -> None:
|
| 105 |
+
while self._running:
|
| 106 |
+
sample = self._read_gpu_metrics()
|
| 107 |
+
self._samples.append(sample)
|
| 108 |
+
await asyncio.sleep(_COLLECTION_INTERVAL_S)
|
| 109 |
+
|
| 110 |
+
def _read_gpu_metrics(self) -> TelemetrySample:
|
| 111 |
+
"""
|
| 112 |
+
Returns a TelemetrySample for the current hardware state.
|
| 113 |
+
|
| 114 |
+
PRODUCTION SWAP: Replace this body with pynvml calls:
|
| 115 |
+
handle = nvmlDeviceGetHandleByIndex(0)
|
| 116 |
+
util = nvmlDeviceGetUtilizationRates(handle)
|
| 117 |
+
mem = nvmlDeviceGetMemoryInfo(handle)
|
| 118 |
+
temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
|
| 119 |
+
power = nvmlDeviceGetPowerUsage(handle) / 1000 # mW → W
|
| 120 |
+
"""
|
| 121 |
+
base_util, base_temp, base_power = self._hw_profile
|
| 122 |
+
|
| 123 |
+
if base_util == 0.0: # CPU path — no meaningful GPU readings
|
| 124 |
+
return TelemetrySample(
|
| 125 |
+
timestamp = time.time(),
|
| 126 |
+
gpu_util_pct = 0.0,
|
| 127 |
+
vram_used_gb = 0.0,
|
| 128 |
+
vram_total_gb = 0.0,
|
| 129 |
+
temp_c = 0.0,
|
| 130 |
+
power_w = 0.0,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Simulate realistic jitter (±5% util, ±3°C, ±10W)
|
| 134 |
+
jitter_util = random.gauss(0, 3.0)
|
| 135 |
+
jitter_temp = random.gauss(0, 1.5)
|
| 136 |
+
jitter_power = random.gauss(0, 8.0)
|
| 137 |
+
vram_frac = random.uniform(0.58, 0.72)
|
| 138 |
+
|
| 139 |
+
return TelemetrySample(
|
| 140 |
+
timestamp = time.time(),
|
| 141 |
+
gpu_util_pct = max(0.0, min(100.0, base_util + jitter_util)),
|
| 142 |
+
vram_used_gb = round(
|
| 143 |
+
max(0.0, min(self._vram_total, self._vram_total * vram_frac)), 3
|
| 144 |
+
),
|
| 145 |
+
vram_total_gb = self._vram_total,
|
| 146 |
+
temp_c = round(max(0.0, base_temp + jitter_temp), 1),
|
| 147 |
+
power_w = round(max(0.0, base_power + jitter_power), 1),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def _build_summary(self) -> TelemetrySummary:
|
| 151 |
+
if not self._samples:
|
| 152 |
+
return TelemetrySummary()
|
| 153 |
+
|
| 154 |
+
utils = [s.gpu_util_pct for s in self._samples]
|
| 155 |
+
vrams = [s.vram_used_gb for s in self._samples]
|
| 156 |
+
temps = [s.temp_c for s in self._samples]
|
| 157 |
+
powers = [s.power_w for s in self._samples]
|
| 158 |
+
|
| 159 |
+
def _safe_mean(lst: list[float]) -> float:
|
| 160 |
+
return statistics.mean(lst) if lst else 0.0
|
| 161 |
+
|
| 162 |
+
return TelemetrySummary(
|
| 163 |
+
gpu_util_avg = round(_safe_mean(utils), 2),
|
| 164 |
+
gpu_util_peak = round(max(utils), 2),
|
| 165 |
+
vram_avg_gb = round(_safe_mean(vrams), 3),
|
| 166 |
+
vram_peak_gb = round(max(vrams), 3),
|
| 167 |
+
temp_avg_c = round(_safe_mean(temps), 1),
|
| 168 |
+
temp_peak_c = round(max(temps), 1),
|
| 169 |
+
power_avg_w = round(_safe_mean(powers), 1),
|
| 170 |
+
power_peak_w = round(max(powers), 1),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def _resolve_profile(hardware: str) -> tuple[float, float, float]:
|
| 175 |
+
hw = hardware.lower().replace(" ", "").replace("-", "")
|
| 176 |
+
for key, profile in _HW_PROFILES.items():
|
| 177 |
+
if key in hw:
|
| 178 |
+
return profile
|
| 179 |
+
# Default for unknown GPU-class hardware
|
| 180 |
+
if any(x in hw for x in ("gpu", "rtx", "gtx", "cuda", "vram")):
|
| 181 |
+
return _HW_PROFILES["gpu"]
|
| 182 |
+
return _HW_PROFILES["cpu"]
|
benchmark/torch_runner.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark/torch_runner.py — Synchronous GPU inference runner.
|
| 3 |
+
|
| 4 |
+
Called from BenchmarkExecutor via asyncio.run_in_executor() so it never
|
| 5 |
+
blocks the event loop. PyTorch is an optional dependency — if it is not
|
| 6 |
+
installed the module raises ImportError and execution.py falls back to
|
| 7 |
+
the simulation path.
|
| 8 |
+
|
| 9 |
+
Supported weight formats (detected by file extension):
|
| 10 |
+
.pt / .pth — torch.load (TorchScript or state-dict)
|
| 11 |
+
.safetensors — safetensors.torch.load_file
|
| 12 |
+
.onnx — onnxruntime InferenceSession
|
| 13 |
+
|
| 14 |
+
PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
# ── Model cache (keyed by absolute path) ─────────────────────────────────────
|
| 23 |
+
_MODEL_CACHE: dict[str, Any] = {}
|
| 24 |
+
|
| 25 |
+
# Standard input shapes per task (B, C, H, W)
|
| 26 |
+
_INPUT_SHAPES: dict[str, tuple[int, int, int]] = {
|
| 27 |
+
"detection": (3, 640, 640),
|
| 28 |
+
"segmentation": (3, 640, 640),
|
| 29 |
+
"classification": (3, 224, 224),
|
| 30 |
+
"generation": (3, 512, 512),
|
| 31 |
+
"embedding": (3, 224, 224),
|
| 32 |
+
}
|
| 33 |
+
_DEFAULT_SHAPE = (3, 640, 640)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def run_torch_batch(model_path: str, batch_size: int, task: str = "detection") -> float:
|
| 37 |
+
"""Run one inference batch and return per-image latency in ms.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_path: Absolute path to the weight file.
|
| 41 |
+
batch_size: Number of images in the batch.
|
| 42 |
+
task: Model task (affects dummy input shape).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Latency per image in milliseconds.
|
| 46 |
+
"""
|
| 47 |
+
import torch # raises ImportError if not installed
|
| 48 |
+
|
| 49 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
ext = Path(model_path).suffix.lower()
|
| 51 |
+
|
| 52 |
+
model = _load_model(model_path, ext, device)
|
| 53 |
+
c, h, w = _INPUT_SHAPES.get(task, _DEFAULT_SHAPE)
|
| 54 |
+
dummy = torch.zeros(batch_size, c, h, w, device=device)
|
| 55 |
+
|
| 56 |
+
# Warm-up pass (first call is slower due to CUDA kernel compilation)
|
| 57 |
+
if device == "cuda":
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
_forward(model, dummy, ext, device)
|
| 60 |
+
torch.cuda.synchronize()
|
| 61 |
+
|
| 62 |
+
# Timed pass
|
| 63 |
+
if device == "cuda":
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
t0 = time.perf_counter()
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
_forward(model, dummy, ext, device)
|
| 68 |
+
if device == "cuda":
|
| 69 |
+
torch.cuda.synchronize()
|
| 70 |
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
| 71 |
+
|
| 72 |
+
return elapsed_ms / batch_size
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _load_model(path: str, ext: str, device: str) -> Any:
|
| 76 |
+
"""Load and cache the model by absolute path."""
|
| 77 |
+
if path in _MODEL_CACHE:
|
| 78 |
+
return _MODEL_CACHE[path]
|
| 79 |
+
|
| 80 |
+
model = _load_by_ext(path, ext, device)
|
| 81 |
+
_MODEL_CACHE[path] = model
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _load_by_ext(path: str, ext: str, device: str) -> Any:
|
| 86 |
+
"""Select loader based on file extension."""
|
| 87 |
+
if ext in (".pt", ".pth"):
|
| 88 |
+
return _load_torch(path, device)
|
| 89 |
+
if ext == ".safetensors":
|
| 90 |
+
return _load_safetensors(path, device)
|
| 91 |
+
if ext == ".onnx":
|
| 92 |
+
return _load_onnx(path)
|
| 93 |
+
raise ValueError(f"Unsupported model format: {ext}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _load_torch(path: str, device: str) -> Any:
|
| 97 |
+
import torch
|
| 98 |
+
# <<< REPLACE IN PRODUCTION >>> with proper model class instantiation
|
| 99 |
+
# TorchScript models can be loaded directly; state-dict models need
|
| 100 |
+
# the model class to be imported separately.
|
| 101 |
+
try:
|
| 102 |
+
model = torch.jit.load(path, map_location=device)
|
| 103 |
+
model.eval()
|
| 104 |
+
return model
|
| 105 |
+
except RuntimeError:
|
| 106 |
+
# Not a TorchScript model — try loading as a full checkpoint
|
| 107 |
+
obj = torch.load(path, map_location=device, weights_only=False)
|
| 108 |
+
if hasattr(obj, "eval"):
|
| 109 |
+
obj.eval()
|
| 110 |
+
return obj
|
| 111 |
+
# It's a state-dict — we cannot run inference without knowing the arch
|
| 112 |
+
raise RuntimeError(
|
| 113 |
+
f"Model at {path} is a state-dict; cannot run inference without "
|
| 114 |
+
"the model class. Use a TorchScript-exported .pt file."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _load_safetensors(path: str, device: str) -> Any:
|
| 119 |
+
# <<< REPLACE IN PRODUCTION >>> safetensors gives tensors only;
|
| 120 |
+
# you still need the model class. This is intentionally left as a
|
| 121 |
+
# placeholder that raises a clear error rather than silently failing.
|
| 122 |
+
raise NotImplementedError(
|
| 123 |
+
"safetensors inference requires the model class to be registered. "
|
| 124 |
+
"Convert to TorchScript or ONNX for architecture-agnostic inference."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _load_onnx(path: str) -> Any:
|
| 129 |
+
import onnxruntime as ort # type: ignore[import]
|
| 130 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 131 |
+
return ort.InferenceSession(path, providers=providers)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _forward(model: Any, dummy: Any, ext: str, device: str) -> Any:
|
| 135 |
+
"""Run a single forward pass, dispatching by model type."""
|
| 136 |
+
if ext == ".onnx":
|
| 137 |
+
import numpy as np
|
| 138 |
+
np_input = dummy.cpu().numpy()
|
| 139 |
+
input_name = model.get_inputs()[0].name
|
| 140 |
+
return model.run(None, {input_name: np_input})
|
| 141 |
+
# TorchScript / nn.Module
|
| 142 |
+
return model(dummy)
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# datasets package
|
datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
datasets/__pycache__/annotation_parser.cpython-310.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
datasets/__pycache__/base_adapter.cpython-310.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
datasets/__pycache__/format_adapters.cpython-310.pyc
ADDED
|
Binary file (9.18 kB). View file
|
|
|
datasets/__pycache__/import_service.cpython-310.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
datasets/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
datasets/__pycache__/viewer_service.cpython-310.pyc
ADDED
|
Binary file (8.22 kB). View file
|
|
|
datasets/annotation_parser.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets/annotation_parser.py — Multi-format annotation parser.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- YOLO (darknet .txt + classes.txt / data.yaml)
|
| 6 |
+
- COCO (instances_*.json / _annotations.coco.json)
|
| 7 |
+
- Pascal VOC (*.xml)
|
| 8 |
+
|
| 9 |
+
All formats normalise to the unified Annotation schema with
|
| 10 |
+
normalised bounding boxes (0–1 range, x_topleft, y_topleft, w, h).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import csv
|
| 15 |
+
import json
|
| 16 |
+
import re
|
| 17 |
+
import uuid
|
| 18 |
+
import xml.etree.ElementTree as ET
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Iterator, Optional
|
| 21 |
+
|
| 22 |
+
from observability.logger import get_logger
|
| 23 |
+
|
| 24 |
+
log = get_logger("annotation_parser")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── Unified Output ────────────────────────────────────────────────────────────
|
| 28 |
+
|
| 29 |
+
def _make_ann(
|
| 30 |
+
image_id: str,
|
| 31 |
+
dataset_id: str,
|
| 32 |
+
label: str,
|
| 33 |
+
bbox: tuple[float, float, float, float] | None = None, # x, y, w, h (normalised)
|
| 34 |
+
normalised: bool = True,
|
| 35 |
+
area: float | None = None,
|
| 36 |
+
confidence: float | None = None,
|
| 37 |
+
ann_type: str = "detection",
|
| 38 |
+
segmentation: list[list[float]] | None = None,
|
| 39 |
+
keypoints: list[float] | None = None,
|
| 40 |
+
metadata: dict | None = None,
|
| 41 |
+
) -> dict:
|
| 42 |
+
return {
|
| 43 |
+
"id": f"ann-{uuid.uuid4().hex[:12]}",
|
| 44 |
+
"image_id": image_id,
|
| 45 |
+
"dataset_id": dataset_id,
|
| 46 |
+
"label": label,
|
| 47 |
+
"bbox_x": bbox[0] if bbox else None,
|
| 48 |
+
"bbox_y": bbox[1] if bbox else None,
|
| 49 |
+
"bbox_w": bbox[2] if bbox else None,
|
| 50 |
+
"bbox_h": bbox[3] if bbox else None,
|
| 51 |
+
"normalised": 1 if normalised else 0,
|
| 52 |
+
"area": area,
|
| 53 |
+
"confidence": confidence,
|
| 54 |
+
"ann_type": ann_type,
|
| 55 |
+
"segmentation": json.dumps(segmentation) if segmentation else None,
|
| 56 |
+
"keypoints": json.dumps(keypoints) if keypoints else None,
|
| 57 |
+
"metadata": json.dumps(metadata) if metadata else None,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── YOLO Parser ───────────────────────────────────────────────────────────────
|
| 62 |
+
|
| 63 |
+
class YOLOParser:
|
| 64 |
+
"""
|
| 65 |
+
Reads YOLO darknet annotation files (.txt) + class map.
|
| 66 |
+
Each line: <class_id> <cx> <cy> <w> <h> (all normalised 0–1)
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def load_class_map(dataset_root: Path) -> list[str]:
|
| 71 |
+
"""Attempt to load class names from data.yaml or classes.txt."""
|
| 72 |
+
# Try data.yaml first
|
| 73 |
+
for yaml_file in dataset_root.rglob("data.yaml"):
|
| 74 |
+
try:
|
| 75 |
+
import yaml
|
| 76 |
+
with open(yaml_file, 'r', encoding='utf-8', errors='replace') as f:
|
| 77 |
+
data = yaml.safe_load(f)
|
| 78 |
+
if data and 'names' in data:
|
| 79 |
+
names = data['names']
|
| 80 |
+
if isinstance(names, list):
|
| 81 |
+
return names
|
| 82 |
+
elif isinstance(names, dict):
|
| 83 |
+
# Handle dict format: {0: 'class_a', 1: 'class_b'}
|
| 84 |
+
return [names[i] for i in sorted(names.keys())]
|
| 85 |
+
except Exception:
|
| 86 |
+
# Fallback to regex if yaml import fails or parsing fails
|
| 87 |
+
try:
|
| 88 |
+
text = yaml_file.read_text(encoding="utf-8", errors="replace")
|
| 89 |
+
import re as _re
|
| 90 |
+
m = _re.search(r"names\s*:\s*\n((?:\s*-\s*.+\n?)+)", text)
|
| 91 |
+
if m:
|
| 92 |
+
return [line.strip().lstrip("- ").strip() for line in m.group(1).splitlines() if line.strip()]
|
| 93 |
+
except Exception:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Try classes.txt
|
| 97 |
+
for cls_file in dataset_root.rglob("classes.txt"):
|
| 98 |
+
try:
|
| 99 |
+
lines = cls_file.read_text(encoding="utf-8", errors="replace").splitlines()
|
| 100 |
+
return [l.strip() for l in lines if l.strip()]
|
| 101 |
+
except Exception:
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def parse_file(
|
| 108 |
+
txt_path: Path,
|
| 109 |
+
image_id: str,
|
| 110 |
+
dataset_id: str,
|
| 111 |
+
class_map: list[str],
|
| 112 |
+
) -> list[dict]:
|
| 113 |
+
annotations = []
|
| 114 |
+
try:
|
| 115 |
+
text = txt_path.read_text(encoding="utf-8", errors="replace")
|
| 116 |
+
except OSError:
|
| 117 |
+
return annotations
|
| 118 |
+
|
| 119 |
+
for line in text.splitlines():
|
| 120 |
+
parts = line.strip().split()
|
| 121 |
+
if len(parts) < 5:
|
| 122 |
+
continue
|
| 123 |
+
try:
|
| 124 |
+
cls_id = int(parts[0])
|
| 125 |
+
cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
|
| 126 |
+
# YOLO cx,cy → top-left x,y
|
| 127 |
+
x = cx - w / 2
|
| 128 |
+
y = cy - h / 2
|
| 129 |
+
label = class_map[cls_id] if cls_id < len(class_map) else str(cls_id)
|
| 130 |
+
annotations.append(
|
| 131 |
+
_make_ann(image_id, dataset_id, label, (x, y, w, h), area=w * h)
|
| 132 |
+
)
|
| 133 |
+
except (ValueError, IndexError):
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
return annotations
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def iter_dataset(
|
| 140 |
+
dataset_root: Path,
|
| 141 |
+
dataset_id: str,
|
| 142 |
+
class_map: list[str],
|
| 143 |
+
) -> Iterator[tuple[str, str, str, list[dict]]]:
|
| 144 |
+
"""
|
| 145 |
+
Yield (image_rel_path, image_id, split, annotations) for every image in the dataset.
|
| 146 |
+
Walks train/valid/test directories.
|
| 147 |
+
"""
|
| 148 |
+
# Supported subfolder names for splits
|
| 149 |
+
split_map = {
|
| 150 |
+
"train": ["train", "training"],
|
| 151 |
+
"val": ["valid", "val", "validation"],
|
| 152 |
+
"test": ["test", "testing"]
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
found_any = False
|
| 156 |
+
for split_name, folder_names in split_map.items():
|
| 157 |
+
for folder_name in folder_names:
|
| 158 |
+
split_dir = dataset_root / folder_name
|
| 159 |
+
images_dir = split_dir / "images"
|
| 160 |
+
|
| 161 |
+
# Support both split/images and split/ (if images are direct)
|
| 162 |
+
search_dir = images_dir if images_dir.exists() else split_dir
|
| 163 |
+
if not search_dir.exists():
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
found_any = True
|
| 167 |
+
labels_dir = split_dir / "labels"
|
| 168 |
+
|
| 169 |
+
for img_path in sorted(search_dir.rglob("*")):
|
| 170 |
+
if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 174 |
+
|
| 175 |
+
# Resolve label path
|
| 176 |
+
# 1. split/labels/img.txt
|
| 177 |
+
# 2. split/img.txt
|
| 178 |
+
# 3. img_path.with_suffix(".txt")
|
| 179 |
+
label_candidates = []
|
| 180 |
+
if labels_dir.exists():
|
| 181 |
+
label_candidates.append(labels_dir / img_path.with_suffix(".txt").name)
|
| 182 |
+
label_candidates.append(img_path.with_suffix(".txt"))
|
| 183 |
+
|
| 184 |
+
anns: list[dict] = []
|
| 185 |
+
for label_file in label_candidates:
|
| 186 |
+
if label_file.exists():
|
| 187 |
+
anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map)
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
rel_path = str(img_path.relative_to(dataset_root))
|
| 191 |
+
yield rel_path, image_id, split_name, anns
|
| 192 |
+
|
| 193 |
+
# Fallback: if no split folders found, scan the root
|
| 194 |
+
if not found_any:
|
| 195 |
+
for img_path in sorted(dataset_root.rglob("*")):
|
| 196 |
+
if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
| 197 |
+
continue
|
| 198 |
+
# Skip files inside already processed folders if we had any
|
| 199 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 200 |
+
label_file = img_path.with_suffix(".txt")
|
| 201 |
+
anns = []
|
| 202 |
+
if label_file.exists():
|
| 203 |
+
anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map)
|
| 204 |
+
|
| 205 |
+
rel_path = str(img_path.relative_to(dataset_root))
|
| 206 |
+
yield rel_path, image_id, "train", anns
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ── COCO Parser ───────────────────────────────────────────────────────────────
|
| 210 |
+
|
| 211 |
+
class COCOParser:
|
| 212 |
+
"""
|
| 213 |
+
Reads COCO JSON annotation files.
|
| 214 |
+
Supports: instances_train.json, instances_val.json, _annotations.coco.json
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def find_annotation_files(dataset_root: Path) -> list[Path]:
|
| 219 |
+
patterns = ["instances_*.json", "_annotations.coco.json", "*.json"]
|
| 220 |
+
found = []
|
| 221 |
+
for pat in patterns:
|
| 222 |
+
for f in dataset_root.rglob(pat):
|
| 223 |
+
if "label" not in f.name.lower() and "class" not in f.name.lower():
|
| 224 |
+
found.append(f)
|
| 225 |
+
return list(dict.fromkeys(found)) # deduplicate
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def parse_file(
|
| 229 |
+
json_path: Path,
|
| 230 |
+
dataset_id: str,
|
| 231 |
+
) -> tuple[list[str], list[tuple[str, str, str, list[dict]]]]:
|
| 232 |
+
"""
|
| 233 |
+
Returns: (class_names, [(rel_image_path, image_id, split, annotations)])
|
| 234 |
+
"""
|
| 235 |
+
try:
|
| 236 |
+
data = json.loads(json_path.read_text(encoding="utf-8"))
|
| 237 |
+
except (OSError, json.JSONDecodeError) as e:
|
| 238 |
+
log.warning("coco_parse_error", file=str(json_path), error=str(e))
|
| 239 |
+
return [], []
|
| 240 |
+
|
| 241 |
+
categories = {c["id"]: c["name"] for c in data.get("categories", [])}
|
| 242 |
+
class_names = list(categories.values())
|
| 243 |
+
|
| 244 |
+
# Determine split from filename
|
| 245 |
+
fname = json_path.stem.lower()
|
| 246 |
+
if "train" in fname:
|
| 247 |
+
split = "train"
|
| 248 |
+
elif "val" in fname or "valid" in fname:
|
| 249 |
+
split = "val"
|
| 250 |
+
elif "test" in fname:
|
| 251 |
+
split = "test"
|
| 252 |
+
else:
|
| 253 |
+
split = "train"
|
| 254 |
+
|
| 255 |
+
# Build image map
|
| 256 |
+
image_map: dict[int, dict] = {
|
| 257 |
+
img["id"]: img for img in data.get("images", [])
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# Group annotations by image
|
| 261 |
+
ann_by_image: dict[int, list] = {}
|
| 262 |
+
for ann in data.get("annotations", []):
|
| 263 |
+
ann_by_image.setdefault(ann["image_id"], []).append(ann)
|
| 264 |
+
|
| 265 |
+
results = []
|
| 266 |
+
for coco_img_id, img_meta in image_map.items():
|
| 267 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 268 |
+
rel_path = img_meta.get("file_name", "")
|
| 269 |
+
anns = []
|
| 270 |
+
for coco_ann in ann_by_image.get(coco_img_id, []):
|
| 271 |
+
label = categories.get(coco_ann.get("category_id", -1), "unknown")
|
| 272 |
+
bbox = coco_ann.get("bbox", [])
|
| 273 |
+
if len(bbox) == 4:
|
| 274 |
+
# COCO: [x_topleft, y_topleft, w, h] in pixel coords
|
| 275 |
+
img_w = img_meta.get("width", 1) or 1
|
| 276 |
+
img_h = img_meta.get("height", 1) or 1
|
| 277 |
+
bx = bbox[0] / img_w
|
| 278 |
+
by = bbox[1] / img_h
|
| 279 |
+
bw = bbox[2] / img_w
|
| 280 |
+
bh = bbox[3] / img_h
|
| 281 |
+
area_pct = (bbox[2] * bbox[3]) / (img_w * img_h)
|
| 282 |
+
|
| 283 |
+
# Extract segmentation if available
|
| 284 |
+
segmentation = coco_ann.get("segmentation")
|
| 285 |
+
# COCO segmentation can be a list of polygons or RLE
|
| 286 |
+
poly_data = None
|
| 287 |
+
if isinstance(segmentation, list) and len(segmentation) > 0:
|
| 288 |
+
# Normalize polygon coordinates
|
| 289 |
+
poly_data = []
|
| 290 |
+
for poly in segmentation:
|
| 291 |
+
normalized_poly = []
|
| 292 |
+
for i in range(0, len(poly), 2):
|
| 293 |
+
normalized_poly.append(poly[i] / img_w)
|
| 294 |
+
normalized_poly.append(poly[i+1] / img_h)
|
| 295 |
+
poly_data.append(normalized_poly)
|
| 296 |
+
|
| 297 |
+
anns.append(
|
| 298 |
+
_make_ann(
|
| 299 |
+
image_id,
|
| 300 |
+
dataset_id,
|
| 301 |
+
label,
|
| 302 |
+
(bx, by, bw, bh),
|
| 303 |
+
area=area_pct,
|
| 304 |
+
segmentation=poly_data,
|
| 305 |
+
ann_type="segmentation" if poly_data else "detection"
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
results.append((rel_path, image_id, split, anns))
|
| 309 |
+
|
| 310 |
+
return class_names, results
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# ── VOC Parser ────────────────────────────────────────────────────────────────
|
| 314 |
+
|
| 315 |
+
class VOCParser:
|
| 316 |
+
"""Reads Pascal VOC XML annotation files."""
|
| 317 |
+
|
| 318 |
+
@staticmethod
|
| 319 |
+
def parse_file(
|
| 320 |
+
xml_path: Path,
|
| 321 |
+
image_id: str,
|
| 322 |
+
dataset_id: str,
|
| 323 |
+
) -> tuple[str, int, int, list[dict]]:
|
| 324 |
+
"""Returns (filename, width, height, annotations)."""
|
| 325 |
+
try:
|
| 326 |
+
tree = ET.parse(str(xml_path))
|
| 327 |
+
except ET.ParseError as e:
|
| 328 |
+
log.warning("voc_parse_error", file=str(xml_path), error=str(e))
|
| 329 |
+
return "", 0, 0, []
|
| 330 |
+
|
| 331 |
+
root = tree.getroot()
|
| 332 |
+
filename = root.findtext("filename") or ""
|
| 333 |
+
size = root.find("size")
|
| 334 |
+
img_w = int(size.findtext("width") or 1) if size is not None else 1
|
| 335 |
+
img_h = int(size.findtext("height") or 1) if size is not None else 1
|
| 336 |
+
|
| 337 |
+
anns = []
|
| 338 |
+
for obj in root.findall("object"):
|
| 339 |
+
label = obj.findtext("name") or "unknown"
|
| 340 |
+
bndbox = obj.find("bndbox")
|
| 341 |
+
if bndbox is None:
|
| 342 |
+
continue
|
| 343 |
+
xmin = float(bndbox.findtext("xmin") or 0)
|
| 344 |
+
ymin = float(bndbox.findtext("ymin") or 0)
|
| 345 |
+
xmax = float(bndbox.findtext("xmax") or 0)
|
| 346 |
+
ymax = float(bndbox.findtext("ymax") or 0)
|
| 347 |
+
# Normalise
|
| 348 |
+
bx = xmin / img_w
|
| 349 |
+
by = ymin / img_h
|
| 350 |
+
bw = (xmax - xmin) / img_w
|
| 351 |
+
bh = (ymax - ymin) / img_h
|
| 352 |
+
anns.append(_make_ann(image_id, dataset_id, label, (bx, by, bw, bh)))
|
| 353 |
+
|
| 354 |
+
return filename, img_w, img_h, anns
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def iter_dataset(
|
| 358 |
+
dataset_root: Path,
|
| 359 |
+
dataset_id: str,
|
| 360 |
+
) -> Iterator[tuple[str, str, str, int, int, list[dict]]]:
|
| 361 |
+
"""Yield (rel_path, image_id, split, w, h, annotations)."""
|
| 362 |
+
for xml_path in sorted(dataset_root.rglob("*.xml")):
|
| 363 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 364 |
+
filename, w, h, anns = VOCParser.parse_file(xml_path, image_id, dataset_id)
|
| 365 |
+
split = "train"
|
| 366 |
+
for part in xml_path.parts:
|
| 367 |
+
if part in ("train", "training"):
|
| 368 |
+
split = "train"; break
|
| 369 |
+
if part in ("val", "valid", "validation"):
|
| 370 |
+
split = "val"; break
|
| 371 |
+
if part in ("test", "testing"):
|
| 372 |
+
split = "test"; break
|
| 373 |
+
rel_path = filename or str(xml_path.with_suffix(".jpg").relative_to(dataset_root))
|
| 374 |
+
yield rel_path, image_id, split, w, h, anns
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ── Roboflow TXT Parser ───────────────────────────────────────────────────────
|
| 378 |
+
|
| 379 |
+
class RoboflowTXTParser:
|
| 380 |
+
"""
|
| 381 |
+
Reads Roboflow classification TXT formats.
|
| 382 |
+
1. Folder-based: split/class_name/image.jpg
|
| 383 |
+
2. Label-file: split/_annotations.txt (format: filename,class_name)
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def iter_dataset(
|
| 388 |
+
dataset_root: Path,
|
| 389 |
+
dataset_id: str,
|
| 390 |
+
) -> Iterator[tuple[str, str, str, list[dict]]]:
|
| 391 |
+
split_map = {
|
| 392 |
+
"train": ["train", "training"],
|
| 393 |
+
"val": ["valid", "val", "validation"],
|
| 394 |
+
"test": ["test", "testing"]
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
found_any = False
|
| 398 |
+
for split_name, folder_names in split_map.items():
|
| 399 |
+
for folder_name in folder_names:
|
| 400 |
+
split_dir = dataset_root / folder_name
|
| 401 |
+
if not split_dir.exists():
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
found_any = True
|
| 405 |
+
|
| 406 |
+
# Check for _annotations.txt (Roboflow's flat format)
|
| 407 |
+
ann_file = split_dir / "_annotations.txt"
|
| 408 |
+
if ann_file.exists():
|
| 409 |
+
try:
|
| 410 |
+
with open(ann_file, "r", encoding="utf-8") as f:
|
| 411 |
+
# Format is usually: filename,class_name
|
| 412 |
+
for line in f:
|
| 413 |
+
parts = line.strip().split(",")
|
| 414 |
+
if len(parts) >= 2:
|
| 415 |
+
fname, label = parts[0], parts[1]
|
| 416 |
+
img_path = split_dir / fname
|
| 417 |
+
if img_path.exists():
|
| 418 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 419 |
+
anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
|
| 420 |
+
rel_path = str(img_path.relative_to(dataset_root))
|
| 421 |
+
yield rel_path, image_id, split_name, anns
|
| 422 |
+
continue # Processed via file, skip folder logic
|
| 423 |
+
except Exception:
|
| 424 |
+
pass
|
| 425 |
+
|
| 426 |
+
# Fallback to Folder-based: split/class_name/image.jpg
|
| 427 |
+
for class_dir in split_dir.iterdir():
|
| 428 |
+
if class_dir.is_dir() and class_dir.name.lower() not in ["images", "labels"]:
|
| 429 |
+
label = class_dir.name
|
| 430 |
+
for img_path in class_dir.rglob("*"):
|
| 431 |
+
if img_path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
| 432 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 433 |
+
anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
|
| 434 |
+
rel_path = str(img_path.relative_to(dataset_root))
|
| 435 |
+
yield rel_path, image_id, split_name, anns
|
| 436 |
+
|
| 437 |
+
# Fallback to root scan if no split folders found
|
| 438 |
+
if not found_any:
|
| 439 |
+
for img_path in sorted(dataset_root.rglob("*")):
|
| 440 |
+
if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
| 441 |
+
continue
|
| 442 |
+
# Simple heuristic: parent folder is class name
|
| 443 |
+
label = img_path.parent.name if img_path.parent != dataset_root else "unknown"
|
| 444 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 445 |
+
anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
|
| 446 |
+
rel_path = str(img_path.relative_to(dataset_root))
|
| 447 |
+
yield rel_path, image_id, "train", anns
|
| 448 |
+
|
| 449 |
+
class CSVParser:
|
| 450 |
+
"""
|
| 451 |
+
Reads CSV files for NLP (classification, NER) or Tabular data.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
@staticmethod
|
| 455 |
+
def detect_delimiter(file_path: Path) -> str:
|
| 456 |
+
try:
|
| 457 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 458 |
+
header = f.readline()
|
| 459 |
+
if ';' in header: return ';'
|
| 460 |
+
if '\t' in header: return '\t'
|
| 461 |
+
return ','
|
| 462 |
+
except Exception:
|
| 463 |
+
return ','
|
| 464 |
+
|
| 465 |
+
@staticmethod
|
| 466 |
+
def parse_file(
|
| 467 |
+
csv_path: Path,
|
| 468 |
+
dataset_id: str,
|
| 469 |
+
text_column: str = "text",
|
| 470 |
+
label_column: str = "label",
|
| 471 |
+
) -> list[dict]:
|
| 472 |
+
annotations = []
|
| 473 |
+
delimiter = CSVParser.detect_delimiter(csv_path)
|
| 474 |
+
try:
|
| 475 |
+
with open(csv_path, mode='r', encoding='utf-8', errors='replace') as f:
|
| 476 |
+
reader = csv.DictReader(f, delimiter=delimiter)
|
| 477 |
+
for row in reader:
|
| 478 |
+
image_id = f"txt-{uuid.uuid4().hex[:12]}"
|
| 479 |
+
text = row.get(text_column, "")
|
| 480 |
+
label = row.get(label_column, "unknown")
|
| 481 |
+
if text:
|
| 482 |
+
annotations.append(
|
| 483 |
+
_make_ann(
|
| 484 |
+
image_id=image_id,
|
| 485 |
+
dataset_id=dataset_id,
|
| 486 |
+
label=label,
|
| 487 |
+
bbox=(0, 0, 0, 0),
|
| 488 |
+
ann_type="nlp_classification"
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
except Exception as e:
|
| 492 |
+
log.error("csv_parse_error", file=str(csv_path), error=str(e))
|
| 493 |
+
return annotations
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# ── Utilities ────────────────────────────────────────────────────────────────
|
| 497 |
+
|
| 498 |
+
def _img_dimensions(path: Path) -> tuple[int, int]:
|
| 499 |
+
"""Fast dimension detection via struct."""
|
| 500 |
+
try:
|
| 501 |
+
import struct
|
| 502 |
+
with open(path, "rb") as f:
|
| 503 |
+
data = f.read(24)
|
| 504 |
+
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
| 505 |
+
return struct.unpack(">II", data[16:24])
|
| 506 |
+
if data[:2] == b"\xff\xd8":
|
| 507 |
+
f.seek(0)
|
| 508 |
+
full = f.read(2048) # Read more for JPEG header
|
| 509 |
+
i = 2
|
| 510 |
+
while i < len(full) - 9:
|
| 511 |
+
if full[i] == 0xFF and full[i + 1] in (0xC0, 0xC1, 0xC2):
|
| 512 |
+
h, w = struct.unpack(">HH", full[i + 5:i + 9])
|
| 513 |
+
return int(w), int(h)
|
| 514 |
+
i += 1
|
| 515 |
+
except: pass
|
| 516 |
+
return 0, 0
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
# ── Format Detector ───────────────────────────────────────────────────────────
|
| 520 |
+
|
| 521 |
+
def detect_format(dataset_root: Path) -> str:
|
| 522 |
+
"""Heuristically detect the annotation format in a dataset directory."""
|
| 523 |
+
# COCO: look for JSON with 'images' and 'annotations' keys
|
| 524 |
+
for jf in dataset_root.rglob("*.json"):
|
| 525 |
+
try:
|
| 526 |
+
snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048]
|
| 527 |
+
if '"images"' in snippet and '"annotations"' in snippet:
|
| 528 |
+
return "coco"
|
| 529 |
+
except OSError:
|
| 530 |
+
pass
|
| 531 |
+
|
| 532 |
+
# VOC: look for XML files with <annotation> root
|
| 533 |
+
for xf in dataset_root.rglob("*.xml"):
|
| 534 |
+
try:
|
| 535 |
+
snippet = xf.read_text(encoding="utf-8", errors="replace")[:512]
|
| 536 |
+
if "<annotation>" in snippet:
|
| 537 |
+
return "voc"
|
| 538 |
+
except OSError:
|
| 539 |
+
pass
|
| 540 |
+
|
| 541 |
+
# YOLO: check for .txt label files and data.yaml
|
| 542 |
+
if list(dataset_root.rglob("data.yaml")):
|
| 543 |
+
return "yolo"
|
| 544 |
+
|
| 545 |
+
txt_files = list(dataset_root.rglob("*.txt"))
|
| 546 |
+
# Filter out common non-label files
|
| 547 |
+
label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")]
|
| 548 |
+
if label_txts:
|
| 549 |
+
# Check if first line looks like YOLO (<int> <float> <float> <float> <float>)
|
| 550 |
+
try:
|
| 551 |
+
first_txt = label_txts[0]
|
| 552 |
+
content = first_txt.read_text(encoding="utf-8").strip().split('\n')[0]
|
| 553 |
+
if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content):
|
| 554 |
+
return "yolo"
|
| 555 |
+
except Exception:
|
| 556 |
+
pass
|
| 557 |
+
|
| 558 |
+
# Roboflow Classification TXT: check for split folders containing only subfolders (class names)
|
| 559 |
+
# or check for _annotations.txt
|
| 560 |
+
if list(dataset_root.rglob("_annotations.txt")):
|
| 561 |
+
return "txt"
|
| 562 |
+
|
| 563 |
+
# Check for folder-based classification (split/class_name/img.jpg)
|
| 564 |
+
# If we see folders that aren't 'images' or 'labels' inside train/val/test
|
| 565 |
+
for split in ["train", "valid", "test"]:
|
| 566 |
+
split_dir = dataset_root / split
|
| 567 |
+
if split_dir.exists() and split_dir.is_dir():
|
| 568 |
+
subdirs = [d for d in split_dir.iterdir() if d.is_dir()]
|
| 569 |
+
if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs):
|
| 570 |
+
return "txt"
|
| 571 |
+
|
| 572 |
+
# CSV/NLP: check for csv files
|
| 573 |
+
if list(dataset_root.rglob("*.csv")):
|
| 574 |
+
return "csv"
|
| 575 |
+
|
| 576 |
+
return "custom"
|
datasets/base_adapter.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Tuple, Iterator, Dict, Any, Optional
|
| 4 |
+
from models.dataset import UniversalDatasetItem, DatasetTask
|
| 5 |
+
|
| 6 |
+
class DatasetAdapter(ABC):
|
| 7 |
+
"""
|
| 8 |
+
Base interface for all dataset format adapters.
|
| 9 |
+
Following the senior architect pattern: decoupling format logic from import orchestration.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 14 |
+
"""Return True if this adapter can handle the dataset at the given path."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 19 |
+
"""Identify the primary task type (detection, classification, etc.) for this dataset."""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 24 |
+
"""
|
| 25 |
+
Yield (image_record, annotations) for each item in the dataset.
|
| 26 |
+
Memory-efficient streaming for large Roboflow datasets.
|
| 27 |
+
"""
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 32 |
+
"""Extract or derive the list of class names from the dataset."""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def get_metadata(self, dataset_path: Path) -> Dict[str, Any]:
|
| 36 |
+
"""Optional: Extract additional format-specific metadata."""
|
| 37 |
+
return {}
|
datasets/format_adapters.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, List, Tuple, Iterator, Dict
|
| 5 |
+
from .base_adapter import DatasetAdapter
|
| 6 |
+
from models.dataset import UniversalDatasetItem, DatasetContentType, UniversalAnnotation, UniversalAnnotationType, DatasetTask
|
| 7 |
+
from .annotation_parser import YOLOParser, COCOParser, VOCParser, RoboflowTXTParser, _img_dimensions
|
| 8 |
+
|
| 9 |
+
class YOLOAdapter(DatasetAdapter):
|
| 10 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 11 |
+
if list(dataset_path.rglob("data.yaml")):
|
| 12 |
+
return True
|
| 13 |
+
txt_files = list(dataset_path.rglob("*.txt"))
|
| 14 |
+
label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")]
|
| 15 |
+
if label_txts:
|
| 16 |
+
try:
|
| 17 |
+
content = label_txts[0].read_text(encoding="utf-8").strip().split('\n')[0]
|
| 18 |
+
if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content):
|
| 19 |
+
return True
|
| 20 |
+
except: pass
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 24 |
+
return DatasetTask.detection
|
| 25 |
+
|
| 26 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 27 |
+
return YOLOParser.load_class_map(dataset_path)
|
| 28 |
+
|
| 29 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 30 |
+
class_map = self.get_class_names(dataset_path)
|
| 31 |
+
for rel_path, image_id, split, anns in YOLOParser.iter_dataset(dataset_path, dataset_id, class_map):
|
| 32 |
+
abs_path = dataset_path / rel_path
|
| 33 |
+
w, h = _img_dimensions(abs_path)
|
| 34 |
+
img_rec = {
|
| 35 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 36 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 37 |
+
"split": split, "ann_count": len(anns),
|
| 38 |
+
}
|
| 39 |
+
yield img_rec, anns
|
| 40 |
+
|
| 41 |
+
class COCOAdapter(DatasetAdapter):
|
| 42 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 43 |
+
for jf in dataset_path.rglob("*.json"):
|
| 44 |
+
try:
|
| 45 |
+
snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048]
|
| 46 |
+
if '"images"' in snippet and '"annotations"' in snippet:
|
| 47 |
+
return True
|
| 48 |
+
except: pass
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 52 |
+
return DatasetTask.segmentation # Roboflow COCO often implies segmentation
|
| 53 |
+
|
| 54 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 55 |
+
ann_files = COCOParser.find_annotation_files(dataset_path)
|
| 56 |
+
all_classes = []
|
| 57 |
+
for ann_file in ann_files:
|
| 58 |
+
classes, _ = COCOParser.parse_file(ann_file, "dummy")
|
| 59 |
+
all_classes = list(dict.fromkeys(all_classes + classes))
|
| 60 |
+
return all_classes
|
| 61 |
+
|
| 62 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 63 |
+
ann_files = COCOParser.find_annotation_files(dataset_path)
|
| 64 |
+
for ann_file in ann_files:
|
| 65 |
+
_, coco_results = COCOParser.parse_file(ann_file, dataset_id)
|
| 66 |
+
for rel_path, image_id, split, anns in coco_results:
|
| 67 |
+
abs_path = dataset_path / rel_path
|
| 68 |
+
w, h = _img_dimensions(abs_path)
|
| 69 |
+
img_rec = {
|
| 70 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 71 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 72 |
+
"split": split, "ann_count": len(anns),
|
| 73 |
+
}
|
| 74 |
+
yield img_rec, anns
|
| 75 |
+
|
| 76 |
+
class VOCAdapter(DatasetAdapter):
|
| 77 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 78 |
+
for xf in dataset_path.rglob("*.xml"):
|
| 79 |
+
try:
|
| 80 |
+
snippet = xf.read_text(encoding="utf-8", errors="replace")[:512]
|
| 81 |
+
if "<annotation>" in snippet:
|
| 82 |
+
return True
|
| 83 |
+
except: pass
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 87 |
+
return DatasetTask.detection
|
| 88 |
+
|
| 89 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 90 |
+
classes = set()
|
| 91 |
+
for _, _, _, _, _, anns in VOCParser.iter_dataset(dataset_path, "dummy"):
|
| 92 |
+
for ann in anns:
|
| 93 |
+
classes.add(ann["label"])
|
| 94 |
+
return sorted(list(classes))
|
| 95 |
+
|
| 96 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 97 |
+
for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(dataset_path, dataset_id):
|
| 98 |
+
img_rec = {
|
| 99 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 100 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 101 |
+
"split": split, "ann_count": len(anns),
|
| 102 |
+
}
|
| 103 |
+
yield img_rec, anns
|
| 104 |
+
|
| 105 |
+
class CreateMLAdapter(DatasetAdapter):
|
| 106 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 107 |
+
for jf in dataset_path.rglob("*.json"):
|
| 108 |
+
try:
|
| 109 |
+
snippet = jf.read_text(encoding="utf-8", errors="replace")[:1024]
|
| 110 |
+
if '"image"' in snippet and '"annotations"' in snippet and "[" in snippet:
|
| 111 |
+
return True
|
| 112 |
+
except: pass
|
| 113 |
+
return False
|
| 114 |
+
|
| 115 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 116 |
+
return DatasetTask.detection
|
| 117 |
+
|
| 118 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 119 |
+
classes = set()
|
| 120 |
+
for jf in dataset_path.rglob("*.json"):
|
| 121 |
+
try:
|
| 122 |
+
data = json.loads(jf.read_text(encoding="utf-8"))
|
| 123 |
+
if isinstance(data, list):
|
| 124 |
+
for item in data:
|
| 125 |
+
for ann in item.get("annotations", []):
|
| 126 |
+
if "label" in ann: classes.add(ann["label"])
|
| 127 |
+
except: pass
|
| 128 |
+
return sorted(list(classes))
|
| 129 |
+
|
| 130 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 131 |
+
from .annotation_parser import _make_ann
|
| 132 |
+
for jf in dataset_path.rglob("*.json"):
|
| 133 |
+
try:
|
| 134 |
+
data = json.loads(jf.read_text(encoding="utf-8"))
|
| 135 |
+
if not isinstance(data, list): continue
|
| 136 |
+
|
| 137 |
+
# Determine split from path
|
| 138 |
+
split = "train"
|
| 139 |
+
if "val" in jf.parts or "valid" in jf.parts: split = "val"
|
| 140 |
+
elif "test" in jf.parts: split = "test"
|
| 141 |
+
|
| 142 |
+
for item in data:
|
| 143 |
+
rel_img_path = item.get("image")
|
| 144 |
+
if not rel_img_path: continue
|
| 145 |
+
|
| 146 |
+
# Try to find the image relative to JSON or root
|
| 147 |
+
img_path = jf.parent / rel_img_path
|
| 148 |
+
if not img_path.exists():
|
| 149 |
+
img_path = dataset_path / rel_img_path
|
| 150 |
+
|
| 151 |
+
if img_path.exists():
|
| 152 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 153 |
+
w, h = _img_dimensions(img_path)
|
| 154 |
+
|
| 155 |
+
anns = []
|
| 156 |
+
for ca in item.get("annotations", []):
|
| 157 |
+
label = ca.get("label", "unknown")
|
| 158 |
+
coord = ca.get("coordinates", {})
|
| 159 |
+
# CreateML coords are usually center-based pixels: {x, y, width, height}
|
| 160 |
+
if "x" in coord and "y" in coord and w > 0 and h > 0:
|
| 161 |
+
cx, cy, bw, bh = coord["x"], coord["y"], coord["width"], coord["height"]
|
| 162 |
+
# Convert to top-left normalized
|
| 163 |
+
nx = (cx - bw/2) / w
|
| 164 |
+
ny = (cy - bh/2) / h
|
| 165 |
+
nw = bw / w
|
| 166 |
+
nh = bh / h
|
| 167 |
+
anns.append(_make_ann(image_id, dataset_id, label, (nx, ny, nw, nh)))
|
| 168 |
+
|
| 169 |
+
img_rec = {
|
| 170 |
+
"id": image_id, "filename": img_path.name,
|
| 171 |
+
"rel_path": str(img_path.relative_to(dataset_path)),
|
| 172 |
+
"width": w, "height": h, "split": split, "ann_count": len(anns)
|
| 173 |
+
}
|
| 174 |
+
yield img_rec, anns
|
| 175 |
+
except: pass
|
| 176 |
+
|
| 177 |
+
class NLPAdapter(DatasetAdapter):
|
| 178 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 179 |
+
return any(dataset_path.rglob("*.csv")) or any(dataset_path.rglob("*.tsv"))
|
| 180 |
+
|
| 181 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 182 |
+
return DatasetTask.nlp
|
| 183 |
+
|
| 184 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 185 |
+
# Implementation for NLP class names
|
| 186 |
+
return []
|
| 187 |
+
|
| 188 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 189 |
+
# Implementation for NLP items
|
| 190 |
+
yield {}, []
|
| 191 |
+
|
| 192 |
+
class TabularAdapter(DatasetAdapter):
|
| 193 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 194 |
+
return False # Placeholder
|
| 195 |
+
|
| 196 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 197 |
+
return DatasetTask.classification
|
| 198 |
+
|
| 199 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 200 |
+
return []
|
| 201 |
+
|
| 202 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 203 |
+
yield {}, []
|
| 204 |
+
|
| 205 |
+
class RoboflowClassificationAdapter(DatasetAdapter):
|
| 206 |
+
def detect(self, dataset_path: Path) -> bool:
|
| 207 |
+
# Check for _annotations.txt or folder-based classification
|
| 208 |
+
if list(dataset_path.rglob("_annotations.txt")): return True
|
| 209 |
+
for split in ["train", "valid", "test"]:
|
| 210 |
+
split_dir = dataset_path / split
|
| 211 |
+
if split_dir.exists() and split_dir.is_dir():
|
| 212 |
+
subdirs = [d for d in split_dir.iterdir() if d.is_dir()]
|
| 213 |
+
if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs):
|
| 214 |
+
return True
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
def get_task(self, dataset_path: Path) -> DatasetTask:
|
| 218 |
+
return DatasetTask.classification
|
| 219 |
+
|
| 220 |
+
def get_class_names(self, dataset_path: Path) -> List[str]:
|
| 221 |
+
classes = set()
|
| 222 |
+
for _, _, _, anns in RoboflowTXTParser.iter_dataset(dataset_path, "dummy"):
|
| 223 |
+
for ann in anns: classes.add(ann["label"])
|
| 224 |
+
return sorted(list(classes))
|
| 225 |
+
|
| 226 |
+
def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
|
| 227 |
+
for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(dataset_path, dataset_id):
|
| 228 |
+
abs_path = dataset_path / rel_path
|
| 229 |
+
w, h = _img_dimensions(abs_path)
|
| 230 |
+
img_rec = {
|
| 231 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 232 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 233 |
+
"split": split, "ann_count": len(anns),
|
| 234 |
+
}
|
| 235 |
+
yield img_rec, anns
|
datasets/import_service.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets/import_service.py — Dataset Import Pipeline.
|
| 3 |
+
|
| 4 |
+
Pipeline stages:
|
| 5 |
+
1. Create job record
|
| 6 |
+
2. Download dataset zip (chunked, progress-tracked)
|
| 7 |
+
3. Extract zip safely (path-traversal protected)
|
| 8 |
+
4. Detect annotation format & task type
|
| 9 |
+
5. Index images into dataset_images table
|
| 10 |
+
6. Parse & store metadata (Stats only, annotations are read-on-demand)
|
| 11 |
+
7. Update dataset stats (images, classes, size)
|
| 12 |
+
8. Mark job completed / failed
|
| 13 |
+
|
| 14 |
+
All stages run as background tasks.
|
| 15 |
+
Supports Roboflow, HuggingFace, and local file/folder imports.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import hashlib
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
import uuid
|
| 24 |
+
import zipfile
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 28 |
+
|
| 29 |
+
import aiofiles
|
| 30 |
+
import httpx
|
| 31 |
+
from huggingface_hub import snapshot_download
|
| 32 |
+
|
| 33 |
+
from config import settings
|
| 34 |
+
from . import registry as ds_reg
|
| 35 |
+
from .format_adapters import (
|
| 36 |
+
YOLOAdapter, COCOAdapter, VOCAdapter, CreateMLAdapter,
|
| 37 |
+
RoboflowClassificationAdapter, NLPAdapter, TabularAdapter
|
| 38 |
+
)
|
| 39 |
+
from .base_adapter import DatasetAdapter
|
| 40 |
+
from .annotation_parser import _img_dimensions
|
| 41 |
+
from observability.logger import audit, get_logger
|
| 42 |
+
from models.dataset import DatasetStatus, DatasetTask, ImportRequest, Dataset
|
| 43 |
+
|
| 44 |
+
log = get_logger("import_service")
|
| 45 |
+
|
| 46 |
+
ADAPTERS: List[DatasetAdapter] = [
|
| 47 |
+
YOLOAdapter(),
|
| 48 |
+
COCOAdapter(),
|
| 49 |
+
VOCAdapter(),
|
| 50 |
+
CreateMLAdapter(),
|
| 51 |
+
RoboflowClassificationAdapter(),
|
| 52 |
+
NLPAdapter(),
|
| 53 |
+
TabularAdapter(),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
def get_adapter_for_path(path: Path) -> DatasetAdapter | None:
|
| 57 |
+
for adapter in ADAPTERS:
|
| 58 |
+
if adapter.detect(path):
|
| 59 |
+
return adapter
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
async def recover_stale_jobs() -> None:
|
| 63 |
+
"""Cleanup dataset import jobs that were left in 'running' or 'queued' state."""
|
| 64 |
+
await ds_reg.cleanup_stale_jobs()
|
| 65 |
+
|
| 66 |
+
def _dataset_path(dataset_id: str) -> Path:
|
| 67 |
+
return settings.datasets_dir / dataset_id
|
| 68 |
+
|
| 69 |
+
# ── Entry Point ──────────────────────────────────────────────────────────────
|
| 70 |
+
|
| 71 |
+
async def start_import(req: ImportRequest) -> str:
|
| 72 |
+
"""Entry point to initiate a background import job."""
|
| 73 |
+
job_id = f"job-{uuid.uuid4().hex[:8]}"
|
| 74 |
+
|
| 75 |
+
# Create initial job record
|
| 76 |
+
await ds_reg.update_job(
|
| 77 |
+
job_id,
|
| 78 |
+
dataset_id=req.dataset_id,
|
| 79 |
+
status="queued",
|
| 80 |
+
progress=0,
|
| 81 |
+
message="Import queued",
|
| 82 |
+
type=str(req.source)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Launch background task
|
| 86 |
+
asyncio.create_task(_run_pipeline(job_id, req, req.dataset_name or req.dataset_id))
|
| 87 |
+
|
| 88 |
+
return job_id
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Pipeline orchestrator ────────────────────────────────────────────────────
|
| 92 |
+
|
| 93 |
+
async def _run_pipeline(job_id: str, req: ImportRequest, dataset_name: str) -> None:
|
| 94 |
+
started = datetime.utcnow().isoformat()
|
| 95 |
+
await ds_reg.update_job(job_id, status="running", started_at=started, message="Starting import")
|
| 96 |
+
await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.importing, progress=0.01)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
# Stage 1 – Resolve download URL or local path
|
| 100 |
+
source_path = await _stage_acquire(job_id, req)
|
| 101 |
+
|
| 102 |
+
# Stage 2 – Extract / Prepare Directory
|
| 103 |
+
extract_dir = await _stage_extract(job_id, req.dataset_id, source_path)
|
| 104 |
+
|
| 105 |
+
# Stage 3 – Detect adapter and Task
|
| 106 |
+
await ds_reg.update_job(job_id, progress=0.55, message="Detecting dataset format...")
|
| 107 |
+
adapter = await asyncio.to_thread(get_adapter_for_path, extract_dir)
|
| 108 |
+
|
| 109 |
+
if not adapter:
|
| 110 |
+
log.warning("no_adapter_found_generic_fallback", dataset_id=req.dataset_id)
|
| 111 |
+
image_records = await asyncio.to_thread(_scan_images_generic, req.dataset_id, extract_dir)
|
| 112 |
+
class_names = []
|
| 113 |
+
task = DatasetTask.classification
|
| 114 |
+
fmt_name = "custom"
|
| 115 |
+
else:
|
| 116 |
+
task = adapter.get_task(extract_dir)
|
| 117 |
+
fmt_name = adapter.__class__.__name__.replace("Adapter", "").lower()
|
| 118 |
+
|
| 119 |
+
log.info("adapter_detected", job_id=job_id, format=fmt_name, task=task)
|
| 120 |
+
await ds_reg.update_job(job_id, progress=0.60, message=f"Parsing {fmt_name.upper()} {task.upper()}")
|
| 121 |
+
|
| 122 |
+
# Stage 4 – Parse Metadata & Annotations (Streaming)
|
| 123 |
+
class_names = await asyncio.to_thread(adapter.get_class_names, extract_dir)
|
| 124 |
+
image_records = []
|
| 125 |
+
all_annotations = []
|
| 126 |
+
|
| 127 |
+
# Health metrics tracking
|
| 128 |
+
hashes = {} # hash -> filename
|
| 129 |
+
duplicates = 0
|
| 130 |
+
empty_images = 0
|
| 131 |
+
total_ann_count = 0
|
| 132 |
+
|
| 133 |
+
for img_rec, anns in adapter.iter_items(req.dataset_id, extract_dir):
|
| 134 |
+
# Duplicate detection via MD5 hash
|
| 135 |
+
abs_path = extract_dir / img_rec["rel_path"]
|
| 136 |
+
if abs_path.exists():
|
| 137 |
+
img_hash = _calculate_hash(abs_path)
|
| 138 |
+
if img_hash in hashes:
|
| 139 |
+
duplicates += 1
|
| 140 |
+
img_rec["metadata"] = json.dumps({"is_duplicate": True, "original": hashes[img_hash]})
|
| 141 |
+
else:
|
| 142 |
+
hashes[img_hash] = img_rec["filename"]
|
| 143 |
+
|
| 144 |
+
if not anns:
|
| 145 |
+
empty_images += 1
|
| 146 |
+
|
| 147 |
+
total_ann_count += len(anns)
|
| 148 |
+
image_records.append(img_rec)
|
| 149 |
+
all_annotations.extend(anns)
|
| 150 |
+
|
| 151 |
+
if not image_records:
|
| 152 |
+
raise ValueError(f"No valid data files found in {extract_dir}")
|
| 153 |
+
|
| 154 |
+
# Stage 5 – Indexing
|
| 155 |
+
await ds_reg.update_job(job_id, progress=0.80, message=f"Indexing {len(image_records)} items")
|
| 156 |
+
await ds_reg.index_images(req.dataset_id, image_records)
|
| 157 |
+
|
| 158 |
+
if all_annotations:
|
| 159 |
+
await ds_reg.update_job(job_id, progress=0.85, message=f"Indexing {len(all_annotations)} annotations")
|
| 160 |
+
await ds_reg.bulk_insert_annotations(all_annotations)
|
| 161 |
+
|
| 162 |
+
# Stage 6 – Stats & Health Analysis
|
| 163 |
+
size_bytes = await asyncio.to_thread(_dir_size, extract_dir)
|
| 164 |
+
|
| 165 |
+
# Calculate Health Score (0-100)
|
| 166 |
+
# Factors: duplicates, empty images (for detection), class balance (TODO)
|
| 167 |
+
score = 100.0
|
| 168 |
+
if len(image_records) > 0:
|
| 169 |
+
dup_penalty = (duplicates / len(image_records)) * 50
|
| 170 |
+
empty_penalty = (empty_images / len(image_records)) * 20 if task == DatasetTask.detection else 0
|
| 171 |
+
score = max(0.0, 100.0 - dup_penalty - empty_penalty)
|
| 172 |
+
|
| 173 |
+
stats_payload = {
|
| 174 |
+
"image_count": len(image_records),
|
| 175 |
+
"annotation_count": total_ann_count,
|
| 176 |
+
"class_count": len(class_names),
|
| 177 |
+
"empty_images": empty_images,
|
| 178 |
+
"duplicate_count": duplicates,
|
| 179 |
+
"health_score": round(score, 1),
|
| 180 |
+
"avg_objects": round(total_ann_count / len(image_records), 2) if image_records else 0
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
await ds_reg.update_dataset_stats(
|
| 184 |
+
req.dataset_id,
|
| 185 |
+
len(image_records),
|
| 186 |
+
len(class_names),
|
| 187 |
+
class_names,
|
| 188 |
+
size_bytes,
|
| 189 |
+
stats=stats_payload
|
| 190 |
+
)
|
| 191 |
+
await ds_reg.update_dataset_task(req.dataset_id, task)
|
| 192 |
+
|
| 193 |
+
# Cleanup temp zip if applicable
|
| 194 |
+
if source_path.is_file() and source_path.suffix.lower() == ".zip" and "_tmp" in str(source_path):
|
| 195 |
+
source_path.unlink(missing_ok=True)
|
| 196 |
+
|
| 197 |
+
# Stage 7 – Project Linking (Integration point)
|
| 198 |
+
local_path = str(extract_dir)
|
| 199 |
+
from projects.service import link_dataset_to_active_project
|
| 200 |
+
project_ds_root = await link_dataset_to_active_project(req.dataset_id, local_path)
|
| 201 |
+
final_local_path = str(project_ds_root) if project_ds_root and project_ds_root.exists() else local_path
|
| 202 |
+
|
| 203 |
+
# Completion
|
| 204 |
+
await ds_reg.update_job(
|
| 205 |
+
job_id, status="completed", progress=1.0,
|
| 206 |
+
message="Import complete", ended_at=datetime.utcnow().isoformat(),
|
| 207 |
+
)
|
| 208 |
+
await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.imported, progress=1.0, local_path=final_local_path)
|
| 209 |
+
await audit("dataset_import_complete", {"job_id": job_id, "path": final_local_path}, job_id=job_id)
|
| 210 |
+
log.info("import_complete", job_id=job_id, dataset_id=req.dataset_id)
|
| 211 |
+
|
| 212 |
+
except asyncio.CancelledError:
|
| 213 |
+
await _fail_job(job_id, req.dataset_id, "Import cancelled by user or system")
|
| 214 |
+
raise
|
| 215 |
+
except Exception as exc:
|
| 216 |
+
log.error("import_failed", job_id=job_id, error=str(exc))
|
| 217 |
+
await _fail_job(job_id, req.dataset_id, str(exc))
|
| 218 |
+
await audit("dataset_import_error", {"job_id": job_id, "error": str(exc)}, job_id=job_id, level="error")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
async def _fail_job(job_id: str, dataset_id: str, error: str) -> None:
|
| 222 |
+
await ds_reg.update_job(
|
| 223 |
+
job_id, status="failed", error=error,
|
| 224 |
+
ended_at=datetime.utcnow().isoformat(),
|
| 225 |
+
message="Import failed",
|
| 226 |
+
)
|
| 227 |
+
await ds_reg.update_dataset_status(dataset_id, DatasetStatus.failed, progress=0.0)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ── Stage 1: Acquire source ──────────────────────────────────────────────────
|
| 231 |
+
|
| 232 |
+
async def _stage_acquire(job_id: str, req: ImportRequest) -> Path:
|
| 233 |
+
"""Resolves the source (Download URL, HF Repo, or Local Path)."""
|
| 234 |
+
await ds_reg.update_job(job_id, progress=0.05, message="Acquiring source...")
|
| 235 |
+
|
| 236 |
+
if req.source in ("roboflow", "roboflow_curl"):
|
| 237 |
+
return await _acquire_roboflow(job_id, req)
|
| 238 |
+
|
| 239 |
+
if req.source == "huggingface":
|
| 240 |
+
return await _acquire_huggingface(job_id, req)
|
| 241 |
+
|
| 242 |
+
if req.source == "local":
|
| 243 |
+
return await _acquire_local(job_id, req)
|
| 244 |
+
|
| 245 |
+
raise ValueError(f"Unsupported source provider: {req.source}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
async def _acquire_roboflow(job_id: str, req: ImportRequest) -> Path:
|
| 249 |
+
"""Specialized Roboflow downloader using SDK or direct link."""
|
| 250 |
+
# Attempt SDK first (more reliable for Universe)
|
| 251 |
+
try:
|
| 252 |
+
from roboflow import Roboflow
|
| 253 |
+
api_key = req.roboflow_key or (req.headers.get("Authorization") if req.headers else None)
|
| 254 |
+
if api_key and "Bearer " in str(api_key):
|
| 255 |
+
api_key = api_key.split("Bearer ")[-1].strip()
|
| 256 |
+
|
| 257 |
+
if api_key and req.roboflow_workspace and req.roboflow_project:
|
| 258 |
+
rf = Roboflow(api_key=api_key)
|
| 259 |
+
project = rf.workspace(req.roboflow_workspace).project(req.roboflow_project)
|
| 260 |
+
version_obj = project.version(req.roboflow_version or 1)
|
| 261 |
+
|
| 262 |
+
tmp_target = DATASETS_ROOT / "_tmp" / f"rf-{uuid.uuid4().hex[:8]}"
|
| 263 |
+
await ds_reg.update_job(job_id, progress=0.10, message="Downloading via Roboflow SDK...")
|
| 264 |
+
|
| 265 |
+
# Threaded SDK call
|
| 266 |
+
await asyncio.to_thread(
|
| 267 |
+
version_obj.download,
|
| 268 |
+
_format_to_rf_slug(str(req.format)),
|
| 269 |
+
location=str(tmp_target)
|
| 270 |
+
)
|
| 271 |
+
return tmp_target
|
| 272 |
+
except Exception as e:
|
| 273 |
+
log.warning("roboflow_sdk_fallback", error=str(e))
|
| 274 |
+
|
| 275 |
+
# Fallback to direct HTTP download
|
| 276 |
+
url = req.download_url
|
| 277 |
+
if not url and req.source == "roboflow":
|
| 278 |
+
from adapters.roboflow_adapter import RoboflowAdapter
|
| 279 |
+
url = await RoboflowAdapter.get_download_url(
|
| 280 |
+
api_key=req.roboflow_key,
|
| 281 |
+
workspace=req.roboflow_workspace,
|
| 282 |
+
project_id=req.roboflow_project,
|
| 283 |
+
version=req.roboflow_version,
|
| 284 |
+
export_format=_format_to_rf_slug(str(req.format)),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if not url:
|
| 288 |
+
raise ValueError("Could not resolve Roboflow download URL")
|
| 289 |
+
|
| 290 |
+
return await _download_zip(job_id, req.dataset_id, url, req.headers)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
async def _acquire_huggingface(job_id: str, req: ImportRequest) -> Path:
|
| 294 |
+
if not req.hf_dataset_id:
|
| 295 |
+
raise ValueError("hf_dataset_id is missing")
|
| 296 |
+
|
| 297 |
+
dest_dir = _dataset_path(req.dataset_id)
|
| 298 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 299 |
+
|
| 300 |
+
await ds_reg.update_job(job_id, progress=0.10, message=f"Cloning {req.hf_dataset_id} from HF...")
|
| 301 |
+
|
| 302 |
+
await asyncio.to_thread(
|
| 303 |
+
snapshot_download,
|
| 304 |
+
repo_id=req.hf_dataset_id,
|
| 305 |
+
repo_type="dataset",
|
| 306 |
+
local_dir=str(dest_dir),
|
| 307 |
+
token=settings.hf_token,
|
| 308 |
+
local_dir_use_symlinks=False
|
| 309 |
+
)
|
| 310 |
+
return dest_dir
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
async def _acquire_local(job_id: str, req: ImportRequest) -> Path:
|
| 314 |
+
if not req.local_path:
|
| 315 |
+
raise ValueError("local_path is missing for local import")
|
| 316 |
+
|
| 317 |
+
path = Path(os.path.normpath(req.local_path.strip().strip('"').strip("'")))
|
| 318 |
+
if not path.exists():
|
| 319 |
+
raise FileNotFoundError(f"Local path does not exist: {path}")
|
| 320 |
+
|
| 321 |
+
return path
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ── Stage 2: Extraction ──────────────────────────────────────────────────────
|
| 325 |
+
|
| 326 |
+
async def _stage_extract(job_id: str, dataset_id: str, source_path: Path) -> Path:
|
| 327 |
+
dest = _dataset_path(dataset_id)
|
| 328 |
+
dest.mkdir(parents=True, exist_ok=True)
|
| 329 |
+
|
| 330 |
+
if source_path.is_dir():
|
| 331 |
+
if source_path == dest:
|
| 332 |
+
return dest
|
| 333 |
+
await ds_reg.update_job(job_id, progress=0.45, message="Copying local files...")
|
| 334 |
+
await asyncio.to_thread(_copy_dir_contents, source_path, dest)
|
| 335 |
+
return dest
|
| 336 |
+
|
| 337 |
+
# It's a zip
|
| 338 |
+
await ds_reg.update_job(job_id, progress=0.45, message="Extracting archive...")
|
| 339 |
+
await ds_reg.update_dataset_status(dataset_id, DatasetStatus.extracting, progress=0.45)
|
| 340 |
+
await asyncio.to_thread(_safe_extract, source_path, dest)
|
| 341 |
+
return dest
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# ── Stage 3: Parsing (Memory-Safe) ───────────────────────────────────────────
|
| 345 |
+
|
| 346 |
+
def _heuristic_task_detection(fmt: str, root: Path) -> DatasetTask:
|
| 347 |
+
"""Improved task detection based on file content."""
|
| 348 |
+
if fmt == "csv":
|
| 349 |
+
return DatasetTask.nlp
|
| 350 |
+
|
| 351 |
+
# Check for segmentation in COCO
|
| 352 |
+
if fmt == "coco":
|
| 353 |
+
# Sample first few lines of JSON if possible or check file size
|
| 354 |
+
return DatasetTask.segmentation # Heuristic: most modern COCO use cases
|
| 355 |
+
|
| 356 |
+
if fmt in ("yolo", "voc"):
|
| 357 |
+
return DatasetTask.detection
|
| 358 |
+
|
| 359 |
+
return DatasetTask.classification
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _parse_yolo(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 363 |
+
class_map = YOLOParser.load_class_map(root)
|
| 364 |
+
results = []
|
| 365 |
+
# Generator approach to keep memory low
|
| 366 |
+
for rel_path, image_id, split, anns in YOLOParser.iter_dataset(root, dataset_id, class_map):
|
| 367 |
+
abs_path = root / rel_path
|
| 368 |
+
w, h = _img_dimensions(abs_path)
|
| 369 |
+
img_rec = {
|
| 370 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 371 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 372 |
+
"split": split, "ann_count": len(anns),
|
| 373 |
+
}
|
| 374 |
+
results.append((img_rec, anns))
|
| 375 |
+
return class_map, results
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _parse_coco(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 379 |
+
ann_files = COCOParser.find_annotation_files(root)
|
| 380 |
+
all_classes: list[str] = []
|
| 381 |
+
results = []
|
| 382 |
+
for ann_file in ann_files:
|
| 383 |
+
classes, coco_results = COCOParser.parse_file(ann_file, dataset_id)
|
| 384 |
+
all_classes = list(dict.fromkeys(all_classes + classes))
|
| 385 |
+
for rel_path, image_id, split, anns in coco_results:
|
| 386 |
+
abs_path = root / rel_path
|
| 387 |
+
w, h = _img_dimensions(abs_path)
|
| 388 |
+
img_rec = {
|
| 389 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 390 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 391 |
+
"split": split, "ann_count": len(anns),
|
| 392 |
+
}
|
| 393 |
+
results.append((img_rec, anns))
|
| 394 |
+
return all_classes, results
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _parse_voc(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 398 |
+
class_set = set()
|
| 399 |
+
results = []
|
| 400 |
+
for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(root, dataset_id):
|
| 401 |
+
img_rec = {
|
| 402 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 403 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 404 |
+
"split": split, "ann_count": len(anns),
|
| 405 |
+
}
|
| 406 |
+
results.append((img_rec, anns))
|
| 407 |
+
for ann in anns:
|
| 408 |
+
class_set.add(ann["label"])
|
| 409 |
+
return sorted(list(class_set)), results
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def _parse_csv(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 413 |
+
all_classes = set()
|
| 414 |
+
results = []
|
| 415 |
+
for csv_path in root.rglob("*.csv"):
|
| 416 |
+
anns = CSVParser.parse_file(csv_path, dataset_id)
|
| 417 |
+
# For CSV, each annotation is a row. We group by text entry id (image_id)
|
| 418 |
+
anns_by_id: Dict[str, List[Dict]] = {}
|
| 419 |
+
for ann in anns:
|
| 420 |
+
all_classes.add(ann["label"])
|
| 421 |
+
anns_by_id.setdefault(ann["image_id"], []).append(ann)
|
| 422 |
+
|
| 423 |
+
for text_id, grouped_anns in anns_by_id.items():
|
| 424 |
+
img_rec = {
|
| 425 |
+
"id": text_id, "filename": csv_path.name,
|
| 426 |
+
"rel_path": str(csv_path.relative_to(root)),
|
| 427 |
+
"width": 0, "height": 0, "split": "train", "ann_count": len(grouped_anns),
|
| 428 |
+
}
|
| 429 |
+
results.append((img_rec, grouped_anns))
|
| 430 |
+
return sorted(list(all_classes)), results
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _parse_txt(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 434 |
+
from datasets.annotation_parser import RoboflowTXTParser
|
| 435 |
+
results = []
|
| 436 |
+
class_set = set()
|
| 437 |
+
|
| 438 |
+
for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(root, dataset_id):
|
| 439 |
+
abs_path = root / rel_path
|
| 440 |
+
w, h = _img_dimensions(abs_path)
|
| 441 |
+
img_rec = {
|
| 442 |
+
"id": image_id, "filename": Path(rel_path).name,
|
| 443 |
+
"rel_path": str(rel_path), "width": w, "height": h,
|
| 444 |
+
"split": split, "ann_count": len(anns),
|
| 445 |
+
}
|
| 446 |
+
results.append((img_rec, anns))
|
| 447 |
+
for ann in anns:
|
| 448 |
+
class_set.add(ann["label"])
|
| 449 |
+
|
| 450 |
+
return sorted(list(class_set)), results
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _parse_generic_folder(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
|
| 454 |
+
"""
|
| 455 |
+
Enhanced generic folder parser. Supports:
|
| 456 |
+
1. root/class_name/img.jpg
|
| 457 |
+
2. root/train/class_name/img.jpg
|
| 458 |
+
3. root/images/img.jpg
|
| 459 |
+
"""
|
| 460 |
+
results = []
|
| 461 |
+
class_set = set()
|
| 462 |
+
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
|
| 463 |
+
|
| 464 |
+
# Structural keywords to ignore as classes
|
| 465 |
+
ignore = {"images", "labels", "train", "val", "test", "validation", "training", "valid", "testing", "unknown", "annotations"}
|
| 466 |
+
|
| 467 |
+
for img_path in sorted(root.rglob("*")):
|
| 468 |
+
if img_path.suffix.lower() not in exts:
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
rel_path = img_path.relative_to(root)
|
| 472 |
+
parts = rel_path.parts
|
| 473 |
+
|
| 474 |
+
# Heuristic for class detection
|
| 475 |
+
label = "unknown"
|
| 476 |
+
split = "train"
|
| 477 |
+
|
| 478 |
+
# Detect split if first folder is a split keyword
|
| 479 |
+
if parts[0].lower() in ignore and len(parts) > 1:
|
| 480 |
+
if parts[0].lower() in ("train", "training"): split = "train"
|
| 481 |
+
elif parts[0].lower() in ("val", "valid", "validation"): split = "val"
|
| 482 |
+
elif parts[0].lower() in ("test", "testing"): split = "test"
|
| 483 |
+
|
| 484 |
+
# Check if next part is class name
|
| 485 |
+
if len(parts) > 2 and parts[1].lower() not in ignore:
|
| 486 |
+
label = parts[1]
|
| 487 |
+
elif len(parts) > 1 and parts[1].lower() not in ignore:
|
| 488 |
+
label = parts[1]
|
| 489 |
+
elif len(parts) > 1 and parts[0].lower() not in ignore:
|
| 490 |
+
label = parts[0]
|
| 491 |
+
|
| 492 |
+
anns = []
|
| 493 |
+
if label != "unknown":
|
| 494 |
+
class_set.add(label)
|
| 495 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 496 |
+
# Create a virtual annotation for classification
|
| 497 |
+
from datasets.annotation_parser import _make_ann
|
| 498 |
+
anns.append(_make_ann(image_id, dataset_id, label, ann_type="classification"))
|
| 499 |
+
else:
|
| 500 |
+
image_id = f"img-{uuid.uuid4().hex[:12]}"
|
| 501 |
+
|
| 502 |
+
w, h = _img_dimensions(img_path)
|
| 503 |
+
img_rec = {
|
| 504 |
+
"id": image_id,
|
| 505 |
+
"filename": img_path.name,
|
| 506 |
+
"rel_path": str(rel_path),
|
| 507 |
+
"width": w, "height": h,
|
| 508 |
+
"split": split,
|
| 509 |
+
"ann_count": len(anns),
|
| 510 |
+
}
|
| 511 |
+
results.append((img_rec, anns))
|
| 512 |
+
|
| 513 |
+
return sorted(list(class_set)), results
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# ── Utilities ────────────────────────────────────────────────────────────────
|
| 517 |
+
|
| 518 |
+
async def _download_zip(job_id: str, dataset_id: str, url: str, custom_headers: dict = None) -> Path:
|
| 519 |
+
tmp_dir = DATASETS_ROOT / "_tmp"
|
| 520 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 521 |
+
zip_path = tmp_dir / f"{dataset_id}-{uuid.uuid4().hex[:8]}.zip"
|
| 522 |
+
|
| 523 |
+
headers = {
|
| 524 |
+
"User-Agent": "Mozilla/5.0 (MLForge Workbench)",
|
| 525 |
+
"Accept": "application/zip, application/octet-stream, */*",
|
| 526 |
+
}
|
| 527 |
+
if custom_headers: headers.update(custom_headers)
|
| 528 |
+
|
| 529 |
+
async with httpx.AsyncClient(follow_redirects=True, timeout=600.0, headers=headers) as client:
|
| 530 |
+
async with client.stream("GET", url) as resp:
|
| 531 |
+
resp.raise_for_status()
|
| 532 |
+
total = int(resp.headers.get("content-length", 0)) or None
|
| 533 |
+
downloaded = 0
|
| 534 |
+
async with aiofiles.open(zip_path, "wb") as f:
|
| 535 |
+
async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size):
|
| 536 |
+
await f.write(chunk)
|
| 537 |
+
downloaded += len(chunk)
|
| 538 |
+
if total:
|
| 539 |
+
pct = 0.10 + (downloaded / total) * 0.35 # 10% -> 45%
|
| 540 |
+
await ds_reg.update_job(job_id, progress=round(pct, 3), message=f"Downloading: {_fmt_bytes(downloaded)} / {_fmt_bytes(total)}")
|
| 541 |
+
|
| 542 |
+
return zip_path
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def _safe_extract(zip_path: Path, dest: Path) -> None:
|
| 546 |
+
with zipfile.ZipFile(str(zip_path), "r") as zf:
|
| 547 |
+
for member in zf.namelist():
|
| 548 |
+
if os.path.isabs(member) or ".." in Path(member).parts: continue
|
| 549 |
+
zf.extract(member, str(dest))
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def _copy_dir_contents(src: Path, dest: Path) -> None:
|
| 553 |
+
for item in src.iterdir():
|
| 554 |
+
s, d = src / item.name, dest / item.name
|
| 555 |
+
if s.is_dir(): shutil.copytree(s, d, dirs_exist_ok=True)
|
| 556 |
+
else: shutil.copy2(s, d)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def _scan_images_generic(dataset_id: str, root: Path) -> list[dict]:
|
| 560 |
+
records = []
|
| 561 |
+
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
| 562 |
+
for img_path in sorted(root.rglob("*")):
|
| 563 |
+
if img_path.suffix.lower() in exts:
|
| 564 |
+
w, h = _img_dimensions(img_path)
|
| 565 |
+
records.append({
|
| 566 |
+
"id": f"img-{uuid.uuid4().hex[:12]}",
|
| 567 |
+
"filename": img_path.name,
|
| 568 |
+
"rel_path": str(img_path.relative_to(root)),
|
| 569 |
+
"width": w, "height": h, "split": "train", "ann_count": 0,
|
| 570 |
+
})
|
| 571 |
+
return records
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _dir_size(path: Path) -> int:
|
| 575 |
+
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _fmt_bytes(n: int) -> str:
|
| 579 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 580 |
+
if n < 1024: return f"{n:.1f} {unit}"
|
| 581 |
+
n /= 1024
|
| 582 |
+
return f"{n:.1f} PB"
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _format_to_rf_slug(fmt: str) -> str:
|
| 586 |
+
return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8")
|
| 587 |
+
|
| 588 |
+
def _format_to_rf_slug(fmt: str) -> str:
|
| 589 |
+
return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8")
|
datasets/registry.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets/registry.py — Dataset Registry: persistent CRUD against datasets table.
|
| 3 |
+
All DB interactions for datasets and dataset_jobs live here.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import uuid
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from database.connection import get_db
|
| 13 |
+
from models.dataset import Dataset, DatasetJob, DatasetStatus, row_to_dataset, row_to_job
|
| 14 |
+
from observability.logger import get_logger
|
| 15 |
+
|
| 16 |
+
log = get_logger("dataset_registry")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ── Dataset CRUD ──────────────────────────────────────────────────────────────
|
| 20 |
+
|
| 21 |
+
async def get_all_datasets(
|
| 22 |
+
task: str | None = None,
|
| 23 |
+
format: str | None = None,
|
| 24 |
+
source: str | None = None,
|
| 25 |
+
status: str | None = None,
|
| 26 |
+
search: str | None = None,
|
| 27 |
+
starred: bool | None = None,
|
| 28 |
+
limit: int = 500,
|
| 29 |
+
offset: int = 0,
|
| 30 |
+
) -> list[Dataset]:
|
| 31 |
+
db = await get_db()
|
| 32 |
+
clauses = []
|
| 33 |
+
params: list[Any] = []
|
| 34 |
+
|
| 35 |
+
if task:
|
| 36 |
+
clauses.append("task = ?")
|
| 37 |
+
params.append(task)
|
| 38 |
+
if format:
|
| 39 |
+
clauses.append("format = ?")
|
| 40 |
+
params.append(format)
|
| 41 |
+
if source:
|
| 42 |
+
clauses.append("source = ?")
|
| 43 |
+
params.append(source)
|
| 44 |
+
if status:
|
| 45 |
+
clauses.append("status = ?")
|
| 46 |
+
params.append(status)
|
| 47 |
+
if starred is not None:
|
| 48 |
+
clauses.append("starred = ?")
|
| 49 |
+
params.append(1 if starred else 0)
|
| 50 |
+
if search:
|
| 51 |
+
clauses.append("(name LIKE ? OR description LIKE ? OR tags LIKE ?)")
|
| 52 |
+
q = f"%{search}%"
|
| 53 |
+
params.extend([q, q, q])
|
| 54 |
+
|
| 55 |
+
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
| 56 |
+
sql = f"SELECT * FROM datasets {where} ORDER BY updated_at DESC LIMIT ? OFFSET ?"
|
| 57 |
+
params.extend([limit, offset])
|
| 58 |
+
|
| 59 |
+
async with db.execute(sql, params) as cur:
|
| 60 |
+
rows = await cur.fetchall()
|
| 61 |
+
return [row_to_dataset(r) for r in rows]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
async def get_dataset_stats(dataset_id: str) -> dict:
|
| 65 |
+
"""Get pre-computed class distributions and statistics from the indexed annotations."""
|
| 66 |
+
db = await get_db()
|
| 67 |
+
|
| 68 |
+
# Class distribution (from dataset_annotations table)
|
| 69 |
+
async with db.execute(
|
| 70 |
+
"SELECT label, COUNT(*) as count FROM dataset_annotations WHERE dataset_id=? GROUP BY label ORDER BY count DESC",
|
| 71 |
+
(dataset_id,)
|
| 72 |
+
) as cur:
|
| 73 |
+
dist = await cur.fetchall()
|
| 74 |
+
|
| 75 |
+
# Split distribution (from dataset_images table)
|
| 76 |
+
async with db.execute(
|
| 77 |
+
"SELECT split, COUNT(*) as count FROM dataset_images WHERE dataset_id=? GROUP BY split",
|
| 78 |
+
(dataset_id,)
|
| 79 |
+
) as cur:
|
| 80 |
+
splits = await cur.fetchall()
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"class_distribution": {row["label"]: row["count"] for row in dist},
|
| 84 |
+
"split_distribution": {row["split"]: row["count"] for row in splits}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def get_dataset(dataset_id: str) -> Dataset | None:
|
| 89 |
+
db = await get_db()
|
| 90 |
+
async with db.execute("SELECT * FROM datasets WHERE id = ?", (dataset_id,)) as cur:
|
| 91 |
+
row = await cur.fetchone()
|
| 92 |
+
return row_to_dataset(row) if row else None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
async def count_datasets() -> int:
|
| 96 |
+
db = await get_db()
|
| 97 |
+
async with db.execute("SELECT COUNT(*) FROM datasets") as cur:
|
| 98 |
+
row = await cur.fetchone()
|
| 99 |
+
return row[0] if row else 0
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
async def upsert_dataset(ds: Dataset) -> None:
|
| 103 |
+
"""Insert or replace a dataset record."""
|
| 104 |
+
db = await get_db()
|
| 105 |
+
|
| 106 |
+
task = getattr(ds.task, "value", ds.task)
|
| 107 |
+
fmt = getattr(ds.format, "value", ds.format)
|
| 108 |
+
src = getattr(ds.source, "value", ds.source)
|
| 109 |
+
status = getattr(ds.status, "value", ds.status)
|
| 110 |
+
await db.execute(
|
| 111 |
+
"""INSERT OR REPLACE INTO datasets
|
| 112 |
+
(id, name, description, task, format, source, status,
|
| 113 |
+
images, classes, class_names, size_bytes, size_label,
|
| 114 |
+
local_path, import_progress, tags, versions, active_version,
|
| 115 |
+
starred, roboflow_id, created_at, updated_at)
|
| 116 |
+
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,datetime('now'))""",
|
| 117 |
+
(
|
| 118 |
+
ds.id, ds.name, ds.description, task, fmt,
|
| 119 |
+
src, status,
|
| 120 |
+
ds.images, ds.classes,
|
| 121 |
+
json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
|
| 122 |
+
ds.local_path, ds.import_progress,
|
| 123 |
+
json.dumps(ds.tags),
|
| 124 |
+
json.dumps([v.model_dump() if hasattr(v, "model_dump") else v for v in ds.versions]),
|
| 125 |
+
ds.active_version,
|
| 126 |
+
1 if ds.starred else 0,
|
| 127 |
+
ds.roboflow_id,
|
| 128 |
+
ds.created_at or datetime.utcnow().isoformat(),
|
| 129 |
+
),
|
| 130 |
+
)
|
| 131 |
+
await db.commit()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
async def update_dataset_status(
|
| 135 |
+
dataset_id: str,
|
| 136 |
+
status: DatasetStatus,
|
| 137 |
+
progress: float | None = None,
|
| 138 |
+
local_path: str | None = None,
|
| 139 |
+
) -> None:
|
| 140 |
+
db = await get_db()
|
| 141 |
+
if progress is not None and local_path is not None:
|
| 142 |
+
await db.execute(
|
| 143 |
+
"UPDATE datasets SET status=?, import_progress=?, local_path=? WHERE id=?",
|
| 144 |
+
(status.value, progress, local_path, dataset_id),
|
| 145 |
+
)
|
| 146 |
+
elif progress is not None:
|
| 147 |
+
await db.execute(
|
| 148 |
+
"UPDATE datasets SET status=?, import_progress=? WHERE id=?",
|
| 149 |
+
(status.value, progress, dataset_id),
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
await db.execute(
|
| 153 |
+
"UPDATE datasets SET status=? WHERE id=?",
|
| 154 |
+
(status.value, dataset_id),
|
| 155 |
+
)
|
| 156 |
+
await db.commit()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
async def update_dataset_stats(
|
| 160 |
+
dataset_id: str,
|
| 161 |
+
images: int,
|
| 162 |
+
classes: int,
|
| 163 |
+
class_names: list[str],
|
| 164 |
+
size_bytes: int,
|
| 165 |
+
stats: dict | None = None
|
| 166 |
+
) -> None:
|
| 167 |
+
db = await get_db()
|
| 168 |
+
|
| 169 |
+
# Calculate health score if stats provided
|
| 170 |
+
health_score = 0.0
|
| 171 |
+
if stats:
|
| 172 |
+
health_score = stats.get("health_score", 0.0)
|
| 173 |
+
|
| 174 |
+
await db.execute(
|
| 175 |
+
"""UPDATE datasets
|
| 176 |
+
SET images=?, classes=?, class_names=?, size_bytes=?,
|
| 177 |
+
size_label=?, stats=?, health_score=?
|
| 178 |
+
WHERE id=?""",
|
| 179 |
+
(
|
| 180 |
+
images, classes, json.dumps(class_names),
|
| 181 |
+
size_bytes, _fmt_bytes(size_bytes),
|
| 182 |
+
json.dumps(stats) if stats else None,
|
| 183 |
+
health_score,
|
| 184 |
+
dataset_id,
|
| 185 |
+
),
|
| 186 |
+
)
|
| 187 |
+
await db.commit()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def delete_dataset(dataset_id: str) -> bool:
|
| 191 |
+
db = await get_db()
|
| 192 |
+
async with db.execute("SELECT 1 FROM datasets WHERE id=?", (dataset_id,)) as cur:
|
| 193 |
+
exists = await cur.fetchone()
|
| 194 |
+
if not exists:
|
| 195 |
+
return False
|
| 196 |
+
await db.execute("DELETE FROM datasets WHERE id=?", (dataset_id,))
|
| 197 |
+
await db.commit()
|
| 198 |
+
return True
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
async def toggle_starred(dataset_id: str) -> bool:
|
| 202 |
+
"""Toggle starred flag, return new value."""
|
| 203 |
+
db = await get_db()
|
| 204 |
+
async with db.execute("SELECT starred FROM datasets WHERE id=?", (dataset_id,)) as cur:
|
| 205 |
+
row = await cur.fetchone()
|
| 206 |
+
if not row:
|
| 207 |
+
return False
|
| 208 |
+
new_val = 0 if row["starred"] else 1
|
| 209 |
+
await db.execute("UPDATE datasets SET starred=? WHERE id=?", (new_val, dataset_id))
|
| 210 |
+
await db.commit()
|
| 211 |
+
return bool(new_val)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ── Bulk dataset upsert from Roboflow ────────────────────────────────────────
|
| 215 |
+
|
| 216 |
+
async def bulk_upsert_datasets(datasets: list[Dataset]) -> int:
|
| 217 |
+
"""Insert/update many datasets in a single transaction."""
|
| 218 |
+
if not datasets:
|
| 219 |
+
return 0
|
| 220 |
+
db = await get_db()
|
| 221 |
+
now = datetime.utcnow().isoformat()
|
| 222 |
+
rows = [
|
| 223 |
+
(
|
| 224 |
+
ds.id, ds.name, ds.description, ds.task.value, ds.format.value,
|
| 225 |
+
ds.source.value, ds.status.value,
|
| 226 |
+
ds.images, ds.classes,
|
| 227 |
+
json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
|
| 228 |
+
ds.local_path, ds.import_progress,
|
| 229 |
+
json.dumps(ds.tags), json.dumps([]),
|
| 230 |
+
ds.active_version, 0, ds.roboflow_id,
|
| 231 |
+
ds.created_at or now,
|
| 232 |
+
)
|
| 233 |
+
for ds in datasets
|
| 234 |
+
]
|
| 235 |
+
await db.executemany(
|
| 236 |
+
"""INSERT OR IGNORE INTO datasets
|
| 237 |
+
(id, name, description, task, format, source, status,
|
| 238 |
+
images, classes, class_names, size_bytes, size_label,
|
| 239 |
+
local_path, import_progress, tags, versions, active_version,
|
| 240 |
+
starred, roboflow_id, created_at)
|
| 241 |
+
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
| 242 |
+
rows,
|
| 243 |
+
)
|
| 244 |
+
await db.commit()
|
| 245 |
+
return len(datasets)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ── Dataset Jobs ──────────────────────────────────────────────────────────────
|
| 249 |
+
|
| 250 |
+
async def create_job(
|
| 251 |
+
dataset_id: str,
|
| 252 |
+
dataset_name: str,
|
| 253 |
+
job_type: str,
|
| 254 |
+
) -> DatasetJob:
|
| 255 |
+
db = await get_db()
|
| 256 |
+
job_id = f"djob-{uuid.uuid4().hex[:12]}"
|
| 257 |
+
now = datetime.utcnow().isoformat()
|
| 258 |
+
await db.execute(
|
| 259 |
+
"""INSERT INTO dataset_jobs
|
| 260 |
+
(id, type, status, dataset_id, dataset_name, progress, message, created_at)
|
| 261 |
+
VALUES (?, ?, 'queued', ?, ?, 0.0, '', ?)""",
|
| 262 |
+
(job_id, job_type, dataset_id, dataset_name, now),
|
| 263 |
+
)
|
| 264 |
+
await db.commit()
|
| 265 |
+
return DatasetJob(
|
| 266 |
+
id=job_id, type=job_type, status="queued",
|
| 267 |
+
dataset_id=dataset_id, dataset_name=dataset_name,
|
| 268 |
+
created_at=now,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
async def update_job(
|
| 273 |
+
job_id: str,
|
| 274 |
+
status: str | None = None,
|
| 275 |
+
progress: float | None = None,
|
| 276 |
+
message: str | None = None,
|
| 277 |
+
error: str | None = None,
|
| 278 |
+
started_at: str | None = None,
|
| 279 |
+
ended_at: str | None = None,
|
| 280 |
+
) -> None:
|
| 281 |
+
db = await get_db()
|
| 282 |
+
parts = []
|
| 283 |
+
params: list[Any] = []
|
| 284 |
+
if status is not None:
|
| 285 |
+
parts.append("status=?"); params.append(status)
|
| 286 |
+
if progress is not None:
|
| 287 |
+
parts.append("progress=?"); params.append(progress)
|
| 288 |
+
if message is not None:
|
| 289 |
+
parts.append("message=?"); params.append(message)
|
| 290 |
+
if error is not None:
|
| 291 |
+
parts.append("error=?"); params.append(error)
|
| 292 |
+
if started_at is not None:
|
| 293 |
+
parts.append("started_at=?"); params.append(started_at)
|
| 294 |
+
if ended_at is not None:
|
| 295 |
+
parts.append("ended_at=?"); params.append(ended_at)
|
| 296 |
+
if not parts:
|
| 297 |
+
return
|
| 298 |
+
params.append(job_id)
|
| 299 |
+
await db.execute(f"UPDATE dataset_jobs SET {', '.join(parts)} WHERE id=?", params)
|
| 300 |
+
await db.commit()
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
async def get_job(job_id: str) -> DatasetJob | None:
|
| 304 |
+
db = await get_db()
|
| 305 |
+
async with db.execute("SELECT * FROM dataset_jobs WHERE id=?", (job_id,)) as cur:
|
| 306 |
+
row = await cur.fetchone()
|
| 307 |
+
return row_to_job(row) if row else None
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
async def get_all_jobs(limit: int = 100) -> list[DatasetJob]:
|
| 311 |
+
db = await get_db()
|
| 312 |
+
async with db.execute(
|
| 313 |
+
"SELECT * FROM dataset_jobs ORDER BY created_at DESC LIMIT ?", (limit,)
|
| 314 |
+
) as cur:
|
| 315 |
+
rows = await cur.fetchall()
|
| 316 |
+
return [row_to_job(r) for r in rows]
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ── Image Index ───────────────────────────────────────────────────────────────
|
| 320 |
+
|
| 321 |
+
async def index_images(
|
| 322 |
+
dataset_id: str,
|
| 323 |
+
records: list[dict], # [{id, filename, rel_path, width, height, split, ann_count}]
|
| 324 |
+
) -> int:
|
| 325 |
+
db = await get_db()
|
| 326 |
+
await db.executemany(
|
| 327 |
+
"""INSERT OR IGNORE INTO dataset_images
|
| 328 |
+
(id, dataset_id, filename, rel_path, width, height, split, ann_count)
|
| 329 |
+
VALUES (:id, :dataset_id, :filename, :rel_path, :width, :height, :split, :ann_count)""",
|
| 330 |
+
[{"dataset_id": dataset_id, **r} for r in records],
|
| 331 |
+
)
|
| 332 |
+
await db.commit()
|
| 333 |
+
return len(records)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
async def get_image_page(
|
| 337 |
+
dataset_id: str,
|
| 338 |
+
page: int = 0,
|
| 339 |
+
page_size: int = 20,
|
| 340 |
+
split: str | None = None,
|
| 341 |
+
class_label: str | None = None,
|
| 342 |
+
) -> tuple[int, list[dict]]:
|
| 343 |
+
db = await get_db()
|
| 344 |
+
|
| 345 |
+
clauses = ["dataset_id=?"]
|
| 346 |
+
params: list[Any] = [dataset_id]
|
| 347 |
+
|
| 348 |
+
if split:
|
| 349 |
+
clauses.append("split=?")
|
| 350 |
+
params.append(split)
|
| 351 |
+
|
| 352 |
+
if class_label:
|
| 353 |
+
# Join with annotations table to filter by class
|
| 354 |
+
where = f"WHERE {' AND '.join(clauses)} AND id IN (SELECT image_id FROM dataset_annotations WHERE label=?)"
|
| 355 |
+
count_params = params + [class_label]
|
| 356 |
+
else:
|
| 357 |
+
where = f"WHERE {' AND '.join(clauses)}"
|
| 358 |
+
count_params = params
|
| 359 |
+
|
| 360 |
+
async with db.execute(f"SELECT COUNT(*) FROM dataset_images {where}", count_params) as cur:
|
| 361 |
+
total = (await cur.fetchone())[0]
|
| 362 |
+
|
| 363 |
+
params_final = count_params + [page_size, page * page_size]
|
| 364 |
+
async with db.execute(
|
| 365 |
+
f"SELECT * FROM dataset_images {where} ORDER BY filename LIMIT ? OFFSET ?", params_final
|
| 366 |
+
) as cur:
|
| 367 |
+
rows = await cur.fetchall()
|
| 368 |
+
return total, [dict(r) for r in rows]
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
async def get_annotations_for_image(image_id: str) -> list[dict]:
|
| 372 |
+
db = await get_db()
|
| 373 |
+
async with db.execute(
|
| 374 |
+
"SELECT * FROM dataset_annotations WHERE image_id=?", (image_id,)
|
| 375 |
+
) as cur:
|
| 376 |
+
rows = await cur.fetchall()
|
| 377 |
+
return [dict(r) for r in rows]
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
async def bulk_insert_annotations(records: list[dict]) -> int:
|
| 381 |
+
if not records:
|
| 382 |
+
return 0
|
| 383 |
+
db = await get_db()
|
| 384 |
+
await db.executemany(
|
| 385 |
+
"""INSERT OR IGNORE INTO dataset_annotations
|
| 386 |
+
(id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
|
| 387 |
+
normalised, area, confidence, ann_type)
|
| 388 |
+
VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
|
| 389 |
+
:normalised,:area,:confidence,:ann_type)""",
|
| 390 |
+
records,
|
| 391 |
+
)
|
| 392 |
+
await db.commit()
|
| 393 |
+
return len(records)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# ── Universal Dataset Items ──────────────────────────────────────────────
|
| 397 |
+
|
| 398 |
+
async def get_universal_items(
|
| 399 |
+
self,
|
| 400 |
+
dataset_id: str,
|
| 401 |
+
page: int = 0,
|
| 402 |
+
page_size: int = 20,
|
| 403 |
+
split: str | None = None,
|
| 404 |
+
class_label: str | None = None,
|
| 405 |
+
) -> tuple[int, list[dict]]:
|
| 406 |
+
"""Fetch polymorphic dataset items (images, text rows, etc.) and their annotations."""
|
| 407 |
+
db = await get_db()
|
| 408 |
+
|
| 409 |
+
# 1. Get total and base item records
|
| 410 |
+
total, items = await self.get_image_page(dataset_id, page, page_size, split, class_label)
|
| 411 |
+
|
| 412 |
+
# 2. Convert to universal format
|
| 413 |
+
# This is a bridge until we fully move to the universal schema
|
| 414 |
+
return total, items
|
| 415 |
+
|
| 416 |
+
async def bulk_insert_universal_annotations(self, records: list[dict]) -> int:
|
| 417 |
+
"""Insert universal annotations into the extended schema."""
|
| 418 |
+
if not records:
|
| 419 |
+
return 0
|
| 420 |
+
db = await get_db()
|
| 421 |
+
await db.executemany(
|
| 422 |
+
"""INSERT OR IGNORE INTO dataset_annotations
|
| 423 |
+
(id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
|
| 424 |
+
normalised, area, confidence, ann_type, segmentation, keypoints, metadata)
|
| 425 |
+
VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
|
| 426 |
+
:normalised,:area,:confidence,:ann_type,:segmentation,:keypoints,:metadata)""",
|
| 427 |
+
records,
|
| 428 |
+
)
|
| 429 |
+
await db.commit()
|
| 430 |
+
return len(records)
|
| 431 |
+
|
| 432 |
+
async def update_dataset_task(dataset_id: str, task: str) -> None:
|
| 433 |
+
db = await get_db()
|
| 434 |
+
await db.execute("UPDATE datasets SET task=? WHERE id=?", (task, dataset_id))
|
| 435 |
+
await db.commit()
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
async def cleanup_stale_jobs() -> None:
|
| 439 |
+
"""Mark running/queued jobs as failed on startup."""
|
| 440 |
+
db = await get_db()
|
| 441 |
+
await db.execute(
|
| 442 |
+
"UPDATE dataset_jobs SET status='failed', error='System restart' WHERE status IN ('running', 'queued')"
|
| 443 |
+
)
|
| 444 |
+
await db.commit()
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _fmt_bytes(n: int) -> str:
|
| 448 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 449 |
+
if n < 1024:
|
| 450 |
+
return f"{n:.1f} {unit}"
|
| 451 |
+
n /= 1024
|
| 452 |
+
return f"{n:.1f} PB"
|
datasets/viewer_service.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets/viewer_service.py — Dataset Viewer Service.
|
| 3 |
+
|
| 4 |
+
Provides paginated image + annotation serving for the Dataset Viewer UI.
|
| 5 |
+
All paths are resolved relative to the dataset's local_path for security.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from datasets import registry as ds_reg
|
| 12 |
+
from models.dataset import (
|
| 13 |
+
Annotation, AnnotationType, BoundingBox, Dataset,
|
| 14 |
+
ImageRecord, ViewerPage, DatasetFormat
|
| 15 |
+
)
|
| 16 |
+
from datasets.annotation_parser import YOLOParser, COCOParser, VOCParser, CSVParser
|
| 17 |
+
from observability.logger import get_logger
|
| 18 |
+
|
| 19 |
+
log = get_logger("viewer_service")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from .format_adapters import NLPAdapter, TabularAdapter
|
| 23 |
+
from models.dataset import UniversalViewerPage, UniversalDatasetItem, UniversalAnnotation, DatasetContentType, DatasetTask
|
| 24 |
+
|
| 25 |
+
async def get_universal_viewer_page(
|
| 26 |
+
dataset_id: str,
|
| 27 |
+
page: int = 0,
|
| 28 |
+
page_size: int = 20,
|
| 29 |
+
split: str | None = None,
|
| 30 |
+
class_label: str | None = None,
|
| 31 |
+
) -> UniversalViewerPage:
|
| 32 |
+
"""Polymorphic viewer endpoint that adapts based on dataset task."""
|
| 33 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 34 |
+
if not ds:
|
| 35 |
+
raise ValueError("Dataset not found")
|
| 36 |
+
|
| 37 |
+
ds_root = Path(ds.local_path) if ds.local_path else None
|
| 38 |
+
|
| 39 |
+
# 1. Vision Tasks (Detection, Seg, Pose) -> Use existing image-centric logic
|
| 40 |
+
if ds.task in (DatasetTask.detection, DatasetTask.segmentation, DatasetTask.keypoints):
|
| 41 |
+
# We wrap the existing get_viewer_page and transform to UniversalDatasetItem
|
| 42 |
+
old_page = await get_viewer_page(dataset_id, page, page_size, split, class_label)
|
| 43 |
+
|
| 44 |
+
items = []
|
| 45 |
+
for img in old_page.images:
|
| 46 |
+
items.append(UniversalDatasetItem(
|
| 47 |
+
id=img.image_id,
|
| 48 |
+
content_type=DatasetContentType.image,
|
| 49 |
+
filename=img.filename,
|
| 50 |
+
metadata={"width": img.width, "height": img.height, "split": img.split},
|
| 51 |
+
annotations=[
|
| 52 |
+
UniversalAnnotation(
|
| 53 |
+
label=ann.label,
|
| 54 |
+
type=ann.type.value if hasattr(ann.type, 'value') else str(ann.type),
|
| 55 |
+
bbox=[ann.bbox.x, ann.bbox.y, ann.bbox.width, ann.bbox.height] if ann.bbox else None,
|
| 56 |
+
segmentation=ann.segmentation,
|
| 57 |
+
keypoints=ann.keypoints,
|
| 58 |
+
confidence=ann.confidence,
|
| 59 |
+
metadata=ann.metadata
|
| 60 |
+
) for ann in img.annotations
|
| 61 |
+
]
|
| 62 |
+
))
|
| 63 |
+
|
| 64 |
+
return UniversalViewerPage(
|
| 65 |
+
dataset_id=dataset_id,
|
| 66 |
+
page=page,
|
| 67 |
+
page_size=page_size,
|
| 68 |
+
total=old_page.total,
|
| 69 |
+
total_pages=old_page.total_pages,
|
| 70 |
+
items=items
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# 2. NLP Tasks (CSV, JSONL)
|
| 74 |
+
elif ds.task == DatasetTask.nlp and ds_root:
|
| 75 |
+
adapter = NLPAdapter()
|
| 76 |
+
total, items = await adapter.get_items(ds_root, page, page_size)
|
| 77 |
+
total_pages = max(1, (total + page_size - 1) // page_size)
|
| 78 |
+
return UniversalViewerPage(
|
| 79 |
+
dataset_id=dataset_id,
|
| 80 |
+
page=page,
|
| 81 |
+
page_size=page_size,
|
| 82 |
+
total=total,
|
| 83 |
+
total_pages=total_pages,
|
| 84 |
+
items=items
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 3. Tabular Tasks (CSV, Parquet)
|
| 88 |
+
elif ds.task == DatasetTask.tabular and ds_root:
|
| 89 |
+
adapter = TabularAdapter()
|
| 90 |
+
total, items = await adapter.get_items(ds_root, page, page_size)
|
| 91 |
+
total_pages = max(1, (total + page_size - 1) // page_size)
|
| 92 |
+
return UniversalViewerPage(
|
| 93 |
+
dataset_id=dataset_id,
|
| 94 |
+
page=page,
|
| 95 |
+
page_size=page_size,
|
| 96 |
+
total=total,
|
| 97 |
+
total_pages=total_pages,
|
| 98 |
+
items=items
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Fallback / Empty
|
| 102 |
+
return UniversalViewerPage(
|
| 103 |
+
dataset_id=dataset_id,
|
| 104 |
+
page=page,
|
| 105 |
+
page_size=page_size,
|
| 106 |
+
total=0,
|
| 107 |
+
total_pages=0,
|
| 108 |
+
items=[]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
async def get_viewer_page(
|
| 112 |
+
dataset_id: str,
|
| 113 |
+
page: int = 0,
|
| 114 |
+
page_size: int = 20,
|
| 115 |
+
split: str | None = None,
|
| 116 |
+
class_label: str | None = None,
|
| 117 |
+
) -> ViewerPage:
|
| 118 |
+
"""
|
| 119 |
+
Return a paginated viewer page for the dataset.
|
| 120 |
+
Images come from the index; annotations are loaded per-image.
|
| 121 |
+
"""
|
| 122 |
+
if page_size > 100:
|
| 123 |
+
page_size = 100 # cap to prevent huge payloads
|
| 124 |
+
|
| 125 |
+
total, image_rows = await ds_reg.get_image_page(dataset_id, page, page_size, split, class_label)
|
| 126 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 127 |
+
|
| 128 |
+
# Check if we have an active project and if the dataset exists there
|
| 129 |
+
from projects.service import get_active_project_path
|
| 130 |
+
project_path = await get_active_project_path()
|
| 131 |
+
|
| 132 |
+
# Dynamically load annotations from database first, fallback to filesystem if needed
|
| 133 |
+
image_ids = [row["id"] for row in image_rows]
|
| 134 |
+
dynamic_anns: dict[str, list[Annotation]] = {img_id: [] for img_id in image_ids}
|
| 135 |
+
|
| 136 |
+
# 1. Try loading from DB index (Authoritative for analytics)
|
| 137 |
+
try:
|
| 138 |
+
from database.connection import get_db
|
| 139 |
+
db = await get_db()
|
| 140 |
+
# Fetch all annotations for these images in one go
|
| 141 |
+
placeholders = ",".join(["?"] * len(image_ids))
|
| 142 |
+
async with db.execute(
|
| 143 |
+
f"SELECT * FROM dataset_annotations WHERE image_id IN ({placeholders})",
|
| 144 |
+
image_ids
|
| 145 |
+
) as cur:
|
| 146 |
+
rows = await cur.fetchall()
|
| 147 |
+
for r in rows:
|
| 148 |
+
dynamic_anns[r["image_id"]].append(_row_to_annotation(dict(r)))
|
| 149 |
+
except Exception as e:
|
| 150 |
+
log.warning("db_annotation_read_failed", error=str(e), dataset_id=dataset_id)
|
| 151 |
+
|
| 152 |
+
# 2. Fallback to filesystem if no annotations found in DB and we have a path
|
| 153 |
+
# This maintains compatibility with old datasets or specific live-read needs
|
| 154 |
+
if all(not anns for anns in dynamic_anns.values()) and ds and ds.local_path:
|
| 155 |
+
ds_root = Path(ds.local_path)
|
| 156 |
+
# Use ds.local_path directly as it is now authoritative project-local path
|
| 157 |
+
# Fallback to global removed per user request
|
| 158 |
+
|
| 159 |
+
fmt = ds.format.value if hasattr(ds.format, 'value') else str(ds.format)
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
if fmt == DatasetFormat.yolo.value or fmt == "yolo":
|
| 163 |
+
class_map = YOLOParser.load_class_map(ds_root)
|
| 164 |
+
for row in image_rows:
|
| 165 |
+
rel_path = Path(row["rel_path"])
|
| 166 |
+
# For YOLO, the label file is usually in a parallel 'labels' folder
|
| 167 |
+
# or in the same folder as the image.
|
| 168 |
+
# Roboflow structure: train/images/img.jpg -> train/labels/img.txt
|
| 169 |
+
parts = list(rel_path.parts)
|
| 170 |
+
|
| 171 |
+
label_rel = None
|
| 172 |
+
if "images" in parts:
|
| 173 |
+
idx = parts.index("images")
|
| 174 |
+
parts_labels = list(parts)
|
| 175 |
+
parts_labels[idx] = "labels"
|
| 176 |
+
label_rel = Path(*parts_labels).with_suffix(".txt")
|
| 177 |
+
|
| 178 |
+
# Fallback: same folder
|
| 179 |
+
label_same_folder = rel_path.with_suffix(".txt")
|
| 180 |
+
|
| 181 |
+
for cand_rel in [label_rel, label_same_folder]:
|
| 182 |
+
if not cand_rel: continue
|
| 183 |
+
label_file = ds_root / cand_rel
|
| 184 |
+
if label_file.exists():
|
| 185 |
+
anns = YOLOParser.parse_file(label_file, row["id"], ds.id, class_map)
|
| 186 |
+
dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns]
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
elif fmt == DatasetFormat.coco.value or fmt == "coco":
|
| 190 |
+
jsons = COCOParser.find_annotation_files(ds_root)
|
| 191 |
+
img_map = {row["filename"]: row["id"] for row in image_rows}
|
| 192 |
+
for jf in jsons:
|
| 193 |
+
_, parsed = COCOParser.parse_file(jf, ds.id)
|
| 194 |
+
for p_rel, _, _, anns in parsed:
|
| 195 |
+
fname = Path(p_rel).name
|
| 196 |
+
if fname in img_map:
|
| 197 |
+
img_id = img_map[fname]
|
| 198 |
+
dynamic_anns[img_id].extend([_row_to_annotation(a) for a in anns])
|
| 199 |
+
|
| 200 |
+
elif fmt == DatasetFormat.voc.value or fmt == "voc":
|
| 201 |
+
for row in image_rows:
|
| 202 |
+
img_abs = ds_root / row["rel_path"]
|
| 203 |
+
xml_candidates = [img_abs.with_suffix(".xml")]
|
| 204 |
+
parts = list(Path(row["rel_path"]).parts)
|
| 205 |
+
if "JPEGImages" in parts:
|
| 206 |
+
idx = parts.index("JPEGImages")
|
| 207 |
+
parts[idx] = "Annotations"
|
| 208 |
+
xml_candidates.append(ds_root.joinpath(*parts).with_suffix(".xml"))
|
| 209 |
+
|
| 210 |
+
for cand in xml_candidates:
|
| 211 |
+
if cand.exists():
|
| 212 |
+
_, _, _, anns = VOCParser.parse_file(cand, row["id"], ds.id)
|
| 213 |
+
dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns]
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
elif fmt == "csv":
|
| 217 |
+
for row in image_rows:
|
| 218 |
+
csv_path = ds_root / row["rel_path"]
|
| 219 |
+
if csv_path.exists():
|
| 220 |
+
# For CSV/NLP, we might need a more specific way to find the exact row,
|
| 221 |
+
# but for now we reload the file or use a cached version.
|
| 222 |
+
# Since get_viewer_page is paginated, we'll parse the file.
|
| 223 |
+
anns = CSVParser.parse_file(csv_path, ds.id)
|
| 224 |
+
# Find the annotation matching this "image_id" (which is the text entry id)
|
| 225 |
+
matching_anns = [a for a in anns if a["image_id"] == row["id"]]
|
| 226 |
+
dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in matching_anns]
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
log.error("dynamic_annotation_read_failed", error=str(e), dataset_id=dataset_id)
|
| 230 |
+
|
| 231 |
+
images: list[ImageRecord] = []
|
| 232 |
+
for row in image_rows:
|
| 233 |
+
annotations = dynamic_anns.get(row["id"], [])
|
| 234 |
+
images.append(ImageRecord(
|
| 235 |
+
image_id = row["id"],
|
| 236 |
+
filename = row["filename"],
|
| 237 |
+
width = row["width"],
|
| 238 |
+
height = row["height"],
|
| 239 |
+
path = row["rel_path"],
|
| 240 |
+
annotations = annotations,
|
| 241 |
+
split = row["split"],
|
| 242 |
+
))
|
| 243 |
+
|
| 244 |
+
total_pages = max(1, (total + page_size - 1) // page_size)
|
| 245 |
+
|
| 246 |
+
return ViewerPage(
|
| 247 |
+
dataset_id = dataset_id,
|
| 248 |
+
page = page,
|
| 249 |
+
page_size = page_size,
|
| 250 |
+
total = total,
|
| 251 |
+
total_pages = total_pages,
|
| 252 |
+
images = images,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _row_to_annotation(row: dict) -> Annotation:
|
| 257 |
+
bbox = None
|
| 258 |
+
if row.get("bbox_x") is not None:
|
| 259 |
+
bbox = BoundingBox(
|
| 260 |
+
x = row["bbox_x"],
|
| 261 |
+
y = row["bbox_y"],
|
| 262 |
+
width = row["bbox_w"],
|
| 263 |
+
height = row["bbox_h"],
|
| 264 |
+
normalised = bool(row.get("normalised", 1)),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
segmentation = None
|
| 268 |
+
if row.get("segmentation"):
|
| 269 |
+
try:
|
| 270 |
+
import json
|
| 271 |
+
segmentation = json.loads(row["segmentation"])
|
| 272 |
+
except:
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
return Annotation(
|
| 276 |
+
label = row["label"],
|
| 277 |
+
bbox = bbox,
|
| 278 |
+
segmentation = segmentation,
|
| 279 |
+
confidence = row.get("confidence"),
|
| 280 |
+
area = row.get("area"),
|
| 281 |
+
type = AnnotationType(row.get("ann_type", "detection")),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
async def resolve_image_path(dataset_id: str, image_id: str) -> Path | None:
|
| 286 |
+
"""
|
| 287 |
+
Resolve the absolute filesystem path for an image.
|
| 288 |
+
Prioritizes the active project's dataset folder, falling back to the global cache.
|
| 289 |
+
Returns None if dataset not imported or image not found.
|
| 290 |
+
"""
|
| 291 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 292 |
+
if ds is None or not ds.local_path:
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
base_root = Path(ds.local_path)
|
| 296 |
+
# ds.local_path is now authoritative project-local path
|
| 297 |
+
# Fallback removed per user request
|
| 298 |
+
|
| 299 |
+
from database.connection import get_db
|
| 300 |
+
db = await get_db()
|
| 301 |
+
async with db.execute(
|
| 302 |
+
"SELECT rel_path FROM dataset_images WHERE id=? AND dataset_id=?",
|
| 303 |
+
(image_id, dataset_id),
|
| 304 |
+
) as cur:
|
| 305 |
+
row = await cur.fetchone()
|
| 306 |
+
if not row:
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
abs_path = base_root / row["rel_path"]
|
| 310 |
+
if not abs_path.exists():
|
| 311 |
+
return None
|
| 312 |
+
|
| 313 |
+
# Security: ensure path is under base_root
|
| 314 |
+
try:
|
| 315 |
+
abs_path.resolve().relative_to(base_root.resolve())
|
| 316 |
+
except ValueError:
|
| 317 |
+
log.warning("path_traversal_attempt", dataset_id=dataset_id, image_id=image_id)
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
return abs_path
|
download/__init__.py
ADDED
|
File without changes
|
download/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|