""" adapters/roboflow_adapter.py — Roboflow Universe API client. Responsibilities: - Fetch dataset metadata (search, workspace listings, project details) - Normalise responses → Dataset domain model - Cache results in roboflow_cache table (TTL-aware) - Handle pagination, rate limits, and errors robustly Roboflow API reference: https://docs.roboflow.com/api-reference/ """ from __future__ import annotations import hashlib import json import time from typing import Any import httpx from tenacity import retry, stop_after_attempt, wait_exponential from database.connection import get_db from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask from observability.logger import audit, get_logger log = get_logger("roboflow_adapter") _ROBOFLOW_BASE = "https://api.roboflow.com" _UNIVERSE_BASE = "https://universe.roboflow.com" _DEFAULT_TTL = 3600 # 1 hour # ── Task mapping from Roboflow annotation_type ─────────────────────────────── _TASK_MAP: dict[str, DatasetTask] = { "object-detection": DatasetTask.detection, "instance-segmentation": DatasetTask.segmentation, "semantic-segmentation": DatasetTask.segmentation, "classification": DatasetTask.classification, "keypoint-detection": DatasetTask.keypoints, "multiclass-classification": DatasetTask.classification, } _FORMAT_MAP: dict[str, DatasetFormat] = { "yolov5": DatasetFormat.yolo, "yolov7": DatasetFormat.yolo, "yolov8": DatasetFormat.yolo, "yolov9": DatasetFormat.yolo, "coco": DatasetFormat.coco, "voc": DatasetFormat.voc, "tfrecord": DatasetFormat.tfrecord, "csv": DatasetFormat.csv, "createml": DatasetFormat.json, "multiclass": DatasetFormat.csv, } def _cache_key(parts: list[str]) -> str: raw = "|".join(parts) return hashlib.sha256(raw.encode()).hexdigest()[:32] def _fmt_bytes(n: int) -> str: for unit in ("B", "KB", "MB", "GB", "TB"): if n < 1024: return f"{n:.1f} {unit}" n /= 1024 return f"{n:.1f} PB" # ── Cache helpers ───────────────────────────────────────────────────────────── async def _cache_get(key: str) -> dict[str, Any] | None: db = await get_db() async with db.execute( "SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?", (key,), ) as cur: row = await cur.fetchone() if row is None: return None fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S")) if time.time() - fetched > row["ttl_secs"]: return None # expired return json.loads(row["payload"]) async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None: db = await get_db() await db.execute( """INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs) VALUES (?, ?, ?)""", (key, json.dumps(payload), ttl), ) await db.commit() # ── HTTP client factory ─────────────────────────────────────────────────────── def _make_client(api_key: str) -> httpx.AsyncClient: return httpx.AsyncClient( base_url=_ROBOFLOW_BASE, params={"api_key": api_key}, timeout=30.0, headers={"User-Agent": "MLForge/1.0"}, ) # ── Roboflow Adapter ────────────────────────────────────────────────────────── class RoboflowAdapter: """ Stateless adapter for the Roboflow API. All methods accept api_key explicitly to support per-user keys. """ # ── Search (Universe) ───────────────────────────────────────────────────── @staticmethod @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) async def search_datasets( api_key: str, query: str = "", workspace: str | None = None, page: int = 0, page_size: int = 50, ) -> list[Dataset]: """ Search Roboflow Universe for datasets. Returns normalised Dataset objects. """ ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)]) cached = await _cache_get(ck) if cached: log.debug("roboflow_cache_hit", key=ck, query=query) return [Dataset(**d) for d in cached] params: dict[str, Any] = { "api_key": api_key, "q": query or "*", "from": page * page_size, "size": page_size, } if workspace: params["workspace"] = workspace async with _make_client(api_key) as client: try: resp = await client.get("/", params=params) resp.raise_for_status() data = resp.json() except httpx.HTTPStatusError as e: log.error("roboflow_api_error", status=e.response.status_code, query=query) await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error") raise datasets = [] for item in data.get("results", []): try: ds = RoboflowAdapter._normalise_search_result(item) datasets.append(ds) except Exception as exc: log.warning("normalise_error", item_id=item.get("id"), error=str(exc)) await _cache_set(ck, [d.model_dump() for d in datasets]) await audit("roboflow_search", {"query": query, "count": len(datasets)}) return datasets # ── Workspace datasets listing ──────────────────────────────────────────── @staticmethod @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) async def list_workspace_datasets( api_key: str, workspace: str, ) -> list[Dataset]: """List all datasets in a Roboflow workspace.""" ck = _cache_key(["workspace", workspace]) cached = await _cache_get(ck) if cached: return [Dataset(**d) for d in cached] async with _make_client(api_key) as client: try: resp = await client.get(f"/{workspace}") resp.raise_for_status() data = resp.json() except httpx.HTTPStatusError as e: log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code) raise datasets = [] for proj in data.get("workspace", {}).get("projects", []): try: ds = RoboflowAdapter._normalise_project(proj, workspace) datasets.append(ds) except Exception as exc: log.warning("normalise_project_error", project=proj.get("id"), error=str(exc)) await _cache_set(ck, [d.model_dump() for d in datasets]) return datasets # ── Single project detail ───────────────────────────────────────────────── @staticmethod @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) async def get_project( api_key: str, workspace: str, project_id: str, ) -> Dataset | None: """Fetch full metadata for a single Roboflow project.""" ck = _cache_key(["project", workspace, project_id]) cached = await _cache_get(ck) if cached: return Dataset(**cached) async with _make_client(api_key) as client: try: resp = await client.get(f"/{workspace}/{project_id}") resp.raise_for_status() data = resp.json() except httpx.HTTPStatusError as e: if e.response.status_code == 404: return None raise proj_data = data.get("project", data) ds = RoboflowAdapter._normalise_project(proj_data, workspace) await _cache_set(ck, ds.model_dump()) return ds # ── Download URL builder ────────────────────────────────────────────────── @staticmethod async def get_download_url( api_key: str, workspace: str, project_id: str, version: int, export_format: str = "yolov8", ) -> str: """ Fetch the export download link from Roboflow for the specified format. Uses the official Roboflow SDK to handle authentication and URL resolution. """ try: from roboflow import Roboflow rf = Roboflow(api_key=api_key) project = rf.workspace(workspace).project(project_id) version_obj = project.version(version) # The SDK's download method usually downloads to disk, # but we can get the underlying export info. # We'll use a thread to run the SDK call since it's blocking. import asyncio def _get_link(): return version_obj.export(export_format).download_link link = await asyncio.to_thread(_get_link) if not link: raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") return link except Exception as e: log.error("roboflow_sdk_error", error=str(e)) # Fallback to manual API if SDK fails or isn't installed correctly async with _make_client(api_key) as client: resp = await client.get( f"/{workspace}/{project_id}/{version}/{export_format}" ) resp.raise_for_status() data = resp.json() link = export.get("link") or "" if not link: # If 'link' is missing, check if it's a Universe-style project and try to resolve manually # Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key] if "project" in data: pid = data["project"].get("id") if pid: link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}" if not link: raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") # Ensure the link includes the API key correctly if "universe.roboflow.com" in link: if "key=" not in link: separator = "&" if "?" in link else "?" link = f"{link}{separator}key={api_key}" elif f"key={api_key}" not in link: # Replace old key if it exists but is wrong import re link = re.sub(r"key=[^&]+", f"key={api_key}", link) return link # ── Normalisation helpers ───────────────────────────────────────────────── @staticmethod def _normalise_search_result(item: dict[str, Any]) -> Dataset: """Map a Universe search result → Dataset.""" ann_type = item.get("annotation", {}).get("type", "object-detection") rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) class_names = [c.get("name", "") for c in item.get("classes", [])] images = item.get("images", 0) or 0 return Dataset( id = item.get("id", "").replace("/", "__"), name = item.get("name", "Unnamed"), description = item.get("description", ""), task = rf_task, format = DatasetFormat.yolo, source = DatasetSource.roboflow, status = DatasetStatus.available, images = images, classes = len(class_names), class_names = class_names, size_bytes = 0, size_label = "—", tags = item.get("tags", []), roboflow_id = item.get("id", ""), created_at = item.get("created", ""), updated_at = item.get("updated", ""), ) @staticmethod def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset: """Map a workspace project → Dataset.""" ann_type = proj.get("annotation", "object-detection") rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) class_names = [c.get("name", c) if isinstance(c, dict) else c for c in proj.get("classes", [])] project_id = proj.get("id", proj.get("name", "unknown")) rf_id = f"{workspace}/{project_id}" images = proj.get("images", 0) or 0 return Dataset( id = rf_id.replace("/", "__"), name = proj.get("name", project_id), description = proj.get("description", ""), task = rf_task, format = DatasetFormat.yolo, source = DatasetSource.roboflow, status = DatasetStatus.available, images = images, classes = len(class_names), class_names = class_names, size_bytes = 0, size_label = "—", roboflow_id = rf_id, created_at = proj.get("created", ""), updated_at = proj.get("updated", ""), )