mlforge / adapters /roboflow_adapter.py
senthil2421
fix: remove project dependencies and redundant imports to fix server startup
8302f42
"""
adapters/roboflow_adapter.py β€” Roboflow Universe API client.
Responsibilities:
- Fetch dataset metadata (search, workspace listings, project details)
- Normalise responses β†’ Dataset domain model
- Cache results in roboflow_cache table (TTL-aware)
- Handle pagination, rate limits, and errors robustly
Roboflow API reference: https://docs.roboflow.com/api-reference/
"""
from __future__ import annotations
import hashlib
import json
import time
from typing import Any
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from database.connection import get_db
from models.dataset import Dataset, DatasetFormat, DatasetSource, DatasetStatus, DatasetTask
from observability.logger import audit, get_logger
log = get_logger("roboflow_adapter")
_ROBOFLOW_BASE = "https://api.roboflow.com"
_UNIVERSE_BASE = "https://universe.roboflow.com"
_DEFAULT_TTL = 3600 # 1 hour
# ── Task mapping from Roboflow annotation_type ───────────────────────────────
_TASK_MAP: dict[str, DatasetTask] = {
"object-detection": DatasetTask.detection,
"instance-segmentation": DatasetTask.segmentation,
"semantic-segmentation": DatasetTask.segmentation,
"classification": DatasetTask.classification,
"keypoint-detection": DatasetTask.keypoints,
"multiclass-classification": DatasetTask.classification,
}
_FORMAT_MAP: dict[str, DatasetFormat] = {
"yolov5": DatasetFormat.yolo,
"yolov7": DatasetFormat.yolo,
"yolov8": DatasetFormat.yolo,
"yolov9": DatasetFormat.yolo,
"coco": DatasetFormat.coco,
"voc": DatasetFormat.voc,
"tfrecord": DatasetFormat.tfrecord,
"csv": DatasetFormat.csv,
"createml": DatasetFormat.json,
"multiclass": DatasetFormat.csv,
}
def _cache_key(parts: list[str]) -> str:
raw = "|".join(parts)
return hashlib.sha256(raw.encode()).hexdigest()[:32]
def _fmt_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024:
return f"{n:.1f} {unit}"
n /= 1024
return f"{n:.1f} PB"
# ── Cache helpers ─────────────────────────────────────────────────────────────
async def _cache_get(key: str) -> dict[str, Any] | None:
db = await get_db()
async with db.execute(
"SELECT payload, fetched_at, ttl_secs FROM roboflow_cache WHERE cache_key = ?",
(key,),
) as cur:
row = await cur.fetchone()
if row is None:
return None
fetched = time.mktime(time.strptime(row["fetched_at"], "%Y-%m-%d %H:%M:%S"))
if time.time() - fetched > row["ttl_secs"]:
return None # expired
return json.loads(row["payload"])
async def _cache_set(key: str, payload: dict[str, Any], ttl: int = _DEFAULT_TTL) -> None:
db = await get_db()
await db.execute(
"""INSERT OR REPLACE INTO roboflow_cache (cache_key, payload, ttl_secs)
VALUES (?, ?, ?)""",
(key, json.dumps(payload), ttl),
)
await db.commit()
# ── HTTP client factory ───────────────────────────────────────────────────────
def _make_client(api_key: str) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=_ROBOFLOW_BASE,
params={"api_key": api_key},
timeout=30.0,
headers={"User-Agent": "MLForge/1.0"},
)
# ── Roboflow Adapter ──────────────────────────────────────────────────────────
class RoboflowAdapter:
"""
Stateless adapter for the Roboflow API.
All methods accept api_key explicitly to support per-user keys.
"""
# ── Search (Universe) ─────────────────────────────────────────────────────
@staticmethod
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
async def search_datasets(
api_key: str,
query: str = "",
workspace: str | None = None,
page: int = 0,
page_size: int = 50,
) -> list[Dataset]:
"""
Search Roboflow Universe for datasets.
Returns normalised Dataset objects.
"""
ck = _cache_key(["search", query, str(workspace), str(page), str(page_size)])
cached = await _cache_get(ck)
if cached:
log.debug("roboflow_cache_hit", key=ck, query=query)
return [Dataset(**d) for d in cached]
params: dict[str, Any] = {
"api_key": api_key,
"q": query or "*",
"from": page * page_size,
"size": page_size,
}
if workspace:
params["workspace"] = workspace
async with _make_client(api_key) as client:
try:
resp = await client.get("/", params=params)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPStatusError as e:
log.error("roboflow_api_error", status=e.response.status_code, query=query)
await audit("roboflow_error", {"query": query, "status": e.response.status_code}, level="error")
raise
datasets = []
for item in data.get("results", []):
try:
ds = RoboflowAdapter._normalise_search_result(item)
datasets.append(ds)
except Exception as exc:
log.warning("normalise_error", item_id=item.get("id"), error=str(exc))
await _cache_set(ck, [d.model_dump() for d in datasets])
await audit("roboflow_search", {"query": query, "count": len(datasets)})
return datasets
# ── Workspace datasets listing ────────────────────────────────────────────
@staticmethod
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
async def list_workspace_datasets(
api_key: str,
workspace: str,
) -> list[Dataset]:
"""List all datasets in a Roboflow workspace."""
ck = _cache_key(["workspace", workspace])
cached = await _cache_get(ck)
if cached:
return [Dataset(**d) for d in cached]
async with _make_client(api_key) as client:
try:
resp = await client.get(f"/{workspace}")
resp.raise_for_status()
data = resp.json()
except httpx.HTTPStatusError as e:
log.error("roboflow_workspace_error", workspace=workspace, status=e.response.status_code)
raise
datasets = []
for proj in data.get("workspace", {}).get("projects", []):
try:
ds = RoboflowAdapter._normalise_project(proj, workspace)
datasets.append(ds)
except Exception as exc:
log.warning("normalise_project_error", project=proj.get("id"), error=str(exc))
await _cache_set(ck, [d.model_dump() for d in datasets])
return datasets
# ── Single project detail ─────────────────────────────────────────────────
@staticmethod
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
async def get_project(
api_key: str,
workspace: str,
project_id: str,
) -> Dataset | None:
"""Fetch full metadata for a single Roboflow project."""
ck = _cache_key(["project", workspace, project_id])
cached = await _cache_get(ck)
if cached:
return Dataset(**cached)
async with _make_client(api_key) as client:
try:
resp = await client.get(f"/{workspace}/{project_id}")
resp.raise_for_status()
data = resp.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return None
raise
proj_data = data.get("project", data)
ds = RoboflowAdapter._normalise_project(proj_data, workspace)
await _cache_set(ck, ds.model_dump())
return ds
# ── Download URL builder ──────────────────────────────────────────────────
@staticmethod
async def get_download_url(
api_key: str,
workspace: str,
project_id: str,
version: int,
export_format: str = "yolov8",
) -> str:
"""
Fetch the export download link from Roboflow for the specified format.
Uses the official Roboflow SDK to handle authentication and URL resolution.
"""
try:
from roboflow import Roboflow
rf = Roboflow(api_key=api_key)
project = rf.workspace(workspace).project(project_id)
version_obj = project.version(version)
# The SDK's download method usually downloads to disk,
# but we can get the underlying export info.
# We'll use a thread to run the SDK call since it's blocking.
import asyncio
def _get_link():
return version_obj.export(export_format).download_link
link = await asyncio.to_thread(_get_link)
if not link:
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
return link
except Exception as e:
log.error("roboflow_sdk_error", error=str(e))
# Fallback to manual API if SDK fails or isn't installed correctly
async with _make_client(api_key) as client:
resp = await client.get(
f"/{workspace}/{project_id}/{version}/{export_format}"
)
resp.raise_for_status()
data = resp.json()
link = export.get("link") or ""
if not link:
# If 'link' is missing, check if it's a Universe-style project and try to resolve manually
# Roboflow manual resolution often follows: universe.roboflow.com/ds/[id]?key=[api_key]
if "project" in data:
pid = data["project"].get("id")
if pid:
link = f"https://universe.roboflow.com/ds/{pid}?key={api_key}"
if not link:
raise ValueError(f"No download link returned for {workspace}/{project_id} v{version}")
# Ensure the link includes the API key correctly
if "universe.roboflow.com" in link:
if "key=" not in link:
separator = "&" if "?" in link else "?"
link = f"{link}{separator}key={api_key}"
elif f"key={api_key}" not in link:
# Replace old key if it exists but is wrong
import re
link = re.sub(r"key=[^&]+", f"key={api_key}", link)
return link
# ── Normalisation helpers ─────────────────────────────────────────────────
@staticmethod
def _normalise_search_result(item: dict[str, Any]) -> Dataset:
"""Map a Universe search result β†’ Dataset."""
ann_type = item.get("annotation", {}).get("type", "object-detection")
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
class_names = [c.get("name", "") for c in item.get("classes", [])]
images = item.get("images", 0) or 0
return Dataset(
id = item.get("id", "").replace("/", "__"),
name = item.get("name", "Unnamed"),
description = item.get("description", ""),
task = rf_task,
format = DatasetFormat.yolo,
source = DatasetSource.roboflow,
status = DatasetStatus.available,
images = images,
classes = len(class_names),
class_names = class_names,
size_bytes = 0,
size_label = "β€”",
tags = item.get("tags", []),
roboflow_id = item.get("id", ""),
created_at = item.get("created", ""),
updated_at = item.get("updated", ""),
)
@staticmethod
def _normalise_project(proj: dict[str, Any], workspace: str) -> Dataset:
"""Map a workspace project β†’ Dataset."""
ann_type = proj.get("annotation", "object-detection")
rf_task = _TASK_MAP.get(ann_type, DatasetTask.detection)
class_names = [c.get("name", c) if isinstance(c, dict) else c
for c in proj.get("classes", [])]
project_id = proj.get("id", proj.get("name", "unknown"))
rf_id = f"{workspace}/{project_id}"
images = proj.get("images", 0) or 0
return Dataset(
id = rf_id.replace("/", "__"),
name = proj.get("name", project_id),
description = proj.get("description", ""),
task = rf_task,
format = DatasetFormat.yolo,
source = DatasetSource.roboflow,
status = DatasetStatus.available,
images = images,
classes = len(class_names),
class_names = class_names,
size_bytes = 0,
size_label = "β€”",
roboflow_id = rf_id,
created_at = proj.get("created", ""),
updated_at = proj.get("updated", ""),
)