diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..81ec9858a74d9f535dadd3f7055b3337fd02ef94 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +# Use a lightweight Python image +FROM python:3.10-slim + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PORT=7860 + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements and install +# Note: requirements.txt should be in the same directory as Dockerfile (backend/) +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the backend code +COPY . . + +# HuggingFace Spaces uses port 7860 by default +EXPOSE 7860 + +# Run the FastAPI app +# We use 0.0.0.0 to allow external connections within the container +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/__pycache__/__init__.cpython-310.pyc b/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30603ebe88441542b5191229ffa494a527119654 Binary files /dev/null and b/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/api/routes/__init__.py b/api/routes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/routes/__pycache__/__init__.cpython-310.pyc b/api/routes/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63d7a75e3ec6ebc14ca64541348f9d5ca1448741 Binary files /dev/null and b/api/routes/__pycache__/__init__.cpython-310.pyc differ diff --git a/api/routes/__pycache__/benchmark.cpython-310.pyc b/api/routes/__pycache__/benchmark.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a861e124c9179113ad392d95e6b57a676c86c61 Binary files /dev/null and b/api/routes/__pycache__/benchmark.cpython-310.pyc differ diff --git a/api/routes/__pycache__/datasets.cpython-310.pyc b/api/routes/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcbc069f868c7bece7bb97332d3a7e517f33a91f Binary files /dev/null and b/api/routes/__pycache__/datasets.cpython-310.pyc differ diff --git a/api/routes/__pycache__/inference.cpython-310.pyc b/api/routes/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2141a177928cf7b6053f8a8e636b272af1f9556d Binary files /dev/null and b/api/routes/__pycache__/inference.cpython-310.pyc differ diff --git a/api/routes/__pycache__/jobs.cpython-310.pyc b/api/routes/__pycache__/jobs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9a6244b3beceeb76fcb6f29fd0bbe5adf1077f Binary files /dev/null and b/api/routes/__pycache__/jobs.cpython-310.pyc differ diff --git a/api/routes/__pycache__/models.cpython-310.pyc b/api/routes/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc5f6d4d777224e1eb209dfbd2c14809ce41dbe0 Binary files /dev/null and b/api/routes/__pycache__/models.cpython-310.pyc differ diff --git a/api/routes/__pycache__/projects.cpython-310.pyc b/api/routes/__pycache__/projects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2453882aeb818938927b99793a30935f933f29 Binary files /dev/null and b/api/routes/__pycache__/projects.cpython-310.pyc differ diff --git a/api/routes/__pycache__/sync.cpython-310.pyc b/api/routes/__pycache__/sync.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598a3b8eb2e6a6ca87f368ba9588cea2b23bebe0 Binary files /dev/null and b/api/routes/__pycache__/sync.cpython-310.pyc differ diff --git a/api/routes/__pycache__/system.cpython-310.pyc b/api/routes/__pycache__/system.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72bcbbc36efe9a80a09a2447fe9b96222d15f3fa Binary files /dev/null and b/api/routes/__pycache__/system.cpython-310.pyc differ diff --git a/api/routes/__pycache__/training.cpython-310.pyc b/api/routes/__pycache__/training.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9495edf337ff405bbb4a38ac4d33ba748a869bd Binary files /dev/null and b/api/routes/__pycache__/training.cpython-310.pyc differ diff --git a/api/routes/benchmark.py b/api/routes/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..63d8a1b9b078bafdc4e152c01cf823c4c7f85db1 --- /dev/null +++ b/api/routes/benchmark.py @@ -0,0 +1,238 @@ +""" +api/routes/benchmark.py — Benchmark Bridge REST + WebSocket API. + +Routes: + POST /benchmark/validate — compatibility check (no job created) + POST /benchmark/run — validate + create + enqueue (202) + GET /benchmark/jobs — list jobs (filterable) + GET /benchmark/results/all — list all results + GET /benchmark/{job_id} — single job status + logs + GET /benchmark/{job_id}/result — metrics + telemetry for completed job + WS /benchmark/live/{job_id} — real-time progress stream +""" +from __future__ import annotations + +import asyncio +from typing import Any + +from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect + +import benchmark.orchestrator as orchestrator +import benchmark.registry as bench_reg +from models.benchmark import ( + BenchmarkContext, + BenchmarkJob, + BenchmarkResult, + BenchmarkRunResponse, + ValidationReport, +) +from observability.logger import get_logger + +log = get_logger("api.benchmark") + +router = APIRouter(prefix="/benchmark", tags=["benchmark"]) + + +# ── POST /benchmark/validate ────────────────────────────────────────────────── + +@router.post( + "/validate", + response_model = ValidationReport, + summary = "Validate model ↔ dataset compatibility", + description = ( + "Runs all 5 compatibility gates (task, format, framework×hardware, " + "VRAM, precision) and returns a structured report. " + "Does NOT create a benchmark job." + ), +) +async def validate_benchmark(ctx: BenchmarkContext) -> ValidationReport: + try: + return await orchestrator.validate_context(ctx) + except HTTPException: + raise + except Exception as exc: + log.exception("validate_error") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +# ── POST /benchmark/run ─────────────────────────────────────────────────────── + +@router.post( + "/run", + response_model = BenchmarkRunResponse, + status_code = 202, + summary = "Start a benchmark run", + description = ( + "Validates compatibility, creates a benchmark job, and starts async " + "execution. Returns job_id immediately — poll GET /benchmark/{job_id} " + "or connect to WS /benchmark/live/{job_id} for progress." + ), +) +async def run_benchmark(ctx: BenchmarkContext) -> BenchmarkRunResponse: + try: + job = await orchestrator.create_and_run(ctx) + return BenchmarkRunResponse( + job_id = job.id, + status = job.status, + message = f"Benchmark job {job.id} queued — connect to /benchmark/live/{job.id} for live telemetry", + ) + except HTTPException: + raise + except Exception as exc: + log.exception("run_benchmark_error") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +# ── POST /benchmark/sync ────────────────────────────────────────────────────────── + +@router.post( + "/sync", + summary = "Sync benchmarks from active project folder", + description = "Scans the active project's 'benchmarks' folder and ensures all JSON records are indexed in SQLite.", +) +async def sync_benchmarks() -> dict[str, Any]: + try: + count = await orchestrator.sync_project_benchmarks() + return {"status": "success", "count": count} + except Exception as exc: + log.exception("sync_error") + raise HTTPException(status_code=500, detail=str(exc)) from exc + + +# ── GET /benchmark/jobs ─────────────────────────────────────────────────────── + +@router.get( + "/jobs", + response_model = list[BenchmarkJob], + summary = "List benchmark jobs", +) +async def list_jobs( + status: str | None = Query(None, description="Filter by status (queued|running|completed|failed)"), + model_id: str | None = Query(None, description="Filter by model_id"), + limit: int = Query(100, ge=1, le=500), +) -> list[BenchmarkJob]: + return await bench_reg.list_jobs(status=status, model_id=model_id, limit=limit) + + +# ── GET /benchmark/results/all ──────────────────────────────────────────────── +# Must be declared BEFORE /{job_id} to avoid "results" being treated as a job_id + +@router.get( + "/results/all", + response_model = list[BenchmarkResult], + summary = "List all benchmark results (leaderboard feed)", +) +async def list_results( + limit: int = Query(100, ge=1, le=500), +) -> list[BenchmarkResult]: + return await bench_reg.list_results(limit=limit) + + +# ── GET /benchmark/{job_id} ─────────────────────────────────────────────────── + +@router.get( + "/{job_id}", + response_model = BenchmarkJob, + summary = "Get benchmark job status + logs", +) +async def get_job(job_id: str) -> BenchmarkJob: + job = await bench_reg.get_job(job_id) + if not job: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + return job + + +# ── GET /benchmark/{job_id}/result ─────────────────────────────────────────── + +@router.get( + "/{job_id}/result", + response_model = BenchmarkResult, + summary = "Get final metrics + telemetry for a completed job", +) +async def get_result(job_id: str) -> BenchmarkResult: + result = await bench_reg.get_result(job_id) + if not result: + raise HTTPException( + status_code = 404, + detail = f"No result for job '{job_id}' — job may still be running", + ) + return result + + +# ── WS /benchmark/live/{job_id} ─────────────────────────────────────────────── + +@router.websocket("/live/{job_id}") +async def live_telemetry(websocket: WebSocket, job_id: str) -> None: + """ + WebSocket stream for real-time benchmark progress. + Streams incremental logs and high-frequency telemetry. + """ + await websocket.accept() + log.info("ws_connected", job_id=job_id) + + last_log_idx = 0 + + try: + while True: + job = await bench_reg.get_job(job_id) + + if not job: + await websocket.send_json( + {"error": f"Job '{job_id}' not found", "job_id": job_id} + ) + break + + # Only send new logs + new_logs = job.logs[last_log_idx:] + last_log_idx = len(job.logs) + + payload: dict[str, Any] = { + "job_id": job.id, + "status": job.status, + "progress": round(job.progress, 4), + "logs": new_logs, + "telemetry": job.last_telemetry.model_dump() if job.last_telemetry else None, + } + # Explicitly include detections for the UI visualizer if they exist + if job.last_telemetry and hasattr(job.last_telemetry, "detections"): + payload["detections"] = job.last_telemetry.detections + + await websocket.send_json(payload) + + if job.status == "completed": + result = await bench_reg.get_result(job_id) + if result: + await websocket.send_json( + { + "job_id": job_id, + "status": "completed", + "result": result.model_dump(), + } + ) + break + + if job.status == "failed": + await websocket.send_json( + { + "job_id": job_id, + "status": "failed", + "error": job.error or "Unknown error", + } + ) + break + + await asyncio.sleep(0.5) # poll at 2 Hz + + except WebSocketDisconnect: + log.info("ws_disconnected", job_id=job_id) + except Exception as exc: + log.exception("ws_error", job_id=job_id) + try: + await websocket.send_json({"error": str(exc), "job_id": job_id}) + except Exception: + pass + finally: + try: + await websocket.close() + except Exception: + pass diff --git a/api/routes/datasets.py b/api/routes/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7eba8f03a4e12ad05c19f46b76c1ec0539d1f3 --- /dev/null +++ b/api/routes/datasets.py @@ -0,0 +1,395 @@ +""" +api/routes/datasets.py — Dataset Manager REST API. + +Routes: + GET /datasets — list/search datasets + GET /datasets/{id} — dataset detail + POST /datasets/search/roboflow — search Roboflow Universe (real-time) + POST /datasets/sync/roboflow — sync workspace datasets to local DB + POST /datasets/{id}/import — initiate dataset import job + GET /datasets/{id}/images — paginated viewer (images + annotations) + GET /datasets/{id}/image/{img} — serve raw image bytes + GET /datasets/jobs — list import jobs + GET /datasets/jobs/{job_id} — single job status + POST /datasets/{id}/star — toggle starred + DELETE /datasets/{id} — delete dataset record (+ local files) +""" +from __future__ import annotations + +from pathlib import Path +from typing import Optional +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import FileResponse, Response + +from adapters.roboflow_adapter import RoboflowAdapter +from datasets import registry as ds_reg +from datasets.import_service import start_import +from datasets.viewer_service import get_universal_viewer_page, get_viewer_page, resolve_image_path +from models.dataset import ( + Dataset, DatasetJob, DatasetSummary, DatasetSource, DatasetTask, + DatasetFormat, DatasetStatus, ImportRequest, ImportResponse, + RoboflowSearchRequest, ViewerPage, UniversalViewerPage, row_to_dataset, +) +from observability.logger import audit, get_logger + +log = get_logger("datasets_route") + +router = APIRouter(prefix="/datasets", tags=["datasets"]) + + +# ── List / Search datasets ──────────────────────────────────────────────────── + +@router.get("", response_model=list[DatasetSummary]) +async def list_datasets( + task: Optional[str] = Query(None), + format: Optional[str] = Query(None), + source: Optional[str] = Query(None), + status: Optional[str] = Query(None), + search: Optional[str] = Query(None), + starred: Optional[bool] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), +): + try: + datasets = await ds_reg.get_all_datasets( + task=task, format=format, source=source, + status=status, search=search, starred=starred, + limit=limit, offset=offset, + ) + return [_to_summary(d) for d in datasets] + except Exception as exc: + log.exception("list_datasets_error") + raise HTTPException(status_code=500, detail=str(exc)) + + +@router.get("/jobs", response_model=list[DatasetJob]) +async def list_jobs(limit: int = Query(50, ge=1, le=500)): + return await ds_reg.get_all_jobs(limit=limit) + + +@router.get("/jobs/{job_id}", response_model=DatasetJob) +async def get_job(job_id: str): + job = await ds_reg.get_job(job_id) + if not job: + raise HTTPException(404, f"Job {job_id!r} not found") + return job + + +@router.post("/jobs/{job_id}/stop") +async def stop_job(job_id: str): + """Cancel a running import job.""" + # Logic to cancel the asyncio task would go here + # For now, we update the status in the DB + await ds_reg.update_job(job_id, status="failed", error="Cancelled by user", ended_at=datetime.utcnow().isoformat()) + return {"status": "success", "message": "Job stop requested"} + + +@router.post("/jobs/{job_id}/pause") +async def pause_job(job_id: str): + """Pause a running import job.""" + await ds_reg.update_job(job_id, status="paused") + return {"status": "success", "message": "Job pause requested"} + + +@router.post("/jobs/{job_id}/resume") +async def resume_job(job_id: str): + """Resume a paused import job.""" + await ds_reg.update_job(job_id, status="running") + return {"status": "success", "message": "Job resume requested"} + + +# ── Roboflow Search & Sync ──────────────────────────────────────────────────── + +@router.post("/search/roboflow", response_model=list[DatasetSummary]) +async def search_roboflow(req: RoboflowSearchRequest): + """ + Live search Roboflow Universe. Results are cached for 1 hour. + Also upserts results into local DB so they appear in /datasets. + """ + try: + datasets = await RoboflowAdapter.search_datasets( + api_key = req.api_key, + query = req.query, + workspace = req.workspace, + page = req.page, + page_size = req.page_size, + ) + except Exception as exc: + log.error("roboflow_search_error", error=str(exc)) + raise HTTPException(502, f"Roboflow API error: {exc}") + + # Upsert to local DB + await ds_reg.bulk_upsert_datasets(datasets) + await audit("roboflow_search", {"query": req.query, "count": len(datasets)}) + return [_to_summary(d) for d in datasets] + + +@router.post("/sync/roboflow", response_model=dict) +async def sync_roboflow_workspace( + api_key: str = Query(..., description="Roboflow API key"), + workspace: str = Query(..., description="Workspace slug"), +): + """Sync all datasets from a Roboflow workspace into local DB.""" + try: + datasets = await RoboflowAdapter.list_workspace_datasets(api_key, workspace) + except Exception as exc: + raise HTTPException(502, f"Roboflow API error: {exc}") + count = await ds_reg.bulk_upsert_datasets(datasets) + return {"synced": count, "workspace": workspace} + + +# ── Dataset detail ──────────────────────────────────────────────────────────── + +@router.get("/{dataset_id}", response_model=Dataset) +async def get_dataset(dataset_id: str): + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found") + return ds + + +# ── Import ──────────────────────────────────────────────────────────────────── + +@router.post("/{dataset_id}/import", response_model=ImportResponse) +async def import_dataset(dataset_id: str, req: ImportRequest): + """ + Initiate a background import job for a dataset. + Supports sources: roboflow | roboflow_curl | huggingface | local + """ + req.dataset_id = dataset_id # enforce consistency + + # Sources that are discovered outside the registry must be auto-registered. + auto_register_sources = {DatasetSource.huggingface, DatasetSource.roboflow_curl, DatasetSource.local} + if req.source in auto_register_sources: + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + # Determine human-readable name + if req.source == DatasetSource.huggingface and req.hf_dataset_id: + name = req.hf_dataset_id + roboflow_ref = req.hf_dataset_id + fmt = DatasetFormat.json + src = DatasetSource.huggingface + + elif req.source == DatasetSource.local: + # local: use provided name or folder name from path + # Try req.local_path first, then req.name, then fallback to dataset_id + path_to_use = req.local_path or req.name or "" + name = req.name or (Path(path_to_use).name if path_to_use else dataset_id) + roboflow_ref = None + fmt = DatasetFormat.custom + src = DatasetSource.local + else: + # roboflow_curl: use provided dataset_name or fall back to dataset_id + name = req.dataset_name or dataset_id + roboflow_ref = None + fmt = _curl_format_to_enum(req.curl_format) + src = DatasetSource.roboflow_curl + + stub = Dataset( + id=dataset_id, + name=name, + task=DatasetTask.detection, + format=fmt, + source=src, + status=DatasetStatus.available, + roboflow_id=roboflow_ref, + created_at=datetime.utcnow().isoformat(), + ) + await ds_reg.upsert_dataset(stub) + log.info("dataset_auto_registered", dataset_id=dataset_id, source=str(req.source)) + else: + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found in registry. " + "Run /datasets/sync/roboflow first.") + + try: + job_id = await start_import(req) + except ValueError as exc: + raise HTTPException(400, str(exc)) + + await audit("dataset_import_requested", {"dataset_id": dataset_id, "source": str(req.source)}) + return ImportResponse( + job_id = job_id, + dataset_id = dataset_id, + status = "queued", + message = "Import job created successfully", + ) + + +def _curl_format_to_enum(curl_format: str | None) -> DatasetFormat: + """Map Roboflow export format string from cURL to DatasetFormat enum.""" + if not curl_format: + return DatasetFormat.yolo + fmt = curl_format.lower() + if "yolo" in fmt: + return DatasetFormat.yolo + if "coco" in fmt: + return DatasetFormat.coco + if "voc" in fmt or "pascal" in fmt: + return DatasetFormat.voc + if "tfrecord" in fmt: + return DatasetFormat.tfrecord + if "csv" in fmt: + return DatasetFormat.csv + if "json" in fmt or "createml" in fmt: + return DatasetFormat.json + return DatasetFormat.yolo + + +# ── Viewer ──────────────────────────────────────────────────────────────────── + +@router.get("/{dataset_id}/universal", response_model=UniversalViewerPage) +async def get_universal_items( + dataset_id: str, + page: int = Query(0, ge=0), + page_size: int = Query(20, ge=1, le=100), + split: Optional[str] = Query(None, regex="^(train|val|test)$"), + class_label: Optional[str] = Query(None), +): + """ + Polymorphic dataset item viewer (UDV). + Supports Vision, NLP, and Tabular data via the Universal Dataset Item schema. + """ + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found") + + # Allow viewing even if not fully imported for NLP/Tabular if files exist, + # but for Vision we usually need the index. + return await get_universal_viewer_page(dataset_id, page, page_size, split, class_label) + + +@router.get("/{dataset_id}/images", response_model=ViewerPage) +async def get_images( + dataset_id: str, + page: int = Query(0, ge=0), + page_size: int = Query(20, ge=1, le=100), + split: Optional[str] = Query(None, regex="^(train|val|test)$"), + class_label: Optional[str] = Query(None), +): + """ + Paginated image + annotation data for the viewer. + Annotations are returned in normalised [0–1] coordinates. + Supports filtering by split and class label. + """ + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found") + if ds.status != "imported": + raise HTTPException(409, f"Dataset is not imported yet (status: {ds.status})") + + return await get_viewer_page(dataset_id, page, page_size, split, class_label) + + +@router.get("/{dataset_id}/stats", response_model=dict) +async def get_dataset_stats(dataset_id: str): + """Return pre-computed class distributions and split statistics.""" + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found") + + return await ds_reg.get_dataset_stats(dataset_id) + + +@router.get("/{dataset_id}/image/{image_id}") +async def serve_image(dataset_id: str, image_id: str): + """Serve raw image bytes for the viewer (cached by browser).""" + path = await resolve_image_path(dataset_id, image_id) + if path is None: + raise HTTPException(404, "Image not found or dataset not imported") + + suffix = path.suffix.lower() + media_types = { + ".jpg": "image/jpeg", ".jpeg": "image/jpeg", + ".png": "image/png", ".bmp": "image/bmp", + ".webp": "image/webp", + } + media_type = media_types.get(suffix, "application/octet-stream") + return FileResponse( + path = str(path), + media_type = media_type, + headers = {"Cache-Control": "public, max-age=86400"}, + ) + + +@router.get("/{dataset_id}/annotations", response_model=dict) +async def get_annotations_summary(dataset_id: str): + """Return class distribution summary from the annotations index.""" + from database.connection import get_db + db = await get_db() + async with db.execute( + """SELECT label, COUNT(*) as count + FROM dataset_annotations + WHERE dataset_id=? + GROUP BY label + ORDER BY count DESC""", + (dataset_id,), + ) as cur: + rows = await cur.fetchall() + return { + "dataset_id": dataset_id, + "class_distribution": [{"label": r["label"], "count": r["count"]} for r in rows], + "total_annotations": sum(r["count"] for r in rows), + } + + +# ── Star / Delete ───────────────────────────────────────────────────────────── + +@router.post("/{dataset_id}/star", response_model=dict) +async def toggle_star(dataset_id: str): + new_val = await ds_reg.toggle_starred(dataset_id) + return {"dataset_id": dataset_id, "starred": new_val} + + +@router.delete("/{dataset_id}", response_model=dict) +async def delete_dataset( + dataset_id: str, + delete_files: bool = Query(False, description="Also remove local files from disk"), +): + ds = await ds_reg.get_dataset(dataset_id) + if not ds: + raise HTTPException(404, f"Dataset {dataset_id!r} not found") + + if delete_files and ds.local_path: + import shutil + local = Path(ds.local_path) + if local.exists() and local.is_dir(): + shutil.rmtree(str(local), ignore_errors=True) + log.info("dataset_files_deleted", path=str(local)) + + deleted = await ds_reg.delete_dataset(dataset_id) + await audit("dataset_deleted", {"dataset_id": dataset_id, "files_deleted": delete_files}) + return {"deleted": deleted, "dataset_id": dataset_id} + + +# ── Helper ──────────────────────────────────────────────────────────────────── + +def _to_summary(d: Dataset) -> DatasetSummary: + # Use 0.0 as default health_score if stats is missing or health_score is not present + health_score = 0.0 + try: + if hasattr(d, 'stats') and d.stats: + health_score = getattr(d.stats, 'health_score', 0.0) + except Exception: + pass + + return DatasetSummary( + id = d.id, + name = d.name, + task = str(d.task), + format = str(d.format), + source = str(d.source), + status = str(d.status), + images = d.images, + classes = d.classes, + size_label = d.size_label, + tags = d.tags, + starred = d.starred, + import_progress = d.import_progress, + health_score = health_score, + created_at = d.created_at, + updated_at = d.updated_at, + ) diff --git a/api/routes/inference.py b/api/routes/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb7039a84aa747d69a63ece340e94c4885b47af --- /dev/null +++ b/api/routes/inference.py @@ -0,0 +1,168 @@ +""" +api/routes/inference.py — Inference Engine endpoints. + +POST /inference/run — single synchronous inference pass +POST /inference/stream — SSE stream (stage-by-stage pipeline events) +GET /inference/history — session ledger +DELETE /inference/history — clear session ledger +GET /inference/cache — currently warm models in memory +DELETE /inference/cache/{model_id} — evict from cache +""" +from __future__ import annotations + +import asyncio +import json +import time + +from fastapi import APIRouter, HTTPException, Response +from fastapi.responses import StreamingResponse + +from inference.engine import InferenceEngine, evict_model, get_cache_status +from inference.session import clear_history, get_history, record +from models.inference import ( + InferenceHistoryEntry, + InferenceRequest, + InferenceResult, +) +from observability.logger import get_logger +from registry.registry import get_model + +log = get_logger("api.inference") +router = APIRouter(prefix="/inference", tags=["inference"]) + +_engine = InferenceEngine() + + +# ── Single run ─────────────────────────────────────────────────────────────── + +@router.post("/run", response_model=InferenceResult) +async def run_inference(body: InferenceRequest) -> InferenceResult: + """Execute one full inference pass and return the complete result.""" + model = await get_model(body.model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found") + + result = await _engine.run(body, model) + + if result.status == "error": + raise HTTPException(status_code=500, detail=result.error or "Inference failed") + + await record(body, result, model.name) + return result + + +# ── SSE stream ─────────────────────────────────────────────────────────────── + +@router.post("/stream") +async def stream_inference(body: InferenceRequest) -> StreamingResponse: + """ + Server-Sent Events stream. + Emits one JSON event per pipeline stage as it completes, then a final + 'done' event with the full InferenceResult. + + Client usage: + const es = new EventSource('/inference/stream'); // POST via fetch + EventSource polyfill + es.onmessage = e => console.log(JSON.parse(e.data)); + """ + model = await get_model(body.model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found") + + queue: asyncio.Queue[str | None] = asyncio.Queue() + + async def _producer() -> None: + """Run inference while pushing SSE events into the queue.""" + try: + # Patch engine to emit stage events + result = await _engine_stream(body, model, queue) + await record(body, result, model.name) + # Final complete event + await queue.put( + f"event: done\ndata: {result.model_dump_json()}\n\n" + ) + except Exception as exc: + await queue.put( + f"event: error\ndata: {json.dumps({'error': str(exc)})}\n\n" + ) + finally: + await queue.put(None) # sentinel + + asyncio.create_task(_producer()) + + async def _generator(): + while True: + msg = await queue.get() + if msg is None: + break + yield msg + + return StreamingResponse( + _generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + +async def _engine_stream( + req: InferenceRequest, + model, + queue: asyncio.Queue, +) -> InferenceResult: + """ + Run inference and push a 'stage' SSE event for each PipelineStage. + Falls back to a simple full run if streaming is not distinguishable. + """ + # Run full pipeline + result = await _engine.run(req, model) + + # Emit one event per stage (replay after completion) + for stage in result.pipeline: + payload = json.dumps({ + "type": "stage", + "stage": stage.model_dump(), + "ts": time.time(), + }) + await queue.put(f"data: {payload}\n\n") + await asyncio.sleep(0) # yield + + # Emit vitals snapshot + vitals_payload = json.dumps({ + "type": "vitals", + "latency_ms": result.inference_ms, + "total_ms": result.total_ms, + "quality": result.quality_score, + }) + await queue.put(f"data: {vitals_payload}\n\n") + + return result + + +# ── History ────────────────────────────────────────────────────────────────── + +@router.get("/history", response_model=list[InferenceHistoryEntry]) +async def inference_history(limit: int = 50) -> list[InferenceHistoryEntry]: + return await get_history(limit=min(limit, 200)) + + +@router.delete("/history", status_code=204, response_model=None) +async def clear_inference_history(): + await clear_history() + return Response(status_code=204) + + +# ── Model cache ────────────────────────────────────────────────────────────── + +@router.get("/cache") +async def cache_status() -> dict[str, bool]: + return get_cache_status() + + +@router.delete("/cache/{model_id}", status_code=204, response_model=None) +async def evict_from_cache(model_id: str): + evicted = evict_model(model_id) + if not evicted: + raise HTTPException(status_code=404, detail="Model not in cache") + return Response(status_code=204) diff --git a/api/routes/jobs.py b/api/routes/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..065f0e1ef1f075ec8d097ee56613eb0e0a5a3346 --- /dev/null +++ b/api/routes/jobs.py @@ -0,0 +1,56 @@ +""" +api/routes/jobs.py — /jobs & /download endpoints. +""" +from __future__ import annotations + +from fastapi import APIRouter, HTTPException + +from download.manager import cancel_job, enqueue_download, get_job, list_jobs +from models.job import Job, JobCreate +from observability.logger import audit, get_logger +from registry.registry import get_model + +log = get_logger("api.jobs") +router = APIRouter(tags=["jobs"]) + + +@router.post("/download", response_model=Job, status_code=202) +async def trigger_download(body: JobCreate) -> Job: + """Enqueue a model download. Returns the created job immediately.""" + model = await get_model(body.model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found") + if model.downloaded: + raise HTTPException(status_code=409, detail="Model is already cached locally") + + job_id = await enqueue_download( + model_id=body.model_id, + model_name=body.model_name, + version=body.version, + ) + job = await get_job(job_id) + if not job: + raise HTTPException(status_code=500, detail="Job creation failed") + + await audit("api_download_trigger", model_id=body.model_id, job_id=job_id) + return job + + +@router.get("/jobs", response_model=list[Job]) +async def jobs_list(status: str | None = None, limit: int = 50) -> list[Job]: + return await list_jobs(status=status, limit=limit) + + +@router.get("/jobs/{job_id}", response_model=Job) +async def job_detail(job_id: str) -> Job: + job = await get_job(job_id) + if not job: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + return job + + +@router.delete("/jobs/{job_id}", status_code=204, response_model=None) +async def job_cancel(job_id: str) -> None: + success = await cancel_job(job_id) + if not success: + raise HTTPException(status_code=409, detail="Job cannot be cancelled") diff --git a/api/routes/models.py b/api/routes/models.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9e5fe504e1b9fd2aefdee657d201c0f7824740 --- /dev/null +++ b/api/routes/models.py @@ -0,0 +1,127 @@ +""" +api/routes/models.py — /models REST endpoints. +""" +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form +from models.model import Model +from observability.logger import audit, get_logger +from registry.registry import count_models, get_model, list_models +from projects.service import get_active_project_id, import_local_model +from projects.registry import get_project +from pathlib import Path +import os +import tempfile + +log = get_logger("api.models") +router = APIRouter(prefix="/models", tags=["models"]) + + +@router.get("", response_model=list[Model]) +async def index( + search: Annotated[str | None, Query()] = None, + task: Annotated[list[str] | None, Query()] = None, + framework: Annotated[list[str] | None, Query()] = None, + hardware: Annotated[list[str] | None, Query()] = None, + source: Annotated[list[str] | None, Query()] = None, + downloaded: Annotated[bool | None, Query()] = None, + sort_by: Annotated[str, Query()] = "downloads", + sort_dir: Annotated[str, Query()] = "desc", + limit: Annotated[int, Query(ge=1, le=1000)] = 200, + offset: Annotated[int, Query(ge=0)] = 0, + project_id: Annotated[str | None, Query()] = None, +) -> list[Model]: + """ + List and search models. + Supports FTS search + server-side filtering. + Target: < 100ms for up to 5 000 models. + """ + effective_project_id = project_id or await get_active_project_id() + + models = await list_models( + search=search, + tasks=task, + frameworks=framework, + hardware=hardware, + sources=source, + downloaded=downloaded, + sort_by=sort_by, + sort_dir=sort_dir, + limit=limit, + offset=offset, + project_id=effective_project_id, + ) + + # If we have an active project, derive cache state from its workspace. + # This makes "downloaded" and "local_path" reflect the *current project*. + if effective_project_id: + proj = await get_project(effective_project_id) + if proj: + project_models_dir = Path(proj.path) / "models" + + updated: list[Model] = [] + for m in models: + model_dir = project_models_dir / m.id + if model_dir.exists() and model_dir.is_dir(): + # Pick the first file in the model directory (best-effort). + found_file: str | None = None + try: + for p in model_dir.rglob("*"): + if p.is_file(): + found_file = str(p) + break + except Exception: + found_file = None + + if found_file: + updated.append(m.model_copy(update={"downloaded": True, "local_path": found_file})) + continue + + # Not present in this project → treat as not cached for this project. + updated.append(m.model_copy(update={"downloaded": False, "local_path": None})) + + models = updated + + await audit("api_list_models", payload={"count": len(models), "search": search}) + return models + + +@router.post("/import", response_model=Model) +async def import_model( + name: Annotated[str, Form()], + task: Annotated[str, Form()], + framework: Annotated[str, Form()], + file: UploadFile = File(...), +) -> Model: + """Import a local model file into the active project.""" + # Save uploaded file to a temporary location + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename or "")[1]) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + try: + model = await import_local_model( + name=name, + task=task, + framework=framework, + source_file_path=tmp_path + ) + return model + except Exception as e: + log.error("model_import_failed", error=str(e)) + raise HTTPException(status_code=500, detail=str(e)) + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +@router.get("/{model_id}", response_model=Model) +async def detail(model_id: str) -> Model: + model = await get_model(model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") + await audit("api_get_model", model_id=model_id) + return model diff --git a/api/routes/projects.py b/api/routes/projects.py new file mode 100644 index 0000000000000000000000000000000000000000..82f30943ff7a2744ab963f26c9d75b616529df33 --- /dev/null +++ b/api/routes/projects.py @@ -0,0 +1,54 @@ +"""api/routes/projects.py — /projects REST endpoints.""" + +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, Query + +from models.project import Project +from observability.logger import audit +from projects.registry import delete_project, get_project, list_projects, touch_last_opened, upsert_project +from projects.service import set_active_project + + +router = APIRouter(prefix="/projects", tags=["projects"]) + + +@router.get("", response_model=list[Project]) +async def projects_list( + limit: int = Query(200, ge=1, le=1000), + offset: int = Query(0, ge=0), +) -> list[Project]: + projects = await list_projects(limit=limit, offset=offset) + await audit("api_list_projects", payload={"count": len(projects)}) + return projects + + +@router.post("", response_model=Project) +async def projects_upsert(project: Project) -> Project: + # Ensure project.created_at and last_opened are set if missing + if not project.created_at: + project.created_at = datetime.now(timezone.utc).isoformat() + if not project.last_opened: + project.last_opened = datetime.now(timezone.utc).isoformat() + + await upsert_project(project) + await audit("api_upsert_project", payload={"project_id": project.id}) + return project + + +@router.post("/{project_id}/open", status_code=204, response_model=None) +async def projects_open(project_id: str) -> None: + await touch_last_opened(project_id) + project = await get_project(project_id) + if project: + await set_active_project(project.id, project.path) + await audit("api_open_project", payload={"project_id": project_id}) + + +@router.delete("/{project_id}", status_code=204, response_model=None) +async def projects_delete(project_id: str) -> None: + ok = await delete_project(project_id) + if not ok: + raise HTTPException(status_code=404, detail=f"Project '{project_id}' not found") + await audit("api_delete_project", payload={"project_id": project_id}) + diff --git a/api/routes/sync.py b/api/routes/sync.py new file mode 100644 index 0000000000000000000000000000000000000000..b4981879385f3191a3489bb53f10aee4ac30d28b --- /dev/null +++ b/api/routes/sync.py @@ -0,0 +1,73 @@ +""" +api/routes/sync.py — /sync endpoint: fetch fresh models from all adapters. +""" +from __future__ import annotations + +from fastapi import APIRouter, BackgroundTasks + +from adapters.hf_adapter import HFAdapter +from adapters.onnx_adapter import ONNXAdapter +from observability.logger import audit, get_logger +from registry.registry import bulk_upsert, count_models + +log = get_logger("api.sync") +router = APIRouter(tags=["sync"]) + + +async def _run_full_sync() -> None: + log.info("sync_start") + total = 0 + + async with HFAdapter() as hf: + hf_models = await hf.fetch_models() + await bulk_upsert(hf_models) + total += len(hf_models) + log.info("sync_hf_done", count=len(hf_models)) + + # Prune any HF models outside the allowed task set (e.g. legacy NLP entries) + allowed_tasks = {"detection", "classification", "segmentation", "generation", "embedding"} + from database.connection import get_db + + db = await get_db() + placeholders = ",".join(["?"] * len(allowed_tasks)) + await db.execute( + f"DELETE FROM models WHERE source = 'hf' AND task NOT IN ({placeholders})", + tuple(sorted(allowed_tasks)), + ) + # Prune non-vision generation/embedding HF models. We rely on the adapter + # adding the pipeline_tag as a normalised tag (e.g. text_to_image). + await db.execute( + """ + DELETE FROM models + WHERE source = 'hf' + AND task IN ('generation','embedding') + AND ( + tags NOT LIKE '%text_to_image%' + AND tags NOT LIKE '%image_to_image%' + AND tags NOT LIKE '%image_feature_extraction%' + ) + """, + ) + await db.commit() + + onnx = ONNXAdapter() + onnx_models = await onnx.fetch_models() + await bulk_upsert(onnx_models) + total += len(onnx_models) + log.info("sync_onnx_done", count=len(onnx_models)) + + log.info("sync_complete", total=total) + await audit("sync_complete", payload={"total": total}) + + +@router.post("/sync", status_code=202) +async def trigger_sync(background_tasks: BackgroundTasks) -> dict: + """ + Kick off a background sync that fetches models from all sources. + Returns immediately; progress visible via /models count. + """ + background_tasks.add_task(_run_full_sync) + current = await count_models() + log.info("sync_triggered", current_model_count=current) + await audit("sync_triggered", payload={"current": current}) + return {"message": "Sync started", "current_model_count": current} diff --git a/api/routes/system.py b/api/routes/system.py new file mode 100644 index 0000000000000000000000000000000000000000..59191be55c260250f7ba18a2dd4fbf17859908e5 --- /dev/null +++ b/api/routes/system.py @@ -0,0 +1,97 @@ +"""api/routes/system.py — System metrics endpoints.""" + +from __future__ import annotations + +import asyncio +import json + +from fastapi import APIRouter, Query +from fastapi.responses import StreamingResponse + +from models.system import SystemMetrics +from system.metrics import sample_metrics + +router = APIRouter(prefix="/system", tags=["system"]) + + +@router.get("/metrics", response_model=SystemMetrics) +async def get_metrics(gpu_index: int = Query(0, ge=0)) -> SystemMetrics: + payload = sample_metrics(gpu_index=gpu_index) + return SystemMetrics( + ts=payload["ts"], + cpu_pct=payload["cpu_pct"], + cpu_model=payload.get("cpu_model"), + cpu_freq_mhz=payload.get("cpu_freq_mhz"), + cpu_count=payload.get("cpu_count"), + ram_used_mb=payload["ram_used_mb"], + ram_total_mb=payload["ram_total_mb"], + gpu=payload.get("gpu"), + disks=payload.get("disks", []), + network=payload.get("network", []), + ) + + +@router.get("/metrics/stream") +async def stream_metrics( + gpu_index: int = Query(0, ge=0), + hz: float = Query(2.0, ge=0.2, le=20.0), +): + """Server-Sent Events stream of system metrics.""" + + interval = 1.0 / float(hz) + + async def gen(): + # Initial comment helps some proxies establish the stream + yield ": connected\n\n" + while True: + try: + payload = sample_metrics(gpu_index=gpu_index) + # Ensure the payload is valid JSON and wrapped in data: format + data = json.dumps(payload) + yield f"data: {data}\n\n" + except Exception as e: + # Log error but keep stream alive + print(f"Metrics streaming error: {e}") + await asyncio.sleep(interval) + + return StreamingResponse( + gen(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + ) + + +@router.get("/logs/stream") +async def stream_system_logs(): + """SSE stream of global system and gateway logs.""" + from observability.logger import _sys_log_subs + + q: asyncio.Queue = asyncio.Queue() + _sys_log_subs.append(q) + + async def generator(): + yield ": connected\n\n" + try: + while True: + try: + entry = await asyncio.wait_for(q.get(), timeout=30.0) + except asyncio.TimeoutError: + yield ": heartbeat\n\n" + continue + if entry is None: + break + yield f"data: {json.dumps(entry)}\n\n" + finally: + if q in _sys_log_subs: + _sys_log_subs.remove(q) + + return StreamingResponse( + generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) diff --git a/api/routes/training.py b/api/routes/training.py new file mode 100644 index 0000000000000000000000000000000000000000..b616144c0564c5902d14df56895522073e84d717 --- /dev/null +++ b/api/routes/training.py @@ -0,0 +1,428 @@ +""" +api/routes/training.py — Training Engine REST + SSE endpoints. + +POST /train/start — create and launch a training run +POST /train/stop — cancel a running run +POST /train/pause — pause a running run +POST /train/resume — resume a paused run +GET /train/status — run status + progress snapshot +GET /train/runs — list all runs +GET /train/runs/{run_id} — single run detail +GET /train/schema — UI schema for task/model/dataset combo +GET /train/checkpoints — checkpoints for a run (stub) +POST /train/checkpoints/{id}/export — export a checkpoint (stub) +GET /train/metrics/stream — SSE: real-time metrics ticks +GET /train/logs/stream — SSE: real-time log entries +GET /train/resources/stream — SSE: real-time resource ticks +""" +from __future__ import annotations + +import asyncio +import json +import time +import os + +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import StreamingResponse + +from observability.logger import get_logger +from training import run_manager +from training.schema_engine import generate_schema +from training.schemas import ( + CheckpointOut, + PauseTrainRequest, + ResumeTrainRequest, + StartTrainRequest, + StartTrainResponse, + StopTrainRequest, + TrainRunOut, + TrainStatusResponse, + TrainingSchemaResponse, +) + +log = get_logger("api.training") +router = APIRouter(prefix="/train", tags=["training"]) + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _format_duration(seconds: float) -> str: + h = int(seconds // 3600) + m = int((seconds % 3600) // 60) + s = int(seconds % 60) + return f"{h}h {m}m {s}s" + + +def _run_to_out(run: run_manager.TrainRun) -> TrainRunOut: + elapsed = (run.completed_at or time.time()) - run.created_at + return TrainRunOut( + id=run.run_id, + run_number=run.run_number, + model_id=run.model_id, + model_name=run.model_name, + dataset_id=run.dataset_id, + dataset_name=run.dataset_name, + task=run.task, + status=run.status, + epochs_done=run.epoch, + total_epochs=run.total_epochs, + best_metric=run.best_metric, + final_loss=run.final_loss, + duration=_format_duration(elapsed), + created_at=run.created_at, + completed_at=run.completed_at, + hyperparams=run.hyperparams, + ) + + +# ── Control endpoints ───────────────────────────────────────────────────────── + +@router.post("/start", response_model=StartTrainResponse) +async def start_training(body: StartTrainRequest) -> StartTrainResponse: + """Create and immediately launch a training run.""" + # Resolve friendly names (fall back to ids if registries unavailable) + model_name = body.model_id + dataset_name = body.dataset_id + try: + from registry.registry import get_model + m = await get_model(body.model_id) + if m: + model_name = m.name + except Exception: + pass + try: + from datasets.registry import get_dataset + d = await get_dataset(body.dataset_id) + if d: + dataset_name = d.get("name", body.dataset_id) if isinstance(d, dict) else getattr(d, "name", body.dataset_id) + except Exception: + pass + + run = run_manager.create_run( + model_id=body.model_id, + model_name=model_name, + dataset_id=body.dataset_id, + dataset_name=dataset_name, + task=body.task, + hyperparams=body.hyperparams, + augmentation=body.augmentation, + scheduler=body.scheduler, + project_id=body.project_id + ) + run_manager.start_run(run) + + log.info("training_started", run_id=run.run_id, model=body.model_id) + return StartTrainResponse( + run_id=run.run_id, + status=run.status, + message=f"Training run {run.run_id} started.", + ) + + +@router.post("/stop", status_code=200) +async def stop_training(body: StopTrainRequest) -> dict: + run = run_manager.get_run(body.run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found") + run_manager.stop_run(run) + log.info("training_stopped", run_id=body.run_id) + return {"run_id": body.run_id, "status": run.status} + + +@router.post("/pause", status_code=200) +async def pause_training(body: PauseTrainRequest) -> dict: + run = run_manager.get_run(body.run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found") + run_manager.pause_run(run) + return {"run_id": body.run_id, "status": run.status} + + +@router.post("/resume", status_code=200) +async def resume_training(body: ResumeTrainRequest) -> dict: + run = run_manager.get_run(body.run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found") + run_manager.resume_run(run) + return {"run_id": body.run_id, "status": run.status} + + +@router.get("/status", response_model=TrainStatusResponse) +async def get_train_status(run_id: str = Query(...)) -> TrainStatusResponse: + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + return TrainStatusResponse( + run_id=run.run_id, + status=run.status, + epoch=run.epoch, + total_epochs=run.total_epochs, + step=run.step, + total_steps=run.total_epochs * 100, + eta_seconds=run.eta_seconds, + elapsed_seconds=run.elapsed_seconds, + ) + + +# ── Run history ─────────────────────────────────────────────────────────────── + +@router.get("/runs", response_model=list[TrainRunOut]) +async def list_runs() -> list[TrainRunOut]: + return [_run_to_out(r) for r in reversed(run_manager.list_runs())] + + +@router.get("/runs/{run_id}", response_model=TrainRunOut) +async def get_run(run_id: str) -> TrainRunOut: + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + return _run_to_out(run) + + +# ── Schema Engine ───────────────────────────────────────────────────────────── + +@router.get("/schema", response_model=TrainingSchemaResponse) +async def get_schema( + model_id: str = Query(""), + dataset_id: str = Query(""), + task: str = Query("detection"), +) -> TrainingSchemaResponse: + schema = generate_schema(task=task, model_id=model_id, dataset_id=dataset_id) + return TrainingSchemaResponse(**schema) + + +# ── Checkpoints (stub — extend when artifact storage is wired) ──────────────── + +@router.get("/checkpoints", response_model=list[CheckpointOut]) +async def list_checkpoints(run_id: str = Query(...)) -> list[CheckpointOut]: + """Returns an empty list until checkpoint persistence is implemented.""" + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + return [] + + +@router.post("/checkpoints/{checkpoint_id}/export") +async def export_checkpoint(checkpoint_id: str, body: dict = {}) -> dict: + raise HTTPException(status_code=501, detail="Checkpoint export not yet implemented") + + +# ── SSE: Metrics stream ──────────────────────────────────────────────────────── + +@router.get("/metrics/stream") +async def stream_metrics(run_id: str = Query(...)) -> StreamingResponse: + """ + Server-Sent Events stream of TrainMetricsTick objects. + Connects to the run's metrics queue and forwards each tick as SSE. + Stream closes when the run finishes (sentinel None pushed by worker). + """ + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + + q: asyncio.Queue = asyncio.Queue() + run.metrics_subs.append(q) + + async def generator(): + yield ": connected\n\n" + try: + while True: + try: + tick = await asyncio.wait_for(q.get(), timeout=30.0) + except asyncio.TimeoutError: + # Heartbeat to keep connection alive + yield ": heartbeat\n\n" + continue + if tick is None: + break + yield f"data: {json.dumps(tick)}\n\n" + finally: + if q in run.metrics_subs: + run.metrics_subs.remove(q) + + return StreamingResponse( + generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +# ── SSE: Logs stream ────────────────────────────────────────────────────────── + +@router.get("/logs/stream") +async def stream_logs(run_id: str = Query(...)) -> StreamingResponse: + """Server-Sent Events stream of LogEntry objects.""" + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + + q: asyncio.Queue = asyncio.Queue() + run.log_subs.append(q) + + async def generator(): + yield ": connected\n\n" + try: + while True: + try: + entry = await asyncio.wait_for(q.get(), timeout=30.0) + except asyncio.TimeoutError: + yield ": heartbeat\n\n" + continue + if entry is None: + break + yield f"data: {json.dumps(entry)}\n\n" + finally: + if q in run.log_subs: + run.log_subs.remove(q) + + return StreamingResponse( + generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.get("/runs/{run_id}/history") +async def get_run_history(run_id: str) -> list[dict]: + """Retrieves the full historical telemetry (metrics ticks) for a run.""" + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + + from training.persistence import TrainingPersistence + run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id) + telemetry_path = os.path.join(run_dir, "telemetry.jsonl") + + history = [] + if os.path.exists(telemetry_path): + try: + with open(telemetry_path, "r") as f: + for line in f: + if line.strip(): + history.append(json.loads(line)) + except Exception as e: + log.error("history_read_failed", run_id=run_id, error=str(e)) + raise HTTPException(status_code=500, detail="Failed to read telemetry history") + + return history + +@router.get("/runs/{run_id}/artifacts") +async def list_run_artifacts(run_id: str) -> dict: + """Lists available artifacts (images) for a specific run by scanning the directory.""" + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + + from training.persistence import TrainingPersistence + run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id) + + if not os.path.exists(run_dir): + return {"artifacts": [], "batches": []} + + artifacts = [] + batches = [] + + # Standard YOLO artifact mappings for better UI titles + titles = { + "confusion_matrix.png": "Confusion Matrix", + "confusion_matrix_normalized.png": "Confusion Matrix (Norm)", + "results.png": "Results Summary", + "F1_curve.png": "F1 Curve", + "PR_curve.png": "PR Curve", + "P_curve.png": "Precision Curve", + "R_curve.png": "Recall Curve", + "BoxF1_curve.png": "Box F1 Curve", + "BoxP_curve.png": "Box Precision Curve", + "BoxPR_curve.png": "Box PR Curve", + "BoxR_curve.png": "Box Recall Curve", + "labels.jpg": "Labels Distribution", + "labels_correlogram.jpg": "Labels Correlogram" + } + + for f in os.listdir(run_dir): + path = f"/train/runs/{run_id}/files/{f}" + if f.endswith(('.png', '.jpg', '.jpeg')): + item = { + "title": titles.get(f, f.replace('_', ' ').title().split('.')[0]), + "path": path, + "type": "Analysis" + } + + if "batch" in f.lower(): + item["type"] = "Batch Preview" if "val" in f.lower() else "Augmentation" + batches.append(item) + else: + if "curve" in f.lower(): + item["type"] = "Precision-Recall" + elif "confusion" in f.lower(): + item["type"] = "Analysis" + elif "results" in f.lower(): + item["type"] = "Overall" + artifacts.append(item) + + return { + "artifacts": sorted(artifacts, key=lambda x: x['title']), + "batches": sorted(batches, key=lambda x: x['title']) + } + +@router.get("/runs/{run_id}/files/{filename}") +async def get_run_file(run_id: str, filename: str): + """Serves a specific file from the run directory.""" + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail="Run not found") + + # We need to find the project to get the run_dir + # Since run_manager doesn't easily expose the full path in memory, + # we recalculate it using persistence + from training.persistence import TrainingPersistence + run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id) + file_path = os.path.join(run_dir, filename) + + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="File not found") + + from fastapi.responses import FileResponse + return FileResponse(file_path) +# The frontend uses /system/metrics/stream for resources (already implemented). +# This alias exists for training-scoped resource monitoring. + +@router.get("/resources/stream") +async def stream_resources( + run_id: str = Query(...), + gpu_index: int = Query(0, ge=0), + hz: float = Query(1.0, ge=0.2, le=10.0), +) -> StreamingResponse: + """ + SSE stream of ResourceTick objects for a specific training run. + Forwards system metrics at the requested hz rate. + """ + run = run_manager.get_run(run_id) + if not run: + raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") + + q: asyncio.Queue = asyncio.Queue() + run.resource_subs.append(q) + + interval = 1.0 / hz + + async def generator(): + yield ": connected\n\n" + try: + while True: + try: + tick = await asyncio.wait_for(q.get(), timeout=30.0) + except asyncio.TimeoutError: + yield ": heartbeat\n\n" + continue + if tick is None: + break + yield f"data: {json.dumps(tick)}\n\n" + finally: + if q in run.resource_subs: + run.resource_subs.remove(q) + + return StreamingResponse( + generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bbcba8cd757f033d6aaf2c2d9e88930af52496c5 --- /dev/null +++ b/config.py @@ -0,0 +1,83 @@ +""" +config.py — Centralized application settings. +All tuneable knobs live here; override via environment variables. +""" +from pathlib import Path +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + ) + + # ── App ─────────────────────────────────────────────────────────── + app_name: str = "MLForge Platform" + version: str = "1.0.0" + debug: bool = False + + # ── API ─────────────────────────────────────────────────────────── + host: str = "0.0.0.0" + port: int = 8005 + cors_origins: list[str] = [ + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:2000", + "http://127.0.0.1:2000", + ] + + # ── Storage ─────────────────────────────────────────────────────── + base_dir: Path = Path(__file__).resolve().parents[1] + data_dir: Path = base_dir / "data" + models_dir: Path = data_dir / "models" + datasets_dir: Path = data_dir / "datasets" # root for imported datasets + logs_dir: Path = data_dir / "logs" + db_path: Path = data_dir / "modelzoo.db" + + # ── Download Manager ────────────────────────────────────────────── + max_concurrent_downloads: int = 5 + download_chunk_size: int = 1024 * 1024 # 1 MB + download_max_retries: int = 3 + download_retry_delay: float = 2.0 # seconds (base, exponential backoff) + + # ── Search ──────────────────────────────────────────────────────── + search_max_results: int = 500 + + # ── Sync ────────────────────────────────────────────────────────── + auto_sync_on_startup: bool = True + + # ── Hugging Face API ────────────────────────────────────────────── + hf_api_base: str = "https://huggingface.co/api" + hf_hub_url: str = "https://huggingface.co" + hf_token: str | None = None # Optional: HF_TOKEN env var + hf_models_per_task: int = 100 # How many to pull per task + + # ── ONNX Zoo ────────────────────────────────────────────────────── + onnx_models_url: str = ( + "https://raw.githubusercontent.com/onnx/models/main/README.md" + ) + + # ── Benchmark Bridge ────────────────────────────────────────────── + benchmark_max_concurrent: int = 3 # max parallel benchmark jobs + benchmark_max_log_lines: int = 500 # log entries kept per job + benchmark_ws_poll_hz: float = 2.0 # WebSocket telemetry poll rate + + # ── Dataset Manager ─────────────────────────────────────────────── + roboflow_api_base: str = "https://api.roboflow.com" + dataset_import_workers: int = 3 # max concurrent import jobs + dataset_chunk_size: int = 1024 * 1024 * 4 # 4 MB download chunk + roboflow_cache_ttl_secs: int = 3600 # 1 hour + + def ensure_dirs(self) -> None: + self.data_dir.mkdir(parents=True, exist_ok=True) + self.models_dir.mkdir(parents=True, exist_ok=True) + self.datasets_dir.mkdir(parents=True, exist_ok=True) + (self.datasets_dir / "_tmp").mkdir(parents=True, exist_ok=True) + self.logs_dir.mkdir(parents=True, exist_ok=True) + + +settings = Settings() diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/database/__pycache__/__init__.cpython-310.pyc b/database/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35599d655ba0f094c0f8f5ad8ae65363b966398f Binary files /dev/null and b/database/__pycache__/__init__.cpython-310.pyc differ diff --git a/database/__pycache__/connection.cpython-310.pyc b/database/__pycache__/connection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d81deca5db6e2e55761ced1a8cd750715a924e1 Binary files /dev/null and b/database/__pycache__/connection.cpython-310.pyc differ diff --git a/database/benchmark_schema.sql b/database/benchmark_schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..5ad87c90b496387ebfc045fb793586f1dcb6f253 --- /dev/null +++ b/database/benchmark_schema.sql @@ -0,0 +1,62 @@ +-- ============================================================ +-- MLForge Benchmark Bridge — SQLite Schema +-- Version: 1.0.0 +-- ============================================================ + +PRAGMA journal_mode = WAL; +PRAGMA foreign_keys = ON; + +-- ── Benchmark Jobs ──────────────────────────────────────────────────────────── +-- Tracks every benchmark run from queued → running → completed/failed. +-- config stores the full BenchmarkContext JSON for full reproducibility. +CREATE TABLE IF NOT EXISTS benchmark_jobs ( + id TEXT PRIMARY KEY, + model_id TEXT NOT NULL, + dataset_id TEXT NOT NULL, + task TEXT NOT NULL, + framework TEXT NOT NULL, + hardware TEXT NOT NULL DEFAULT 'cpu', + precision TEXT NOT NULL DEFAULT 'FP32', + batch_size INTEGER NOT NULL DEFAULT 1, + config TEXT NOT NULL DEFAULT '{}', -- full BenchmarkContext JSON + status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed + progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0 + logs TEXT NOT NULL DEFAULT '[]', -- JSON array of timestamped log strings + error TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + started_at TEXT, + ended_at TEXT +); + +-- ── Benchmark Results ───────────────────────────────────────────────────────── +-- Stores final computed metrics + telemetry summary after job completion. +CREATE TABLE IF NOT EXISTS benchmark_results ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL REFERENCES benchmark_jobs(id) ON DELETE CASCADE, + metrics TEXT NOT NULL DEFAULT '{}', -- JSON: BenchmarkMetrics + telemetry_summary TEXT NOT NULL DEFAULT '{}', -- JSON: TelemetrySummary + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Validation Logs ─────────────────────────────────────────────────────────── +-- Immutable audit trail of every compatibility check performed. +-- job_id = 'pre-check' for validations that blocked job creation. +CREATE TABLE IF NOT EXISTS benchmark_validation_logs ( + id TEXT PRIMARY KEY, + job_id TEXT NOT NULL, + model_id TEXT NOT NULL, + dataset_id TEXT NOT NULL, + checks TEXT NOT NULL DEFAULT '[]', -- JSON: list[ValidationCheck] + passed INTEGER NOT NULL DEFAULT 1, -- 1=passed, 0=failed + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Indexes ─────────────────────────────────────────────────────────────────── +CREATE INDEX IF NOT EXISTS idx_bmark_jobs_status ON benchmark_jobs(status); +CREATE INDEX IF NOT EXISTS idx_bmark_jobs_model ON benchmark_jobs(model_id); +CREATE INDEX IF NOT EXISTS idx_bmark_jobs_dataset ON benchmark_jobs(dataset_id); +CREATE INDEX IF NOT EXISTS idx_bmark_jobs_created ON benchmark_jobs(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_bmark_results_job ON benchmark_results(job_id); +CREATE INDEX IF NOT EXISTS idx_bmark_valid_job ON benchmark_validation_logs(job_id); +CREATE INDEX IF NOT EXISTS idx_bmark_valid_model ON benchmark_validation_logs(model_id); diff --git a/database/connection.py b/database/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..a44bd0f60fa76184be1417aa7e55f97cba7cbf10 --- /dev/null +++ b/database/connection.py @@ -0,0 +1,106 @@ +""" +database/connection.py — Async SQLite connection & migration bootstrap. +Single module responsible for DB lifecycle. All queries use this pool. +""" +from __future__ import annotations + +import asyncio +from pathlib import Path + +import aiosqlite + +from config import settings + +# Module-level connection (shared within the process) +_db: aiosqlite.Connection | None = None +_lock = asyncio.Lock() + + +async def get_db() -> aiosqlite.Connection: + """Return the singleton async database connection.""" + global _db + async with _lock: + if _db is None: + _db = await _open_connection() + return _db + + +async def _open_connection() -> aiosqlite.Connection: + settings.ensure_dirs() + conn = await aiosqlite.connect(settings.db_path, check_same_thread=False) + conn.row_factory = aiosqlite.Row + await conn.execute("PRAGMA journal_mode=WAL") + await conn.execute("PRAGMA foreign_keys=ON") + await conn.execute("PRAGMA synchronous=NORMAL") + await conn.execute("PRAGMA cache_size=-65536") # 64 MB page cache + await _run_migrations(conn) + await conn.commit() + return conn + + +async def _run_migrations(conn: aiosqlite.Connection) -> None: + """Apply all schema files idempotently (CREATE IF NOT EXISTS).""" + base = Path(__file__).parent + + # ── STEP 1: Ensure basic tables exist ── + for schema_file in ["schema.sql", "dataset_schema.sql", "benchmark_schema.sql"]: + path = base / schema_file + if path.exists(): + sql = path.read_text(encoding="utf-8") + await conn.executescript(sql) + + # ── STEP 2: Legacy Alterations ── + # Check 'models' table for specific columns + async with conn.execute("PRAGMA table_info(models)") as cur: + cols = {r[1] for r in await cur.fetchall()} + + if cols: # only if table exists + if "download_url" not in cols: + await conn.execute("ALTER TABLE models ADD COLUMN download_url TEXT") + + if "active_version" not in cols: + await conn.execute("ALTER TABLE models ADD COLUMN active_version TEXT") + + if "metrics" not in cols: + await conn.execute("ALTER TABLE models ADD COLUMN metrics TEXT NOT NULL DEFAULT '{}' ") + + # Check 'datasets' table for new columns (e.g. active_version) + async with conn.execute("PRAGMA table_info(datasets)") as cur: + ds_cols = {r[1] for r in await cur.fetchall()} + + if ds_cols: + if "active_version" not in ds_cols: + await conn.execute("ALTER TABLE datasets ADD COLUMN active_version TEXT NOT NULL DEFAULT 'v1'") + if "roboflow_id" not in ds_cols: + await conn.execute("ALTER TABLE datasets ADD COLUMN roboflow_id TEXT") + if "health_score" not in ds_cols: + await conn.execute("ALTER TABLE datasets ADD COLUMN health_score INTEGER NOT NULL DEFAULT 0") + + # Check 'models' table for project_id + async with conn.execute("PRAGMA table_info(models)") as cur: + model_cols = {r[1] for r in await cur.fetchall()} + + if model_cols and "project_id" not in model_cols: + await conn.execute("ALTER TABLE models ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE CASCADE") + + # Clean up any lingering temporary tables from failed legacy migrations + # COMMIT is essential here to ensure background jobs see the clean state immediately + # We use a try/except block to avoid "no such table" errors if the table is already gone + try: + await conn.execute("DROP TABLE IF EXISTS datasets_old") + except: + pass + + try: + await conn.execute("DROP TABLE IF EXISTS dataset_jobs_old") + except: + pass + + await conn.commit() + +async def close_db() -> None: + global _db + async with _lock: + if _db is not None: + await _db.close() + _db = None diff --git a/database/dataset_schema.sql b/database/dataset_schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..9dc26f3136ff23cb3236c6310c08d3e46b509043 --- /dev/null +++ b/database/dataset_schema.sql @@ -0,0 +1,117 @@ +-- ============================================================ +-- MLForge Dataset Manager — SQLite Schema Extension +-- Appended to existing modelzoo.db (CREATE IF NOT EXISTS) +-- ============================================================ + +-- ── Datasets ────────────────────────────────────────────── +CREATE TABLE IF NOT EXISTS datasets ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + task TEXT NOT NULL, + format TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'roboflow', + status TEXT NOT NULL DEFAULT 'available', + images INTEGER NOT NULL DEFAULT 0, + classes INTEGER NOT NULL DEFAULT 0, + class_names TEXT NOT NULL DEFAULT '[]', -- JSON array + size_bytes INTEGER NOT NULL DEFAULT 0, + size_label TEXT NOT NULL DEFAULT '0 B', + local_path TEXT, + import_progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0 + tags TEXT NOT NULL DEFAULT '[]', -- JSON array + versions TEXT NOT NULL DEFAULT '[]', -- JSON array + active_version TEXT NOT NULL DEFAULT 'v1', + starred INTEGER NOT NULL DEFAULT 0, + roboflow_id TEXT, -- workspace/project slug + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Dataset Jobs ────────────────────────────────────────── +CREATE TABLE IF NOT EXISTS dataset_jobs ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, -- import|extract|validate|analyze + status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled + dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE, + dataset_name TEXT NOT NULL DEFAULT '', + progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0 + message TEXT NOT NULL DEFAULT '', + error TEXT, + meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + started_at TEXT, + ended_at TEXT +); + +-- ── Dataset Images Index ────────────────────────────────── +-- Populated after extraction; enables fast paginated viewer queries +CREATE TABLE IF NOT EXISTS dataset_images ( + id TEXT PRIMARY KEY, -- sha1 or sequential id + dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE, + filename TEXT NOT NULL, + rel_path TEXT NOT NULL, -- relative to dataset local_path + width INTEGER NOT NULL DEFAULT 0, + height INTEGER NOT NULL DEFAULT 0, + split TEXT NOT NULL DEFAULT 'train', + ann_count INTEGER NOT NULL DEFAULT 0, -- fast count without parsing + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Dataset Annotations Cache ───────────────────────────── +-- Parsed annotations stored in normalised form for fast retrieval +CREATE TABLE IF NOT EXISTS dataset_annotations ( + id TEXT PRIMARY KEY, + image_id TEXT NOT NULL REFERENCES dataset_images(id) ON DELETE CASCADE, + dataset_id TEXT NOT NULL, + label TEXT NOT NULL, + bbox_x REAL, + bbox_y REAL, + bbox_w REAL, + bbox_h REAL, + normalised INTEGER DEFAULT 1, + area REAL, + confidence REAL, + ann_type TEXT DEFAULT 'detection', + segmentation TEXT, -- JSON array of points [[x,y],...] + keypoints TEXT, -- JSON array of keypoints [x,y,v,...] + metadata TEXT -- Extra JSON metadata +); + +-- ── Roboflow Metadata Cache ─────────────────────────────── +-- Avoids redundant API calls; TTL enforced in Python layer +CREATE TABLE IF NOT EXISTS roboflow_cache ( + cache_key TEXT PRIMARY KEY, -- workspace/project or search query hash + payload TEXT NOT NULL, -- JSON blob + fetched_at TEXT NOT NULL DEFAULT (datetime('now')), + ttl_secs INTEGER NOT NULL DEFAULT 3600 -- 1 hour default +); + +-- ── Indexes ─────────────────────────────────────────────── +CREATE INDEX IF NOT EXISTS idx_datasets_task ON datasets(task); +CREATE INDEX IF NOT EXISTS idx_datasets_format ON datasets(format); +CREATE INDEX IF NOT EXISTS idx_datasets_source ON datasets(source); +CREATE INDEX IF NOT EXISTS idx_datasets_status ON datasets(status); +CREATE INDEX IF NOT EXISTS idx_datasets_starred ON datasets(starred); + +CREATE INDEX IF NOT EXISTS idx_djobs_status ON dataset_jobs(status); +CREATE INDEX IF NOT EXISTS idx_djobs_dataset ON dataset_jobs(dataset_id); + +CREATE INDEX IF NOT EXISTS idx_dimages_dataset ON dataset_images(dataset_id); +CREATE INDEX IF NOT EXISTS idx_dimages_split ON dataset_images(dataset_id, split); + +CREATE INDEX IF NOT EXISTS idx_dann_image ON dataset_annotations(image_id); +CREATE INDEX IF NOT EXISTS idx_dann_dataset ON dataset_annotations(dataset_id); +CREATE INDEX IF NOT EXISTS idx_dann_label ON dataset_annotations(dataset_id, label); + +-- ── Updated-at trigger for datasets ────────────────────── +CREATE TRIGGER IF NOT EXISTS datasets_updated_at +AFTER UPDATE ON datasets BEGIN + UPDATE datasets SET updated_at = datetime('now') WHERE id = NEW.id; +END; + +CREATE TRIGGER IF NOT EXISTS dataset_jobs_updated_at +AFTER UPDATE ON dataset_jobs BEGIN + UPDATE dataset_jobs SET updated_at = datetime('now') WHERE id = NEW.id; +END; diff --git a/database/schema.sql b/database/schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..d5c39618f571c5c1faabf44aab29ad4692e074ff --- /dev/null +++ b/database/schema.sql @@ -0,0 +1,152 @@ +-- ============================================================ +-- MLForge Model Zoo — SQLite Schema +-- Version: 1.0.0 +-- ============================================================ + +PRAGMA journal_mode = WAL; +PRAGMA foreign_keys = ON; + +-- ── Models ──────────────────────────────────────────────── +CREATE TABLE IF NOT EXISTS models ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + variant TEXT, + task TEXT NOT NULL, + framework TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'hf', + provider TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + download_url TEXT, -- explicit download source URL + size INTEGER NOT NULL DEFAULT 0, + size_label TEXT NOT NULL DEFAULT '0B', + tags TEXT NOT NULL DEFAULT '[]', -- JSON array + hardware TEXT NOT NULL DEFAULT '[]', -- JSON array + status TEXT NOT NULL DEFAULT 'available', + downloaded INTEGER NOT NULL DEFAULT 0, + active_version TEXT, + local_path TEXT, + metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc. + downloads INTEGER DEFAULT 0, + rating REAL, + liked INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Model Versions ──────────────────────────────────────── +CREATE TABLE IF NOT EXISTS model_versions ( + version_id TEXT PRIMARY KEY, + model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE, + version TEXT NOT NULL, + label TEXT NOT NULL DEFAULT 'Stable', -- Latest|Stable|Legacy + description TEXT, + metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc. + local_path TEXT, + downloaded INTEGER NOT NULL DEFAULT 0, + release_date TEXT, + changelog TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── Jobs ─────────────────────────────────────────────────- +CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, -- download|benchmark|sync + status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled + model_id TEXT REFERENCES models(id), + model_name TEXT, + progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0 + error TEXT, + meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + started_at TEXT, + ended_at TEXT +); + +-- ── Projects ───────────────────────────────────────────── +CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + path TEXT NOT NULL, + created_at TEXT NOT NULL, + last_opened TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'idle' +); + +-- ── Session ─────────────────────────────────────────────── +-- Stores the currently active project so backend services +-- (e.g. download manager) can link assets into the workspace. +CREATE TABLE IF NOT EXISTS session ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_projects_path ON projects(path); + +-- ── Audit Log ───────────────────────────────────────────── +CREATE TABLE IF NOT EXISTS audit_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL, -- api_request|download_start|download_ok|error|sync + model_id TEXT, + job_id TEXT, + payload TEXT NOT NULL DEFAULT '{}', -- JSON + level TEXT NOT NULL DEFAULT 'info', + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- ── FTS (Full-Text Search) ──────────────────────────────── +CREATE VIRTUAL TABLE IF NOT EXISTS models_fts USING fts5( + id UNINDEXED, + name, + description, + tags, + provider, + task, + framework, + content='models', + content_rowid='rowid' +); + +-- Triggers to keep FTS in sync +CREATE TRIGGER IF NOT EXISTS models_fts_insert AFTER INSERT ON models BEGIN + INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework) + VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework); +END; + +CREATE TRIGGER IF NOT EXISTS models_fts_delete BEFORE DELETE ON models BEGIN + DELETE FROM models_fts WHERE rowid = old.rowid; +END; + +CREATE TRIGGER IF NOT EXISTS models_fts_update AFTER UPDATE ON models BEGIN + DELETE FROM models_fts WHERE rowid = old.rowid; + INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework) + VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework); +END; + +-- ── Inference History ──────────────────────────────────── +CREATE TABLE IF NOT EXISTS inference_history ( + id TEXT PRIMARY KEY, + model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE, + model_name TEXT NOT NULL, + adapter_type TEXT NOT NULL, + timestamp REAL NOT NULL DEFAULT (unixepoch('now')), + total_ms REAL NOT NULL DEFAULT 0.0, + quality_score REAL, + status TEXT NOT NULL DEFAULT 'ok', + request_snapshot TEXT NOT NULL DEFAULT '{}' -- JSON +); + +CREATE INDEX IF NOT EXISTS idx_inference_model ON inference_history(model_id); +CREATE INDEX IF NOT EXISTS idx_inference_time ON inference_history(timestamp DESC); + +-- ── Indexes ─────────────────────────────────────────────── +CREATE INDEX IF NOT EXISTS idx_models_task ON models(task); +CREATE INDEX IF NOT EXISTS idx_models_framework ON models(framework); +CREATE INDEX IF NOT EXISTS idx_models_source ON models(source); +CREATE INDEX IF NOT EXISTS idx_models_status ON models(status); +CREATE INDEX IF NOT EXISTS idx_models_downloads ON models(downloads DESC); +CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); +CREATE INDEX IF NOT EXISTS idx_jobs_model ON jobs(model_id); +CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log(event_type); +CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log(created_at DESC); diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..61af7e63b40a436b37b0fc5f2f7798c325404826 --- /dev/null +++ b/main.py @@ -0,0 +1,149 @@ +""" +main.py — FastAPI application entry point. +Wires together all modules, registers middleware/routes, manages lifespan. +""" +from __future__ import annotations + +import os +import sys + +# Ensure backend root is in sys.path to resolve 'backend.*' imports correctly +# when running from the 'backend' directory. +backend_root = os.path.dirname(os.path.abspath(__file__)) +if backend_root not in sys.path: + sys.path.insert(0, backend_root) + +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator + +import traceback + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from api.routes import jobs as jobs_router +from api.routes import models as models_router +from api.routes import sync as sync_router +from api.routes import datasets as datasets_router +from api.routes import benchmark as benchmark_router +from api.routes import system as system_router +from api.routes import projects as projects_router +from api.routes import inference as inference_router +from api.routes import training as training_router +from config import settings +from database.connection import close_db, get_db +from middleware.logging_middleware import RequestLoggingMiddleware +from observability.logger import configure_logging, get_logger + +# ── Logging bootstrap (must be first) ───────────────────────────────────────── +configure_logging() +log = get_logger("main") + + +# ── Lifespan ────────────────────────────────────────────────────────────────── +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + # Startup + settings.ensure_dirs() + log.info("startup", host=settings.host, port=settings.port, version=settings.version) + await get_db() # Bootstrap DB / run migrations + log.info("database_ready", path=str(settings.db_path)) + + # Job Recovery (Cleanup stale imports/benchmarks) + try: + from datasets.import_service import recover_stale_jobs + await recover_stale_jobs() + except Exception as e: + log.error("job_recovery_failed", error=str(e)) + + if settings.auto_sync_on_startup: + from registry.registry import count_models + + current = await count_models() + if current == 0: + from api.routes.sync import _run_full_sync + + log.info("auto_sync_startup_triggered") + asyncio.create_task(_run_full_sync()) + + yield # ← app runs + + # Shutdown + await close_db() + log.info("shutdown") + + +# ── Application ─────────────────────────────────────────────────────────────── +app = FastAPI( + title=settings.app_name, + version=settings.version, + description="Production ML Model Zoo backend — local-first, traceable, extensible.", + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan, +) + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + # Log full traceback for debugging 500s. + log.error( + "unhandled_exception", + path=request.url.path, + error=str(exc), + traceback=traceback.format_exc(), + ) + return JSONResponse( + status_code=500, + content={"detail": "Internal Server Error", "error": str(exc)}, + ) + +# ── Middleware ───────────────────────────────────────────────────────────────── +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_origin_regex=r"^https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?$", + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(RequestLoggingMiddleware) + +# ── Routes ──────────────────────────────────────────────────────────────────── +app.include_router(models_router.router) +app.include_router(jobs_router.router) +app.include_router(sync_router.router) +app.include_router(datasets_router.router) +app.include_router(benchmark_router.router) +app.include_router(system_router.router) +app.include_router(projects_router.router) +app.include_router(inference_router.router) +app.include_router(training_router.router) + + +@app.get("/health", tags=["system"]) +async def health() -> dict: + from registry.registry import count_models + from datasets.registry import count_datasets + n_models = await count_models() + n_datasets = await count_datasets() + return { + "status": "ok", + "version": settings.version, + "model_count": n_models, + "dataset_count": n_datasets, + } + + +# ── Dev runner ──────────────────────────────────────────────────────────────── +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "main:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + log_config=None, # We use structlog + ) diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/middleware/__pycache__/__init__.cpython-310.pyc b/middleware/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42afbebcd39bd3110182f40acf07d28d70b1a66d Binary files /dev/null and b/middleware/__pycache__/__init__.cpython-310.pyc differ diff --git a/middleware/__pycache__/logging_middleware.cpython-310.pyc b/middleware/__pycache__/logging_middleware.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94c4ddd1033c4e330c9633cc2faeb61d73481974 Binary files /dev/null and b/middleware/__pycache__/logging_middleware.cpython-310.pyc differ diff --git a/middleware/logging_middleware.py b/middleware/logging_middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..8d885b03d55f3338299c89bfecb22127a26e8eac --- /dev/null +++ b/middleware/logging_middleware.py @@ -0,0 +1,57 @@ +""" +middleware/logging_middleware.py — Structured request/response logging. +Attaches a trace_id to every request, logs timing, method, path, status. +""" +from __future__ import annotations + +import time +import uuid +from typing import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +from observability.logger import audit, get_logger, log_system_event + +log = get_logger("http") + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable) -> Response: + trace_id = str(uuid.uuid4())[:8] + request.state.trace_id = trace_id + start = time.perf_counter() + + log_system_event( + level="info", + message=f"API Request: {request.method} {request.url.path}", + source="gateway", + payload={"trace_id": trace_id, "query": str(request.url.query)} + ) + + response = await call_next(request) + duration_ms = (time.perf_counter() - start) * 1000 + + log_system_event( + level="info" if response.status_code < 400 else "error", + message=f"API Response: {response.status_code} ({duration_ms:.1f}ms)", + source="gateway", + payload={"trace_id": trace_id, "status": response.status_code, "latency_ms": round(duration_ms, 2)} + ) + + response.headers["X-Trace-Id"] = trace_id + response.headers["X-Response-Time"] = f"{duration_ms:.1f}ms" + + # Audit slow requests + if duration_ms > 200: + await audit( + "slow_request", + payload={ + "path": request.url.path, + "duration_ms": round(duration_ms, 2), + "trace_id": trace_id, + }, + level="warning", + ) + + return response diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2ea9a85f8498614ace4d85a1afb123896436ac2 Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/benchmark.cpython-310.pyc b/models/__pycache__/benchmark.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea70ae3d85f9e2efb637120e8e022599437e9920 Binary files /dev/null and b/models/__pycache__/benchmark.cpython-310.pyc differ diff --git a/models/__pycache__/dataset.cpython-310.pyc b/models/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c536820915c1e5c44a508733ecae98f2f0488151 Binary files /dev/null and b/models/__pycache__/dataset.cpython-310.pyc differ diff --git a/models/__pycache__/inference.cpython-310.pyc b/models/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a083fb483a60b1986bc4acb8254e61ce6158218 Binary files /dev/null and b/models/__pycache__/inference.cpython-310.pyc differ diff --git a/models/__pycache__/job.cpython-310.pyc b/models/__pycache__/job.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8384d8a35d1a236b96a93800a30969d919ef721e Binary files /dev/null and b/models/__pycache__/job.cpython-310.pyc differ diff --git a/models/__pycache__/model.cpython-310.pyc b/models/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..837e92a89becb77993eab08089932699ea5f44fd Binary files /dev/null and b/models/__pycache__/model.cpython-310.pyc differ diff --git a/models/__pycache__/project.cpython-310.pyc b/models/__pycache__/project.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf6d60fa8f7af302a7b9ee33f452f3e1d3280a3 Binary files /dev/null and b/models/__pycache__/project.cpython-310.pyc differ diff --git a/models/__pycache__/system.cpython-310.pyc b/models/__pycache__/system.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ec1ffbe9126a6ec88cf39433e9aa35f2ab1fba Binary files /dev/null and b/models/__pycache__/system.cpython-310.pyc differ diff --git a/models/benchmark.py b/models/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..a5cb66df74d9542e229a4054863dd75312c0149e --- /dev/null +++ b/models/benchmark.py @@ -0,0 +1,223 @@ +""" +models/benchmark.py — Pydantic domain models for the Benchmark Bridge System. +Single source of truth for all benchmark-related data shapes across API, +execution engine, and database layer. +""" +from __future__ import annotations + +import json +from typing import Any + +from pydantic import BaseModel, Field, ConfigDict + + +# ── Input ───────────────────────────────────────────────────────────────────── + +class BenchmarkContext(BaseModel): + """Payload the UI sends to initiate a benchmark run.""" + model_config = ConfigDict(protected_namespaces=()) + model_id: str + dataset_id: str + task: str + framework: str + hardware: str = "cpu" + precision: str = "FP32" + batch_size: int = Field(1, ge=1, le=512) + # Task-specific overrides + max_tokens: int | None = 512 + sequence_length: int | None = 512 + img_size: int | None = 640 + vid_stride: int | None = 1 + stream: bool | None = False + input_source: str | None = "dataset" + video_path: str | None = None + rtsp_url: str | None = None + # Object Detection live preview data + detections: list[dict[str, Any]] = Field(default_factory=list) + + +# ── Validation ──────────────────────────────────────────────────────────────── + +class ValidationCheck(BaseModel): + """Result of a single compatibility gate.""" + name: str + passed: bool + detail: str + suggestion: str | None = None + + +class ValidationReport(BaseModel): + """Aggregated result of all compatibility checks for a model+dataset pair.""" + model_config = ConfigDict(protected_namespaces=()) + model_id: str + dataset_id: str + passed: bool # True only if ALL checks pass + checks: list[ValidationCheck] + errors: list[str] # details from failed checks + warnings: list[str] = Field(default_factory=list) + + +# ── Metrics ─────────────────────────────────────────────────────────────────── + +class BenchmarkMetrics(BaseModel): + """Task-specific + hardware performance metrics from a completed run.""" + # Detection / Segmentation + mAP: float | None = None + mAP_50: float | None = None + mAP_50_95: float | None = None + # Classification + accuracy: float | None = None + top1: float | None = None + top5: float | None = None + # Segmentation + iou_mean: float | None = None + # NLP / Generation + rouge_l: float | None = None + bleu: float | None = None + perplexity: float | None = None + tokens_per_sec: float | None = None + # Throughput & Latency + fps: float | None = None + latency_mean_ms: float | None = None + latency_p95_ms: float | None = None + latency_p99_ms: float | None = None + # Memory + vram_peak_gb: float | None = None + vram_avg_gb: float | None = None + # Dataset info + total_images: int | None = None + total_tokens: int | None = None + batch_size: int | None = None + + class Config: + extra = "allow" + + +# ── Telemetry ───────────────────────────────────────────────────────────────── + +class TelemetrySample(BaseModel): + """Single hardware reading captured during benchmark execution.""" + timestamp: float # Unix epoch seconds + gpu_util_pct: float = 0.0 # 0–100 + vram_used_gb: float = 0.0 + vram_total_gb: float = 0.0 + temp_c: float = 0.0 + power_w: float = 0.0 + batch_idx: int = 0 + progress: float = 0.0 # 0.0–1.0 + # Optional task-specific live data (e.g. BBoxes for detection) + live_data: dict[str, Any] = Field(default_factory=dict) + detections: list[dict[str, Any]] = Field(default_factory=list) + + +class LayerBreakdown(BaseModel): + """Single layer entry in a bottleneck analysis.""" + name: str + time_ms: float + percent: float + + +class TelemetrySummary(BaseModel): + """Aggregated telemetry statistics over the full benchmark run.""" + gpu_util_avg: float = 0.0 + gpu_util_peak: float = 0.0 + vram_avg_gb: float = 0.0 + vram_peak_gb: float = 0.0 + temp_avg_c: float = 0.0 + temp_peak_c: float = 0.0 + power_avg_w: float = 0.0 + power_peak_w: float = 0.0 + layer_breakdown: list[LayerBreakdown] = Field(default_factory=list) + + +# ── Job & Result ────────────────────────────────────────────────────────────── + +class BenchmarkJob(BaseModel): + id: str + model_config = ConfigDict(protected_namespaces=()) + model_id: str + dataset_id: str + task: str + framework: str + hardware: str + precision: str + batch_size: int + config: dict = Field(default_factory=dict) + status: str = "queued" # queued|running|completed|failed + progress: float = 0.0 + logs: list[str] = Field(default_factory=list) + created_at: str | None = None + updated_at: str | None = None + started_at: str | None = None + ended_at: str | None = None + last_telemetry: TelemetrySample | None = None + + +class BenchmarkResult(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + id: str + job_id: str + metrics: BenchmarkMetrics + telemetry_summary: TelemetrySummary + created_at: str | None = None + # Denormalized from Job for UI efficiency + model_id: str | None = None + dataset_id: str | None = None + task: str | None = None + framework: str | None = None + hardware: str | None = None + precision: str | None = None + + +# ── API Responses ───────────────────────────────────────────────────────────── + +class BenchmarkRunResponse(BaseModel): + job_id: str + status: str + message: str + + +# ── DB Row helpers ──────────────────────────────────────────────────────────── + +def row_to_job(row: Any) -> BenchmarkJob: + d = dict(row) + cfg = json.loads(d.get("config") or "{}") + return BenchmarkJob( + id = d["id"], + model_id = d["model_id"], + dataset_id = d["dataset_id"], + task = d["task"], + framework = d["framework"], + hardware = d["hardware"], + precision = d["precision"], + batch_size = d["batch_size"], + config = cfg, + status = d["status"], + progress = float(d.get("progress", 0.0)), + logs = json.loads(d.get("logs") or "[]"), + error = d.get("error"), + created_at = d.get("created_at"), + updated_at = d.get("updated_at"), + started_at = d.get("started_at"), + ended_at = d.get("ended_at"), + last_telemetry = TelemetrySample(**json.loads(d.get("last_telemetry") or "{}")) if d.get("last_telemetry") else None, + ) + + +def row_to_result(row: Any) -> BenchmarkResult: + d = dict(row) + metrics_raw = json.loads(d.get("metrics") or "{}") + telemetry_raw = json.loads(d.get("telemetry_summary") or "{}") + return BenchmarkResult( + id = d["id"], + job_id = d["job_id"], + metrics = BenchmarkMetrics(**metrics_raw), + telemetry_summary = TelemetrySummary(**telemetry_raw), + created_at = d.get("created_at"), + model_id = d.get("model_id"), + dataset_id = d.get("dataset_id"), + task = d.get("task"), + framework = d.get("framework"), + hardware = d.get("hardware"), + precision = d.get("precision"), + ) diff --git a/models/dataset.py b/models/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2255929cfa587159655a9be9fbbd747202b9ec --- /dev/null +++ b/models/dataset.py @@ -0,0 +1,401 @@ +""" +models/dataset.py — Pydantic domain models for the Dataset Manager. +Single source of truth for all dataset-related data shapes. +""" +from __future__ import annotations + +import json +from datetime import datetime +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field, ConfigDict + + +# ── Universal Dataset Viewer (UDV) Models ────────────────────────────────── + +class DatasetContentType(str, Enum): + image = "image" + text = "text" + audio = "audio" + tabular = "tabular" + +class UniversalAnnotationType(str, Enum): + detection = "detection" + segmentation = "segmentation" + keypoints = "keypoints" + classification = "classification" + span = "span" + +class UniversalAnnotation(BaseModel): + label: str + type: UniversalAnnotationType + bbox: Optional[list[float]] = None # [x, y, w, h] normalized + segmentation: Optional[list[list[float]]] = None # [[x1, y1, x2, y2, ...], ...] + keypoints: Optional[list[float]] = None # [x1, y1, v1, ...] + confidence: Optional[float] = None + metadata: Optional[dict[str, Any]] = None + +class UniversalDatasetItem(BaseModel): + id: str + content_type: DatasetContentType + content_url: Optional[str] = None + content_body: Optional[str] = None # For text or raw json + filename: Optional[str] = None + metadata: dict[str, Any] = Field(default_factory=dict) + annotations: list[UniversalAnnotation] = Field(default_factory=list) + +class UniversalViewerPage(BaseModel): + dataset_id: str + page: int + page_size: int + total: int + total_pages: int + items: list[UniversalDatasetItem] + + +# ── Enumerations ────────────────────────────────────────────────────────────── + +class DatasetTask(str, Enum): + detection = "detection" + classification = "classification" + segmentation = "segmentation" + nlp = "nlp" + generation = "generation" + keypoints = "keypoints" + + +class DatasetFormat(str, Enum): + yolo = "yolo" + coco = "coco" + voc = "voc" + csv = "csv" + json = "json" + tfrecord = "tfrecord" + custom = "custom" + + +class DatasetSource(str, Enum): + roboflow = "roboflow" + roboflow_curl = "roboflow_curl" # direct cURL / pre-signed URL download + local = "local" + huggingface = "huggingface" + + +class DatasetStatus(str, Enum): + available = "available" + queued = "queued" + importing = "importing" + extracting = "extracting" + validating = "validating" + imported = "imported" + failed = "failed" + + +class JobType(str, Enum): + import_ = "import" + extract = "extract" + validate = "validate" + analyze = "analyze" + delete = "delete" + + +class JobStatus(str, Enum): + queued = "queued" + running = "running" + completed = "completed" + failed = "failed" + cancelled = "cancelled" + + +class AnnotationType(str, Enum): + detection = "detection" + segmentation = "segmentation" + classification = "classification" + + +# ── Sub-models ──────────────────────────────────────────────────────────────── + +class DatasetSplit(BaseModel): + train: int = 0 + val: int = 0 + test: int = 0 + + @property + def total(self) -> int: + return self.train + self.val + self.test + + +class DatasetVersion(BaseModel): + version: str + date: str = "" + changes: str = "" + images: int = 0 + format: str = "" + + +class DatasetStats(BaseModel): + """Aggregate statistics computed during import/analysis.""" + image_count: int = 0 + annotation_count: int = 0 + class_count: int = 0 + avg_objects: float = 0.0 + missing_labels: int = 0 + empty_images: int = 0 + duplicate_count: int = 0 + health_score: float = 0.0 + split: DatasetSplit = Field(default_factory=DatasetSplit) + + +# ── Core Domain Models ──────────────────────────────────────────────────────── + +class Dataset(BaseModel): + model_config = ConfigDict(protected_namespaces=(), use_enum_values=True) + id: str + name: str + description: str = "" + task: DatasetTask + format: DatasetFormat + source: DatasetSource + status: DatasetStatus = DatasetStatus.available + images: int = 0 + classes: int = 0 + class_names: list[str] = Field(default_factory=list) + size_bytes: int = 0 + size_label: str = "0 B" + local_path: str | None = None + import_progress: float = 0.0 # 0.0–1.0 + tags: list[str] = Field(default_factory=list) + versions: list[DatasetVersion] = Field(default_factory=list) + active_version: str = "v1" + stats: DatasetStats = Field(default_factory=DatasetStats) + starred: bool = False + roboflow_id: str | None = None # workspace/project slug + created_at: str | None = None + updated_at: str | None = None + + +class DatasetSummary(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + """Lightweight projection for list endpoints.""" + id: str + name: str + task: str + format: str + source: str + status: str + images: int + classes: int + size_label: str + tags: list[str] + starred: bool + import_progress: float + health_score: float = 0.0 + created_at: str | None = None + updated_at: str | None = None + + +# ── Annotation Models ───────────────────────────────────────────────────────── + +class BoundingBox(BaseModel): + x: float # top-left x (pixels or normalised) + y: float # top-left y + width: float + height: float + normalised: bool = True # True → 0–1 range, False → pixel coords + + +class Annotation(BaseModel): + """Unified annotation record (format-agnostic).""" + label: str + bbox: BoundingBox | None = None + segmentation: list[list[float]] | None = None # polygon points + keypoints: list[float] | None = None # [x, y, v, ...] + metadata: dict[str, Any] | None = None + confidence: float | None = None + area: float | None = None + type: AnnotationType = AnnotationType.detection + + +class ImageRecord(BaseModel): + """Image + its parsed annotations — returned by viewer endpoints.""" + image_id: str + filename: str + width: int = 0 + height: int = 0 + path: str # relative to dataset root + annotations: list[Annotation] = Field(default_factory=list) + split: str = "train" # train|val|test + + +class ViewerPage(BaseModel): + """Paginated viewer response.""" + dataset_id: str + page: int + page_size: int + total: int + total_pages: int + images: list[ImageRecord] + + +# ── Job Models ──────────────────────────────────────────────────────────────── + +class DatasetJob(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + id: str + type: str + status: str + dataset_id: str + dataset_name: str + progress: float = 0.0 # 0.0–1.0 + message: str = "" + error: str | None = None + created_at: str | None = None + updated_at: str | None = None + started_at: str | None = None + ended_at: str | None = None + + +# ── Request/Response Schemas ───────────────────────────────────────────────── + +class ImportRequest(BaseModel): + dataset_id: str + source: DatasetSource + roboflow_key: str | None = None # required when source=roboflow + roboflow_workspace: str | None = None + roboflow_project: str | None = None + roboflow_version: int = 1 + hf_dataset_id: str | None = None # required when source=huggingface (e.g. "microsoft/coco") + format: DatasetFormat = DatasetFormat.yolo + local_path: str | None = None # required when source=local + # cURL / direct download (source=roboflow_curl) + download_url: str | None = None # pre-signed or direct download URL + headers: dict[str, str] = Field(default_factory=dict) # Custom headers for download + dataset_name: str | None = None # human-readable name override + name: str | None = None # alias for dataset_name (used in local folder import) + curl_format: str | None = None # export format label from Roboflow cURL (e.g. "yolov8") + + +class ImportResponse(BaseModel): + job_id: str + dataset_id: str + status: str + message: str + + +class RoboflowSearchRequest(BaseModel): + query: str = "" + api_key: str + workspace: str | None = None + page: int = 0 + page_size: int = 50 + + +# ── DB Row helpers ──────────────────────────────────────────────────────────── + +def row_to_dataset(row: Any) -> Dataset: + """ + Robustly convert a DB row (sqlite3.Row or dict) to a Dataset model. + Handles: + 1. Enum string cleaning (stripping prefixes like 'DatasetStatus.') + 2. JSON parsing for nested fields (tags, class_names, versions) + 3. Missing 'stats' object initialization + """ + import logging + logger = logging.getLogger("models.dataset") + + try: + d = dict(row) if not isinstance(row, dict) else row.copy() + + def clean_enum(val: Any) -> Any: + if isinstance(val, str) and "." in val: + return val.split(".")[-1] + return val + + # Clean enum fields + for field in ["status", "task", "format", "source"]: + if field in d: + d[field] = clean_enum(d[field]) + + # Parse JSON fields with safety + for field in ["class_names", "tags", "versions"]: + raw = d.get(field) + if isinstance(raw, str): + try: + d[field] = json.loads(raw) + except Exception: + d[field] = [] + elif raw is None: + d[field] = [] + + # Handle 'stats' - it might be a JSON string or missing in DB + stats_obj = DatasetStats() + stats_raw = d.get("stats") + if isinstance(stats_raw, str): + try: + stats_data = json.loads(stats_raw) + stats_obj = DatasetStats(**stats_data) + except Exception: + pass + elif isinstance(stats_raw, dict): + try: + stats_obj = DatasetStats(**stats_raw) + except Exception: + pass + + # Ensure other numeric/boolean fields have defaults + d["images"] = d.get("images", 0) + d["classes"] = d.get("classes", 0) + d["starred"] = bool(d.get("starred", 0)) + d["import_progress"] = float(d.get("import_progress", 0.0)) + d["size_bytes"] = d.get("size_bytes", 0) + + # Build clean dict for Pydantic + clean_data = { + "id": d["id"], + "name": d["name"], + "description": d.get("description", ""), + "task": d["task"], + "format": d["format"], + "source": d["source"], + "status": d.get("status", "available"), + "images": d["images"], + "classes": d["classes"], + "class_names": d["class_names"], + "size_bytes": d["size_bytes"], + "size_label": d.get("size_label", "0 B"), + "local_path": d.get("local_path"), + "import_progress": d["import_progress"], + "tags": d["tags"], + "versions": d["versions"], + "active_version": d.get("active_version", "v1"), + "stats": stats_obj, + "starred": d["starred"], + "roboflow_id": d.get("roboflow_id"), + "created_at": d.get("created_at"), + "updated_at": d.get("updated_at") + } + + return Dataset(**clean_data) + + except Exception as e: + logger.error(f"Pydantic instantiation error: {e}, row keys: {list(row.keys()) if hasattr(row, 'keys') else 'N/A'}") + raise + + +def row_to_job(row: Any) -> DatasetJob: + d = dict(row) + return DatasetJob( + id = d["id"], + type = d["type"], + status = d["status"], + dataset_id = d.get("dataset_id", ""), + dataset_name = d.get("dataset_name", ""), + progress = float(d.get("progress", 0.0)), + message = d.get("message", ""), + error = d.get("error"), + created_at = d.get("created_at"), + updated_at = d.get("updated_at"), + started_at = d.get("started_at"), + ended_at = d.get("ended_at"), + ) diff --git a/models/inference.py b/models/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..454f630a38d35afdc08ad7e7551868f268b5c62e --- /dev/null +++ b/models/inference.py @@ -0,0 +1,142 @@ +""" +models/inference.py — Pydantic models for the Inference Engine. +Covers request, response, session history, and pipeline stage telemetry. +""" +from __future__ import annotations + +from enum import Enum +from typing import Any, Literal +from pydantic import BaseModel, Field +import time +import uuid + + +class AdapterType(str, Enum): + YOLO = "yolo" + TRANSFORMERS = "transformers" + ONNX = "onnx" + CUSTOM = "custom" + + +class InferencePrecision(str, Enum): + FP32 = "FP32" + FP16 = "FP16" + INT8 = "INT8" + + +class YOLOConfig(BaseModel): + confidence: float = Field(0.25, ge=0.0, le=1.0) + iou_threshold: float = Field(0.45, ge=0.1, le=0.9) + class_filter: list[str] = Field(default_factory=list) + max_detections: int = Field(300, ge=1, le=1000) + + +class TransformersConfig(BaseModel): + max_new_tokens: int = Field(256, ge=1, le=4096) + temperature: float = Field(0.7, ge=0.0, le=2.0) + top_p: float = Field(0.9, ge=0.0, le=1.0) + top_k: int = Field(50, ge=0, le=200) + beam_width: int = Field(1, ge=1, le=8) + do_sample: bool = True + + +class ONNXConfig(BaseModel): + execution_provider: Literal["CUDAExecutionProvider", "CPUExecutionProvider"] = "CUDAExecutionProvider" + input_size: int = Field(640, ge=32, le=1280) + normalize: bool = True + + +class CustomConfig(BaseModel): + preprocess_script: str = "" + postprocess_script: str = "" + + +class InferenceRequest(BaseModel): + model_id: str + adapter_type: AdapterType + precision: InferencePrecision = InferencePrecision.FP16 + + # Input — one of these must be set + image_base64: str | None = None # base64-encoded image + text_input: str | None = None # text/prompt + + # Per-adapter config + yolo_config: YOLOConfig | None = None + transformers_config: TransformersConfig | None = None + onnx_config: ONNXConfig | None = None + custom_config: CustomConfig | None = None + + # Execution + run_mode: Literal["single", "stream"] = "single" + + +class PipelineStage(BaseModel): + name: str + status: Literal["pending", "running", "done", "error"] = "pending" + latency_ms: float | None = None + detail: str | None = None + + +class Detection(BaseModel): + x1: float + y1: float + x2: float + y2: float + confidence: float + class_id: int + class_name: str + + +class InferenceResult(BaseModel): + # Identity + request_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + model_id: str + adapter_type: AdapterType + timestamp: float = Field(default_factory=time.time) + + # Timing + preprocess_ms: float = 0.0 + inference_ms: float = 0.0 + postprocess_ms: float = 0.0 + total_ms: float = 0.0 + + # Output — adapter-specific, all optional + detections: list[Detection] = Field(default_factory=list) + text_output: str | None = None + class_label: str | None = None + confidence: float | None = None + embeddings: list[float] | None = None + raw_output: Any = None # raw JSON for inspector + + # Pipeline trace + pipeline: list[PipelineStage] = Field(default_factory=list) + + # Quality score (0–5) derived from confidence mean + quality_score: float | None = None + + # Error + error: str | None = None + status: Literal["ok", "error"] = "ok" + + +class InferenceHistoryEntry(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + model_id: str + model_name: str + adapter_type: AdapterType + timestamp: float = Field(default_factory=time.time) + total_ms: float + quality_score: float | None + status: Literal["ok", "error"] + # Compact snapshot of result for re-run + request_snapshot: dict[str, Any] = Field(default_factory=dict) + + +class SystemVitals(BaseModel): + ts: float + latency_ms: float + fps: float + vram_used_gb: float + vram_total_gb: float + gpu_temp_c: float | None = None + cpu_pct: float = 0.0 diff --git a/models/job.py b/models/job.py new file mode 100644 index 0000000000000000000000000000000000000000..6f30eea6462c8611400cc4f918ef68721e3a53f7 --- /dev/null +++ b/models/job.py @@ -0,0 +1,51 @@ +""" +models/job.py — Job domain models (download / benchmark / sync). +""" +from __future__ import annotations + +import json +from typing import Any + +from pydantic import BaseModel, Field + + +class Job(BaseModel): + model_config = {"protected_namespaces": ()} + id: str + type: str # download|benchmark|sync + status: str # queued|running|completed|failed|cancelled + model_id: str | None = None + model_name: str | None = None + progress: float = 0.0 # 0.0–1.0 + error: str | None = None + meta: dict[str, Any] = Field(default_factory=dict) + created_at: str | None = None + updated_at: str | None = None + started_at: str | None = None + ended_at: str | None = None + + +class JobCreate(BaseModel): + model_config = {"protected_namespaces": ()} + model_id: str + model_name: str + type: str = "download" + version: str | None = None # specific weight file / version to download + + +def row_to_job(row: Any) -> Job: + d = dict(row) + return Job( + id = d["id"], + type = d["type"], + status = d["status"], + model_id = d.get("model_id"), + model_name = d.get("model_name"), + progress = float(d.get("progress", 0.0)), + error = d.get("error"), + meta = json.loads(d.get("meta") or "{}"), + created_at = d.get("created_at"), + updated_at = d.get("updated_at"), + started_at = d.get("started_at"), + ended_at = d.get("ended_at"), + ) diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a008a49da528bdfcd81734a5a2fd6a58f76bae3a --- /dev/null +++ b/models/model.py @@ -0,0 +1,129 @@ +""" +models/model.py — Pydantic domain models (schema contract for API + internal). +Single source of truth for data shapes between all modules. +""" +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field, field_validator, model_validator + +# ── Enumerations ────────────────────────────────────────────────────────────── + +ModelTask = str # detection|classification|segmentation|generation|embedding|nlp +ModelFramework = str # pytorch|onnx|tensorflow|tflite|coreml +ModelSource = str # hf|onnx|local +ModelStatus = str # available|downloading|cached|error +HardwareTarget = str # gpu|cpu|edge|tpu + + +# ── Sub-models ──────────────────────────────────────────────────────────────── + +class ModelMetrics(BaseModel): + latency_ms: float | None = None + mAP: float | None = None + accuracy: float | None = None + top1: float | None = None + vram_gb: float | None = None + fps: float | None = None + flops: float | None = None + + class Config: + extra = "allow" + + +class ModelVersion(BaseModel): + version: str + label: str = "Stable" # Latest|Stable|Legacy|Nano|Small|Medium|Large|XLarge + description: str | None = None + releaseDate: str = "" + changelog: str | None = None + + +# ── Core Model ──────────────────────────────────────────────────────────────── + +class Model(BaseModel): + id: str + name: str + variant: str | None = None + task: ModelTask + framework: ModelFramework + size: int = 0 # bytes + size_label: str = "0 B" + tags: list[str] = Field(default_factory=list) + source: ModelSource = "hf" + provider: str = "" + description: str = "" + download_url: str | None = None # explicit download source (HF repo URL, ONNX direct URL, etc.) + local_path: str | None = None + project_id: str | None = None + downloaded: bool = False + status: ModelStatus = "available" + hardware: list[HardwareTarget] = Field(default_factory=list) + metrics: ModelMetrics = Field(default_factory=ModelMetrics) + versions: list[ModelVersion] = Field(default_factory=list) + active_version: str | None = None + rating: float | None = None + downloads: int | None = None + liked: bool = False + created_at: str | None = None + updated_at: str | None = None + + +class ModelSummary(BaseModel): + """Lightweight projection returned in list endpoints.""" + id: str + name: str + task: ModelTask + framework: ModelFramework + source: ModelSource + provider: str + size_label: str + status: ModelStatus + downloaded: bool + downloads: int | None = None + rating: float | None = None + tags: list[str] + hardware: list[HardwareTarget] + metrics: ModelMetrics + + +# ── DB Row → Model ──────────────────────────────────────────────────────────── + +def row_to_model(row: Any, versions: list[ModelVersion] | None = None) -> Model: + """Convert an aiosqlite Row dict to a Model instance.""" + d = dict(row) + metrics_raw = d.get("metrics") or "{}" + # metrics may come from model_versions join or not exist on models row + if isinstance(metrics_raw, str): + metrics_raw = json.loads(metrics_raw) + + return Model( + id = d["id"], + name = d["name"], + variant = d.get("variant"), + task = d["task"], + framework = d["framework"], + source = d.get("source", "hf"), + provider = d.get("provider", ""), + description = d.get("description", ""), + download_url= d.get("download_url"), + size = d.get("size", 0), + size_label = d.get("size_label", "0 B"), + tags = json.loads(d.get("tags") or "[]"), + hardware = json.loads(d.get("hardware") or "[]"), + status = d.get("status", "available"), + downloaded = bool(d.get("downloaded", 0)), + local_path = d.get("local_path"), + project_id = d.get("project_id"), + downloads = d.get("downloads"), + rating = d.get("rating"), + liked = bool(d.get("liked", 0)), + metrics = ModelMetrics(**metrics_raw), + versions = versions or [], + active_version = d.get("active_version"), + created_at = d.get("created_at"), + updated_at = d.get("updated_at"), + ) diff --git a/models/project.py b/models/project.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0b6695887d51c024e393df7218de7d9ef7ada3 --- /dev/null +++ b/models/project.py @@ -0,0 +1,15 @@ +"""models/project.py — Pydantic domain model for workspace projects.""" + +from __future__ import annotations + +from pydantic import BaseModel + + +class Project(BaseModel): + id: str + name: str + path: str + created_at: str + last_opened: str + status: str = "idle" + diff --git a/models/system.py b/models/system.py new file mode 100644 index 0000000000000000000000000000000000000000..abb8a429445e75a346cf26f7e267e204ae679d35 --- /dev/null +++ b/models/system.py @@ -0,0 +1,47 @@ +"""models/system.py — Pydantic models for real-time system metrics.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel + + +class GpuMetrics(BaseModel): + index: int + name: str | None = None + utilization_pct: float | None = None + mem_used_mb: float | None = None + mem_total_mb: float | None = None + temperature_c: float | None = None + power_usage_w: float | None = None + power_limit_w: float | None = None + clock_graphics_mhz: float | None = None + +class DiskMetrics(BaseModel): + device: str + mountpoint: str + total_gb: float + used_gb: float + percent: float + read_bytes_sec: float + write_bytes_sec: float + +class NetworkMetrics(BaseModel): + interface: str + bytes_sent_sec: float + bytes_recv_sec: float + +class SystemMetrics(BaseModel): + ts: float + + cpu_pct: float + cpu_model: str | None = None + cpu_freq_mhz: float | None = None + cpu_count: int | None = None + ram_used_mb: float + ram_total_mb: float + + gpu: Optional[GpuMetrics] = None + disks: list[DiskMetrics] = [] + network: list[NetworkMetrics] = [] diff --git a/observability/__init__.py b/observability/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/observability/__pycache__/__init__.cpython-310.pyc b/observability/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..968ea6ed4a7ac32deb8aa855134bb652c8fede47 Binary files /dev/null and b/observability/__pycache__/__init__.cpython-310.pyc differ diff --git a/observability/__pycache__/logger.cpython-310.pyc b/observability/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d0e6a04ad4c619fefbe432b951ae910cefdd178 Binary files /dev/null and b/observability/__pycache__/logger.cpython-310.pyc differ diff --git a/observability/logger.py b/observability/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce97a0e611ff56fe1155fd5b1d5cce4350cbccd --- /dev/null +++ b/observability/logger.py @@ -0,0 +1,149 @@ +""" +observability/logger.py — Structured JSON logging & audit trail. +Every module imports get_logger(); every API call writes to audit_log. +""" +from __future__ import annotations + +import logging +import sys +import traceback +import random +import asyncio +from datetime import datetime, timezone +from typing import Any + +import structlog +from structlog.types import EventDict, WrappedLogger + +from config import settings + + +# ── Processors ──────────────────────────────────────────────────────────────── + +def _add_timestamp( + _logger: WrappedLogger, _method: str, event_dict: EventDict +) -> EventDict: + event_dict["timestamp"] = datetime.now(timezone.utc).isoformat() + return event_dict + + +def _add_service( + _logger: WrappedLogger, _method: str, event_dict: EventDict +) -> EventDict: + event_dict["service"] = "mlforge-backend" + event_dict["version"] = settings.version + return event_dict + + +# ── Bootstrap ───────────────────────────────────────────────────────────────── + +def configure_logging() -> None: + """Call once at startup before any log is emitted.""" + settings.ensure_dirs() + + shared_processors: list[Any] = [ + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + _add_timestamp, + _add_service, + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + ] + + # File handler — JSON + file_handler = logging.FileHandler( + settings.logs_dir / "backend.log", encoding="utf-8" + ) + file_handler.setLevel(logging.DEBUG) + + # Console handler — pretty in dev, JSON in prod + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.DEBUG if settings.debug else logging.INFO) + + logging.basicConfig( + level=logging.DEBUG, + handlers=[file_handler, console_handler], + format="%(message)s", + ) + + structlog.configure( + processors=shared_processors + + [ + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + wrapper_class=structlog.stdlib.BoundLogger, + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + # Attach JSON renderer to handlers + renderer = structlog.processors.JSONRenderer() + formatter = structlog.stdlib.ProcessorFormatter( + processor=renderer, + foreign_pre_chain=shared_processors, + ) + for handler in logging.getLogger().handlers: + handler.setFormatter(formatter) + + +# ── Global System Log Queue (for Unified Dashboard) ────────────────────────── + +_sys_log_queue: asyncio.Queue = asyncio.Queue(maxsize=1000) +_sys_log_subs: list[asyncio.Queue] = [] + +def log_system_event( + level: str, + message: str, + source: str = "system", + payload: dict[str, Any] | None = None +) -> None: + """Push a structured log into the global system queue.""" + import time + + event = { + "id": f"sys-{time.time()}-{random.random()}", + "ts": time.strftime("%H:%M:%S"), + "timestamp": int(time.time() * 1000), + "level": level.upper(), + "source": source, + "message": message, + "metrics": payload, + "source_type": "system" + } + + # Broadcast to all active SSE subscribers + dead = [] + for q in _sys_log_subs: + try: + if q.qsize() >= 100: q.get_nowait() + q.put_nowait(event) + except Exception: + dead.append(q) + for d in dead: + if d in _sys_log_subs: + _sys_log_subs.remove(d) + + +def get_logger(name: str = "mlforge") -> structlog.stdlib.BoundLogger: + return structlog.get_logger(name) + + +async def audit( + event_type: str, + payload: dict[str, Any] | None = None, + model_id: str | None = None, + job_id: str | None = None, + level: str = "info", +) -> None: + """Write a structured audit record to the audit_log table.""" + import json + from database.connection import get_db + + db = await get_db() + await db.execute( + """INSERT INTO audit_log (event_type, model_id, job_id, payload, level) + VALUES (?, ?, ?, ?, ?)""", + (event_type, model_id, job_id, json.dumps(payload or {}), level), + ) + await db.commit() diff --git a/projects/__init__.py b/projects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/projects/__pycache__/__init__.cpython-310.pyc b/projects/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32012d58ef7624a78017691a0749ea3e665d812b Binary files /dev/null and b/projects/__pycache__/__init__.cpython-310.pyc differ diff --git a/projects/__pycache__/registry.cpython-310.pyc b/projects/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67d54c8623bf91bec987d8cd7815bc46f2095fa8 Binary files /dev/null and b/projects/__pycache__/registry.cpython-310.pyc differ diff --git a/projects/__pycache__/service.cpython-310.pyc b/projects/__pycache__/service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001fc8db63f17e75d6ef644b43fe80388712d0f6 Binary files /dev/null and b/projects/__pycache__/service.cpython-310.pyc differ diff --git a/projects/registry.py b/projects/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e541ae8fe4bdebbce6ce1a3b15f85aa61ddd0e4d --- /dev/null +++ b/projects/registry.py @@ -0,0 +1,73 @@ +"""projects/registry.py — Project persistence in SQLite.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from database.connection import get_db +from models.project import Project + + +async def upsert_project(project: Project) -> None: + db = await get_db() + await db.execute( + """INSERT INTO projects (id, name, path, created_at, last_opened, status) + VALUES (?,?,?,?,?,?) + ON CONFLICT(id) DO UPDATE SET + name=excluded.name, + path=excluded.path, + last_opened=excluded.last_opened, + status=excluded.status + """, + ( + project.id, + project.name, + project.path, + project.created_at, + project.last_opened, + project.status, + ), + ) + await db.commit() + + +async def list_projects(limit: int = 200, offset: int = 0) -> list[Project]: + db = await get_db() + async with db.execute( + """SELECT id, name, path, created_at, last_opened, status + FROM projects + ORDER BY datetime(last_opened) DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ) as cur: + rows = await cur.fetchall() + return [Project(**dict(r)) for r in rows] + + +async def delete_project(project_id: str) -> bool: + db = await get_db() + cur = await db.execute("DELETE FROM projects WHERE id = ?", (project_id,)) + await db.commit() + return (cur.rowcount or 0) > 0 + + +async def get_project(project_id: str) -> Project | None: + db = await get_db() + async with db.execute( + "SELECT id, name, path, created_at, last_opened, status FROM projects WHERE id = ?", + (project_id,), + ) as cur: + row = await cur.fetchone() + return Project(**dict(row)) if row else None + + +async def touch_last_opened(project_id: str) -> None: + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + await db.execute( + "UPDATE projects SET last_opened = ? WHERE id = ?", + (now, project_id), + ) + await db.commit() + diff --git a/projects/service.py b/projects/service.py new file mode 100644 index 0000000000000000000000000000000000000000..fba5632396b5f33091f45a5959ee3f01c7cd46d1 --- /dev/null +++ b/projects/service.py @@ -0,0 +1,186 @@ +""" +projects/service.py — Active project session + model workspace linking. + +Tracks which project is currently open (via the `session` DB table) and +copies freshly-downloaded model files into the project's models/ folder +so the benchmark engine and other workspaces can locate them. +""" +from __future__ import annotations + +import os +import shutil +import uuid +from pathlib import Path +from datetime import datetime, timezone + +from database.connection import get_db +from observability.logger import get_logger, audit +from models.model import Model, ModelMetrics +from registry.registry import upsert_model + +log = get_logger("projects.service") + + +async def import_local_model( + name: str, + task: str, + framework: str, + source_file_path: str +) -> Model: + """Import a local model file into the active project.""" + project_id = await get_active_project_id() + project_path = await get_active_project_path() + + if not project_id or not project_path: + raise ValueError("No active project found. Please open a project first.") + + src = Path(source_file_path) + if not src.exists(): + raise FileNotFoundError(f"Source model file not found: {source_file_path}") + + # Create destination directory in project + model_id = f"local-{uuid.uuid4().hex[:12]}" + dest_dir = Path(project_path) / "models" / model_id + dest_dir.mkdir(parents=True, exist_ok=True) + + dest_path = dest_dir / src.name + shutil.copy2(src, dest_path) + + # Calculate size + size_bytes = dest_path.stat().st_size + size_label = f"{size_bytes / (1024*1024):.1f} MB" if size_bytes > 1024*1024 else f"{size_bytes / 1024:.1f} KB" + + # Create model entry + now = datetime.now(timezone.utc).isoformat() + model = Model( + id=model_id, + name=name, + task=task, + framework=framework, + source="local", + provider="Local Import", + size=size_bytes, + size_label=size_label, + local_path=str(dest_path), + project_id=project_id, + downloaded=True, + status="cached", + created_at=now, + updated_at=now, + metrics=ModelMetrics() + ) + + await upsert_model(model) + await audit("model_imported_locally", model_id=model_id, payload={"name": name, "project_id": project_id}) + log.info("model_imported_locally", model_id=model_id, name=name, path=str(dest_path)) + + return model + + +# ── Session helpers ─────────────────────────────────────────────────────────── + +async def set_active_project(project_id: str, project_path: str) -> None: + """Persist the currently open project in the session table.""" + db = await get_db() + await db.execute( + "INSERT INTO session (key, value) VALUES ('active_project_id', ?)" + " ON CONFLICT(key) DO UPDATE SET value=excluded.value", + (project_id,), + ) + await db.execute( + "INSERT INTO session (key, value) VALUES ('active_project_path', ?)" + " ON CONFLICT(key) DO UPDATE SET value=excluded.value", + (project_path,), + ) + await db.commit() + log.info("active_project_set", project_id=project_id, path=project_path) + + +async def get_active_project_id() -> str | None: + """Return the ID of the currently open project, or None.""" + db = await get_db() + async with db.execute( + "SELECT value FROM session WHERE key = 'active_project_id'" + ) as cur: + row = await cur.fetchone() + return row["value"] if row else None + + +async def get_active_project_path() -> str | None: + """Return the filesystem path of the currently open project, or None.""" + db = await get_db() + async with db.execute( + "SELECT value FROM session WHERE key = 'active_project_path'" + ) as cur: + row = await cur.fetchone() + return row["value"] if row else None + + +# ── Workspace model linking ─────────────────────────────────────────────────── + +async def link_model_to_active_project(model_id: str, source_path: str) -> None: + """Copy the downloaded model file into the active project's models/ folder. + + This is a best-effort operation — if no project is open, or if the copy + fails for any reason, we log and continue rather than failing the download. + """ + project_path = await get_active_project_path() + if not project_path: + log.debug("link_model_skipped_no_project", model_id=model_id) + return + + src = Path(source_path) + if not src.exists(): + log.warning("link_model_source_missing", model_id=model_id, path=source_path) + return + + dest_dir = Path(project_path) / "models" / model_id + dest_dir.mkdir(parents=True, exist_ok=True) + dest = dest_dir / src.name + + if dest.exists(): + log.debug("link_model_already_exists", model_id=model_id, dest=str(dest)) + return + + try: + shutil.copy2(src, dest) + log.info("model_linked_to_project", model_id=model_id, project=project_path, dest=str(dest)) + except OSError as exc: + log.warning("link_model_copy_failed", model_id=model_id, error=str(exc)) + + +async def link_dataset_to_active_project(dataset_id: str, source_path: str) -> None: + """Copy the imported dataset folder into the active project's datasets/ folder. + + This is a best-effort operation — if no project is open, or if the copy + fails for any reason, we log and continue rather than failing the import. + """ + project_path = await get_active_project_path() + if not project_path: + log.debug("link_dataset_skipped_no_project", dataset_id=dataset_id) + return + + src = Path(source_path) + if not src.exists(): + log.warning("link_dataset_source_missing", dataset_id=dataset_id, path=source_path) + return + + dest_dir = Path(project_path) / "datasets" / dataset_id + + if dest_dir.exists(): + log.debug("link_dataset_already_exists", dataset_id=dataset_id, dest=str(dest_dir)) + return + + try: + if src.is_dir(): + shutil.copytree(src, dest_dir, dirs_exist_ok=True) + else: + # If it's a file (e.g. zip that wasn't extracted yet), just copy it + dest_dir.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dest_dir) + + log.info("dataset_linked_to_project", dataset_id=dataset_id, project=project_path, dest=str(dest_dir)) + return dest_dir + except OSError as exc: + log.warning("link_dataset_copy_failed", dataset_id=dataset_id, error=str(exc)) + return None diff --git a/registry/__init__.py b/registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/registry/__pycache__/__init__.cpython-310.pyc b/registry/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c5c61670feb979874c891daa560de9d90887d53 Binary files /dev/null and b/registry/__pycache__/__init__.cpython-310.pyc differ diff --git a/registry/__pycache__/registry.cpython-310.pyc b/registry/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d21434567cd8cceab34f29d592479b1eef4631ee Binary files /dev/null and b/registry/__pycache__/registry.cpython-310.pyc differ diff --git a/registry/registry.py b/registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..38849785e56cb61ad3aa1c1f78ba88d485719d73 --- /dev/null +++ b/registry/registry.py @@ -0,0 +1,251 @@ +""" +registry/registry.py — Model Registry. +Responsible for persisting, reading, and updating model metadata in SQLite. +All callers go through this module; no direct DB access from other modules. +""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + + + +from database.connection import get_db +from models.model import Model, ModelVersion, row_to_model +from observability.logger import audit, get_logger + +log = get_logger("registry") + + +async def upsert_model(model: Model) -> None: + """Insert or update a model record (and its first version).""" + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + + await db.execute( + """INSERT INTO models + (id, name, variant, task, framework, source, provider, description, + download_url, size, size_label, tags, hardware, status, downloaded, local_path, project_id, + active_version, metrics, + downloads, rating, liked, created_at, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(id) DO UPDATE SET + name=excluded.name, + variant=excluded.variant, + task=excluded.task, + framework=excluded.framework, + source=excluded.source, + provider=excluded.provider, + description=excluded.description, + download_url=excluded.download_url, + size=excluded.size, + size_label=excluded.size_label, + tags=excluded.tags, + hardware=excluded.hardware, + status=excluded.status, + downloads=excluded.downloads, + rating=excluded.rating, + active_version=excluded.active_version, + metrics=excluded.metrics, + local_path=excluded.local_path, + project_id=excluded.project_id, + updated_at=excluded.updated_at""", + ( + model.id, model.name, model.variant, model.task, model.framework, + model.source, model.provider, model.description, model.download_url, + model.size, model.size_label, json.dumps(model.tags), json.dumps(model.hardware), + model.status, int(model.downloaded), model.local_path, model.project_id, + model.active_version, model.metrics.model_dump_json(), + model.downloads, model.rating, int(model.liked), + model.created_at or now, now, + ), + ) + + # Upsert versions + for v in model.versions: + version_id = f"{model.id}_{v.version}" + await db.execute( + """INSERT INTO model_versions + (version_id, model_id, version, label, description, metrics, release_date, changelog) + VALUES (?,?,?,?,?,?,?,?) + ON CONFLICT(version_id) DO UPDATE SET + label=excluded.label, description=excluded.description, + release_date=excluded.release_date, changelog=excluded.changelog""", + ( + version_id, model.id, v.version, v.label, v.description, + json.dumps({}), v.releaseDate, v.changelog, + ), + ) + await db.commit() + + +async def bulk_upsert(models: list[Model]) -> None: + """Batch upsert for sync operations.""" + inserted = 0 + for model in models: + await upsert_model(model) + inserted += 1 + log.info("registry_bulk_upsert", total=inserted) + await audit("registry_sync", payload={"count": inserted}) + + +async def get_model(model_id: str) -> Model | None: + db = await get_db() + async with db.execute("SELECT * FROM models WHERE id = ?", (model_id,)) as cur: + row = await cur.fetchone() + if not row: + return None + + # Fetch versions + async with db.execute( + "SELECT * FROM model_versions WHERE model_id = ? ORDER BY created_at DESC", + (model_id,), + ) as cur: + version_rows = await cur.fetchall() + + versions = [ + ModelVersion( + version=r["version"], + label=r["label"], + description=r["description"] if "description" in r.keys() else None, + releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "", + changelog=r["changelog"] if "changelog" in r.keys() else None, + ) + for r in version_rows + ] + return row_to_model(row, versions) + + +async def list_models( + *, + tasks: list[str] | None = None, + frameworks: list[str] | None = None, + hardware: list[str] | None = None, + sources: list[str] | None = None, + downloaded: bool | None = None, + sort_by: str = "downloads", + sort_dir: str = "desc", + limit: int = 500, + offset: int = 0, + search: str | None = None, + project_id: str | None = None, +) -> list[Model]: + db = await get_db() + + # ── WHERE conditions ────────────────────────────────────────────── + conditions: list[str] = [] + params: list[Any] = [] + + # FTS5 subquery — valid SQLite syntax + if search and search.strip(): + fts_term = f'"{search.strip()}"*' + conditions.append( + "m.id IN (SELECT id FROM models_fts WHERE models_fts MATCH ?)" + ) + params.append(fts_term) + + if tasks: + placeholders = ",".join(["?"] * len(tasks)) + conditions.append(f"m.task IN ({placeholders})") + params.extend(tasks) + + if frameworks: + placeholders = ",".join(["?"] * len(frameworks)) + conditions.append(f"m.framework IN ({placeholders})") + params.extend(frameworks) + + if sources: + placeholders = ",".join(["?"] * len(sources)) + conditions.append(f"m.source IN ({placeholders})") + params.extend(sources) + + if hardware: + hw_conds = ["m.hardware LIKE ?" for _ in hardware] + conditions.append(f"({' OR '.join(hw_conds)})") + params.extend([f"%{h}%" for h in hardware]) + + if downloaded is not None: + conditions.append("m.downloaded = ?") + params.append(int(downloaded)) + + if project_id: + conditions.append("(m.project_id = ? OR m.project_id IS NULL)") + params.append(project_id) + + where_clause = ("WHERE " + " AND ".join(conditions)) if conditions else "" + + # ── Sort ────────────────────────────────────────────────────────── + sort_col_map = { + "downloads": "m.downloads", + "name": "m.name", + "size": "m.size", + "rating": "m.rating", + "created": "m.created_at", + } + col = sort_col_map.get(sort_by, "m.downloads") + direction = "DESC" if sort_dir == "desc" else "ASC" + + sql = f""" + SELECT m.* FROM models m + {where_clause} + ORDER BY {col} {direction} NULLS LAST + LIMIT ? OFFSET ? + """ + + async with db.execute(sql, params + [limit, offset]) as cur: + rows = await cur.fetchall() + + models = [row_to_model(row, []) for row in rows] + if not models: + return models + + ids = [m.id for m in models] + placeholders = ",".join(["?"] * len(ids)) + async with db.execute( + f"SELECT * FROM model_versions WHERE model_id IN ({placeholders}) ORDER BY created_at DESC", + ids, + ) as cur: + vrows = await cur.fetchall() + + by_model: dict[str, list[ModelVersion]] = {} + for r in vrows: + mv = ModelVersion( + version=r["version"], + label=r["label"], + description=r["description"] if "description" in r.keys() else None, + releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "", + changelog=r["changelog"] if "changelog" in r.keys() else None, + ) + by_model.setdefault(r["model_id"], []).append(mv) + + return [m.model_copy(update={"versions": by_model.get(m.id, [])}) for m in models] + + +async def update_model_status( + model_id: str, + *, + status: str | None = None, + downloaded: bool | None = None, + local_path: str | None = None, +) -> None: + db = await get_db() + now = datetime.now(timezone.utc).isoformat() + parts: list[str] = ["updated_at = ?"] + vals: list[Any] = [now] + if status is not None: + parts.append("status = ?"); vals.append(status) + if downloaded is not None: + parts.append("downloaded = ?"); vals.append(int(downloaded)) + if local_path is not None: + parts.append("local_path = ?"); vals.append(local_path) + vals.append(model_id) + await db.execute(f"UPDATE models SET {', '.join(parts)} WHERE id = ?", vals) + await db.commit() + + +async def count_models() -> int: + db = await get_db() + async with db.execute("SELECT COUNT(*) FROM models") as cur: + row = await cur.fetchone() + return row[0] if row else 0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..69034ba290a12ff232ca903b096e874bc1a4a50c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +fastapi==0.115.5 +uvicorn[standard]==0.32.1 +httpx==0.27.2 +aiofiles==24.1.0 +aiosqlite==0.20.0 +pydantic==2.9.2 +pydantic-settings==2.6.1 +python-multipart==0.0.17 +huggingface_hub==0.26.2 +structlog==24.4.0 +rich==13.9.4 +tenacity==9.0.0 +tqdm==4.67.1 +psutil==6.1.1 +pynvml==11.5.3 +pytest==8.3.4 +pytest-asyncio==0.24.0