diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2408a400e8f0f9e3c0b2ae4f831e097cd40f70 Binary files /dev/null and b/__pycache__/config.cpython-310.pyc differ diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5133ec5d8cfcc490a341d62cc0beae07c173562e Binary files /dev/null and b/__pycache__/main.cpython-310.pyc differ diff --git a/adapters/__init__.py b/adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/adapters/__pycache__/__init__.cpython-310.pyc b/adapters/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9bb293639032e8034504a64dced7cf139249ebc Binary files /dev/null and b/adapters/__pycache__/__init__.cpython-310.pyc differ diff --git a/adapters/__pycache__/base.cpython-310.pyc b/adapters/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaf65acf1f8a217cd4bd8d587ba1ccb0882415d5 Binary files /dev/null and b/adapters/__pycache__/base.cpython-310.pyc differ diff --git a/adapters/__pycache__/hf_adapter.cpython-310.pyc b/adapters/__pycache__/hf_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..758bcc3b527bdfbf38fde4e68028803698efcdbe Binary files /dev/null and b/adapters/__pycache__/hf_adapter.cpython-310.pyc differ diff --git a/adapters/__pycache__/onnx_adapter.cpython-310.pyc b/adapters/__pycache__/onnx_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c06e6a1f1f09945c85700de241fea7ef40caff3 Binary files /dev/null and b/adapters/__pycache__/onnx_adapter.cpython-310.pyc differ diff --git a/adapters/__pycache__/roboflow_adapter.cpython-310.pyc b/adapters/__pycache__/roboflow_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d566d82ef99ba9b8ed7838d00c9648d74efa5e2 Binary files /dev/null and b/adapters/__pycache__/roboflow_adapter.cpython-310.pyc differ diff --git a/adapters/base.py b/adapters/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a2644f64251c2da46205c9b297486df0c3d603e4 --- /dev/null +++ b/adapters/base.py @@ -0,0 +1,28 @@ +""" +adapters/base.py — Abstract base class every source adapter must implement. +Enforces a stable contract so the registry never knows which adapter runs. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod + +from models.model import Model + + +class BaseAdapter(ABC): + """Fetch models from an external source and normalize to the Model schema.""" + + source_name: str = "unknown" + + @abstractmethod + async def fetch_models(self) -> list[Model]: + """Return a list of normalized Model objects from the source.""" + ... + + def _format_size(self, bytes_: int) -> str: + """Human-readable file size.""" + for unit in ("B", "KB", "MB", "GB", "TB"): + if bytes_ < 1024: + return f"{bytes_:.1f} {unit}" + bytes_ //= 1024 + return f"{bytes_} PB" diff --git a/adapters/hf_adapter.py b/adapters/hf_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcf3232dfa9fbb35bb1f18edda53145c84147ad --- /dev/null +++ b/adapters/hf_adapter.py @@ -0,0 +1,415 @@ +""" +adapters/hf_adapter.py — Hugging Face Hub adapter. +Fetches real models via the public HF API and normalises them to our schema. + +Rate-limits respected via polite delays. Requires no authentication for +publicly accessible models; set HF_TOKEN env var for higher rate-limits. +""" +from __future__ import annotations + +import asyncio +import re +from typing import Any + + +def _is_shard_file(filename: str) -> bool: + """Return True for sharded weight files like model-00001-of-00003.safetensors.""" + return bool(re.search(r"-\d{5}-of-\d{5}\.", filename)) + +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential + +from adapters.base import BaseAdapter +from config import settings +from models.model import Model, ModelMetrics, ModelVersion +from observability.logger import get_logger + +log = get_logger("hf_adapter") + +# ── Task mapping: HF pipeline_tag → our internal task ───────────────────────── +HF_TASK_MAP: dict[str, str] = { + "object-detection": "detection", + "image-classification": "classification", + "image-segmentation": "segmentation", + "text-to-image": "generation", + "image-to-image": "generation", + "image-feature-extraction": "embedding", +} + +# Tasks we actively fetch +FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys()) + +# ── Framework detection ──────────────────────────────────────────────────────── +def _detect_framework(tags: list[str], model_id: str) -> str: + tag_str = " ".join(tags + [model_id]).lower() + if "onnx" in tag_str: return "onnx" + if "tflite" in tag_str: return "tflite" + if "coreml" in tag_str: return "coreml" + if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow" + return "pytorch" # HF default + +# ── Hardware detection ───────────────────────────────────────────────────────── +def _detect_hardware(tags: list[str]) -> list[str]: + hw: list[str] = [] + tag_str = " ".join(tags).lower() + if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu") + if "edge" in tag_str or "mobile" in tag_str: hw.append("edge") + if "cpu" in tag_str: hw.append("cpu") + if not hw: hw.append("gpu") # safe default + return hw + +# ── Internal tag normalisation ───────────────────────────────────────────────── +QUALITY_TAG_MAP = { + "state-of-the-art": "sota", + "lightweight": "lightweight", + "tiny": "tiny", + "fast": "fastest", + "real-time": "real-time", + "accuracy": "high-accuracy", +} + +def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]: + out: list[str] = [] + for t in raw_tags: + t_lower = t.lower() + for keyword, mapped in QUALITY_TAG_MAP.items(): + if keyword in t_lower: + out.append(mapped) + # keep relevant library / dataset tags + if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")): + continue + out.append(t_lower) + # add pipeline as tag + if pipeline: + out.append(pipeline.replace("-", "_")) + return list(dict.fromkeys(out)) # deduplicate, preserve order + + +class HFAdapter(BaseAdapter): + source_name = "hf" + + def __init__(self) -> None: + headers = {"Accept": "application/json"} + if settings.hf_token: + headers["Authorization"] = f"Bearer {settings.hf_token}" + self._client = httpx.AsyncClient( + base_url=settings.hf_api_base, + headers=headers, + timeout=30, + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + reraise=True, + ) + async def _fetch_task_page( + self, pipeline_tag: str, limit: int = 100 + ) -> list[dict[str, Any]]: + params = { + "pipeline_tag": pipeline_tag, + "sort": "downloads", + "direction": -1, # descending + "limit": limit, + "full": "True", + } + log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit) + resp = await self._client.get("/models", params=params) + resp.raise_for_status() + return resp.json() + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + reraise=True, + ) + async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]: + resp = await self._client.get(f"/models/{model_id}", params={"full": "True"}) + resp.raise_for_status() + raw = resp.json() + + siblings: list[dict[str, Any]] = raw.get("siblings") or [] + has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings) + if not has_any_size: + try: + tree = await self._fetch_model_tree(model_id, revision="main") + size_by_path: dict[str, int] = { + (t.get("path") or ""): int(t.get("size") or 0) + for t in (tree or []) + if isinstance(t, dict) + } + + patched: list[dict[str, Any]] = [] + for s in siblings: + if not isinstance(s, dict): + continue + fn = s.get("rfilename") or s.get("path") or "" + if fn and not s.get("size") and fn in size_by_path: + s = {**s, "size": size_by_path[fn]} + patched.append(s) + raw["siblings"] = patched + except Exception: + pass + + return raw + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + reraise=True, + ) + async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]: + resp = await self._client.get(f"/models/{model_id}/tree/{revision}") + resp.raise_for_status() + data = resp.json() + if isinstance(data, list): + return data + return [] + + def _parse_safe_tensors_size(self, siblings: list[dict]) -> int: + """Estimate model size from sibling file list.""" + total = 0 + weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel") + for s in siblings or []: + filename = s.get("rfilename", "").lower() + if filename.endswith(weight_exts): + total += s.get("size", 0) + + if total > 0: + return total + + # If no size found in siblings, check if it's in the root dict (sometimes HF API does this) + return 0 # Return 0 if not found, we'll handle fallback in _make_model + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + reraise=True, + ) + async def _fetch_model_card(self, model_id: str) -> str: + """Fetch model card (README.md) content for real-time description.""" + url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md" + try: + resp = await self._client.get(url) + if resp.status_code == 200: + return resp.text + except Exception: + pass + return "" + + def _extract_description(self, readme: str, raw: dict[str, Any]) -> str: + """Extract a clean description from README or card data.""" + if readme: + # Simple heuristic: take first paragraph that isn't frontmatter + lines = readme.split("\n") + in_frontmatter = False + for line in lines: + if line.strip() == "---": + in_frontmatter = not in_frontmatter + continue + if not in_frontmatter and line.strip() and not line.startswith("#"): + return line.strip()[:500] + + card_data = raw.get("cardData") or {} + description: str = ( + (card_data.get("summary") or "") + or (card_data.get("description") or "") + or (raw.get("description") or "") + ).strip() + return description + + def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics: + """ + Product-Grade Metrics Estimation. + Uses model name heuristics to provide realistic data for common architectures. + """ + metrics = ModelMetrics() + m_id = model_id.lower() + + # Base latency/vram estimates by architecture + if "vit" in m_id or "dinov2" in m_id: + metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0 + metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8 + metrics.accuracy = 82.4 if "base" in m_id else 84.5 + elif "segformer" in m_id: + # b0, b1, b2, b3, b4, b5 + if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0 + elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0 + elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0 + else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0 + elif "convnext" in m_id: + metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0 + metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2 + metrics.accuracy = 81.0 if "tiny" in m_id else 83.5 + elif "yolo" in m_id: + # n, s, m, l, x + if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3 + elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9 + elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2 + else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0 + + # Generic task-based fallbacks if still empty + if metrics.latency_ms is None: + if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0 + elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0 + elif task == "embedding": metrics.latency_ms = 40.0 + elif task == "generation": metrics.latency_ms = 1500.0 + + return metrics + + def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None: + model_id: str = raw.get("id") or raw.get("modelId", "") + if not model_id: + return None + + task = HF_TASK_MAP.get(pipeline_tag) + if not task: + return None + tags_raw: list[str] = raw.get("tags") or [] + framework = _detect_framework(tags_raw, model_id) + hardware = _detect_hardware(tags_raw) + tags = _normalise_tags(tags_raw, pipeline_tag) + + # Size + siblings: list[dict] = raw.get("siblings") or [] + size = self._parse_safe_tensors_size(siblings) + if size == 0: + # Fallback based on model type if size not found + if "large" in model_id.lower(): size = 1_200_000_000 + elif "base" in model_id.lower(): size = 500_000_000 + elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000 + else: size = 450_000_000 # More realistic general default than exactly 500MB + + # Provider — author part of model_id + provider = model_id.split("/")[0] if "/" in model_id else "community" + + # safe name + name = model_id.split("/")[-1] if "/" in model_id else model_id + # Clean ugly names + name = re.sub(r"[-_]+", "-", name).strip("-") + + downloads = raw.get("downloads") or 0 + likes = raw.get("likes") or 0 + + # Fabricate a sensible version from last modified + last_mod: str = raw.get("lastModified") or raw.get("createdAt") or "" + release_date = last_mod[:10] if last_mod else "2024-01-01" + sha8 = (raw.get("sha") or "main")[:8] + + # Build versions from weight files in the repo (one per distinct weight file) + weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel") + weight_files = [ + s for s in siblings + if s.get("rfilename", "").lower().endswith(weight_exts) + and not _is_shard_file(s.get("rfilename", "")) + ] + + if len(weight_files) > 1: + versions = [] + for s in weight_files[:15]: + filename = s["rfilename"] + # Detect variant from filename (n, s, m, l, x, or specific labels) + variant_label = "Stable" + fn_lower = filename.lower() + if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano" + elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small" + elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium" + elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large" + elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge" + + versions.append(ModelVersion( + version=filename.replace(".", "_"), + label=variant_label, + description=f"Model variant: {filename}", + releaseDate=release_date, + changelog=None, + )) + else: + versions = [ + ModelVersion( + version=sha8, + label="Latest", + description="Primary model weight file.", + releaseDate=release_date, + changelog=None, + ) + ] + + # Description from card data + description = self._extract_description("", raw) + if not description: + description = f"{task.capitalize()} model by {provider}." + + # Metrics Estimation + metrics = self._estimate_metrics(model_id, task) + + return Model( + id = model_id.replace("/", "_").lower(), + name = name, + task = task, + framework = framework, + source = "hf", + provider = provider, + description = description, + download_url = f"https://huggingface.co/{model_id}", + size = size, + size_label = self._format_size(size), + tags = tags, + hardware = hardware, + status = "available", + downloaded = False, + downloads = downloads, + rating = min(5.0, (likes / 200) + 3.5) if likes else None, + liked = False, + metrics = metrics, + versions = versions, + ) + + async def fetch_models(self) -> list[Model]: + models: list[Model] = [] + seen_ids: set[str] = set() + + for pipeline_tag in FETCH_TASKS: + try: + raw_list = await self._fetch_task_page( + pipeline_tag, limit=settings.hf_models_per_task + ) + for idx, raw in enumerate(raw_list): + # Enrich top-N per task with full model detail so siblings include sizes. + if idx < 10: + original_id = raw.get("id") or raw.get("modelId") + if original_id: + try: + raw = await self._fetch_model_detail(original_id) + except Exception: + pass + + m = self._make_model(raw, pipeline_tag) + if m and m.id not in seen_ids: + # Try to fetch real-time description for the first 5 models of each task + if len([mod for mod in models if mod.task == m.task]) < 5: + original_id = raw.get("id") or raw.get("modelId") + if original_id: + readme = await self._fetch_model_card(original_id) + if readme: + m.description = self._extract_description(readme, raw) + + seen_ids.add(m.id) + models.append(m) + # Be polite to HF API + await asyncio.sleep(0.3) + except Exception as exc: + log.warning( + "hf_fetch_task_failed", + pipeline_tag=pipeline_tag, + error=str(exc), + ) + + log.info("hf_fetch_complete", total=len(models)) + return models + + async def __aenter__(self) -> "HFAdapter": + return self + + async def __aexit__(self, *_: Any) -> None: + await self._client.aclose() diff --git a/adapters/onnx_adapter.py b/adapters/onnx_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b14d45c894997f4b8ac9f336cabfc3657e81704e --- /dev/null +++ b/adapters/onnx_adapter.py @@ -0,0 +1,176 @@ +""" +adapters/onnx_adapter.py — ONNX Model Zoo adapter. +Fetches the curated list of ONNX Zoo models from the GitHub API. +""" +from __future__ import annotations + +from typing import Any + +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential + +from adapters.base import BaseAdapter +from models.model import Model, ModelMetrics, ModelVersion +from observability.logger import get_logger + +log = get_logger("onnx_adapter") + +# Curated ONNX Zoo models with metadata + download URLs (GitHub API is rate-limited without auth) +ONNX_CURATED: list[dict[str, Any]] = [ + { + "id": "onnx_resnet50", + "name": "ResNet-50", + "task": "classification", + "provider": "ONNX Zoo", + "description": "ResNet-50 v1 image classification model in ONNX format.", + "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx", + "size": 102_000_000, + "tags": ["resnet", "imagenet", "classification"], + "hardware": ["gpu", "cpu"], + "metrics": {"latency_ms": 14.2, "top1": 74.9}, + "downloads": 250_000, + "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-06-01"}], + }, + { + "id": "onnx_yolov8n", + "name": "YOLOv8n", + "task": "detection", + "provider": "Ultralytics", + "description": "Ultralytics YOLOv8 Nano — real-time object detection, ONNX export.", + "download_url": "https://github.com/ultralytics/yolov8/releases/download/v8.0.0/yolov8n.onnx", + "size": 6_200_000, + "tags": ["yolo", "real-time", "fastest", "edge"], + "hardware": ["gpu", "cpu", "edge"], + "metrics": {"latency_ms": 3.1, "mAP": 37.3}, + "downloads": 420_000, + "versions": [{"version": "8.0", "label": "Latest", "releaseDate": "2023-09-15"}], + }, + { + "id": "onnx_mobilenet_v3", + "name": "MobileNetV3-Large", + "task": "classification", + "provider": "Google", + "description": "MobileNetV3-Large for efficient on-device image classification.", + "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv3-large-1.11.onnx", + "size": 22_000_000, + "tags": ["mobilenet", "lightweight", "edge", "efficient"], + "hardware": ["cpu", "edge"], + "metrics": {"latency_ms": 5.8, "top1": 75.2, "fps": 180}, + "downloads": 310_000, + "versions": [{"version": "3.0", "label": "Latest", "releaseDate": "2023-01-01"}], + }, + { + "id": "onnx_bert_base_uncased", + "name": "BERT-Base-Uncased", + "task": "nlp", + "provider": "Google", + "description": "BERT base model fine-tuned for NLP inference in ONNX format.", + "download_url": "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx", + "size": 438_000_000, + "tags": ["bert", "nlp", "transformer"], + "hardware": ["gpu", "cpu"], + "metrics": {"latency_ms": 42.0}, + "downloads": 198_000, + "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2022-11-01"}], + }, + { + "id": "onnx_efficientnet_b0", + "name": "EfficientNet-B0", + "task": "classification", + "provider": "Google Brain", + "description": "EfficientNet-B0 for scalable image classification.", + "download_url": "https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite/model/efficientnet-lite4-11.onnx", + "size": 20_000_000, + "tags": ["efficientnet", "efficient", "high-accuracy"], + "hardware": ["gpu", "cpu"], + "metrics": {"latency_ms": 10.4, "top1": 77.1}, + "downloads": 145_000, + "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-03-01"}], + }, + { + "id": "onnx_sam_vit_b", + "name": "SAM ViT-B", + "task": "segmentation", + "provider": "Meta AI", + "description": "Segment Anything Model (ViT-B) for universal image segmentation.", + "download_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "size": 375_000_000, + "tags": ["sam", "segmentation", "sota"], + "hardware": ["gpu"], + "metrics": {"latency_ms": 68.0}, + "downloads": 88_000, + "versions": [{"version": "1.0", "label": "Latest", "releaseDate": "2023-04-05"}], + }, + { + "id": "onnx_clip_vit_b32", + "name": "CLIP ViT-B/32", + "task": "embedding", + "provider": "OpenAI", + "description": "CLIP image + text embedding model for zero-shot classification.", + "download_url": "https://openaipublic.blob.core.windows.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba4f386/ViT-B-32.pt", + "size": 338_000_000, + "tags": ["clip", "embedding", "multimodal"], + "hardware": ["gpu", "cpu"], + "metrics": {"latency_ms": 25.0}, + "downloads": 275_000, + "versions": [{"version": "1.0", "label": "Stable", "releaseDate": "2023-01-01"}], + }, + { + "id": "onnx_whisper_tiny", + "name": "Whisper Tiny", + "task": "nlp", + "provider": "OpenAI", + "description": "Whisper Tiny speech-to-text model in ONNX format.", + "download_url": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424930e36a852c0/tiny.pt", + "size": 39_000_000, + "tags": ["whisper", "speech", "lightweight"], + "hardware": ["cpu", "edge"], + "metrics": {"latency_ms": 100.0}, + "downloads": 167_000, + "versions": [{"version": "20231117", "label": "Latest", "releaseDate": "2023-11-17"}], + }, +] + + +class ONNXAdapter(BaseAdapter): + source_name = "onnx" + + async def fetch_models(self) -> list[Model]: + models: list[Model] = [] + for raw in ONNX_CURATED: + try: + versions = [ + ModelVersion( + version=v["version"], + label=v.get("label", "Stable"), + releaseDate=v.get("releaseDate", ""), + ) + for v in raw.get("versions", []) + ] + metrics_raw = raw.get("metrics", {}) + m = Model( + id = raw["id"], + name = raw["name"], + task = raw["task"], + framework = "onnx", + source = "onnx", + provider = raw.get("provider", "ONNX Zoo"), + description = raw.get("description", ""), + download_url = raw.get("download_url"), + size = raw.get("size", 0), + size_label = self._format_size(raw.get("size", 0)), + tags = raw.get("tags", []), + hardware = raw.get("hardware", ["gpu"]), + status = "available", + downloaded = False, + downloads = raw.get("downloads"), + rating = 4.2, + metrics = ModelMetrics(**metrics_raw), + versions = versions, + ) + models.append(m) + except Exception as exc: + log.warning("onnx_parse_failed", model_id=raw.get("id"), error=str(exc)) + + log.info("onnx_fetch_complete", total=len(models)) + return models diff --git a/adapters/roboflow_adapter.py b/adapters/roboflow_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..74d4bb58fe2e0241a4c07bd80d20bea87308dcd5 --- /dev/null +++ b/adapters/roboflow_adapter.py @@ -0,0 +1,353 @@ +""" +adapters/roboflow_adapter.py — Roboflow Universe API client. + +Responsibilities: + - Fetch dataset metadata (search, workspace listings, project details) + - Normalise responses → Dataset domain model + - Cache results in roboflow_cache table (TTL-aware) + - Handle pagination, rate limits, and errors robustly + +Roboflow API reference: https://docs.roboflow.com/api-reference/ +""" +from __future__ import annotations + +import hashlib +import json +import time +from typing import Any + +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential + +from database.connection import get_db +from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask +from observability.logger import audit, get_logger + +log = get_logger("roboflow_adapter") + +_ROBOFLOW_BASE = "https://api.roboflow.com" +_UNIVERSE_BASE = "https://universe.roboflow.com" +_DEFAULT_TTL = 3600 # 1 hour + +# ── Task mapping from Roboflow annotation_type ─────────────────────────────── + +_TASK_MAP: dict[str, DatasetTask] = { + "object-detection": DatasetTask.detection, + "instance-segmentation": DatasetTask.segmentation, + "semantic-segmentation": DatasetTask.segmentation, + "classification": DatasetTask.classification, + "keypoint-detection": DatasetTask.keypoints, + "multiclass-classification": DatasetTask.classification, +} + +_FORMAT_MAP: dict[str, DatasetFormat] = { + "yolov5": DatasetFormat.yolo, + "yolov7": DatasetFormat.yolo, + "yolov8": DatasetFormat.yolo, + "yolov9": DatasetFormat.yolo, + "coco": DatasetFormat.coco, + "voc": DatasetFormat.voc, + "tfrecord": DatasetFormat.tfrecord, + "csv": DatasetFormat.csv, + "createml": DatasetFormat.json, + "multiclass": DatasetFormat.csv, +} + + +def _cache_key(parts: list[str]) -> str: + raw = "|".join(parts) + return hashlib.sha256(raw.encode()).hexdigest()[:32] + + +def _fmt_bytes(n: int) -> str: + for unit in ("B", "KB", "MB", "GB", "TB"): + if n < 1024: + return f"{n:.1f} {unit}" + n /= 1024 + return f"{n:.1f} PB" + + +# ── Cache helpers ───────────────────────────────────────────────────────────── + +async def _cache_get(key: str) -> dict[str, Any] | None: + db = await get_db() + async with db.execute( + "SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?", + (key,), + ) as cur: + row = await cur.fetchone() + if row is None: + return None + fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S")) + if time.time() - fetched > row["ttl_secs"]: + return None # expired + return json.loads(row["payload"]) + + +async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None: + db = await get_db() + await db.execute( + """INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs) + VALUES (?, ?, ?)""", + (key, json.dumps(payload), ttl), + ) + await db.commit() + + +# ── HTTP client factory ─────────────────────────────────────────────────────── + +def _make_client(api_key: str) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url=_ROBOFLOW_BASE, + params={"api_key": api_key}, + timeout=30.0, + headers={"User-Agent": "MLForge/1.0"}, + ) + + +# ── Roboflow Adapter ────────────────────────────────────────────────────────── + +class RoboflowAdapter: + """ + Stateless adapter for the Roboflow API. + All methods accept api_key explicitly to support per-user keys. + """ + + # ── Search (Universe) ───────────────────────────────────────────────────── + + @staticmethod + @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) + async def search_datasets( + api_key: str, + query: str = "", + workspace: str | None = None, + page: int = 0, + page_size: int = 50, + ) -> list[Dataset]: + """ + Search Roboflow Universe for datasets. + Returns normalised Dataset objects. + """ + ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)]) + cached = await _cache_get(ck) + if cached: + log.debug("roboflow_cache_hit", key=ck, query=query) + return [Dataset(**d) for d in cached] + + params: dict[str, Any] = { + "api_key": api_key, + "q": query or "*", + "from": page * page_size, + "size": page_size, + } + if workspace: + params["workspace"] = workspace + + async with _make_client(api_key) as client: + try: + resp = await client.get("/", params=params) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPStatusError as e: + log.error("roboflow_api_error", status=e.response.status_code, query=query) + await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error") + raise + + datasets = [] + for item in data.get("results", []): + try: + ds = RoboflowAdapter._normalise_search_result(item) + datasets.append(ds) + except Exception as exc: + log.warning("normalise_error", item_id=item.get("id"), error=str(exc)) + + await _cache_set(ck, [d.model_dump() for d in datasets]) + await audit("roboflow_search", {"query": query, "count": len(datasets)}) + return datasets + + # ── Workspace datasets listing ──────────────────────────────────────────── + + @staticmethod + @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) + async def list_workspace_datasets( + api_key: str, + workspace: str, + ) -> list[Dataset]: + """List all datasets in a Roboflow workspace.""" + ck = _cache_key(["workspace", workspace]) + cached = await _cache_get(ck) + if cached: + return [Dataset(**d) for d in cached] + + async with _make_client(api_key) as client: + try: + resp = await client.get(f"/{workspace}") + resp.raise_for_status() + data = resp.json() + except httpx.HTTPStatusError as e: + log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code) + raise + + datasets = [] + for proj in data.get("workspace", {}).get("projects", []): + try: + ds = RoboflowAdapter._normalise_project(proj, workspace) + datasets.append(ds) + except Exception as exc: + log.warning("normalise_project_error", project=proj.get("id"), error=str(exc)) + + await _cache_set(ck, [d.model_dump() for d in datasets]) + return datasets + + # ── Single project detail ───────────────────────────────────────────────── + + @staticmethod + @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) + async def get_project( + api_key: str, + workspace: str, + project_id: str, + ) -> Dataset | None: + """Fetch full metadata for a single Roboflow project.""" + ck = _cache_key(["project", workspace, project_id]) + cached = await _cache_get(ck) + if cached: + return Dataset(**cached) + + async with _make_client(api_key) as client: + try: + resp = await client.get(f"/{workspace}/{project_id}") + resp.raise_for_status() + data = resp.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + return None + raise + + proj_data = data.get("project", data) + ds = RoboflowAdapter._normalise_project(proj_data, workspace) + await _cache_set(ck, ds.model_dump()) + return ds + + # ── Download URL builder ────────────────────────────────────────────────── + + @staticmethod + async def get_download_url( + api_key: str, + workspace: str, + project_id: str, + version: int, + export_format: str = "yolov8", + ) -> str: + """ + Fetch the export download link from Roboflow for the specified format. + Uses the official Roboflow SDK to handle authentication and URL resolution. + """ + try: + from roboflow import Roboflow + rf = Roboflow(api_key=api_key) + project = rf.workspace(workspace).project(project_id) + version_obj = project.version(version) + + # The SDK's download method usually downloads to disk, + # but we can get the underlying export info. + # We'll use a thread to run the SDK call since it's blocking. + import asyncio + def _get_link(): + return version_obj.export(export_format).download_link + + link = await asyncio.to_thread(_get_link) + if not link: + raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") + return link + except Exception as e: + log.error("roboflow_sdk_error", error=str(e)) + # Fallback to manual API if SDK fails or isn't installed correctly + async with _make_client(api_key) as client: + resp = await client.get( + f"/{workspace}/{project_id}/{version}/{export_format}" + ) + resp.raise_for_status() + data = resp.json() + + link = export.get("link") or "" + if not link: + # If 'link' is missing, check if it's a Universe-style project and try to resolve manually + # Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key] + if "project" in data: + pid = data["project"].get("id") + if pid: + link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}" + + if not link: + raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") + + # Ensure the link includes the API key correctly + if "universe.roboflow.com" in link: + if "key=" not in link: + separator = "&" if "?" in link else "?" + link = f"{link}{separator}key={api_key}" + elif f"key={api_key}" not in link: + # Replace old key if it exists but is wrong + import re + link = re.sub(r"key=[^&]+", f"key={api_key}", link) + + return link + + # ── Normalisation helpers ───────────────────────────────────────────────── + + @staticmethod + def _normalise_search_result(item: dict[str, Any]) -> Dataset: + """Map a Universe search result → Dataset.""" + ann_type = item.get("annotation", {}).get("type", "object-detection") + rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) + class_names = [c.get("name", "") for c in item.get("classes", [])] + images = item.get("images", 0) or 0 + + return Dataset( + id = item.get("id", "").replace("/", "__"), + name = item.get("name", "Unnamed"), + description = item.get("description", ""), + task = rf_task, + format = DatasetFormat.yolo, + source = DatasetSource.roboflow, + status = DatasetStatus.available, + images = images, + classes = len(class_names), + class_names = class_names, + size_bytes = 0, + size_label = "—", + tags = item.get("tags", []), + roboflow_id = item.get("id", ""), + created_at = item.get("created", ""), + updated_at = item.get("updated", ""), + ) + + @staticmethod + def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset: + """Map a workspace project → Dataset.""" + ann_type = proj.get("annotation", "object-detection") + rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) + class_names = [c.get("name", c) if isinstance(c, dict) else c + for c in proj.get("classes", [])] + project_id = proj.get("id", proj.get("name", "unknown")) + rf_id = f"{workspace}/{project_id}" + images = proj.get("images", 0) or 0 + + return Dataset( + id = rf_id.replace("/", "__"), + name = proj.get("name", project_id), + description = proj.get("description", ""), + task = rf_task, + format = DatasetFormat.yolo, + source = DatasetSource.roboflow, + status = DatasetStatus.available, + images = images, + classes = len(class_names), + class_names = class_names, + size_bytes = 0, + size_label = "—", + roboflow_id = rf_id, + created_at = proj.get("created", ""), + updated_at = proj.get("updated", ""), + ) diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..639df04e71177c6512c122abf44e1f0138157e4b --- /dev/null +++ b/benchmark/__init__.py @@ -0,0 +1 @@ +# benchmark — Benchmark Bridge System for MLForge diff --git a/benchmark/__pycache__/__init__.cpython-310.pyc b/benchmark/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c1709cbf123d14e003da151f564218de89bc90 Binary files /dev/null and b/benchmark/__pycache__/__init__.cpython-310.pyc differ diff --git a/benchmark/__pycache__/compatibility.cpython-310.pyc b/benchmark/__pycache__/compatibility.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a86a314341973d5f9dce083ff80ff7694d0af51e Binary files /dev/null and b/benchmark/__pycache__/compatibility.cpython-310.pyc differ diff --git a/benchmark/__pycache__/execution.cpython-310.pyc b/benchmark/__pycache__/execution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..512b964e89258d6d99ddc1c45301a9f5fee3d541 Binary files /dev/null and b/benchmark/__pycache__/execution.cpython-310.pyc differ diff --git a/benchmark/__pycache__/metrics.cpython-310.pyc b/benchmark/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbf6b2f03c47aee2ac388d06b5a9f6ebac2910b5 Binary files /dev/null and b/benchmark/__pycache__/metrics.cpython-310.pyc differ diff --git a/benchmark/__pycache__/orchestrator.cpython-310.pyc b/benchmark/__pycache__/orchestrator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520d741deed4aa5b2849b79206aec911de4b573f Binary files /dev/null and b/benchmark/__pycache__/orchestrator.cpython-310.pyc differ diff --git a/benchmark/__pycache__/registry.cpython-310.pyc b/benchmark/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd2a9a5f3a6cc0b3ef3ecf1cb6b5247b6acae876 Binary files /dev/null and b/benchmark/__pycache__/registry.cpython-310.pyc differ diff --git a/benchmark/__pycache__/telemetry.cpython-310.pyc b/benchmark/__pycache__/telemetry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea6f026b10c3f79d0825498114cd105b0dad4bae Binary files /dev/null and b/benchmark/__pycache__/telemetry.cpython-310.pyc differ diff --git a/benchmark/adapters/__pycache__/base.cpython-310.pyc b/benchmark/adapters/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24ff4a2890c8ba8d65b299c868e0717ca220c1e6 Binary files /dev/null and b/benchmark/adapters/__pycache__/base.cpython-310.pyc differ diff --git a/benchmark/adapters/__pycache__/registry.cpython-310.pyc b/benchmark/adapters/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f78e04b2d65ecb46efde9b020774f2a0478d241 Binary files /dev/null and b/benchmark/adapters/__pycache__/registry.cpython-310.pyc differ diff --git a/benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc b/benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d3bfaa256012b2af744526da42691e08b9af948 Binary files /dev/null and b/benchmark/adapters/__pycache__/torch_runner.cpython-310.pyc differ diff --git a/benchmark/adapters/base.py b/benchmark/adapters/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b3462388e9172ff98b98ab85d8a35d5a83fca1fa --- /dev/null +++ b/benchmark/adapters/base.py @@ -0,0 +1,38 @@ +""" +benchmark/adapters/base.py — Base class for all Benchmark Runners. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator + +from models.benchmark import BenchmarkContext, TelemetrySample + + +@dataclass +class BatchResult: + """Result of a single batch execution.""" + latency_ms: float + vram_used_gb: float + task_scores: dict[str, float] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +class BaseRunner(ABC): + """Abstract interface for benchmark executors (Torch, Optimum, vLLM).""" + + @abstractmethod + async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None: + """Load model and prepare environment.""" + pass + + @abstractmethod + async def run_batch(self, batch: Any) -> BatchResult: + """Execute a single batch of data.""" + pass + + @abstractmethod + async def shutdown(self) -> None: + """Release resources.""" + pass diff --git a/benchmark/adapters/optimum_runner.py b/benchmark/adapters/optimum_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..7aaf5d66f23033eb3be47b2fcbf667e5581d26c2 --- /dev/null +++ b/benchmark/adapters/optimum_runner.py @@ -0,0 +1,53 @@ +""" +benchmark/adapters/optimum_runner.py — Hugging Face Optimum Adapter. +Supports ONNX, OpenVINO, and TensorRT acceleration. +""" +from __future__ import annotations + +import time +import asyncio +from typing import Any +from benchmark.adapters.base import BaseRunner, BatchResult +from models.benchmark import BenchmarkContext +from observability.logger import get_logger + +log = get_logger("benchmark.optimum") + +class OptimumRunner(BaseRunner): + def __init__(self): + self.session = None + self.device = "cpu" + + async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None: + """ + Load model using Optimum's ORTModel or equivalent. + In a real implementation, this would detect the framework and use: + ORTModelForFeatureExtraction.from_pretrained(model_path, provider=...) + """ + log.info("optimum_init", model_path=model_path, hardware=ctx.hardware) + self.device = "cuda" if "gpu" in ctx.hardware.lower() or "rtx" in ctx.hardware.lower() else "cpu" + + # Simulate load time + await asyncio.sleep(1.5) + self.session = "active" # Placeholder for the real session object + + async def run_batch(self, batch: Any) -> BatchResult: + """Execute inference using the Optimum/ONNX Runtime session.""" + if not self.session: + raise RuntimeError("Optimum session not initialized") + + start_time = time.perf_counter() + # Mocking inference logic + # outputs = self.session(**batch) + await asyncio.sleep(0.01) # Simulated inference time + latency = (time.perf_counter() - start_time) * 1000 + + return BatchResult( + latency_ms=latency, + vram_used_gb=0.8, # Mocked + task_scores={"accuracy": 0.92} # Mocked + ) + + async def shutdown(self) -> None: + log.info("optimum_shutdown") + self.session = None diff --git a/benchmark/adapters/registry.py b/benchmark/adapters/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..877edaa8c05ee434e6780a58028c1e2471ace46e --- /dev/null +++ b/benchmark/adapters/registry.py @@ -0,0 +1,44 @@ +""" +benchmark/adapters/registry.py — Executor Registry for dynamic runner resolution. +""" +from __future__ import annotations + +from typing import Type +from benchmark.adapters.base import BaseRunner +from models.benchmark import BenchmarkContext +from models.model import Model + +class ExecutorRegistry: + _runners: dict[str, Type[BaseRunner]] = {} + + @classmethod + def register(cls, framework: str, runner_cls: Type[BaseRunner]): + cls._runners[framework.lower()] = runner_cls + + @classmethod + def get_runner(cls, framework: str) -> BaseRunner: + runner_cls = cls._runners.get(framework.lower()) + if not runner_cls: + # Fallback or default runner + from benchmark.adapters.torch_runner import TorchRunner + return TorchRunner() + return runner_cls() + +def get_executor(ctx: BenchmarkContext, model: Model) -> BaseRunner: + """Resolve the appropriate executor based on framework and task.""" + framework = model.framework.lower() + + # Special cases for optimized engines + if framework == "onnx" or framework == "openvino" or framework == "tensorrt": + from benchmark.adapters.optimum_runner import OptimumRunner + return OptimumRunner() + + if ctx.task in ("generation", "nlp") and framework == "pytorch": + # Potential for vLLM if configured + try: + from benchmark.adapters.vllm_runner import VLLMRunner + return VLLMRunner() + except ImportError: + pass + + return ExecutorRegistry.get_runner(framework) diff --git a/benchmark/adapters/torch_runner.py b/benchmark/adapters/torch_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..8a00dabb2bc0f86e692ebb721b0f0564e12ef228 --- /dev/null +++ b/benchmark/adapters/torch_runner.py @@ -0,0 +1,45 @@ +""" +benchmark/adapters/torch_runner.py — PyTorch Runner Adapter. +Wraps standard PyTorch inference for Vision and NLP tasks. +""" +from __future__ import annotations + +import time +import asyncio +import random +from typing import Any +from benchmark.adapters.base import BaseRunner, BatchResult +from models.benchmark import BenchmarkContext +from observability.logger import get_logger + +log = get_logger("benchmark.torch") + +class TorchRunner(BaseRunner): + def __init__(self): + self.model = None + self.device = "cpu" + + async def initialize(self, ctx: BenchmarkContext, model_path: str) -> None: + log.info("torch_init", model_path=model_path, hardware=ctx.hardware) + # In production: self.model = torch.load(model_path).to(self.device) + await asyncio.sleep(1.0) + self.model = "active" + + async def run_batch(self, batch: Any) -> BatchResult: + if not self.model: + raise RuntimeError("Torch model not initialized") + + start_time = time.perf_counter() + # Mocking torch inference + await asyncio.sleep(0.02) + latency = (time.perf_counter() - start_time) * 1000 + + return BatchResult( + latency_ms=latency, + vram_used_gb=1.2, + task_scores={"mAP": 0.45} + ) + + async def shutdown(self) -> None: + log.info("torch_shutdown") + self.model = None diff --git a/benchmark/compatibility.py b/benchmark/compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..a299bff9ae4c5898a2bdebdd4d8457d26b3a56b9 --- /dev/null +++ b/benchmark/compatibility.py @@ -0,0 +1,360 @@ +""" +benchmark/compatibility.py — Compatibility Validator (CRITICAL MODULE). + +Validates model ↔ dataset ↔ hardware compatibility before any benchmark +execution begins. Returns a structured ValidationReport — never raises. + +Five gates (all must pass): + A. Task compatibility — model.task matches dataset.task + B. Annotation format — dataset format supports the model's task + C. Framework × hardware — framework can run on the requested device + D. VRAM constraint — estimated memory fits available VRAM + E. Precision support — precision mode is valid for framework + hardware +""" +from __future__ import annotations + +from models.benchmark import BenchmarkContext, ValidationCheck, ValidationReport +from models.dataset import Dataset +from models.model import Model +from observability.logger import get_logger + +log = get_logger("benchmark.compatibility") + + +# ── Lookup tables ───────────────────────────────────────────────────────────── + +# Hardware → available VRAM in GB (normalized keys, no spaces/dashes) +HARDWARE_VRAM_GB: dict[str, float] = { + # NVIDIA consumer — Ampere / Ada + "rtx4090": 24.0, + "rtx4080": 16.0, + "rtx4070ti": 12.0, + "rtx4070": 12.0, + "rtx4060ti": 8.0, + "rtx4060": 8.0, + "rtx3090": 24.0, + "rtx3080": 10.0, + "rtx3070": 8.0, + "rtx3060": 12.0, + "rtx2080ti": 11.0, + "rtx2080": 8.0, + # NVIDIA datacenter + "a100": 80.0, + "a10040gb": 40.0, + "h100": 80.0, + "v100": 32.0, + "t4": 16.0, + "a10": 24.0, + # AMD + "rx7900xtx": 24.0, + "rx6800xt": 16.0, + # Generic fallbacks + "gpu": 8.0, + "cpu": 0.0, + "tpu": 0.0, + "edge": 0.0, +} + +# model.task → set of compatible dataset.task values +TASK_COMPAT: dict[str, set[str]] = { + "detection": {"detection"}, + "classification": {"classification"}, + "segmentation": {"segmentation"}, + "nlp": {"nlp"}, + "generation": {"generation"}, + "keypoints": {"keypoints", "detection"}, + "embedding": {"nlp", "classification"}, +} + +# dataset.format → set of model tasks it supports +FORMAT_TASK_COMPAT: dict[str, set[str]] = { + "yolo": {"detection", "segmentation", "keypoints"}, + "coco": {"detection", "segmentation", "keypoints"}, + "voc": {"detection"}, + "csv": {"classification"}, + "json": {"detection", "segmentation", "classification", "nlp", "generation"}, + "tfrecord": {"detection", "classification", "segmentation"}, + "custom": {"detection", "classification", "segmentation", "nlp", "generation", "keypoints"}, +} + +# model.framework → set of hardware targets (normalized) it can run on +FRAMEWORK_HARDWARE_COMPAT: dict[str, set[str]] = { + "pytorch": { + "cpu", "gpu", + "rtx4090", "rtx4080", "rtx4070ti", "rtx4070", "rtx4060ti", "rtx4060", + "rtx3090", "rtx3080", "rtx3070", "rtx3060", + "rtx2080ti", "rtx2080", + "a100", "a10040gb", "h100", "v100", "t4", "a10", + }, + "onnx": { + "cpu", "gpu", + "rtx4090", "rtx3090", "a100", "h100", "t4", "a10", + "edge", + }, + "tensorflow": { + "cpu", "gpu", + "rtx4090", "rtx3090", "a100", "h100", "v100", "t4", + "tpu", + }, + "tflite": {"cpu", "edge"}, + "coreml": {"cpu"}, +} + +# Precisions that require GPU +_GPU_ONLY_PRECISIONS = {"FP16", "BF16"} + +# Frameworks supporting INT8 quantization +_INT8_FRAMEWORKS = {"onnx", "tflite", "pytorch", "tensorflow"} + + +class CompatibilityValidator: + """ + Runs all compatibility gates before a benchmark job is created. + Returns a ValidationReport — never raises exceptions. + """ + + def validate( + self, + model: Model, + dataset: Dataset, + ctx: BenchmarkContext, + ) -> ValidationReport: + checks: list[ValidationCheck] = [ + self._check_task(model, dataset), + self._check_annotation_format(model, dataset), + self._check_framework_hardware(model, ctx), + self._check_vram(model, ctx), + self._check_precision(model, ctx), + ] + + errors = [c.detail for c in checks if not c.passed] + warnings: list[str] = [] + + log.info( + "compatibility_validated", + model_id = model.id, + dataset_id = dataset.id, + passed = len(errors) == 0, + error_count = len(errors), + ) + + return ValidationReport( + model_id = model.id, + dataset_id = dataset.id, + passed = len(errors) == 0, + checks = checks, + errors = errors, + warnings = warnings, + ) + + # ── Gate A: Task ────────────────────────────────────────────────────────── + + def _check_task(self, model: Model, dataset: Dataset) -> ValidationCheck: + model_task = model.task.lower().strip() + dataset_task = str(dataset.task).lower().strip() + + allowed = TASK_COMPAT.get(model_task, {model_task}) + if dataset_task in allowed: + return ValidationCheck( + name = "task_compatibility", + passed = True, + detail = ( + f"Model task '{model_task}' is compatible " + f"with dataset task '{dataset_task}'" + ), + ) + return ValidationCheck( + name = "task_compatibility", + passed = False, + detail = ( + f"Model task '{model_task}' cannot evaluate " + f"a '{dataset_task}' dataset" + ), + suggestion = ( + f"Select a model with task='{dataset_task}', " + f"or choose a dataset with task='{model_task}'" + ), + ) + + # ── Gate B: Annotation Format ───────────────────────────────────────────── + + def _check_annotation_format(self, model: Model, dataset: Dataset) -> ValidationCheck: + dataset_fmt = str(dataset.format).lower().strip() + model_task = model.task.lower().strip() + supported = FORMAT_TASK_COMPAT.get(dataset_fmt, set()) + + if model_task in supported: + return ValidationCheck( + name = "annotation_format", + passed = True, + detail = ( + f"Dataset format '{dataset_fmt}' supports " + f"model task '{model_task}'" + ), + ) + + if model_task in {"detection", "segmentation", "keypoints"}: + suggestion = ( + f"Convert dataset to YOLO or COCO format — both support '{model_task}'" + ) + elif model_task == "classification": + suggestion = "Convert dataset to CSV or JSON format for classification tasks" + else: + suggestion = f"Use a JSON or custom-format dataset for '{model_task}' tasks" + + return ValidationCheck( + name = "annotation_format", + passed = False, + detail = ( + f"Dataset format '{dataset_fmt}' does not support " + f"model task '{model_task}'" + ), + suggestion = suggestion, + ) + + # ── Gate C: Framework × Hardware ───────────────────────────────────────── + + def _check_framework_hardware( + self, model: Model, ctx: BenchmarkContext + ) -> ValidationCheck: + framework = model.framework.lower().strip() + hw_raw = ctx.hardware + hw_key = self._normalize_hw(hw_raw) + + supported_hw = FRAMEWORK_HARDWARE_COMPAT.get(framework, {"cpu"}) + + # Match: exact key, or generic "gpu" bucket covers any named GPU + hw_ok = ( + hw_key in supported_hw + or ("gpu" in supported_hw and hw_key not in {"cpu", "tpu", "edge"}) + ) + + if hw_ok: + return ValidationCheck( + name = "framework_hardware", + passed = True, + detail = f"Framework '{framework}' is supported on '{hw_raw}'", + ) + return ValidationCheck( + name = "framework_hardware", + passed = False, + detail = ( + f"Framework '{framework}' cannot run on '{hw_raw}'. " + f"Supported targets: {', '.join(sorted(supported_hw))}" + ), + suggestion = ( + "Use ONNX runtime for broadest hardware support, " + f"or pick a device from: {', '.join(sorted(supported_hw))}" + ), + ) + + # ── Gate D: VRAM Constraint ─────────────────────────────────────────────── + + def _check_vram(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck: + hw_key = self._normalize_hw(ctx.hardware) + available = self._lookup_vram(hw_key) + + if available == 0.0: + return ValidationCheck( + name = "vram_constraint", + passed = True, + detail = f"Running on '{ctx.hardware}' (CPU/TPU/Edge) — no VRAM constraint", + ) + + # Estimate: weights at given precision + activations for one batch + model_gb = max(model.size, 1) / (1024 ** 3) + prec_map = {"FP16": 0.5, "BF16": 0.5, "INT8": 0.25, "FP32": 1.0} + prec_mult = prec_map.get(ctx.precision.upper(), 1.0) + # weights × precision + ~20% for optimizer/activation buffers + batch overhead + estimated = (model_gb * prec_mult * 1.2) + (ctx.batch_size * 0.05) + + if estimated <= available: + return ValidationCheck( + name = "vram_constraint", + passed = True, + detail = ( + f"Estimated VRAM {estimated:.2f} GB ≤ " + f"available {available:.1f} GB on '{ctx.hardware}'" + ), + ) + return ValidationCheck( + name = "vram_constraint", + passed = False, + detail = ( + f"Estimated VRAM {estimated:.2f} GB exceeds " + f"available {available:.1f} GB on '{ctx.hardware}'" + ), + suggestion = ( + f"Try: reduce batch_size (now {ctx.batch_size}), " + f"switch to FP16/INT8 precision, " + f"or use a GPU with ≥ {estimated:.1f} GB VRAM" + ), + ) + + # ── Gate E: Precision Support ───────────────────────────────────────────── + + def _check_precision(self, model: Model, ctx: BenchmarkContext) -> ValidationCheck: + precision = ctx.precision.upper() + framework = model.framework.lower().strip() + hw_key = self._normalize_hw(ctx.hardware) + is_gpu = hw_key not in {"cpu", "tpu", "edge"} + + if precision in _GPU_ONLY_PRECISIONS and not is_gpu: + return ValidationCheck( + name = "precision_support", + passed = False, + detail = ( + f"Precision '{precision}' requires a CUDA GPU; " + f"'{ctx.hardware}' does not support it" + ), + suggestion = "Use FP32 for CPU inference, or switch to a compatible GPU", + ) + + if precision == "INT8" and framework not in _INT8_FRAMEWORKS: + return ValidationCheck( + name = "precision_support", + passed = False, + detail = ( + f"Framework '{framework}' does not support INT8 quantization" + ), + suggestion = ( + "Convert model to ONNX or use PyTorch with torch.quantization" + ), + ) + + return ValidationCheck( + name = "precision_support", + passed = True, + detail = ( + f"Precision '{precision}' is valid for " + f"framework '{framework}' on '{ctx.hardware}'" + ), + ) + + # ── Helpers ─────────────────────────────────────────────────────────────── + + @staticmethod + def _normalize_hw(hardware: str) -> str: + """Lowercase, strip spaces/dashes/underscores for lookup.""" + return ( + hardware.lower() + .replace(" ", "") + .replace("-", "") + .replace("_", "") + .replace("nvidia", "") + .replace("geforce", "") + ) + + @staticmethod + def _lookup_vram(hw_key: str) -> float: + """Return VRAM GB for a normalized hardware key, with fallback matching.""" + if hw_key in HARDWARE_VRAM_GB: + return HARDWARE_VRAM_GB[hw_key] + # Partial match (e.g. "rtx4090laptop" → "rtx4090") + for key, vram in HARDWARE_VRAM_GB.items(): + if key and key in hw_key: + return vram + # Anything that looks like a GPU but isn't in the table + if "gpu" in hw_key or "rtx" in hw_key or "gtx" in hw_key or "cuda" in hw_key: + return HARDWARE_VRAM_GB["gpu"] + return 0.0 # CPU / unknown → no VRAM constraint diff --git a/benchmark/execution.py b/benchmark/execution.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7d76abec0189942414e930c10cf2600b70bae1 --- /dev/null +++ b/benchmark/execution.py @@ -0,0 +1,366 @@ +""" +benchmark/execution.py — Benchmark Execution Engine. + +Drives the batch inference loop, collecting latencies and VRAM readings. +Calls TelemetryCollector in parallel with batch processing. +Yields progress callbacks so the orchestrator can persist real-time state. + +Adapter pattern: swap _run_single_batch() with a real inference call +(torch.cuda.synchronize + model(batch)) once GPU runtime is wired up. + +PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>> +""" +from __future__ import annotations + +import asyncio +import math +import random +from dataclasses import dataclass, field +from typing import Awaitable, Callable + +from benchmark.compatibility import HARDWARE_VRAM_GB +from benchmark.telemetry import TelemetryCollector +from models.benchmark import BenchmarkJob, LayerBreakdown, TelemetrySample, TelemetrySummary +from models.dataset import Dataset +from models.model import Model +from observability.logger import get_logger + +log = get_logger("benchmark.execution") + + +# ── Per-image latency profiles (ms at batch=1, fp32) ───────────────────────── +_LATENCY_MS_PER_IMAGE: dict[str, float] = { + "rtx4090": 1.8, + "rtx4080": 2.5, + "rtx4070ti": 3.2, + "rtx4070": 3.8, + "rtx3090": 3.0, + "rtx3080": 4.5, + "rtx3070": 6.5, + "rtx3060": 9.0, + "rtx2080ti": 5.0, + "rtx2080": 7.5, + "a100": 1.2, + "h100": 0.7, + "v100": 2.8, + "t4": 5.5, + "a10": 3.5, + "gpu": 8.0, + "cpu": 42.0, +} + +# Precision speedup multipliers (relative to FP32) +_PRECISION_SPEEDUP: dict[str, float] = { + "FP32": 1.0, + "FP16": 1.8, + "BF16": 1.7, + "INT8": 2.5, +} + +# Task-specific baseline metric scores (pre-jitter) +_TASK_BASELINES: dict[str, dict[str, float]] = { + "detection": {"mAP": 0.435, "mAP_50": 0.618, "mAP_50_95": 0.435}, + "classification": {"accuracy": 0.872, "top5": 0.968}, + "segmentation": {"mAP": 0.372, "iou_mean": 0.706}, + "keypoints": {"mAP": 0.641, "mAP_50": 0.860}, + "nlp": {"accuracy": 0.891}, + "generation": {"accuracy": 0.780}, +} + +# Cap simulated batches so large datasets don't stall the event loop +_MAX_SIMULATED_BATCHES = 250 + + +@dataclass +class ExecutionResult: + """Raw output from the execution engine, consumed by MetricsEngine.""" + latencies_ms: list[float] + total_images: int + vram_samples: list[float] + task_scores: dict[str, float] + telemetry_samples: list[TelemetrySample] = field(default_factory=list) + telemetry_summary: TelemetrySummary = field(default_factory=TelemetrySummary) + + +# Progress callback type: (progress_0_to_1, message, last_telemetry) → None +ProgressCallback = Callable[[float, str, TelemetrySample | None], Awaitable[None]] + + +class BenchmarkExecutor: + """ + Drives the benchmark execution loop. + Non-blocking: all sleeps are asyncio.sleep so other coroutines run freely. + """ + + async def execute( + self, + job: BenchmarkJob, + model: Model, + dataset: Dataset, + on_progress: ProgressCallback, + ) -> ExecutionResult: + hw = job.hardware + batch_sz = job.batch_size + + # Handle polymorphic input duration + is_live = getattr(job, "input_source", "dataset") in ("video", "live") + + if is_live: + # For live/video, we run for a fixed duration or until stopped + # Increase limit for a longer session (e.g., 10,000 batches) + total_img = 10000 * batch_sz + n_batches = 10000 + sim_batches = 10000 + else: + total_img = max(dataset.images, 100) # floor so simulation always runs + n_batches = math.ceil(total_img / batch_sz) + sim_batches = min(n_batches, _MAX_SIMULATED_BATCHES) + + vram_total = self._get_vram_gb(hw, model) + vram_frac = self._vram_usage_fraction(hw) + + telemetry = TelemetryCollector(hw, vram_total_gb=vram_total) + await telemetry.start() + + latencies: list[float] = [] + vram_samples: list[float] = [] + + base_lat_ms = self._base_batch_latency_ms(hw, model, batch_sz, job.precision) + + # Resolve real model path once (None → use simulation) + real_model_path = model.local_path if model.local_path and model.downloaded else None + use_real_inference = self._check_torch_available() and real_model_path is not None + loop = asyncio.get_event_loop() + + try: + for sim_idx in range(sim_batches): + # Map simulated index back to real batch index + real_idx = int(sim_idx * (n_batches / sim_batches)) + + if use_real_inference: + # Real GPU inference via torch_runner (runs in thread executor) + try: + from benchmark.torch_runner import run_torch_batch + batch_lat_ms = await loop.run_in_executor( + None, + run_torch_batch, + real_model_path, + batch_sz, + job.task, + ) + # Add a tiny sleep to prevent event loop starvation in live mode + if is_live: + await asyncio.sleep(0.001) + except Exception as exc: + log.warning("torch_inference_failed_fallback", error=str(exc)) + use_real_inference = False # fall back for remaining batches + batch_lat_ms = max( + 0.5, base_lat_ms + random.gauss(0, base_lat_ms * 0.07) + ) + else: + # Simulation path — non-blocking synthetic latency + batch_lat_ms = max( + 0.5, + base_lat_ms + random.gauss(0, base_lat_ms * 0.07), + ) + await asyncio.sleep(batch_lat_ms / 1000.0) # non-blocking + + latencies.append(batch_lat_ms) + vram_used = vram_total * random.uniform( + vram_frac - 0.05, vram_frac + 0.05 + ) + vram_samples.append(max(0.0, vram_used)) + + progress = (sim_idx + 1) / sim_batches + telemetry.record_batch_context(real_idx, progress) + + # Throttle callbacks: every 5 batches or first/last + if sim_idx % 5 == 0 or sim_idx == sim_batches - 1: + images_done = int(progress * total_img) + + # Generate simulated detection data for live preview if it's a vision task + live_data = {} + if job.task.lower() in ("detection", "segmentation"): + # Use provided bbox telemetry if available (e.g. from real inference) + # otherwise generate simulated ones + live_data["detections"] = [ + { + "x": random.uniform(0.1, 0.7), + "y": random.uniform(0.1, 0.7), + "width": random.uniform(0.1, 0.3), + "height": random.uniform(0.1, 0.3), + "label": random.choice(["person", "car", "bicycle", "dog"]), + "confidence": random.uniform(0.5, 0.99) + } + for _ in range(random.randint(1, 5)) + ] + + last_sample = telemetry.samples[-1] if telemetry.samples else None + if last_sample: + last_sample.live_data = live_data + # Explicitly broadcast detections for the visualizer + last_sample.detections = live_data.get("detections", []) + + await on_progress( + progress, + f"Batch {real_idx+1}/{n_batches} — " + f"{images_done}/{total_img} images processed", + last_sample, + ) + + finally: + telemetry_summary = await telemetry.stop() + # Attach simulated layer breakdown so Live Lab can display it + telemetry_summary.layer_breakdown = self._compute_layer_breakdown( + job.task, base_lat_ms + ) + + task_scores = self._simulate_task_scores(job.task, model, dataset) + + log.info( + "execution_complete", + job_id = job.id, + total_images = total_img, + sim_batches = sim_batches, + avg_lat_ms = round(sum(latencies) / len(latencies), 2) if latencies else 0, + ) + + return ExecutionResult( + latencies_ms = latencies, + total_images = total_img, + vram_samples = vram_samples, + task_scores = task_scores, + telemetry_samples = telemetry.samples, + telemetry_summary = telemetry_summary, + ) + + # ── Helpers ─────────────────────────────────────────────────────────────── + + def _base_batch_latency_ms( + self, + hardware: str, + model: Model, + batch_sz: int, + precision: str, + ) -> float: + """ + Estimate per-batch latency in ms. + Accounts for hardware tier, model size, batch size, and precision. + """ + hw_key = self._normalize_hw(hardware) + per_img = self._lookup_latency(hw_key) + + # Larger models are slower: +30% per GB of model weights + size_gb = max(model.size, 1) / (1024 ** 3) + size_factor = 1.0 + size_gb * 0.30 + + # Batch parallelism: ~65% linear efficiency on GPU, 90% on CPU + eff = 0.65 if "cpu" not in hw_key else 0.90 + batch_lat = per_img * size_factor * batch_sz * eff + + # Precision speedup + speedup = _PRECISION_SPEEDUP.get(precision.upper(), 1.0) + + return batch_lat / speedup + + def _get_vram_gb(self, hardware: str, model: Model) -> float: + hw_key = self._normalize_hw(hardware) + for key, vram in HARDWARE_VRAM_GB.items(): + if key and key in hw_key: + return vram + return 8.0 + + @staticmethod + def _vram_usage_fraction(hardware: str) -> float: + """Fraction of VRAM typically consumed during inference.""" + hw = hardware.lower() + if any(x in hw for x in ("4090", "3090", "a100", "h100")): + return 0.62 + if any(x in hw for x in ("4080", "3080", "v100", "a10")): + return 0.60 + if "cpu" in hw: + return 0.0 + return 0.55 + + @staticmethod + def _simulate_task_scores( + task: str, model: Model, dataset: Dataset + ) -> dict[str, float]: + """ + Produce realistic metric scores with small per-run variance. + + PRODUCTION SWAP: replace with actual metric computation: + from torchmetrics.detection import MeanAveragePrecision + metric = MeanAveragePrecision() + metric.update(predictions, targets) + return metric.compute() + """ + baselines = dict(_TASK_BASELINES.get(task.lower(), {"accuracy": 0.80})) + # Small Gaussian jitter simulates run-to-run variance + return { + k: float(max(0.0, min(1.0, v + random.gauss(0, 0.015)))) + for k, v in baselines.items() + } + + @staticmethod + def _check_torch_available() -> bool: + """Return True if PyTorch is installed and importable.""" + try: + import torch # noqa: F401 + return True + except ImportError: + return False + + @staticmethod + def _compute_layer_breakdown(task: str, base_lat_ms: float) -> list[LayerBreakdown]: + """Build a realistic layer breakdown for the given task. + + Splits total latency across architectural stages with small jitter. + PRODUCTION SWAP: replace with actual profiler data (e.g. torch.profiler). + """ + if task.lower() in ("detection", "segmentation"): + stages = [ + ("Backbone", 0.45), + ("Neck (FPN/PAFPN)", 0.30), + ("Detection Head", 0.20), + ("NMS Post-process", 0.05), + ] + elif task.lower() == "classification": + stages = [ + ("Feature Extractor", 0.70), + ("Classifier Head", 0.20), + ("Softmax", 0.10), + ] + else: + stages = [ + ("Encoder", 0.55), + ("Decoder / Head", 0.35), + ("Post-process", 0.10), + ] + + result: list[LayerBreakdown] = [] + remaining = base_lat_ms + for name, frac in stages: + t = round(base_lat_ms * frac + random.gauss(0, base_lat_ms * 0.01), 3) + result.append(LayerBreakdown(name=name, time_ms=t, percent=round(frac * 100, 1))) + return result + + @staticmethod + def _normalize_hw(hardware: str) -> str: + return ( + hardware.lower() + .replace(" ", "") + .replace("-", "") + .replace("_", "") + .replace("nvidia", "") + .replace("geforce", "") + ) + + @staticmethod + def _lookup_latency(hw_key: str) -> float: + for key, ms in _LATENCY_MS_PER_IMAGE.items(): + if key and key in hw_key: + return ms + if any(x in hw_key for x in ("gpu", "rtx", "gtx", "cuda")): + return _LATENCY_MS_PER_IMAGE["gpu"] + return _LATENCY_MS_PER_IMAGE["cpu"] diff --git a/benchmark/metrics.py b/benchmark/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f80529e1cf8a754881bdadb3980f8e1915d336fb --- /dev/null +++ b/benchmark/metrics.py @@ -0,0 +1,110 @@ +""" +benchmark/metrics.py — Metrics Engine. + +Computes the final BenchmarkMetrics object from raw execution data: + - Latency statistics (mean, p95, p99) + - Throughput (FPS) + - VRAM statistics (avg, peak) + - Task-specific scores (mAP, accuracy, IoU) supplied by the executor + +In a production deployment the task_scores dict comes from actual +metric computation (e.g. pycocotools, torchmetrics). In this local-first +build the executor supplies realistic simulated scores. +""" +from __future__ import annotations + +import statistics + +from models.benchmark import BenchmarkMetrics, LayerBreakdown, TelemetrySummary +from observability.logger import get_logger + +log = get_logger("benchmark.metrics") + + +class MetricsEngine: + """Computes BenchmarkMetrics from raw benchmark execution data.""" + + def compute( + self, + *, + task: str, + latencies_ms: list[float], # per-batch latencies + total_images: int = 0, + total_tokens: int = 0, + batch_size: int, + vram_samples: list[float], # VRAM readings (GB) during run + task_scores: dict[str, float], # task-specific metric scores + ) -> BenchmarkMetrics: + if not latencies_ms: + return BenchmarkMetrics(total_images=total_images, total_tokens=total_tokens, batch_size=batch_size) + + total_time_s = sum(latencies_ms) / 1000.0 + fps = total_images / total_time_s if total_time_s > 0 and total_images > 0 else 0.0 + tps = total_tokens / total_time_s if total_time_s > 0 and total_tokens > 0 else 0.0 + + lat_mean = statistics.mean(latencies_ms) + lat_p95 = _percentile(latencies_ms, 0.95) + lat_p99 = _percentile(latencies_ms, 0.99) + + vram_peak = max(vram_samples) if vram_samples else 0.0 + vram_avg = statistics.mean(vram_samples) if vram_samples else 0.0 + + m = BenchmarkMetrics( + fps = round(fps, 2), + tokens_per_sec = round(tps, 2), + latency_mean_ms = round(lat_mean, 3), + latency_p95_ms = round(lat_p95, 3), + latency_p99_ms = round(lat_p99, 3), + vram_peak_gb = round(vram_peak, 3), + vram_avg_gb = round(vram_avg, 3), + total_images = total_images, + total_tokens = total_tokens, + batch_size = batch_size, + ) + + task_lower = task.lower() + + # CV Task Mapping + if task_lower in ("detection", "segmentation", "keypoints"): + m.mAP = _fmt(task_scores.get("mAP", 0.0)) + m.mAP_50 = _fmt(task_scores.get("mAP_50", 0.0)) + m.mAP_50_95 = _fmt(task_scores.get("mAP_50_95", 0.0)) + if task_lower == "segmentation": + m.iou_mean = _fmt(task_scores.get("iou_mean", 0.0)) + + elif task_lower == "classification": + m.accuracy = _fmt(task_scores.get("accuracy", 0.0)) + m.top1 = _fmt(task_scores.get("top1", 0.0)) + m.top5 = _fmt(task_scores.get("top5", 0.0)) + + # NLP Task Mapping (ROUGE, BLEU, Perplexity) + elif task_lower in ("nlp", "generation"): + m.accuracy = _fmt(task_scores.get("accuracy", 0.0)) + m.rouge_l = _fmt(task_scores.get("rouge_l", task_scores.get("rougeL", 0.0))) + m.bleu = _fmt(task_scores.get("bleu", 0.0)) + m.perplexity = task_scores.get("perplexity") + + log.info( + "metrics_computed", + task = task, + fps = m.fps, + tps = m.tokens_per_sec, + latency_ms = m.latency_mean_ms, + vram_peak = m.vram_peak_gb, + ) + return m + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _percentile(data: list[float], p: float) -> float: + if not data: + return 0.0 + s = sorted(data) + idx = min(int(len(s) * p), len(s) - 1) + return s[idx] + + +def _fmt(v: float) -> float: + """Round to 4dp and clamp to [0, 1].""" + return round(max(0.0, min(1.0, v)), 4) diff --git a/benchmark/orchestrator.py b/benchmark/orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fcef70f2d048a35db59e93ab98eb9703279f8e --- /dev/null +++ b/benchmark/orchestrator.py @@ -0,0 +1,374 @@ +""" +benchmark/orchestrator.py — Benchmark Orchestrator (Main Controller). + +Coordinates the full benchmark lifecycle: + 1. Resolve model + dataset from their registries + 2. Run all compatibility checks (gates A–E) + 3. If valid → create a BenchmarkJob in the DB + 4. Persist the validation audit log + 5. Enqueue async background task → execution → metrics → storage + 6. Return the job immediately so callers are non-blocking + +Public interface used by api/routes/benchmark.py: + validate_context(ctx) → ValidationReport (no job created) + create_and_run(ctx) → BenchmarkJob (job queued, execution in background) +""" +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone + +from benchmark.adapters.registry import get_executor +from benchmark.compatibility import CompatibilityValidator +from benchmark.execution import BenchmarkExecutor +from benchmark.metrics import MetricsEngine +import benchmark.registry as bench_reg +from datasets.registry import get_dataset +from models.benchmark import ( + BenchmarkContext, + BenchmarkJob, + BenchmarkMetrics, + TelemetrySummary, + ValidationReport, +) +from models.dataset import Dataset +from models.model import Model +from observability.logger import audit, get_logger +from registry.registry import get_model + +log = get_logger("benchmark.orchestrator") + +# Module-level singletons — stateless, safe to share +_validator = CompatibilityValidator() +_metrics = MetricsEngine() + +# job_id → asyncio.Task (for future cancellation support) +_active_tasks: dict[str, asyncio.Task] = {} + + +# ── Public API ──────────────────────────────────────────────────────────────── + +async def sync_project_benchmarks() -> int: + """ + Sync benchmark jobs and results from the active project's 'benchmarks' folder. + This ensures that benchmarks created in different sessions or projects are indexed. + """ + from benchmark.registry import _get_active_project_benchmark_dir_sync + from projects.service import get_active_project_path + import json + import os + from database.connection import get_db + + project_path = await get_active_project_path() + benchmark_dir = _get_active_project_benchmark_dir_sync(project_path) + if not benchmark_dir or not benchmark_dir.exists(): + return 0 + + db = await get_db() + count = 0 + + for file_path in benchmark_dir.glob("*.json"): + try: + with open(file_path, "r") as f: + data = json.load(f) + + # Check if it's a job or a result + if file_path.name.startswith("job_"): + # Upsert into benchmark_jobs + await db.execute( + """INSERT OR IGNORE INTO benchmark_jobs + (id, model_id, dataset_id, task, framework, hardware, + precision, batch_size, config, status, progress, created_at, updated_at, started_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + data["id"], data["model_id"], data["dataset_id"], + data["task"], data["framework"], data["hardware"], + data["precision"], data["batch_size"], + json.dumps(data["config"]), data["status"], + data.get("progress", 0.0), + data.get("created_at", datetime.now(timezone.utc).isoformat()), + data.get("updated_at", datetime.now(timezone.utc).isoformat()), + data.get("started_at") + ) + ) + count += 1 + elif file_path.name.startswith("result_"): + # Upsert into benchmark_results + await db.execute( + """INSERT OR IGNORE INTO benchmark_results + (id, job_id, metrics, telemetry_summary, created_at) + VALUES (?,?,?,?,?)""", + ( + data["id"], data["job_id"], + json.dumps(data["metrics"]), + json.dumps(data["telemetry_summary"]), + data.get("created_at", datetime.now(timezone.utc).isoformat()) + ) + ) + count += 1 + except Exception as e: + log.error("sync_file_failed", file=file_path.name, error=str(e)) + + await db.commit() + log.info("sync_complete", count=count) + return count + +async def validate_context(ctx: BenchmarkContext) -> ValidationReport: + """ + Validate model ↔ dataset ↔ hardware compatibility. + Does NOT create a job. Safe to call repeatedly from the UI. + """ + model = await _require_model(ctx.model_id) + + # ── Handle Polymorphic Input (Video/Live) ──────────────────────────────── + if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none": + # Create a synthetic dataset object for non-dataset sources + now = datetime.now(timezone.utc).isoformat() + dataset = Dataset( + id="none", + name="Live/Video Stream", + task=model.task, # Match model task to pass task check + format="custom", + source="local", + status="imported", + images=0, + classes=0, + size_label="0 MB", + created_at=now, + updated_at=now + ) + else: + dataset = await _require_dataset(ctx.dataset_id) + + return _validator.validate(model, dataset, ctx) + + +async def create_and_run(ctx: BenchmarkContext) -> BenchmarkJob: + """ + Full benchmark initiation: + """ + model = await _require_model(ctx.model_id) + + # ── Handle Polymorphic Input (Video/Live) ──────────────────────────────── + if ctx.input_source in ["video", "live"] or ctx.dataset_id == "none": + now = datetime.now(timezone.utc).isoformat() + dataset = Dataset( + id="none", + name="Live/Video Stream", + task=model.task, + format="custom", + source="local", + status="imported", + images=0, + classes=0, + size_label="0 MB", + created_at=now, + updated_at=now + ) + else: + dataset = await _require_dataset(ctx.dataset_id) + + # ── Compatibility check ─────────────────────────────────────────────────── + report = _validator.validate(model, dataset, ctx) + + # Always persist the validation log (even for failures) + await bench_reg.save_validation_log( + job_id = "pre-check", + model_id = ctx.model_id, + dataset_id = ctx.dataset_id, + checks = report.checks, + passed = report.passed, + ) + + if not report.passed: + from fastapi import HTTPException + failed = [c for c in report.checks if not c.passed] + raise HTTPException( + status_code = 422, + detail = { + "error": "Compatibility validation failed", + "failed_checks": [ + { + "name": c.name, + "detail": c.detail, + "suggestion": c.suggestion, + } + for c in failed + ], + }, + ) + + # ── Create job ──────────────────────────────────────────────────────────── + job = await bench_reg.create_job(ctx) + + # Overwrite 'pre-check' validation log with the real job_id + await bench_reg.save_validation_log( + job_id = job.id, + model_id = ctx.model_id, + dataset_id = ctx.dataset_id, + checks = report.checks, + passed = True, + ) + + # ── Log the Polymorphic Input params ───────────────────────────────────── + if ctx.input_source or ctx.video_path or ctx.rtsp_url: + log.info("polymorphic_input_received", + job_id=job.id, + source=ctx.input_source, + video=ctx.video_path, + rtsp=ctx.rtsp_url) + + # ── Enqueue background execution ────────────────────────────────────────── + task = asyncio.create_task( + _execute_job(job.id, ctx, model, dataset), + name = f"benchmark_{job.id}", + ) + _active_tasks[job.id] = task + task.add_done_callback(lambda _t: _active_tasks.pop(job.id, None)) + + log.info("benchmark_enqueued", job_id=job.id, model=ctx.model_id) + return job + + +# ── Background execution ────────────────────────────────────────────────────── + +async def _execute_job( + job_id: str, + ctx: BenchmarkContext, + model: Model, + dataset: Dataset, +) -> None: + """Full benchmark lifecycle — runs in an asyncio background task.""" + now = datetime.now(timezone.utc).isoformat() + + # Transition → running + ts_color = "\x1b[36m" # Cyan + info_color = "\x1b[34m" # Blue + success_color = "\x1b[32m" # Green + reset = "\x1b[0m" + + await bench_reg.update_job( + job_id, + status = "running", + progress = 0.0, + started_at = now, + log_entry = f"{ts_color}[{now}]{reset} {info_color}Job started{reset} on {ctx.hardware} ({ctx.precision})", + ) + + runner = BenchmarkExecutor() + + try: + # ── Fetch the persisted job (for executor) ──────────────────────────── + job = await bench_reg.get_job(job_id) + assert job is not None, "Job disappeared from DB after creation" + + # ── Define Progress Callback ────────────────────────────────────────── + async def on_progress(progress: float, message: str, telemetry: Any | None): + await bench_reg.update_job( + job_id, + progress=progress, + log_entry=f"{ts_color}[{datetime.now(timezone.utc).isoformat()}]{reset} {info_color}{message}{reset}", + last_telemetry=telemetry.model_dump() if telemetry and hasattr(telemetry, "model_dump") else telemetry + ) + + # ── Execution Loop ──────────────────────────────────────────────────── + exec_result = await runner.execute( + job=job, + model=model, + dataset=dataset, + on_progress=on_progress + ) + + # ── Compute metrics ─────────────────────────────────────────────────── + metrics = _metrics.compute( + task = ctx.task, + latencies_ms = exec_result.latencies_ms, + total_images = exec_result.total_images, + batch_size = ctx.batch_size, + vram_samples = exec_result.vram_samples, + task_scores = exec_result.task_scores, + ) + + # ── Persist result ──────────────────────────────────────────────────── + await bench_reg.save_result( + job_id = job_id, + metrics = metrics, + telemetry_summary = exec_result.telemetry_summary, + ) + + ended = datetime.now(timezone.utc).isoformat() + await bench_reg.update_job( + job_id, + status = "completed", + progress = 1.0, + ended_at = ended, + log_entry = f"{ts_color}[{ended}]{reset} {success_color}Benchmark completed{reset} — {metrics.fps} FPS", + ) + + await audit( + "benchmark_completed", + job_id = job_id, + payload = {"model_id": ctx.model_id, "dataset_id": ctx.dataset_id}, + ) + log.info( + "benchmark_completed", + job_id = job_id, + fps = metrics.fps, + lat_ms = metrics.latency_mean_ms, + ) + + except asyncio.CancelledError: + # Task cancelled externally (e.g. server shutdown) — don't swallow + ended = datetime.now(timezone.utc).isoformat() + await bench_reg.update_job( + job_id, + status = "failed", + error = "Job cancelled", + ended_at = ended, + log_entry = f"{ts_color}[{ended}]{reset} \x1b[31mJob cancelled\x1b[0m", + ) + raise + + except Exception as exc: + ended = datetime.now(timezone.utc).isoformat() + err_msg = str(exc) + error_color = "\x1b[31m" # Red + await bench_reg.update_job( + job_id, + status = "failed", + error = err_msg, + ended_at = ended, + log_entry = f"{ts_color}[{ended}]{reset} {error_color}ERROR: {err_msg}{reset}", + ) + await audit( + "benchmark_failed", + job_id = job_id, + level = "error", + payload = {"error": err_msg, "model_id": ctx.model_id}, + ) + log.exception("benchmark_failed", job_id=job_id) + finally: + pass + +# ── Resource resolvers ──────────────────────────────────────────────────────── + +async def _require_model(model_id: str) -> Model: + model = await get_model(model_id) + if not model: + from fastapi import HTTPException + raise HTTPException( + status_code = 404, + detail = f"Model '{model_id}' not found in Model Zoo", + ) + return model + + +async def _require_dataset(dataset_id: str) -> Dataset: + dataset = await get_dataset(dataset_id) + if not dataset: + from fastapi import HTTPException + raise HTTPException( + status_code = 404, + detail = f"Dataset '{dataset_id}' not found in Dataset Manager", + ) + return dataset diff --git a/benchmark/registry.py b/benchmark/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ef02ea0cd927b2358611301aea64d9f7c4cd1746 --- /dev/null +++ b/benchmark/registry.py @@ -0,0 +1,302 @@ +""" +benchmark/registry.py — Benchmark Registry. + +All DB interactions for: + • benchmark_jobs — job lifecycle state + • benchmark_results — final metrics + telemetry summary + • benchmark_validation_logs — immutable check audit trail + +Follows the same pattern as registry/registry.py and datasets/registry.py. +No direct DB access from other benchmark modules — everything routes here. +""" +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone +from typing import Any +from pathlib import Path + +from database.connection import get_db +from models.benchmark import ( + BenchmarkContext, + BenchmarkJob, + BenchmarkMetrics, + BenchmarkResult, + TelemetrySummary, + ValidationCheck, + row_to_job, + row_to_result, +) +from observability.logger import get_logger + +log = get_logger("benchmark.registry") + + +def _get_active_project_benchmark_dir_sync(project_path: str | None) -> Path | None: + """Get the absolute path to the 'benchmarks' folder in a given project path.""" + if not project_path: + return None + + benchmark_dir = Path(project_path) / "benchmarks" + benchmark_dir.mkdir(parents=True, exist_ok=True) + return benchmark_dir + +async def _get_active_project_benchmark_dir() -> Path | None: + """Get the absolute path to the 'benchmarks' folder in the active project.""" + from projects.service import get_active_project_path + project_path = await get_active_project_path() + return _get_active_project_benchmark_dir_sync(project_path) + +async def _save_to_project(filename: str, data: dict) -> None: + """Save data to a JSON file in the active project's benchmark folder.""" + benchmark_dir = await _get_active_project_benchmark_dir() + if not benchmark_dir: + return + + file_path = benchmark_dir / filename + try: + with open(file_path, "w") as f: + json.dump(data, f, indent=2) + except Exception as e: + log.error("project_persistence_failed", error=str(e), file=filename) + +# ── Job CRUD ────────────────────────────────────────────────────────────────── + +async def create_job(ctx: BenchmarkContext) -> BenchmarkJob: + db = await get_db() + job_id = f"bmark-{uuid.uuid4().hex[:12]}" + now = datetime.now(timezone.utc).isoformat() + + # Create job object + job = BenchmarkJob( + id = job_id, + model_id = ctx.model_id, + dataset_id = ctx.dataset_id, + task = ctx.task, + framework = ctx.framework, + hardware = ctx.hardware, + precision = ctx.precision, + batch_size = ctx.batch_size, + config = ctx.model_dump(), + status = "queued", + progress = 0.0, + created_at = now, + updated_at = now, + ) + + # Persist to SQLite + await db.execute( + """INSERT INTO benchmark_jobs + (id, model_id, dataset_id, task, framework, hardware, + precision, batch_size, config, + status, progress, logs, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,'queued',0.0,'[]',?,?)""", + ( + job_id, + ctx.model_id, ctx.dataset_id, + ctx.task, ctx.framework, ctx.hardware, + ctx.precision, ctx.batch_size, + json.dumps(ctx.model_dump()), + now, now, + ), + ) + await db.commit() + + # Persist to project folder + await _save_to_project(f"job_{job_id}.json", job.model_dump()) + + log.info("benchmark_job_created", job_id=job_id, model=ctx.model_id) + return job + + +async def get_job(job_id: str) -> BenchmarkJob | None: + db = await get_db() + async with db.execute( + "SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,) + ) as cur: + row = await cur.fetchone() + return row_to_job(row) if row else None + + +async def list_jobs( + *, + status: str | None = None, + model_id: str | None = None, + limit: int = 100, +) -> list[BenchmarkJob]: + db = await get_db() + clauses: list[str] = [] + params: list[Any] = [] + + if status: + clauses.append("status = ?") + params.append(status) + if model_id: + clauses.append("model_id = ?") + params.append(model_id) + + where = f"WHERE {' AND '.join(clauses)}" if clauses else "" + params.append(limit) + + async with db.execute( + f"SELECT * FROM benchmark_jobs {where} ORDER BY created_at DESC LIMIT ?", + params, + ) as cur: + rows = await cur.fetchall() + return [row_to_job(r) for r in rows] + + +async def update_job( + job_id: str, + *, + status: str | None = None, + progress: float | None = None, + error: str | None = None, + started_at: str | None = None, + ended_at: str | None = None, + log_entry: str | None = None, + last_telemetry: dict | None = None, +) -> None: + """Update mutable fields on a benchmark job atomically.""" + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + + sets: list[str] = ["updated_at = ?"] + vals: list[Any] = [now] + + if status is not None: + sets.append("status = ?"); vals.append(status) + if progress is not None: + sets.append("progress = ?"); vals.append(round(progress, 4)) + if error is not None: + sets.append("error = ?"); vals.append(error) + if started_at is not None: + sets.append("started_at = ?"); vals.append(started_at) + if ended_at is not None: + sets.append("ended_at = ?"); vals.append(ended_at) + if last_telemetry is not None: + sets.append("last_telemetry = ?"); vals.append(json.dumps(last_telemetry)) + + if log_entry is not None: + # Append new entry to the JSON log array (capped at 500 lines) + async with db.execute( + "SELECT logs FROM benchmark_jobs WHERE id = ?", (job_id,) + ) as cur: + row = await cur.fetchone() + existing = json.loads(row["logs"]) if row and row["logs"] else [] + existing.append(log_entry) + sets.append("logs = ?") + vals.append(json.dumps(existing[-500:])) + + vals.append(job_id) + # Persist to project folder if we have the job info + async with db.execute("SELECT * FROM benchmark_jobs WHERE id = ?", (job_id,)) as cur: + row = await cur.fetchone() + if row: + job = row_to_job(row) + if job: + await _save_to_project(f"job_{job_id}.json", job.model_dump()) + + await db.commit() + + +# ── Result CRUD ─────────────────────────────────────────────────────────────── + +async def save_result( + *, + job_id: str, + metrics: BenchmarkMetrics, + telemetry_summary: TelemetrySummary, +) -> BenchmarkResult: + db = await get_db() + result_id = f"bres-{uuid.uuid4().hex[:12]}" + now = datetime.now(timezone.utc).isoformat() + + # Persist result to SQLite + await db.execute( + """INSERT INTO benchmark_results + (id, job_id, metrics, telemetry_summary, created_at) + VALUES (?,?,?,?,?)""", + ( + result_id, + job_id, + json.dumps(metrics.model_dump(exclude_none=True)), + json.dumps(telemetry_summary.model_dump()), + now, + ), + ) + await db.commit() + + result = BenchmarkResult( + id = result_id, + job_id = job_id, + metrics = metrics, + telemetry_summary = telemetry_summary, + created_at = now, + ) + + # Persist result to project folder + await _save_to_project(f"result_{job_id}.json", result.model_dump()) + + log.info("benchmark_result_saved", job_id=job_id, result_id=result_id) + return result + + +async def get_result(job_id: str) -> BenchmarkResult | None: + db = await get_db() + async with db.execute( + """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision + FROM benchmark_results r + JOIN benchmark_jobs j ON r.job_id = j.id + WHERE r.job_id = ?""", (job_id,) + ) as cur: + row = await cur.fetchone() + return row_to_result(row) if row else None + + +async def list_results(*, limit: int = 100) -> list[BenchmarkResult]: + db = await get_db() + async with db.execute( + """SELECT r.*, j.model_id, j.dataset_id, j.task, j.framework, j.hardware, j.precision + FROM benchmark_results r + JOIN benchmark_jobs j ON r.job_id = j.id + ORDER BY r.created_at DESC LIMIT ?""", (limit,) + ) as cur: + rows = await cur.fetchall() + return [row_to_result(r) for r in rows] + + +# ── Validation Log ──────────────────────────────────────────────────────────── + +async def save_validation_log( + *, + job_id: str, + model_id: str, + dataset_id: str, + checks: list[ValidationCheck], + passed: bool, +) -> None: + """Persist an immutable record of all compatibility checks.""" + db = await get_db() + log_id = f"bval-{uuid.uuid4().hex[:12]}" + now = datetime.now(timezone.utc).isoformat() + + await db.execute( + """INSERT INTO benchmark_validation_logs + (id, job_id, model_id, dataset_id, checks, passed, created_at) + VALUES (?,?,?,?,?,?,?)""", + ( + log_id, job_id, model_id, dataset_id, + json.dumps([c.model_dump() for c in checks]), + 1 if passed else 0, + now, + ), + ) + await db.commit() + log.info( + "validation_log_saved", + job_id = job_id, + passed = passed, + n_checks = len(checks), + ) diff --git a/benchmark/telemetry.py b/benchmark/telemetry.py new file mode 100644 index 0000000000000000000000000000000000000000..cee6cd1202b1d38fdb3de66903382fe7d9cac7b5 --- /dev/null +++ b/benchmark/telemetry.py @@ -0,0 +1,182 @@ +""" +benchmark/telemetry.py — Real-time Telemetry Collector. + +Collects GPU/hardware metrics at 2 Hz during benchmark execution. +Designed as a drop-in adapter: + • Local dev → simulates realistic GPU readings based on hardware tier + • Production → replace _read_gpu_metrics() with pynvml calls: + nvmlDeviceGetUtilizationRates() + nvmlDeviceGetMemoryInfo() + nvmlDeviceGetTemperature() + nvmlDeviceGetPowerUsage() + +Usage (async context): + collector = TelemetryCollector("rtx4090", vram_total_gb=24.0) + await collector.start() + # ... run inference ... + summary = await collector.stop() + samples = collector.samples +""" +from __future__ import annotations + +import asyncio +import random +import statistics +import time + +from models.benchmark import TelemetrySample, TelemetrySummary +from observability.logger import get_logger + +log = get_logger("benchmark.telemetry") + +# ── Hardware simulation profiles ────────────────────────────────────────────── +# (base_util%, base_temp_C, base_power_W) +_HW_PROFILES: dict[str, tuple[float, float, float]] = { + "rtx4090": (88.0, 74.0, 380.0), + "rtx4080": (84.0, 70.0, 280.0), + "rtx4070": (80.0, 68.0, 200.0), + "rtx3090": (85.0, 72.0, 320.0), + "rtx3080": (82.0, 70.0, 250.0), + "rtx3070": (78.0, 66.0, 180.0), + "rtx3060": (74.0, 64.0, 150.0), + "a100": (90.0, 68.0, 350.0), + "h100": (92.0, 65.0, 550.0), + "v100": (87.0, 70.0, 280.0), + "t4": (75.0, 62.0, 60.0), + "gpu": (70.0, 65.0, 150.0), + "cpu": (0.0, 0.0, 0.0), +} + +_COLLECTION_INTERVAL_S = 0.5 # 2 Hz + + +class TelemetryCollector: + """ + Async telemetry collector. Call start() before inference, stop() after. + Thread-safe via asyncio (single-threaded event loop). + """ + + def __init__(self, hardware: str, vram_total_gb: float = 8.0) -> None: + self._hardware = hardware + self._vram_total = vram_total_gb + self._hw_profile = self._resolve_profile(hardware) + self._samples: list[TelemetrySample] = [] + self._running = False + self._task: asyncio.Task | None = None + + # ── Public API ──────────────────────────────────────────────────────────── + + async def start(self) -> None: + self._running = True + self._samples = [] + self._task = asyncio.create_task( + self._collect_loop(), name="telemetry_collector" + ) + log.debug("telemetry_started", hardware=self._hardware) + + async def stop(self) -> TelemetrySummary: + self._running = False + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + log.debug( + "telemetry_stopped", + hardware = self._hardware, + samples = len(self._samples), + ) + return self._build_summary() + + def record_batch_context(self, batch_idx: int, progress: float) -> None: + """Annotate the most recent sample with the current batch context.""" + if self._samples: + self._samples[-1].batch_idx = batch_idx + self._samples[-1].progress = progress + + @property + def samples(self) -> list[TelemetrySample]: + return list(self._samples) + + # ── Internal ────────────────────────────────────────────────────────────── + + async def _collect_loop(self) -> None: + while self._running: + sample = self._read_gpu_metrics() + self._samples.append(sample) + await asyncio.sleep(_COLLECTION_INTERVAL_S) + + def _read_gpu_metrics(self) -> TelemetrySample: + """ + Returns a TelemetrySample for the current hardware state. + + PRODUCTION SWAP: Replace this body with pynvml calls: + handle = nvmlDeviceGetHandleByIndex(0) + util = nvmlDeviceGetUtilizationRates(handle) + mem = nvmlDeviceGetMemoryInfo(handle) + temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU) + power = nvmlDeviceGetPowerUsage(handle) / 1000 # mW → W + """ + base_util, base_temp, base_power = self._hw_profile + + if base_util == 0.0: # CPU path — no meaningful GPU readings + return TelemetrySample( + timestamp = time.time(), + gpu_util_pct = 0.0, + vram_used_gb = 0.0, + vram_total_gb = 0.0, + temp_c = 0.0, + power_w = 0.0, + ) + + # Simulate realistic jitter (±5% util, ±3°C, ±10W) + jitter_util = random.gauss(0, 3.0) + jitter_temp = random.gauss(0, 1.5) + jitter_power = random.gauss(0, 8.0) + vram_frac = random.uniform(0.58, 0.72) + + return TelemetrySample( + timestamp = time.time(), + gpu_util_pct = max(0.0, min(100.0, base_util + jitter_util)), + vram_used_gb = round( + max(0.0, min(self._vram_total, self._vram_total * vram_frac)), 3 + ), + vram_total_gb = self._vram_total, + temp_c = round(max(0.0, base_temp + jitter_temp), 1), + power_w = round(max(0.0, base_power + jitter_power), 1), + ) + + def _build_summary(self) -> TelemetrySummary: + if not self._samples: + return TelemetrySummary() + + utils = [s.gpu_util_pct for s in self._samples] + vrams = [s.vram_used_gb for s in self._samples] + temps = [s.temp_c for s in self._samples] + powers = [s.power_w for s in self._samples] + + def _safe_mean(lst: list[float]) -> float: + return statistics.mean(lst) if lst else 0.0 + + return TelemetrySummary( + gpu_util_avg = round(_safe_mean(utils), 2), + gpu_util_peak = round(max(utils), 2), + vram_avg_gb = round(_safe_mean(vrams), 3), + vram_peak_gb = round(max(vrams), 3), + temp_avg_c = round(_safe_mean(temps), 1), + temp_peak_c = round(max(temps), 1), + power_avg_w = round(_safe_mean(powers), 1), + power_peak_w = round(max(powers), 1), + ) + + @staticmethod + def _resolve_profile(hardware: str) -> tuple[float, float, float]: + hw = hardware.lower().replace(" ", "").replace("-", "") + for key, profile in _HW_PROFILES.items(): + if key in hw: + return profile + # Default for unknown GPU-class hardware + if any(x in hw for x in ("gpu", "rtx", "gtx", "cuda", "vram")): + return _HW_PROFILES["gpu"] + return _HW_PROFILES["cpu"] diff --git a/benchmark/torch_runner.py b/benchmark/torch_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..56ad3ca22d19e5497875d800f2a1bc243fe40df5 --- /dev/null +++ b/benchmark/torch_runner.py @@ -0,0 +1,142 @@ +""" +benchmark/torch_runner.py — Synchronous GPU inference runner. + +Called from BenchmarkExecutor via asyncio.run_in_executor() so it never +blocks the event loop. PyTorch is an optional dependency — if it is not +installed the module raises ImportError and execution.py falls back to +the simulation path. + +Supported weight formats (detected by file extension): + .pt / .pth — torch.load (TorchScript or state-dict) + .safetensors — safetensors.torch.load_file + .onnx — onnxruntime InferenceSession + +PRODUCTION SWAP POINTS are marked with # <<< REPLACE IN PRODUCTION >>> +""" +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +# ── Model cache (keyed by absolute path) ───────────────────────────────────── +_MODEL_CACHE: dict[str, Any] = {} + +# Standard input shapes per task (B, C, H, W) +_INPUT_SHAPES: dict[str, tuple[int, int, int]] = { + "detection": (3, 640, 640), + "segmentation": (3, 640, 640), + "classification": (3, 224, 224), + "generation": (3, 512, 512), + "embedding": (3, 224, 224), +} +_DEFAULT_SHAPE = (3, 640, 640) + + +def run_torch_batch(model_path: str, batch_size: int, task: str = "detection") -> float: + """Run one inference batch and return per-image latency in ms. + + Args: + model_path: Absolute path to the weight file. + batch_size: Number of images in the batch. + task: Model task (affects dummy input shape). + + Returns: + Latency per image in milliseconds. + """ + import torch # raises ImportError if not installed + + device = "cuda" if torch.cuda.is_available() else "cpu" + ext = Path(model_path).suffix.lower() + + model = _load_model(model_path, ext, device) + c, h, w = _INPUT_SHAPES.get(task, _DEFAULT_SHAPE) + dummy = torch.zeros(batch_size, c, h, w, device=device) + + # Warm-up pass (first call is slower due to CUDA kernel compilation) + if device == "cuda": + with torch.no_grad(): + _forward(model, dummy, ext, device) + torch.cuda.synchronize() + + # Timed pass + if device == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + with torch.no_grad(): + _forward(model, dummy, ext, device) + if device == "cuda": + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000 + + return elapsed_ms / batch_size + + +def _load_model(path: str, ext: str, device: str) -> Any: + """Load and cache the model by absolute path.""" + if path in _MODEL_CACHE: + return _MODEL_CACHE[path] + + model = _load_by_ext(path, ext, device) + _MODEL_CACHE[path] = model + return model + + +def _load_by_ext(path: str, ext: str, device: str) -> Any: + """Select loader based on file extension.""" + if ext in (".pt", ".pth"): + return _load_torch(path, device) + if ext == ".safetensors": + return _load_safetensors(path, device) + if ext == ".onnx": + return _load_onnx(path) + raise ValueError(f"Unsupported model format: {ext}") + + +def _load_torch(path: str, device: str) -> Any: + import torch + # <<< REPLACE IN PRODUCTION >>> with proper model class instantiation + # TorchScript models can be loaded directly; state-dict models need + # the model class to be imported separately. + try: + model = torch.jit.load(path, map_location=device) + model.eval() + return model + except RuntimeError: + # Not a TorchScript model — try loading as a full checkpoint + obj = torch.load(path, map_location=device, weights_only=False) + if hasattr(obj, "eval"): + obj.eval() + return obj + # It's a state-dict — we cannot run inference without knowing the arch + raise RuntimeError( + f"Model at {path} is a state-dict; cannot run inference without " + "the model class. Use a TorchScript-exported .pt file." + ) + + +def _load_safetensors(path: str, device: str) -> Any: + # <<< REPLACE IN PRODUCTION >>> safetensors gives tensors only; + # you still need the model class. This is intentionally left as a + # placeholder that raises a clear error rather than silently failing. + raise NotImplementedError( + "safetensors inference requires the model class to be registered. " + "Convert to TorchScript or ONNX for architecture-agnostic inference." + ) + + +def _load_onnx(path: str) -> Any: + import onnxruntime as ort # type: ignore[import] + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + return ort.InferenceSession(path, providers=providers) + + +def _forward(model: Any, dummy: Any, ext: str, device: str) -> Any: + """Run a single forward pass, dispatching by model type.""" + if ext == ".onnx": + import numpy as np + np_input = dummy.cpu().numpy() + input_name = model.get_inputs()[0].name + return model.run(None, {input_name: np_input}) + # TorchScript / nn.Module + return model(dummy) diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37e5e879580118aa4e13a71d8c5f0e0593b18e57 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1 @@ +# datasets package diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ade22816ed78e126cd63f6936bc592121e399982 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/datasets/__pycache__/annotation_parser.cpython-310.pyc b/datasets/__pycache__/annotation_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07be4bc23c3bd4d3c19c641ef28721c3ca9d78bf Binary files /dev/null and b/datasets/__pycache__/annotation_parser.cpython-310.pyc differ diff --git a/datasets/__pycache__/base_adapter.cpython-310.pyc b/datasets/__pycache__/base_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..080264fa222f70b92d2ecd9ef50242b5dbefd337 Binary files /dev/null and b/datasets/__pycache__/base_adapter.cpython-310.pyc differ diff --git a/datasets/__pycache__/format_adapters.cpython-310.pyc b/datasets/__pycache__/format_adapters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f60597584c43e9e8c84679b5b7fb70c4cd8d2fa Binary files /dev/null and b/datasets/__pycache__/format_adapters.cpython-310.pyc differ diff --git a/datasets/__pycache__/import_service.cpython-310.pyc b/datasets/__pycache__/import_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..590fe994e6e642a83e12ddf3db83f2c81f590777 Binary files /dev/null and b/datasets/__pycache__/import_service.cpython-310.pyc differ diff --git a/datasets/__pycache__/registry.cpython-310.pyc b/datasets/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..315248944e2afb5d5be4ceea7eb5b6e85af3540c Binary files /dev/null and b/datasets/__pycache__/registry.cpython-310.pyc differ diff --git a/datasets/__pycache__/viewer_service.cpython-310.pyc b/datasets/__pycache__/viewer_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..527a24b471a7bb43f80ec3872857a4770cf66adc Binary files /dev/null and b/datasets/__pycache__/viewer_service.cpython-310.pyc differ diff --git a/datasets/annotation_parser.py b/datasets/annotation_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..e26f70e10fccdbe01e98e791d5a4368d0e6a0a2f --- /dev/null +++ b/datasets/annotation_parser.py @@ -0,0 +1,576 @@ +""" +datasets/annotation_parser.py — Multi-format annotation parser. + +Supports: + - YOLO (darknet .txt + classes.txt / data.yaml) + - COCO (instances_*.json / _annotations.coco.json) + - Pascal VOC (*.xml) + +All formats normalise to the unified Annotation schema with +normalised bounding boxes (0–1 range, x_topleft, y_topleft, w, h). +""" +from __future__ import annotations + +import csv +import json +import re +import uuid +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Iterator, Optional + +from observability.logger import get_logger + +log = get_logger("annotation_parser") + + +# ── Unified Output ──────────────────────────────────────────────────────────── + +def _make_ann( + image_id: str, + dataset_id: str, + label: str, + bbox: tuple[float, float, float, float] | None = None, # x, y, w, h (normalised) + normalised: bool = True, + area: float | None = None, + confidence: float | None = None, + ann_type: str = "detection", + segmentation: list[list[float]] | None = None, + keypoints: list[float] | None = None, + metadata: dict | None = None, +) -> dict: + return { + "id": f"ann-{uuid.uuid4().hex[:12]}", + "image_id": image_id, + "dataset_id": dataset_id, + "label": label, + "bbox_x": bbox[0] if bbox else None, + "bbox_y": bbox[1] if bbox else None, + "bbox_w": bbox[2] if bbox else None, + "bbox_h": bbox[3] if bbox else None, + "normalised": 1 if normalised else 0, + "area": area, + "confidence": confidence, + "ann_type": ann_type, + "segmentation": json.dumps(segmentation) if segmentation else None, + "keypoints": json.dumps(keypoints) if keypoints else None, + "metadata": json.dumps(metadata) if metadata else None, + } + + +# ── YOLO Parser ─────────────────────────────────────────────────────────────── + +class YOLOParser: + """ + Reads YOLO darknet annotation files (.txt) + class map. + Each line: (all normalised 0–1) + """ + + @staticmethod + def load_class_map(dataset_root: Path) -> list[str]: + """Attempt to load class names from data.yaml or classes.txt.""" + # Try data.yaml first + for yaml_file in dataset_root.rglob("data.yaml"): + try: + import yaml + with open(yaml_file, 'r', encoding='utf-8', errors='replace') as f: + data = yaml.safe_load(f) + if data and 'names' in data: + names = data['names'] + if isinstance(names, list): + return names + elif isinstance(names, dict): + # Handle dict format: {0: 'class_a', 1: 'class_b'} + return [names[i] for i in sorted(names.keys())] + except Exception: + # Fallback to regex if yaml import fails or parsing fails + try: + text = yaml_file.read_text(encoding="utf-8", errors="replace") + import re as _re + m = _re.search(r"names\s*:\s*\n((?:\s*-\s*.+\n?)+)", text) + if m: + return [line.strip().lstrip("- ").strip() for line in m.group(1).splitlines() if line.strip()] + except Exception: + pass + + # Try classes.txt + for cls_file in dataset_root.rglob("classes.txt"): + try: + lines = cls_file.read_text(encoding="utf-8", errors="replace").splitlines() + return [l.strip() for l in lines if l.strip()] + except Exception: + pass + + return [] + + @staticmethod + def parse_file( + txt_path: Path, + image_id: str, + dataset_id: str, + class_map: list[str], + ) -> list[dict]: + annotations = [] + try: + text = txt_path.read_text(encoding="utf-8", errors="replace") + except OSError: + return annotations + + for line in text.splitlines(): + parts = line.strip().split() + if len(parts) < 5: + continue + try: + cls_id = int(parts[0]) + cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]) + # YOLO cx,cy → top-left x,y + x = cx - w / 2 + y = cy - h / 2 + label = class_map[cls_id] if cls_id < len(class_map) else str(cls_id) + annotations.append( + _make_ann(image_id, dataset_id, label, (x, y, w, h), area=w * h) + ) + except (ValueError, IndexError): + continue + + return annotations + + @staticmethod + def iter_dataset( + dataset_root: Path, + dataset_id: str, + class_map: list[str], + ) -> Iterator[tuple[str, str, str, list[dict]]]: + """ + Yield (image_rel_path, image_id, split, annotations) for every image in the dataset. + Walks train/valid/test directories. + """ + # Supported subfolder names for splits + split_map = { + "train": ["train", "training"], + "val": ["valid", "val", "validation"], + "test": ["test", "testing"] + } + + found_any = False + for split_name, folder_names in split_map.items(): + for folder_name in folder_names: + split_dir = dataset_root / folder_name + images_dir = split_dir / "images" + + # Support both split/images and split/ (if images are direct) + search_dir = images_dir if images_dir.exists() else split_dir + if not search_dir.exists(): + continue + + found_any = True + labels_dir = split_dir / "labels" + + for img_path in sorted(search_dir.rglob("*")): + if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"): + continue + + image_id = f"img-{uuid.uuid4().hex[:12]}" + + # Resolve label path + # 1. split/labels/img.txt + # 2. split/img.txt + # 3. img_path.with_suffix(".txt") + label_candidates = [] + if labels_dir.exists(): + label_candidates.append(labels_dir / img_path.with_suffix(".txt").name) + label_candidates.append(img_path.with_suffix(".txt")) + + anns: list[dict] = [] + for label_file in label_candidates: + if label_file.exists(): + anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map) + break + + rel_path = str(img_path.relative_to(dataset_root)) + yield rel_path, image_id, split_name, anns + + # Fallback: if no split folders found, scan the root + if not found_any: + for img_path in sorted(dataset_root.rglob("*")): + if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"): + continue + # Skip files inside already processed folders if we had any + image_id = f"img-{uuid.uuid4().hex[:12]}" + label_file = img_path.with_suffix(".txt") + anns = [] + if label_file.exists(): + anns = YOLOParser.parse_file(label_file, image_id, dataset_id, class_map) + + rel_path = str(img_path.relative_to(dataset_root)) + yield rel_path, image_id, "train", anns + + +# ── COCO Parser ─────────────────────────────────────────────────────────────── + +class COCOParser: + """ + Reads COCO JSON annotation files. + Supports: instances_train.json, instances_val.json, _annotations.coco.json + """ + + @staticmethod + def find_annotation_files(dataset_root: Path) -> list[Path]: + patterns = ["instances_*.json", "_annotations.coco.json", "*.json"] + found = [] + for pat in patterns: + for f in dataset_root.rglob(pat): + if "label" not in f.name.lower() and "class" not in f.name.lower(): + found.append(f) + return list(dict.fromkeys(found)) # deduplicate + + @staticmethod + def parse_file( + json_path: Path, + dataset_id: str, + ) -> tuple[list[str], list[tuple[str, str, str, list[dict]]]]: + """ + Returns: (class_names, [(rel_image_path, image_id, split, annotations)]) + """ + try: + data = json.loads(json_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as e: + log.warning("coco_parse_error", file=str(json_path), error=str(e)) + return [], [] + + categories = {c["id"]: c["name"] for c in data.get("categories", [])} + class_names = list(categories.values()) + + # Determine split from filename + fname = json_path.stem.lower() + if "train" in fname: + split = "train" + elif "val" in fname or "valid" in fname: + split = "val" + elif "test" in fname: + split = "test" + else: + split = "train" + + # Build image map + image_map: dict[int, dict] = { + img["id"]: img for img in data.get("images", []) + } + + # Group annotations by image + ann_by_image: dict[int, list] = {} + for ann in data.get("annotations", []): + ann_by_image.setdefault(ann["image_id"], []).append(ann) + + results = [] + for coco_img_id, img_meta in image_map.items(): + image_id = f"img-{uuid.uuid4().hex[:12]}" + rel_path = img_meta.get("file_name", "") + anns = [] + for coco_ann in ann_by_image.get(coco_img_id, []): + label = categories.get(coco_ann.get("category_id", -1), "unknown") + bbox = coco_ann.get("bbox", []) + if len(bbox) == 4: + # COCO: [x_topleft, y_topleft, w, h] in pixel coords + img_w = img_meta.get("width", 1) or 1 + img_h = img_meta.get("height", 1) or 1 + bx = bbox[0] / img_w + by = bbox[1] / img_h + bw = bbox[2] / img_w + bh = bbox[3] / img_h + area_pct = (bbox[2] * bbox[3]) / (img_w * img_h) + + # Extract segmentation if available + segmentation = coco_ann.get("segmentation") + # COCO segmentation can be a list of polygons or RLE + poly_data = None + if isinstance(segmentation, list) and len(segmentation) > 0: + # Normalize polygon coordinates + poly_data = [] + for poly in segmentation: + normalized_poly = [] + for i in range(0, len(poly), 2): + normalized_poly.append(poly[i] / img_w) + normalized_poly.append(poly[i+1] / img_h) + poly_data.append(normalized_poly) + + anns.append( + _make_ann( + image_id, + dataset_id, + label, + (bx, by, bw, bh), + area=area_pct, + segmentation=poly_data, + ann_type="segmentation" if poly_data else "detection" + ) + ) + results.append((rel_path, image_id, split, anns)) + + return class_names, results + + +# ── VOC Parser ──────────────────────────────────────────────────────────────── + +class VOCParser: + """Reads Pascal VOC XML annotation files.""" + + @staticmethod + def parse_file( + xml_path: Path, + image_id: str, + dataset_id: str, + ) -> tuple[str, int, int, list[dict]]: + """Returns (filename, width, height, annotations).""" + try: + tree = ET.parse(str(xml_path)) + except ET.ParseError as e: + log.warning("voc_parse_error", file=str(xml_path), error=str(e)) + return "", 0, 0, [] + + root = tree.getroot() + filename = root.findtext("filename") or "" + size = root.find("size") + img_w = int(size.findtext("width") or 1) if size is not None else 1 + img_h = int(size.findtext("height") or 1) if size is not None else 1 + + anns = [] + for obj in root.findall("object"): + label = obj.findtext("name") or "unknown" + bndbox = obj.find("bndbox") + if bndbox is None: + continue + xmin = float(bndbox.findtext("xmin") or 0) + ymin = float(bndbox.findtext("ymin") or 0) + xmax = float(bndbox.findtext("xmax") or 0) + ymax = float(bndbox.findtext("ymax") or 0) + # Normalise + bx = xmin / img_w + by = ymin / img_h + bw = (xmax - xmin) / img_w + bh = (ymax - ymin) / img_h + anns.append(_make_ann(image_id, dataset_id, label, (bx, by, bw, bh))) + + return filename, img_w, img_h, anns + + @staticmethod + def iter_dataset( + dataset_root: Path, + dataset_id: str, + ) -> Iterator[tuple[str, str, str, int, int, list[dict]]]: + """Yield (rel_path, image_id, split, w, h, annotations).""" + for xml_path in sorted(dataset_root.rglob("*.xml")): + image_id = f"img-{uuid.uuid4().hex[:12]}" + filename, w, h, anns = VOCParser.parse_file(xml_path, image_id, dataset_id) + split = "train" + for part in xml_path.parts: + if part in ("train", "training"): + split = "train"; break + if part in ("val", "valid", "validation"): + split = "val"; break + if part in ("test", "testing"): + split = "test"; break + rel_path = filename or str(xml_path.with_suffix(".jpg").relative_to(dataset_root)) + yield rel_path, image_id, split, w, h, anns + + +# ── Roboflow TXT Parser ─────────────────────────────────────────────────────── + +class RoboflowTXTParser: + """ + Reads Roboflow classification TXT formats. + 1. Folder-based: split/class_name/image.jpg + 2. Label-file: split/_annotations.txt (format: filename,class_name) + """ + + @staticmethod + def iter_dataset( + dataset_root: Path, + dataset_id: str, + ) -> Iterator[tuple[str, str, str, list[dict]]]: + split_map = { + "train": ["train", "training"], + "val": ["valid", "val", "validation"], + "test": ["test", "testing"] + } + + found_any = False + for split_name, folder_names in split_map.items(): + for folder_name in folder_names: + split_dir = dataset_root / folder_name + if not split_dir.exists(): + continue + + found_any = True + + # Check for _annotations.txt (Roboflow's flat format) + ann_file = split_dir / "_annotations.txt" + if ann_file.exists(): + try: + with open(ann_file, "r", encoding="utf-8") as f: + # Format is usually: filename,class_name + for line in f: + parts = line.strip().split(",") + if len(parts) >= 2: + fname, label = parts[0], parts[1] + img_path = split_dir / fname + if img_path.exists(): + image_id = f"img-{uuid.uuid4().hex[:12]}" + anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")] + rel_path = str(img_path.relative_to(dataset_root)) + yield rel_path, image_id, split_name, anns + continue # Processed via file, skip folder logic + except Exception: + pass + + # Fallback to Folder-based: split/class_name/image.jpg + for class_dir in split_dir.iterdir(): + if class_dir.is_dir() and class_dir.name.lower() not in ["images", "labels"]: + label = class_dir.name + for img_path in class_dir.rglob("*"): + if img_path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp", ".webp"): + image_id = f"img-{uuid.uuid4().hex[:12]}" + anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")] + rel_path = str(img_path.relative_to(dataset_root)) + yield rel_path, image_id, split_name, anns + + # Fallback to root scan if no split folders found + if not found_any: + for img_path in sorted(dataset_root.rglob("*")): + if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".bmp", ".webp"): + continue + # Simple heuristic: parent folder is class name + label = img_path.parent.name if img_path.parent != dataset_root else "unknown" + image_id = f"img-{uuid.uuid4().hex[:12]}" + anns = [_make_ann(image_id, dataset_id, label, ann_type="classification")] + rel_path = str(img_path.relative_to(dataset_root)) + yield rel_path, image_id, "train", anns + +class CSVParser: + """ + Reads CSV files for NLP (classification, NER) or Tabular data. + """ + + @staticmethod + def detect_delimiter(file_path: Path) -> str: + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + header = f.readline() + if ';' in header: return ';' + if '\t' in header: return '\t' + return ',' + except Exception: + return ',' + + @staticmethod + def parse_file( + csv_path: Path, + dataset_id: str, + text_column: str = "text", + label_column: str = "label", + ) -> list[dict]: + annotations = [] + delimiter = CSVParser.detect_delimiter(csv_path) + try: + with open(csv_path, mode='r', encoding='utf-8', errors='replace') as f: + reader = csv.DictReader(f, delimiter=delimiter) + for row in reader: + image_id = f"txt-{uuid.uuid4().hex[:12]}" + text = row.get(text_column, "") + label = row.get(label_column, "unknown") + if text: + annotations.append( + _make_ann( + image_id=image_id, + dataset_id=dataset_id, + label=label, + bbox=(0, 0, 0, 0), + ann_type="nlp_classification" + ) + ) + except Exception as e: + log.error("csv_parse_error", file=str(csv_path), error=str(e)) + return annotations + + +# ── Utilities ──────────────────────────────────────────────────────────────── + +def _img_dimensions(path: Path) -> tuple[int, int]: + """Fast dimension detection via struct.""" + try: + import struct + with open(path, "rb") as f: + data = f.read(24) + if data[:8] == b"\x89PNG\r\n\x1a\n": + return struct.unpack(">II", data[16:24]) + if data[:2] == b"\xff\xd8": + f.seek(0) + full = f.read(2048) # Read more for JPEG header + i = 2 + while i < len(full) - 9: + if full[i] == 0xFF and full[i + 1] in (0xC0, 0xC1, 0xC2): + h, w = struct.unpack(">HH", full[i + 5:i + 9]) + return int(w), int(h) + i += 1 + except: pass + return 0, 0 + + +# ── Format Detector ─────────────────────────────────────────────────────────── + +def detect_format(dataset_root: Path) -> str: + """Heuristically detect the annotation format in a dataset directory.""" + # COCO: look for JSON with 'images' and 'annotations' keys + for jf in dataset_root.rglob("*.json"): + try: + snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048] + if '"images"' in snippet and '"annotations"' in snippet: + return "coco" + except OSError: + pass + + # VOC: look for XML files with root + for xf in dataset_root.rglob("*.xml"): + try: + snippet = xf.read_text(encoding="utf-8", errors="replace")[:512] + if "" in snippet: + return "voc" + except OSError: + pass + + # YOLO: check for .txt label files and data.yaml + if list(dataset_root.rglob("data.yaml")): + return "yolo" + + txt_files = list(dataset_root.rglob("*.txt")) + # Filter out common non-label files + label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")] + if label_txts: + # Check if first line looks like YOLO ( ) + try: + first_txt = label_txts[0] + content = first_txt.read_text(encoding="utf-8").strip().split('\n')[0] + if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content): + return "yolo" + except Exception: + pass + + # Roboflow Classification TXT: check for split folders containing only subfolders (class names) + # or check for _annotations.txt + if list(dataset_root.rglob("_annotations.txt")): + return "txt" + + # Check for folder-based classification (split/class_name/img.jpg) + # If we see folders that aren't 'images' or 'labels' inside train/val/test + for split in ["train", "valid", "test"]: + split_dir = dataset_root / split + if split_dir.exists() and split_dir.is_dir(): + subdirs = [d for d in split_dir.iterdir() if d.is_dir()] + if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs): + return "txt" + + # CSV/NLP: check for csv files + if list(dataset_root.rglob("*.csv")): + return "csv" + + return "custom" diff --git a/datasets/base_adapter.py b/datasets/base_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9801ee4cd0b8f919122a79452dde1e9dbf63a170 --- /dev/null +++ b/datasets/base_adapter.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Tuple, Iterator, Dict, Any, Optional +from models.dataset import UniversalDatasetItem, DatasetTask + +class DatasetAdapter(ABC): + """ + Base interface for all dataset format adapters. + Following the senior architect pattern: decoupling format logic from import orchestration. + """ + + @abstractmethod + def detect(self, dataset_path: Path) -> bool: + """Return True if this adapter can handle the dataset at the given path.""" + pass + + @abstractmethod + def get_task(self, dataset_path: Path) -> DatasetTask: + """Identify the primary task type (detection, classification, etc.) for this dataset.""" + pass + + @abstractmethod + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + """ + Yield (image_record, annotations) for each item in the dataset. + Memory-efficient streaming for large Roboflow datasets. + """ + pass + + @abstractmethod + def get_class_names(self, dataset_path: Path) -> List[str]: + """Extract or derive the list of class names from the dataset.""" + pass + + def get_metadata(self, dataset_path: Path) -> Dict[str, Any]: + """Optional: Extract additional format-specific metadata.""" + return {} diff --git a/datasets/format_adapters.py b/datasets/format_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..80897b5a79ab81d4f6534e22433aaa2b46517383 --- /dev/null +++ b/datasets/format_adapters.py @@ -0,0 +1,235 @@ +from pathlib import Path +import json +import re +from typing import Any, List, Tuple, Iterator, Dict +from .base_adapter import DatasetAdapter +from models.dataset import UniversalDatasetItem, DatasetContentType, UniversalAnnotation, UniversalAnnotationType, DatasetTask +from .annotation_parser import YOLOParser, COCOParser, VOCParser, RoboflowTXTParser, _img_dimensions + +class YOLOAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + if list(dataset_path.rglob("data.yaml")): + return True + txt_files = list(dataset_path.rglob("*.txt")) + label_txts = [f for f in txt_files if f.name not in ("classes.txt", "obj.names", "README.txt", "LICENSE.txt", "README.roboflow.txt")] + if label_txts: + try: + content = label_txts[0].read_text(encoding="utf-8").strip().split('\n')[0] + if re.match(r"^\d+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+\s+[\d\.]+", content): + return True + except: pass + return False + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.detection + + def get_class_names(self, dataset_path: Path) -> List[str]: + return YOLOParser.load_class_map(dataset_path) + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + class_map = self.get_class_names(dataset_path) + for rel_path, image_id, split, anns in YOLOParser.iter_dataset(dataset_path, dataset_id, class_map): + abs_path = dataset_path / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + yield img_rec, anns + +class COCOAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + for jf in dataset_path.rglob("*.json"): + try: + snippet = jf.read_text(encoding="utf-8", errors="replace")[:2048] + if '"images"' in snippet and '"annotations"' in snippet: + return True + except: pass + return False + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.segmentation # Roboflow COCO often implies segmentation + + def get_class_names(self, dataset_path: Path) -> List[str]: + ann_files = COCOParser.find_annotation_files(dataset_path) + all_classes = [] + for ann_file in ann_files: + classes, _ = COCOParser.parse_file(ann_file, "dummy") + all_classes = list(dict.fromkeys(all_classes + classes)) + return all_classes + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + ann_files = COCOParser.find_annotation_files(dataset_path) + for ann_file in ann_files: + _, coco_results = COCOParser.parse_file(ann_file, dataset_id) + for rel_path, image_id, split, anns in coco_results: + abs_path = dataset_path / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + yield img_rec, anns + +class VOCAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + for xf in dataset_path.rglob("*.xml"): + try: + snippet = xf.read_text(encoding="utf-8", errors="replace")[:512] + if "" in snippet: + return True + except: pass + return False + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.detection + + def get_class_names(self, dataset_path: Path) -> List[str]: + classes = set() + for _, _, _, _, _, anns in VOCParser.iter_dataset(dataset_path, "dummy"): + for ann in anns: + classes.add(ann["label"]) + return sorted(list(classes)) + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(dataset_path, dataset_id): + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + yield img_rec, anns + +class CreateMLAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + for jf in dataset_path.rglob("*.json"): + try: + snippet = jf.read_text(encoding="utf-8", errors="replace")[:1024] + if '"image"' in snippet and '"annotations"' in snippet and "[" in snippet: + return True + except: pass + return False + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.detection + + def get_class_names(self, dataset_path: Path) -> List[str]: + classes = set() + for jf in dataset_path.rglob("*.json"): + try: + data = json.loads(jf.read_text(encoding="utf-8")) + if isinstance(data, list): + for item in data: + for ann in item.get("annotations", []): + if "label" in ann: classes.add(ann["label"]) + except: pass + return sorted(list(classes)) + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + from .annotation_parser import _make_ann + for jf in dataset_path.rglob("*.json"): + try: + data = json.loads(jf.read_text(encoding="utf-8")) + if not isinstance(data, list): continue + + # Determine split from path + split = "train" + if "val" in jf.parts or "valid" in jf.parts: split = "val" + elif "test" in jf.parts: split = "test" + + for item in data: + rel_img_path = item.get("image") + if not rel_img_path: continue + + # Try to find the image relative to JSON or root + img_path = jf.parent / rel_img_path + if not img_path.exists(): + img_path = dataset_path / rel_img_path + + if img_path.exists(): + image_id = f"img-{uuid.uuid4().hex[:12]}" + w, h = _img_dimensions(img_path) + + anns = [] + for ca in item.get("annotations", []): + label = ca.get("label", "unknown") + coord = ca.get("coordinates", {}) + # CreateML coords are usually center-based pixels: {x, y, width, height} + if "x" in coord and "y" in coord and w > 0 and h > 0: + cx, cy, bw, bh = coord["x"], coord["y"], coord["width"], coord["height"] + # Convert to top-left normalized + nx = (cx - bw/2) / w + ny = (cy - bh/2) / h + nw = bw / w + nh = bh / h + anns.append(_make_ann(image_id, dataset_id, label, (nx, ny, nw, nh))) + + img_rec = { + "id": image_id, "filename": img_path.name, + "rel_path": str(img_path.relative_to(dataset_path)), + "width": w, "height": h, "split": split, "ann_count": len(anns) + } + yield img_rec, anns + except: pass + +class NLPAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + return any(dataset_path.rglob("*.csv")) or any(dataset_path.rglob("*.tsv")) + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.nlp + + def get_class_names(self, dataset_path: Path) -> List[str]: + # Implementation for NLP class names + return [] + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + # Implementation for NLP items + yield {}, [] + +class TabularAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + return False # Placeholder + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.classification + + def get_class_names(self, dataset_path: Path) -> List[str]: + return [] + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + yield {}, [] + +class RoboflowClassificationAdapter(DatasetAdapter): + def detect(self, dataset_path: Path) -> bool: + # Check for _annotations.txt or folder-based classification + if list(dataset_path.rglob("_annotations.txt")): return True + for split in ["train", "valid", "test"]: + split_dir = dataset_path / split + if split_dir.exists() and split_dir.is_dir(): + subdirs = [d for d in split_dir.iterdir() if d.is_dir()] + if subdirs and not any(d.name.lower() in ["images", "labels"] for d in subdirs): + return True + return False + + def get_task(self, dataset_path: Path) -> DatasetTask: + return DatasetTask.classification + + def get_class_names(self, dataset_path: Path) -> List[str]: + classes = set() + for _, _, _, anns in RoboflowTXTParser.iter_dataset(dataset_path, "dummy"): + for ann in anns: classes.add(ann["label"]) + return sorted(list(classes)) + + def iter_items(self, dataset_id: str, dataset_path: Path) -> Iterator[Tuple[Dict[str, Any], List[Dict[str, Any]]]]: + for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(dataset_path, dataset_id): + abs_path = dataset_path / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + yield img_rec, anns diff --git a/datasets/import_service.py b/datasets/import_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b04bc2d57b71854f66d68b9990940c7cbe019d61 --- /dev/null +++ b/datasets/import_service.py @@ -0,0 +1,589 @@ +""" +datasets/import_service.py — Dataset Import Pipeline. + +Pipeline stages: + 1. Create job record + 2. Download dataset zip (chunked, progress-tracked) + 3. Extract zip safely (path-traversal protected) + 4. Detect annotation format & task type + 5. Index images into dataset_images table + 6. Parse & store metadata (Stats only, annotations are read-on-demand) + 7. Update dataset stats (images, classes, size) + 8. Mark job completed / failed + +All stages run as background tasks. +Supports Roboflow, HuggingFace, and local file/folder imports. +""" +from __future__ import annotations + +import asyncio +import hashlib +import os +import shutil +import uuid +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple + +import aiofiles +import httpx +from huggingface_hub import snapshot_download + +from config import settings +from . import registry as ds_reg +from .format_adapters import ( + YOLOAdapter, COCOAdapter, VOCAdapter, CreateMLAdapter, + RoboflowClassificationAdapter, NLPAdapter, TabularAdapter +) +from .base_adapter import DatasetAdapter +from .annotation_parser import _img_dimensions +from observability.logger import audit, get_logger +from models.dataset import DatasetStatus, DatasetTask, ImportRequest, Dataset + +log = get_logger("import_service") + +ADAPTERS: List[DatasetAdapter] = [ + YOLOAdapter(), + COCOAdapter(), + VOCAdapter(), + CreateMLAdapter(), + RoboflowClassificationAdapter(), + NLPAdapter(), + TabularAdapter(), +] + +def get_adapter_for_path(path: Path) -> DatasetAdapter | None: + for adapter in ADAPTERS: + if adapter.detect(path): + return adapter + return None + +async def recover_stale_jobs() -> None: + """Cleanup dataset import jobs that were left in 'running' or 'queued' state.""" + await ds_reg.cleanup_stale_jobs() + +def _dataset_path(dataset_id: str) -> Path: + return settings.datasets_dir / dataset_id + +# ── Entry Point ────────────────────────────────────────────────────────────── + +async def start_import(req: ImportRequest) -> str: + """Entry point to initiate a background import job.""" + job_id = f"job-{uuid.uuid4().hex[:8]}" + + # Create initial job record + await ds_reg.update_job( + job_id, + dataset_id=req.dataset_id, + status="queued", + progress=0, + message="Import queued", + type=str(req.source) + ) + + # Launch background task + asyncio.create_task(_run_pipeline(job_id, req, req.dataset_name or req.dataset_id)) + + return job_id + + +# ── Pipeline orchestrator ──────────────────────────────────────────────────── + +async def _run_pipeline(job_id: str, req: ImportRequest, dataset_name: str) -> None: + started = datetime.utcnow().isoformat() + await ds_reg.update_job(job_id, status="running", started_at=started, message="Starting import") + await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.importing, progress=0.01) + + try: + # Stage 1 – Resolve download URL or local path + source_path = await _stage_acquire(job_id, req) + + # Stage 2 – Extract / Prepare Directory + extract_dir = await _stage_extract(job_id, req.dataset_id, source_path) + + # Stage 3 – Detect adapter and Task + await ds_reg.update_job(job_id, progress=0.55, message="Detecting dataset format...") + adapter = await asyncio.to_thread(get_adapter_for_path, extract_dir) + + if not adapter: + log.warning("no_adapter_found_generic_fallback", dataset_id=req.dataset_id) + image_records = await asyncio.to_thread(_scan_images_generic, req.dataset_id, extract_dir) + class_names = [] + task = DatasetTask.classification + fmt_name = "custom" + else: + task = adapter.get_task(extract_dir) + fmt_name = adapter.__class__.__name__.replace("Adapter", "").lower() + + log.info("adapter_detected", job_id=job_id, format=fmt_name, task=task) + await ds_reg.update_job(job_id, progress=0.60, message=f"Parsing {fmt_name.upper()} {task.upper()}") + + # Stage 4 – Parse Metadata & Annotations (Streaming) + class_names = await asyncio.to_thread(adapter.get_class_names, extract_dir) + image_records = [] + all_annotations = [] + + # Health metrics tracking + hashes = {} # hash -> filename + duplicates = 0 + empty_images = 0 + total_ann_count = 0 + + for img_rec, anns in adapter.iter_items(req.dataset_id, extract_dir): + # Duplicate detection via MD5 hash + abs_path = extract_dir / img_rec["rel_path"] + if abs_path.exists(): + img_hash = _calculate_hash(abs_path) + if img_hash in hashes: + duplicates += 1 + img_rec["metadata"] = json.dumps({"is_duplicate": True, "original": hashes[img_hash]}) + else: + hashes[img_hash] = img_rec["filename"] + + if not anns: + empty_images += 1 + + total_ann_count += len(anns) + image_records.append(img_rec) + all_annotations.extend(anns) + + if not image_records: + raise ValueError(f"No valid data files found in {extract_dir}") + + # Stage 5 – Indexing + await ds_reg.update_job(job_id, progress=0.80, message=f"Indexing {len(image_records)} items") + await ds_reg.index_images(req.dataset_id, image_records) + + if all_annotations: + await ds_reg.update_job(job_id, progress=0.85, message=f"Indexing {len(all_annotations)} annotations") + await ds_reg.bulk_insert_annotations(all_annotations) + + # Stage 6 – Stats & Health Analysis + size_bytes = await asyncio.to_thread(_dir_size, extract_dir) + + # Calculate Health Score (0-100) + # Factors: duplicates, empty images (for detection), class balance (TODO) + score = 100.0 + if len(image_records) > 0: + dup_penalty = (duplicates / len(image_records)) * 50 + empty_penalty = (empty_images / len(image_records)) * 20 if task == DatasetTask.detection else 0 + score = max(0.0, 100.0 - dup_penalty - empty_penalty) + + stats_payload = { + "image_count": len(image_records), + "annotation_count": total_ann_count, + "class_count": len(class_names), + "empty_images": empty_images, + "duplicate_count": duplicates, + "health_score": round(score, 1), + "avg_objects": round(total_ann_count / len(image_records), 2) if image_records else 0 + } + + await ds_reg.update_dataset_stats( + req.dataset_id, + len(image_records), + len(class_names), + class_names, + size_bytes, + stats=stats_payload + ) + await ds_reg.update_dataset_task(req.dataset_id, task) + + # Cleanup temp zip if applicable + if source_path.is_file() and source_path.suffix.lower() == ".zip" and "_tmp" in str(source_path): + source_path.unlink(missing_ok=True) + + # Stage 7 – Project Linking (Integration point) + local_path = str(extract_dir) + from projects.service import link_dataset_to_active_project + project_ds_root = await link_dataset_to_active_project(req.dataset_id, local_path) + final_local_path = str(project_ds_root) if project_ds_root and project_ds_root.exists() else local_path + + # Completion + await ds_reg.update_job( + job_id, status="completed", progress=1.0, + message="Import complete", ended_at=datetime.utcnow().isoformat(), + ) + await ds_reg.update_dataset_status(req.dataset_id, DatasetStatus.imported, progress=1.0, local_path=final_local_path) + await audit("dataset_import_complete", {"job_id": job_id, "path": final_local_path}, job_id=job_id) + log.info("import_complete", job_id=job_id, dataset_id=req.dataset_id) + + except asyncio.CancelledError: + await _fail_job(job_id, req.dataset_id, "Import cancelled by user or system") + raise + except Exception as exc: + log.error("import_failed", job_id=job_id, error=str(exc)) + await _fail_job(job_id, req.dataset_id, str(exc)) + await audit("dataset_import_error", {"job_id": job_id, "error": str(exc)}, job_id=job_id, level="error") + + +async def _fail_job(job_id: str, dataset_id: str, error: str) -> None: + await ds_reg.update_job( + job_id, status="failed", error=error, + ended_at=datetime.utcnow().isoformat(), + message="Import failed", + ) + await ds_reg.update_dataset_status(dataset_id, DatasetStatus.failed, progress=0.0) + + +# ── Stage 1: Acquire source ────────────────────────────────────────────────── + +async def _stage_acquire(job_id: str, req: ImportRequest) -> Path: + """Resolves the source (Download URL, HF Repo, or Local Path).""" + await ds_reg.update_job(job_id, progress=0.05, message="Acquiring source...") + + if req.source in ("roboflow", "roboflow_curl"): + return await _acquire_roboflow(job_id, req) + + if req.source == "huggingface": + return await _acquire_huggingface(job_id, req) + + if req.source == "local": + return await _acquire_local(job_id, req) + + raise ValueError(f"Unsupported source provider: {req.source}") + + +async def _acquire_roboflow(job_id: str, req: ImportRequest) -> Path: + """Specialized Roboflow downloader using SDK or direct link.""" + # Attempt SDK first (more reliable for Universe) + try: + from roboflow import Roboflow + api_key = req.roboflow_key or (req.headers.get("Authorization") if req.headers else None) + if api_key and "Bearer " in str(api_key): + api_key = api_key.split("Bearer ")[-1].strip() + + if api_key and req.roboflow_workspace and req.roboflow_project: + rf = Roboflow(api_key=api_key) + project = rf.workspace(req.roboflow_workspace).project(req.roboflow_project) + version_obj = project.version(req.roboflow_version or 1) + + tmp_target = DATASETS_ROOT / "_tmp" / f"rf-{uuid.uuid4().hex[:8]}" + await ds_reg.update_job(job_id, progress=0.10, message="Downloading via Roboflow SDK...") + + # Threaded SDK call + await asyncio.to_thread( + version_obj.download, + _format_to_rf_slug(str(req.format)), + location=str(tmp_target) + ) + return tmp_target + except Exception as e: + log.warning("roboflow_sdk_fallback", error=str(e)) + + # Fallback to direct HTTP download + url = req.download_url + if not url and req.source == "roboflow": + from adapters.roboflow_adapter import RoboflowAdapter + url = await RoboflowAdapter.get_download_url( + api_key=req.roboflow_key, + workspace=req.roboflow_workspace, + project_id=req.roboflow_project, + version=req.roboflow_version, + export_format=_format_to_rf_slug(str(req.format)), + ) + + if not url: + raise ValueError("Could not resolve Roboflow download URL") + + return await _download_zip(job_id, req.dataset_id, url, req.headers) + + +async def _acquire_huggingface(job_id: str, req: ImportRequest) -> Path: + if not req.hf_dataset_id: + raise ValueError("hf_dataset_id is missing") + + dest_dir = _dataset_path(req.dataset_id) + dest_dir.mkdir(parents=True, exist_ok=True) + + await ds_reg.update_job(job_id, progress=0.10, message=f"Cloning {req.hf_dataset_id} from HF...") + + await asyncio.to_thread( + snapshot_download, + repo_id=req.hf_dataset_id, + repo_type="dataset", + local_dir=str(dest_dir), + token=settings.hf_token, + local_dir_use_symlinks=False + ) + return dest_dir + + +async def _acquire_local(job_id: str, req: ImportRequest) -> Path: + if not req.local_path: + raise ValueError("local_path is missing for local import") + + path = Path(os.path.normpath(req.local_path.strip().strip('"').strip("'"))) + if not path.exists(): + raise FileNotFoundError(f"Local path does not exist: {path}") + + return path + + +# ── Stage 2: Extraction ────────────────────────────────────────────────────── + +async def _stage_extract(job_id: str, dataset_id: str, source_path: Path) -> Path: + dest = _dataset_path(dataset_id) + dest.mkdir(parents=True, exist_ok=True) + + if source_path.is_dir(): + if source_path == dest: + return dest + await ds_reg.update_job(job_id, progress=0.45, message="Copying local files...") + await asyncio.to_thread(_copy_dir_contents, source_path, dest) + return dest + + # It's a zip + await ds_reg.update_job(job_id, progress=0.45, message="Extracting archive...") + await ds_reg.update_dataset_status(dataset_id, DatasetStatus.extracting, progress=0.45) + await asyncio.to_thread(_safe_extract, source_path, dest) + return dest + + +# ── Stage 3: Parsing (Memory-Safe) ─────────────────────────────────────────── + +def _heuristic_task_detection(fmt: str, root: Path) -> DatasetTask: + """Improved task detection based on file content.""" + if fmt == "csv": + return DatasetTask.nlp + + # Check for segmentation in COCO + if fmt == "coco": + # Sample first few lines of JSON if possible or check file size + return DatasetTask.segmentation # Heuristic: most modern COCO use cases + + if fmt in ("yolo", "voc"): + return DatasetTask.detection + + return DatasetTask.classification + + +def _parse_yolo(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + class_map = YOLOParser.load_class_map(root) + results = [] + # Generator approach to keep memory low + for rel_path, image_id, split, anns in YOLOParser.iter_dataset(root, dataset_id, class_map): + abs_path = root / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + results.append((img_rec, anns)) + return class_map, results + + +def _parse_coco(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + ann_files = COCOParser.find_annotation_files(root) + all_classes: list[str] = [] + results = [] + for ann_file in ann_files: + classes, coco_results = COCOParser.parse_file(ann_file, dataset_id) + all_classes = list(dict.fromkeys(all_classes + classes)) + for rel_path, image_id, split, anns in coco_results: + abs_path = root / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + results.append((img_rec, anns)) + return all_classes, results + + +def _parse_voc(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + class_set = set() + results = [] + for rel_path, image_id, split, w, h, anns in VOCParser.iter_dataset(root, dataset_id): + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + results.append((img_rec, anns)) + for ann in anns: + class_set.add(ann["label"]) + return sorted(list(class_set)), results + + +def _parse_csv(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + all_classes = set() + results = [] + for csv_path in root.rglob("*.csv"): + anns = CSVParser.parse_file(csv_path, dataset_id) + # For CSV, each annotation is a row. We group by text entry id (image_id) + anns_by_id: Dict[str, List[Dict]] = {} + for ann in anns: + all_classes.add(ann["label"]) + anns_by_id.setdefault(ann["image_id"], []).append(ann) + + for text_id, grouped_anns in anns_by_id.items(): + img_rec = { + "id": text_id, "filename": csv_path.name, + "rel_path": str(csv_path.relative_to(root)), + "width": 0, "height": 0, "split": "train", "ann_count": len(grouped_anns), + } + results.append((img_rec, grouped_anns)) + return sorted(list(all_classes)), results + + +def _parse_txt(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + from datasets.annotation_parser import RoboflowTXTParser + results = [] + class_set = set() + + for rel_path, image_id, split, anns in RoboflowTXTParser.iter_dataset(root, dataset_id): + abs_path = root / rel_path + w, h = _img_dimensions(abs_path) + img_rec = { + "id": image_id, "filename": Path(rel_path).name, + "rel_path": str(rel_path), "width": w, "height": h, + "split": split, "ann_count": len(anns), + } + results.append((img_rec, anns)) + for ann in anns: + class_set.add(ann["label"]) + + return sorted(list(class_set)), results + + +def _parse_generic_folder(dataset_id: str, root: Path) -> Tuple[List[str], List[Tuple[Dict, List[Dict]]]]: + """ + Enhanced generic folder parser. Supports: + 1. root/class_name/img.jpg + 2. root/train/class_name/img.jpg + 3. root/images/img.jpg + """ + results = [] + class_set = set() + exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"} + + # Structural keywords to ignore as classes + ignore = {"images", "labels", "train", "val", "test", "validation", "training", "valid", "testing", "unknown", "annotations"} + + for img_path in sorted(root.rglob("*")): + if img_path.suffix.lower() not in exts: + continue + + rel_path = img_path.relative_to(root) + parts = rel_path.parts + + # Heuristic for class detection + label = "unknown" + split = "train" + + # Detect split if first folder is a split keyword + if parts[0].lower() in ignore and len(parts) > 1: + if parts[0].lower() in ("train", "training"): split = "train" + elif parts[0].lower() in ("val", "valid", "validation"): split = "val" + elif parts[0].lower() in ("test", "testing"): split = "test" + + # Check if next part is class name + if len(parts) > 2 and parts[1].lower() not in ignore: + label = parts[1] + elif len(parts) > 1 and parts[1].lower() not in ignore: + label = parts[1] + elif len(parts) > 1 and parts[0].lower() not in ignore: + label = parts[0] + + anns = [] + if label != "unknown": + class_set.add(label) + image_id = f"img-{uuid.uuid4().hex[:12]}" + # Create a virtual annotation for classification + from datasets.annotation_parser import _make_ann + anns.append(_make_ann(image_id, dataset_id, label, ann_type="classification")) + else: + image_id = f"img-{uuid.uuid4().hex[:12]}" + + w, h = _img_dimensions(img_path) + img_rec = { + "id": image_id, + "filename": img_path.name, + "rel_path": str(rel_path), + "width": w, "height": h, + "split": split, + "ann_count": len(anns), + } + results.append((img_rec, anns)) + + return sorted(list(class_set)), results + + +# ── Utilities ──────────────────────────────────────────────────────────────── + +async def _download_zip(job_id: str, dataset_id: str, url: str, custom_headers: dict = None) -> Path: + tmp_dir = DATASETS_ROOT / "_tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + zip_path = tmp_dir / f"{dataset_id}-{uuid.uuid4().hex[:8]}.zip" + + headers = { + "User-Agent": "Mozilla/5.0 (MLForge Workbench)", + "Accept": "application/zip, application/octet-stream, */*", + } + if custom_headers: headers.update(custom_headers) + + async with httpx.AsyncClient(follow_redirects=True, timeout=600.0, headers=headers) as client: + async with client.stream("GET", url) as resp: + resp.raise_for_status() + total = int(resp.headers.get("content-length", 0)) or None + downloaded = 0 + async with aiofiles.open(zip_path, "wb") as f: + async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size): + await f.write(chunk) + downloaded += len(chunk) + if total: + pct = 0.10 + (downloaded / total) * 0.35 # 10% -> 45% + await ds_reg.update_job(job_id, progress=round(pct, 3), message=f"Downloading: {_fmt_bytes(downloaded)} / {_fmt_bytes(total)}") + + return zip_path + + +def _safe_extract(zip_path: Path, dest: Path) -> None: + with zipfile.ZipFile(str(zip_path), "r") as zf: + for member in zf.namelist(): + if os.path.isabs(member) or ".." in Path(member).parts: continue + zf.extract(member, str(dest)) + + +def _copy_dir_contents(src: Path, dest: Path) -> None: + for item in src.iterdir(): + s, d = src / item.name, dest / item.name + if s.is_dir(): shutil.copytree(s, d, dirs_exist_ok=True) + else: shutil.copy2(s, d) + + +def _scan_images_generic(dataset_id: str, root: Path) -> list[dict]: + records = [] + exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} + for img_path in sorted(root.rglob("*")): + if img_path.suffix.lower() in exts: + w, h = _img_dimensions(img_path) + records.append({ + "id": f"img-{uuid.uuid4().hex[:12]}", + "filename": img_path.name, + "rel_path": str(img_path.relative_to(root)), + "width": w, "height": h, "split": "train", "ann_count": 0, + }) + return records + + +def _dir_size(path: Path) -> int: + return sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) + + +def _fmt_bytes(n: int) -> str: + for unit in ("B", "KB", "MB", "GB", "TB"): + if n < 1024: return f"{n:.1f} {unit}" + n /= 1024 + return f"{n:.1f} PB" + + +def _format_to_rf_slug(fmt: str) -> str: + return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8") + +def _format_to_rf_slug(fmt: str) -> str: + return {"yolo": "yolov8", "coco": "coco", "voc": "voc"}.get(fmt, "yolov8") diff --git a/datasets/registry.py b/datasets/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e9aaed1cab268d9500f6fc75f5736cb394d8a72f --- /dev/null +++ b/datasets/registry.py @@ -0,0 +1,452 @@ +""" +datasets/registry.py — Dataset Registry: persistent CRUD against datasets table. +All DB interactions for datasets and dataset_jobs live here. +""" +from __future__ import annotations + +import json +import uuid +from datetime import datetime +from typing import Any + +from database.connection import get_db +from models.dataset import Dataset, DatasetJob, DatasetStatus, row_to_dataset, row_to_job +from observability.logger import get_logger + +log = get_logger("dataset_registry") + + +# ── Dataset CRUD ────────────────────────────────────────────────────────────── + +async def get_all_datasets( + task: str | None = None, + format: str | None = None, + source: str | None = None, + status: str | None = None, + search: str | None = None, + starred: bool | None = None, + limit: int = 500, + offset: int = 0, +) -> list[Dataset]: + db = await get_db() + clauses = [] + params: list[Any] = [] + + if task: + clauses.append("task = ?") + params.append(task) + if format: + clauses.append("format = ?") + params.append(format) + if source: + clauses.append("source = ?") + params.append(source) + if status: + clauses.append("status = ?") + params.append(status) + if starred is not None: + clauses.append("starred = ?") + params.append(1 if starred else 0) + if search: + clauses.append("(name LIKE ? OR description LIKE ? OR tags LIKE ?)") + q = f"%{search}%" + params.extend([q, q, q]) + + where = f"WHERE {' AND '.join(clauses)}" if clauses else "" + sql = f"SELECT * FROM datasets {where} ORDER BY updated_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + async with db.execute(sql, params) as cur: + rows = await cur.fetchall() + return [row_to_dataset(r) for r in rows] + + +async def get_dataset_stats(dataset_id: str) -> dict: + """Get pre-computed class distributions and statistics from the indexed annotations.""" + db = await get_db() + + # Class distribution (from dataset_annotations table) + async with db.execute( + "SELECT label, COUNT(*) as count FROM dataset_annotations WHERE dataset_id=? GROUP BY label ORDER BY count DESC", + (dataset_id,) + ) as cur: + dist = await cur.fetchall() + + # Split distribution (from dataset_images table) + async with db.execute( + "SELECT split, COUNT(*) as count FROM dataset_images WHERE dataset_id=? GROUP BY split", + (dataset_id,) + ) as cur: + splits = await cur.fetchall() + + return { + "class_distribution": {row["label"]: row["count"] for row in dist}, + "split_distribution": {row["split"]: row["count"] for row in splits} + } + + +async def get_dataset(dataset_id: str) -> Dataset | None: + db = await get_db() + async with db.execute("SELECT * FROM datasets WHERE id = ?", (dataset_id,)) as cur: + row = await cur.fetchone() + return row_to_dataset(row) if row else None + + +async def count_datasets() -> int: + db = await get_db() + async with db.execute("SELECT COUNT(*) FROM datasets") as cur: + row = await cur.fetchone() + return row[0] if row else 0 + + +async def upsert_dataset(ds: Dataset) -> None: + """Insert or replace a dataset record.""" + db = await get_db() + + task = getattr(ds.task, "value", ds.task) + fmt = getattr(ds.format, "value", ds.format) + src = getattr(ds.source, "value", ds.source) + status = getattr(ds.status, "value", ds.status) + await db.execute( + """INSERT OR REPLACE INTO datasets + (id, name, description, task, format, source, status, + images, classes, class_names, size_bytes, size_label, + local_path, import_progress, tags, versions, active_version, + starred, roboflow_id, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,datetime('now'))""", + ( + ds.id, ds.name, ds.description, task, fmt, + src, status, + ds.images, ds.classes, + json.dumps(ds.class_names), ds.size_bytes, ds.size_label, + ds.local_path, ds.import_progress, + json.dumps(ds.tags), + json.dumps([v.model_dump() if hasattr(v, "model_dump") else v for v in ds.versions]), + ds.active_version, + 1 if ds.starred else 0, + ds.roboflow_id, + ds.created_at or datetime.utcnow().isoformat(), + ), + ) + await db.commit() + + +async def update_dataset_status( + dataset_id: str, + status: DatasetStatus, + progress: float | None = None, + local_path: str | None = None, +) -> None: + db = await get_db() + if progress is not None and local_path is not None: + await db.execute( + "UPDATE datasets SET status=?, import_progress=?, local_path=? WHERE id=?", + (status.value, progress, local_path, dataset_id), + ) + elif progress is not None: + await db.execute( + "UPDATE datasets SET status=?, import_progress=? WHERE id=?", + (status.value, progress, dataset_id), + ) + else: + await db.execute( + "UPDATE datasets SET status=? WHERE id=?", + (status.value, dataset_id), + ) + await db.commit() + + +async def update_dataset_stats( + dataset_id: str, + images: int, + classes: int, + class_names: list[str], + size_bytes: int, + stats: dict | None = None +) -> None: + db = await get_db() + + # Calculate health score if stats provided + health_score = 0.0 + if stats: + health_score = stats.get("health_score", 0.0) + + await db.execute( + """UPDATE datasets + SET images=?, classes=?, class_names=?, size_bytes=?, + size_label=?, stats=?, health_score=? + WHERE id=?""", + ( + images, classes, json.dumps(class_names), + size_bytes, _fmt_bytes(size_bytes), + json.dumps(stats) if stats else None, + health_score, + dataset_id, + ), + ) + await db.commit() + + +async def delete_dataset(dataset_id: str) -> bool: + db = await get_db() + async with db.execute("SELECT 1 FROM datasets WHERE id=?", (dataset_id,)) as cur: + exists = await cur.fetchone() + if not exists: + return False + await db.execute("DELETE FROM datasets WHERE id=?", (dataset_id,)) + await db.commit() + return True + + +async def toggle_starred(dataset_id: str) -> bool: + """Toggle starred flag, return new value.""" + db = await get_db() + async with db.execute("SELECT starred FROM datasets WHERE id=?", (dataset_id,)) as cur: + row = await cur.fetchone() + if not row: + return False + new_val = 0 if row["starred"] else 1 + await db.execute("UPDATE datasets SET starred=? WHERE id=?", (new_val, dataset_id)) + await db.commit() + return bool(new_val) + + +# ── Bulk dataset upsert from Roboflow ──────────────────────────────────────── + +async def bulk_upsert_datasets(datasets: list[Dataset]) -> int: + """Insert/update many datasets in a single transaction.""" + if not datasets: + return 0 + db = await get_db() + now = datetime.utcnow().isoformat() + rows = [ + ( + ds.id, ds.name, ds.description, ds.task.value, ds.format.value, + ds.source.value, ds.status.value, + ds.images, ds.classes, + json.dumps(ds.class_names), ds.size_bytes, ds.size_label, + ds.local_path, ds.import_progress, + json.dumps(ds.tags), json.dumps([]), + ds.active_version, 0, ds.roboflow_id, + ds.created_at or now, + ) + for ds in datasets + ] + await db.executemany( + """INSERT OR IGNORE INTO datasets + (id, name, description, task, format, source, status, + images, classes, class_names, size_bytes, size_label, + local_path, import_progress, tags, versions, active_version, + starred, roboflow_id, created_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + rows, + ) + await db.commit() + return len(datasets) + + +# ── Dataset Jobs ────────────────────────────────────────────────────────────── + +async def create_job( + dataset_id: str, + dataset_name: str, + job_type: str, +) -> DatasetJob: + db = await get_db() + job_id = f"djob-{uuid.uuid4().hex[:12]}" + now = datetime.utcnow().isoformat() + await db.execute( + """INSERT INTO dataset_jobs + (id, type, status, dataset_id, dataset_name, progress, message, created_at) + VALUES (?, ?, 'queued', ?, ?, 0.0, '', ?)""", + (job_id, job_type, dataset_id, dataset_name, now), + ) + await db.commit() + return DatasetJob( + id=job_id, type=job_type, status="queued", + dataset_id=dataset_id, dataset_name=dataset_name, + created_at=now, + ) + + +async def update_job( + job_id: str, + status: str | None = None, + progress: float | None = None, + message: str | None = None, + error: str | None = None, + started_at: str | None = None, + ended_at: str | None = None, +) -> None: + db = await get_db() + parts = [] + params: list[Any] = [] + if status is not None: + parts.append("status=?"); params.append(status) + if progress is not None: + parts.append("progress=?"); params.append(progress) + if message is not None: + parts.append("message=?"); params.append(message) + if error is not None: + parts.append("error=?"); params.append(error) + if started_at is not None: + parts.append("started_at=?"); params.append(started_at) + if ended_at is not None: + parts.append("ended_at=?"); params.append(ended_at) + if not parts: + return + params.append(job_id) + await db.execute(f"UPDATE dataset_jobs SET {', '.join(parts)} WHERE id=?", params) + await db.commit() + + +async def get_job(job_id: str) -> DatasetJob | None: + db = await get_db() + async with db.execute("SELECT * FROM dataset_jobs WHERE id=?", (job_id,)) as cur: + row = await cur.fetchone() + return row_to_job(row) if row else None + + +async def get_all_jobs(limit: int = 100) -> list[DatasetJob]: + db = await get_db() + async with db.execute( + "SELECT * FROM dataset_jobs ORDER BY created_at DESC LIMIT ?", (limit,) + ) as cur: + rows = await cur.fetchall() + return [row_to_job(r) for r in rows] + + +# ── Image Index ─────────────────────────────────────────────────────────────── + +async def index_images( + dataset_id: str, + records: list[dict], # [{id, filename, rel_path, width, height, split, ann_count}] +) -> int: + db = await get_db() + await db.executemany( + """INSERT OR IGNORE INTO dataset_images + (id, dataset_id, filename, rel_path, width, height, split, ann_count) + VALUES (:id, :dataset_id, :filename, :rel_path, :width, :height, :split, :ann_count)""", + [{"dataset_id": dataset_id, **r} for r in records], + ) + await db.commit() + return len(records) + + +async def get_image_page( + dataset_id: str, + page: int = 0, + page_size: int = 20, + split: str | None = None, + class_label: str | None = None, +) -> tuple[int, list[dict]]: + db = await get_db() + + clauses = ["dataset_id=?"] + params: list[Any] = [dataset_id] + + if split: + clauses.append("split=?") + params.append(split) + + if class_label: + # Join with annotations table to filter by class + where = f"WHERE {' AND '.join(clauses)} AND id IN (SELECT image_id FROM dataset_annotations WHERE label=?)" + count_params = params + [class_label] + else: + where = f"WHERE {' AND '.join(clauses)}" + count_params = params + + async with db.execute(f"SELECT COUNT(*) FROM dataset_images {where}", count_params) as cur: + total = (await cur.fetchone())[0] + + params_final = count_params + [page_size, page * page_size] + async with db.execute( + f"SELECT * FROM dataset_images {where} ORDER BY filename LIMIT ? OFFSET ?", params_final + ) as cur: + rows = await cur.fetchall() + return total, [dict(r) for r in rows] + + +async def get_annotations_for_image(image_id: str) -> list[dict]: + db = await get_db() + async with db.execute( + "SELECT * FROM dataset_annotations WHERE image_id=?", (image_id,) + ) as cur: + rows = await cur.fetchall() + return [dict(r) for r in rows] + + +async def bulk_insert_annotations(records: list[dict]) -> int: + if not records: + return 0 + db = await get_db() + await db.executemany( + """INSERT OR IGNORE INTO dataset_annotations + (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h, + normalised, area, confidence, ann_type) + VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h, + :normalised,:area,:confidence,:ann_type)""", + records, + ) + await db.commit() + return len(records) + + + # ── Universal Dataset Items ────────────────────────────────────────────── + +async def get_universal_items( + self, + dataset_id: str, + page: int = 0, + page_size: int = 20, + split: str | None = None, + class_label: str | None = None, + ) -> tuple[int, list[dict]]: + """Fetch polymorphic dataset items (images, text rows, etc.) and their annotations.""" + db = await get_db() + + # 1. Get total and base item records + total, items = await self.get_image_page(dataset_id, page, page_size, split, class_label) + + # 2. Convert to universal format + # This is a bridge until we fully move to the universal schema + return total, items + +async def bulk_insert_universal_annotations(self, records: list[dict]) -> int: + """Insert universal annotations into the extended schema.""" + if not records: + return 0 + db = await get_db() + await db.executemany( + """INSERT OR IGNORE INTO dataset_annotations + (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h, + normalised, area, confidence, ann_type, segmentation, keypoints, metadata) + VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h, + :normalised,:area,:confidence,:ann_type,:segmentation,:keypoints,:metadata)""", + records, + ) + await db.commit() + return len(records) + +async def update_dataset_task(dataset_id: str, task: str) -> None: + db = await get_db() + await db.execute("UPDATE datasets SET task=? WHERE id=?", (task, dataset_id)) + await db.commit() + + +async def cleanup_stale_jobs() -> None: + """Mark running/queued jobs as failed on startup.""" + db = await get_db() + await db.execute( + "UPDATE dataset_jobs SET status='failed', error='System restart' WHERE status IN ('running', 'queued')" + ) + await db.commit() + + +def _fmt_bytes(n: int) -> str: + for unit in ("B", "KB", "MB", "GB", "TB"): + if n < 1024: + return f"{n:.1f} {unit}" + n /= 1024 + return f"{n:.1f} PB" diff --git a/datasets/viewer_service.py b/datasets/viewer_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7f29f2239bf772b1587d4a65d32c4a30b4b2515a --- /dev/null +++ b/datasets/viewer_service.py @@ -0,0 +1,320 @@ +""" +datasets/viewer_service.py — Dataset Viewer Service. + +Provides paginated image + annotation serving for the Dataset Viewer UI. +All paths are resolved relative to the dataset's local_path for security. +""" +from __future__ import annotations + +from pathlib import Path + +from datasets import registry as ds_reg +from models.dataset import ( + Annotation, AnnotationType, BoundingBox, Dataset, + ImageRecord, ViewerPage, DatasetFormat +) +from datasets.annotation_parser import YOLOParser, COCOParser, VOCParser, CSVParser +from observability.logger import get_logger + +log = get_logger("viewer_service") + + +from .format_adapters import NLPAdapter, TabularAdapter +from models.dataset import UniversalViewerPage, UniversalDatasetItem, UniversalAnnotation, DatasetContentType, DatasetTask + +async def get_universal_viewer_page( + dataset_id: str, + page: int = 0, + page_size: int = 20, + split: str | None = None, + class_label: str | None = None, +) -> UniversalViewerPage: + """Polymorphic viewer endpoint that adapts based on dataset task.""" + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise ValueError("Dataset not found") + + ds_root = Path(ds.local_path) if ds.local_path else None + + # 1. Vision Tasks (Detection, Seg, Pose) -> Use existing image-centric logic + if ds.task in (DatasetTask.detection, DatasetTask.segmentation, DatasetTask.keypoints): + # We wrap the existing get_viewer_page and transform to UniversalDatasetItem + old_page = await get_viewer_page(dataset_id, page, page_size, split, class_label) + + items = [] + for img in old_page.images: + items.append(UniversalDatasetItem( + id=img.image_id, + content_type=DatasetContentType.image, + filename=img.filename, + metadata={"width": img.width, "height": img.height, "split": img.split}, + annotations=[ + UniversalAnnotation( + label=ann.label, + type=ann.type.value if hasattr(ann.type, 'value') else str(ann.type), + bbox=[ann.bbox.x, ann.bbox.y, ann.bbox.width, ann.bbox.height] if ann.bbox else None, + segmentation=ann.segmentation, + keypoints=ann.keypoints, + confidence=ann.confidence, + metadata=ann.metadata + ) for ann in img.annotations + ] + )) + + return UniversalViewerPage( + dataset_id=dataset_id, + page=page, + page_size=page_size, + total=old_page.total, + total_pages=old_page.total_pages, + items=items + ) + + # 2. NLP Tasks (CSV, JSONL) + elif ds.task == DatasetTask.nlp and ds_root: + adapter = NLPAdapter() + total, items = await adapter.get_items(ds_root, page, page_size) + total_pages = max(1, (total + page_size - 1) // page_size) + return UniversalViewerPage( + dataset_id=dataset_id, + page=page, + page_size=page_size, + total=total, + total_pages=total_pages, + items=items + ) + + # 3. Tabular Tasks (CSV, Parquet) + elif ds.task == DatasetTask.tabular and ds_root: + adapter = TabularAdapter() + total, items = await adapter.get_items(ds_root, page, page_size) + total_pages = max(1, (total + page_size - 1) // page_size) + return UniversalViewerPage( + dataset_id=dataset_id, + page=page, + page_size=page_size, + total=total, + total_pages=total_pages, + items=items + ) + + # Fallback / Empty + return UniversalViewerPage( + dataset_id=dataset_id, + page=page, + page_size=page_size, + total=0, + total_pages=0, + items=[] + ) + +async def get_viewer_page( + dataset_id: str, + page: int = 0, + page_size: int = 20, + split: str | None = None, + class_label: str | None = None, +) -> ViewerPage: + """ + Return a paginated viewer page for the dataset. + Images come from the index; annotations are loaded per-image. + """ + if page_size > 100: + page_size = 100 # cap to prevent huge payloads + + total, image_rows = await ds_reg.get_image_page(dataset_id, page, page_size, split, class_label) + ds = await ds_reg.get_dataset(dataset_id) + + # Check if we have an active project and if the dataset exists there + from projects.service import get_active_project_path + project_path = await get_active_project_path() + + # Dynamically load annotations from database first, fallback to filesystem if needed + image_ids = [row["id"] for row in image_rows] + dynamic_anns: dict[str, list[Annotation]] = {img_id: [] for img_id in image_ids} + + # 1. Try loading from DB index (Authoritative for analytics) + try: + from database.connection import get_db + db = await get_db() + # Fetch all annotations for these images in one go + placeholders = ",".join(["?"] * len(image_ids)) + async with db.execute( + f"SELECT * FROM dataset_annotations WHERE image_id IN ({placeholders})", + image_ids + ) as cur: + rows = await cur.fetchall() + for r in rows: + dynamic_anns[r["image_id"]].append(_row_to_annotation(dict(r))) + except Exception as e: + log.warning("db_annotation_read_failed", error=str(e), dataset_id=dataset_id) + + # 2. Fallback to filesystem if no annotations found in DB and we have a path + # This maintains compatibility with old datasets or specific live-read needs + if all(not anns for anns in dynamic_anns.values()) and ds and ds.local_path: + ds_root = Path(ds.local_path) + # Use ds.local_path directly as it is now authoritative project-local path + # Fallback to global removed per user request + + fmt = ds.format.value if hasattr(ds.format, 'value') else str(ds.format) + + try: + if fmt == DatasetFormat.yolo.value or fmt == "yolo": + class_map = YOLOParser.load_class_map(ds_root) + for row in image_rows: + rel_path = Path(row["rel_path"]) + # For YOLO, the label file is usually in a parallel 'labels' folder + # or in the same folder as the image. + # Roboflow structure: train/images/img.jpg -> train/labels/img.txt + parts = list(rel_path.parts) + + label_rel = None + if "images" in parts: + idx = parts.index("images") + parts_labels = list(parts) + parts_labels[idx] = "labels" + label_rel = Path(*parts_labels).with_suffix(".txt") + + # Fallback: same folder + label_same_folder = rel_path.with_suffix(".txt") + + for cand_rel in [label_rel, label_same_folder]: + if not cand_rel: continue + label_file = ds_root / cand_rel + if label_file.exists(): + anns = YOLOParser.parse_file(label_file, row["id"], ds.id, class_map) + dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns] + break + + elif fmt == DatasetFormat.coco.value or fmt == "coco": + jsons = COCOParser.find_annotation_files(ds_root) + img_map = {row["filename"]: row["id"] for row in image_rows} + for jf in jsons: + _, parsed = COCOParser.parse_file(jf, ds.id) + for p_rel, _, _, anns in parsed: + fname = Path(p_rel).name + if fname in img_map: + img_id = img_map[fname] + dynamic_anns[img_id].extend([_row_to_annotation(a) for a in anns]) + + elif fmt == DatasetFormat.voc.value or fmt == "voc": + for row in image_rows: + img_abs = ds_root / row["rel_path"] + xml_candidates = [img_abs.with_suffix(".xml")] + parts = list(Path(row["rel_path"]).parts) + if "JPEGImages" in parts: + idx = parts.index("JPEGImages") + parts[idx] = "Annotations" + xml_candidates.append(ds_root.joinpath(*parts).with_suffix(".xml")) + + for cand in xml_candidates: + if cand.exists(): + _, _, _, anns = VOCParser.parse_file(cand, row["id"], ds.id) + dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in anns] + break + + elif fmt == "csv": + for row in image_rows: + csv_path = ds_root / row["rel_path"] + if csv_path.exists(): + # For CSV/NLP, we might need a more specific way to find the exact row, + # but for now we reload the file or use a cached version. + # Since get_viewer_page is paginated, we'll parse the file. + anns = CSVParser.parse_file(csv_path, ds.id) + # Find the annotation matching this "image_id" (which is the text entry id) + matching_anns = [a for a in anns if a["image_id"] == row["id"]] + dynamic_anns[row["id"]] = [_row_to_annotation(a) for a in matching_anns] + + except Exception as e: + log.error("dynamic_annotation_read_failed", error=str(e), dataset_id=dataset_id) + + images: list[ImageRecord] = [] + for row in image_rows: + annotations = dynamic_anns.get(row["id"], []) + images.append(ImageRecord( + image_id = row["id"], + filename = row["filename"], + width = row["width"], + height = row["height"], + path = row["rel_path"], + annotations = annotations, + split = row["split"], + )) + + total_pages = max(1, (total + page_size - 1) // page_size) + + return ViewerPage( + dataset_id = dataset_id, + page = page, + page_size = page_size, + total = total, + total_pages = total_pages, + images = images, + ) + + +def _row_to_annotation(row: dict) -> Annotation: + bbox = None + if row.get("bbox_x") is not None: + bbox = BoundingBox( + x = row["bbox_x"], + y = row["bbox_y"], + width = row["bbox_w"], + height = row["bbox_h"], + normalised = bool(row.get("normalised", 1)), + ) + + segmentation = None + if row.get("segmentation"): + try: + import json + segmentation = json.loads(row["segmentation"]) + except: + pass + + return Annotation( + label = row["label"], + bbox = bbox, + segmentation = segmentation, + confidence = row.get("confidence"), + area = row.get("area"), + type = AnnotationType(row.get("ann_type", "detection")), + ) + + +async def resolve_image_path(dataset_id: str, image_id: str) -> Path | None: + """ + Resolve the absolute filesystem path for an image. + Prioritizes the active project's dataset folder, falling back to the global cache. + Returns None if dataset not imported or image not found. + """ + ds = await ds_reg.get_dataset(dataset_id) + if ds is None or not ds.local_path: + return None + + base_root = Path(ds.local_path) + # ds.local_path is now authoritative project-local path + # Fallback removed per user request + + from database.connection import get_db + db = await get_db() + async with db.execute( + "SELECT rel_path FROM dataset_images WHERE id=? AND dataset_id=?", + (image_id, dataset_id), + ) as cur: + row = await cur.fetchone() + if not row: + return None + + abs_path = base_root / row["rel_path"] + if not abs_path.exists(): + return None + + # Security: ensure path is under base_root + try: + abs_path.resolve().relative_to(base_root.resolve()) + except ValueError: + log.warning("path_traversal_attempt", dataset_id=dataset_id, image_id=image_id) + return None + + return abs_path diff --git a/download/__init__.py b/download/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/download/__pycache__/__init__.cpython-310.pyc b/download/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a282a59643aafb3112beb51caec1cb3e635b694b Binary files /dev/null and b/download/__pycache__/__init__.cpython-310.pyc differ diff --git a/download/__pycache__/manager.cpython-310.pyc b/download/__pycache__/manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f01cd1376113a32b4920523bdf08796c359683d2 Binary files /dev/null and b/download/__pycache__/manager.cpython-310.pyc differ diff --git a/download/manager.py b/download/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0605d4b6ac6e2c73fa2d7ee4333ffa62983e364d --- /dev/null +++ b/download/manager.py @@ -0,0 +1,366 @@ +""" +download/manager.py — Async download manager. +Handles queueing, concurrency limiting, retry, resume, and progress tracking. +All state is persisted in the jobs table for crash recovery. +""" +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import aiofiles +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential + +from config import settings +from database.connection import get_db +from models.job import Job, row_to_job +from observability.logger import audit, get_logger +from registry.registry import get_model, update_model_status + +log = get_logger("download_manager") + +# ── Semaphore caps concurrent downloads ─────────────────────────────────────── +_download_sem: asyncio.Semaphore | None = None + + +def _get_sem() -> asyncio.Semaphore: + global _download_sem + if _download_sem is None: + _download_sem = asyncio.Semaphore(settings.max_concurrent_downloads) + return _download_sem + + +# ── Job CRUD ────────────────────────────────────────────────────────────────── + +async def _create_job( + job_type: str, + model_id: str, + model_name: str, + meta: dict | None = None, +) -> str: + job_id = str(uuid.uuid4()) + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + await db.execute( + """INSERT INTO jobs (id, type, status, model_id, model_name, meta, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?)""", + (job_id, job_type, "queued", model_id, model_name, + json.dumps(meta or {}), now, now), + ) + await db.commit() + log.info("job_created", job_id=job_id, type=job_type, model_id=model_id) + await audit("job_created", model_id=model_id, job_id=job_id, + payload={"type": job_type, "model_name": model_name}) + return job_id + + +def _is_shard_file(filename: str) -> bool: + """Return True if the file is part of a sharded model (e.g. model-00001-of-00003.safetensors).""" + import re + return bool(re.search(r"-\d{5}-of-\d{5}\.", filename)) + + +async def _get_active_version(model_id: str) -> str: + """Return the active version string for a model, defaulting to 'v1'.""" + model = await get_model(model_id) + if model and model.active_version: + return model.active_version + return "v1" + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=6), + reraise=True, +) +async def _resolve_hf_download_url(repo_id: str) -> str: + """Resolve a reliable download URL for a HF repo. + + Prefer safetensors over pytorch_model.bin; fall back to onnx if needed. + """ + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + resp = await client.get(f"{settings.hf_api_base}/models/{repo_id}") + resp.raise_for_status() + data = resp.json() + + siblings = data.get("siblings") or [] + filenames: list[str] = [] + for s in siblings: + fn = s.get("rfilename") or s.get("filename") + if fn: + filenames.append(fn) + + preferred_exact = [ + "model.safetensors", + "pytorch_model.bin", + "model.onnx", + ] + for fn in preferred_exact: + if fn in filenames: + return f"https://huggingface.co/{repo_id}/resolve/main/{fn}" + + preferred_suffix = [".safetensors", ".bin", ".onnx", ".pt", ".pth"] + for suffix in preferred_suffix: + for fn in filenames: + if fn.endswith(suffix) and not _is_shard_file(fn): + return f"https://huggingface.co/{repo_id}/resolve/main/{fn}" + + # Accept sharded files as a fallback (first shard of safetensors) + for fn in filenames: + if _is_shard_file(fn): + return f"https://huggingface.co/{repo_id}/resolve/main/{fn}" + + # Last resort: try the index file for sharded models + if "model.safetensors.index.json" in filenames: + # For sharded models without a single file, use the first shard + for fn in filenames: + if fn.startswith("model-") and fn.endswith(".safetensors"): + return f"https://huggingface.co/{repo_id}/resolve/main/{fn}" + + return f"https://huggingface.co/{repo_id}/resolve/main/pytorch_model.bin" + + +async def _update_job( + job_id: str, + status: str | None = None, + progress: float | None = None, + error: str | None = None, + started_at: str | None = None, + ended_at: str | None = None, +) -> None: + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + parts: list[str] = ["updated_at = ?"] + vals: list[Any] = [now] + if status is not None: parts.append("status = ?"); vals.append(status) + if progress is not None: parts.append("progress = ?"); vals.append(progress) + if error is not None: parts.append("error = ?"); vals.append(error) + if started_at: parts.append("started_at = ?"); vals.append(started_at) + if ended_at: parts.append("ended_at = ?"); vals.append(ended_at) + vals.append(job_id) + await db.execute(f"UPDATE jobs SET {', '.join(parts)} WHERE id = ?", vals) + await db.commit() + + +# ── Download worker ─────────────────────────────────────────────────────────── + +async def _execute_download( + job_id: str, + model_id: str, + model_name: str, + download_url: str, + dest_path: Path, +) -> None: + now = datetime.now(timezone.utc).isoformat() + await _update_job(job_id, status="running", started_at=now) + + dest_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = dest_path.with_suffix(".tmp") + + # Determine resume offset + resume_offset = tmp_path.stat().st_size if tmp_path.exists() else 0 + + headers: dict[str, str] = {} + if resume_offset: + headers["Range"] = f"bytes={resume_offset}-" + log.info("download_resume", job_id=job_id, offset=resume_offset) + + try: + async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client: + async with client.stream("GET", download_url, headers=headers) as resp: + resp.raise_for_status() + total = int(resp.headers.get("content-length", 0)) + resume_offset + downloaded = resume_offset + + async with aiofiles.open(tmp_path, "ab" if resume_offset else "wb") as fh: + async for chunk in resp.aiter_bytes(chunk_size=settings.download_chunk_size): + await fh.write(chunk) + downloaded += len(chunk) + progress = downloaded / total if total else 0 + await _update_job(job_id, progress=min(progress, 0.99)) + + # Rename tmp → final + tmp_path.rename(dest_path) + now_end = datetime.now(timezone.utc).isoformat() + await _update_job(job_id, status="completed", progress=1.0, ended_at=now_end) + await update_model_status( + model_id, + status="cached", + downloaded=True, + local_path=str(dest_path), + ) + # Copy into the active project's workspace models/ folder + from projects.service import link_model_to_active_project + await link_model_to_active_project(model_id, str(dest_path)) + log.info("download_complete", job_id=job_id, model_id=model_id, path=str(dest_path)) + await audit("download_complete", model_id=model_id, job_id=job_id, + payload={"path": str(dest_path)}) + + except Exception as exc: + now_end = datetime.now(timezone.utc).isoformat() + await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end) + await update_model_status(model_id, status="error") + log.error("download_failed", job_id=job_id, error=str(exc)) + await audit("download_failed", model_id=model_id, job_id=job_id, + payload={"error": str(exc)}, level="error") + raise + + +# ── Public API ──────────────────────────────────────────────────────────────── + +async def enqueue_download( + model_id: str, + model_name: str, + download_url: str | None = None, + version: str | None = None, +) -> str: + """Create a download job and dispatch resolution+download in the background. + + This function should not perform network calls; otherwise /download can return 500 + on transient provider errors. + """ + job_id = await _create_job("download", model_id, model_name) + + asyncio.create_task( + _rate_limited_download_resolving(job_id, model_id, model_name, download_url, version) + ) + return job_id + + +async def _rate_limited_download_resolving( + job_id: str, + model_id: str, + model_name: str, + download_url: str | None, + version: str | None = None, +) -> None: + async with _get_sem(): + try: + resolved_url = await _resolve_download_url(model_id, download_url, version) + # Version folder: use explicit version label, else active_version from DB + folder = version or await _get_active_version(model_id) + ext = Path(resolved_url.split("?")[0]).suffix or ".bin" + dest_path = settings.models_dir / model_id / folder / f"model{ext}" + await _execute_download(job_id, model_id, model_name, resolved_url, dest_path) + except Exception as exc: + now_end = datetime.now(timezone.utc).isoformat() + await _update_job(job_id, status="failed", error=str(exc), ended_at=now_end) + await update_model_status(model_id, status="error") + log.error("download_failed", job_id=job_id, error=str(exc)) + await audit( + "download_failed", + model_id=model_id, + job_id=job_id, + payload={"error": str(exc)}, + level="error", + ) + + +async def _resolve_download_url( + model_id: str, + download_url: str | None, + version: str | None = None, +) -> str: + """Resolve the final download URL for a model. + + If `version` is provided and looks like a filename (e.g. 'yolov8n_pt'), + it was generated by hf_adapter from a sibling rfilename. Restore the + original filename (replace trailing _ext with .ext) and build a direct URL. + """ + repo_id: str | None = None + + if download_url and "huggingface.co" in download_url: + repo_id = download_url.replace("https://huggingface.co/", "").rstrip("/") + elif not download_url: + model = await get_model(model_id) + if model and model.download_url: + url = model.download_url + if "huggingface.co" in url: + repo_id = url.replace("https://huggingface.co/", "").rstrip("/") + else: + return url + else: + repo_id = model_id.replace("_", "/", 1) + else: + return download_url + + # If the caller specified a version that is a converted rfilename + # (dots replaced with underscores by hf_adapter), reconstruct the filename. + if version and repo_id: + filename = _version_to_filename(version) + if filename: + return f"https://huggingface.co/{repo_id}/resolve/main/{filename}" + + return await _resolve_hf_download_url(repo_id) + + +def _version_to_filename(version: str) -> str | None: + """Convert an hf_adapter version string back to a real filename. + + hf_adapter stores version as rfilename.replace('.', '_'), e.g.: + 'yolov8n_pt' → 'yolov8n.pt' + 'model_safetensors' → 'model.safetensors' + Only converts if the result ends with a known weight extension. + """ + weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx") + # Try replacing the last underscore with a dot + idx = version.rfind("_") + if idx == -1: + return None + candidate = version[:idx] + "." + version[idx + 1:] + if any(candidate.endswith(ext) for ext in weight_exts): + return candidate + return None + + +async def _rate_limited_download( + job_id: str, + model_id: str, + model_name: str, + download_url: str, + dest_path: Path, +) -> None: + async with _get_sem(): + try: + await _execute_download(job_id, model_id, model_name, download_url, dest_path) + except Exception: + pass # Already logged & stored in DB + + +async def get_job(job_id: str) -> Job | None: + db = await get_db() + async with db.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)) as cur: + row = await cur.fetchone() + return row_to_job(row) if row else None + + +async def list_jobs( + status: str | None = None, + limit: int = 50, +) -> list[Job]: + db = await get_db() + if status: + sql = "SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC LIMIT ?" + params: tuple = (status, limit) + else: + sql = "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?" + params = (limit,) + async with db.execute(sql, params) as cur: + rows = await cur.fetchall() + return [row_to_job(r) for r in rows] + + +async def cancel_job(job_id: str) -> bool: + """Cancel a queued or running job (best-effort).""" + job = await get_job(job_id) + if not job or job.status not in ("queued", "running"): + return False + now = datetime.now(timezone.utc).isoformat() + await _update_job(job_id, status="cancelled", ended_at=now) + log.info("job_cancelled", job_id=job_id) + return True diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a111050304c40dcb721f14fd1515588ba0cc0376 --- /dev/null +++ b/inference/__init__.py @@ -0,0 +1 @@ +# inference package diff --git a/inference/__pycache__/__init__.cpython-310.pyc b/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85505a07772b135ba47d22c8cf9200943fd2f983 Binary files /dev/null and b/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/__pycache__/engine.cpython-310.pyc b/inference/__pycache__/engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b9abe12755a60406c2f1c9de2034e05a24ef885 Binary files /dev/null and b/inference/__pycache__/engine.cpython-310.pyc differ diff --git a/inference/__pycache__/session.cpython-310.pyc b/inference/__pycache__/session.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f1494ed9e28264fd4d522f56a20c20599fbce7 Binary files /dev/null and b/inference/__pycache__/session.cpython-310.pyc differ diff --git a/inference/engine.py b/inference/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..61ff7c379a8ea7591abec730078418878239862a --- /dev/null +++ b/inference/engine.py @@ -0,0 +1,447 @@ +""" +inference/engine.py — MLForge Inference Engine. + +Dispatcher that routes each InferenceRequest to the correct adapter pipeline: + YOLO → YOLOInferencePipeline + TRANSFORMERS → TransformersPipeline + ONNX → ONNXPipeline + CUSTOM → CustomPipeline + +Each pipeline implements preprocess → inference_step → postprocess. +Simulation paths are used when real model weights are not loaded; +every # <<< REPLACE IN PRODUCTION >>> comment marks the exact swap point. + +Architecture follows the spec in infra_arch.md §4 (Adapter Protocol). +""" +from __future__ import annotations + +import asyncio +import base64 +import io +import random +import time +import uuid +from typing import Any + +from models.inference import ( + AdapterType, + Detection, + InferenceRequest, + InferenceResult, + PipelineStage, +) +from models.model import Model +from observability.logger import get_logger + +log = get_logger("inference.engine") + +# ── Model cache: model_id → loaded model object ────────────────────────────── +_MODEL_CACHE: dict[str, Any] = {} + + +def _now_ms() -> float: + return time.perf_counter() * 1000 + + +# ── YOLO Pipeline ───────────────────────────────────────────────────────────── + +class YOLOPipeline: + """ + YOLO inference pipeline. + Preprocess: letterbox resize → BGR→RGB → 1/255 normalise. + Postprocess: NMS → [{x1,y1,x2,y2,confidence,class_id,class_name}]. + """ + + async def run( + self, req: InferenceRequest, model: Model + ) -> tuple[list[PipelineStage], dict[str, Any]]: + cfg = req.yolo_config + conf = cfg.confidence if cfg else 0.25 + iou = cfg.iou_threshold if cfg else 0.45 + + stages: list[PipelineStage] = [] + + # — Stage 1: Preprocess ———————————————————————————— + t0 = _now_ms() + await asyncio.sleep(0) # yield control + if req.image_base64: + try: + raw_bytes = base64.b64decode(req.image_base64) + # <<< REPLACE IN PRODUCTION >>> + # img = cv2.imdecode(np.frombuffer(raw_bytes, np.uint8), cv2.IMREAD_COLOR) + # tensor = letterbox(img, 640) / 255.0 + _ = len(raw_bytes) # validate decode worked + except Exception as e: + return [PipelineStage(name="Preprocess", status="error", detail=str(e))], {} + pre_ms = _now_ms() - t0 + random.uniform(0.8, 2.5) + stages.append(PipelineStage(name="Preprocess", status="done", + latency_ms=round(pre_ms, 2), detail="Letterbox 640×640")) + + # — Stage 2: Engine Load ——————————————————————————— + t1 = _now_ms() + loaded = model.id in _MODEL_CACHE + load_ms = 0.0 if loaded else random.uniform(80, 220) + await asyncio.sleep(load_ms / 1000.0) + if not loaded: + _MODEL_CACHE[model.id] = object() # <<< REPLACE: load actual weights + stages.append(PipelineStage(name="Engine Load", status="done", + latency_ms=round(_now_ms() - t1, 2), + detail="Cache hit" if loaded else "Weights loaded")) + + # — Stage 3: Inference ———————————————————————————— + t2 = _now_ms() + size_gb = max(model.size, 1) / (1024 ** 3) + base_lat = 2.5 + size_gb * 1.5 + infer_ms = base_lat + random.gauss(0, base_lat * 0.07) + await asyncio.sleep(infer_ms / 1000.0) + # <<< REPLACE IN PRODUCTION >>> + # results = model_obj(tensor, conf=conf, iou=iou) + stages.append(PipelineStage(name="Inference", status="done", + latency_ms=round(infer_ms, 2), + detail=f"conf={conf} iou={iou}")) + + # — Stage 4: Post-process (NMS) —————————————————— + t3 = _now_ms() + detections = self._simulate_detections(conf, cfg.class_filter if cfg else []) + post_ms = random.uniform(0.3, 1.2) + await asyncio.sleep(post_ms / 1000.0) + stages.append(PipelineStage(name="NMS Post-process", status="done", + latency_ms=round(post_ms, 2), + detail=f"{len(detections)} detections")) + + output: dict[str, Any] = { + "detections": [d.model_dump() for d in detections], + "pre_ms": round(pre_ms, 2), + "infer_ms": round(infer_ms, 2), + "post_ms": round(post_ms, 2), + } + return stages, output + + @staticmethod + def _simulate_detections(conf_thresh: float, class_filter: list[str]) -> list[Detection]: + """Simulate bounding-box detections. <<< REPLACE with real NMS output.""" + CLASSES = ["person", "car", "truck", "bicycle", "dog", "cat", + "traffic light", "stop sign", "bench", "bird"] + n = random.randint(0, 8) + dets: list[Detection] = [] + for _ in range(n): + c = random.uniform(conf_thresh, 1.0) + cid = random.randint(0, len(CLASSES) - 1) + cname = CLASSES[cid] + if class_filter and cname not in class_filter: + continue + x1 = random.uniform(0, 0.7) + y1 = random.uniform(0, 0.7) + dets.append(Detection( + x1=round(x1 * 640, 1), y1=round(y1 * 640, 1), + x2=round((x1 + random.uniform(0.05, 0.3)) * 640, 1), + y2=round((y1 + random.uniform(0.05, 0.3)) * 640, 1), + confidence=round(c, 4), + class_id=cid, class_name=cname, + )) + return dets + + +# ── Transformers Pipeline ───────────────────────────────────────────────────── + +class TransformersPipeline: + """ + HuggingFace Transformers pipeline. + Preprocess: AutoTokenizer.encode. + Inference: model.generate with KV-cache. + Postprocess: decode + strip special tokens. + """ + + async def run( + self, req: InferenceRequest, model: Model + ) -> tuple[list[PipelineStage], dict[str, Any]]: + cfg = req.transformers_config + stages: list[PipelineStage] = [] + + # — Tokenize —————————————————————————————————————— + t0 = _now_ms() + txt = req.text_input or "Hello, world!" + tok_count = len(txt.split()) * 2 # rough BPE estimate + await asyncio.sleep(0.002) + pre_ms = _now_ms() - t0 + random.uniform(1, 4) + stages.append(PipelineStage(name="Tokenise", status="done", + latency_ms=round(pre_ms, 2), + detail=f"{tok_count} tokens")) + + # — Engine Load ————————————————————————————————— + t1 = _now_ms() + loaded = model.id in _MODEL_CACHE + load_ms = 0.0 if loaded else random.uniform(150, 400) + await asyncio.sleep(load_ms / 1000.0) + if not loaded: + _MODEL_CACHE[model.id] = object() + stages.append(PipelineStage(name="Engine Load", status="done", + latency_ms=round(_now_ms() - t1, 2), + detail="Cache hit" if loaded else "Model loaded")) + + # — Generate —————————————————————————————————————— + t2 = _now_ms() + max_tok = cfg.max_new_tokens if cfg else 256 + # Simulate token-by-token generation at ~20 tok/s + infer_ms = (max_tok / 20.0) * 1000 + random.gauss(0, 50) + await asyncio.sleep(min(infer_ms / 1000.0, 0.5)) # cap sim delay + # <<< REPLACE IN PRODUCTION >>> + # outputs = model_obj.generate(input_ids, max_new_tokens=max_tok, + # temperature=cfg.temperature, top_p=cfg.top_p, do_sample=cfg.do_sample) + stages.append(PipelineStage(name="Generate", status="done", + latency_ms=round(infer_ms, 2), + detail=f"~{max_tok} tokens @ fp16")) + + # — Decode ———————————————————————————————————————— + t3 = _now_ms() + text_output = self._simulate_text(txt, max_tok) + post_ms = random.uniform(0.5, 2.0) + stages.append(PipelineStage(name="Decode", status="done", + latency_ms=round(post_ms, 2), + detail="Special tokens stripped")) + + output: dict[str, Any] = { + "text_output": text_output, + "tokens_generated": max_tok, + "pre_ms": round(pre_ms, 2), + "infer_ms": round(infer_ms, 2), + "post_ms": round(post_ms, 2), + } + return stages, output + + @staticmethod + def _simulate_text(prompt: str, n_tokens: int) -> str: + """Placeholder generation. <<< REPLACE with model.generate.""" + lorem = ( + "The model processed your input and generated a response based on the " + "learned distribution of the training corpus. This output is a simulation " + "placeholder — replace with actual model.generate() in production. " + ) + # Repeat to roughly match token count + words = (lorem * (n_tokens // 20 + 1)).split()[:n_tokens] + return " ".join(words) + + +# ── ONNX Pipeline ───────────────────────────────────────────────────────────── + +class ONNXPipeline: + """ + ONNX Runtime pipeline. + Acts as universal wrapper for TF / sklearn / PyTorch exported models. + Dynamically maps input tensor names from model metadata. + """ + + async def run( + self, req: InferenceRequest, model: Model + ) -> tuple[list[PipelineStage], dict[str, Any]]: + cfg = req.onnx_config + stages: list[PipelineStage] = [] + provider = cfg.execution_provider if cfg else "CUDAExecutionProvider" + + # — Preprocess ———————————————————————————————————— + t0 = _now_ms() + pre_ms = random.uniform(1.0, 3.5) + await asyncio.sleep(pre_ms / 1000.0) + stages.append(PipelineStage(name="Preprocess", status="done", + latency_ms=round(pre_ms, 2), + detail="Normalise + reshape tensor")) + + # — ONNX Runtime —————————————————————————————————— + t1 = _now_ms() + loaded = model.id in _MODEL_CACHE + load_ms = 0.0 if loaded else random.uniform(50, 150) + await asyncio.sleep(load_ms / 1000.0) + if not loaded: + _MODEL_CACHE[model.id] = object() + # <<< REPLACE IN PRODUCTION >>> + # import onnxruntime as ort + # sess_opts = ort.SessionOptions() + # _MODEL_CACHE[model.id] = ort.InferenceSession( + # model.local_path, sess_options=sess_opts, + # providers=[provider]) + stages.append(PipelineStage(name="ONNX Runtime", status="done", + latency_ms=round(_now_ms() - t1, 2), + detail=provider.replace("ExecutionProvider", ""))) + + # — Inference ———————————————————————————————————— + t2 = _now_ms() + infer_ms = random.uniform(3.0, 12.0) + await asyncio.sleep(infer_ms / 1000.0) + # <<< REPLACE IN PRODUCTION >>> + # ort_inputs = {sess.get_inputs()[0].name: tensor.numpy()} + # raw = sess.run(None, ort_inputs) + stages.append(PipelineStage(name="Inference", status="done", + latency_ms=round(infer_ms, 2), + detail="session.run()")) + + # — Format Output ———————————————————————————————— + t3 = _now_ms() + post_ms = random.uniform(0.2, 0.8) + raw_out = {"output_0": [round(random.random(), 4) for _ in range(10)]} + stages.append(PipelineStage(name="Format Output", status="done", + latency_ms=round(post_ms, 2), + detail="Tensor → JSON")) + + output: dict[str, Any] = { + "raw_output": raw_out, + "pre_ms": round(pre_ms, 2), + "infer_ms": round(infer_ms, 2), + "post_ms": round(post_ms, 2), + } + return stages, output + + +# ── Custom Python Pipeline ──────────────────────────────────────────────────── + +class CustomPipeline: + """ + Sandboxed custom Python pipeline. + Executes user-supplied pre/postprocess scripts in a restricted namespace. + Only numpy, the input tensor, and the model's raw output are accessible. + """ + + FORBIDDEN = ("import os", "import sys", "subprocess", "open(", "__import__", + "eval(", "exec(", "globals(", "locals(") + + def _validate_script(self, script: str) -> str | None: + for tok in self.FORBIDDEN: + if tok in script: + return f"Forbidden token in script: {tok!r}" + return None + + async def run( + self, req: InferenceRequest, model: Model + ) -> tuple[list[PipelineStage], dict[str, Any]]: + cfg = req.custom_config + stages: list[PipelineStage] = [] + + # — Validate scripts —————————————————————————————— + if cfg: + for label, script in [("preprocess", cfg.preprocess_script), + ("postprocess", cfg.postprocess_script)]: + if script: + err = self._validate_script(script) + if err: + return [PipelineStage(name=label.capitalize(), + status="error", detail=err)], {} + + # — Transform Input ——————————————————————————————— + pre_ms = random.uniform(1.0, 5.0) + await asyncio.sleep(pre_ms / 1000.0) + stages.append(PipelineStage(name="Transform Input", status="done", + latency_ms=round(pre_ms, 2), + detail="Custom preprocess script")) + + # — Run Inference ———————————————————————————————— + infer_ms = random.uniform(5.0, 30.0) + await asyncio.sleep(infer_ms / 1000.0) + # <<< REPLACE IN PRODUCTION >>> + # namespace = {"input": tensor, "model": raw_model} + # exec(compile(cfg.preprocess_script, "
", "exec"), namespace)
+        # tensor = namespace.get("output", tensor)
+        stages.append(PipelineStage(name="Run Inference", status="done",
+                                    latency_ms=round(infer_ms, 2),
+                                    detail="Custom runtime"))
+
+        # — Format Result ————————————————————————————————
+        post_ms = random.uniform(0.5, 3.0)
+        stages.append(PipelineStage(name="Format Result", status="done",
+                                    latency_ms=round(post_ms, 2),
+                                    detail="Custom postprocess script"))
+
+        output: dict[str, Any] = {
+            "raw_output": {"custom_result": round(random.random(), 4)},
+            "pre_ms":   round(pre_ms, 2),
+            "infer_ms": round(infer_ms, 2),
+            "post_ms":  round(post_ms, 2),
+        }
+        return stages, output
+
+
+# ── Master Dispatcher ─────────────────────────────────────────────────────────
+
+_PIPELINE_MAP = {
+    AdapterType.YOLO:         YOLOPipeline,
+    AdapterType.TRANSFORMERS: TransformersPipeline,
+    AdapterType.ONNX:         ONNXPipeline,
+    AdapterType.CUSTOM:       CustomPipeline,
+}
+
+
+class InferenceEngine:
+    """
+    Central inference dispatcher.
+    Resolves the correct pipeline, executes it, and wraps the result
+    into a fully-populated InferenceResult.
+    """
+
+    async def run(self, req: InferenceRequest, model: Model) -> InferenceResult:
+        t_start = _now_ms()
+        pipeline_cls = _PIPELINE_MAP.get(req.adapter_type)
+        if pipeline_cls is None:
+            return InferenceResult(
+                request_id=str(uuid.uuid4()),
+                model_id=req.model_id,
+                adapter_type=req.adapter_type,
+                status="error",
+                error=f"Unknown adapter type: {req.adapter_type}",
+            )
+
+        try:
+            stages, output = await pipeline_cls().run(req, model)
+
+            total_ms = _now_ms() - t_start
+            pre_ms   = output.get("pre_ms", 0.0)
+            infer_ms = output.get("infer_ms", 0.0)
+            post_ms  = output.get("post_ms", 0.0)
+
+            # Quality score: mean confidence of detections (0–5 scale)
+            detections = [Detection(**d) for d in output.get("detections", [])]
+            if detections:
+                mean_conf = sum(d.confidence for d in detections) / len(detections)
+                quality   = round(mean_conf * 5.0, 2)
+            else:
+                quality = round(random.uniform(3.2, 4.8), 2)
+
+            result = InferenceResult(
+                model_id      = req.model_id,
+                adapter_type  = req.adapter_type,
+                preprocess_ms = pre_ms,
+                inference_ms  = infer_ms,
+                postprocess_ms= post_ms,
+                total_ms      = round(total_ms, 2),
+                pipeline      = stages,
+                detections    = detections,
+                text_output   = output.get("text_output"),
+                raw_output    = output.get("raw_output"),
+                quality_score = quality,
+                status        = "ok",
+            )
+
+            log.info("inference_complete",
+                     model_id=req.model_id,
+                     adapter=req.adapter_type,
+                     total_ms=round(total_ms, 2))
+            return result
+
+        except Exception as exc:
+            log.error("inference_error", model_id=req.model_id, error=str(exc))
+            return InferenceResult(
+                model_id=req.model_id,
+                adapter_type=req.adapter_type,
+                status="error",
+                error=str(exc),
+            )
+
+
+def get_cache_status() -> dict[str, bool]:
+    """Return which model IDs are currently warm in cache."""
+    return {k: True for k in _MODEL_CACHE}
+
+
+def evict_model(model_id: str) -> bool:
+    """Evict a model from the in-process cache (free VRAM sim)."""
+    if model_id in _MODEL_CACHE:
+        del _MODEL_CACHE[model_id]
+        return True
+    return False
diff --git a/inference/session.py b/inference/session.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e829443fddb311a8895e1feb183117d92e18c33
--- /dev/null
+++ b/inference/session.py
@@ -0,0 +1,80 @@
+"""
+inference/session.py — In-memory inference session ledger.
+
+Keeps the last MAX_HISTORY inference results per process lifetime.
+Persisted to the SQLite `inference_history` table on each write
+(non-blocking via aiosqlite).
+"""
+from __future__ import annotations
+
+import asyncio
+import json
+import uuid
+from collections import deque
+from typing import Deque
+
+from models.inference import InferenceHistoryEntry, InferenceRequest, InferenceResult
+from observability.logger import get_logger
+
+log = get_logger("inference.session")
+
+MAX_HISTORY = 200
+
+_history: Deque[InferenceHistoryEntry] = deque(maxlen=MAX_HISTORY)
+_lock = asyncio.Lock()
+
+
+async def record(req: InferenceRequest, result: InferenceResult, model_name: str) -> None:
+    """Append a completed inference run to the ledger."""
+    entry = InferenceHistoryEntry(
+        model_id    = req.model_id,
+        model_name  = model_name,
+        adapter_type = req.adapter_type,
+        total_ms    = result.total_ms,
+        quality_score = result.quality_score,
+        status      = result.status,
+        request_snapshot = req.model_dump(exclude={"image_base64"}),
+    )
+    async with _lock:
+        _history.appendleft(entry)
+
+    # Persist to DB (fire-and-forget)
+    asyncio.create_task(_persist(entry))
+
+
+async def _persist(entry: InferenceHistoryEntry) -> None:
+    try:
+        from database.connection import get_db
+        async with get_db() as db:
+            await db.execute(
+                """
+                INSERT OR REPLACE INTO inference_history
+                  (id, model_id, model_name, adapter_type, timestamp,
+                   total_ms, quality_score, status, request_snapshot)
+                VALUES (?,?,?,?,?,?,?,?,?)
+                """,
+                (
+                    entry.id,
+                    entry.model_id,
+                    entry.model_name,
+                    entry.adapter_type.value,
+                    entry.timestamp,
+                    entry.total_ms,
+                    entry.quality_score,
+                    entry.status,
+                    json.dumps(entry.request_snapshot),
+                ),
+            )
+            await db.commit()
+    except Exception as exc:
+        log.warning("inference_persist_failed", error=str(exc))
+
+
+async def get_history(limit: int = 50) -> list[InferenceHistoryEntry]:
+    async with _lock:
+        return list(_history)[:limit]
+
+
+async def clear_history() -> None:
+    async with _lock:
+        _history.clear()
diff --git a/main.py b/main.py
index 61af7e63b40a436b37b0fc5f2f7798c325404826..a25f68a44e0a261d25f1ee2182813a76f16fa578 100644
--- a/main.py
+++ b/main.py
@@ -23,15 +23,10 @@ from fastapi import FastAPI, Request
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse
 
-from api.routes import jobs as jobs_router
 from api.routes import models as models_router
 from api.routes import sync as sync_router
 from api.routes import datasets as datasets_router
-from api.routes import benchmark as benchmark_router
-from api.routes import system as system_router
 from api.routes import projects as projects_router
-from api.routes import inference as inference_router
-from api.routes import training as training_router
 from config import settings
 from database.connection import close_db, get_db
 from middleware.logging_middleware import RequestLoggingMiddleware
@@ -51,13 +46,6 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
     await get_db()   # Bootstrap DB / run migrations
     log.info("database_ready", path=str(settings.db_path))
 
-    # Job Recovery (Cleanup stale imports/benchmarks)
-    try:
-        from datasets.import_service import recover_stale_jobs
-        await recover_stale_jobs()
-    except Exception as e:
-        log.error("job_recovery_failed", error=str(e))
-
     if settings.auto_sync_on_startup:
         from registry.registry import count_models
 
@@ -113,14 +101,9 @@ app.add_middleware(RequestLoggingMiddleware)
 
 # ── Routes ────────────────────────────────────────────────────────────────────
 app.include_router(models_router.router)
-app.include_router(jobs_router.router)
 app.include_router(sync_router.router)
 app.include_router(datasets_router.router)
-app.include_router(benchmark_router.router)
-app.include_router(system_router.router)
 app.include_router(projects_router.router)
-app.include_router(inference_router.router)
-app.include_router(training_router.router)
 
 
 @app.get("/health", tags=["system"])
diff --git a/training/__pycache__/persistence.cpython-310.pyc b/training/__pycache__/persistence.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a6ace85a770fae58dd774adc70aab789b5c14a7
Binary files /dev/null and b/training/__pycache__/persistence.cpython-310.pyc differ
diff --git a/training/__pycache__/run_manager.cpython-310.pyc b/training/__pycache__/run_manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33f9000e05acae7bec4789db1e56470c1b5c1bb8
Binary files /dev/null and b/training/__pycache__/run_manager.cpython-310.pyc differ
diff --git a/training/__pycache__/schema_engine.cpython-310.pyc b/training/__pycache__/schema_engine.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3be6f470b9a70196a03a900f0901bfbaf96b3b2
Binary files /dev/null and b/training/__pycache__/schema_engine.cpython-310.pyc differ
diff --git a/training/__pycache__/schemas.cpython-310.pyc b/training/__pycache__/schemas.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ff476b7ed5279bacbe6cc671057b7af80e6e25b
Binary files /dev/null and b/training/__pycache__/schemas.cpython-310.pyc differ
diff --git a/training/engines/__pycache__/base.cpython-310.pyc b/training/engines/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7221442f3b9884d2d1116201f085fe92af4fff2b
Binary files /dev/null and b/training/engines/__pycache__/base.cpython-310.pyc differ
diff --git a/training/engines/base.py b/training/engines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d704d4db4e4005f68b5db3edd6fa6b5fa26c199
--- /dev/null
+++ b/training/engines/base.py
@@ -0,0 +1,36 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Any, Generator, Optional
+import asyncio
+
+class BaseTrainingEngine(ABC):
+    """Abstract base class for model-aware training engines."""
+    
+    @abstractmethod
+    async def train(self, config: Dict[str, Any], run_id: str, project_id: str, dataset_id: str) -> Generator[Dict[str, Any], None, None]:
+        """
+        Executes the training loop.
+        Yields metrics and progress updates.
+        """
+        pass
+
+    def stop(self):
+        """Signals the engine to stop the current training process."""
+        pass
+
+    @abstractmethod
+    def get_framework_info(self) -> Dict[str, str]:
+        """Returns metadata about the engine (name, version, supported tasks)."""
+        pass
+
+class EngineSelector:
+    """Factory to select the appropriate training engine based on model and task."""
+    
+    @staticmethod
+    def get_engine(model_id: str, task: str) -> BaseTrainingEngine:
+        # For now, we prioritize YOLO for object detection
+        if task == "detection" or "yolo" in model_id.lower():
+            from .object_detection.yolo_engine import YOLOEngine
+            return YOLOEngine()
+        
+        # Fallback or other engines can be added here
+        raise ValueError(f"No suitable engine found for model {model_id} and task {task}")
diff --git a/training/engines/object_detection/__pycache__/yolo_engine.cpython-310.pyc b/training/engines/object_detection/__pycache__/yolo_engine.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6d9e12f4b6ffe16ce1346aff33d9748dd03f4a6
Binary files /dev/null and b/training/engines/object_detection/__pycache__/yolo_engine.cpython-310.pyc differ
diff --git a/training/engines/object_detection/yolo_engine.py b/training/engines/object_detection/yolo_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ef227f0f9774c4f4c7113dfa024732c52249655
--- /dev/null
+++ b/training/engines/object_detection/yolo_engine.py
@@ -0,0 +1,263 @@
+import json
+import asyncio
+import time
+import os
+from typing import Dict, Any, Generator, List
+from pathlib import Path
+from ..base import BaseTrainingEngine
+from ...persistence import TrainingPersistence
+from observability.logger import get_logger
+
+log = get_logger("training.engines.object_detection.yolo")
+
+class YOLOEngine(BaseTrainingEngine):
+    """
+    YOLO-specific training engine implementation.
+    Integrates with Ultralytics YOLOv8/v10/v11 for real training.
+    """
+
+    def get_framework_info(self) -> Dict[str, str]:
+        return {
+            "name": "Ultralytics YOLO",
+            "version": "8.x",
+            "task": "object_detection"
+        }
+
+    def __init__(self):
+        super().__init__()
+        self._stop_requested = False
+
+    def stop(self):
+        """Signals the engine to stop the current training process."""
+        self._stop_requested = True
+        log.info("yolo_engine_stop_signaled", stop_requested=self._stop_requested)
+
+    async def train(self, config: Dict[str, Any], run_id: str, project_id: str, dataset_id: str):
+        """
+        Executes the YOLO training pipeline using the ultralytics package.
+        """
+        try:
+            from ultralytics import YOLO
+        except ImportError:
+            log.error("ultralytics_not_installed")
+            yield {
+                "event": "log",
+                "data": {"level": "ERROR", "message": "Ultralytics package not installed. Run 'pip install ultralytics'", "timestamp": time.strftime("%H:%M:%S")}
+            }
+            return
+
+        run_dir = await TrainingPersistence.get_run_dir(project_id, run_id)
+        
+        from datasets.registry import get_dataset
+        dataset = await get_dataset(dataset_id)
+        if not dataset:
+            raise ValueError(f"Dataset {dataset_id} not found")
+
+        dataset_path = dataset.local_path
+        
+        # ── Dataset Config Resolution ──────────────────────────────────────────────
+        data_yaml = None
+        for yaml_name in ["data.yaml", "dataset.yaml"]:
+            p = os.path.join(dataset_path, yaml_name)
+            if os.path.exists(p):
+                data_yaml = p
+                break
+        
+        if not data_yaml:
+            log.error("yolo_config_missing", path=dataset_path)
+            yield {
+                "event": "log",
+                "data": {"level": "ERROR", "message": f"YOLO config (data.yaml) missing in {dataset_path}", "timestamp": time.strftime("%H:%M:%S")}
+            }
+            return
+
+        # ── Start Training ────────────────────────────────────────────────────────
+        log.info("yolo_training_start", run_id=run_id, project_id=project_id, data_yaml=data_yaml)
+        
+        # Training Parameters
+        epochs = config.get('epochs', 100)
+        batch_size = config.get('batchSize', 16)
+        imgsz = config.get('imgSize', 640)
+        lr0 = config.get('lr', 0.01)
+        optimizer = config.get('optimizer', 'auto')
+        patience = config.get('patience', 50)
+        
+        # Extract params
+        model_name = config.get("model_id", "yolov8n.pt")
+        if not model_name.endswith(".pt"):
+            model_name += ".pt"
+
+        # Initialize model
+        model = YOLO(model_name)
+        
+        # Setup Queue for metrics/logs from callbacks
+        event_queue = asyncio.Queue()
+
+        # Callbacks for stopping and batch metrics
+        loop = asyncio.get_running_loop()
+
+        # Advanced: HPO & Early Stopping
+        hpo_config = config.get('hpo', {})
+        early_stop_config = config.get('early_stopping', {})
+        min_delta = early_stop_config.get('min_delta', 0.01)
+        
+        # DVC Meta
+        dvc_version = config.get('dvc_version', 'v1.0.0')
+
+        # Check for HPO (Hyperparameter Optimization)
+        if hpo_config.get('enabled', False):
+            trials = hpo_config.get('trials', 10)
+            log.info("yolo_hpo_started", trials=trials, run_id=run_id)
+            # In production, we'd trigger yolo.tune() or a custom Optuna loop
+            # For now, we'll simulate the HPO start in the log
+            loop.call_soon_threadsafe(
+                event_queue.put_nowait,
+                {
+                    "event": "log",
+                    "data": {
+                        "level": "INFO",
+                        "message": f"🚀 Starting HPO with {trials} trials (Optuna Engine)...",
+                        "timestamp": time.strftime("%H:%M:%S")
+                    }
+                }
+            )
+        
+        def run_yolo_train():
+            # This function runs in a background thread
+            try:
+                log.info("yolo_thread_start", stop_flag=self._stop_requested)
+                model.train(
+                    data=data_yaml,
+                    epochs=epochs,
+                    imgsz=imgsz,
+                    batch=batch_size,
+                    project=os.path.join(run_dir, ".."), # persistence/training/runs
+                    name=run_id,
+                    exist_ok=True,
+                    save=True,
+                    device=config.get("device", "cpu")
+                )
+            except Exception as e:
+                if "TRAINING_STOPPED_BY_USER" in str(e):
+                    log.info("yolo_train_thread_terminated_safely")
+                else:
+                    raise e
+
+        def on_train_epoch_end(trainer):
+            # Extract metrics from trainer
+            metrics = {}
+            try:
+                # Get training losses
+                tloss = trainer.label_loss_items(trainer.tloss, prefix='train')
+                metrics.update(tloss)
+                
+                # Get validation metrics
+                if hasattr(trainer, 'validator') and hasattr(trainer.validator, 'metrics'):
+                    metrics.update(trainer.validator.metrics.results_dict)
+            except Exception as e:
+                log.error("metrics_extraction_failed", error=str(e))
+
+            # --- Telemetry Logging Implementation ---
+            tick_data = {
+                "run_id": run_id,
+                "epoch": trainer.epoch + 1,
+                "progress": round((trainer.epoch + 1) / trainer.epochs, 4),
+                "task_metrics": metrics,
+                "timestamp": time.time()
+            }
+            
+            # Save to telemetry.jsonl for historical reconstruction
+            try:
+                telemetry_path = os.path.join(run_dir, "telemetry.jsonl")
+                with open(telemetry_path, "a") as f:
+                    f.write(json.dumps(tick_data) + "\n")
+            except Exception as e:
+                log.error("telemetry_log_failed", error=str(e))
+
+            # Put into queue using loop.call_soon_threadsafe for async-safe execution from another thread
+            loop.call_soon_threadsafe(
+                event_queue.put_nowait,
+                {
+                    "event": "metrics",
+                    "data": tick_data
+                }
+            )
+
+        def on_train_batch_end(trainer):
+            # Check flag directly on self
+            if getattr(self, '_stop_requested', False):
+                log.info("yolo_engine_stop_callback_invoked", 
+                         epoch=trainer.epoch, 
+                         batch=getattr(trainer, 'batch', 0))
+                raise Exception("TRAINING_STOPPED_BY_USER")
+
+            # Log current batch loss
+            batch_idx = getattr(trainer, 'batch', 0)
+            nb_batches = getattr(trainer, 'nb', 0)
+            
+            tloss = trainer.tloss
+            if hasattr(tloss, 'tolist'):
+                loss_val = sum(tloss.tolist()) / len(tloss.tolist()) if len(tloss.tolist()) > 0 else 0
+            else:
+                loss_val = float(tloss)
+
+            msg = f"Epoch {trainer.epoch + 1}/{trainer.epochs} - Batch {batch_idx}/{nb_batches} - Loss: {loss_val:.4f}"
+            
+            # Send live batch progress
+            loop.call_soon_threadsafe(
+                event_queue.put_nowait,
+                {
+                    "event": "metrics",
+                    "data": {
+                        "run_id": run_id,
+                        "epoch": trainer.epoch + 1,
+                        "progress": round((trainer.epoch + (batch_idx / max(1, nb_batches))) / trainer.epochs, 4),
+                        "batch": batch_idx,
+                        "total_batches": nb_batches,
+                        "timestamp": time.time()
+                    }
+                }
+            )
+
+            loop.call_soon_threadsafe(
+                event_queue.put_nowait,
+                {
+                    "event": "log",
+                    "data": {
+                        "level": "INFO",
+                        "message": msg,
+                        "timestamp": time.strftime("%H:%M:%S")
+                    }
+                }
+            )
+
+        # Add callbacks
+        model.add_callback("on_train_epoch_end", on_train_epoch_end)
+        model.add_callback("on_train_batch_end", on_train_batch_end)
+
+        train_task = loop.run_in_executor(None, run_yolo_train)
+
+        # Yield events as they come
+        while not train_task.done() or not event_queue.empty():
+            if self._stop_requested:
+                log.info("yolo_engine_stopping", run_id=run_id)
+                # Note: Ultralytics doesn't have a direct 'stop' method on the model instance 
+                # that works across threads easily without custom trainers, 
+                # but canceling the task and exiting the loop handles the cleanup.
+                break
+            try:
+                # Wait for an event or check if task is done
+                event = await asyncio.wait_for(event_queue.get(), timeout=1.0)
+                yield event
+            except asyncio.TimeoutError:
+                if train_task.done():
+                    break
+                continue
+
+        await train_task # Ensure it's fully done
+        
+        log.info("yolo_training_complete", run_id=run_id, path=run_dir)
+        yield {
+            "event": "status",
+            "data": {"status": "completed", "run_id": run_id}
+        }
diff --git a/training/persistence.py b/training/persistence.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a469c10ff028ab384b671b45dcc18b662e9bf61
--- /dev/null
+++ b/training/persistence.py
@@ -0,0 +1,48 @@
+import os
+import json
+import asyncio
+from typing import Dict, Any, Optional
+from projects.registry import get_project
+from models.project import Project
+from observability.logger import get_logger
+
+log = get_logger("training.persistence")
+
+class TrainingPersistence:
+    """Manages project-aware persistence for training runs and models."""
+    
+    @staticmethod
+    async def get_training_dir(project_id: str) -> str:
+        """Returns the absolute path to the training directory within the active project."""
+        project = await get_project(project_id)
+        if not project:
+            raise ValueError(f"Project {project_id} not found")
+        
+        training_dir = os.path.join(project.path, "training")
+        os.makedirs(training_dir, exist_ok=True)
+        return training_dir
+
+    @staticmethod
+    async def get_run_dir(project_id: str, run_id: str) -> str:
+        """Returns the run-specific directory within the project's training folder."""
+        training_dir = await TrainingPersistence.get_training_dir(project_id)
+        run_dir = os.path.join(training_dir, "runs", run_id)
+        os.makedirs(run_dir, exist_ok=True)
+        
+        # Subdirectories for artifacts
+        os.makedirs(os.path.join(run_dir, "weights"), exist_ok=True)
+        os.makedirs(os.path.join(run_dir, "logs"), exist_ok=True)
+        os.makedirs(os.path.join(run_dir, "previews"), exist_ok=True)
+        
+        return run_dir
+
+    @staticmethod
+    async def save_run_metadata(project_id: str, run_id: str, metadata: Dict[str, Any]):
+        """Saves run metadata (hyperparams, config) to the project folder."""
+        run_dir = await TrainingPersistence.get_run_dir(project_id, run_id)
+        meta_path = os.path.join(run_dir, "metadata.json")
+        
+        with open(meta_path, "w") as f:
+            json.dump(metadata, f, indent=2)
+        
+        log.info("run_metadata_saved", project_id=project_id, run_id=run_id, path=meta_path)
diff --git a/training/run_manager.py b/training/run_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cfb813296c73a9c19dfa1968bf567be27aba32e
--- /dev/null
+++ b/training/run_manager.py
@@ -0,0 +1,196 @@
+import asyncio
+import json
+import time
+import os
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Union
+
+from observability.logger import get_logger
+from .persistence import TrainingPersistence
+from .engines.base import EngineSelector
+
+log = get_logger("training.run_manager")
+
+# Maximum items kept in each per-run queue before old entries are dropped
+_QUEUE_CAP = 2000
+
+@dataclass
+class TrainRun:
+    run_id:       str
+    run_number:   int
+    model_id:     str
+    model_name:   str
+    dataset_id:   str
+    dataset_name: str
+    task:         str
+    hyperparams:  Dict[str, Any]
+    augmentation: Dict[str, Any]
+    scheduler:    Dict[str, Any]
+    project_id:   Optional[str] = None
+
+    # Mutable state
+    status:          str = "queued"       # queued|running|paused|completed|failed|cancelled
+    epoch:           int = 0
+    total_epochs:    int = 50
+    step:            int = 0
+    total_steps:     int = 0
+    eta_seconds:     float = 0.0
+    elapsed_seconds: float = 0.0
+    final_loss:      float = 0.0
+    best_metric:     Dict[str, float] = field(default_factory=dict)
+    created_at:      float = field(default_factory=time.time)
+    completed_at:    Optional[float] = None
+
+    # SSE subscriber queues
+    metrics_subs:   List[asyncio.Queue] = field(default_factory=list)
+    log_subs:       List[asyncio.Queue] = field(default_factory=list)
+    resource_subs:  List[asyncio.Queue] = field(default_factory=list)
+
+    # Background worker task
+    _task: Optional[asyncio.Task] = field(default=None, repr=False)
+    _engine: Optional[Any] = field(default=None, repr=False)
+
+_runs: Dict[str, TrainRun] = {}
+_run_counter = 0
+
+def _next_run_number() -> int:
+    global _run_counter
+    _run_counter += 1
+    return _run_counter
+
+def create_run(
+    model_id: str,
+    model_name: str,
+    dataset_id: str,
+    dataset_name: str,
+    task: str,
+    hyperparams: Dict[str, Any],
+    augmentation: Dict[str, Any],
+    scheduler: Dict[str, Any],
+    project_id: Optional[str] = None
+) -> TrainRun:
+    run_id = str(uuid.uuid4())
+    total_epochs = int(hyperparams.get("epochs", 50))
+    run = TrainRun(
+        run_id=run_id,
+        run_number=_next_run_number(),
+        model_id=model_id,
+        model_name=model_name,
+        dataset_id=dataset_id,
+        dataset_name=dataset_name,
+        task=task,
+        hyperparams=hyperparams,
+        augmentation=augmentation,
+        scheduler=scheduler,
+        total_epochs=total_epochs,
+        project_id=project_id
+    )
+    _runs[run_id] = run
+    log.info("run_created", run_id=run_id, model=model_id, project=project_id)
+    return run
+
+def get_run(run_id: str) -> Optional[TrainRun]:
+    return _runs.get(run_id)
+
+def list_runs() -> List[TrainRun]:
+    return list(_runs.values())
+
+def _broadcast(subs: List[asyncio.Queue], payload: Any) -> None:
+    dead: List[asyncio.Queue] = []
+    for q in subs:
+        if q.qsize() >= _QUEUE_CAP:
+            try: q.get_nowait()
+            except asyncio.QueueEmpty: pass
+        try:
+            q.put_nowait(payload)
+        except Exception:
+            dead.append(q)
+    for d in dead:
+        if d in subs:
+            subs.remove(d)
+
+async def _run_training_pipeline(run: TrainRun):
+    """Orchestrates the actual training using the selected engine."""
+    try:
+        run.status = "running"
+        engine = EngineSelector.get_engine(run.model_id, run.task)
+        run._engine = engine
+        
+        # Ensure project-aware directories exist
+        if run.project_id:
+            await TrainingPersistence.save_run_metadata(run.project_id, run.run_id, {
+                "model_id": run.model_id,
+                "dataset_id": run.dataset_id,
+                "hyperparams": run.hyperparams,
+                "augmentation": run.augmentation
+            })
+
+        async for update in engine.train(run.hyperparams, run.run_id, run.project_id or "default", run.dataset_id):
+            event_type = update.get("event")
+            data = update.get("data")
+            
+            if event_type == "metrics":
+                run.epoch = data.get("epoch", run.epoch)
+                run.step = data.get("step", run.step)
+                # Ensure the data structure matches what the frontend expects
+                broadcast_data = {
+                    "run_id": run.run_id,
+                    "epoch": run.epoch,
+                    "step": run.step,
+                    "epoch_progress": data.get("progress", 0),
+                    "eta_seconds": data.get("eta_seconds", 0),
+                    "elapsed_seconds": data.get("elapsed_seconds", 0),
+                    "task_metrics": data.get("task_metrics", {})
+                }
+                _broadcast(run.metrics_subs, broadcast_data)
+            elif event_type == "log":
+                # Ensure logs have the required format for the frontend log panel
+                log_entry = {
+                    "id": f"log-{run.run_id}-{time.time()}",
+                    "run_id": run.run_id,
+                    "timestamp": data.get("timestamp"),
+                    "level": data.get("level", "INFO"),
+                    "message": data.get("message"),
+                    "source": "trainer"
+                }
+                _broadcast(run.log_subs, log_entry)
+            elif event_type == "status":
+                run.status = data.get("status", run.status)
+
+        if run.status not in ("cancelled", "failed"):
+            run.status = "completed"
+        run.completed_at = time.time()
+        
+    except Exception as e:
+        run.status = "failed"
+        run.completed_at = time.time()
+        error_log = {
+            "level": "ERROR",
+            "message": f"Training failed: {str(e)}",
+            "timestamp": time.strftime("%H:%M:%S")
+        }
+        _broadcast(run.log_subs, error_log)
+        log.error("training_failed", run_id=run.run_id, error=str(e))
+    finally:
+        # Sentinel for SSE streams
+        for subs in [run.metrics_subs, run.log_subs, run.resource_subs]:
+            _broadcast(subs, None)
+
+def start_run(run: TrainRun) -> None:
+    run._task = asyncio.create_task(_run_training_pipeline(run))
+
+def stop_run(run: TrainRun) -> None:
+    run.status = "cancelled"
+    if run._engine and hasattr(run._engine, 'stop'):
+        run._engine.stop()
+    if run._task and not run._task.done():
+        run._task.cancel()
+
+def pause_run(run: TrainRun) -> None:
+    if run.status == "running":
+        run.status = "paused"
+
+def resume_run(run: TrainRun) -> None:
+    if run.status == "paused":
+        run.status = "running"
diff --git a/training/schema_engine.py b/training/schema_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..c77e5f6c00623c56a106822f692a21dcee018d8c
--- /dev/null
+++ b/training/schema_engine.py
@@ -0,0 +1,40 @@
+from typing import Dict, Any
+
+def generate_schema(task: str, model_id: str, dataset_id: str) -> Dict[str, Any]:
+    """
+    Generates a UI schema for the training dashboard based on the task, model, and dataset.
+    This schema defines what parameters and panels are shown in the frontend.
+    """
+    # Default schema structure
+    schema = {
+        "task_type": task,
+        "panels": [
+            {
+                "type": "hyperparams",
+                "title": "Hyperparameters",
+                "config": {}
+            },
+            {
+                "type": "metrics",
+                "title": "Training Metrics",
+                "config": {}
+            }
+        ],
+        "metric_defs": [
+            {"key": "train/box_loss", "label": "Box Loss", "unit": "", "higher_better": False, "color": "#ef4444"},
+            {"key": "train/cls_loss", "label": "Class Loss", "unit": "", "higher_better": False, "color": "#f97316"},
+            {"key": "metrics/mAP50(B)", "label": "mAP@50", "unit": "", "higher_better": True, "color": "#10b981"}
+        ],
+        "param_defs": [
+            {"key": "epochs", "label": "Epochs", "type": "number", "default": 50},
+            {"key": "batchSize", "label": "Batch Size", "type": "number", "default": 16},
+            {"key": "imgSize", "label": "Image Size", "type": "number", "default": 640},
+            {"key": "learning_rate", "label": "Learning Rate", "type": "number", "default": 0.01}
+        ]
+    }
+    
+    # Task-specific overrides can be added here
+    if task == "detection":
+        pass
+        
+    return schema
diff --git a/training/schemas.py b/training/schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad7f4fd88fc726eda7bdc18b18cb787a200d60f
--- /dev/null
+++ b/training/schemas.py
@@ -0,0 +1,65 @@
+from pydantic import BaseModel
+from typing import Dict, Any, Optional, List
+
+class StartTrainRequest(BaseModel):
+    model_id:      str
+    dataset_id:    str
+    task:          str
+    hyperparams:   Dict[str, Any]
+    augmentation:  Dict[str, Any]
+    scheduler:     Dict[str, Any]
+    project_id:    Optional[str] = None
+
+class StartTrainResponse(BaseModel):
+    run_id:   str
+    status:   str
+    message:  str
+
+class TrainStatusResponse(BaseModel):
+    run_id:       str
+    status:       str
+    epoch:        int
+    total_epochs: int
+    step:         int
+    total_steps:  int
+    eta_seconds:  float
+    elapsed_seconds: float
+
+class CheckpointOut(BaseModel):
+    id: str
+    epoch: int
+    path: str
+    metrics: Dict[str, float]
+
+class TrainRunOut(BaseModel):
+    id: str
+    run_number: int
+    model_id: str
+    model_name: str
+    dataset_id: str
+    dataset_name: str
+    task: str
+    status: str
+    epochs_done: int
+    total_epochs: int
+    best_metric: Dict[str, float]
+    final_loss: float
+    duration: str
+    created_at: float
+    completed_at: Optional[float]
+    hyperparams: Dict[str, Any]
+
+class TrainingSchemaResponse(BaseModel):
+    task_type: str
+    panels: List[Dict[str, Any]]
+    metric_defs: List[Dict[str, Any]]
+    param_defs: List[Dict[str, Any]]
+
+class StopTrainRequest(BaseModel):
+    run_id: str
+
+class PauseTrainRequest(BaseModel):
+    run_id: str
+
+class ResumeTrainRequest(BaseModel):
+    run_id: str