mlforge / datasets /registry.py
senthil2421
Refactor cloud_backend: remove local execution routes and fix missing modules
e10cda2
"""
datasets/registry.py β€” Dataset Registry: persistent CRUD against datasets table.
All DB interactions for datasets and dataset_jobs live here.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime
from typing import Any
from database.connection import get_db
from models.dataset import Dataset, DatasetJob, DatasetStatus, row_to_dataset, row_to_job
from observability.logger import get_logger
log = get_logger("dataset_registry")
# ── Dataset CRUD ──────────────────────────────────────────────────────────────
async def get_all_datasets(
task: str | None = None,
format: str | None = None,
source: str | None = None,
status: str | None = None,
search: str | None = None,
starred: bool | None = None,
limit: int = 500,
offset: int = 0,
) -> list[Dataset]:
db = await get_db()
clauses = []
params: list[Any] = []
if task:
clauses.append("task = ?")
params.append(task)
if format:
clauses.append("format = ?")
params.append(format)
if source:
clauses.append("source = ?")
params.append(source)
if status:
clauses.append("status = ?")
params.append(status)
if starred is not None:
clauses.append("starred = ?")
params.append(1 if starred else 0)
if search:
clauses.append("(name LIKE ? OR description LIKE ? OR tags LIKE ?)")
q = f"%{search}%"
params.extend([q, q, q])
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
sql = f"SELECT * FROM datasets {where} ORDER BY updated_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
async with db.execute(sql, params) as cur:
rows = await cur.fetchall()
return [row_to_dataset(r) for r in rows]
async def get_dataset_stats(dataset_id: str) -> dict:
"""Get pre-computed class distributions and statistics from the indexed annotations."""
db = await get_db()
# Class distribution (from dataset_annotations table)
async with db.execute(
"SELECT label, COUNT(*) as count FROM dataset_annotations WHERE dataset_id=? GROUP BY label ORDER BY count DESC",
(dataset_id,)
) as cur:
dist = await cur.fetchall()
# Split distribution (from dataset_images table)
async with db.execute(
"SELECT split, COUNT(*) as count FROM dataset_images WHERE dataset_id=? GROUP BY split",
(dataset_id,)
) as cur:
splits = await cur.fetchall()
return {
"class_distribution": {row["label"]: row["count"] for row in dist},
"split_distribution": {row["split"]: row["count"] for row in splits}
}
async def get_dataset(dataset_id: str) -> Dataset | None:
db = await get_db()
async with db.execute("SELECT * FROM datasets WHERE id = ?", (dataset_id,)) as cur:
row = await cur.fetchone()
return row_to_dataset(row) if row else None
async def count_datasets() -> int:
db = await get_db()
async with db.execute("SELECT COUNT(*) FROM datasets") as cur:
row = await cur.fetchone()
return row[0] if row else 0
async def upsert_dataset(ds: Dataset) -> None:
"""Insert or replace a dataset record."""
db = await get_db()
task = getattr(ds.task, "value", ds.task)
fmt = getattr(ds.format, "value", ds.format)
src = getattr(ds.source, "value", ds.source)
status = getattr(ds.status, "value", ds.status)
await db.execute(
"""INSERT OR REPLACE INTO datasets
(id, name, description, task, format, source, status,
images, classes, class_names, size_bytes, size_label,
local_path, import_progress, tags, versions, active_version,
starred, roboflow_id, created_at, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,datetime('now'))""",
(
ds.id, ds.name, ds.description, task, fmt,
src, status,
ds.images, ds.classes,
json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
ds.local_path, ds.import_progress,
json.dumps(ds.tags),
json.dumps([v.model_dump() if hasattr(v, "model_dump") else v for v in ds.versions]),
ds.active_version,
1 if ds.starred else 0,
ds.roboflow_id,
ds.created_at or datetime.utcnow().isoformat(),
),
)
await db.commit()
async def update_dataset_status(
dataset_id: str,
status: DatasetStatus,
progress: float | None = None,
local_path: str | None = None,
) -> None:
db = await get_db()
if progress is not None and local_path is not None:
await db.execute(
"UPDATE datasets SET status=?, import_progress=?, local_path=? WHERE id=?",
(status.value, progress, local_path, dataset_id),
)
elif progress is not None:
await db.execute(
"UPDATE datasets SET status=?, import_progress=? WHERE id=?",
(status.value, progress, dataset_id),
)
else:
await db.execute(
"UPDATE datasets SET status=? WHERE id=?",
(status.value, dataset_id),
)
await db.commit()
async def update_dataset_stats(
dataset_id: str,
images: int,
classes: int,
class_names: list[str],
size_bytes: int,
stats: dict | None = None
) -> None:
db = await get_db()
# Calculate health score if stats provided
health_score = 0.0
if stats:
health_score = stats.get("health_score", 0.0)
await db.execute(
"""UPDATE datasets
SET images=?, classes=?, class_names=?, size_bytes=?,
size_label=?, stats=?, health_score=?
WHERE id=?""",
(
images, classes, json.dumps(class_names),
size_bytes, _fmt_bytes(size_bytes),
json.dumps(stats) if stats else None,
health_score,
dataset_id,
),
)
await db.commit()
async def delete_dataset(dataset_id: str) -> bool:
db = await get_db()
async with db.execute("SELECT 1 FROM datasets WHERE id=?", (dataset_id,)) as cur:
exists = await cur.fetchone()
if not exists:
return False
await db.execute("DELETE FROM datasets WHERE id=?", (dataset_id,))
await db.commit()
return True
async def toggle_starred(dataset_id: str) -> bool:
"""Toggle starred flag, return new value."""
db = await get_db()
async with db.execute("SELECT starred FROM datasets WHERE id=?", (dataset_id,)) as cur:
row = await cur.fetchone()
if not row:
return False
new_val = 0 if row["starred"] else 1
await db.execute("UPDATE datasets SET starred=? WHERE id=?", (new_val, dataset_id))
await db.commit()
return bool(new_val)
# ── Bulk dataset upsert from Roboflow ────────────────────────────────────────
async def bulk_upsert_datasets(datasets: list[Dataset]) -> int:
"""Insert/update many datasets in a single transaction."""
if not datasets:
return 0
db = await get_db()
now = datetime.utcnow().isoformat()
rows = [
(
ds.id, ds.name, ds.description, ds.task.value, ds.format.value,
ds.source.value, ds.status.value,
ds.images, ds.classes,
json.dumps(ds.class_names), ds.size_bytes, ds.size_label,
ds.local_path, ds.import_progress,
json.dumps(ds.tags), json.dumps([]),
ds.active_version, 0, ds.roboflow_id,
ds.created_at or now,
)
for ds in datasets
]
await db.executemany(
"""INSERT OR IGNORE INTO datasets
(id, name, description, task, format, source, status,
images, classes, class_names, size_bytes, size_label,
local_path, import_progress, tags, versions, active_version,
starred, roboflow_id, created_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
rows,
)
await db.commit()
return len(datasets)
# ── Dataset Jobs ──────────────────────────────────────────────────────────────
async def create_job(
dataset_id: str,
dataset_name: str,
job_type: str,
) -> DatasetJob:
db = await get_db()
job_id = f"djob-{uuid.uuid4().hex[:12]}"
now = datetime.utcnow().isoformat()
await db.execute(
"""INSERT INTO dataset_jobs
(id, type, status, dataset_id, dataset_name, progress, message, created_at)
VALUES (?, ?, 'queued', ?, ?, 0.0, '', ?)""",
(job_id, job_type, dataset_id, dataset_name, now),
)
await db.commit()
return DatasetJob(
id=job_id, type=job_type, status="queued",
dataset_id=dataset_id, dataset_name=dataset_name,
created_at=now,
)
async def update_job(
job_id: str,
status: str | None = None,
progress: float | None = None,
message: str | None = None,
error: str | None = None,
started_at: str | None = None,
ended_at: str | None = None,
) -> None:
db = await get_db()
parts = []
params: list[Any] = []
if status is not None:
parts.append("status=?"); params.append(status)
if progress is not None:
parts.append("progress=?"); params.append(progress)
if message is not None:
parts.append("message=?"); params.append(message)
if error is not None:
parts.append("error=?"); params.append(error)
if started_at is not None:
parts.append("started_at=?"); params.append(started_at)
if ended_at is not None:
parts.append("ended_at=?"); params.append(ended_at)
if not parts:
return
params.append(job_id)
await db.execute(f"UPDATE dataset_jobs SET {', '.join(parts)} WHERE id=?", params)
await db.commit()
async def get_job(job_id: str) -> DatasetJob | None:
db = await get_db()
async with db.execute("SELECT * FROM dataset_jobs WHERE id=?", (job_id,)) as cur:
row = await cur.fetchone()
return row_to_job(row) if row else None
async def get_all_jobs(limit: int = 100) -> list[DatasetJob]:
db = await get_db()
async with db.execute(
"SELECT * FROM dataset_jobs ORDER BY created_at DESC LIMIT ?", (limit,)
) as cur:
rows = await cur.fetchall()
return [row_to_job(r) for r in rows]
# ── Image Index ───────────────────────────────────────────────────────────────
async def index_images(
dataset_id: str,
records: list[dict], # [{id, filename, rel_path, width, height, split, ann_count}]
) -> int:
db = await get_db()
await db.executemany(
"""INSERT OR IGNORE INTO dataset_images
(id, dataset_id, filename, rel_path, width, height, split, ann_count)
VALUES (:id, :dataset_id, :filename, :rel_path, :width, :height, :split, :ann_count)""",
[{"dataset_id": dataset_id, **r} for r in records],
)
await db.commit()
return len(records)
async def get_image_page(
dataset_id: str,
page: int = 0,
page_size: int = 20,
split: str | None = None,
class_label: str | None = None,
) -> tuple[int, list[dict]]:
db = await get_db()
clauses = ["dataset_id=?"]
params: list[Any] = [dataset_id]
if split:
clauses.append("split=?")
params.append(split)
if class_label:
# Join with annotations table to filter by class
where = f"WHERE {' AND '.join(clauses)} AND id IN (SELECT image_id FROM dataset_annotations WHERE label=?)"
count_params = params + [class_label]
else:
where = f"WHERE {' AND '.join(clauses)}"
count_params = params
async with db.execute(f"SELECT COUNT(*) FROM dataset_images {where}", count_params) as cur:
total = (await cur.fetchone())[0]
params_final = count_params + [page_size, page * page_size]
async with db.execute(
f"SELECT * FROM dataset_images {where} ORDER BY filename LIMIT ? OFFSET ?", params_final
) as cur:
rows = await cur.fetchall()
return total, [dict(r) for r in rows]
async def get_annotations_for_image(image_id: str) -> list[dict]:
db = await get_db()
async with db.execute(
"SELECT * FROM dataset_annotations WHERE image_id=?", (image_id,)
) as cur:
rows = await cur.fetchall()
return [dict(r) for r in rows]
async def bulk_insert_annotations(records: list[dict]) -> int:
if not records:
return 0
db = await get_db()
await db.executemany(
"""INSERT OR IGNORE INTO dataset_annotations
(id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
normalised, area, confidence, ann_type)
VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
:normalised,:area,:confidence,:ann_type)""",
records,
)
await db.commit()
return len(records)
# ── Universal Dataset Items ──────────────────────────────────────────────
async def get_universal_items(
self,
dataset_id: str,
page: int = 0,
page_size: int = 20,
split: str | None = None,
class_label: str | None = None,
) -> tuple[int, list[dict]]:
"""Fetch polymorphic dataset items (images, text rows, etc.) and their annotations."""
db = await get_db()
# 1. Get total and base item records
total, items = await self.get_image_page(dataset_id, page, page_size, split, class_label)
# 2. Convert to universal format
# This is a bridge until we fully move to the universal schema
return total, items
async def bulk_insert_universal_annotations(self, records: list[dict]) -> int:
"""Insert universal annotations into the extended schema."""
if not records:
return 0
db = await get_db()
await db.executemany(
"""INSERT OR IGNORE INTO dataset_annotations
(id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h,
normalised, area, confidence, ann_type, segmentation, keypoints, metadata)
VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h,
:normalised,:area,:confidence,:ann_type,:segmentation,:keypoints,:metadata)""",
records,
)
await db.commit()
return len(records)
async def update_dataset_task(dataset_id: str, task: str) -> None:
db = await get_db()
await db.execute("UPDATE datasets SET task=? WHERE id=?", (task, dataset_id))
await db.commit()
async def cleanup_stale_jobs() -> None:
"""Mark running/queued jobs as failed on startup."""
db = await get_db()
await db.execute(
"UPDATE dataset_jobs SET status='failed', error='System restart' WHERE status IN ('running', 'queued')"
)
await db.commit()
def _fmt_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024:
return f"{n:.1f} {unit}"
n /= 1024
return f"{n:.1f} PB"