Spaces:
Sleeping
Sleeping
| """ | |
| adapters/roboflow_adapter.py β Roboflow Universe API client. | |
| Responsibilities: | |
| - Fetch dataset metadata (search, workspace listings, project details) | |
| - Normalise responses β Dataset domain model | |
| - Cache results in roboflow_cache table (TTL-aware) | |
| - Handle pagination, rate limits, and errors robustly | |
| Roboflow API reference: https://docs.roboflow.com/api-reference/ | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import time | |
| from typing import Any | |
| import httpx | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from database.connection import get_db | |
| from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask | |
| from observability.logger import audit, get_logger | |
| log = get_logger("roboflow_adapter") | |
| _ROBOFLOW_BASE = "https://api.roboflow.com" | |
| _UNIVERSE_BASE = "https://universe.roboflow.com" | |
| _DEFAULT_TTL = 3600 # 1 hour | |
| # ββ Task mapping from Roboflow annotation_type βββββββββββββββββββββββββββββββ | |
| _TASK_MAP: dict[str, DatasetTask] = { | |
| "object-detection": DatasetTask.detection, | |
| "instance-segmentation": DatasetTask.segmentation, | |
| "semantic-segmentation": DatasetTask.segmentation, | |
| "classification": DatasetTask.classification, | |
| "keypoint-detection": DatasetTask.keypoints, | |
| "multiclass-classification": DatasetTask.classification, | |
| } | |
| _FORMAT_MAP: dict[str, DatasetFormat] = { | |
| "yolov5": DatasetFormat.yolo, | |
| "yolov7": DatasetFormat.yolo, | |
| "yolov8": DatasetFormat.yolo, | |
| "yolov9": DatasetFormat.yolo, | |
| "coco": DatasetFormat.coco, | |
| "voc": DatasetFormat.voc, | |
| "tfrecord": DatasetFormat.tfrecord, | |
| "csv": DatasetFormat.csv, | |
| "createml": DatasetFormat.json, | |
| "multiclass": DatasetFormat.csv, | |
| } | |
| def _cache_key(parts: list[str]) -> str: | |
| raw = "|".join(parts) | |
| return hashlib.sha256(raw.encode()).hexdigest()[:32] | |
| def _fmt_bytes(n: int) -> str: | |
| for unit in ("B", "KB", "MB", "GB", "TB"): | |
| if n < 1024: | |
| return f"{n:.1f} {unit}" | |
| n /= 1024 | |
| return f"{n:.1f} PB" | |
| # ββ Cache helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _cache_get(key: str) -> dict[str, Any] | None: | |
| db = await get_db() | |
| async with db.execute( | |
| "SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?", | |
| (key,), | |
| ) as cur: | |
| row = await cur.fetchone() | |
| if row is None: | |
| return None | |
| fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S")) | |
| if time.time() - fetched > row["ttl_secs"]: | |
| return None # expired | |
| return json.loads(row["payload"]) | |
| async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None: | |
| db = await get_db() | |
| await db.execute( | |
| """INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs) | |
| VALUES (?, ?, ?)""", | |
| (key, json.dumps(payload), ttl), | |
| ) | |
| await db.commit() | |
| # ββ HTTP client factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_client(api_key: str) -> httpx.AsyncClient: | |
| return httpx.AsyncClient( | |
| base_url=_ROBOFLOW_BASE, | |
| params={"api_key": api_key}, | |
| timeout=30.0, | |
| headers={"User-Agent": "MLForge/1.0"}, | |
| ) | |
| # ββ Roboflow Adapter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RoboflowAdapter: | |
| """ | |
| Stateless adapter for the Roboflow API. | |
| All methods accept api_key explicitly to support per-user keys. | |
| """ | |
| # ββ Search (Universe) βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def search_datasets( | |
| api_key: str, | |
| query: str = "", | |
| workspace: str | None = None, | |
| page: int = 0, | |
| page_size: int = 50, | |
| ) -> list[Dataset]: | |
| """ | |
| Search Roboflow Universe for datasets. | |
| Returns normalised Dataset objects. | |
| """ | |
| ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)]) | |
| cached = await _cache_get(ck) | |
| if cached: | |
| log.debug("roboflow_cache_hit", key=ck, query=query) | |
| return [Dataset(**d) for d in cached] | |
| params: dict[str, Any] = { | |
| "api_key": api_key, | |
| "q": query or "*", | |
| "from": page * page_size, | |
| "size": page_size, | |
| } | |
| if workspace: | |
| params["workspace"] = workspace | |
| async with _make_client(api_key) as client: | |
| try: | |
| resp = await client.get("/", params=params) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except httpx.HTTPStatusError as e: | |
| log.error("roboflow_api_error", status=e.response.status_code, query=query) | |
| await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error") | |
| raise | |
| datasets = [] | |
| for item in data.get("results", []): | |
| try: | |
| ds = RoboflowAdapter._normalise_search_result(item) | |
| datasets.append(ds) | |
| except Exception as exc: | |
| log.warning("normalise_error", item_id=item.get("id"), error=str(exc)) | |
| await _cache_set(ck, [d.model_dump() for d in datasets]) | |
| await audit("roboflow_search", {"query": query, "count": len(datasets)}) | |
| return datasets | |
| # ββ Workspace datasets listing ββββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_workspace_datasets( | |
| api_key: str, | |
| workspace: str, | |
| ) -> list[Dataset]: | |
| """List all datasets in a Roboflow workspace.""" | |
| ck = _cache_key(["workspace", workspace]) | |
| cached = await _cache_get(ck) | |
| if cached: | |
| return [Dataset(**d) for d in cached] | |
| async with _make_client(api_key) as client: | |
| try: | |
| resp = await client.get(f"/{workspace}") | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except httpx.HTTPStatusError as e: | |
| log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code) | |
| raise | |
| datasets = [] | |
| for proj in data.get("workspace", {}).get("projects", []): | |
| try: | |
| ds = RoboflowAdapter._normalise_project(proj, workspace) | |
| datasets.append(ds) | |
| except Exception as exc: | |
| log.warning("normalise_project_error", project=proj.get("id"), error=str(exc)) | |
| await _cache_set(ck, [d.model_dump() for d in datasets]) | |
| return datasets | |
| # ββ Single project detail βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_project( | |
| api_key: str, | |
| workspace: str, | |
| project_id: str, | |
| ) -> Dataset | None: | |
| """Fetch full metadata for a single Roboflow project.""" | |
| ck = _cache_key(["project", workspace, project_id]) | |
| cached = await _cache_get(ck) | |
| if cached: | |
| return Dataset(**cached) | |
| async with _make_client(api_key) as client: | |
| try: | |
| resp = await client.get(f"/{workspace}/{project_id}") | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code == 404: | |
| return None | |
| raise | |
| proj_data = data.get("project", data) | |
| ds = RoboflowAdapter._normalise_project(proj_data, workspace) | |
| await _cache_set(ck, ds.model_dump()) | |
| return ds | |
| # ββ Download URL builder ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_download_url( | |
| api_key: str, | |
| workspace: str, | |
| project_id: str, | |
| version: int, | |
| export_format: str = "yolov8", | |
| ) -> str: | |
| """ | |
| Fetch the export download link from Roboflow for the specified format. | |
| Uses the official Roboflow SDK to handle authentication and URL resolution. | |
| """ | |
| try: | |
| from roboflow import Roboflow | |
| rf = Roboflow(api_key=api_key) | |
| project = rf.workspace(workspace).project(project_id) | |
| version_obj = project.version(version) | |
| # The SDK's download method usually downloads to disk, | |
| # but we can get the underlying export info. | |
| # We'll use a thread to run the SDK call since it's blocking. | |
| import asyncio | |
| def _get_link(): | |
| return version_obj.export(export_format).download_link | |
| link = await asyncio.to_thread(_get_link) | |
| if not link: | |
| raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") | |
| return link | |
| except Exception as e: | |
| log.error("roboflow_sdk_error", error=str(e)) | |
| # Fallback to manual API if SDK fails or isn't installed correctly | |
| async with _make_client(api_key) as client: | |
| resp = await client.get( | |
| f"/{workspace}/{project_id}/{version}/{export_format}" | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| link = export.get("link") or "" | |
| if not link: | |
| # If 'link' is missing, check if it's a Universe-style project and try to resolve manually | |
| # Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key] | |
| if "project" in data: | |
| pid = data["project"].get("id") | |
| if pid: | |
| link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}" | |
| if not link: | |
| raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}") | |
| # Ensure the link includes the API key correctly | |
| if "universe.roboflow.com" in link: | |
| if "key=" not in link: | |
| separator = "&" if "?" in link else "?" | |
| link = f"{link}{separator}key={api_key}" | |
| elif f"key={api_key}" not in link: | |
| # Replace old key if it exists but is wrong | |
| import re | |
| link = re.sub(r"key=[^&]+", f"key={api_key}", link) | |
| return link | |
| # ββ Normalisation helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _normalise_search_result(item: dict[str, Any]) -> Dataset: | |
| """Map a Universe search result β Dataset.""" | |
| ann_type = item.get("annotation", {}).get("type", "object-detection") | |
| rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) | |
| class_names = [c.get("name", "") for c in item.get("classes", [])] | |
| images = item.get("images", 0) or 0 | |
| return Dataset( | |
| id = item.get("id", "").replace("/", "__"), | |
| name = item.get("name", "Unnamed"), | |
| description = item.get("description", ""), | |
| task = rf_task, | |
| format = DatasetFormat.yolo, | |
| source = DatasetSource.roboflow, | |
| status = DatasetStatus.available, | |
| images = images, | |
| classes = len(class_names), | |
| class_names = class_names, | |
| size_bytes = 0, | |
| size_label = "β", | |
| tags = item.get("tags", []), | |
| roboflow_id = item.get("id", ""), | |
| created_at = item.get("created", ""), | |
| updated_at = item.get("updated", ""), | |
| ) | |
| def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset: | |
| """Map a workspace project β Dataset.""" | |
| ann_type = proj.get("annotation", "object-detection") | |
| rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection) | |
| class_names = [c.get("name", c) if isinstance(c, dict) else c | |
| for c in proj.get("classes", [])] | |
| project_id = proj.get("id", proj.get("name", "unknown")) | |
| rf_id = f"{workspace}/{project_id}" | |
| images = proj.get("images", 0) or 0 | |
| return Dataset( | |
| id = rf_id.replace("/", "__"), | |
| name = proj.get("name", project_id), | |
| description = proj.get("description", ""), | |
| task = rf_task, | |
| format = DatasetFormat.yolo, | |
| source = DatasetSource.roboflow, | |
| status = DatasetStatus.available, | |
| images = images, | |
| classes = len(class_names), | |
| class_names = class_names, | |
| size_bytes = 0, | |
| size_label = "β", | |
| roboflow_id = rf_id, | |
| created_at = proj.get("created", ""), | |
| updated_at = proj.get("updated", ""), | |
| ) | |