Spaces:
Sleeping
Sleeping
| """ | |
| datasets/registry.py β Dataset Registry: persistent CRUD against datasets table. | |
| All DB interactions for datasets and dataset_jobs live here. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from datetime import datetime | |
| from typing import Any | |
| from database.connection import get_db | |
| from models.dataset import Dataset, DatasetJob, DatasetStatus, row_to_dataset, row_to_job | |
| from observability.logger import get_logger | |
| log = get_logger("dataset_registry") | |
| # ββ Dataset CRUD ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_all_datasets( | |
| task: str | None = None, | |
| format: str | None = None, | |
| source: str | None = None, | |
| status: str | None = None, | |
| search: str | None = None, | |
| starred: bool | None = None, | |
| limit: int = 500, | |
| offset: int = 0, | |
| ) -> list[Dataset]: | |
| db = await get_db() | |
| clauses = [] | |
| params: list[Any] = [] | |
| if task: | |
| clauses.append("task = ?") | |
| params.append(task) | |
| if format: | |
| clauses.append("format = ?") | |
| params.append(format) | |
| if source: | |
| clauses.append("source = ?") | |
| params.append(source) | |
| if status: | |
| clauses.append("status = ?") | |
| params.append(status) | |
| if starred is not None: | |
| clauses.append("starred = ?") | |
| params.append(1 if starred else 0) | |
| if search: | |
| clauses.append("(name LIKE ? OR description LIKE ? OR tags LIKE ?)") | |
| q = f"%{search}%" | |
| params.extend([q, q, q]) | |
| where = f"WHERE {' AND '.join(clauses)}" if clauses else "" | |
| sql = f"SELECT * FROM datasets {where} ORDER BY updated_at DESC LIMIT ? OFFSET ?" | |
| params.extend([limit, offset]) | |
| async with db.execute(sql, params) as cur: | |
| rows = await cur.fetchall() | |
| return [row_to_dataset(r) for r in rows] | |
| async def get_dataset_stats(dataset_id: str) -> dict: | |
| """Get pre-computed class distributions and statistics from the indexed annotations.""" | |
| db = await get_db() | |
| # Class distribution (from dataset_annotations table) | |
| async with db.execute( | |
| "SELECT label, COUNT(*) as count FROM dataset_annotations WHERE dataset_id=? GROUP BY label ORDER BY count DESC", | |
| (dataset_id,) | |
| ) as cur: | |
| dist = await cur.fetchall() | |
| # Split distribution (from dataset_images table) | |
| async with db.execute( | |
| "SELECT split, COUNT(*) as count FROM dataset_images WHERE dataset_id=? GROUP BY split", | |
| (dataset_id,) | |
| ) as cur: | |
| splits = await cur.fetchall() | |
| return { | |
| "class_distribution": {row["label"]: row["count"] for row in dist}, | |
| "split_distribution": {row["split"]: row["count"] for row in splits} | |
| } | |
| async def get_dataset(dataset_id: str) -> Dataset | None: | |
| db = await get_db() | |
| async with db.execute("SELECT * FROM datasets WHERE id = ?", (dataset_id,)) as cur: | |
| row = await cur.fetchone() | |
| return row_to_dataset(row) if row else None | |
| async def count_datasets() -> int: | |
| db = await get_db() | |
| async with db.execute("SELECT COUNT(*) FROM datasets") as cur: | |
| row = await cur.fetchone() | |
| return row[0] if row else 0 | |
| async def upsert_dataset(ds: Dataset) -> None: | |
| """Insert or replace a dataset record.""" | |
| db = await get_db() | |
| task = getattr(ds.task, "value", ds.task) | |
| fmt = getattr(ds.format, "value", ds.format) | |
| src = getattr(ds.source, "value", ds.source) | |
| status = getattr(ds.status, "value", ds.status) | |
| await db.execute( | |
| """INSERT OR REPLACE INTO datasets | |
| (id, name, description, task, format, source, status, | |
| images, classes, class_names, size_bytes, size_label, | |
| local_path, import_progress, tags, versions, active_version, | |
| starred, roboflow_id, created_at, updated_at) | |
| VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,datetime('now'))""", | |
| ( | |
| ds.id, ds.name, ds.description, task, fmt, | |
| src, status, | |
| ds.images, ds.classes, | |
| json.dumps(ds.class_names), ds.size_bytes, ds.size_label, | |
| ds.local_path, ds.import_progress, | |
| json.dumps(ds.tags), | |
| json.dumps([v.model_dump() if hasattr(v, "model_dump") else v for v in ds.versions]), | |
| ds.active_version, | |
| 1 if ds.starred else 0, | |
| ds.roboflow_id, | |
| ds.created_at or datetime.utcnow().isoformat(), | |
| ), | |
| ) | |
| await db.commit() | |
| async def update_dataset_status( | |
| dataset_id: str, | |
| status: DatasetStatus, | |
| progress: float | None = None, | |
| local_path: str | None = None, | |
| ) -> None: | |
| db = await get_db() | |
| if progress is not None and local_path is not None: | |
| await db.execute( | |
| "UPDATE datasets SET status=?, import_progress=?, local_path=? WHERE id=?", | |
| (status.value, progress, local_path, dataset_id), | |
| ) | |
| elif progress is not None: | |
| await db.execute( | |
| "UPDATE datasets SET status=?, import_progress=? WHERE id=?", | |
| (status.value, progress, dataset_id), | |
| ) | |
| else: | |
| await db.execute( | |
| "UPDATE datasets SET status=? WHERE id=?", | |
| (status.value, dataset_id), | |
| ) | |
| await db.commit() | |
| async def update_dataset_stats( | |
| dataset_id: str, | |
| images: int, | |
| classes: int, | |
| class_names: list[str], | |
| size_bytes: int, | |
| stats: dict | None = None | |
| ) -> None: | |
| db = await get_db() | |
| # Calculate health score if stats provided | |
| health_score = 0.0 | |
| if stats: | |
| health_score = stats.get("health_score", 0.0) | |
| await db.execute( | |
| """UPDATE datasets | |
| SET images=?, classes=?, class_names=?, size_bytes=?, | |
| size_label=?, stats=?, health_score=? | |
| WHERE id=?""", | |
| ( | |
| images, classes, json.dumps(class_names), | |
| size_bytes, _fmt_bytes(size_bytes), | |
| json.dumps(stats) if stats else None, | |
| health_score, | |
| dataset_id, | |
| ), | |
| ) | |
| await db.commit() | |
| async def delete_dataset(dataset_id: str) -> bool: | |
| db = await get_db() | |
| async with db.execute("SELECT 1 FROM datasets WHERE id=?", (dataset_id,)) as cur: | |
| exists = await cur.fetchone() | |
| if not exists: | |
| return False | |
| await db.execute("DELETE FROM datasets WHERE id=?", (dataset_id,)) | |
| await db.commit() | |
| return True | |
| async def toggle_starred(dataset_id: str) -> bool: | |
| """Toggle starred flag, return new value.""" | |
| db = await get_db() | |
| async with db.execute("SELECT starred FROM datasets WHERE id=?", (dataset_id,)) as cur: | |
| row = await cur.fetchone() | |
| if not row: | |
| return False | |
| new_val = 0 if row["starred"] else 1 | |
| await db.execute("UPDATE datasets SET starred=? WHERE id=?", (new_val, dataset_id)) | |
| await db.commit() | |
| return bool(new_val) | |
| # ββ Bulk dataset upsert from Roboflow ββββββββββββββββββββββββββββββββββββββββ | |
| async def bulk_upsert_datasets(datasets: list[Dataset]) -> int: | |
| """Insert/update many datasets in a single transaction.""" | |
| if not datasets: | |
| return 0 | |
| db = await get_db() | |
| now = datetime.utcnow().isoformat() | |
| rows = [ | |
| ( | |
| ds.id, ds.name, ds.description, ds.task.value, ds.format.value, | |
| ds.source.value, ds.status.value, | |
| ds.images, ds.classes, | |
| json.dumps(ds.class_names), ds.size_bytes, ds.size_label, | |
| ds.local_path, ds.import_progress, | |
| json.dumps(ds.tags), json.dumps([]), | |
| ds.active_version, 0, ds.roboflow_id, | |
| ds.created_at or now, | |
| ) | |
| for ds in datasets | |
| ] | |
| await db.executemany( | |
| """INSERT OR IGNORE INTO datasets | |
| (id, name, description, task, format, source, status, | |
| images, classes, class_names, size_bytes, size_label, | |
| local_path, import_progress, tags, versions, active_version, | |
| starred, roboflow_id, created_at) | |
| VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", | |
| rows, | |
| ) | |
| await db.commit() | |
| return len(datasets) | |
| # ββ Dataset Jobs ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def create_job( | |
| dataset_id: str, | |
| dataset_name: str, | |
| job_type: str, | |
| ) -> DatasetJob: | |
| db = await get_db() | |
| job_id = f"djob-{uuid.uuid4().hex[:12]}" | |
| now = datetime.utcnow().isoformat() | |
| await db.execute( | |
| """INSERT INTO dataset_jobs | |
| (id, type, status, dataset_id, dataset_name, progress, message, created_at) | |
| VALUES (?, ?, 'queued', ?, ?, 0.0, '', ?)""", | |
| (job_id, job_type, dataset_id, dataset_name, now), | |
| ) | |
| await db.commit() | |
| return DatasetJob( | |
| id=job_id, type=job_type, status="queued", | |
| dataset_id=dataset_id, dataset_name=dataset_name, | |
| created_at=now, | |
| ) | |
| async def update_job( | |
| job_id: str, | |
| status: str | None = None, | |
| progress: float | None = None, | |
| message: str | None = None, | |
| error: str | None = None, | |
| started_at: str | None = None, | |
| ended_at: str | None = None, | |
| ) -> None: | |
| db = await get_db() | |
| parts = [] | |
| params: list[Any] = [] | |
| if status is not None: | |
| parts.append("status=?"); params.append(status) | |
| if progress is not None: | |
| parts.append("progress=?"); params.append(progress) | |
| if message is not None: | |
| parts.append("message=?"); params.append(message) | |
| if error is not None: | |
| parts.append("error=?"); params.append(error) | |
| if started_at is not None: | |
| parts.append("started_at=?"); params.append(started_at) | |
| if ended_at is not None: | |
| parts.append("ended_at=?"); params.append(ended_at) | |
| if not parts: | |
| return | |
| params.append(job_id) | |
| await db.execute(f"UPDATE dataset_jobs SET {', '.join(parts)} WHERE id=?", params) | |
| await db.commit() | |
| async def get_job(job_id: str) -> DatasetJob | None: | |
| db = await get_db() | |
| async with db.execute("SELECT * FROM dataset_jobs WHERE id=?", (job_id,)) as cur: | |
| row = await cur.fetchone() | |
| return row_to_job(row) if row else None | |
| async def get_all_jobs(limit: int = 100) -> list[DatasetJob]: | |
| db = await get_db() | |
| async with db.execute( | |
| "SELECT * FROM dataset_jobs ORDER BY created_at DESC LIMIT ?", (limit,) | |
| ) as cur: | |
| rows = await cur.fetchall() | |
| return [row_to_job(r) for r in rows] | |
| # ββ Image Index βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def index_images( | |
| dataset_id: str, | |
| records: list[dict], # [{id, filename, rel_path, width, height, split, ann_count}] | |
| ) -> int: | |
| db = await get_db() | |
| await db.executemany( | |
| """INSERT OR IGNORE INTO dataset_images | |
| (id, dataset_id, filename, rel_path, width, height, split, ann_count) | |
| VALUES (:id, :dataset_id, :filename, :rel_path, :width, :height, :split, :ann_count)""", | |
| [{"dataset_id": dataset_id, **r} for r in records], | |
| ) | |
| await db.commit() | |
| return len(records) | |
| async def get_image_page( | |
| dataset_id: str, | |
| page: int = 0, | |
| page_size: int = 20, | |
| split: str | None = None, | |
| class_label: str | None = None, | |
| ) -> tuple[int, list[dict]]: | |
| db = await get_db() | |
| clauses = ["dataset_id=?"] | |
| params: list[Any] = [dataset_id] | |
| if split: | |
| clauses.append("split=?") | |
| params.append(split) | |
| if class_label: | |
| # Join with annotations table to filter by class | |
| where = f"WHERE {' AND '.join(clauses)} AND id IN (SELECT image_id FROM dataset_annotations WHERE label=?)" | |
| count_params = params + [class_label] | |
| else: | |
| where = f"WHERE {' AND '.join(clauses)}" | |
| count_params = params | |
| async with db.execute(f"SELECT COUNT(*) FROM dataset_images {where}", count_params) as cur: | |
| total = (await cur.fetchone())[0] | |
| params_final = count_params + [page_size, page * page_size] | |
| async with db.execute( | |
| f"SELECT * FROM dataset_images {where} ORDER BY filename LIMIT ? OFFSET ?", params_final | |
| ) as cur: | |
| rows = await cur.fetchall() | |
| return total, [dict(r) for r in rows] | |
| async def get_annotations_for_image(image_id: str) -> list[dict]: | |
| db = await get_db() | |
| async with db.execute( | |
| "SELECT * FROM dataset_annotations WHERE image_id=?", (image_id,) | |
| ) as cur: | |
| rows = await cur.fetchall() | |
| return [dict(r) for r in rows] | |
| async def bulk_insert_annotations(records: list[dict]) -> int: | |
| if not records: | |
| return 0 | |
| db = await get_db() | |
| await db.executemany( | |
| """INSERT OR IGNORE INTO dataset_annotations | |
| (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h, | |
| normalised, area, confidence, ann_type) | |
| VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h, | |
| :normalised,:area,:confidence,:ann_type)""", | |
| records, | |
| ) | |
| await db.commit() | |
| return len(records) | |
| # ββ Universal Dataset Items ββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_universal_items( | |
| self, | |
| dataset_id: str, | |
| page: int = 0, | |
| page_size: int = 20, | |
| split: str | None = None, | |
| class_label: str | None = None, | |
| ) -> tuple[int, list[dict]]: | |
| """Fetch polymorphic dataset items (images, text rows, etc.) and their annotations.""" | |
| db = await get_db() | |
| # 1. Get total and base item records | |
| total, items = await self.get_image_page(dataset_id, page, page_size, split, class_label) | |
| # 2. Convert to universal format | |
| # This is a bridge until we fully move to the universal schema | |
| return total, items | |
| async def bulk_insert_universal_annotations(self, records: list[dict]) -> int: | |
| """Insert universal annotations into the extended schema.""" | |
| if not records: | |
| return 0 | |
| db = await get_db() | |
| await db.executemany( | |
| """INSERT OR IGNORE INTO dataset_annotations | |
| (id, image_id, dataset_id, label, bbox_x, bbox_y, bbox_w, bbox_h, | |
| normalised, area, confidence, ann_type, segmentation, keypoints, metadata) | |
| VALUES (:id,:image_id,:dataset_id,:label,:bbox_x,:bbox_y,:bbox_w,:bbox_h, | |
| :normalised,:area,:confidence,:ann_type,:segmentation,:keypoints,:metadata)""", | |
| records, | |
| ) | |
| await db.commit() | |
| return len(records) | |
| async def update_dataset_task(dataset_id: str, task: str) -> None: | |
| db = await get_db() | |
| await db.execute("UPDATE datasets SET task=? WHERE id=?", (task, dataset_id)) | |
| await db.commit() | |
| async def cleanup_stale_jobs() -> None: | |
| """Mark running/queued jobs as failed on startup.""" | |
| db = await get_db() | |
| await db.execute( | |
| "UPDATE dataset_jobs SET status='failed', error='System restart' WHERE status IN ('running', 'queued')" | |
| ) | |
| await db.commit() | |
| def _fmt_bytes(n: int) -> str: | |
| for unit in ("B", "KB", "MB", "GB", "TB"): | |
| if n < 1024: | |
| return f"{n:.1f} {unit}" | |
| n /= 1024 | |
| return f"{n:.1f} PB" | |