senthil2421 commited on
Commit
e10cda2
·
1 Parent(s): ee35993

Refactor cloud_backend: remove local execution routes and fix missing modules

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/config.cpython-310.pyc +0 -0
  2. __pycache__/main.cpython-310.pyc +0 -0
  3. adapters/__init__.py +0 -0
  4. adapters/__pycache__/__init__.cpython-310.pyc +0 -0
  5. adapters/__pycache__/base.cpython-310.pyc +0 -0
  6. adapters/__pycache__/hf_adapter.cpython-310.pyc +0 -0
  7. adapters/__pycache__/onnx_adapter.cpython-310.pyc +0 -0
  8. adapters/__pycache__/roboflow_adapter.cpython-310.pyc +0 -0
  9. adapters/base.py +28 -0
  10. adapters/hf_adapter.py +415 -0
  11. adapters/onnx_adapter.py +176 -0
  12. adapters/roboflow_adapter.py +353 -0
  13. benchmark/__init__.py +1 -0
  14. benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
  15. benchmark/__pycache__/compatibility.cpython-310.pyc +0 -0
  16. benchmark/__pycache__/execution.cpython-310.pyc +0 -0
  17. benchmark/__pycache__/metrics.cpython-310.pyc +0 -0
  18. benchmark/__pycache__/orchestrator.cpython-310.pyc +0 -0
  19. benchmark/__pycache__/registry.cpython-310.pyc +0 -0
  20. benchmark/__pycache__/telemetry.cpython-310.pyc +0 -0
  21. benchmark/adapters/__pycache__/base.cpython-310.pyc +0 -0
  22. benchmark/adapters/__pycache__/registry.cpython-310.pyc +0 -0
  23. benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc +0 -0
  24. benchmark/adapters/base.py +38 -0
  25. benchmark/adapters/optimum_runner.py +53 -0
  26. benchmark/adapters/registry.py +44 -0
  27. benchmark/adapters/torch_runner.py +45 -0
  28. benchmark/compatibility.py +360 -0
  29. benchmark/execution.py +366 -0
  30. benchmark/metrics.py +110 -0
  31. benchmark/orchestrator.py +374 -0
  32. benchmark/registry.py +302 -0
  33. benchmark/telemetry.py +182 -0
  34. benchmark/torch_runner.py +142 -0
  35. datasets/__init__.py +1 -0
  36. datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  37. datasets/__pycache__/annotation_parser.cpython-310.pyc +0 -0
  38. datasets/__pycache__/base_adapter.cpython-310.pyc +0 -0
  39. datasets/__pycache__/format_adapters.cpython-310.pyc +0 -0
  40. datasets/__pycache__/import_service.cpython-310.pyc +0 -0
  41. datasets/__pycache__/registry.cpython-310.pyc +0 -0
  42. datasets/__pycache__/viewer_service.cpython-310.pyc +0 -0
  43. datasets/annotation_parser.py +576 -0
  44. datasets/base_adapter.py +37 -0
  45. datasets/format_adapters.py +235 -0
  46. datasets/import_service.py +589 -0
  47. datasets/registry.py +452 -0
  48. datasets/viewer_service.py +320 -0
  49. download/__init__.py +0 -0
  50. download/__pycache__/__init__.cpython-310.pyc +0 -0
__pycache__/config.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (3.43 kB). View file
 
adapters/__init__.py ADDED
File without changes
adapters/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
adapters/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
adapters/__pycache__/hf_adapter.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
adapters/__pycache__/onnx_adapter.cpython-310.pyc ADDED
Binary file (5.27 kB). View file
 
adapters/__pycache__/roboflow_adapter.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
adapters/base.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adapters/base.py — Abstract base class every source adapter must implement.
3
+ Enforces a stable contract so the registry never knows which adapter runs.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from models.model import Model
10
+
11
+
12
+ class BaseAdapter(ABC):
13
+ """Fetch models from an external source and normalize to the Model schema."""
14
+
15
+ source_name: str = "unknown"
16
+
17
+ @abstractmethod
18
+ async def fetch_models(self) -> list[Model]:
19
+ """Return a list of normalized Model objects from the source."""
20
+ ...
21
+
22
+ def _format_size(self, bytes_: int) -> str:
23
+ """Human-readable file size."""
24
+ for unit in ("B", "KB", "MB", "GB", "TB"):
25
+ if bytes_ < 1024:
26
+ return f"{bytes_:.1f} {unit}"
27
+ bytes_ //= 1024
28
+ return f"{bytes_} PB"
adapters/hf_adapter.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adapters/hf_adapter.py — Hugging Face Hub adapter.
3
+ Fetches real models via the public HF API and normalises them to our schema.
4
+
5
+ Rate-limits respected via polite delays. Requires no authentication for
6
+ publicly accessible models; set HF_TOKEN env var for higher rate-limits.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import re
12
+ from typing import Any
13
+
14
+
15
+ def _is_shard_file(filename: str) -> bool:
16
+ """Return True for sharded weight files like model-00001-of-00003.safetensors."""
17
+ return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
18
+
19
+ import httpx
20
+ from tenacity import retry, stop_after_attempt, wait_exponential
21
+
22
+ from adapters.base import BaseAdapter
23
+ from config import settings
24
+ from models.model import Model, ModelMetrics, ModelVersion
25
+ from observability.logger import get_logger
26
+
27
+ log = get_logger("hf_adapter")
28
+
29
+ # ── Task mapping: HF pipeline_tag → our internal task ─────────────────────────
30
+ HF_TASK_MAP: dict[str, str] = {
31
+ "object-detection": "detection",
32
+ "image-classification": "classification",
33
+ "image-segmentation": "segmentation",
34
+ "text-to-image": "generation",
35
+ "image-to-image": "generation",
36
+ "image-feature-extraction": "embedding",
37
+ }
38
+
39
+ # Tasks we actively fetch
40
+ FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys())
41
+
42
+ # ── Framework detection ────────────────────────────────────────────────────────
43
+ def _detect_framework(tags: list[str], model_id: str) -> str:
44
+ tag_str = " ".join(tags + [model_id]).lower()
45
+ if "onnx" in tag_str: return "onnx"
46
+ if "tflite" in tag_str: return "tflite"
47
+ if "coreml" in tag_str: return "coreml"
48
+ if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow"
49
+ return "pytorch" # HF default
50
+
51
+ # ── Hardware detection ─────────────────────────────────────────────────────────
52
+ def _detect_hardware(tags: list[str]) -> list[str]:
53
+ hw: list[str] = []
54
+ tag_str = " ".join(tags).lower()
55
+ if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu")
56
+ if "edge" in tag_str or "mobile" in tag_str: hw.append("edge")
57
+ if "cpu" in tag_str: hw.append("cpu")
58
+ if not hw: hw.append("gpu") # safe default
59
+ return hw
60
+
61
+ # ── Internal tag normalisation ─────────────────────────────────────────────────
62
+ QUALITY_TAG_MAP = {
63
+ "state-of-the-art": "sota",
64
+ "lightweight": "lightweight",
65
+ "tiny": "tiny",
66
+ "fast": "fastest",
67
+ "real-time": "real-time",
68
+ "accuracy": "high-accuracy",
69
+ }
70
+
71
+ def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]:
72
+ out: list[str] = []
73
+ for t in raw_tags:
74
+ t_lower = t.lower()
75
+ for keyword, mapped in QUALITY_TAG_MAP.items():
76
+ if keyword in t_lower:
77
+ out.append(mapped)
78
+ # keep relevant library / dataset tags
79
+ if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")):
80
+ continue
81
+ out.append(t_lower)
82
+ # add pipeline as tag
83
+ if pipeline:
84
+ out.append(pipeline.replace("-", "_"))
85
+ return list(dict.fromkeys(out)) # deduplicate, preserve order
86
+
87
+
88
+ class HFAdapter(BaseAdapter):
89
+ source_name = "hf"
90
+
91
+ def __init__(self) -> None:
92
+ headers = {"Accept": "application/json"}
93
+ if settings.hf_token:
94
+ headers["Authorization"] = f"Bearer {settings.hf_token}"
95
+ self._client = httpx.AsyncClient(
96
+ base_url=settings.hf_api_base,
97
+ headers=headers,
98
+ timeout=30,
99
+ )
100
+
101
+ @retry(
102
+ stop=stop_after_attempt(3),
103
+ wait=wait_exponential(multiplier=1, min=2, max=10),
104
+ reraise=True,
105
+ )
106
+ async def _fetch_task_page(
107
+ self, pipeline_tag: str, limit: int = 100
108
+ ) -> list[dict[str, Any]]:
109
+ params = {
110
+ "pipeline_tag": pipeline_tag,
111
+ "sort": "downloads",
112
+ "direction": -1, # descending
113
+ "limit": limit,
114
+ "full": "True",
115
+ }
116
+ log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit)
117
+ resp = await self._client.get("/models", params=params)
118
+ resp.raise_for_status()
119
+ return resp.json()
120
+
121
+ @retry(
122
+ stop=stop_after_attempt(3),
123
+ wait=wait_exponential(multiplier=1, min=2, max=10),
124
+ reraise=True,
125
+ )
126
+ async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]:
127
+ resp = await self._client.get(f"/models/{model_id}", params={"full": "True"})
128
+ resp.raise_for_status()
129
+ raw = resp.json()
130
+
131
+ siblings: list[dict[str, Any]] = raw.get("siblings") or []
132
+ has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings)
133
+ if not has_any_size:
134
+ try:
135
+ tree = await self._fetch_model_tree(model_id, revision="main")
136
+ size_by_path: dict[str, int] = {
137
+ (t.get("path") or ""): int(t.get("size") or 0)
138
+ for t in (tree or [])
139
+ if isinstance(t, dict)
140
+ }
141
+
142
+ patched: list[dict[str, Any]] = []
143
+ for s in siblings:
144
+ if not isinstance(s, dict):
145
+ continue
146
+ fn = s.get("rfilename") or s.get("path") or ""
147
+ if fn and not s.get("size") and fn in size_by_path:
148
+ s = {**s, "size": size_by_path[fn]}
149
+ patched.append(s)
150
+ raw["siblings"] = patched
151
+ except Exception:
152
+ pass
153
+
154
+ return raw
155
+
156
+ @retry(
157
+ stop=stop_after_attempt(3),
158
+ wait=wait_exponential(multiplier=1, min=2, max=10),
159
+ reraise=True,
160
+ )
161
+ async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]:
162
+ resp = await self._client.get(f"/models/{model_id}/tree/{revision}")
163
+ resp.raise_for_status()
164
+ data = resp.json()
165
+ if isinstance(data, list):
166
+ return data
167
+ return []
168
+
169
+ def _parse_safe_tensors_size(self, siblings: list[dict]) -> int:
170
+ """Estimate model size from sibling file list."""
171
+ total = 0
172
+ weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
173
+ for s in siblings or []:
174
+ filename = s.get("rfilename", "").lower()
175
+ if filename.endswith(weight_exts):
176
+ total += s.get("size", 0)
177
+
178
+ if total > 0:
179
+ return total
180
+
181
+ # If no size found in siblings, check if it's in the root dict (sometimes HF API does this)
182
+ return 0 # Return 0 if not found, we'll handle fallback in _make_model
183
+
184
+ @retry(
185
+ stop=stop_after_attempt(3),
186
+ wait=wait_exponential(multiplier=1, min=2, max=10),
187
+ reraise=True,
188
+ )
189
+ async def _fetch_model_card(self, model_id: str) -> str:
190
+ """Fetch model card (README.md) content for real-time description."""
191
+ url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md"
192
+ try:
193
+ resp = await self._client.get(url)
194
+ if resp.status_code == 200:
195
+ return resp.text
196
+ except Exception:
197
+ pass
198
+ return ""
199
+
200
+ def _extract_description(self, readme: str, raw: dict[str, Any]) -> str:
201
+ """Extract a clean description from README or card data."""
202
+ if readme:
203
+ # Simple heuristic: take first paragraph that isn't frontmatter
204
+ lines = readme.split("\n")
205
+ in_frontmatter = False
206
+ for line in lines:
207
+ if line.strip() == "---":
208
+ in_frontmatter = not in_frontmatter
209
+ continue
210
+ if not in_frontmatter and line.strip() and not line.startswith("#"):
211
+ return line.strip()[:500]
212
+
213
+ card_data = raw.get("cardData") or {}
214
+ description: str = (
215
+ (card_data.get("summary") or "")
216
+ or (card_data.get("description") or "")
217
+ or (raw.get("description") or "")
218
+ ).strip()
219
+ return description
220
+
221
+ def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics:
222
+ """
223
+ Product-Grade Metrics Estimation.
224
+ Uses model name heuristics to provide realistic data for common architectures.
225
+ """
226
+ metrics = ModelMetrics()
227
+ m_id = model_id.lower()
228
+
229
+ # Base latency/vram estimates by architecture
230
+ if "vit" in m_id or "dinov2" in m_id:
231
+ metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0
232
+ metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8
233
+ metrics.accuracy = 82.4 if "base" in m_id else 84.5
234
+ elif "segformer" in m_id:
235
+ # b0, b1, b2, b3, b4, b5
236
+ if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0
237
+ elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0
238
+ elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0
239
+ else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0
240
+ elif "convnext" in m_id:
241
+ metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0
242
+ metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2
243
+ metrics.accuracy = 81.0 if "tiny" in m_id else 83.5
244
+ elif "yolo" in m_id:
245
+ # n, s, m, l, x
246
+ if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3
247
+ elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9
248
+ elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2
249
+ else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0
250
+
251
+ # Generic task-based fallbacks if still empty
252
+ if metrics.latency_ms is None:
253
+ if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0
254
+ elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0
255
+ elif task == "embedding": metrics.latency_ms = 40.0
256
+ elif task == "generation": metrics.latency_ms = 1500.0
257
+
258
+ return metrics
259
+
260
+ def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None:
261
+ model_id: str = raw.get("id") or raw.get("modelId", "")
262
+ if not model_id:
263
+ return None
264
+
265
+ task = HF_TASK_MAP.get(pipeline_tag)
266
+ if not task:
267
+ return None
268
+ tags_raw: list[str] = raw.get("tags") or []
269
+ framework = _detect_framework(tags_raw, model_id)
270
+ hardware = _detect_hardware(tags_raw)
271
+ tags = _normalise_tags(tags_raw, pipeline_tag)
272
+
273
+ # Size
274
+ siblings: list[dict] = raw.get("siblings") or []
275
+ size = self._parse_safe_tensors_size(siblings)
276
+ if size == 0:
277
+ # Fallback based on model type if size not found
278
+ if "large" in model_id.lower(): size = 1_200_000_000
279
+ elif "base" in model_id.lower(): size = 500_000_000
280
+ elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000
281
+ else: size = 450_000_000 # More realistic general default than exactly 500MB
282
+
283
+ # Provider — author part of model_id
284
+ provider = model_id.split("/")[0] if "/" in model_id else "community"
285
+
286
+ # safe name
287
+ name = model_id.split("/")[-1] if "/" in model_id else model_id
288
+ # Clean ugly names
289
+ name = re.sub(r"[-_]+", "-", name).strip("-")
290
+
291
+ downloads = raw.get("downloads") or 0
292
+ likes = raw.get("likes") or 0
293
+
294
+ # Fabricate a sensible version from last modified
295
+ last_mod: str = raw.get("lastModified") or raw.get("createdAt") or ""
296
+ release_date = last_mod[:10] if last_mod else "2024-01-01"
297
+ sha8 = (raw.get("sha") or "main")[:8]
298
+
299
+ # Build versions from weight files in the repo (one per distinct weight file)
300
+ weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
301
+ weight_files = [
302
+ s for s in siblings
303
+ if s.get("rfilename", "").lower().endswith(weight_exts)
304
+ and not _is_shard_file(s.get("rfilename", ""))
305
+ ]
306
+
307
+ if len(weight_files) > 1:
308
+ versions = []
309
+ for s in weight_files[:15]:
310
+ filename = s["rfilename"]
311
+ # Detect variant from filename (n, s, m, l, x, or specific labels)
312
+ variant_label = "Stable"
313
+ fn_lower = filename.lower()
314
+ if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano"
315
+ elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small"
316
+ elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium"
317
+ elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large"
318
+ elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge"
319
+
320
+ versions.append(ModelVersion(
321
+ version=filename.replace(".", "_"),
322
+ label=variant_label,
323
+ description=f"Model variant: {filename}",
324
+ releaseDate=release_date,
325
+ changelog=None,
326
+ ))
327
+ else:
328
+ versions = [
329
+ ModelVersion(
330
+ version=sha8,
331
+ label="Latest",
332
+ description="Primary model weight file.",
333
+ releaseDate=release_date,
334
+ changelog=None,
335
+ )
336
+ ]
337
+
338
+ # Description from card data
339
+ description = self._extract_description("", raw)
340
+ if not description:
341
+ description = f"{task.capitalize()} model by {provider}."
342
+
343
+ # Metrics Estimation
344
+ metrics = self._estimate_metrics(model_id, task)
345
+
346
+ return Model(
347
+ id = model_id.replace("/", "_").lower(),
348
+ name = name,
349
+ task = task,
350
+ framework = framework,
351
+ source = "hf",
352
+ provider = provider,
353
+ description = description,
354
+ download_url = f"https://huggingface.co/{model_id}",
355
+ size = size,
356
+ size_label = self._format_size(size),
357
+ tags = tags,
358
+ hardware = hardware,
359
+ status = "available",
360
+ downloaded = False,
361
+ downloads = downloads,
362
+ rating = min(5.0, (likes / 200) + 3.5) if likes else None,
363
+ liked = False,
364
+ metrics = metrics,
365
+ versions = versions,
366
+ )
367
+
368
+ async def fetch_models(self) -> list[Model]:
369
+ models: list[Model] = []
370
+ seen_ids: set[str] = set()
371
+
372
+ for pipeline_tag in FETCH_TASKS:
373
+ try:
374
+ raw_list = await self._fetch_task_page(
375
+ pipeline_tag, limit=settings.hf_models_per_task
376
+ )
377
+ for idx, raw in enumerate(raw_list):
378
+ # Enrich top-N per task with full model detail so siblings include sizes.
379
+ if idx < 10:
380
+ original_id = raw.get("id") or raw.get("modelId")
381
+ if original_id:
382
+ try:
383
+ raw = await self._fetch_model_detail(original_id)
384
+ except Exception:
385
+ pass
386
+
387
+ m = self._make_model(raw, pipeline_tag)
388
+ if m and m.id not in seen_ids:
389
+ # Try to fetch real-time description for the first 5 models of each task
390
+ if len([mod for mod in models if mod.task == m.task]) < 5:
391
+ original_id = raw.get("id") or raw.get("modelId")
392
+ if original_id:
393
+ readme = await self._fetch_model_card(original_id)
394
+ if readme:
395
+ m.description = self._extract_description(readme, raw)
396
+
397
+ seen_ids.add(m.id)
398
+ models.append(m)
399
+ # Be polite to HF API
400
+ await asyncio.sleep(0.3)
401
+ except Exception as exc:
402
+ log.warning(
403
+ "hf_fetch_task_failed",
404
+ pipeline_tag=pipeline_tag,
405
+ error=str(exc),
406
+ )
407
+
408
+ log.info("hf_fetch_complete", total=len(models))
409
+ return models
410
+
411
+ async def __aenter__(self) -> "HFAdapter":
412
+ return self
413
+
414
+ async def __aexit__(self, *_: Any) -> None:
415
+ await self._client.aclose()
adapters/onnx_adapter.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adapters/onnx_adapter.py — ONNX Model Zoo adapter.
3
+ Fetches the curated list of ONNX Zoo models from the GitHub API.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import Any
8
+
9
+ import httpx
10
+ from tenacity import retry, stop_after_attempt, wait_exponential
11
+
12
+ from adapters.base import BaseAdapter
13
+ from models.model import Model, ModelMetrics, ModelVersion
14
+ from observability.logger import get_logger
15
+
16
+ log = get_logger("onnx_adapter")
17
+
18
+ # Curated ONNX Zoo models with metadata + download URLs (GitHub API is rate-limited without auth)
19
+ ONNX_CURATED: list[dict[str, Any]] = [
20
+ {
21
+ "id": "onnx_resnet50",
22
+ "name": "ResNet-50",
23
+ "task": "classification",
24
+ "provider": "ONNX Zoo",
25
+ "description": "ResNet-50 v1 image classification model in ONNX format.",
26
+ "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx",
27
+ "size": 102_000_000,
28
+ "tags": ["resnet", "imagenet", "classification"],
29
+ "hardware": ["gpu", "cpu"],
30
+ "metrics": {"latency_ms": 14.2, "top1": 74.9},
31
+ "downloads": 250_000,
32
+ "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-06-01"}],
33
+ },
34
+ {
35
+ "id": "onnx_yolov8n",
36
+ "name": "YOLOv8n",
37
+ "task": "detection",
38
+ "provider": "Ultralytics",
39
+ "description": "Ultralytics YOLOv8 Nano — real-time object detection, ONNX export.",
40
+ "download_url": "https://github.com/ultralytics/yolov8/releases/download/v8.0.0/yolov8n.onnx",
41
+ "size": 6_200_000,
42
+ "tags": ["yolo", "real-time", "fastest", "edge"],
43
+ "hardware": ["gpu", "cpu", "edge"],
44
+ "metrics": {"latency_ms": 3.1, "mAP": 37.3},
45
+ "downloads": 420_000,
46
+ "versions": [{"version": "8.0", "label": "Latest", "releaseDate": "2023-09-15"}],
47
+ },
48
+ {
49
+ "id": "onnx_mobilenet_v3",
50
+ "name": "MobileNetV3-Large",
51
+ "task": "classification",
52
+ "provider": "Google",
53
+ "description": "MobileNetV3-Large for efficient on-device image classification.",
54
+ "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv3-large-1.11.onnx",
55
+ "size": 22_000_000,
56
+ "tags": ["mobilenet", "lightweight", "edge", "efficient"],
57
+ "hardware": ["cpu", "edge"],
58
+ "metrics": {"latency_ms": 5.8, "top1": 75.2, "fps": 180},
59
+ "downloads": 310_000,
60
+ "versions": [{"version": "3.0", "label": "Latest", "releaseDate": "2023-01-01"}],
61
+ },
62
+ {
63
+ "id": "onnx_bert_base_uncased",
64
+ "name": "BERT-Base-Uncased",
65
+ "task": "nlp",
66
+ "provider": "Google",
67
+ "description": "BERT base model fine-tuned for NLP inference in ONNX format.",
68
+ "download_url": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
69
+ "size": 438_000_000,
70
+ "tags": ["bert", "nlp", "transformer"],
71
+ "hardware": ["gpu", "cpu"],
72
+ "metrics": {"latency_ms": 42.0},
73
+ "downloads": 198_000,
74
+ "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2022-11-01"}],
75
+ },
76
+ {
77
+ "id": "onnx_efficientnet_b0",
78
+ "name": "EfficientNet-B0",
79
+ "task": "classification",
80
+ "provider": "Google Brain",
81
+ "description": "EfficientNet-B0 for scalable image classification.",
82
+ "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite/model/efficientnet-lite4-11.onnx",
83
+ "size": 20_000_000,
84
+ "tags": ["efficientnet", "efficient", "high-accuracy"],
85
+ "hardware": ["gpu", "cpu"],
86
+ "metrics": {"latency_ms": 10.4, "top1": 77.1},
87
+ "downloads": 145_000,
88
+ "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-03-01"}],
89
+ },
90
+ {
91
+ "id": "onnx_sam_vit_b",
92
+ "name": "SAM ViT-B",
93
+ "task": "segmentation",
94
+ "provider": "Meta AI",
95
+ "description": "Segment Anything Model (ViT-B) for universal image segmentation.",
96
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
97
+ "size": 375_000_000,
98
+ "tags": ["sam", "segmentation", "sota"],
99
+ "hardware": ["gpu"],
100
+ "metrics": {"latency_ms": 68.0},
101
+ "downloads": 88_000,
102
+ "versions": [{"version": "1.0", "label": "Latest", "releaseDate": "2023-04-05"}],
103
+ },
104
+ {
105
+ "id": "onnx_clip_vit_b32",
106
+ "name": "CLIP ViT-B/32",
107
+ "task": "embedding",
108
+ "provider": "OpenAI",
109
+ "description": "CLIP image + text embedding model for zero-shot classification.",
110
+ "download_url": "https://openaipublic.blob.core.windows.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba4f386/ViT-B-32.pt",
111
+ "size": 338_000_000,
112
+ "tags": ["clip", "embedding", "multimodal"],
113
+ "hardware": ["gpu", "cpu"],
114
+ "metrics": {"latency_ms": 25.0},
115
+ "downloads": 275_000,
116
+ "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-01-01"}],
117
+ },
118
+ {
119
+ "id": "onnx_whisper_tiny",
120
+ "name": "Whisper Tiny",
121
+ "task": "nlp",
122
+ "provider": "OpenAI",
123
+ "description": "Whisper Tiny speech-to-text model in ONNX format.",
124
+ "download_url": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424930e36a852c0/tiny.pt",
125
+ "size": 39_000_000,
126
+ "tags": ["whisper", "speech", "lightweight"],
127
+ "hardware": ["cpu", "edge"],
128
+ "metrics": {"latency_ms": 100.0},
129
+ "downloads": 167_000,
130
+ "versions": [{"version": "20231117", "label": "Latest", "releaseDate": "2023-11-17"}],
131
+ },
132
+ ]
133
+
134
+
135
+ class ONNXAdapter(BaseAdapter):
136
+ source_name = "onnx"
137
+
138
+ async def fetch_models(self) -> list[Model]:
139
+ models: list[Model] = []
140
+ for raw in ONNX_CURATED:
141
+ try:
142
+ versions = [
143
+ ModelVersion(
144
+ version=v["version"],
145
+ label=v.get("label", "Stable"),
146
+ releaseDate=v.get("releaseDate", ""),
147
+ )
148
+ for v in raw.get("versions", [])
149
+ ]
150
+ metrics_raw = raw.get("metrics", {})
151
+ m = Model(
152
+ id = raw["id"],
153
+ name = raw["name"],
154
+ task = raw["task"],
155
+ framework = "onnx",
156
+ source = "onnx",
157
+ provider = raw.get("provider", "ONNX Zoo"),
158
+ description = raw.get("description", ""),
159
+ download_url = raw.get("download_url"),
160
+ size = raw.get("size", 0),
161
+ size_label = self._format_size(raw.get("size", 0)),
162
+ tags = raw.get("tags", []),
163
+ hardware = raw.get("hardware", ["gpu"]),
164
+ status = "available",
165
+ downloaded = False,
166
+ downloads = raw.get("downloads"),
167
+ rating = 4.2,
168
+ metrics = ModelMetrics(**metrics_raw),
169
+ versions = versions,
170
+ )
171
+ models.append(m)
172
+ except Exception as exc:
173
+ log.warning("onnx_parse_failed", model_id=raw.get("id"), error=str(exc))
174
+
175
+ log.info("onnx_fetch_complete", total=len(models))
176
+ return models
adapters/roboflow_adapter.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adapters/roboflow_adapter.py — Roboflow Universe API client.
3
+
4
+ Responsibilities:
5
+ - Fetch dataset metadata (search, workspace listings, project details)
6
+ - Normalise responses → Dataset domain model
7
+ - Cache results in roboflow_cache table (TTL-aware)
8
+ - Handle pagination, rate limits, and errors robustly
9
+
10
+ Roboflow API reference: https://docs.roboflow.com/api-reference/
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import hashlib
15
+ import json
16
+ import time
17
+ from typing import Any
18
+
19
+ import httpx
20
+ from tenacity import retry, stop_after_attempt, wait_exponential
21
+
22
+ from database.connection import get_db
23
+ from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask
24
+ from observability.logger import audit, get_logger
25
+
26
+ log = get_logger("roboflow_adapter")
27
+
28
+ _ROBOFLOW_BASE = "https://api.roboflow.com"
29
+ _UNIVERSE_BASE = "https://universe.roboflow.com"
30
+ _DEFAULT_TTL = 3600 # 1 hour
31
+
32
+ # ── Task mapping from Roboflow annotation_type ───────────────────────────────
33
+
34
+ _TASK_MAP: dict[str, DatasetTask] = {
35
+ "object-detection": DatasetTask.detection,
36
+ "instance-segmentation": DatasetTask.segmentation,
37
+ "semantic-segmentation": DatasetTask.segmentation,
38
+ "classification": DatasetTask.classification,
39
+ "keypoint-detection": DatasetTask.keypoints,
40
+ "multiclass-classification": DatasetTask.classification,
41
+ }
42
+
43
+ _FORMAT_MAP: dict[str, DatasetFormat] = {
44
+ "yolov5": DatasetFormat.yolo,
45
+ "yolov7": DatasetFormat.yolo,
46
+ "yolov8": DatasetFormat.yolo,
47
+ "yolov9": DatasetFormat.yolo,
48
+ "coco": DatasetFormat.coco,
49
+ "voc": DatasetFormat.voc,
50
+ "tfrecord": DatasetFormat.tfrecord,
51
+ "csv": DatasetFormat.csv,
52
+ "createml": DatasetFormat.json,
53
+ "multiclass": DatasetFormat.csv,
54
+ }
55
+
56
+
57
+ def _cache_key(parts: list[str]) -> str:
58
+ raw = "|".join(parts)
59
+ return hashlib.sha256(raw.encode()).hexdigest()[:32]
60
+
61
+
62
+ def _fmt_bytes(n: int) -> str:
63
+ for unit in ("B", "KB", "MB", "GB", "TB"):
64
+ if n < 1024:
65
+ return f"{n:.1f} {unit}"
66
+ n /= 1024
67
+ return f"{n:.1f} PB"
68
+
69
+
70
+ # ── Cache helpers ─────────────────────────────────────────────────────────────
71
+
72
+ async def _cache_get(key: str) -> dict[str, Any] | None:
73
+ db = await get_db()
74
+ async with db.execute(
75
+ "SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?",
76
+ (key,),
77
+ ) as cur:
78
+ row = await cur.fetchone()
79
+ if row is None:
80
+ return None
81
+ fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S"))
82
+ if time.time() - fetched > row["ttl_secs"]:
83
+ return None # expired
84
+ return json.loads(row["payload"])
85
+
86
+
87
+ async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None:
88
+ db = await get_db()
89
+ await db.execute(
90
+ """INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs)
91
+ VALUES (?, ?, ?)""",
92
+ (key, json.dumps(payload), ttl),
93
+ )
94
+ await db.commit()
95
+
96
+
97
+ # ── HTTP client factory ───────────────────────────────────────────────────────
98
+
99
+ def _make_client(api_key: str) -> httpx.AsyncClient:
100
+ return httpx.AsyncClient(
101
+ base_url=_ROBOFLOW_BASE,
102
+ params={"api_key": api_key},
103
+ timeout=30.0,
104
+ headers={"User-Agent": "MLForge/1.0"},
105
+ )
106
+
107
+
108
+ # ── Roboflow Adapter ──────────────────────────────────────────────────────────
109
+
110
+ class RoboflowAdapter:
111
+ """
112
+ Stateless adapter for the Roboflow API.
113
+ All methods accept api_key explicitly to support per-user keys.
114
+ """
115
+
116
+ # ── Search (Universe) ─────────────────────────────────────────────────────
117
+
118
+ @staticmethod
119
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
120
+ async def search_datasets(
121
+ api_key: str,
122
+ query: str = "",
123
+ workspace: str | None = None,
124
+ page: int = 0,
125
+ page_size: int = 50,
126
+ ) -> list[Dataset]:
127
+ """
128
+ Search Roboflow Universe for datasets.
129
+ Returns normalised Dataset objects.
130
+ """
131
+ ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)])
132
+ cached = await _cache_get(ck)
133
+ if cached:
134
+ log.debug("roboflow_cache_hit", key=ck, query=query)
135
+ return [Dataset(**d) for d in cached]
136
+
137
+ params: dict[str, Any] = {
138
+ "api_key": api_key,
139
+ "q": query or "*",
140
+ "from": page * page_size,
141
+ "size": page_size,
142
+ }
143
+ if workspace:
144
+ params["workspace"] = workspace
145
+
146
+ async with _make_client(api_key) as client:
147
+ try:
148
+ resp = await client.get("/", params=params)
149
+ resp.raise_for_status()
150
+ data = resp.json()
151
+ except httpx.HTTPStatusError as e:
152
+ log.error("roboflow_api_error", status=e.response.status_code, query=query)
153
+ await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error")
154
+ raise
155
+
156
+ datasets = []
157
+ for item in data.get("results", []):
158
+ try:
159
+ ds = RoboflowAdapter._normalise_search_result(item)
160
+ datasets.append(ds)
161
+ except Exception as exc:
162
+ log.warning("normalise_error", item_id=item.get("id"), error=str(exc))
163
+
164
+ await _cache_set(ck, [d.model_dump() for d in datasets])
165
+ await audit("roboflow_search", {"query": query, "count": len(datasets)})
166
+ return datasets
167
+
168
+ # ── Workspace datasets listing ────────────────────────────────────────────
169
+
170
+ @staticmethod
171
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
172
+ async def list_workspace_datasets(
173
+ api_key: str,
174
+ workspace: str,
175
+ ) -> list[Dataset]:
176
+ """List all datasets in a Roboflow workspace."""
177
+ ck = _cache_key(["workspace", workspace])
178
+ cached = await _cache_get(ck)
179
+ if cached:
180
+ return [Dataset(**d) for d in cached]
181
+
182
+ async with _make_client(api_key) as client:
183
+ try:
184
+ resp = await client.get(f"/{workspace}")
185
+ resp.raise_for_status()
186
+ data = resp.json()
187
+ except httpx.HTTPStatusError as e:
188
+ log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code)
189
+ raise
190
+
191
+ datasets = []
192
+ for proj in data.get("workspace", {}).get("projects", []):
193
+ try:
194
+ ds = RoboflowAdapter._normalise_project(proj, workspace)
195
+ datasets.append(ds)
196
+ except Exception as exc:
197
+ log.warning("normalise_project_error", project=proj.get("id"), error=str(exc))
198
+
199
+ await _cache_set(ck, [d.model_dump() for d in datasets])
200
+ return datasets
201
+
202
+ # ── Single project detail ─────────────────────────────────────────────────
203
+
204
+ @staticmethod
205
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
206
+ async def get_project(
207
+ api_key: str,
208
+ workspace: str,
209
+ project_id: str,
210
+ ) -> Dataset | None:
211
+ """Fetch full metadata for a single Roboflow project."""
212
+ ck = _cache_key(["project", workspace, project_id])
213
+ cached = await _cache_get(ck)
214
+ if cached:
215
+ return Dataset(**cached)
216
+
217
+ async with _make_client(api_key) as client:
218
+ try:
219
+ resp = await client.get(f"/{workspace}/{project_id}")
220
+ resp.raise_for_status()
221
+ data = resp.json()
222
+ except httpx.HTTPStatusError as e:
223
+ if e.response.status_code == 404:
224
+ return None
225
+ raise
226
+
227
+ proj_data = data.get("project", data)
228
+ ds = RoboflowAdapter._normalise_project(proj_data, workspace)
229
+ await _cache_set(ck, ds.model_dump())
230
+ return ds
231
+
232
+ # ── Download URL builder ──────────────────────────────────────────────────
233
+
234
+ @staticmethod
235
+ async def get_download_url(
236
+ api_key: str,
237
+ workspace: str,
238
+ project_id: str,
239
+ version: int,
240
+ export_format: str = "yolov8",
241
+ ) -> str:
242
+ """
243
+ Fetch the export download link from Roboflow for the specified format.
244
+ Uses the official Roboflow SDK to handle authentication and URL resolution.
245
+ """
246
+ try:
247
+ from roboflow import Roboflow
248
+ rf = Roboflow(api_key=api_key)
249
+ project = rf.workspace(workspace).project(project_id)
250
+ version_obj = project.version(version)
251
+
252
+ # The SDK's download method usually downloads to disk,
253
+ # but we can get the underlying export info.
254
+ # We'll use a thread to run the SDK call since it's blocking.
255
+ import asyncio
256
+ def _get_link():
257
+ return version_obj.export(export_format).download_link
258
+
259
+ link = await asyncio.to_thread(_get_link)
260
+ if not link:
261
+ raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
262
+ return link
263
+ except Exception as e:
264
+ log.error("roboflow_sdk_error", error=str(e))
265
+ # Fallback to manual API if SDK fails or isn't installed correctly
266
+ async with _make_client(api_key) as client:
267
+ resp = await client.get(
268
+ f"/{workspace}/{project_id}/{version}/{export_format}"
269
+ )
270
+ resp.raise_for_status()
271
+ data = resp.json()
272
+
273
+ link = export.get("link") or ""
274
+ if not link:
275
+ # If 'link' is missing, check if it's a Universe-style project and try to resolve manually
276
+ # Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key]
277
+ if "project" in data:
278
+ pid = data["project"].get("id")
279
+ if pid:
280
+ link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}"
281
+
282
+ if not link:
283
+ raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
284
+
285
+ # Ensure the link includes the API key correctly
286
+ if "universe.roboflow.com" in link:
287
+ if "key=" not in link:
288
+ separator = "&" if "?" in link else "?"
289
+ link = f"{link}{separator}key={api_key}"
290
+ elif f"key={api_key}" not in link:
291
+ # Replace old key if it exists but is wrong
292
+ import re
293
+ link = re.sub(r"key=[^&]+", f"key={api_key}", link)
294
+
295
+ return link
296
+
297
+ # ── Normalisation helpers ─────────────────────────────────────────────────
298
+
299
+ @staticmethod
300
+ def _normalise_search_result(item: dict[str, Any]) -> Dataset:
301
+ """Map a Universe search result → Dataset."""
302
+ ann_type = item.get("annotation", {}).get("type", "object-detection")
303
+ rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
304
+ class_names = [c.get("name", "") for c in item.get("classes", [])]
305
+ images = item.get("images", 0) or 0
306
+
307
+ return Dataset(
308
+ id = item.get("id", "").replace("/", "__"),
309
+ name = item.get("name", "Unnamed"),
310
+ description = item.get("description", ""),
311
+ task = rf_task,
312
+ format = DatasetFormat.yolo,
313
+ source = DatasetSource.roboflow,
314
+ status = DatasetStatus.available,
315
+ images = images,
316
+ classes = len(class_names),
317
+ class_names = class_names,
318
+ size_bytes = 0,
319
+ size_label = "—",
320
+ tags = item.get("tags", []),
321
+ roboflow_id = item.get("id", ""),
322
+ created_at = item.get("created", ""),
323
+ updated_at = item.get("updated", ""),
324
+ )
325
+
326
+ @staticmethod
327
+ def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset:
328
+ """Map a workspace project → Dataset."""
329
+ ann_type = proj.get("annotation", "object-detection")
330
+ rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
331
+ class_names = [c.get("name", c) if isinstance(c, dict) else c
332
+ for c in proj.get("classes", [])]
333
+ project_id = proj.get("id", proj.get("name", "unknown"))
334
+ rf_id = f"{workspace}/{project_id}"
335
+ images = proj.get("images", 0) or 0
336
+
337
+ return Dataset(
338
+ id = rf_id.replace("/", "__"),
339
+ name = proj.get("name", project_id),
340
+ description = proj.get("description", ""),
341
+ task = rf_task,
342
+ format = DatasetFormat.yolo,
343
+ source = DatasetSource.roboflow,
344
+ status = DatasetStatus.available,
345
+ images = images,
346
+ classes = len(class_names),
347
+ class_names = class_names,
348
+ size_bytes = 0,
349
+ size_label = "—",
350
+ roboflow_id = rf_id,
351
+ created_at = proj.get("created", ""),
352
+ updated_at = proj.get("updated", ""),
353
+ )
benchmark/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # benchmark — Benchmark Bridge System for MLForge
benchmark/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (144 Bytes). View file
 
benchmark/__pycache__/compatibility.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
benchmark/__pycache__/execution.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
benchmark/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
benchmark/__pycache__/orchestrator.cpython-310.pyc ADDED
Binary file (9.11 kB). View file
 
benchmark/__pycache__/registry.cpython-310.pyc ADDED
Binary file (8.77 kB). View file
 
benchmark/__pycache__/telemetry.cpython-310.pyc ADDED
Binary file (6.73 kB). View file
 
benchmark/adapters/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
benchmark/adapters/__pycache__/registry.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
benchmark/adapters/base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/adapters/base.py — Base class for all Benchmark Runners.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, AsyncGenerator
9
+
10
+ from models.benchmark import BenchmarkContext, TelemetrySample
11
+
12
+
13
+ @dataclass
14
+ class BatchResult:
15
+ """Result of a single batch execution."""
16
+ latency_ms: float
17
+ vram_used_gb: float
18
+ task_scores: dict[str, float] = field(default_factory=dict)
19
+ metadata: dict[str, Any] = field(default_factory=dict)
20
+
21
+
22
+ class BaseRunner(ABC):
23
+ """Abstract interface for benchmark executors (Torch, Optimum, vLLM)."""
24
+
25
+ @abstractmethod
26
+ async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
27
+ """Load model and prepare environment."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def run_batch(self, batch: Any) -> BatchResult:
32
+ """Execute a single batch of data."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def shutdown(self) -> None:
37
+ """Release resources."""
38
+ pass
benchmark/adapters/optimum_runner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/adapters/optimum_runner.py — Hugging Face Optimum Adapter.
3
+ Supports ONNX, OpenVINO, and TensorRT acceleration.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import time
8
+ import asyncio
9
+ from typing import Any
10
+ from benchmark.adapters.base import BaseRunner, BatchResult
11
+ from models.benchmark import BenchmarkContext
12
+ from observability.logger import get_logger
13
+
14
+ log = get_logger("benchmark.optimum")
15
+
16
+ class OptimumRunner(BaseRunner):
17
+ def __init__(self):
18
+ self.session = None
19
+ self.device = "cpu"
20
+
21
+ async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
22
+ """
23
+ Load model using Optimum's ORTModel or equivalent.
24
+ In a real implementation, this would detect the framework and use:
25
+ ORTModelForFeatureExtraction.from_pretrained(model_path, provider=...)
26
+ """
27
+ log.info("optimum_init", model_path=model_path, hardware=ctx.hardware)
28
+ self.device = "cuda" if "gpu" in ctx.hardware.lower() or "rtx" in ctx.hardware.lower() else "cpu"
29
+
30
+ # Simulate load time
31
+ await asyncio.sleep(1.5)
32
+ self.session = "active" # Placeholder for the real session object
33
+
34
+ async def run_batch(self, batch: Any) -> BatchResult:
35
+ """Execute inference using the Optimum/ONNX Runtime session."""
36
+ if not self.session:
37
+ raise RuntimeError("Optimum session not initialized")
38
+
39
+ start_time = time.perf_counter()
40
+ # Mocking inference logic
41
+ # outputs = self.session(**batch)
42
+ await asyncio.sleep(0.01) # Simulated inference time
43
+ latency = (time.perf_counter() - start_time) * 1000
44
+
45
+ return BatchResult(
46
+ latency_ms=latency,
47
+ vram_used_gb=0.8, # Mocked
48
+ task_scores={"accuracy": 0.92} # Mocked
49
+ )
50
+
51
+ async def shutdown(self) -> None:
52
+ log.info("optimum_shutdown")
53
+ self.session = None
benchmark/adapters/registry.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/adapters/registry.py — Executor Registry for dynamic runner resolution.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from typing import Type
7
+ from benchmark.adapters.base import BaseRunner
8
+ from models.benchmark import BenchmarkContext
9
+ from models.model import Model
10
+
11
+ class ExecutorRegistry:
12
+ _runners: dict[str, Type[BaseRunner]] = {}
13
+
14
+ @classmethod
15
+ def register(cls, framework: str, runner_cls: Type[BaseRunner]):
16
+ cls._runners[framework.lower()] = runner_cls
17
+
18
+ @classmethod
19
+ def get_runner(cls, framework: str) -> BaseRunner:
20
+ runner_cls = cls._runners.get(framework.lower())
21
+ if not runner_cls:
22
+ # Fallback or default runner
23
+ from benchmark.adapters.torch_runner import TorchRunner
24
+ return TorchRunner()
25
+ return runner_cls()
26
+
27
+ def get_executor(ctx: BenchmarkContext, model: Model) -> BaseRunner:
28
+ """Resolve the appropriate executor based on framework and task."""
29
+ framework = model.framework.lower()
30
+
31
+ # Special cases for optimized engines
32
+ if framework == "onnx" or framework == "openvino" or framework == "tensorrt":
33
+ from benchmark.adapters.optimum_runner import OptimumRunner
34
+ return OptimumRunner()
35
+
36
+ if ctx.task in ("generation", "nlp") and framework == "pytorch":
37
+ # Potential for vLLM if configured
38
+ try:
39
+ from benchmark.adapters.vllm_runner import VLLMRunner
40
+ return VLLMRunner()
41
+ except ImportError:
42
+ pass
43
+
44
+ return ExecutorRegistry.get_runner(framework)
benchmark/adapters/torch_runner.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/adapters/torch_runner.py — PyTorch Runner Adapter.
3
+ Wraps standard PyTorch inference for Vision and NLP tasks.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import time
8
+ import asyncio
9
+ import random
10
+ from typing import Any
11
+ from benchmark.adapters.base import BaseRunner, BatchResult
12
+ from models.benchmark import BenchmarkContext
13
+ from observability.logger import get_logger
14
+
15
+ log = get_logger("benchmark.torch")
16
+
17
+ class TorchRunner(BaseRunner):
18
+ def __init__(self):
19
+ self.model = None
20
+ self.device = "cpu"
21
+
22
+ async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None:
23
+ log.info("torch_init", model_path=model_path, hardware=ctx.hardware)
24
+ # In production: self.model = torch.load(model_path).to(self.device)
25
+ await asyncio.sleep(1.0)
26
+ self.model = "active"
27
+
28
+ async def run_batch(self, batch: Any) -> BatchResult:
29
+ if not self.model:
30
+ raise RuntimeError("Torch model not initialized")
31
+
32
+ start_time = time.perf_counter()
33
+ # Mocking torch inference
34
+ await asyncio.sleep(0.02)
35
+ latency = (time.perf_counter() - start_time) * 1000
36
+
37
+ return BatchResult(
38
+ latency_ms=latency,
39
+ vram_used_gb=1.2,
40
+ task_scores={"mAP": 0.45}
41
+ )
42
+
43
+ async def shutdown(self) -> None:
44
+ log.info("torch_shutdown")
45
+ self.model = None
benchmark/compatibility.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/compatibility.py — Compatibility Validator (CRITICAL MODULE).
3
+
4
+ Validates model ↔ dataset ↔ hardware compatibility before any benchmark
5
+ execution begins. Returns a structured ValidationReport — never raises.
6
+
7
+ Five gates (all must pass):
8
+ A. Task compatibility — model.task matches dataset.task
9
+ B. Annotation format — dataset format supports the model's task
10
+ C. Framework × hardware — framework can run on the requested device
11
+ D. VRAM constraint — estimated memory fits available VRAM
12
+ E. Precision support — precision mode is valid for framework + hardware
13
+ """
14
+ from __future__ import annotations
15
+
16
+ from models.benchmark import BenchmarkContext, ValidationCheck, ValidationReport
17
+ from models.dataset import Dataset
18
+ from models.model import Model
19
+ from observability.logger import get_logger
20
+
21
+ log = get_logger("benchmark.compatibility")
22
+
23
+
24
+ # ── Lookup tables ─────────────────────────────────────────────────────────────
25
+
26
+ # Hardware → available VRAM in GB (normalized keys, no spaces/dashes)
27
+ HARDWARE_VRAM_GB: dict[str, float] = {
28
+ # NVIDIA consumer — Ampere / Ada
29
+ "rtx4090": 24.0,
30
+ "rtx4080": 16.0,
31
+ "rtx4070ti": 12.0,
32
+ "rtx4070": 12.0,
33
+ "rtx4060ti": 8.0,
34
+ "rtx4060": 8.0,
35
+ "rtx3090": 24.0,
36
+ "rtx3080": 10.0,
37
+ "rtx3070": 8.0,
38
+ "rtx3060": 12.0,
39
+ "rtx2080ti": 11.0,
40
+ "rtx2080": 8.0,
41
+ # NVIDIA datacenter
42
+ "a100": 80.0,
43
+ "a10040gb": 40.0,
44
+ "h100": 80.0,
45
+ "v100": 32.0,
46
+ "t4": 16.0,
47
+ "a10": 24.0,
48
+ # AMD
49
+ "rx7900xtx": 24.0,
50
+ "rx6800xt": 16.0,
51
+ # Generic fallbacks
52
+ "gpu": 8.0,
53
+ "cpu": 0.0,
54
+ "tpu": 0.0,
55
+ "edge": 0.0,
56
+ }
57
+
58
+ # model.task → set of compatible dataset.task values
59
+ TASK_COMPAT: dict[str, set[str]] = {
60
+ "detection": {"detection"},
61
+ "classification": {"classification"},
62
+ "segmentation": {"segmentation"},
63
+ "nlp": {"nlp"},
64
+ "generation": {"generation"},
65
+ "keypoints": {"keypoints", "detection"},
66
+ "embedding": {"nlp", "classification"},
67
+ }
68
+
69
+ # dataset.format → set of model tasks it supports
70
+ FORMAT_TASK_COMPAT: dict[str, set[str]] = {
71
+ "yolo": {"detection", "segmentation", "keypoints"},
72
+ "coco": {"detection", "segmentation", "keypoints"},
73
+ "voc": {"detection"},
74
+ "csv": {"classification"},
75
+ "json": {"detection", "segmentation", "classification", "nlp", "generation"},
76
+ "tfrecord": {"detection", "classification", "segmentation"},
77
+ "custom": {"detection", "classification", "segmentation", "nlp", "generation", "keypoints"},
78
+ }
79
+
80
+ # model.framework → set of hardware targets (normalized) it can run on
81
+ FRAMEWORK_HARDWARE_COMPAT: dict[str, set[str]] = {
82
+ "pytorch": {
83
+ "cpu", "gpu",
84
+ "rtx4090", "rtx4080", "rtx4070ti", "rtx4070", "rtx4060ti", "rtx4060",
85
+ "rtx3090", "rtx3080", "rtx3070", "rtx3060",
86
+ "rtx2080ti", "rtx2080",
87
+ "a100", "a10040gb", "h100", "v100", "t4", "a10",
88
+ },
89
+ "onnx": {
90
+ "cpu", "gpu",
91
+ "rtx4090", "rtx3090", "a100", "h100", "t4", "a10",
92
+ "edge",
93
+ },
94
+ "tensorflow": {
95
+ "cpu", "gpu",
96
+ "rtx4090", "rtx3090", "a100", "h100", "v100", "t4",
97
+ "tpu",
98
+ },
99
+ "tflite": {"cpu", "edge"},
100
+ "coreml": {"cpu"},
101
+ }
102
+
103
+ # Precisions that require GPU
104
+ _GPU_ONLY_PRECISIONS = {"FP16", "BF16"}
105
+
106
+ # Frameworks supporting INT8 quantization
107
+ _INT8_FRAMEWORKS = {"onnx", "tflite", "pytorch", "tensorflow"}
108
+
109
+
110
+ class CompatibilityValidator:
111
+ """
112
+ Runs all compatibility gates before a benchmark job is created.
113
+ Returns a ValidationReport — never raises exceptions.
114
+ """
115
+
116
+ def validate(
117
+ self,
118
+ model: Model,
119
+ dataset: Dataset,
120
+ ctx: BenchmarkContext,
121
+ ) -> ValidationReport:
122
+ checks: list[ValidationCheck] = [
123
+ self._check_task(model, dataset),
124
+ self._check_annotation_format(model, dataset),
125
+ self._check_framework_hardware(model, ctx),
126
+ self._check_vram(model, ctx),
127
+ self._check_precision(model, ctx),
128
+ ]
129
+
130
+ errors = [c.detail for c in checks if not c.passed]
131
+ warnings: list[str] = []
132
+
133
+ log.info(
134
+ "compatibility_validated",
135
+ model_id = model.id,
136
+ dataset_id = dataset.id,
137
+ passed = len(errors) == 0,
138
+ error_count = len(errors),
139
+ )
140
+
141
+ return ValidationReport(
142
+ model_id = model.id,
143
+ dataset_id = dataset.id,
144
+ passed = len(errors) == 0,
145
+ checks = checks,
146
+ errors = errors,
147
+ warnings = warnings,
148
+ )
149
+
150
+ # ── Gate A: Task ────────────────────��─────────────────────────────────────
151
+
152
+ def _check_task(self, model: Model, dataset: Dataset) -> ValidationCheck:
153
+ model_task = model.task.lower().strip()
154
+ dataset_task = str(dataset.task).lower().strip()
155
+
156
+ allowed = TASK_COMPAT.get(model_task, {model_task})
157
+ if dataset_task in allowed:
158
+ return ValidationCheck(
159
+ name = "task_compatibility",
160
+ passed = True,
161
+ detail = (
162
+ f"Model task '{model_task}' is compatible "
163
+ f"with dataset task '{dataset_task}'"
164
+ ),
165
+ )
166
+ return ValidationCheck(
167
+ name = "task_compatibility",
168
+ passed = False,
169
+ detail = (
170
+ f"Model task '{model_task}' cannot evaluate "
171
+ f"a '{dataset_task}' dataset"
172
+ ),
173
+ suggestion = (
174
+ f"Select a model with task='{dataset_task}', "
175
+ f"or choose a dataset with task='{model_task}'"
176
+ ),
177
+ )
178
+
179
+ # ── Gate B: Annotation Format ─────────────────────────────────────────────
180
+
181
+ def _check_annotation_format(self, model: Model, dataset: Dataset) -> ValidationCheck:
182
+ dataset_fmt = str(dataset.format).lower().strip()
183
+ model_task = model.task.lower().strip()
184
+ supported = FORMAT_TASK_COMPAT.get(dataset_fmt, set())
185
+
186
+ if model_task in supported:
187
+ return ValidationCheck(
188
+ name = "annotation_format",
189
+ passed = True,
190
+ detail = (
191
+ f"Dataset format '{dataset_fmt}' supports "
192
+ f"model task '{model_task}'"
193
+ ),
194
+ )
195
+
196
+ if model_task in {"detection", "segmentation", "keypoints"}:
197
+ suggestion = (
198
+ f"Convert dataset to YOLO or COCO format — both support '{model_task}'"
199
+ )
200
+ elif model_task == "classification":
201
+ suggestion = "Convert dataset to CSV or JSON format for classification tasks"
202
+ else:
203
+ suggestion = f"Use a JSON or custom-format dataset for '{model_task}' tasks"
204
+
205
+ return ValidationCheck(
206
+ name = "annotation_format",
207
+ passed = False,
208
+ detail = (
209
+ f"Dataset format '{dataset_fmt}' does not support "
210
+ f"model task '{model_task}'"
211
+ ),
212
+ suggestion = suggestion,
213
+ )
214
+
215
+ # ── Gate C: Framework × Hardware ─────────────────────────────────────────
216
+
217
+ def _check_framework_hardware(
218
+ self, model: Model, ctx: BenchmarkContext
219
+ ) -> ValidationCheck:
220
+ framework = model.framework.lower().strip()
221
+ hw_raw = ctx.hardware
222
+ hw_key = self._normalize_hw(hw_raw)
223
+
224
+ supported_hw = FRAMEWORK_HARDWARE_COMPAT.get(framework, {"cpu"})
225
+
226
+ # Match: exact key, or generic "gpu" bucket covers any named GPU
227
+ hw_ok = (
228
+ hw_key in supported_hw
229
+ or ("gpu" in supported_hw and hw_key not in {"cpu", "tpu", "edge"})
230
+ )
231
+
232
+ if hw_ok:
233
+ return ValidationCheck(
234
+ name = "framework_hardware",
235
+ passed = True,
236
+ detail = f"Framework '{framework}' is supported on '{hw_raw}'",
237
+ )
238
+ return ValidationCheck(
239
+ name = "framework_hardware",
240
+ passed = False,
241
+ detail = (
242
+ f"Framework '{framework}' cannot run on '{hw_raw}'. "
243
+ f"Supported targets: {', '.join(sorted(supported_hw))}"
244
+ ),
245
+ suggestion = (
246
+ "Use ONNX runtime for broadest hardware support, "
247
+ f"or pick a device from: {', '.join(sorted(supported_hw))}"
248
+ ),
249
+ )
250
+
251
+ # ── Gate D: VRAM Constraint ───────────────────────────────────────────────
252
+
253
+ def _check_vram(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
254
+ hw_key = self._normalize_hw(ctx.hardware)
255
+ available = self._lookup_vram(hw_key)
256
+
257
+ if available == 0.0:
258
+ return ValidationCheck(
259
+ name = "vram_constraint",
260
+ passed = True,
261
+ detail = f"Running on '{ctx.hardware}' (CPU/TPU/Edge) — no VRAM constraint",
262
+ )
263
+
264
+ # Estimate: weights at given precision + activations for one batch
265
+ model_gb = max(model.size, 1) / (1024 ** 3)
266
+ prec_map = {"FP16": 0.5, "BF16": 0.5, "INT8": 0.25, "FP32": 1.0}
267
+ prec_mult = prec_map.get(ctx.precision.upper(), 1.0)
268
+ # weights × precision + ~20% for optimizer/activation buffers + batch overhead
269
+ estimated = (model_gb * prec_mult * 1.2) + (ctx.batch_size * 0.05)
270
+
271
+ if estimated <= available:
272
+ return ValidationCheck(
273
+ name = "vram_constraint",
274
+ passed = True,
275
+ detail = (
276
+ f"Estimated VRAM {estimated:.2f} GB ≤ "
277
+ f"available {available:.1f} GB on '{ctx.hardware}'"
278
+ ),
279
+ )
280
+ return ValidationCheck(
281
+ name = "vram_constraint",
282
+ passed = False,
283
+ detail = (
284
+ f"Estimated VRAM {estimated:.2f} GB exceeds "
285
+ f"available {available:.1f} GB on '{ctx.hardware}'"
286
+ ),
287
+ suggestion = (
288
+ f"Try: reduce batch_size (now {ctx.batch_size}), "
289
+ f"switch to FP16/INT8 precision, "
290
+ f"or use a GPU with ≥ {estimated:.1f} GB VRAM"
291
+ ),
292
+ )
293
+
294
+ # ── Gate E: Precision Support ─────────────────────────────────────────────
295
+
296
+ def _check_precision(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck:
297
+ precision = ctx.precision.upper()
298
+ framework = model.framework.lower().strip()
299
+ hw_key = self._normalize_hw(ctx.hardware)
300
+ is_gpu = hw_key not in {"cpu", "tpu", "edge"}
301
+
302
+ if precision in _GPU_ONLY_PRECISIONS and not is_gpu:
303
+ return ValidationCheck(
304
+ name = "precision_support",
305
+ passed = False,
306
+ detail = (
307
+ f"Precision '{precision}' requires a CUDA GPU; "
308
+ f"'{ctx.hardware}' does not support it"
309
+ ),
310
+ suggestion = "Use FP32 for CPU inference, or switch to a compatible GPU",
311
+ )
312
+
313
+ if precision == "INT8" and framework not in _INT8_FRAMEWORKS:
314
+ return ValidationCheck(
315
+ name = "precision_support",
316
+ passed = False,
317
+ detail = (
318
+ f"Framework '{framework}' does not support INT8 quantization"
319
+ ),
320
+ suggestion = (
321
+ "Convert model to ONNX or use PyTorch with torch.quantization"
322
+ ),
323
+ )
324
+
325
+ return ValidationCheck(
326
+ name = "precision_support",
327
+ passed = True,
328
+ detail = (
329
+ f"Precision '{precision}' is valid for "
330
+ f"framework '{framework}' on '{ctx.hardware}'"
331
+ ),
332
+ )
333
+
334
+ # ── Helpers ───────────────────────────────────────────────────────────────
335
+
336
+ @staticmethod
337
+ def _normalize_hw(hardware: str) -> str:
338
+ """Lowercase, strip spaces/dashes/underscores for lookup."""
339
+ return (
340
+ hardware.lower()
341
+ .replace(" ", "")
342
+ .replace("-", "")
343
+ .replace("_", "")
344
+ .replace("nvidia", "")
345
+ .replace("geforce", "")
346
+ )
347
+
348
+ @staticmethod
349
+ def _lookup_vram(hw_key: str) -> float:
350
+ """Return VRAM GB for a normalized hardware key, with fallback matching."""
351
+ if hw_key in HARDWARE_VRAM_GB:
352
+ return HARDWARE_VRAM_GB[hw_key]
353
+ # Partial match (e.g. "rtx4090laptop" → "rtx4090")
354
+ for key, vram in HARDWARE_VRAM_GB.items():
355
+ if key and key in hw_key:
356
+ return vram
357
+ # Anything that looks like a GPU but isn't in the table
358
+ if "gpu" in hw_key or "rtx" in hw_key or "gtx" in hw_key or "cuda" in hw_key:
359
+ return HARDWARE_VRAM_GB["gpu"]
360
+ return 0.0 # CPU / unknown → no VRAM constraint
benchmark/execution.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/execution.py — Benchmark Execution Engine.
3
+
4
+ Drives the batch inference loop, collecting latencies and VRAM readings.
5
+ Calls TelemetryCollector in parallel with batch processing.
6
+ Yields progress callbacks so the orchestrator can persist real-time state.
7
+
8
+ Adapter pattern: swap _run_single_batch() with a real inference call
9
+ (torch.cuda.synchronize + model(batch)) once GPU runtime is wired up.
10
+
11
+ PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import math
17
+ import random
18
+ from dataclasses import dataclass, field
19
+ from typing import Awaitable, Callable
20
+
21
+ from benchmark.compatibility import HARDWARE_VRAM_GB
22
+ from benchmark.telemetry import TelemetryCollector
23
+ from models.benchmark import BenchmarkJob, LayerBreakdown, TelemetrySample, TelemetrySummary
24
+ from models.dataset import Dataset
25
+ from models.model import Model
26
+ from observability.logger import get_logger
27
+
28
+ log = get_logger("benchmark.execution")
29
+
30
+
31
+ # ── Per-image latency profiles (ms at batch=1, fp32) ─────────────────────────
32
+ _LATENCY_MS_PER_IMAGE: dict[str, float] = {
33
+ "rtx4090": 1.8,
34
+ "rtx4080": 2.5,
35
+ "rtx4070ti": 3.2,
36
+ "rtx4070": 3.8,
37
+ "rtx3090": 3.0,
38
+ "rtx3080": 4.5,
39
+ "rtx3070": 6.5,
40
+ "rtx3060": 9.0,
41
+ "rtx2080ti": 5.0,
42
+ "rtx2080": 7.5,
43
+ "a100": 1.2,
44
+ "h100": 0.7,
45
+ "v100": 2.8,
46
+ "t4": 5.5,
47
+ "a10": 3.5,
48
+ "gpu": 8.0,
49
+ "cpu": 42.0,
50
+ }
51
+
52
+ # Precision speedup multipliers (relative to FP32)
53
+ _PRECISION_SPEEDUP: dict[str, float] = {
54
+ "FP32": 1.0,
55
+ "FP16": 1.8,
56
+ "BF16": 1.7,
57
+ "INT8": 2.5,
58
+ }
59
+
60
+ # Task-specific baseline metric scores (pre-jitter)
61
+ _TASK_BASELINES: dict[str, dict[str, float]] = {
62
+ "detection": {"mAP": 0.435, "mAP_50": 0.618, "mAP_50_95": 0.435},
63
+ "classification": {"accuracy": 0.872, "top5": 0.968},
64
+ "segmentation": {"mAP": 0.372, "iou_mean": 0.706},
65
+ "keypoints": {"mAP": 0.641, "mAP_50": 0.860},
66
+ "nlp": {"accuracy": 0.891},
67
+ "generation": {"accuracy": 0.780},
68
+ }
69
+
70
+ # Cap simulated batches so large datasets don't stall the event loop
71
+ _MAX_SIMULATED_BATCHES = 250
72
+
73
+
74
+ @dataclass
75
+ class ExecutionResult:
76
+ """Raw output from the execution engine, consumed by MetricsEngine."""
77
+ latencies_ms: list[float]
78
+ total_images: int
79
+ vram_samples: list[float]
80
+ task_scores: dict[str, float]
81
+ telemetry_samples: list[TelemetrySample] = field(default_factory=list)
82
+ telemetry_summary: TelemetrySummary = field(default_factory=TelemetrySummary)
83
+
84
+
85
+ # Progress callback type: (progress_0_to_1, message, last_telemetry) → None
86
+ ProgressCallback = Callable[[float, str, TelemetrySample | None], Awaitable[None]]
87
+
88
+
89
+ class BenchmarkExecutor:
90
+ """
91
+ Drives the benchmark execution loop.
92
+ Non-blocking: all sleeps are asyncio.sleep so other coroutines run freely.
93
+ """
94
+
95
+ async def execute(
96
+ self,
97
+ job: BenchmarkJob,
98
+ model: Model,
99
+ dataset: Dataset,
100
+ on_progress: ProgressCallback,
101
+ ) -> ExecutionResult:
102
+ hw = job.hardware
103
+ batch_sz = job.batch_size
104
+
105
+ # Handle polymorphic input duration
106
+ is_live = getattr(job, "input_source", "dataset") in ("video", "live")
107
+
108
+ if is_live:
109
+ # For live/video, we run for a fixed duration or until stopped
110
+ # Increase limit for a longer session (e.g., 10,000 batches)
111
+ total_img = 10000 * batch_sz
112
+ n_batches = 10000
113
+ sim_batches = 10000
114
+ else:
115
+ total_img = max(dataset.images, 100) # floor so simulation always runs
116
+ n_batches = math.ceil(total_img / batch_sz)
117
+ sim_batches = min(n_batches, _MAX_SIMULATED_BATCHES)
118
+
119
+ vram_total = self._get_vram_gb(hw, model)
120
+ vram_frac = self._vram_usage_fraction(hw)
121
+
122
+ telemetry = TelemetryCollector(hw, vram_total_gb=vram_total)
123
+ await telemetry.start()
124
+
125
+ latencies: list[float] = []
126
+ vram_samples: list[float] = []
127
+
128
+ base_lat_ms = self._base_batch_latency_ms(hw, model, batch_sz, job.precision)
129
+
130
+ # Resolve real model path once (None → use simulation)
131
+ real_model_path = model.local_path if model.local_path and model.downloaded else None
132
+ use_real_inference = self._check_torch_available() and real_model_path is not None
133
+ loop = asyncio.get_event_loop()
134
+
135
+ try:
136
+ for sim_idx in range(sim_batches):
137
+ # Map simulated index back to real batch index
138
+ real_idx = int(sim_idx * (n_batches / sim_batches))
139
+
140
+ if use_real_inference:
141
+ # Real GPU inference via torch_runner (runs in thread executor)
142
+ try:
143
+ from benchmark.torch_runner import run_torch_batch
144
+ batch_lat_ms = await loop.run_in_executor(
145
+ None,
146
+ run_torch_batch,
147
+ real_model_path,
148
+ batch_sz,
149
+ job.task,
150
+ )
151
+ # Add a tiny sleep to prevent event loop starvation in live mode
152
+ if is_live:
153
+ await asyncio.sleep(0.001)
154
+ except Exception as exc:
155
+ log.warning("torch_inference_failed_fallback", error=str(exc))
156
+ use_real_inference = False # fall back for remaining batches
157
+ batch_lat_ms = max(
158
+ 0.5, base_lat_ms + random.gauss(0, base_lat_ms * 0.07)
159
+ )
160
+ else:
161
+ # Simulation path — non-blocking synthetic latency
162
+ batch_lat_ms = max(
163
+ 0.5,
164
+ base_lat_ms + random.gauss(0, base_lat_ms * 0.07),
165
+ )
166
+ await asyncio.sleep(batch_lat_ms / 1000.0) # non-blocking
167
+
168
+ latencies.append(batch_lat_ms)
169
+ vram_used = vram_total * random.uniform(
170
+ vram_frac - 0.05, vram_frac + 0.05
171
+ )
172
+ vram_samples.append(max(0.0, vram_used))
173
+
174
+ progress = (sim_idx + 1) / sim_batches
175
+ telemetry.record_batch_context(real_idx, progress)
176
+
177
+ # Throttle callbacks: every 5 batches or first/last
178
+ if sim_idx % 5 == 0 or sim_idx == sim_batches - 1:
179
+ images_done = int(progress * total_img)
180
+
181
+ # Generate simulated detection data for live preview if it's a vision task
182
+ live_data = {}
183
+ if job.task.lower() in ("detection", "segmentation"):
184
+ # Use provided bbox telemetry if available (e.g. from real inference)
185
+ # otherwise generate simulated ones
186
+ live_data["detections"] = [
187
+ {
188
+ "x": random.uniform(0.1, 0.7),
189
+ "y": random.uniform(0.1, 0.7),
190
+ "width": random.uniform(0.1, 0.3),
191
+ "height": random.uniform(0.1, 0.3),
192
+ "label": random.choice(["person", "car", "bicycle", "dog"]),
193
+ "confidence": random.uniform(0.5, 0.99)
194
+ }
195
+ for _ in range(random.randint(1, 5))
196
+ ]
197
+
198
+ last_sample = telemetry.samples[-1] if telemetry.samples else None
199
+ if last_sample:
200
+ last_sample.live_data = live_data
201
+ # Explicitly broadcast detections for the visualizer
202
+ last_sample.detections = live_data.get("detections", [])
203
+
204
+ await on_progress(
205
+ progress,
206
+ f"Batch {real_idx+1}/{n_batches} — "
207
+ f"{images_done}/{total_img} images processed",
208
+ last_sample,
209
+ )
210
+
211
+ finally:
212
+ telemetry_summary = await telemetry.stop()
213
+ # Attach simulated layer breakdown so Live Lab can display it
214
+ telemetry_summary.layer_breakdown = self._compute_layer_breakdown(
215
+ job.task, base_lat_ms
216
+ )
217
+
218
+ task_scores = self._simulate_task_scores(job.task, model, dataset)
219
+
220
+ log.info(
221
+ "execution_complete",
222
+ job_id = job.id,
223
+ total_images = total_img,
224
+ sim_batches = sim_batches,
225
+ avg_lat_ms = round(sum(latencies) / len(latencies), 2) if latencies else 0,
226
+ )
227
+
228
+ return ExecutionResult(
229
+ latencies_ms = latencies,
230
+ total_images = total_img,
231
+ vram_samples = vram_samples,
232
+ task_scores = task_scores,
233
+ telemetry_samples = telemetry.samples,
234
+ telemetry_summary = telemetry_summary,
235
+ )
236
+
237
+ # ── Helpers ───────────────────────────────────────────────────────────────
238
+
239
+ def _base_batch_latency_ms(
240
+ self,
241
+ hardware: str,
242
+ model: Model,
243
+ batch_sz: int,
244
+ precision: str,
245
+ ) -> float:
246
+ """
247
+ Estimate per-batch latency in ms.
248
+ Accounts for hardware tier, model size, batch size, and precision.
249
+ """
250
+ hw_key = self._normalize_hw(hardware)
251
+ per_img = self._lookup_latency(hw_key)
252
+
253
+ # Larger models are slower: +30% per GB of model weights
254
+ size_gb = max(model.size, 1) / (1024 ** 3)
255
+ size_factor = 1.0 + size_gb * 0.30
256
+
257
+ # Batch parallelism: ~65% linear efficiency on GPU, 90% on CPU
258
+ eff = 0.65 if "cpu" not in hw_key else 0.90
259
+ batch_lat = per_img * size_factor * batch_sz * eff
260
+
261
+ # Precision speedup
262
+ speedup = _PRECISION_SPEEDUP.get(precision.upper(), 1.0)
263
+
264
+ return batch_lat / speedup
265
+
266
+ def _get_vram_gb(self, hardware: str, model: Model) -> float:
267
+ hw_key = self._normalize_hw(hardware)
268
+ for key, vram in HARDWARE_VRAM_GB.items():
269
+ if key and key in hw_key:
270
+ return vram
271
+ return 8.0
272
+
273
+ @staticmethod
274
+ def _vram_usage_fraction(hardware: str) -> float:
275
+ """Fraction of VRAM typically consumed during inference."""
276
+ hw = hardware.lower()
277
+ if any(x in hw for x in ("4090", "3090", "a100", "h100")):
278
+ return 0.62
279
+ if any(x in hw for x in ("4080", "3080", "v100", "a10")):
280
+ return 0.60
281
+ if "cpu" in hw:
282
+ return 0.0
283
+ return 0.55
284
+
285
+ @staticmethod
286
+ def _simulate_task_scores(
287
+ task: str, model: Model, dataset: Dataset
288
+ ) -> dict[str, float]:
289
+ """
290
+ Produce realistic metric scores with small per-run variance.
291
+
292
+ PRODUCTION SWAP: replace with actual metric computation:
293
+ from torchmetrics.detection import MeanAveragePrecision
294
+ metric = MeanAveragePrecision()
295
+ metric.update(predictions, targets)
296
+ return metric.compute()
297
+ """
298
+ baselines = dict(_TASK_BASELINES.get(task.lower(), {"accuracy": 0.80}))
299
+ # Small Gaussian jitter simulates run-to-run variance
300
+ return {
301
+ k: float(max(0.0, min(1.0, v + random.gauss(0, 0.015))))
302
+ for k, v in baselines.items()
303
+ }
304
+
305
+ @staticmethod
306
+ def _check_torch_available() -> bool:
307
+ """Return True if PyTorch is installed and importable."""
308
+ try:
309
+ import torch # noqa: F401
310
+ return True
311
+ except ImportError:
312
+ return False
313
+
314
+ @staticmethod
315
+ def _compute_layer_breakdown(task: str, base_lat_ms: float) -> list[LayerBreakdown]:
316
+ """Build a realistic layer breakdown for the given task.
317
+
318
+ Splits total latency across architectural stages with small jitter.
319
+ PRODUCTION SWAP: replace with actual profiler data (e.g. torch.profiler).
320
+ """
321
+ if task.lower() in ("detection", "segmentation"):
322
+ stages = [
323
+ ("Backbone", 0.45),
324
+ ("Neck (FPN/PAFPN)", 0.30),
325
+ ("Detection Head", 0.20),
326
+ ("NMS Post-process", 0.05),
327
+ ]
328
+ elif task.lower() == "classification":
329
+ stages = [
330
+ ("Feature Extractor", 0.70),
331
+ ("Classifier Head", 0.20),
332
+ ("Softmax", 0.10),
333
+ ]
334
+ else:
335
+ stages = [
336
+ ("Encoder", 0.55),
337
+ ("Decoder / Head", 0.35),
338
+ ("Post-process", 0.10),
339
+ ]
340
+
341
+ result: list[LayerBreakdown] = []
342
+ remaining = base_lat_ms
343
+ for name, frac in stages:
344
+ t = round(base_lat_ms * frac + random.gauss(0, base_lat_ms * 0.01), 3)
345
+ result.append(LayerBreakdown(name=name, time_ms=t, percent=round(frac * 100, 1)))
346
+ return result
347
+
348
+ @staticmethod
349
+ def _normalize_hw(hardware: str) -> str:
350
+ return (
351
+ hardware.lower()
352
+ .replace(" ", "")
353
+ .replace("-", "")
354
+ .replace("_", "")
355
+ .replace("nvidia", "")
356
+ .replace("geforce", "")
357
+ )
358
+
359
+ @staticmethod
360
+ def _lookup_latency(hw_key: str) -> float:
361
+ for key, ms in _LATENCY_MS_PER_IMAGE.items():
362
+ if key and key in hw_key:
363
+ return ms
364
+ if any(x in hw_key for x in ("gpu", "rtx", "gtx", "cuda")):
365
+ return _LATENCY_MS_PER_IMAGE["gpu"]
366
+ return _LATENCY_MS_PER_IMAGE["cpu"]
benchmark/metrics.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/metrics.py — Metrics Engine.
3
+
4
+ Computes the final BenchmarkMetrics object from raw execution data:
5
+ - Latency statistics (mean, p95, p99)
6
+ - Throughput (FPS)
7
+ - VRAM statistics (avg, peak)
8
+ - Task-specific scores (mAP, accuracy, IoU) supplied by the executor
9
+
10
+ In a production deployment the task_scores dict comes from actual
11
+ metric computation (e.g. pycocotools, torchmetrics). In this local-first
12
+ build the executor supplies realistic simulated scores.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import statistics
17
+
18
+ from models.benchmark import BenchmarkMetrics, LayerBreakdown, TelemetrySummary
19
+ from observability.logger import get_logger
20
+
21
+ log = get_logger("benchmark.metrics")
22
+
23
+
24
+ class MetricsEngine:
25
+ """Computes BenchmarkMetrics from raw benchmark execution data."""
26
+
27
+ def compute(
28
+ self,
29
+ *,
30
+ task: str,
31
+ latencies_ms: list[float], # per-batch latencies
32
+ total_images: int = 0,
33
+ total_tokens: int = 0,
34
+ batch_size: int,
35
+ vram_samples: list[float], # VRAM readings (GB) during run
36
+ task_scores: dict[str, float], # task-specific metric scores
37
+ ) -> BenchmarkMetrics:
38
+ if not latencies_ms:
39
+ return BenchmarkMetrics(total_images=total_images, total_tokens=total_tokens, batch_size=batch_size)
40
+
41
+ total_time_s = sum(latencies_ms) / 1000.0
42
+ fps = total_images / total_time_s if total_time_s > 0 and total_images > 0 else 0.0
43
+ tps = total_tokens / total_time_s if total_time_s > 0 and total_tokens > 0 else 0.0
44
+
45
+ lat_mean = statistics.mean(latencies_ms)
46
+ lat_p95 = _percentile(latencies_ms, 0.95)
47
+ lat_p99 = _percentile(latencies_ms, 0.99)
48
+
49
+ vram_peak = max(vram_samples) if vram_samples else 0.0
50
+ vram_avg = statistics.mean(vram_samples) if vram_samples else 0.0
51
+
52
+ m = BenchmarkMetrics(
53
+ fps = round(fps, 2),
54
+ tokens_per_sec = round(tps, 2),
55
+ latency_mean_ms = round(lat_mean, 3),
56
+ latency_p95_ms = round(lat_p95, 3),
57
+ latency_p99_ms = round(lat_p99, 3),
58
+ vram_peak_gb = round(vram_peak, 3),
59
+ vram_avg_gb = round(vram_avg, 3),
60
+ total_images = total_images,
61
+ total_tokens = total_tokens,
62
+ batch_size = batch_size,
63
+ )
64
+
65
+ task_lower = task.lower()
66
+
67
+ # CV Task Mapping
68
+ if task_lower in ("detection", "segmentation", "keypoints"):
69
+ m.mAP = _fmt(task_scores.get("mAP", 0.0))
70
+ m.mAP_50 = _fmt(task_scores.get("mAP_50", 0.0))
71
+ m.mAP_50_95 = _fmt(task_scores.get("mAP_50_95", 0.0))
72
+ if task_lower == "segmentation":
73
+ m.iou_mean = _fmt(task_scores.get("iou_mean", 0.0))
74
+
75
+ elif task_lower == "classification":
76
+ m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
77
+ m.top1 = _fmt(task_scores.get("top1", 0.0))
78
+ m.top5 = _fmt(task_scores.get("top5", 0.0))
79
+
80
+ # NLP Task Mapping (ROUGE, BLEU, Perplexity)
81
+ elif task_lower in ("nlp", "generation"):
82
+ m.accuracy = _fmt(task_scores.get("accuracy", 0.0))
83
+ m.rouge_l = _fmt(task_scores.get("rouge_l", task_scores.get("rougeL", 0.0)))
84
+ m.bleu = _fmt(task_scores.get("bleu", 0.0))
85
+ m.perplexity = task_scores.get("perplexity")
86
+
87
+ log.info(
88
+ "metrics_computed",
89
+ task = task,
90
+ fps = m.fps,
91
+ tps = m.tokens_per_sec,
92
+ latency_ms = m.latency_mean_ms,
93
+ vram_peak = m.vram_peak_gb,
94
+ )
95
+ return m
96
+
97
+
98
+ # ── Helpers ───────────────────────────────────────────────────────────────────
99
+
100
+ def _percentile(data: list[float], p: float) -> float:
101
+ if not data:
102
+ return 0.0
103
+ s = sorted(data)
104
+ idx = min(int(len(s) * p), len(s) - 1)
105
+ return s[idx]
106
+
107
+
108
+ def _fmt(v: float) -> float:
109
+ """Round to 4dp and clamp to [0, 1]."""
110
+ return round(max(0.0, min(1.0, v)), 4)
benchmark/orchestrator.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/orchestrator.py — Benchmark Orchestrator (Main Controller).
3
+
4
+ Coordinates the full benchmark lifecycle:
5
+ 1. Resolve model + dataset from their registries
6
+ 2. Run all compatibility checks (gates A–E)
7
+ 3. If valid → create a BenchmarkJob in the DB
8
+ 4. Persist the validation audit log
9
+ 5. Enqueue async background task → execution → metrics → storage
10
+ 6. Return the job immediately so callers are non-blocking
11
+
12
+ Public interface used by api/routes/benchmark.py:
13
+ validate_context(ctx) → ValidationReport (no job created)
14
+ create_and_run(ctx) → BenchmarkJob (job queued, execution in background)
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ from datetime import datetime, timezone
20
+
21
+ from benchmark.adapters.registry import get_executor
22
+ from benchmark.compatibility import CompatibilityValidator
23
+ from benchmark.execution import BenchmarkExecutor
24
+ from benchmark.metrics import MetricsEngine
25
+ import benchmark.registry as bench_reg
26
+ from datasets.registry import get_dataset
27
+ from models.benchmark import (
28
+ BenchmarkContext,
29
+ BenchmarkJob,
30
+ BenchmarkMetrics,
31
+ TelemetrySummary,
32
+ ValidationReport,
33
+ )
34
+ from models.dataset import Dataset
35
+ from models.model import Model
36
+ from observability.logger import audit, get_logger
37
+ from registry.registry import get_model
38
+
39
+ log = get_logger("benchmark.orchestrator")
40
+
41
+ # Module-level singletons — stateless, safe to share
42
+ _validator = CompatibilityValidator()
43
+ _metrics = MetricsEngine()
44
+
45
+ # job_id → asyncio.Task (for future cancellation support)
46
+ _active_tasks: dict[str, asyncio.Task] = {}
47
+
48
+
49
+ # ── Public API ────────────────────────────────────────────────────────────────
50
+
51
+ async def sync_project_benchmarks() -> int:
52
+ """
53
+ Sync benchmark jobs and results from the active project's 'benchmarks' folder.
54
+ This ensures that benchmarks created in different sessions or projects are indexed.
55
+ """
56
+ from benchmark.registry import _get_active_project_benchmark_dir_sync
57
+ from projects.service import get_active_project_path
58
+ import json
59
+ import os
60
+ from database.connection import get_db
61
+
62
+ project_path = await get_active_project_path()
63
+ benchmark_dir = _get_active_project_benchmark_dir_sync(project_path)
64
+ if not benchmark_dir or not benchmark_dir.exists():
65
+ return 0
66
+
67
+ db = await get_db()
68
+ count = 0
69
+
70
+ for file_path in benchmark_dir.glob("*.json"):
71
+ try:
72
+ with open(file_path, "r") as f:
73
+ data = json.load(f)
74
+
75
+ # Check if it's a job or a result
76
+ if file_path.name.startswith("job_"):
77
+ # Upsert into benchmark_jobs
78
+ await db.execute(
79
+ """INSERT OR IGNORE INTO benchmark_jobs
80
+ (id, model_id, dataset_id, task, framework, hardware,
81
+ precision, batch_size, config, status, progress, created_at, updated_at, started_at)
82
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
83
+ (
84
+ data["id"], data["model_id"], data["dataset_id"],
85
+ data["task"], data["framework"], data["hardware"],
86
+ data["precision"], data["batch_size"],
87
+ json.dumps(data["config"]), data["status"],
88
+ data.get("progress", 0.0),
89
+ data.get("created_at", datetime.now(timezone.utc).isoformat()),
90
+ data.get("updated_at", datetime.now(timezone.utc).isoformat()),
91
+ data.get("started_at")
92
+ )
93
+ )
94
+ count += 1
95
+ elif file_path.name.startswith("result_"):
96
+ # Upsert into benchmark_results
97
+ await db.execute(
98
+ """INSERT OR IGNORE INTO benchmark_results
99
+ (id, job_id, metrics, telemetry_summary, created_at)
100
+ VALUES (?,?,?,?,?)""",
101
+ (
102
+ data["id"], data["job_id"],
103
+ json.dumps(data["metrics"]),
104
+ json.dumps(data["telemetry_summary"]),
105
+ data.get("created_at", datetime.now(timezone.utc).isoformat())
106
+ )
107
+ )
108
+ count += 1
109
+ except Exception as e:
110
+ log.error("sync_file_failed", file=file_path.name, error=str(e))
111
+
112
+ await db.commit()
113
+ log.info("sync_complete", count=count)
114
+ return count
115
+
116
+ async def validate_context(ctx: BenchmarkContext) -> ValidationReport:
117
+ """
118
+ Validate model ↔ dataset ↔ hardware compatibility.
119
+ Does NOT create a job. Safe to call repeatedly from the UI.
120
+ """
121
+ model = await _require_model(ctx.model_id)
122
+
123
+ # ── Handle Polymorphic Input (Video/Live) ��───────────────────────────────
124
+ if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
125
+ # Create a synthetic dataset object for non-dataset sources
126
+ now = datetime.now(timezone.utc).isoformat()
127
+ dataset = Dataset(
128
+ id="none",
129
+ name="Live/Video Stream",
130
+ task=model.task, # Match model task to pass task check
131
+ format="custom",
132
+ source="local",
133
+ status="imported",
134
+ images=0,
135
+ classes=0,
136
+ size_label="0 MB",
137
+ created_at=now,
138
+ updated_at=now
139
+ )
140
+ else:
141
+ dataset = await _require_dataset(ctx.dataset_id)
142
+
143
+ return _validator.validate(model, dataset, ctx)
144
+
145
+
146
+ async def create_and_run(ctx: BenchmarkContext) -> BenchmarkJob:
147
+ """
148
+ Full benchmark initiation:
149
+ """
150
+ model = await _require_model(ctx.model_id)
151
+
152
+ # ── Handle Polymorphic Input (Video/Live) ────────────────────────────────
153
+ if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none":
154
+ now = datetime.now(timezone.utc).isoformat()
155
+ dataset = Dataset(
156
+ id="none",
157
+ name="Live/Video Stream",
158
+ task=model.task,
159
+ format="custom",
160
+ source="local",
161
+ status="imported",
162
+ images=0,
163
+ classes=0,
164
+ size_label="0 MB",
165
+ created_at=now,
166
+ updated_at=now
167
+ )
168
+ else:
169
+ dataset = await _require_dataset(ctx.dataset_id)
170
+
171
+ # ── Compatibility check ───────────────────────────────────────────────────
172
+ report = _validator.validate(model, dataset, ctx)
173
+
174
+ # Always persist the validation log (even for failures)
175
+ await bench_reg.save_validation_log(
176
+ job_id = "pre-check",
177
+ model_id = ctx.model_id,
178
+ dataset_id = ctx.dataset_id,
179
+ checks = report.checks,
180
+ passed = report.passed,
181
+ )
182
+
183
+ if not report.passed:
184
+ from fastapi import HTTPException
185
+ failed = [c for c in report.checks if not c.passed]
186
+ raise HTTPException(
187
+ status_code = 422,
188
+ detail = {
189
+ "error": "Compatibility validation failed",
190
+ "failed_checks": [
191
+ {
192
+ "name": c.name,
193
+ "detail": c.detail,
194
+ "suggestion": c.suggestion,
195
+ }
196
+ for c in failed
197
+ ],
198
+ },
199
+ )
200
+
201
+ # ── Create job ────────────────────────────────────────────────────────────
202
+ job = await bench_reg.create_job(ctx)
203
+
204
+ # Overwrite 'pre-check' validation log with the real job_id
205
+ await bench_reg.save_validation_log(
206
+ job_id = job.id,
207
+ model_id = ctx.model_id,
208
+ dataset_id = ctx.dataset_id,
209
+ checks = report.checks,
210
+ passed = True,
211
+ )
212
+
213
+ # ── Log the Polymorphic Input params ─────────────────────────────────────
214
+ if ctx.input_source or ctx.video_path or ctx.rtsp_url:
215
+ log.info("polymorphic_input_received",
216
+ job_id=job.id,
217
+ source=ctx.input_source,
218
+ video=ctx.video_path,
219
+ rtsp=ctx.rtsp_url)
220
+
221
+ # ── Enqueue background execution ──────────────────────────────────────────
222
+ task = asyncio.create_task(
223
+ _execute_job(job.id, ctx, model, dataset),
224
+ name = f"benchmark_{job.id}",
225
+ )
226
+ _active_tasks[job.id] = task
227
+ task.add_done_callback(lambda _t: _active_tasks.pop(job.id, None))
228
+
229
+ log.info("benchmark_enqueued", job_id=job.id, model=ctx.model_id)
230
+ return job
231
+
232
+
233
+ # ── Background execution ──────────────────────────────────────────────────────
234
+
235
+ async def _execute_job(
236
+ job_id: str,
237
+ ctx: BenchmarkContext,
238
+ model: Model,
239
+ dataset: Dataset,
240
+ ) -> None:
241
+ """Full benchmark lifecycle — runs in an asyncio background task."""
242
+ now = datetime.now(timezone.utc).isoformat()
243
+
244
+ # Transition → running
245
+ ts_color = "\x1b[36m" # Cyan
246
+ info_color = "\x1b[34m" # Blue
247
+ success_color = "\x1b[32m" # Green
248
+ reset = "\x1b[0m"
249
+
250
+ await bench_reg.update_job(
251
+ job_id,
252
+ status = "running",
253
+ progress = 0.0,
254
+ started_at = now,
255
+ log_entry = f"{ts_color}[{now}]{reset} {info_color}Job started{reset} on {ctx.hardware} ({ctx.precision})",
256
+ )
257
+
258
+ runner = BenchmarkExecutor()
259
+
260
+ try:
261
+ # ── Fetch the persisted job (for executor) ────────────────────────────
262
+ job = await bench_reg.get_job(job_id)
263
+ assert job is not None, "Job disappeared from DB after creation"
264
+
265
+ # ── Define Progress Callback ──────────────────────────────────────────
266
+ async def on_progress(progress: float, message: str, telemetry: Any | None):
267
+ await bench_reg.update_job(
268
+ job_id,
269
+ progress=progress,
270
+ log_entry=f"{ts_color}[{datetime.now(timezone.utc).isoformat()}]{reset} {info_color}{message}{reset}",
271
+ last_telemetry=telemetry.model_dump() if telemetry and hasattr(telemetry, "model_dump") else telemetry
272
+ )
273
+
274
+ # ── Execution Loop ────────────────────────────────────────────────────
275
+ exec_result = await runner.execute(
276
+ job=job,
277
+ model=model,
278
+ dataset=dataset,
279
+ on_progress=on_progress
280
+ )
281
+
282
+ # ── Compute metrics ───────────────────────────────────────────────────
283
+ metrics = _metrics.compute(
284
+ task = ctx.task,
285
+ latencies_ms = exec_result.latencies_ms,
286
+ total_images = exec_result.total_images,
287
+ batch_size = ctx.batch_size,
288
+ vram_samples = exec_result.vram_samples,
289
+ task_scores = exec_result.task_scores,
290
+ )
291
+
292
+ # ── Persist result ────────────────────────────────────────────────────
293
+ await bench_reg.save_result(
294
+ job_id = job_id,
295
+ metrics = metrics,
296
+ telemetry_summary = exec_result.telemetry_summary,
297
+ )
298
+
299
+ ended = datetime.now(timezone.utc).isoformat()
300
+ await bench_reg.update_job(
301
+ job_id,
302
+ status = "completed",
303
+ progress = 1.0,
304
+ ended_at = ended,
305
+ log_entry = f"{ts_color}[{ended}]{reset} {success_color}Benchmark completed{reset} — {metrics.fps} FPS",
306
+ )
307
+
308
+ await audit(
309
+ "benchmark_completed",
310
+ job_id = job_id,
311
+ payload = {"model_id": ctx.model_id, "dataset_id": ctx.dataset_id},
312
+ )
313
+ log.info(
314
+ "benchmark_completed",
315
+ job_id = job_id,
316
+ fps = metrics.fps,
317
+ lat_ms = metrics.latency_mean_ms,
318
+ )
319
+
320
+ except asyncio.CancelledError:
321
+ # Task cancelled externally (e.g. server shutdown) — don't swallow
322
+ ended = datetime.now(timezone.utc).isoformat()
323
+ await bench_reg.update_job(
324
+ job_id,
325
+ status = "failed",
326
+ error = "Job cancelled",
327
+ ended_at = ended,
328
+ log_entry = f"{ts_color}[{ended}]{reset} \x1b[31mJob cancelled\x1b[0m",
329
+ )
330
+ raise
331
+
332
+ except Exception as exc:
333
+ ended = datetime.now(timezone.utc).isoformat()
334
+ err_msg = str(exc)
335
+ error_color = "\x1b[31m" # Red
336
+ await bench_reg.update_job(
337
+ job_id,
338
+ status = "failed",
339
+ error = err_msg,
340
+ ended_at = ended,
341
+ log_entry = f"{ts_color}[{ended}]{reset} {error_color}ERROR: {err_msg}{reset}",
342
+ )
343
+ await audit(
344
+ "benchmark_failed",
345
+ job_id = job_id,
346
+ level = "error",
347
+ payload = {"error": err_msg, "model_id": ctx.model_id},
348
+ )
349
+ log.exception("benchmark_failed", job_id=job_id)
350
+ finally:
351
+ pass
352
+
353
+ # ── Resource resolvers ────────────────────────────────────────────────────────
354
+
355
+ async def _require_model(model_id: str) -> Model:
356
+ model = await get_model(model_id)
357
+ if not model:
358
+ from fastapi import HTTPException
359
+ raise HTTPException(
360
+ status_code = 404,
361
+ detail = f"Model '{model_id}' not found in Model Zoo",
362
+ )
363
+ return model
364
+
365
+
366
+ async def _require_dataset(dataset_id: str) -> Dataset:
367
+ dataset = await get_dataset(dataset_id)
368
+ if not dataset:
369
+ from fastapi import HTTPException
370
+ raise HTTPException(
371
+ status_code = 404,
372
+ detail = f"Dataset '{dataset_id}' not found in Dataset Manager",
373
+ )
374
+ return dataset
benchmark/registry.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/registry.py — Benchmark Registry.
3
+
4
+ All DB interactions for:
5
+ • benchmark_jobs — job lifecycle state
6
+ • benchmark_results — final metrics + telemetry summary
7
+ • benchmark_validation_logs — immutable check audit trail
8
+
9
+ Follows the same pattern as registry/registry.py and datasets/registry.py.
10
+ No direct DB access from other benchmark modules — everything routes here.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import uuid
16
+ from datetime import datetime, timezone
17
+ from typing import Any
18
+ from pathlib import Path
19
+
20
+ from database.connection import get_db
21
+ from models.benchmark import (
22
+ BenchmarkContext,
23
+ BenchmarkJob,
24
+ BenchmarkMetrics,
25
+ BenchmarkResult,
26
+ TelemetrySummary,
27
+ ValidationCheck,
28
+ row_to_job,
29
+ row_to_result,
30
+ )
31
+ from observability.logger import get_logger
32
+
33
+ log = get_logger("benchmark.registry")
34
+
35
+
36
+ def _get_active_project_benchmark_dir_sync(project_path: str | None) -> Path | None:
37
+ """Get the absolute path to the 'benchmarks' folder in a given project path."""
38
+ if not project_path:
39
+ return None
40
+
41
+ benchmark_dir = Path(project_path) / "benchmarks"
42
+ benchmark_dir.mkdir(parents=True, exist_ok=True)
43
+ return benchmark_dir
44
+
45
+ async def _get_active_project_benchmark_dir() -> Path | None:
46
+ """Get the absolute path to the 'benchmarks' folder in the active project."""
47
+ from projects.service import get_active_project_path
48
+ project_path = await get_active_project_path()
49
+ return _get_active_project_benchmark_dir_sync(project_path)
50
+
51
+ async def _save_to_project(filename: str, data: dict) -> None:
52
+ """Save data to a JSON file in the active project's benchmark folder."""
53
+ benchmark_dir = await _get_active_project_benchmark_dir()
54
+ if not benchmark_dir:
55
+ return
56
+
57
+ file_path = benchmark_dir / filename
58
+ try:
59
+ with open(file_path, "w") as f:
60
+ json.dump(data, f, indent=2)
61
+ except Exception as e:
62
+ log.error("project_persistence_failed", error=str(e), file=filename)
63
+
64
+ # ── Job CRUD ──────────────────────────────────────────────────────────────────
65
+
66
+ async def create_job(ctx: BenchmarkContext) -> BenchmarkJob:
67
+ db = await get_db()
68
+ job_id = f"bmark-{uuid.uuid4().hex[:12]}"
69
+ now = datetime.now(timezone.utc).isoformat()
70
+
71
+ # Create job object
72
+ job = BenchmarkJob(
73
+ id = job_id,
74
+ model_id = ctx.model_id,
75
+ dataset_id = ctx.dataset_id,
76
+ task = ctx.task,
77
+ framework = ctx.framework,
78
+ hardware = ctx.hardware,
79
+ precision = ctx.precision,
80
+ batch_size = ctx.batch_size,
81
+ config = ctx.model_dump(),
82
+ status = "queued",
83
+ progress = 0.0,
84
+ created_at = now,
85
+ updated_at = now,
86
+ )
87
+
88
+ # Persist to SQLite
89
+ await db.execute(
90
+ """INSERT INTO benchmark_jobs
91
+ (id, model_id, dataset_id, task, framework, hardware,
92
+ precision, batch_size, config,
93
+ status, progress, logs, created_at, updated_at)
94
+ VALUES (?,?,?,?,?,?,?,?,?,'queued',0.0,'[]',?,?)""",
95
+ (
96
+ job_id,
97
+ ctx.model_id, ctx.dataset_id,
98
+ ctx.task, ctx.framework, ctx.hardware,
99
+ ctx.precision, ctx.batch_size,
100
+ json.dumps(ctx.model_dump()),
101
+ now, now,
102
+ ),
103
+ )
104
+ await db.commit()
105
+
106
+ # Persist to project folder
107
+ await _save_to_project(f"job_{job_id}.json", job.model_dump())
108
+
109
+ log.info("benchmark_job_created", job_id=job_id, model=ctx.model_id)
110
+ return job
111
+
112
+
113
+ async def get_job(job_id: str) -> BenchmarkJob | None:
114
+ db = await get_db()
115
+ async with db.execute(
116
+ "SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)
117
+ ) as cur:
118
+ row = await cur.fetchone()
119
+ return row_to_job(row) if row else None
120
+
121
+
122
+ async def list_jobs(
123
+ *,
124
+ status: str | None = None,
125
+ model_id: str | None = None,
126
+ limit: int = 100,
127
+ ) -> list[BenchmarkJob]:
128
+ db = await get_db()
129
+ clauses: list[str] = []
130
+ params: list[Any] = []
131
+
132
+ if status:
133
+ clauses.append("status = ?")
134
+ params.append(status)
135
+ if model_id:
136
+ clauses.append("model_id = ?")
137
+ params.append(model_id)
138
+
139
+ where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
140
+ params.append(limit)
141
+
142
+ async with db.execute(
143
+ f"SELECT * FROM benchmark_jobs {where} ORDER BY created_at DESC LIMIT ?",
144
+ params,
145
+ ) as cur:
146
+ rows = await cur.fetchall()
147
+ return [row_to_job(r) for r in rows]
148
+
149
+
150
+ async def update_job(
151
+ job_id: str,
152
+ *,
153
+ status: str | None = None,
154
+ progress: float | None = None,
155
+ error: str | None = None,
156
+ started_at: str | None = None,
157
+ ended_at: str | None = None,
158
+ log_entry: str | None = None,
159
+ last_telemetry: dict | None = None,
160
+ ) -> None:
161
+ """Update mutable fields on a benchmark job atomically."""
162
+ db = await get_db()
163
+ now = datetime.now(timezone.utc).isoformat()
164
+
165
+ sets: list[str] = ["updated_at = ?"]
166
+ vals: list[Any] = [now]
167
+
168
+ if status is not None:
169
+ sets.append("status = ?"); vals.append(status)
170
+ if progress is not None:
171
+ sets.append("progress = ?"); vals.append(round(progress, 4))
172
+ if error is not None:
173
+ sets.append("error = ?"); vals.append(error)
174
+ if started_at is not None:
175
+ sets.append("started_at = ?"); vals.append(started_at)
176
+ if ended_at is not None:
177
+ sets.append("ended_at = ?"); vals.append(ended_at)
178
+ if last_telemetry is not None:
179
+ sets.append("last_telemetry = ?"); vals.append(json.dumps(last_telemetry))
180
+
181
+ if log_entry is not None:
182
+ # Append new entry to the JSON log array (capped at 500 lines)
183
+ async with db.execute(
184
+ "SELECT logs FROM benchmark_jobs WHERE id = ?", (job_id,)
185
+ ) as cur:
186
+ row = await cur.fetchone()
187
+ existing = json.loads(row["logs"]) if row and row["logs"] else []
188
+ existing.append(log_entry)
189
+ sets.append("logs = ?")
190
+ vals.append(json.dumps(existing[-500:]))
191
+
192
+ vals.append(job_id)
193
+ # Persist to project folder if we have the job info
194
+ async with db.execute("SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)) as cur:
195
+ row = await cur.fetchone()
196
+ if row:
197
+ job = row_to_job(row)
198
+ if job:
199
+ await _save_to_project(f"job_{job_id}.json", job.model_dump())
200
+
201
+ await db.commit()
202
+
203
+
204
+ # ── Result CRUD ───────────────────────────────────────────────────────────────
205
+
206
+ async def save_result(
207
+ *,
208
+ job_id: str,
209
+ metrics: BenchmarkMetrics,
210
+ telemetry_summary: TelemetrySummary,
211
+ ) -> BenchmarkResult:
212
+ db = await get_db()
213
+ result_id = f"bres-{uuid.uuid4().hex[:12]}"
214
+ now = datetime.now(timezone.utc).isoformat()
215
+
216
+ # Persist result to SQLite
217
+ await db.execute(
218
+ """INSERT INTO benchmark_results
219
+ (id, job_id, metrics, telemetry_summary, created_at)
220
+ VALUES (?,?,?,?,?)""",
221
+ (
222
+ result_id,
223
+ job_id,
224
+ json.dumps(metrics.model_dump(exclude_none=True)),
225
+ json.dumps(telemetry_summary.model_dump()),
226
+ now,
227
+ ),
228
+ )
229
+ await db.commit()
230
+
231
+ result = BenchmarkResult(
232
+ id = result_id,
233
+ job_id = job_id,
234
+ metrics = metrics,
235
+ telemetry_summary = telemetry_summary,
236
+ created_at = now,
237
+ )
238
+
239
+ # Persist result to project folder
240
+ await _save_to_project(f"result_{job_id}.json", result.model_dump())
241
+
242
+ log.info("benchmark_result_saved", job_id=job_id, result_id=result_id)
243
+ return result
244
+
245
+
246
+ async def get_result(job_id: str) -> BenchmarkResult | None:
247
+ db = await get_db()
248
+ async with db.execute(
249
+ """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
250
+ FROM benchmark_results r
251
+ JOIN benchmark_jobs j ON r.job_id = j.id
252
+ WHERE r.job_id = ?""", (job_id,)
253
+ ) as cur:
254
+ row = await cur.fetchone()
255
+ return row_to_result(row) if row else None
256
+
257
+
258
+ async def list_results(*, limit: int = 100) -> list[BenchmarkResult]:
259
+ db = await get_db()
260
+ async with db.execute(
261
+ """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision
262
+ FROM benchmark_results r
263
+ JOIN benchmark_jobs j ON r.job_id = j.id
264
+ ORDER BY r.created_at DESC LIMIT ?""", (limit,)
265
+ ) as cur:
266
+ rows = await cur.fetchall()
267
+ return [row_to_result(r) for r in rows]
268
+
269
+
270
+ # ── Validation Log ────────────────────────────────────────────────────────────
271
+
272
+ async def save_validation_log(
273
+ *,
274
+ job_id: str,
275
+ model_id: str,
276
+ dataset_id: str,
277
+ checks: list[ValidationCheck],
278
+ passed: bool,
279
+ ) -> None:
280
+ """Persist an immutable record of all compatibility checks."""
281
+ db = await get_db()
282
+ log_id = f"bval-{uuid.uuid4().hex[:12]}"
283
+ now = datetime.now(timezone.utc).isoformat()
284
+
285
+ await db.execute(
286
+ """INSERT INTO benchmark_validation_logs
287
+ (id, job_id, model_id, dataset_id, checks, passed, created_at)
288
+ VALUES (?,?,?,?,?,?,?)""",
289
+ (
290
+ log_id, job_id, model_id, dataset_id,
291
+ json.dumps([c.model_dump() for c in checks]),
292
+ 1 if passed else 0,
293
+ now,
294
+ ),
295
+ )
296
+ await db.commit()
297
+ log.info(
298
+ "validation_log_saved",
299
+ job_id = job_id,
300
+ passed = passed,
301
+ n_checks = len(checks),
302
+ )
benchmark/telemetry.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/telemetry.py — Real-time Telemetry Collector.
3
+
4
+ Collects GPU/hardware metrics at 2 Hz during benchmark execution.
5
+ Designed as a drop-in adapter:
6
+ • Local dev → simulates realistic GPU readings based on hardware tier
7
+ • Production → replace _read_gpu_metrics() with pynvml calls:
8
+ nvmlDeviceGetUtilizationRates()
9
+ nvmlDeviceGetMemoryInfo()
10
+ nvmlDeviceGetTemperature()
11
+ nvmlDeviceGetPowerUsage()
12
+
13
+ Usage (async context):
14
+ collector = TelemetryCollector("rtx4090", vram_total_gb=24.0)
15
+ await collector.start()
16
+ # ... run inference ...
17
+ summary = await collector.stop()
18
+ samples = collector.samples
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import asyncio
23
+ import random
24
+ import statistics
25
+ import time
26
+
27
+ from models.benchmark import TelemetrySample, TelemetrySummary
28
+ from observability.logger import get_logger
29
+
30
+ log = get_logger("benchmark.telemetry")
31
+
32
+ # ── Hardware simulation profiles ──────────────────────────────────────────────
33
+ # (base_util%, base_temp_C, base_power_W)
34
+ _HW_PROFILES: dict[str, tuple[float, float, float]] = {
35
+ "rtx4090": (88.0, 74.0, 380.0),
36
+ "rtx4080": (84.0, 70.0, 280.0),
37
+ "rtx4070": (80.0, 68.0, 200.0),
38
+ "rtx3090": (85.0, 72.0, 320.0),
39
+ "rtx3080": (82.0, 70.0, 250.0),
40
+ "rtx3070": (78.0, 66.0, 180.0),
41
+ "rtx3060": (74.0, 64.0, 150.0),
42
+ "a100": (90.0, 68.0, 350.0),
43
+ "h100": (92.0, 65.0, 550.0),
44
+ "v100": (87.0, 70.0, 280.0),
45
+ "t4": (75.0, 62.0, 60.0),
46
+ "gpu": (70.0, 65.0, 150.0),
47
+ "cpu": (0.0, 0.0, 0.0),
48
+ }
49
+
50
+ _COLLECTION_INTERVAL_S = 0.5 # 2 Hz
51
+
52
+
53
+ class TelemetryCollector:
54
+ """
55
+ Async telemetry collector. Call start() before inference, stop() after.
56
+ Thread-safe via asyncio (single-threaded event loop).
57
+ """
58
+
59
+ def __init__(self, hardware: str, vram_total_gb: float = 8.0) -> None:
60
+ self._hardware = hardware
61
+ self._vram_total = vram_total_gb
62
+ self._hw_profile = self._resolve_profile(hardware)
63
+ self._samples: list[TelemetrySample] = []
64
+ self._running = False
65
+ self._task: asyncio.Task | None = None
66
+
67
+ # ── Public API ────────────────────────────────────────────────────────────
68
+
69
+ async def start(self) -> None:
70
+ self._running = True
71
+ self._samples = []
72
+ self._task = asyncio.create_task(
73
+ self._collect_loop(), name="telemetry_collector"
74
+ )
75
+ log.debug("telemetry_started", hardware=self._hardware)
76
+
77
+ async def stop(self) -> TelemetrySummary:
78
+ self._running = False
79
+ if self._task and not self._task.done():
80
+ self._task.cancel()
81
+ try:
82
+ await self._task
83
+ except asyncio.CancelledError:
84
+ pass
85
+ log.debug(
86
+ "telemetry_stopped",
87
+ hardware = self._hardware,
88
+ samples = len(self._samples),
89
+ )
90
+ return self._build_summary()
91
+
92
+ def record_batch_context(self, batch_idx: int, progress: float) -> None:
93
+ """Annotate the most recent sample with the current batch context."""
94
+ if self._samples:
95
+ self._samples[-1].batch_idx = batch_idx
96
+ self._samples[-1].progress = progress
97
+
98
+ @property
99
+ def samples(self) -> list[TelemetrySample]:
100
+ return list(self._samples)
101
+
102
+ # ── Internal ──────────────────────────────────────────────────────────────
103
+
104
+ async def _collect_loop(self) -> None:
105
+ while self._running:
106
+ sample = self._read_gpu_metrics()
107
+ self._samples.append(sample)
108
+ await asyncio.sleep(_COLLECTION_INTERVAL_S)
109
+
110
+ def _read_gpu_metrics(self) -> TelemetrySample:
111
+ """
112
+ Returns a TelemetrySample for the current hardware state.
113
+
114
+ PRODUCTION SWAP: Replace this body with pynvml calls:
115
+ handle = nvmlDeviceGetHandleByIndex(0)
116
+ util = nvmlDeviceGetUtilizationRates(handle)
117
+ mem = nvmlDeviceGetMemoryInfo(handle)
118
+ temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
119
+ power = nvmlDeviceGetPowerUsage(handle) / 1000 # mW → W
120
+ """
121
+ base_util, base_temp, base_power = self._hw_profile
122
+
123
+ if base_util == 0.0: # CPU path — no meaningful GPU readings
124
+ return TelemetrySample(
125
+ timestamp = time.time(),
126
+ gpu_util_pct = 0.0,
127
+ vram_used_gb = 0.0,
128
+ vram_total_gb = 0.0,
129
+ temp_c = 0.0,
130
+ power_w = 0.0,
131
+ )
132
+
133
+ # Simulate realistic jitter (±5% util, ±3°C, ±10W)
134
+ jitter_util = random.gauss(0, 3.0)
135
+ jitter_temp = random.gauss(0, 1.5)
136
+ jitter_power = random.gauss(0, 8.0)
137
+ vram_frac = random.uniform(0.58, 0.72)
138
+
139
+ return TelemetrySample(
140
+ timestamp = time.time(),
141
+ gpu_util_pct = max(0.0, min(100.0, base_util + jitter_util)),
142
+ vram_used_gb = round(
143
+ max(0.0, min(self._vram_total, self._vram_total * vram_frac)), 3
144
+ ),
145
+ vram_total_gb = self._vram_total,
146
+ temp_c = round(max(0.0, base_temp + jitter_temp), 1),
147
+ power_w = round(max(0.0, base_power + jitter_power), 1),
148
+ )
149
+
150
+ def _build_summary(self) -> TelemetrySummary:
151
+ if not self._samples:
152
+ return TelemetrySummary()
153
+
154
+ utils = [s.gpu_util_pct for s in self._samples]
155
+ vrams = [s.vram_used_gb for s in self._samples]
156
+ temps = [s.temp_c for s in self._samples]
157
+ powers = [s.power_w for s in self._samples]
158
+
159
+ def _safe_mean(lst: list[float]) -> float:
160
+ return statistics.mean(lst) if lst else 0.0
161
+
162
+ return TelemetrySummary(
163
+ gpu_util_avg = round(_safe_mean(utils), 2),
164
+ gpu_util_peak = round(max(utils), 2),
165
+ vram_avg_gb = round(_safe_mean(vrams), 3),
166
+ vram_peak_gb = round(max(vrams), 3),
167
+ temp_avg_c = round(_safe_mean(temps), 1),
168
+ temp_peak_c = round(max(temps), 1),
169
+ power_avg_w = round(_safe_mean(powers), 1),
170
+ power_peak_w = round(max(powers), 1),
171
+ )
172
+
173
+ @staticmethod
174
+ def _resolve_profile(hardware: str) -> tuple[float, float, float]:
175
+ hw = hardware.lower().replace(" ", "").replace("-", "")
176
+ for key, profile in _HW_PROFILES.items():
177
+ if key in hw:
178
+ return profile
179
+ # Default for unknown GPU-class hardware
180
+ if any(x in hw for x in ("gpu", "rtx", "gtx", "cuda", "vram")):
181
+ return _HW_PROFILES["gpu"]
182
+ return _HW_PROFILES["cpu"]
benchmark/torch_runner.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark/torch_runner.py — Synchronous GPU inference runner.
3
+
4
+ Called from BenchmarkExecutor via asyncio.run_in_executor() so it never
5
+ blocks the event loop. PyTorch is an optional dependency — if it is not
6
+ installed the module raises ImportError and execution.py falls back to
7
+ the simulation path.
8
+
9
+ Supported weight formats (detected by file extension):
10
+ .pt / .pth — torch.load (TorchScript or state-dict)
11
+ .safetensors — safetensors.torch.load_file
12
+ .onnx — onnxruntime InferenceSession
13
+
14
+ PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>>
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ # ── Model cache (keyed by absolute path) ─────────────────────────────────────
23
+ _MODEL_CACHE: dict[str, Any] = {}
24
+
25
+ # Standard input shapes per task (B, C, H, W)
26
+ _INPUT_SHAPES: dict[str, tuple[int, int, int]] = {
27
+ "detection": (3, 640, 640),
28
+ "segmentation": (3, 640, 640),
29
+ "classification": (3, 224, 224),
30
+ "generation": (3, 512, 512),
31
+ "embedding": (3, 224, 224),
32
+ }
33
+ _DEFAULT_SHAPE = (3, 640, 640)
34
+
35
+
36
+ def run_torch_batch(model_path: str, batch_size: int, task: str = "detection") -> float:
37
+ """Run one inference batch and return per-image latency in ms.
38
+
39
+ Args:
40
+ model_path: Absolute path to the weight file.
41
+ batch_size: Number of images in the batch.
42
+ task: Model task (affects dummy input shape).
43
+
44
+ Returns:
45
+ Latency per image in milliseconds.
46
+ """
47
+ import torch # raises ImportError if not installed
48
+
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ ext = Path(model_path).suffix.lower()
51
+
52
+ model = _load_model(model_path, ext, device)
53
+ c, h, w = _INPUT_SHAPES.get(task, _DEFAULT_SHAPE)
54
+ dummy = torch.zeros(batch_size, c, h, w, device=device)
55
+
56
+ # Warm-up pass (first call is slower due to CUDA kernel compilation)
57
+ if device == "cuda":
58
+ with torch.no_grad():
59
+ _forward(model, dummy, ext, device)
60
+ torch.cuda.synchronize()
61
+
62
+ # Timed pass
63
+ if device == "cuda":
64
+ torch.cuda.synchronize()
65
+ t0 = time.perf_counter()
66
+ with torch.no_grad():
67
+ _forward(model, dummy, ext, device)
68
+ if device == "cuda":
69
+ torch.cuda.synchronize()
70
+ elapsed_ms = (time.perf_counter() - t0) * 1000
71
+
72
+ return elapsed_ms / batch_size
73
+
74
+
75
+ def _load_model(path: str, ext: str, device: str) -> Any:
76
+ """Load and cache the model by absolute path."""
77
+ if path in _MODEL_CACHE:
78
+ return _MODEL_CACHE[path]
79
+
80
+ model = _load_by_ext(path, ext, device)
81
+ _MODEL_CACHE[path] = model
82
+ return model
83
+
84
+
85
+ def _load_by_ext(path: str, ext: str, device: str) -> Any:
86
+ """Select loader based on file extension."""
87
+ if ext in (".pt", ".pth"):
88
+ return _load_torch(path, device)
89
+ if ext == ".safetensors":
90
+ return _load_safetensors(path, device)
91
+ if ext == ".onnx":
92
+ return _load_onnx(path)
93
+ raise ValueError(f"Unsupported model format: {ext}")
94
+
95
+
96
+ def _load_torch(path: str, device: str) -> Any:
97
+ import torch
98
+ # <<< REPLACE IN PRODUCTION >>> with proper model class instantiation
99
+ # TorchScript models can be loaded directly; state-dict models need
100
+ # the model class to be imported separately.
101
+ try:
102
+ model = torch.jit.load(path, map_location=device)
103
+ model.eval()
104
+ return model
105
+ except RuntimeError:
106
+ # Not a TorchScript model — try loading as a full checkpoint
107
+ obj = torch.load(path, map_location=device, weights_only=False)
108
+ if hasattr(obj, "eval"):
109
+ obj.eval()
110
+ return obj
111
+ # It's a state-dict — we cannot run inference without knowing the arch
112
+ raise RuntimeError(
113
+ f"Model at {path} is a state-dict; cannot run inference without "
114
+ "the model class. Use a TorchScript-exported .pt file."
115
+ )
116
+
117
+
118
+ def _load_safetensors(path: str, device: str) -> Any:
119
+ # <<< REPLACE IN PRODUCTION >>> safetensors gives tensors only;
120
+ # you still need the model class. This is intentionally left as a
121
+ # placeholder that raises a clear error rather than silently failing.
122
+ raise NotImplementedError(
123
+ "safetensors inference requires the model class to be registered. "
124
+ "Convert to TorchScript or ONNX for architecture-agnostic inference."
125
+ )
126
+
127
+
128
+ def _load_onnx(path: str) -> Any:
129
+ import onnxruntime as ort # type: ignore[import]
130
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
131
+ return ort.InferenceSession(path, providers=providers)
132
+
133
+
134
+ def _forward(model: Any, dummy: Any, ext: str, device: str) -> Any:
135
+ """Run a single forward pass, dispatching by model type."""
136
+ if ext == ".onnx":
137
+ import numpy as np
138
+ np_input = dummy.cpu().numpy()
139
+ input_name = model.get_inputs()[0].name
140
+ return model.run(None, {input_name: np_input})
141
+ # TorchScript / nn.Module
142
+ return model(dummy)
datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # datasets package
datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
datasets/__pycache__/annotation_parser.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
datasets/__pycache__/base_adapter.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
datasets/__pycache__/format_adapters.cpython-310.pyc ADDED
Binary file (9.18 kB). View file
 
datasets/__pycache__/import_service.cpython-310.pyc ADDED
Binary file (16.8 kB). View file
 
datasets/__pycache__/registry.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
datasets/__pycache__/viewer_service.cpython-310.pyc ADDED
Binary file (8.22 kB). View file
 
datasets/annotation_parser.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets/annotation_parser.py — Multi-format annotation parser.
3
+
4
+ Supports:
5
+ - YOLO (darknet .txt + classes.txt / data.yaml)
6
+ - COCO (instances_*.json / _annotations.coco.json)
7
+ - Pascal VOC (*.xml)
8
+
9
+ All formats normalise to the unified Annotation schema with
10
+ normalised bounding boxes (0–1 range, x_topleft, y_topleft, w, h).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import csv
15
+ import json
16
+ import re
17
+ import uuid
18
+ import xml.etree.ElementTree as ET
19
+ from pathlib import Path
20
+ from typing import Iterator, Optional
21
+
22
+ from observability.logger import get_logger
23
+
24
+ log = get_logger("annotation_parser")
25
+
26
+
27
+ # ── Unified Output ────────────────────────────────────────────────────────────
28
+
29
+ def _make_ann(
30
+ image_id: str,
31
+ dataset_id: str,
32
+ label: str,
33
+ bbox: tuple[float, float, float, float] | None = None, # x, y, w, h (normalised)
34
+ normalised: bool = True,
35
+ area: float | None = None,
36
+ confidence: float | None = None,
37
+ ann_type: str = "detection",
38
+ segmentation: list[list[float]] | None = None,
39
+ keypoints: list[float] | None = None,
40
+ metadata: dict | None = None,
41
+ ) -> dict:
42
+ return {
43
+ "id": f"ann-{uuid.uuid4().hex[:12]}",
44
+ "image_id": image_id,
45
+ "dataset_id": dataset_id,
46
+ "label": label,
47
+ "bbox_x": bbox[0] if bbox else None,
48
+ "bbox_y": bbox[1] if bbox else None,
49
+ "bbox_w": bbox[2] if bbox else None,
50
+ "bbox_h": bbox[3] if bbox else None,
51
+ "normalised": 1 if normalised else 0,
52
+ "area": area,
53
+ "confidence": confidence,
54
+ "ann_type": ann_type,
55
+ "segmentation": json.dumps(segmentation) if segmentation else None,
56
+ "keypoints": json.dumps(keypoints) if keypoints else None,
57
+ "metadata": json.dumps(metadata) if metadata else None,
58
+ }
59
+
60
+
61
+ # ── YOLO Parser ───────────────────────────────────────────────────────────────
62
+
63
+ class YOLOParser:
64
+ """
65
+ Reads YOLO darknet annotation files (.txt) + class map.
66
+ Each line: <class_id> <cx> <cy> <w> <h> (all normalised 0–1)
67
+ """
68
+
69
+ @staticmethod
70
+ def load_class_map(dataset_root: Path) -> list[str]:
71
+ """Attempt to load class names from data.yaml or classes.txt."""
72
+ # Try data.yaml first
73
+ for yaml_file in dataset_root.rglob("data.yaml"):
74
+ try:
75
+ import yaml
76
+ with open(yaml_file, 'r', encoding='utf-8', errors='replace') as f:
77
+ data = yaml.safe_load(f)
78
+ if data and 'names' in data:
79
+ names = data['names']
80
+ if isinstance(names, list):
81
+ return names
82
+ elif isinstance(names, dict):
83
+ # Handle dict format: {0: 'class_a', 1: 'class_b'}
84
+ return [names[i] for i in sorted(names.keys())]
85
+ except Exception:
86
+ # Fallback to regex if yaml import fails or parsing fails
87
+ try:
88
+ text = yaml_file.read_text(encoding="utf-8", errors="replace")
89
+ import re as _re
90
+ m = _re.search(r"names\s*:\s*\n((?:\s*-\s*.+\n?)+)", text)
91
+ if m:
92
+ return [line.strip().lstrip("- ").strip() for line in m.group(1).splitlines() if line.strip()]
93
+ except Exception:
94
+ pass
95
+
96
+ # Try classes.txt
97
+ for cls_file in dataset_root.rglob("classes.txt"):
98
+ try:
99
+ lines = cls_file.read_text(encoding="utf-8", errors="replace").splitlines()
100
+ return [l.strip() for l in lines if l.strip()]
101
+ except Exception:
102
+ pass
103
+
104
+ return []
105
+
106
+ @staticmethod
107
+ def parse_file(
108
+ txt_path: Path,
109
+ image_id: str,
110
+ dataset_id: str,
111
+ class_map: list[str],
112
+ ) -> list[dict]:
113
+ annotations = []
114
+ try:
115
+ text = txt_path.read_text(encoding="utf-8", errors="replace")
116
+ except OSError:
117
+ return annotations
118
+
119
+ for line in text.splitlines():
120
+ parts = line.strip().split()
121
+ if len(parts) < 5:
122
+ continue
123
+ try:
124
+ cls_id = int(parts[0])
125
+ cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
126
+ # YOLO cx,cy → top-left x,y
127
+ x = cx - w / 2
128
+ y = cy - h / 2
129
+ label = class_map[cls_id] if cls_id < len(class_map) else str(cls_id)
130
+ annotations.append(
131
+ _make_ann(image_id, dataset_id, label, (x, y, w, h), area=w * h)
132
+ )
133
+ except (ValueError, IndexError):
134
+ continue
135
+
136
+ return annotations
137
+
138
+ @staticmethod
139
+ def iter_dataset(
140
+ dataset_root: Path,
141
+ dataset_id: str,
142
+ class_map: list[str],
143
+ ) -> Iterator[tuple[str, str, str, list[dict]]]:
144
+ """
145
+ Yield (image_rel_path, image_id, split, annotations) for every image in the dataset.
146
+ Walks train/valid/test directories.
147
+ """
148
+ # Supported subfolder names for splits
149
+ split_map = {
150
+ "train": ["train", "training"],
151
+ "val": ["valid", "val", "validation"],
152
+ "test": ["test", "testing"]
153
+ }
154
+
155
+ found_any = False
156
+ for split_name, folder_names in split_map.items():
157
+ for folder_name in folder_names:
158
+ split_dir = dataset_root / folder_name
159
+ images_dir = split_dir / "images"
160
+
161
+ # Support both split/images and split/ (if images are direct)
162
+ search_dir = images_dir if images_dir.exists() else split_dir
163
+ if not search_dir.exists():
164
+ continue
165
+
166
+ found_any = True
167
+ labels_dir = split_dir / "labels"
168
+
169
+ for img_path in sorted(search_dir.rglob("*")):
170
+ if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
171
+ continue
172
+
173
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
174
+
175
+ # Resolve label path
176
+ # 1. split/labels/img.txt
177
+ # 2. split/img.txt
178
+ # 3. img_path.with_suffix(".txt")
179
+ label_candidates = []
180
+ if labels_dir.exists():
181
+ label_candidates.append(labels_dir / img_path.with_suffix(".txt").name)
182
+ label_candidates.append(img_path.with_suffix(".txt"))
183
+
184
+ anns: list[dict] = []
185
+ for label_file in label_candidates:
186
+ if label_file.exists():
187
+ anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map)
188
+ break
189
+
190
+ rel_path = str(img_path.relative_to(dataset_root))
191
+ yield rel_path, image_id, split_name, anns
192
+
193
+ # Fallback: if no split folders found, scan the root
194
+ if not found_any:
195
+ for img_path in sorted(dataset_root.rglob("*")):
196
+ if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
197
+ continue
198
+ # Skip files inside already processed folders if we had any
199
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
200
+ label_file = img_path.with_suffix(".txt")
201
+ anns = []
202
+ if label_file.exists():
203
+ anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map)
204
+
205
+ rel_path = str(img_path.relative_to(dataset_root))
206
+ yield rel_path, image_id, "train", anns
207
+
208
+
209
+ # ── COCO Parser ───────────────────────────────────────────────────────────────
210
+
211
+ class COCOParser:
212
+ """
213
+ Reads COCO JSON annotation files.
214
+ Supports: instances_train.json, instances_val.json, _annotations.coco.json
215
+ """
216
+
217
+ @staticmethod
218
+ def find_annotation_files(dataset_root: Path) -> list[Path]:
219
+ patterns = ["instances_*.json", "_annotations.coco.json", "*.json"]
220
+ found = []
221
+ for pat in patterns:
222
+ for f in dataset_root.rglob(pat):
223
+ if "label" not in f.name.lower() and "class" not in f.name.lower():
224
+ found.append(f)
225
+ return list(dict.fromkeys(found)) # deduplicate
226
+
227
+ @staticmethod
228
+ def parse_file(
229
+ json_path: Path,
230
+ dataset_id: str,
231
+ ) -> tuple[list[str], list[tuple[str, str, str, list[dict]]]]:
232
+ """
233
+ Returns: (class_names, [(rel_image_path, image_id, split, annotations)])
234
+ """
235
+ try:
236
+ data = json.loads(json_path.read_text(encoding="utf-8"))
237
+ except (OSError, json.JSONDecodeError) as e:
238
+ log.warning("coco_parse_error", file=str(json_path), error=str(e))
239
+ return [], []
240
+
241
+ categories = {c["id"]: c["name"] for c in data.get("categories", [])}
242
+ class_names = list(categories.values())
243
+
244
+ # Determine split from filename
245
+ fname = json_path.stem.lower()
246
+ if "train" in fname:
247
+ split = "train"
248
+ elif "val" in fname or "valid" in fname:
249
+ split = "val"
250
+ elif "test" in fname:
251
+ split = "test"
252
+ else:
253
+ split = "train"
254
+
255
+ # Build image map
256
+ image_map: dict[int, dict] = {
257
+ img["id"]: img for img in data.get("images", [])
258
+ }
259
+
260
+ # Group annotations by image
261
+ ann_by_image: dict[int, list] = {}
262
+ for ann in data.get("annotations", []):
263
+ ann_by_image.setdefault(ann["image_id"], []).append(ann)
264
+
265
+ results = []
266
+ for coco_img_id, img_meta in image_map.items():
267
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
268
+ rel_path = img_meta.get("file_name", "")
269
+ anns = []
270
+ for coco_ann in ann_by_image.get(coco_img_id, []):
271
+ label = categories.get(coco_ann.get("category_id", -1), "unknown")
272
+ bbox = coco_ann.get("bbox", [])
273
+ if len(bbox) == 4:
274
+ # COCO: [x_topleft, y_topleft, w, h] in pixel coords
275
+ img_w = img_meta.get("width", 1) or 1
276
+ img_h = img_meta.get("height", 1) or 1
277
+ bx = bbox[0] / img_w
278
+ by = bbox[1] / img_h
279
+ bw = bbox[2] / img_w
280
+ bh = bbox[3] / img_h
281
+ area_pct = (bbox[2] * bbox[3]) / (img_w * img_h)
282
+
283
+ # Extract segmentation if available
284
+ segmentation = coco_ann.get("segmentation")
285
+ # COCO segmentation can be a list of polygons or RLE
286
+ poly_data = None
287
+ if isinstance(segmentation, list) and len(segmentation) > 0:
288
+ # Normalize polygon coordinates
289
+ poly_data = []
290
+ for poly in segmentation:
291
+ normalized_poly = []
292
+ for i in range(0, len(poly), 2):
293
+ normalized_poly.append(poly[i] / img_w)
294
+ normalized_poly.append(poly[i+1] / img_h)
295
+ poly_data.append(normalized_poly)
296
+
297
+ anns.append(
298
+ _make_ann(
299
+ image_id,
300
+ dataset_id,
301
+ label,
302
+ (bx, by, bw, bh),
303
+ area=area_pct,
304
+ segmentation=poly_data,
305
+ ann_type="segmentation" if poly_data else "detection"
306
+ )
307
+ )
308
+ results.append((rel_path, image_id, split, anns))
309
+
310
+ return class_names, results
311
+
312
+
313
+ # ── VOC Parser ────────────────────────────────────────────────────────────────
314
+
315
+ class VOCParser:
316
+ """Reads Pascal VOC XML annotation files."""
317
+
318
+ @staticmethod
319
+ def parse_file(
320
+ xml_path: Path,
321
+ image_id: str,
322
+ dataset_id: str,
323
+ ) -> tuple[str, int, int, list[dict]]:
324
+ """Returns (filename, width, height, annotations)."""
325
+ try:
326
+ tree = ET.parse(str(xml_path))
327
+ except ET.ParseError as e:
328
+ log.warning("voc_parse_error", file=str(xml_path), error=str(e))
329
+ return "", 0, 0, []
330
+
331
+ root = tree.getroot()
332
+ filename = root.findtext("filename") or ""
333
+ size = root.find("size")
334
+ img_w = int(size.findtext("width") or 1) if size is not None else 1
335
+ img_h = int(size.findtext("height") or 1) if size is not None else 1
336
+
337
+ anns = []
338
+ for obj in root.findall("object"):
339
+ label = obj.findtext("name") or "unknown"
340
+ bndbox = obj.find("bndbox")
341
+ if bndbox is None:
342
+ continue
343
+ xmin = float(bndbox.findtext("xmin") or 0)
344
+ ymin = float(bndbox.findtext("ymin") or 0)
345
+ xmax = float(bndbox.findtext("xmax") or 0)
346
+ ymax = float(bndbox.findtext("ymax") or 0)
347
+ # Normalise
348
+ bx = xmin / img_w
349
+ by = ymin / img_h
350
+ bw = (xmax - xmin) / img_w
351
+ bh = (ymax - ymin) / img_h
352
+ anns.append(_make_ann(image_id, dataset_id, label, (bx, by, bw, bh)))
353
+
354
+ return filename, img_w, img_h, anns
355
+
356
+ @staticmethod
357
+ def iter_dataset(
358
+ dataset_root: Path,
359
+ dataset_id: str,
360
+ ) -> Iterator[tuple[str, str, str, int, int, list[dict]]]:
361
+ """Yield (rel_path, image_id, split, w, h, annotations)."""
362
+ for xml_path in sorted(dataset_root.rglob("*.xml")):
363
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
364
+ filename, w, h, anns = VOCParser.parse_file(xml_path, image_id, dataset_id)
365
+ split = "train"
366
+ for part in xml_path.parts:
367
+ if part in ("train", "training"):
368
+ split = "train"; break
369
+ if part in ("val", "valid", "validation"):
370
+ split = "val"; break
371
+ if part in ("test", "testing"):
372
+ split = "test"; break
373
+ rel_path = filename or str(xml_path.with_suffix(".jpg").relative_to(dataset_root))
374
+ yield rel_path, image_id, split, w, h, anns
375
+
376
+
377
+ # ── Roboflow TXT Parser ───────────────────────────────────────────────────────
378
+
379
+ class RoboflowTXTParser:
380
+ """
381
+ Reads Roboflow classification TXT formats.
382
+ 1. Folder-based: split/class_name/image.jpg
383
+ 2. Label-file: split/_annotations.txt (format: filename,class_name)
384
+ """
385
+
386
+ @staticmethod
387
+ def iter_dataset(
388
+ dataset_root: Path,
389
+ dataset_id: str,
390
+ ) -> Iterator[tuple[str, str, str, list[dict]]]:
391
+ split_map = {
392
+ "train": ["train", "training"],
393
+ "val": ["valid", "val", "validation"],
394
+ "test": ["test", "testing"]
395
+ }
396
+
397
+ found_any = False
398
+ for split_name, folder_names in split_map.items():
399
+ for folder_name in folder_names:
400
+ split_dir = dataset_root / folder_name
401
+ if not split_dir.exists():
402
+ continue
403
+
404
+ found_any = True
405
+
406
+ # Check for _annotations.txt (Roboflow's flat format)
407
+ ann_file = split_dir / "_annotations.txt"
408
+ if ann_file.exists():
409
+ try:
410
+ with open(ann_file, "r", encoding="utf-8") as f:
411
+ # Format is usually: filename,class_name
412
+ for line in f:
413
+ parts = line.strip().split(",")
414
+ if len(parts) >= 2:
415
+ fname, label = parts[0], parts[1]
416
+ img_path = split_dir / fname
417
+ if img_path.exists():
418
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
419
+ anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
420
+ rel_path = str(img_path.relative_to(dataset_root))
421
+ yield rel_path, image_id, split_name, anns
422
+ continue # Processed via file, skip folder logic
423
+ except Exception:
424
+ pass
425
+
426
+ # Fallback to Folder-based: split/class_name/image.jpg
427
+ for class_dir in split_dir.iterdir():
428
+ if class_dir.is_dir() and class_dir.name.lower() not in ["images", "labels"]:
429
+ label = class_dir.name
430
+ for img_path in class_dir.rglob("*"):
431
+ if img_path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
432
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
433
+ anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
434
+ rel_path = str(img_path.relative_to(dataset_root))
435
+ yield rel_path, image_id, split_name, anns
436
+
437
+ # Fallback to root scan if no split folders found
438
+ if not found_any:
439
+ for img_path in sorted(dataset_root.rglob("*")):
440
+ if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
441
+ continue
442
+ # Simple heuristic: parent folder is class name
443
+ label = img_path.parent.name if img_path.parent != dataset_root else "unknown"
444
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
445
+ anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")]
446
+ rel_path = str(img_path.relative_to(dataset_root))
447
+ yield rel_path, image_id, "train", anns
448
+
449
+ class CSVParser:
450
+ """
451
+ Reads CSV files for NLP (classification, NER) or Tabular data.
452
+ """
453
+
454
+ @staticmethod
455
+ def detect_delimiter(file_path: Path) -> str:
456
+ try:
457
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
458
+ header = f.readline()
459
+ if ';' in header: return ';'
460
+ if '\t' in header: return '\t'
461
+ return ','
462
+ except Exception:
463
+ return ','
464
+
465
+ @staticmethod
466
+ def parse_file(
467
+ csv_path: Path,
468
+ dataset_id: str,
469
+ text_column: str = "text",
470
+ label_column: str = "label",
471
+ ) -> list[dict]:
472
+ annotations = []
473
+ delimiter = CSVParser.detect_delimiter(csv_path)
474
+ try:
475
+ with open(csv_path, mode='r', encoding='utf-8', errors='replace') as f:
476
+ reader = csv.DictReader(f, delimiter=delimiter)
477
+ for row in reader:
478
+ image_id = f"txt-{uuid.uuid4().hex[:12]}"
479
+ text = row.get(text_column, "")
480
+ label = row.get(label_column, "unknown")
481
+ if text:
482
+ annotations.append(
483
+ _make_ann(
484
+ image_id=image_id,
485
+ dataset_id=dataset_id,
486
+ label=label,
487
+ bbox=(0, 0, 0, 0),
488
+ ann_type="nlp_classification"
489
+ )
490
+ )
491
+ except Exception as e:
492
+ log.error("csv_parse_error", file=str(csv_path), error=str(e))
493
+ return annotations
494
+
495
+
496
+ # ── Utilities ────────────────────────────────────────────────────────────────
497
+
498
+ def _img_dimensions(path: Path) -> tuple[int, int]:
499
+ """Fast dimension detection via struct."""
500
+ try:
501
+ import struct
502
+ with open(path, "rb") as f:
503
+ data = f.read(24)
504
+ if data[:8] == b"\x89PNG\r\n\x1a\n":
505
+ return struct.unpack(">II", data[16:24])
506
+ if data[:2] == b"\xff\xd8":
507
+ f.seek(0)
508
+ full = f.read(2048) # Read more for JPEG header
509
+ i = 2
510
+ while i < len(full) - 9:
511
+ if full[i] == 0xFF and full[i + 1] in (0xC0, 0xC1, 0xC2):
512
+ h, w = struct.unpack(">HH", full[i + 5:i + 9])
513
+ return int(w), int(h)
514
+ i += 1
515
+ except: pass
516
+ return 0, 0
517
+
518
+
519
+ # ── Format Detector ───────────────────────────────────────────────────────────
520
+
521
+ def detect_format(dataset_root: Path) -> str:
522
+ """Heuristically detect the annotation format in a dataset directory."""
523
+ # COCO: look for JSON with 'images' and 'annotations' keys
524
+ for jf in dataset_root.rglob("*.json"):
525
+ try:
526
+ snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048]
527
+ if '"images"' in snippet and '"annotations"' in snippet:
528
+ return "coco"
529
+ except OSError:
530
+ pass
531
+
532
+ # VOC: look for XML files with <annotation> root
533
+ for xf in dataset_root.rglob("*.xml"):
534
+ try:
535
+ snippet = xf.read_text(encoding="utf-8", errors="replace")[:512]
536
+ if "<annotation>" in snippet:
537
+ return "voc"
538
+ except OSError:
539
+ pass
540
+
541
+ # YOLO: check for .txt label files and data.yaml
542
+ if list(dataset_root.rglob("data.yaml")):
543
+ return "yolo"
544
+
545
+ txt_files = list(dataset_root.rglob("*.txt"))
546
+ # Filter out common non-label files
547
+ label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")]
548
+ if label_txts:
549
+ # Check if first line looks like YOLO (<int> <float> <float> <float> <float>)
550
+ try:
551
+ first_txt = label_txts[0]
552
+ content = first_txt.read_text(encoding="utf-8").strip().split('\n')[0]
553
+ if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content):
554
+ return "yolo"
555
+ except Exception:
556
+ pass
557
+
558
+ # Roboflow Classification TXT: check for split folders containing only subfolders (class names)
559
+ # or check for _annotations.txt
560
+ if list(dataset_root.rglob("_annotations.txt")):
561
+ return "txt"
562
+
563
+ # Check for folder-based classification (split/class_name/img.jpg)
564
+ # If we see folders that aren't 'images' or 'labels' inside train/val/test
565
+ for split in ["train", "valid", "test"]:
566
+ split_dir = dataset_root / split
567
+ if split_dir.exists() and split_dir.is_dir():
568
+ subdirs = [d for d in split_dir.iterdir() if d.is_dir()]
569
+ if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs):
570
+ return "txt"
571
+
572
+ # CSV/NLP: check for csv files
573
+ if list(dataset_root.rglob("*.csv")):
574
+ return "csv"
575
+
576
+ return "custom"
datasets/base_adapter.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Iterator, Dict, Any, Optional
4
+ from models.dataset import UniversalDatasetItem, DatasetTask
5
+
6
+ class DatasetAdapter(ABC):
7
+ """
8
+ Base interface for all dataset format adapters.
9
+ Following the senior architect pattern: decoupling format logic from import orchestration.
10
+ """
11
+
12
+ @abstractmethod
13
+ def detect(self, dataset_path: Path) -> bool:
14
+ """Return True if this adapter can handle the dataset at the given path."""
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_task(self, dataset_path: Path) -> DatasetTask:
19
+ """Identify the primary task type (detection, classification, etc.) for this dataset."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
24
+ """
25
+ Yield (image_record, annotations) for each item in the dataset.
26
+ Memory-efficient streaming for large Roboflow datasets.
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def get_class_names(self, dataset_path: Path) -> List[str]:
32
+ """Extract or derive the list of class names from the dataset."""
33
+ pass
34
+
35
+ def get_metadata(self, dataset_path: Path) -> Dict[str, Any]:
36
+ """Optional: Extract additional format-specific metadata."""
37
+ return {}
datasets/format_adapters.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import re
4
+ from typing import Any, List, Tuple, Iterator, Dict
5
+ from .base_adapter import DatasetAdapter
6
+ from models.dataset import UniversalDatasetItem, DatasetContentType, UniversalAnnotation, UniversalAnnotationType, DatasetTask
7
+ from .annotation_parser import YOLOParser, COCOParser, VOCParser, RoboflowTXTParser, _img_dimensions
8
+
9
+ class YOLOAdapter(DatasetAdapter):
10
+ def detect(self, dataset_path: Path) -> bool:
11
+ if list(dataset_path.rglob("data.yaml")):
12
+ return True
13
+ txt_files = list(dataset_path.rglob("*.txt"))
14
+ label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")]
15
+ if label_txts:
16
+ try:
17
+ content = label_txts[0].read_text(encoding="utf-8").strip().split('\n')[0]
18
+ if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content):
19
+ return True
20
+ except: pass
21
+ return False
22
+
23
+ def get_task(self, dataset_path: Path) -> DatasetTask:
24
+ return DatasetTask.detection
25
+
26
+ def get_class_names(self, dataset_path: Path) -> List[str]:
27
+ return YOLOParser.load_class_map(dataset_path)
28
+
29
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
30
+ class_map = self.get_class_names(dataset_path)
31
+ for rel_path, image_id, split, anns in YOLOParser.iter_dataset(dataset_path, dataset_id, class_map):
32
+ abs_path = dataset_path / rel_path
33
+ w, h = _img_dimensions(abs_path)
34
+ img_rec = {
35
+ "id": image_id, "filename": Path(rel_path).name,
36
+ "rel_path": str(rel_path), "width": w, "height": h,
37
+ "split": split, "ann_count": len(anns),
38
+ }
39
+ yield img_rec, anns
40
+
41
+ class COCOAdapter(DatasetAdapter):
42
+ def detect(self, dataset_path: Path) -> bool:
43
+ for jf in dataset_path.rglob("*.json"):
44
+ try:
45
+ snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048]
46
+ if '"images"' in snippet and '"annotations"' in snippet:
47
+ return True
48
+ except: pass
49
+ return False
50
+
51
+ def get_task(self, dataset_path: Path) -> DatasetTask:
52
+ return DatasetTask.segmentation # Roboflow COCO often implies segmentation
53
+
54
+ def get_class_names(self, dataset_path: Path) -> List[str]:
55
+ ann_files = COCOParser.find_annotation_files(dataset_path)
56
+ all_classes = []
57
+ for ann_file in ann_files:
58
+ classes, _ = COCOParser.parse_file(ann_file, "dummy")
59
+ all_classes = list(dict.fromkeys(all_classes + classes))
60
+ return all_classes
61
+
62
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
63
+ ann_files = COCOParser.find_annotation_files(dataset_path)
64
+ for ann_file in ann_files:
65
+ _, coco_results = COCOParser.parse_file(ann_file, dataset_id)
66
+ for rel_path, image_id, split, anns in coco_results:
67
+ abs_path = dataset_path / rel_path
68
+ w, h = _img_dimensions(abs_path)
69
+ img_rec = {
70
+ "id": image_id, "filename": Path(rel_path).name,
71
+ "rel_path": str(rel_path), "width": w, "height": h,
72
+ "split": split, "ann_count": len(anns),
73
+ }
74
+ yield img_rec, anns
75
+
76
+ class VOCAdapter(DatasetAdapter):
77
+ def detect(self, dataset_path: Path) -> bool:
78
+ for xf in dataset_path.rglob("*.xml"):
79
+ try:
80
+ snippet = xf.read_text(encoding="utf-8", errors="replace")[:512]
81
+ if "<annotation>" in snippet:
82
+ return True
83
+ except: pass
84
+ return False
85
+
86
+ def get_task(self, dataset_path: Path) -> DatasetTask:
87
+ return DatasetTask.detection
88
+
89
+ def get_class_names(self, dataset_path: Path) -> List[str]:
90
+ classes = set()
91
+ for _, _, _, _, _, anns in VOCParser.iter_dataset(dataset_path, "dummy"):
92
+ for ann in anns:
93
+ classes.add(ann["label"])
94
+ return sorted(list(classes))
95
+
96
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
97
+ for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(dataset_path, dataset_id):
98
+ img_rec = {
99
+ "id": image_id, "filename": Path(rel_path).name,
100
+ "rel_path": str(rel_path), "width": w, "height": h,
101
+ "split": split, "ann_count": len(anns),
102
+ }
103
+ yield img_rec, anns
104
+
105
+ class CreateMLAdapter(DatasetAdapter):
106
+ def detect(self, dataset_path: Path) -> bool:
107
+ for jf in dataset_path.rglob("*.json"):
108
+ try:
109
+ snippet = jf.read_text(encoding="utf-8", errors="replace")[:1024]
110
+ if '"image"' in snippet and '"annotations"' in snippet and "[" in snippet:
111
+ return True
112
+ except: pass
113
+ return False
114
+
115
+ def get_task(self, dataset_path: Path) -> DatasetTask:
116
+ return DatasetTask.detection
117
+
118
+ def get_class_names(self, dataset_path: Path) -> List[str]:
119
+ classes = set()
120
+ for jf in dataset_path.rglob("*.json"):
121
+ try:
122
+ data = json.loads(jf.read_text(encoding="utf-8"))
123
+ if isinstance(data, list):
124
+ for item in data:
125
+ for ann in item.get("annotations", []):
126
+ if "label" in ann: classes.add(ann["label"])
127
+ except: pass
128
+ return sorted(list(classes))
129
+
130
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
131
+ from .annotation_parser import _make_ann
132
+ for jf in dataset_path.rglob("*.json"):
133
+ try:
134
+ data = json.loads(jf.read_text(encoding="utf-8"))
135
+ if not isinstance(data, list): continue
136
+
137
+ # Determine split from path
138
+ split = "train"
139
+ if "val" in jf.parts or "valid" in jf.parts: split = "val"
140
+ elif "test" in jf.parts: split = "test"
141
+
142
+ for item in data:
143
+ rel_img_path = item.get("image")
144
+ if not rel_img_path: continue
145
+
146
+ # Try to find the image relative to JSON or root
147
+ img_path = jf.parent / rel_img_path
148
+ if not img_path.exists():
149
+ img_path = dataset_path / rel_img_path
150
+
151
+ if img_path.exists():
152
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
153
+ w, h = _img_dimensions(img_path)
154
+
155
+ anns = []
156
+ for ca in item.get("annotations", []):
157
+ label = ca.get("label", "unknown")
158
+ coord = ca.get("coordinates", {})
159
+ # CreateML coords are usually center-based pixels: {x, y, width, height}
160
+ if "x" in coord and "y" in coord and w > 0 and h > 0:
161
+ cx, cy, bw, bh = coord["x"], coord["y"], coord["width"], coord["height"]
162
+ # Convert to top-left normalized
163
+ nx = (cx - bw/2) / w
164
+ ny = (cy - bh/2) / h
165
+ nw = bw / w
166
+ nh = bh / h
167
+ anns.append(_make_ann(image_id, dataset_id, label, (nx, ny, nw, nh)))
168
+
169
+ img_rec = {
170
+ "id": image_id, "filename": img_path.name,
171
+ "rel_path": str(img_path.relative_to(dataset_path)),
172
+ "width": w, "height": h, "split": split, "ann_count": len(anns)
173
+ }
174
+ yield img_rec, anns
175
+ except: pass
176
+
177
+ class NLPAdapter(DatasetAdapter):
178
+ def detect(self, dataset_path: Path) -> bool:
179
+ return any(dataset_path.rglob("*.csv")) or any(dataset_path.rglob("*.tsv"))
180
+
181
+ def get_task(self, dataset_path: Path) -> DatasetTask:
182
+ return DatasetTask.nlp
183
+
184
+ def get_class_names(self, dataset_path: Path) -> List[str]:
185
+ # Implementation for NLP class names
186
+ return []
187
+
188
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
189
+ # Implementation for NLP items
190
+ yield {}, []
191
+
192
+ class TabularAdapter(DatasetAdapter):
193
+ def detect(self, dataset_path: Path) -> bool:
194
+ return False # Placeholder
195
+
196
+ def get_task(self, dataset_path: Path) -> DatasetTask:
197
+ return DatasetTask.classification
198
+
199
+ def get_class_names(self, dataset_path: Path) -> List[str]:
200
+ return []
201
+
202
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
203
+ yield {}, []
204
+
205
+ class RoboflowClassificationAdapter(DatasetAdapter):
206
+ def detect(self, dataset_path: Path) -> bool:
207
+ # Check for _annotations.txt or folder-based classification
208
+ if list(dataset_path.rglob("_annotations.txt")): return True
209
+ for split in ["train", "valid", "test"]:
210
+ split_dir = dataset_path / split
211
+ if split_dir.exists() and split_dir.is_dir():
212
+ subdirs = [d for d in split_dir.iterdir() if d.is_dir()]
213
+ if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs):
214
+ return True
215
+ return False
216
+
217
+ def get_task(self, dataset_path: Path) -> DatasetTask:
218
+ return DatasetTask.classification
219
+
220
+ def get_class_names(self, dataset_path: Path) -> List[str]:
221
+ classes = set()
222
+ for _, _, _, anns in RoboflowTXTParser.iter_dataset(dataset_path, "dummy"):
223
+ for ann in anns: classes.add(ann["label"])
224
+ return sorted(list(classes))
225
+
226
+ def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]:
227
+ for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(dataset_path, dataset_id):
228
+ abs_path = dataset_path / rel_path
229
+ w, h = _img_dimensions(abs_path)
230
+ img_rec = {
231
+ "id": image_id, "filename": Path(rel_path).name,
232
+ "rel_path": str(rel_path), "width": w, "height": h,
233
+ "split": split, "ann_count": len(anns),
234
+ }
235
+ yield img_rec, anns
datasets/import_service.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets/import_service.py — Dataset Import Pipeline.
3
+
4
+ Pipeline stages:
5
+ 1. Create job record
6
+ 2. Download dataset zip (chunked, progress-tracked)
7
+ 3. Extract zip safely (path-traversal protected)
8
+ 4. Detect annotation format & task type
9
+ 5. Index images into dataset_images table
10
+ 6. Parse & store metadata (Stats only, annotations are read-on-demand)
11
+ 7. Update dataset stats (images, classes, size)
12
+ 8. Mark job completed / failed
13
+
14
+ All stages run as background tasks.
15
+ Supports Roboflow, HuggingFace, and local file/folder imports.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import hashlib
21
+ import os
22
+ import shutil
23
+ import uuid
24
+ import zipfile
25
+ from datetime import datetime
26
+ from pathlib import Path
27
+ from typing import Optional, List, Dict, Any, Tuple
28
+
29
+ import aiofiles
30
+ import httpx
31
+ from huggingface_hub import snapshot_download
32
+
33
+ from config import settings
34
+ from . import registry as ds_reg
35
+ from .format_adapters import (
36
+ YOLOAdapter, COCOAdapter, VOCAdapter, CreateMLAdapter,
37
+ RoboflowClassificationAdapter, NLPAdapter, TabularAdapter
38
+ )
39
+ from .base_adapter import DatasetAdapter
40
+ from .annotation_parser import _img_dimensions
41
+ from observability.logger import audit, get_logger
42
+ from models.dataset import DatasetStatus, DatasetTask, ImportRequest, Dataset
43
+
44
+ log = get_logger("import_service")
45
+
46
+ ADAPTERS: List[DatasetAdapter] = [
47
+ YOLOAdapter(),
48
+ COCOAdapter(),
49
+ VOCAdapter(),
50
+ CreateMLAdapter(),
51
+ RoboflowClassificationAdapter(),
52
+ NLPAdapter(),
53
+ TabularAdapter(),
54
+ ]
55
+
56
+ def get_adapter_for_path(path: Path) -> DatasetAdapter | None:
57
+ for adapter in ADAPTERS:
58
+ if adapter.detect(path):
59
+ return adapter
60
+ return None
61
+
62
+ async def recover_stale_jobs() -> None:
63
+ """Cleanup dataset import jobs that were left in 'running' or 'queued' state."""
64
+ await ds_reg.cleanup_stale_jobs()
65
+
66
+ def _dataset_path(dataset_id: str) -> Path:
67
+ return settings.datasets_dir / dataset_id
68
+
69
+ # ── Entry Point ──────────────────────────────────────────────────────────────
70
+
71
+ async def start_import(req: ImportRequest) -> str:
72
+ """Entry point to initiate a background import job."""
73
+ job_id = f"job-{uuid.uuid4().hex[:8]}"
74
+
75
+ # Create initial job record
76
+ await ds_reg.update_job(
77
+ job_id,
78
+ dataset_id=req.dataset_id,
79
+ status="queued",
80
+ progress=0,
81
+ message="Import queued",
82
+ type=str(req.source)
83
+ )
84
+
85
+ # Launch background task
86
+ asyncio.create_task(_run_pipeline(job_id, req, req.dataset_name or req.dataset_id))
87
+
88
+ return job_id
89
+
90
+
91
+ # ── Pipeline orchestrator ────────────────────────────────────────────────────
92
+
93
+ async def _run_pipeline(job_id: str, req: ImportRequest, dataset_name: str) -> None:
94
+ started = datetime.utcnow().isoformat()
95
+ await ds_reg.update_job(job_id, status="running", started_at=started, message="Starting import")
96
+ await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.importing, progress=0.01)
97
+
98
+ try:
99
+ # Stage 1 – Resolve download URL or local path
100
+ source_path = await _stage_acquire(job_id, req)
101
+
102
+ # Stage 2 – Extract / Prepare Directory
103
+ extract_dir = await _stage_extract(job_id, req.dataset_id, source_path)
104
+
105
+ # Stage 3 – Detect adapter and Task
106
+ await ds_reg.update_job(job_id, progress=0.55, message="Detecting dataset format...")
107
+ adapter = await asyncio.to_thread(get_adapter_for_path, extract_dir)
108
+
109
+ if not adapter:
110
+ log.warning("no_adapter_found_generic_fallback", dataset_id=req.dataset_id)
111
+ image_records = await asyncio.to_thread(_scan_images_generic, req.dataset_id, extract_dir)
112
+ class_names = []
113
+ task = DatasetTask.classification
114
+ fmt_name = "custom"
115
+ else:
116
+ task = adapter.get_task(extract_dir)
117
+ fmt_name = adapter.__class__.__name__.replace("Adapter", "").lower()
118
+
119
+ log.info("adapter_detected", job_id=job_id, format=fmt_name, task=task)
120
+ await ds_reg.update_job(job_id, progress=0.60, message=f"Parsing {fmt_name.upper()} {task.upper()}")
121
+
122
+ # Stage 4 – Parse Metadata & Annotations (Streaming)
123
+ class_names = await asyncio.to_thread(adapter.get_class_names, extract_dir)
124
+ image_records = []
125
+ all_annotations = []
126
+
127
+ # Health metrics tracking
128
+ hashes = {} # hash -> filename
129
+ duplicates = 0
130
+ empty_images = 0
131
+ total_ann_count = 0
132
+
133
+ for img_rec, anns in adapter.iter_items(req.dataset_id, extract_dir):
134
+ # Duplicate detection via MD5 hash
135
+ abs_path = extract_dir / img_rec["rel_path"]
136
+ if abs_path.exists():
137
+ img_hash = _calculate_hash(abs_path)
138
+ if img_hash in hashes:
139
+ duplicates += 1
140
+ img_rec["metadata"] = json.dumps({"is_duplicate": True, "original": hashes[img_hash]})
141
+ else:
142
+ hashes[img_hash] = img_rec["filename"]
143
+
144
+ if not anns:
145
+ empty_images += 1
146
+
147
+ total_ann_count += len(anns)
148
+ image_records.append(img_rec)
149
+ all_annotations.extend(anns)
150
+
151
+ if not image_records:
152
+ raise ValueError(f"No valid data files found in {extract_dir}")
153
+
154
+ # Stage 5 – Indexing
155
+ await ds_reg.update_job(job_id, progress=0.80, message=f"Indexing {len(image_records)} items")
156
+ await ds_reg.index_images(req.dataset_id, image_records)
157
+
158
+ if all_annotations:
159
+ await ds_reg.update_job(job_id, progress=0.85, message=f"Indexing {len(all_annotations)} annotations")
160
+ await ds_reg.bulk_insert_annotations(all_annotations)
161
+
162
+ # Stage 6 – Stats & Health Analysis
163
+ size_bytes = await asyncio.to_thread(_dir_size, extract_dir)
164
+
165
+ # Calculate Health Score (0-100)
166
+ # Factors: duplicates, empty images (for detection), class balance (TODO)
167
+ score = 100.0
168
+ if len(image_records) > 0:
169
+ dup_penalty = (duplicates / len(image_records)) * 50
170
+ empty_penalty = (empty_images / len(image_records)) * 20 if task == DatasetTask.detection else 0
171
+ score = max(0.0, 100.0 - dup_penalty - empty_penalty)
172
+
173
+ stats_payload = {
174
+ "image_count": len(image_records),
175
+ "annotation_count": total_ann_count,
176
+ "class_count": len(class_names),
177
+ "empty_images": empty_images,
178
+ "duplicate_count": duplicates,
179
+ "health_score": round(score, 1),
180
+ "avg_objects": round(total_ann_count / len(image_records), 2) if image_records else 0
181
+ }
182
+
183
+ await ds_reg.update_dataset_stats(
184
+ req.dataset_id,
185
+ len(image_records),
186
+ len(class_names),
187
+ class_names,
188
+ size_bytes,
189
+ stats=stats_payload
190
+ )
191
+ await ds_reg.update_dataset_task(req.dataset_id, task)
192
+
193
+ # Cleanup temp zip if applicable
194
+ if source_path.is_file() and source_path.suffix.lower() == ".zip" and "_tmp" in str(source_path):
195
+ source_path.unlink(missing_ok=True)
196
+
197
+ # Stage 7 – Project Linking (Integration point)
198
+ local_path = str(extract_dir)
199
+ from projects.service import link_dataset_to_active_project
200
+ project_ds_root = await link_dataset_to_active_project(req.dataset_id, local_path)
201
+ final_local_path = str(project_ds_root) if project_ds_root and project_ds_root.exists() else local_path
202
+
203
+ # Completion
204
+ await ds_reg.update_job(
205
+ job_id, status="completed", progress=1.0,
206
+ message="Import complete", ended_at=datetime.utcnow().isoformat(),
207
+ )
208
+ await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.imported, progress=1.0, local_path=final_local_path)
209
+ await audit("dataset_import_complete", {"job_id": job_id, "path": final_local_path}, job_id=job_id)
210
+ log.info("import_complete", job_id=job_id, dataset_id=req.dataset_id)
211
+
212
+ except asyncio.CancelledError:
213
+ await _fail_job(job_id, req.dataset_id, "Import cancelled by user or system")
214
+ raise
215
+ except Exception as exc:
216
+ log.error("import_failed", job_id=job_id, error=str(exc))
217
+ await _fail_job(job_id, req.dataset_id, str(exc))
218
+ await audit("dataset_import_error", {"job_id": job_id, "error": str(exc)}, job_id=job_id, level="error")
219
+
220
+
221
+ async def _fail_job(job_id: str, dataset_id: str, error: str) -> None:
222
+ await ds_reg.update_job(
223
+ job_id, status="failed", error=error,
224
+ ended_at=datetime.utcnow().isoformat(),
225
+ message="Import failed",
226
+ )
227
+ await ds_reg.update_dataset_status(dataset_id, DatasetStatus.failed, progress=0.0)
228
+
229
+
230
+ # ── Stage 1: Acquire source ──────────────────────────────────────────────────
231
+
232
+ async def _stage_acquire(job_id: str, req: ImportRequest) -> Path:
233
+ """Resolves the source (Download URL, HF Repo, or Local Path)."""
234
+ await ds_reg.update_job(job_id, progress=0.05, message="Acquiring source...")
235
+
236
+ if req.source in ("roboflow", "roboflow_curl"):
237
+ return await _acquire_roboflow(job_id, req)
238
+
239
+ if req.source == "huggingface":
240
+ return await _acquire_huggingface(job_id, req)
241
+
242
+ if req.source == "local":
243
+ return await _acquire_local(job_id, req)
244
+
245
+ raise ValueError(f"Unsupported source provider: {req.source}")
246
+
247
+
248
+ async def _acquire_roboflow(job_id: str, req: ImportRequest) -> Path:
249
+ """Specialized Roboflow downloader using SDK or direct link."""
250
+ # Attempt SDK first (more reliable for Universe)
251
+ try:
252
+ from roboflow import Roboflow
253
+ api_key = req.roboflow_key or (req.headers.get("Authorization") if req.headers else None)
254
+ if api_key and "Bearer " in str(api_key):
255
+ api_key = api_key.split("Bearer ")[-1].strip()
256
+
257
+ if api_key and req.roboflow_workspace and req.roboflow_project:
258
+ rf = Roboflow(api_key=api_key)
259
+ project = rf.workspace(req.roboflow_workspace).project(req.roboflow_project)
260
+ version_obj = project.version(req.roboflow_version or 1)
261
+
262
+ tmp_target = DATASETS_ROOT / "_tmp" / f"rf-{uuid.uuid4().hex[:8]}"
263
+ await ds_reg.update_job(job_id, progress=0.10, message="Downloading via Roboflow SDK...")
264
+
265
+ # Threaded SDK call
266
+ await asyncio.to_thread(
267
+ version_obj.download,
268
+ _format_to_rf_slug(str(req.format)),
269
+ location=str(tmp_target)
270
+ )
271
+ return tmp_target
272
+ except Exception as e:
273
+ log.warning("roboflow_sdk_fallback", error=str(e))
274
+
275
+ # Fallback to direct HTTP download
276
+ url = req.download_url
277
+ if not url and req.source == "roboflow":
278
+ from adapters.roboflow_adapter import RoboflowAdapter
279
+ url = await RoboflowAdapter.get_download_url(
280
+ api_key=req.roboflow_key,
281
+ workspace=req.roboflow_workspace,
282
+ project_id=req.roboflow_project,
283
+ version=req.roboflow_version,
284
+ export_format=_format_to_rf_slug(str(req.format)),
285
+ )
286
+
287
+ if not url:
288
+ raise ValueError("Could not resolve Roboflow download URL")
289
+
290
+ return await _download_zip(job_id, req.dataset_id, url, req.headers)
291
+
292
+
293
+ async def _acquire_huggingface(job_id: str, req: ImportRequest) -> Path:
294
+ if not req.hf_dataset_id:
295
+ raise ValueError("hf_dataset_id is missing")
296
+
297
+ dest_dir = _dataset_path(req.dataset_id)
298
+ dest_dir.mkdir(parents=True, exist_ok=True)
299
+
300
+ await ds_reg.update_job(job_id, progress=0.10, message=f"Cloning {req.hf_dataset_id} from HF...")
301
+
302
+ await asyncio.to_thread(
303
+ snapshot_download,
304
+ repo_id=req.hf_dataset_id,
305
+ repo_type="dataset",
306
+ local_dir=str(dest_dir),
307
+ token=settings.hf_token,
308
+ local_dir_use_symlinks=False
309
+ )
310
+ return dest_dir
311
+
312
+
313
+ async def _acquire_local(job_id: str, req: ImportRequest) -> Path:
314
+ if not req.local_path:
315
+ raise ValueError("local_path is missing for local import")
316
+
317
+ path = Path(os.path.normpath(req.local_path.strip().strip('"').strip("'")))
318
+ if not path.exists():
319
+ raise FileNotFoundError(f"Local path does not exist: {path}")
320
+
321
+ return path
322
+
323
+
324
+ # ── Stage 2: Extraction ──────────────────────────────────────────────────────
325
+
326
+ async def _stage_extract(job_id: str, dataset_id: str, source_path: Path) -> Path:
327
+ dest = _dataset_path(dataset_id)
328
+ dest.mkdir(parents=True, exist_ok=True)
329
+
330
+ if source_path.is_dir():
331
+ if source_path == dest:
332
+ return dest
333
+ await ds_reg.update_job(job_id, progress=0.45, message="Copying local files...")
334
+ await asyncio.to_thread(_copy_dir_contents, source_path, dest)
335
+ return dest
336
+
337
+ # It's a zip
338
+ await ds_reg.update_job(job_id, progress=0.45, message="Extracting archive...")
339
+ await ds_reg.update_dataset_status(dataset_id, DatasetStatus.extracting, progress=0.45)
340
+ await asyncio.to_thread(_safe_extract, source_path, dest)
341
+ return dest
342
+
343
+
344
+ # ── Stage 3: Parsing (Memory-Safe) ───────────────────────────────────────────
345
+
346
+ def _heuristic_task_detection(fmt: str, root: Path) -> DatasetTask:
347
+ """Improved task detection based on file content."""
348
+ if fmt == "csv":
349
+ return DatasetTask.nlp
350
+
351
+ # Check for segmentation in COCO
352
+ if fmt == "coco":
353
+ # Sample first few lines of JSON if possible or check file size
354
+ return DatasetTask.segmentation # Heuristic: most modern COCO use cases
355
+
356
+ if fmt in ("yolo", "voc"):
357
+ return DatasetTask.detection
358
+
359
+ return DatasetTask.classification
360
+
361
+
362
+ def _parse_yolo(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
363
+ class_map = YOLOParser.load_class_map(root)
364
+ results = []
365
+ # Generator approach to keep memory low
366
+ for rel_path, image_id, split, anns in YOLOParser.iter_dataset(root, dataset_id, class_map):
367
+ abs_path = root / rel_path
368
+ w, h = _img_dimensions(abs_path)
369
+ img_rec = {
370
+ "id": image_id, "filename": Path(rel_path).name,
371
+ "rel_path": str(rel_path), "width": w, "height": h,
372
+ "split": split, "ann_count": len(anns),
373
+ }
374
+ results.append((img_rec, anns))
375
+ return class_map, results
376
+
377
+
378
+ def _parse_coco(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
379
+ ann_files = COCOParser.find_annotation_files(root)
380
+ all_classes: list[str] = []
381
+ results = []
382
+ for ann_file in ann_files:
383
+ classes, coco_results = COCOParser.parse_file(ann_file, dataset_id)
384
+ all_classes = list(dict.fromkeys(all_classes + classes))
385
+ for rel_path, image_id, split, anns in coco_results:
386
+ abs_path = root / rel_path
387
+ w, h = _img_dimensions(abs_path)
388
+ img_rec = {
389
+ "id": image_id, "filename": Path(rel_path).name,
390
+ "rel_path": str(rel_path), "width": w, "height": h,
391
+ "split": split, "ann_count": len(anns),
392
+ }
393
+ results.append((img_rec, anns))
394
+ return all_classes, results
395
+
396
+
397
+ def _parse_voc(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
398
+ class_set = set()
399
+ results = []
400
+ for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(root, dataset_id):
401
+ img_rec = {
402
+ "id": image_id, "filename": Path(rel_path).name,
403
+ "rel_path": str(rel_path), "width": w, "height": h,
404
+ "split": split, "ann_count": len(anns),
405
+ }
406
+ results.append((img_rec, anns))
407
+ for ann in anns:
408
+ class_set.add(ann["label"])
409
+ return sorted(list(class_set)), results
410
+
411
+
412
+ def _parse_csv(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
413
+ all_classes = set()
414
+ results = []
415
+ for csv_path in root.rglob("*.csv"):
416
+ anns = CSVParser.parse_file(csv_path, dataset_id)
417
+ # For CSV, each annotation is a row. We group by text entry id (image_id)
418
+ anns_by_id: Dict[str, List[Dict]] = {}
419
+ for ann in anns:
420
+ all_classes.add(ann["label"])
421
+ anns_by_id.setdefault(ann["image_id"], []).append(ann)
422
+
423
+ for text_id, grouped_anns in anns_by_id.items():
424
+ img_rec = {
425
+ "id": text_id, "filename": csv_path.name,
426
+ "rel_path": str(csv_path.relative_to(root)),
427
+ "width": 0, "height": 0, "split": "train", "ann_count": len(grouped_anns),
428
+ }
429
+ results.append((img_rec, grouped_anns))
430
+ return sorted(list(all_classes)), results
431
+
432
+
433
+ def _parse_txt(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
434
+ from datasets.annotation_parser import RoboflowTXTParser
435
+ results = []
436
+ class_set = set()
437
+
438
+ for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(root, dataset_id):
439
+ abs_path = root / rel_path
440
+ w, h = _img_dimensions(abs_path)
441
+ img_rec = {
442
+ "id": image_id, "filename": Path(rel_path).name,
443
+ "rel_path": str(rel_path), "width": w, "height": h,
444
+ "split": split, "ann_count": len(anns),
445
+ }
446
+ results.append((img_rec, anns))
447
+ for ann in anns:
448
+ class_set.add(ann["label"])
449
+
450
+ return sorted(list(class_set)), results
451
+
452
+
453
+ def _parse_generic_folder(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]:
454
+ """
455
+ Enhanced generic folder parser. Supports:
456
+ 1. root/class_name/img.jpg
457
+ 2. root/train/class_name/img.jpg
458
+ 3. root/images/img.jpg
459
+ """
460
+ results = []
461
+ class_set = set()
462
+ exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
463
+
464
+ # Structural keywords to ignore as classes
465
+ ignore = {"images", "labels", "train", "val", "test", "validation", "training", "valid", "testing", "unknown", "annotations"}
466
+
467
+ for img_path in sorted(root.rglob("*")):
468
+ if img_path.suffix.lower() not in exts:
469
+ continue
470
+
471
+ rel_path = img_path.relative_to(root)
472
+ parts = rel_path.parts
473
+
474
+ # Heuristic for class detection
475
+ label = "unknown"
476
+ split = "train"
477
+
478
+ # Detect split if first folder is a split keyword
479
+ if parts[0].lower() in ignore and len(parts) > 1:
480
+ if parts[0].lower() in ("train", "training"): split = "train"
481
+ elif parts[0].lower() in ("val", "valid", "validation"): split = "val"
482
+ elif parts[0].lower() in ("test", "testing"): split = "test"
483
+
484
+ # Check if next part is class name
485
+ if len(parts) > 2 and parts[1].lower() not in ignore:
486
+ label = parts[1]
487
+ elif len(parts) > 1 and parts[1].lower() not in ignore:
488
+ label = parts[1]
489
+ elif len(parts) > 1 and parts[0].lower() not in ignore:
490
+ label = parts[0]
491
+
492
+ anns = []
493
+ if label != "unknown":
494
+ class_set.add(label)
495
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
496
+ # Create a virtual annotation for classification
497
+ from datasets.annotation_parser import _make_ann
498
+ anns.append(_make_ann(image_id, dataset_id, label, ann_type="classification"))
499
+ else:
500
+ image_id = f"img-{uuid.uuid4().hex[:12]}"
501
+
502
+ w, h = _img_dimensions(img_path)
503
+ img_rec = {
504
+ "id": image_id,
505
+ "filename": img_path.name,
506
+ "rel_path": str(rel_path),
507
+ "width": w, "height": h,
508
+ "split": split,
509
+ "ann_count": len(anns),
510
+ }
511
+ results.append((img_rec, anns))
512
+
513
+ return sorted(list(class_set)), results
514
+
515
+
516
+ # ── Utilities ────────────────────────────────────────────────────────────────
517
+
518
+ async def _download_zip(job_id: str, dataset_id: str, url: str, custom_headers: dict = None) -> Path:
519
+ tmp_dir = DATASETS_ROOT / "_tmp"
520
+ tmp_dir.mkdir(parents=True, exist_ok=True)
521
+ zip_path = tmp_dir / f"{dataset_id}-{uuid.uuid4().hex[:8]}.zip"
522
+
523
+ headers = {
524
+ "User-Agent": "Mozilla/5.0 (MLForge Workbench)",
525
+ "Accept": "application/zip, application/octet-stream, */*",
526
+ }
527
+ if custom_headers: headers.update(custom_headers)
528
+
529
+ async with httpx.AsyncClient(follow_redirects=True, timeout=600.0, headers=headers) as client:
530
+ async with client.stream("GET", url) as resp:
531
+ resp.raise_for_status()
532
+ total = int(resp.headers.get("content-length", 0)) or None
533
+ downloaded = 0
534
+ async with aiofiles.open(zip_path, "wb") as f:
535
+ async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size):
536
+ await f.write(chunk)
537
+ downloaded += len(chunk)
538
+ if total:
539
+ pct = 0.10 + (downloaded / total) * 0.35 # 10% -> 45%
540
+ await ds_reg.update_job(job_id, progress=round(pct, 3), message=f"Downloading: {_fmt_bytes(downloaded)} / {_fmt_bytes(total)}")
541
+
542
+ return zip_path
543
+
544
+
545
+ def _safe_extract(zip_path: Path, dest: Path) -> None:
546
+ with zipfile.ZipFile(str(zip_path), "r") as zf:
547
+ for member in zf.namelist():
548
+ if os.path.isabs(member) or ".." in Path(member).parts: continue
549
+ zf.extract(member, str(dest))
550
+
551
+
552
+ def _copy_dir_contents(src: Path, dest: Path) -> None:
553
+ for item in src.iterdir():
554
+ s, d = src / item.name, dest / item.name
555
+ if s.is_dir(): shutil.copytree(s, d, dirs_exist_ok=True)
556
+ else: shutil.copy2(s, d)
557
+
558
+
559
+ def _scan_images_generic(dataset_id: str, root: Path) -> list[dict]:
560
+ records = []
561
+ exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
562
+ for img_path in sorted(root.rglob("*")):
563
+ if img_path.suffix.lower() in exts:
564
+ w, h = _img_dimensions(img_path)
565
+ records.append({
566
+ "id": f"img-{uuid.uuid4().hex[:12]}",
567
+ "filename": img_path.name,
568
+ "rel_path": str(img_path.relative_to(root)),
569
+ "width": w, "height": h, "split": "train", "ann_count": 0,
570
+ })
571
+ return records
572
+
573
+
574
+ def _dir_size(path: Path) -> int:
575
+ return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
576
+
577
+
578
+ def _fmt_bytes(n: int) -> str:
579
+ for unit in ("B", "KB", "MB", "GB", "TB"):
580
+ if n < 1024: return f"{n:.1f} {unit}"
581
+ n /= 1024
582
+ return f"{n:.1f} PB"
583
+
584
+
585
+ def _format_to_rf_slug(fmt: str) -> str:
586
+ return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8")
587
+
588
+ def _format_to_rf_slug(fmt: str) -> str:
589
+ return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8")
datasets/registry.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets/registry.py — Dataset Registry: persistent CRUD against datasets table.
3
+ All DB interactions for datasets and dataset_jobs live here.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import uuid
9
+ from datetime import datetime
10
+ from typing import Any
11
+
12
+ from database.connection import get_db
13
+ from models.dataset import Dataset, DatasetJob, DatasetStatus, row_to_dataset, row_to_job
14
+ from observability.logger import get_logger
15
+
16
+ log = get_logger("dataset_registry")
17
+
18
+
19
+ # ── Dataset CRUD ──────────────────────────────────────────────────────────────
20
+
21
+ async def get_all_datasets(
22
+ task: str | None = None,
23
+ format: str | None = None,
24
+ source: str | None = None,
25
+ status: str | None = None,
26
+ search: str | None = None,
27
+ starred: bool | None = None,
28
+ limit: int = 500,
29
+ offset: int = 0,
30
+ ) -> list[Dataset]:
31
+ db = await get_db()
32
+ clauses = []
33
+ params: list[Any] = []
34
+
35
+ if task:
36
+ clauses.append("task = ?")
37
+ params.append(task)
38
+ if format:
39
+ clauses.append("format = ?")
40
+ params.append(format)
41
+ if source:
42
+ clauses.append("source = ?")
43
+ params.append(source)
44
+ if status:
45
+ clauses.append("status = ?")
46
+ params.append(status)
47
+ if starred is not None:
48
+ clauses.append("starred = ?")
49
+ params.append(1 if starred else 0)
50
+ if search:
51
+ clauses.append("(name LIKE ? OR description LIKE ? OR tags LIKE ?)")
52
+ q = f"%{search}%"
53
+ params.extend([q, q, q])
54
+
55
+ where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
56
+ sql = f"SELECT * FROM datasets {where} ORDER BY updated_at DESC LIMIT ? OFFSET ?"
57
+ params.extend([limit, offset])
58
+
59
+ async with db.execute(sql, params) as cur:
60
+ rows = await cur.fetchall()
61
+ return [row_to_dataset(r) for r in rows]
62
+
63
+
64
+ async def get_dataset_stats(dataset_id: str) -> dict:
65
+ """Get pre-computed class distributions and statistics from the indexed annotations."""
66
+ db = await get_db()
67
+
68
+ # Class distribution (from dataset_annotations table)
69
+ async with db.execute(
70
+ "SELECT label, COUNT(*) as count FROM dataset_annotations WHERE dataset_id=? GROUP BY label ORDER BY count DESC",
71
+ (dataset_id,)
72
+ ) as cur:
73
+ dist = await cur.fetchall()
74
+
75
+ # Split distribution (from dataset_images table)
76
+ async with db.execute(
77
+ "SELECT split, COUNT(*) as count FROM dataset_images WHERE dataset_id=? GROUP BY split",
78
+ (dataset_id,)
79
+ ) as cur:
80
+ splits = await cur.fetchall()
81
+
82
+ return {
83
+ "class_distribution": {row["label"]: row["count"] for row in dist},
84
+ "split_distribution": {row["split"]: row["count"] for row in splits}
85
+ }
86
+
87
+
88
+ async def get_dataset(dataset_id: str) -> Dataset | None:
89
+ db = await get_db()
90
+ async with db.execute("SELECT * FROM datasets WHERE id = ?", (dataset_id,)) as cur:
91
+ row = await cur.fetchone()
92
+ return row_to_dataset(row) if row else None
93
+
94
+
95
+ async def count_datasets() -> int:
96
+ db = await get_db()
97
+ async with db.execute("SELECT COUNT(*) FROM datasets") as cur:
98
+ row = await cur.fetchone()
99
+ return row[0] if row else 0
100
+
101
+
102
+ async def upsert_dataset(ds: Dataset) -> None:
103
+ """Insert or replace a dataset record."""
104
+ db = await get_db()
105
+
106
+ task = getattr(ds.task, "value", ds.task)
107
+ fmt = getattr(ds.format, "value", ds.format)
108
+ src = getattr(ds.source, "value", ds.source)
109
+ status = getattr(ds.status, "value", ds.status)
110
+ await db.execute(
111
+ """INSERT OR REPLACE INTO datasets
112
+ (id, name, description, task, format, source, status,
113
+ images, classes, class_names, size_bytes, size_label,
114
+ local_path, import_progress, tags, versions, active_version,
115
+ starred, roboflow_id, created_at, updated_at)
116
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,datetime('now'))""",
117
+ (
118
+ ds.id, ds.name, ds.description, task, fmt,
119
+ src, status,
120
+ ds.images, ds.classes,
121
+ json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
122
+ ds.local_path, ds.import_progress,
123
+ json.dumps(ds.tags),
124
+ json.dumps([v.model_dump() if hasattr(v, "model_dump") else v for v in ds.versions]),
125
+ ds.active_version,
126
+ 1 if ds.starred else 0,
127
+ ds.roboflow_id,
128
+ ds.created_at or datetime.utcnow().isoformat(),
129
+ ),
130
+ )
131
+ await db.commit()
132
+
133
+
134
+ async def update_dataset_status(
135
+ dataset_id: str,
136
+ status: DatasetStatus,
137
+ progress: float | None = None,
138
+ local_path: str | None = None,
139
+ ) -> None:
140
+ db = await get_db()
141
+ if progress is not None and local_path is not None:
142
+ await db.execute(
143
+ "UPDATE datasets SET status=?, import_progress=?, local_path=? WHERE id=?",
144
+ (status.value, progress, local_path, dataset_id),
145
+ )
146
+ elif progress is not None:
147
+ await db.execute(
148
+ "UPDATE datasets SET status=?, import_progress=? WHERE id=?",
149
+ (status.value, progress, dataset_id),
150
+ )
151
+ else:
152
+ await db.execute(
153
+ "UPDATE datasets SET status=? WHERE id=?",
154
+ (status.value, dataset_id),
155
+ )
156
+ await db.commit()
157
+
158
+
159
+ async def update_dataset_stats(
160
+ dataset_id: str,
161
+ images: int,
162
+ classes: int,
163
+ class_names: list[str],
164
+ size_bytes: int,
165
+ stats: dict | None = None
166
+ ) -> None:
167
+ db = await get_db()
168
+
169
+ # Calculate health score if stats provided
170
+ health_score = 0.0
171
+ if stats:
172
+ health_score = stats.get("health_score", 0.0)
173
+
174
+ await db.execute(
175
+ """UPDATE datasets
176
+ SET images=?, classes=?, class_names=?, size_bytes=?,
177
+ size_label=?, stats=?, health_score=?
178
+ WHERE id=?""",
179
+ (
180
+ images, classes, json.dumps(class_names),
181
+ size_bytes, _fmt_bytes(size_bytes),
182
+ json.dumps(stats) if stats else None,
183
+ health_score,
184
+ dataset_id,
185
+ ),
186
+ )
187
+ await db.commit()
188
+
189
+
190
+ async def delete_dataset(dataset_id: str) -> bool:
191
+ db = await get_db()
192
+ async with db.execute("SELECT 1 FROM datasets WHERE id=?", (dataset_id,)) as cur:
193
+ exists = await cur.fetchone()
194
+ if not exists:
195
+ return False
196
+ await db.execute("DELETE FROM datasets WHERE id=?", (dataset_id,))
197
+ await db.commit()
198
+ return True
199
+
200
+
201
+ async def toggle_starred(dataset_id: str) -> bool:
202
+ """Toggle starred flag, return new value."""
203
+ db = await get_db()
204
+ async with db.execute("SELECT starred FROM datasets WHERE id=?", (dataset_id,)) as cur:
205
+ row = await cur.fetchone()
206
+ if not row:
207
+ return False
208
+ new_val = 0 if row["starred"] else 1
209
+ await db.execute("UPDATE datasets SET starred=? WHERE id=?", (new_val, dataset_id))
210
+ await db.commit()
211
+ return bool(new_val)
212
+
213
+
214
+ # ── Bulk dataset upsert from Roboflow ────────────────────────────────────────
215
+
216
+ async def bulk_upsert_datasets(datasets: list[Dataset]) -> int:
217
+ """Insert/update many datasets in a single transaction."""
218
+ if not datasets:
219
+ return 0
220
+ db = await get_db()
221
+ now = datetime.utcnow().isoformat()
222
+ rows = [
223
+ (
224
+ ds.id, ds.name, ds.description, ds.task.value, ds.format.value,
225
+ ds.source.value, ds.status.value,
226
+ ds.images, ds.classes,
227
+ json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
228
+ ds.local_path, ds.import_progress,
229
+ json.dumps(ds.tags), json.dumps([]),
230
+ ds.active_version, 0, ds.roboflow_id,
231
+ ds.created_at or now,
232
+ )
233
+ for ds in datasets
234
+ ]
235
+ await db.executemany(
236
+ """INSERT OR IGNORE INTO datasets
237
+ (id, name, description, task, format, source, status,
238
+ images, classes, class_names, size_bytes, size_label,
239
+ local_path, import_progress, tags, versions, active_version,
240
+ starred, roboflow_id, created_at)
241
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
242
+ rows,
243
+ )
244
+ await db.commit()
245
+ return len(datasets)
246
+
247
+
248
+ # ── Dataset Jobs ──────────────────────────────────────────────────────────────
249
+
250
+ async def create_job(
251
+ dataset_id: str,
252
+ dataset_name: str,
253
+ job_type: str,
254
+ ) -> DatasetJob:
255
+ db = await get_db()
256
+ job_id = f"djob-{uuid.uuid4().hex[:12]}"
257
+ now = datetime.utcnow().isoformat()
258
+ await db.execute(
259
+ """INSERT INTO dataset_jobs
260
+ (id, type, status, dataset_id, dataset_name, progress, message, created_at)
261
+ VALUES (?, ?, 'queued', ?, ?, 0.0, '', ?)""",
262
+ (job_id, job_type, dataset_id, dataset_name, now),
263
+ )
264
+ await db.commit()
265
+ return DatasetJob(
266
+ id=job_id, type=job_type, status="queued",
267
+ dataset_id=dataset_id, dataset_name=dataset_name,
268
+ created_at=now,
269
+ )
270
+
271
+
272
+ async def update_job(
273
+ job_id: str,
274
+ status: str | None = None,
275
+ progress: float | None = None,
276
+ message: str | None = None,
277
+ error: str | None = None,
278
+ started_at: str | None = None,
279
+ ended_at: str | None = None,
280
+ ) -> None:
281
+ db = await get_db()
282
+ parts = []
283
+ params: list[Any] = []
284
+ if status is not None:
285
+ parts.append("status=?"); params.append(status)
286
+ if progress is not None:
287
+ parts.append("progress=?"); params.append(progress)
288
+ if message is not None:
289
+ parts.append("message=?"); params.append(message)
290
+ if error is not None:
291
+ parts.append("error=?"); params.append(error)
292
+ if started_at is not None:
293
+ parts.append("started_at=?"); params.append(started_at)
294
+ if ended_at is not None:
295
+ parts.append("ended_at=?"); params.append(ended_at)
296
+ if not parts:
297
+ return
298
+ params.append(job_id)
299
+ await db.execute(f"UPDATE dataset_jobs SET {', '.join(parts)} WHERE id=?", params)
300
+ await db.commit()
301
+
302
+
303
+ async def get_job(job_id: str) -> DatasetJob | None:
304
+ db = await get_db()
305
+ async with db.execute("SELECT * FROM dataset_jobs WHERE id=?", (job_id,)) as cur:
306
+ row = await cur.fetchone()
307
+ return row_to_job(row) if row else None
308
+
309
+
310
+ async def get_all_jobs(limit: int = 100) -> list[DatasetJob]:
311
+ db = await get_db()
312
+ async with db.execute(
313
+ "SELECT * FROM dataset_jobs ORDER BY created_at DESC LIMIT ?", (limit,)
314
+ ) as cur:
315
+ rows = await cur.fetchall()
316
+ return [row_to_job(r) for r in rows]
317
+
318
+
319
+ # ── Image Index ───────────────────────────────────────────────────────────────
320
+
321
+ async def index_images(
322
+ dataset_id: str,
323
+ records: list[dict], # [{id, filename, rel_path, width, height, split, ann_count}]
324
+ ) -> int:
325
+ db = await get_db()
326
+ await db.executemany(
327
+ """INSERT OR IGNORE INTO dataset_images
328
+ (id, dataset_id, filename, rel_path, width, height, split, ann_count)
329
+ VALUES (:id, :dataset_id, :filename, :rel_path, :width, :height, :split, :ann_count)""",
330
+ [{"dataset_id": dataset_id, **r} for r in records],
331
+ )
332
+ await db.commit()
333
+ return len(records)
334
+
335
+
336
+ async def get_image_page(
337
+ dataset_id: str,
338
+ page: int = 0,
339
+ page_size: int = 20,
340
+ split: str | None = None,
341
+ class_label: str | None = None,
342
+ ) -> tuple[int, list[dict]]:
343
+ db = await get_db()
344
+
345
+ clauses = ["dataset_id=?"]
346
+ params: list[Any] = [dataset_id]
347
+
348
+ if split:
349
+ clauses.append("split=?")
350
+ params.append(split)
351
+
352
+ if class_label:
353
+ # Join with annotations table to filter by class
354
+ where = f"WHERE {' AND '.join(clauses)} AND id IN (SELECT image_id FROM dataset_annotations WHERE label=?)"
355
+ count_params = params + [class_label]
356
+ else:
357
+ where = f"WHERE {' AND '.join(clauses)}"
358
+ count_params = params
359
+
360
+ async with db.execute(f"SELECT COUNT(*) FROM dataset_images {where}", count_params) as cur:
361
+ total = (await cur.fetchone())[0]
362
+
363
+ params_final = count_params + [page_size, page * page_size]
364
+ async with db.execute(
365
+ f"SELECT * FROM dataset_images {where} ORDER BY filename LIMIT ? OFFSET ?", params_final
366
+ ) as cur:
367
+ rows = await cur.fetchall()
368
+ return total, [dict(r) for r in rows]
369
+
370
+
371
+ async def get_annotations_for_image(image_id: str) -> list[dict]:
372
+ db = await get_db()
373
+ async with db.execute(
374
+ "SELECT * FROM dataset_annotations WHERE image_id=?", (image_id,)
375
+ ) as cur:
376
+ rows = await cur.fetchall()
377
+ return [dict(r) for r in rows]
378
+
379
+
380
+ async def bulk_insert_annotations(records: list[dict]) -> int:
381
+ if not records:
382
+ return 0
383
+ db = await get_db()
384
+ await db.executemany(
385
+ """INSERT OR IGNORE INTO dataset_annotations
386
+ (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
387
+ normalised, area, confidence, ann_type)
388
+ VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
389
+ :normalised,:area,:confidence,:ann_type)""",
390
+ records,
391
+ )
392
+ await db.commit()
393
+ return len(records)
394
+
395
+
396
+ # ── Universal Dataset Items ──────────────────────────────────────────────
397
+
398
+ async def get_universal_items(
399
+ self,
400
+ dataset_id: str,
401
+ page: int = 0,
402
+ page_size: int = 20,
403
+ split: str | None = None,
404
+ class_label: str | None = None,
405
+ ) -> tuple[int, list[dict]]:
406
+ """Fetch polymorphic dataset items (images, text rows, etc.) and their annotations."""
407
+ db = await get_db()
408
+
409
+ # 1. Get total and base item records
410
+ total, items = await self.get_image_page(dataset_id, page, page_size, split, class_label)
411
+
412
+ # 2. Convert to universal format
413
+ # This is a bridge until we fully move to the universal schema
414
+ return total, items
415
+
416
+ async def bulk_insert_universal_annotations(self, records: list[dict]) -> int:
417
+ """Insert universal annotations into the extended schema."""
418
+ if not records:
419
+ return 0
420
+ db = await get_db()
421
+ await db.executemany(
422
+ """INSERT OR IGNORE INTO dataset_annotations
423
+ (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
424
+ normalised, area, confidence, ann_type, segmentation, keypoints, metadata)
425
+ VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
426
+ :normalised,:area,:confidence,:ann_type,:segmentation,:keypoints,:metadata)""",
427
+ records,
428
+ )
429
+ await db.commit()
430
+ return len(records)
431
+
432
+ async def update_dataset_task(dataset_id: str, task: str) -> None:
433
+ db = await get_db()
434
+ await db.execute("UPDATE datasets SET task=? WHERE id=?", (task, dataset_id))
435
+ await db.commit()
436
+
437
+
438
+ async def cleanup_stale_jobs() -> None:
439
+ """Mark running/queued jobs as failed on startup."""
440
+ db = await get_db()
441
+ await db.execute(
442
+ "UPDATE dataset_jobs SET status='failed', error='System restart' WHERE status IN ('running', 'queued')"
443
+ )
444
+ await db.commit()
445
+
446
+
447
+ def _fmt_bytes(n: int) -> str:
448
+ for unit in ("B", "KB", "MB", "GB", "TB"):
449
+ if n < 1024:
450
+ return f"{n:.1f} {unit}"
451
+ n /= 1024
452
+ return f"{n:.1f} PB"
datasets/viewer_service.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets/viewer_service.py — Dataset Viewer Service.
3
+
4
+ Provides paginated image + annotation serving for the Dataset Viewer UI.
5
+ All paths are resolved relative to the dataset's local_path for security.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+
11
+ from datasets import registry as ds_reg
12
+ from models.dataset import (
13
+ Annotation, AnnotationType, BoundingBox, Dataset,
14
+ ImageRecord, ViewerPage, DatasetFormat
15
+ )
16
+ from datasets.annotation_parser import YOLOParser, COCOParser, VOCParser, CSVParser
17
+ from observability.logger import get_logger
18
+
19
+ log = get_logger("viewer_service")
20
+
21
+
22
+ from .format_adapters import NLPAdapter, TabularAdapter
23
+ from models.dataset import UniversalViewerPage, UniversalDatasetItem, UniversalAnnotation, DatasetContentType, DatasetTask
24
+
25
+ async def get_universal_viewer_page(
26
+ dataset_id: str,
27
+ page: int = 0,
28
+ page_size: int = 20,
29
+ split: str | None = None,
30
+ class_label: str | None = None,
31
+ ) -> UniversalViewerPage:
32
+ """Polymorphic viewer endpoint that adapts based on dataset task."""
33
+ ds = await ds_reg.get_dataset(dataset_id)
34
+ if not ds:
35
+ raise ValueError("Dataset not found")
36
+
37
+ ds_root = Path(ds.local_path) if ds.local_path else None
38
+
39
+ # 1. Vision Tasks (Detection, Seg, Pose) -> Use existing image-centric logic
40
+ if ds.task in (DatasetTask.detection, DatasetTask.segmentation, DatasetTask.keypoints):
41
+ # We wrap the existing get_viewer_page and transform to UniversalDatasetItem
42
+ old_page = await get_viewer_page(dataset_id, page, page_size, split, class_label)
43
+
44
+ items = []
45
+ for img in old_page.images:
46
+ items.append(UniversalDatasetItem(
47
+ id=img.image_id,
48
+ content_type=DatasetContentType.image,
49
+ filename=img.filename,
50
+ metadata={"width": img.width, "height": img.height, "split": img.split},
51
+ annotations=[
52
+ UniversalAnnotation(
53
+ label=ann.label,
54
+ type=ann.type.value if hasattr(ann.type, 'value') else str(ann.type),
55
+ bbox=[ann.bbox.x, ann.bbox.y, ann.bbox.width, ann.bbox.height] if ann.bbox else None,
56
+ segmentation=ann.segmentation,
57
+ keypoints=ann.keypoints,
58
+ confidence=ann.confidence,
59
+ metadata=ann.metadata
60
+ ) for ann in img.annotations
61
+ ]
62
+ ))
63
+
64
+ return UniversalViewerPage(
65
+ dataset_id=dataset_id,
66
+ page=page,
67
+ page_size=page_size,
68
+ total=old_page.total,
69
+ total_pages=old_page.total_pages,
70
+ items=items
71
+ )
72
+
73
+ # 2. NLP Tasks (CSV, JSONL)
74
+ elif ds.task == DatasetTask.nlp and ds_root:
75
+ adapter = NLPAdapter()
76
+ total, items = await adapter.get_items(ds_root, page, page_size)
77
+ total_pages = max(1, (total + page_size - 1) // page_size)
78
+ return UniversalViewerPage(
79
+ dataset_id=dataset_id,
80
+ page=page,
81
+ page_size=page_size,
82
+ total=total,
83
+ total_pages=total_pages,
84
+ items=items
85
+ )
86
+
87
+ # 3. Tabular Tasks (CSV, Parquet)
88
+ elif ds.task == DatasetTask.tabular and ds_root:
89
+ adapter = TabularAdapter()
90
+ total, items = await adapter.get_items(ds_root, page, page_size)
91
+ total_pages = max(1, (total + page_size - 1) // page_size)
92
+ return UniversalViewerPage(
93
+ dataset_id=dataset_id,
94
+ page=page,
95
+ page_size=page_size,
96
+ total=total,
97
+ total_pages=total_pages,
98
+ items=items
99
+ )
100
+
101
+ # Fallback / Empty
102
+ return UniversalViewerPage(
103
+ dataset_id=dataset_id,
104
+ page=page,
105
+ page_size=page_size,
106
+ total=0,
107
+ total_pages=0,
108
+ items=[]
109
+ )
110
+
111
+ async def get_viewer_page(
112
+ dataset_id: str,
113
+ page: int = 0,
114
+ page_size: int = 20,
115
+ split: str | None = None,
116
+ class_label: str | None = None,
117
+ ) -> ViewerPage:
118
+ """
119
+ Return a paginated viewer page for the dataset.
120
+ Images come from the index; annotations are loaded per-image.
121
+ """
122
+ if page_size > 100:
123
+ page_size = 100 # cap to prevent huge payloads
124
+
125
+ total, image_rows = await ds_reg.get_image_page(dataset_id, page, page_size, split, class_label)
126
+ ds = await ds_reg.get_dataset(dataset_id)
127
+
128
+ # Check if we have an active project and if the dataset exists there
129
+ from projects.service import get_active_project_path
130
+ project_path = await get_active_project_path()
131
+
132
+ # Dynamically load annotations from database first, fallback to filesystem if needed
133
+ image_ids = [row["id"] for row in image_rows]
134
+ dynamic_anns: dict[str, list[Annotation]] = {img_id: [] for img_id in image_ids}
135
+
136
+ # 1. Try loading from DB index (Authoritative for analytics)
137
+ try:
138
+ from database.connection import get_db
139
+ db = await get_db()
140
+ # Fetch all annotations for these images in one go
141
+ placeholders = ",".join(["?"] * len(image_ids))
142
+ async with db.execute(
143
+ f"SELECT * FROM dataset_annotations WHERE image_id IN ({placeholders})",
144
+ image_ids
145
+ ) as cur:
146
+ rows = await cur.fetchall()
147
+ for r in rows:
148
+ dynamic_anns[r["image_id"]].append(_row_to_annotation(dict(r)))
149
+ except Exception as e:
150
+ log.warning("db_annotation_read_failed", error=str(e), dataset_id=dataset_id)
151
+
152
+ # 2. Fallback to filesystem if no annotations found in DB and we have a path
153
+ # This maintains compatibility with old datasets or specific live-read needs
154
+ if all(not anns for anns in dynamic_anns.values()) and ds and ds.local_path:
155
+ ds_root = Path(ds.local_path)
156
+ # Use ds.local_path directly as it is now authoritative project-local path
157
+ # Fallback to global removed per user request
158
+
159
+ fmt = ds.format.value if hasattr(ds.format, 'value') else str(ds.format)
160
+
161
+ try:
162
+ if fmt == DatasetFormat.yolo.value or fmt == "yolo":
163
+ class_map = YOLOParser.load_class_map(ds_root)
164
+ for row in image_rows:
165
+ rel_path = Path(row["rel_path"])
166
+ # For YOLO, the label file is usually in a parallel 'labels' folder
167
+ # or in the same folder as the image.
168
+ # Roboflow structure: train/images/img.jpg -> train/labels/img.txt
169
+ parts = list(rel_path.parts)
170
+
171
+ label_rel = None
172
+ if "images" in parts:
173
+ idx = parts.index("images")
174
+ parts_labels = list(parts)
175
+ parts_labels[idx] = "labels"
176
+ label_rel = Path(*parts_labels).with_suffix(".txt")
177
+
178
+ # Fallback: same folder
179
+ label_same_folder = rel_path.with_suffix(".txt")
180
+
181
+ for cand_rel in [label_rel, label_same_folder]:
182
+ if not cand_rel: continue
183
+ label_file = ds_root / cand_rel
184
+ if label_file.exists():
185
+ anns = YOLOParser.parse_file(label_file, row["id"], ds.id, class_map)
186
+ dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns]
187
+ break
188
+
189
+ elif fmt == DatasetFormat.coco.value or fmt == "coco":
190
+ jsons = COCOParser.find_annotation_files(ds_root)
191
+ img_map = {row["filename"]: row["id"] for row in image_rows}
192
+ for jf in jsons:
193
+ _, parsed = COCOParser.parse_file(jf, ds.id)
194
+ for p_rel, _, _, anns in parsed:
195
+ fname = Path(p_rel).name
196
+ if fname in img_map:
197
+ img_id = img_map[fname]
198
+ dynamic_anns[img_id].extend([_row_to_annotation(a) for a in anns])
199
+
200
+ elif fmt == DatasetFormat.voc.value or fmt == "voc":
201
+ for row in image_rows:
202
+ img_abs = ds_root / row["rel_path"]
203
+ xml_candidates = [img_abs.with_suffix(".xml")]
204
+ parts = list(Path(row["rel_path"]).parts)
205
+ if "JPEGImages" in parts:
206
+ idx = parts.index("JPEGImages")
207
+ parts[idx] = "Annotations"
208
+ xml_candidates.append(ds_root.joinpath(*parts).with_suffix(".xml"))
209
+
210
+ for cand in xml_candidates:
211
+ if cand.exists():
212
+ _, _, _, anns = VOCParser.parse_file(cand, row["id"], ds.id)
213
+ dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns]
214
+ break
215
+
216
+ elif fmt == "csv":
217
+ for row in image_rows:
218
+ csv_path = ds_root / row["rel_path"]
219
+ if csv_path.exists():
220
+ # For CSV/NLP, we might need a more specific way to find the exact row,
221
+ # but for now we reload the file or use a cached version.
222
+ # Since get_viewer_page is paginated, we'll parse the file.
223
+ anns = CSVParser.parse_file(csv_path, ds.id)
224
+ # Find the annotation matching this "image_id" (which is the text entry id)
225
+ matching_anns = [a for a in anns if a["image_id"] == row["id"]]
226
+ dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in matching_anns]
227
+
228
+ except Exception as e:
229
+ log.error("dynamic_annotation_read_failed", error=str(e), dataset_id=dataset_id)
230
+
231
+ images: list[ImageRecord] = []
232
+ for row in image_rows:
233
+ annotations = dynamic_anns.get(row["id"], [])
234
+ images.append(ImageRecord(
235
+ image_id = row["id"],
236
+ filename = row["filename"],
237
+ width = row["width"],
238
+ height = row["height"],
239
+ path = row["rel_path"],
240
+ annotations = annotations,
241
+ split = row["split"],
242
+ ))
243
+
244
+ total_pages = max(1, (total + page_size - 1) // page_size)
245
+
246
+ return ViewerPage(
247
+ dataset_id = dataset_id,
248
+ page = page,
249
+ page_size = page_size,
250
+ total = total,
251
+ total_pages = total_pages,
252
+ images = images,
253
+ )
254
+
255
+
256
+ def _row_to_annotation(row: dict) -> Annotation:
257
+ bbox = None
258
+ if row.get("bbox_x") is not None:
259
+ bbox = BoundingBox(
260
+ x = row["bbox_x"],
261
+ y = row["bbox_y"],
262
+ width = row["bbox_w"],
263
+ height = row["bbox_h"],
264
+ normalised = bool(row.get("normalised", 1)),
265
+ )
266
+
267
+ segmentation = None
268
+ if row.get("segmentation"):
269
+ try:
270
+ import json
271
+ segmentation = json.loads(row["segmentation"])
272
+ except:
273
+ pass
274
+
275
+ return Annotation(
276
+ label = row["label"],
277
+ bbox = bbox,
278
+ segmentation = segmentation,
279
+ confidence = row.get("confidence"),
280
+ area = row.get("area"),
281
+ type = AnnotationType(row.get("ann_type", "detection")),
282
+ )
283
+
284
+
285
+ async def resolve_image_path(dataset_id: str, image_id: str) -> Path | None:
286
+ """
287
+ Resolve the absolute filesystem path for an image.
288
+ Prioritizes the active project's dataset folder, falling back to the global cache.
289
+ Returns None if dataset not imported or image not found.
290
+ """
291
+ ds = await ds_reg.get_dataset(dataset_id)
292
+ if ds is None or not ds.local_path:
293
+ return None
294
+
295
+ base_root = Path(ds.local_path)
296
+ # ds.local_path is now authoritative project-local path
297
+ # Fallback removed per user request
298
+
299
+ from database.connection import get_db
300
+ db = await get_db()
301
+ async with db.execute(
302
+ "SELECT rel_path FROM dataset_images WHERE id=? AND dataset_id=?",
303
+ (image_id, dataset_id),
304
+ ) as cur:
305
+ row = await cur.fetchone()
306
+ if not row:
307
+ return None
308
+
309
+ abs_path = base_root / row["rel_path"]
310
+ if not abs_path.exists():
311
+ return None
312
+
313
+ # Security: ensure path is under base_root
314
+ try:
315
+ abs_path.resolve().relative_to(base_root.resolve())
316
+ except ValueError:
317
+ log.warning("path_traversal_attempt", dataset_id=dataset_id, image_id=image_id)
318
+ return None
319
+
320
+ return abs_path
download/__init__.py ADDED
File without changes
download/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file