Spaces:
Sleeping
Sleeping
Commit Β·
ac5551d
0
Parent(s):
Deploy cloud brain to HF Spaces
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- Dockerfile +29 -0
- api/__init__.py +0 -0
- api/__pycache__/__init__.cpython-310.pyc +0 -0
- api/routes/__init__.py +0 -0
- api/routes/__pycache__/__init__.cpython-310.pyc +0 -0
- api/routes/__pycache__/benchmark.cpython-310.pyc +0 -0
- api/routes/__pycache__/datasets.cpython-310.pyc +0 -0
- api/routes/__pycache__/inference.cpython-310.pyc +0 -0
- api/routes/__pycache__/jobs.cpython-310.pyc +0 -0
- api/routes/__pycache__/models.cpython-310.pyc +0 -0
- api/routes/__pycache__/projects.cpython-310.pyc +0 -0
- api/routes/__pycache__/sync.cpython-310.pyc +0 -0
- api/routes/__pycache__/system.cpython-310.pyc +0 -0
- api/routes/__pycache__/training.cpython-310.pyc +0 -0
- api/routes/benchmark.py +238 -0
- api/routes/datasets.py +395 -0
- api/routes/inference.py +168 -0
- api/routes/jobs.py +56 -0
- api/routes/models.py +127 -0
- api/routes/projects.py +54 -0
- api/routes/sync.py +73 -0
- api/routes/system.py +97 -0
- api/routes/training.py +428 -0
- config.py +83 -0
- database/__init__.py +0 -0
- database/__pycache__/__init__.cpython-310.pyc +0 -0
- database/__pycache__/connection.cpython-310.pyc +0 -0
- database/benchmark_schema.sql +62 -0
- database/connection.py +106 -0
- database/dataset_schema.sql +117 -0
- database/schema.sql +152 -0
- main.py +149 -0
- middleware/__init__.py +0 -0
- middleware/__pycache__/__init__.cpython-310.pyc +0 -0
- middleware/__pycache__/logging_middleware.cpython-310.pyc +0 -0
- middleware/logging_middleware.py +57 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/benchmark.cpython-310.pyc +0 -0
- models/__pycache__/dataset.cpython-310.pyc +0 -0
- models/__pycache__/inference.cpython-310.pyc +0 -0
- models/__pycache__/job.cpython-310.pyc +0 -0
- models/__pycache__/model.cpython-310.pyc +0 -0
- models/__pycache__/project.cpython-310.pyc +0 -0
- models/__pycache__/system.cpython-310.pyc +0 -0
- models/benchmark.py +223 -0
- models/dataset.py +401 -0
- models/inference.py +142 -0
- models/job.py +51 -0
- models/model.py +129 -0
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a lightweight Python image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PORT=7860
|
| 7 |
+
|
| 8 |
+
# Set working directory
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Install system dependencies
|
| 12 |
+
RUN apt-get update && apt-get install -y \
|
| 13 |
+
build-essential \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy requirements and install
|
| 17 |
+
# Note: requirements.txt should be in the same directory as Dockerfile (backend/)
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy the rest of the backend code
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# HuggingFace Spaces uses port 7860 by default
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Run the FastAPI app
|
| 28 |
+
# We use 0.0.0.0 to allow external connections within the container
|
| 29 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
api/__init__.py
ADDED
|
File without changes
|
api/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (138 Bytes). View file
|
|
|
api/routes/__init__.py
ADDED
|
File without changes
|
api/routes/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
api/routes/__pycache__/benchmark.cpython-310.pyc
ADDED
|
Binary file (6.16 kB). View file
|
|
|
api/routes/__pycache__/datasets.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
api/routes/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
api/routes/__pycache__/jobs.cpython-310.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
api/routes/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
api/routes/__pycache__/projects.cpython-310.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
api/routes/__pycache__/sync.cpython-310.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
api/routes/__pycache__/system.cpython-310.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
api/routes/__pycache__/training.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
api/routes/benchmark.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/benchmark.py β Benchmark Bridge REST + WebSocket API.
|
| 3 |
+
|
| 4 |
+
Routes:
|
| 5 |
+
POST /benchmark/validate β compatibility check (no job created)
|
| 6 |
+
POST /benchmark/run β validate + create + enqueue (202)
|
| 7 |
+
GET /benchmark/jobs β list jobs (filterable)
|
| 8 |
+
GET /benchmark/results/all β list all results
|
| 9 |
+
GET /benchmark/{job_id} β single job status + logs
|
| 10 |
+
GET /benchmark/{job_id}/result β metrics + telemetry for completed job
|
| 11 |
+
WS /benchmark/live/{job_id} β real-time progress stream
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect
|
| 19 |
+
|
| 20 |
+
import benchmark.orchestrator as orchestrator
|
| 21 |
+
import benchmark.registry as bench_reg
|
| 22 |
+
from models.benchmark import (
|
| 23 |
+
BenchmarkContext,
|
| 24 |
+
BenchmarkJob,
|
| 25 |
+
BenchmarkResult,
|
| 26 |
+
BenchmarkRunResponse,
|
| 27 |
+
ValidationReport,
|
| 28 |
+
)
|
| 29 |
+
from observability.logger import get_logger
|
| 30 |
+
|
| 31 |
+
log = get_logger("api.benchmark")
|
| 32 |
+
|
| 33 |
+
router = APIRouter(prefix="/benchmark", tags=["benchmark"])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ββ POST /benchmark/validate ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
@router.post(
|
| 39 |
+
"/validate",
|
| 40 |
+
response_model = ValidationReport,
|
| 41 |
+
summary = "Validate model β dataset compatibility",
|
| 42 |
+
description = (
|
| 43 |
+
"Runs all 5 compatibility gates (task, format, frameworkΓhardware, "
|
| 44 |
+
"VRAM, precision) and returns a structured report. "
|
| 45 |
+
"Does NOT create a benchmark job."
|
| 46 |
+
),
|
| 47 |
+
)
|
| 48 |
+
async def validate_benchmark(ctx: BenchmarkContext) -> ValidationReport:
|
| 49 |
+
try:
|
| 50 |
+
return await orchestrator.validate_context(ctx)
|
| 51 |
+
except HTTPException:
|
| 52 |
+
raise
|
| 53 |
+
except Exception as exc:
|
| 54 |
+
log.exception("validate_error")
|
| 55 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββ POST /benchmark/run βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
|
| 60 |
+
@router.post(
|
| 61 |
+
"/run",
|
| 62 |
+
response_model = BenchmarkRunResponse,
|
| 63 |
+
status_code = 202,
|
| 64 |
+
summary = "Start a benchmark run",
|
| 65 |
+
description = (
|
| 66 |
+
"Validates compatibility, creates a benchmark job, and starts async "
|
| 67 |
+
"execution. Returns job_id immediately β poll GET /benchmark/{job_id} "
|
| 68 |
+
"or connect to WS /benchmark/live/{job_id} for progress."
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
async def run_benchmark(ctx: BenchmarkContext) -> BenchmarkRunResponse:
|
| 72 |
+
try:
|
| 73 |
+
job = await orchestrator.create_and_run(ctx)
|
| 74 |
+
return BenchmarkRunResponse(
|
| 75 |
+
job_id = job.id,
|
| 76 |
+
status = job.status,
|
| 77 |
+
message = f"Benchmark job {job.id} queued β connect to /benchmark/live/{job.id} for live telemetry",
|
| 78 |
+
)
|
| 79 |
+
except HTTPException:
|
| 80 |
+
raise
|
| 81 |
+
except Exception as exc:
|
| 82 |
+
log.exception("run_benchmark_error")
|
| 83 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ββ POST /benchmark/sync ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
+
|
| 88 |
+
@router.post(
|
| 89 |
+
"/sync",
|
| 90 |
+
summary = "Sync benchmarks from active project folder",
|
| 91 |
+
description = "Scans the active project's 'benchmarks' folder and ensures all JSON records are indexed in SQLite.",
|
| 92 |
+
)
|
| 93 |
+
async def sync_benchmarks() -> dict[str, Any]:
|
| 94 |
+
try:
|
| 95 |
+
count = await orchestrator.sync_project_benchmarks()
|
| 96 |
+
return {"status": "success", "count": count}
|
| 97 |
+
except Exception as exc:
|
| 98 |
+
log.exception("sync_error")
|
| 99 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ββ GET /benchmark/jobs βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
|
| 104 |
+
@router.get(
|
| 105 |
+
"/jobs",
|
| 106 |
+
response_model = list[BenchmarkJob],
|
| 107 |
+
summary = "List benchmark jobs",
|
| 108 |
+
)
|
| 109 |
+
async def list_jobs(
|
| 110 |
+
status: str | None = Query(None, description="Filter by status (queued|running|completed|failed)"),
|
| 111 |
+
model_id: str | None = Query(None, description="Filter by model_id"),
|
| 112 |
+
limit: int = Query(100, ge=1, le=500),
|
| 113 |
+
) -> list[BenchmarkJob]:
|
| 114 |
+
return await bench_reg.list_jobs(status=status, model_id=model_id, limit=limit)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββ GET /benchmark/results/all ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
# Must be declared BEFORE /{job_id} to avoid "results" being treated as a job_id
|
| 119 |
+
|
| 120 |
+
@router.get(
|
| 121 |
+
"/results/all",
|
| 122 |
+
response_model = list[BenchmarkResult],
|
| 123 |
+
summary = "List all benchmark results (leaderboard feed)",
|
| 124 |
+
)
|
| 125 |
+
async def list_results(
|
| 126 |
+
limit: int = Query(100, ge=1, le=500),
|
| 127 |
+
) -> list[BenchmarkResult]:
|
| 128 |
+
return await bench_reg.list_results(limit=limit)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ββ GET /benchmark/{job_id} βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
|
| 133 |
+
@router.get(
|
| 134 |
+
"/{job_id}",
|
| 135 |
+
response_model = BenchmarkJob,
|
| 136 |
+
summary = "Get benchmark job status + logs",
|
| 137 |
+
)
|
| 138 |
+
async def get_job(job_id: str) -> BenchmarkJob:
|
| 139 |
+
job = await bench_reg.get_job(job_id)
|
| 140 |
+
if not job:
|
| 141 |
+
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
| 142 |
+
return job
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ββ GET /benchmark/{job_id}/result βββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
|
| 147 |
+
@router.get(
|
| 148 |
+
"/{job_id}/result",
|
| 149 |
+
response_model = BenchmarkResult,
|
| 150 |
+
summary = "Get final metrics + telemetry for a completed job",
|
| 151 |
+
)
|
| 152 |
+
async def get_result(job_id: str) -> BenchmarkResult:
|
| 153 |
+
result = await bench_reg.get_result(job_id)
|
| 154 |
+
if not result:
|
| 155 |
+
raise HTTPException(
|
| 156 |
+
status_code = 404,
|
| 157 |
+
detail = f"No result for job '{job_id}' β job may still be running",
|
| 158 |
+
)
|
| 159 |
+
return result
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ββ WS /benchmark/live/{job_id} βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
|
| 164 |
+
@router.websocket("/live/{job_id}")
|
| 165 |
+
async def live_telemetry(websocket: WebSocket, job_id: str) -> None:
|
| 166 |
+
"""
|
| 167 |
+
WebSocket stream for real-time benchmark progress.
|
| 168 |
+
Streams incremental logs and high-frequency telemetry.
|
| 169 |
+
"""
|
| 170 |
+
await websocket.accept()
|
| 171 |
+
log.info("ws_connected", job_id=job_id)
|
| 172 |
+
|
| 173 |
+
last_log_idx = 0
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
while True:
|
| 177 |
+
job = await bench_reg.get_job(job_id)
|
| 178 |
+
|
| 179 |
+
if not job:
|
| 180 |
+
await websocket.send_json(
|
| 181 |
+
{"error": f"Job '{job_id}' not found", "job_id": job_id}
|
| 182 |
+
)
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
# Only send new logs
|
| 186 |
+
new_logs = job.logs[last_log_idx:]
|
| 187 |
+
last_log_idx = len(job.logs)
|
| 188 |
+
|
| 189 |
+
payload: dict[str, Any] = {
|
| 190 |
+
"job_id": job.id,
|
| 191 |
+
"status": job.status,
|
| 192 |
+
"progress": round(job.progress, 4),
|
| 193 |
+
"logs": new_logs,
|
| 194 |
+
"telemetry": job.last_telemetry.model_dump() if job.last_telemetry else None,
|
| 195 |
+
}
|
| 196 |
+
# Explicitly include detections for the UI visualizer if they exist
|
| 197 |
+
if job.last_telemetry and hasattr(job.last_telemetry, "detections"):
|
| 198 |
+
payload["detections"] = job.last_telemetry.detections
|
| 199 |
+
|
| 200 |
+
await websocket.send_json(payload)
|
| 201 |
+
|
| 202 |
+
if job.status == "completed":
|
| 203 |
+
result = await bench_reg.get_result(job_id)
|
| 204 |
+
if result:
|
| 205 |
+
await websocket.send_json(
|
| 206 |
+
{
|
| 207 |
+
"job_id": job_id,
|
| 208 |
+
"status": "completed",
|
| 209 |
+
"result": result.model_dump(),
|
| 210 |
+
}
|
| 211 |
+
)
|
| 212 |
+
break
|
| 213 |
+
|
| 214 |
+
if job.status == "failed":
|
| 215 |
+
await websocket.send_json(
|
| 216 |
+
{
|
| 217 |
+
"job_id": job_id,
|
| 218 |
+
"status": "failed",
|
| 219 |
+
"error": job.error or "Unknown error",
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
await asyncio.sleep(0.5) # poll at 2 Hz
|
| 225 |
+
|
| 226 |
+
except WebSocketDisconnect:
|
| 227 |
+
log.info("ws_disconnected", job_id=job_id)
|
| 228 |
+
except Exception as exc:
|
| 229 |
+
log.exception("ws_error", job_id=job_id)
|
| 230 |
+
try:
|
| 231 |
+
await websocket.send_json({"error": str(exc), "job_id": job_id})
|
| 232 |
+
except Exception:
|
| 233 |
+
pass
|
| 234 |
+
finally:
|
| 235 |
+
try:
|
| 236 |
+
await websocket.close()
|
| 237 |
+
except Exception:
|
| 238 |
+
pass
|
api/routes/datasets.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/datasets.py β Dataset Manager REST API.
|
| 3 |
+
|
| 4 |
+
Routes:
|
| 5 |
+
GET /datasets β list/search datasets
|
| 6 |
+
GET /datasets/{id} β dataset detail
|
| 7 |
+
POST /datasets/search/roboflow β search Roboflow Universe (real-time)
|
| 8 |
+
POST /datasets/sync/roboflow β sync workspace datasets to local DB
|
| 9 |
+
POST /datasets/{id}/import β initiate dataset import job
|
| 10 |
+
GET /datasets/{id}/images β paginated viewer (images + annotations)
|
| 11 |
+
GET /datasets/{id}/image/{img} β serve raw image bytes
|
| 12 |
+
GET /datasets/jobs β list import jobs
|
| 13 |
+
GET /datasets/jobs/{job_id} β single job status
|
| 14 |
+
POST /datasets/{id}/star β toggle starred
|
| 15 |
+
DELETE /datasets/{id} β delete dataset record (+ local files)
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
|
| 23 |
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 24 |
+
from fastapi.responses import FileResponse, Response
|
| 25 |
+
|
| 26 |
+
from adapters.roboflow_adapter import RoboflowAdapter
|
| 27 |
+
from datasets import registry as ds_reg
|
| 28 |
+
from datasets.import_service import start_import
|
| 29 |
+
from datasets.viewer_service import get_universal_viewer_page, get_viewer_page, resolve_image_path
|
| 30 |
+
from models.dataset import (
|
| 31 |
+
Dataset, DatasetJob, DatasetSummary, DatasetSource, DatasetTask,
|
| 32 |
+
DatasetFormat, DatasetStatus, ImportRequest, ImportResponse,
|
| 33 |
+
RoboflowSearchRequest, ViewerPage, UniversalViewerPage, row_to_dataset,
|
| 34 |
+
)
|
| 35 |
+
from observability.logger import audit, get_logger
|
| 36 |
+
|
| 37 |
+
log = get_logger("datasets_route")
|
| 38 |
+
|
| 39 |
+
router = APIRouter(prefix="/datasets", tags=["datasets"])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ββ List / Search datasets ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
@router.get("", response_model=list[DatasetSummary])
|
| 45 |
+
async def list_datasets(
|
| 46 |
+
task: Optional[str] = Query(None),
|
| 47 |
+
format: Optional[str] = Query(None),
|
| 48 |
+
source: Optional[str] = Query(None),
|
| 49 |
+
status: Optional[str] = Query(None),
|
| 50 |
+
search: Optional[str] = Query(None),
|
| 51 |
+
starred: Optional[bool] = Query(None),
|
| 52 |
+
limit: int = Query(100, ge=1, le=1000),
|
| 53 |
+
offset: int = Query(0, ge=0),
|
| 54 |
+
):
|
| 55 |
+
try:
|
| 56 |
+
datasets = await ds_reg.get_all_datasets(
|
| 57 |
+
task=task, format=format, source=source,
|
| 58 |
+
status=status, search=search, starred=starred,
|
| 59 |
+
limit=limit, offset=offset,
|
| 60 |
+
)
|
| 61 |
+
return [_to_summary(d) for d in datasets]
|
| 62 |
+
except Exception as exc:
|
| 63 |
+
log.exception("list_datasets_error")
|
| 64 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@router.get("/jobs", response_model=list[DatasetJob])
|
| 68 |
+
async def list_jobs(limit: int = Query(50, ge=1, le=500)):
|
| 69 |
+
return await ds_reg.get_all_jobs(limit=limit)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@router.get("/jobs/{job_id}", response_model=DatasetJob)
|
| 73 |
+
async def get_job(job_id: str):
|
| 74 |
+
job = await ds_reg.get_job(job_id)
|
| 75 |
+
if not job:
|
| 76 |
+
raise HTTPException(404, f"Job {job_id!r} not found")
|
| 77 |
+
return job
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@router.post("/jobs/{job_id}/stop")
|
| 81 |
+
async def stop_job(job_id: str):
|
| 82 |
+
"""Cancel a running import job."""
|
| 83 |
+
# Logic to cancel the asyncio task would go here
|
| 84 |
+
# For now, we update the status in the DB
|
| 85 |
+
await ds_reg.update_job(job_id, status="failed", error="Cancelled by user", ended_at=datetime.utcnow().isoformat())
|
| 86 |
+
return {"status": "success", "message": "Job stop requested"}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@router.post("/jobs/{job_id}/pause")
|
| 90 |
+
async def pause_job(job_id: str):
|
| 91 |
+
"""Pause a running import job."""
|
| 92 |
+
await ds_reg.update_job(job_id, status="paused")
|
| 93 |
+
return {"status": "success", "message": "Job pause requested"}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@router.post("/jobs/{job_id}/resume")
|
| 97 |
+
async def resume_job(job_id: str):
|
| 98 |
+
"""Resume a paused import job."""
|
| 99 |
+
await ds_reg.update_job(job_id, status="running")
|
| 100 |
+
return {"status": "success", "message": "Job resume requested"}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ββ Roboflow Search & Sync ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
|
| 105 |
+
@router.post("/search/roboflow", response_model=list[DatasetSummary])
|
| 106 |
+
async def search_roboflow(req: RoboflowSearchRequest):
|
| 107 |
+
"""
|
| 108 |
+
Live search Roboflow Universe. Results are cached for 1 hour.
|
| 109 |
+
Also upserts results into local DB so they appear in /datasets.
|
| 110 |
+
"""
|
| 111 |
+
try:
|
| 112 |
+
datasets = await RoboflowAdapter.search_datasets(
|
| 113 |
+
api_key = req.api_key,
|
| 114 |
+
query = req.query,
|
| 115 |
+
workspace = req.workspace,
|
| 116 |
+
page = req.page,
|
| 117 |
+
page_size = req.page_size,
|
| 118 |
+
)
|
| 119 |
+
except Exception as exc:
|
| 120 |
+
log.error("roboflow_search_error", error=str(exc))
|
| 121 |
+
raise HTTPException(502, f"Roboflow API error: {exc}")
|
| 122 |
+
|
| 123 |
+
# Upsert to local DB
|
| 124 |
+
await ds_reg.bulk_upsert_datasets(datasets)
|
| 125 |
+
await audit("roboflow_search", {"query": req.query, "count": len(datasets)})
|
| 126 |
+
return [_to_summary(d) for d in datasets]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@router.post("/sync/roboflow", response_model=dict)
|
| 130 |
+
async def sync_roboflow_workspace(
|
| 131 |
+
api_key: str = Query(..., description="Roboflow API key"),
|
| 132 |
+
workspace: str = Query(..., description="Workspace slug"),
|
| 133 |
+
):
|
| 134 |
+
"""Sync all datasets from a Roboflow workspace into local DB."""
|
| 135 |
+
try:
|
| 136 |
+
datasets = await RoboflowAdapter.list_workspace_datasets(api_key, workspace)
|
| 137 |
+
except Exception as exc:
|
| 138 |
+
raise HTTPException(502, f"Roboflow API error: {exc}")
|
| 139 |
+
count = await ds_reg.bulk_upsert_datasets(datasets)
|
| 140 |
+
return {"synced": count, "workspace": workspace}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ββ Dataset detail ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
|
| 145 |
+
@router.get("/{dataset_id}", response_model=Dataset)
|
| 146 |
+
async def get_dataset(dataset_id: str):
|
| 147 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 148 |
+
if not ds:
|
| 149 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found")
|
| 150 |
+
return ds
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ββ Import ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
|
| 155 |
+
@router.post("/{dataset_id}/import", response_model=ImportResponse)
|
| 156 |
+
async def import_dataset(dataset_id: str, req: ImportRequest):
|
| 157 |
+
"""
|
| 158 |
+
Initiate a background import job for a dataset.
|
| 159 |
+
Supports sources: roboflow | roboflow_curl | huggingface | local
|
| 160 |
+
"""
|
| 161 |
+
req.dataset_id = dataset_id # enforce consistency
|
| 162 |
+
|
| 163 |
+
# Sources that are discovered outside the registry must be auto-registered.
|
| 164 |
+
auto_register_sources = {DatasetSource.huggingface, DatasetSource.roboflow_curl, DatasetSource.local}
|
| 165 |
+
if req.source in auto_register_sources:
|
| 166 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 167 |
+
if not ds:
|
| 168 |
+
# Determine human-readable name
|
| 169 |
+
if req.source == DatasetSource.huggingface and req.hf_dataset_id:
|
| 170 |
+
name = req.hf_dataset_id
|
| 171 |
+
roboflow_ref = req.hf_dataset_id
|
| 172 |
+
fmt = DatasetFormat.json
|
| 173 |
+
src = DatasetSource.huggingface
|
| 174 |
+
|
| 175 |
+
elif req.source == DatasetSource.local:
|
| 176 |
+
# local: use provided name or folder name from path
|
| 177 |
+
# Try req.local_path first, then req.name, then fallback to dataset_id
|
| 178 |
+
path_to_use = req.local_path or req.name or ""
|
| 179 |
+
name = req.name or (Path(path_to_use).name if path_to_use else dataset_id)
|
| 180 |
+
roboflow_ref = None
|
| 181 |
+
fmt = DatasetFormat.custom
|
| 182 |
+
src = DatasetSource.local
|
| 183 |
+
else:
|
| 184 |
+
# roboflow_curl: use provided dataset_name or fall back to dataset_id
|
| 185 |
+
name = req.dataset_name or dataset_id
|
| 186 |
+
roboflow_ref = None
|
| 187 |
+
fmt = _curl_format_to_enum(req.curl_format)
|
| 188 |
+
src = DatasetSource.roboflow_curl
|
| 189 |
+
|
| 190 |
+
stub = Dataset(
|
| 191 |
+
id=dataset_id,
|
| 192 |
+
name=name,
|
| 193 |
+
task=DatasetTask.detection,
|
| 194 |
+
format=fmt,
|
| 195 |
+
source=src,
|
| 196 |
+
status=DatasetStatus.available,
|
| 197 |
+
roboflow_id=roboflow_ref,
|
| 198 |
+
created_at=datetime.utcnow().isoformat(),
|
| 199 |
+
)
|
| 200 |
+
await ds_reg.upsert_dataset(stub)
|
| 201 |
+
log.info("dataset_auto_registered", dataset_id=dataset_id, source=str(req.source))
|
| 202 |
+
else:
|
| 203 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 204 |
+
if not ds:
|
| 205 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found in registry. "
|
| 206 |
+
"Run /datasets/sync/roboflow first.")
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
job_id = await start_import(req)
|
| 210 |
+
except ValueError as exc:
|
| 211 |
+
raise HTTPException(400, str(exc))
|
| 212 |
+
|
| 213 |
+
await audit("dataset_import_requested", {"dataset_id": dataset_id, "source": str(req.source)})
|
| 214 |
+
return ImportResponse(
|
| 215 |
+
job_id = job_id,
|
| 216 |
+
dataset_id = dataset_id,
|
| 217 |
+
status = "queued",
|
| 218 |
+
message = "Import job created successfully",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _curl_format_to_enum(curl_format: str | None) -> DatasetFormat:
|
| 223 |
+
"""Map Roboflow export format string from cURL to DatasetFormat enum."""
|
| 224 |
+
if not curl_format:
|
| 225 |
+
return DatasetFormat.yolo
|
| 226 |
+
fmt = curl_format.lower()
|
| 227 |
+
if "yolo" in fmt:
|
| 228 |
+
return DatasetFormat.yolo
|
| 229 |
+
if "coco" in fmt:
|
| 230 |
+
return DatasetFormat.coco
|
| 231 |
+
if "voc" in fmt or "pascal" in fmt:
|
| 232 |
+
return DatasetFormat.voc
|
| 233 |
+
if "tfrecord" in fmt:
|
| 234 |
+
return DatasetFormat.tfrecord
|
| 235 |
+
if "csv" in fmt:
|
| 236 |
+
return DatasetFormat.csv
|
| 237 |
+
if "json" in fmt or "createml" in fmt:
|
| 238 |
+
return DatasetFormat.json
|
| 239 |
+
return DatasetFormat.yolo
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ββ Viewer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½β
|
| 243 |
+
|
| 244 |
+
@router.get("/{dataset_id}/universal", response_model=UniversalViewerPage)
|
| 245 |
+
async def get_universal_items(
|
| 246 |
+
dataset_id: str,
|
| 247 |
+
page: int = Query(0, ge=0),
|
| 248 |
+
page_size: int = Query(20, ge=1, le=100),
|
| 249 |
+
split: Optional[str] = Query(None, regex="^(train|val|test)$"),
|
| 250 |
+
class_label: Optional[str] = Query(None),
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Polymorphic dataset item viewer (UDV).
|
| 254 |
+
Supports Vision, NLP, and Tabular data via the Universal Dataset Item schema.
|
| 255 |
+
"""
|
| 256 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 257 |
+
if not ds:
|
| 258 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found")
|
| 259 |
+
|
| 260 |
+
# Allow viewing even if not fully imported for NLP/Tabular if files exist,
|
| 261 |
+
# but for Vision we usually need the index.
|
| 262 |
+
return await get_universal_viewer_page(dataset_id, page, page_size, split, class_label)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@router.get("/{dataset_id}/images", response_model=ViewerPage)
|
| 266 |
+
async def get_images(
|
| 267 |
+
dataset_id: str,
|
| 268 |
+
page: int = Query(0, ge=0),
|
| 269 |
+
page_size: int = Query(20, ge=1, le=100),
|
| 270 |
+
split: Optional[str] = Query(None, regex="^(train|val|test)$"),
|
| 271 |
+
class_label: Optional[str] = Query(None),
|
| 272 |
+
):
|
| 273 |
+
"""
|
| 274 |
+
Paginated image + annotation data for the viewer.
|
| 275 |
+
Annotations are returned in normalised [0β1] coordinates.
|
| 276 |
+
Supports filtering by split and class label.
|
| 277 |
+
"""
|
| 278 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 279 |
+
if not ds:
|
| 280 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found")
|
| 281 |
+
if ds.status != "imported":
|
| 282 |
+
raise HTTPException(409, f"Dataset is not imported yet (status: {ds.status})")
|
| 283 |
+
|
| 284 |
+
return await get_viewer_page(dataset_id, page, page_size, split, class_label)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@router.get("/{dataset_id}/stats", response_model=dict)
|
| 288 |
+
async def get_dataset_stats(dataset_id: str):
|
| 289 |
+
"""Return pre-computed class distributions and split statistics."""
|
| 290 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 291 |
+
if not ds:
|
| 292 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found")
|
| 293 |
+
|
| 294 |
+
return await ds_reg.get_dataset_stats(dataset_id)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@router.get("/{dataset_id}/image/{image_id}")
|
| 298 |
+
async def serve_image(dataset_id: str, image_id: str):
|
| 299 |
+
"""Serve raw image bytes for the viewer (cached by browser)."""
|
| 300 |
+
path = await resolve_image_path(dataset_id, image_id)
|
| 301 |
+
if path is None:
|
| 302 |
+
raise HTTPException(404, "Image not found or dataset not imported")
|
| 303 |
+
|
| 304 |
+
suffix = path.suffix.lower()
|
| 305 |
+
media_types = {
|
| 306 |
+
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
| 307 |
+
".png": "image/png", ".bmp": "image/bmp",
|
| 308 |
+
".webp": "image/webp",
|
| 309 |
+
}
|
| 310 |
+
media_type = media_types.get(suffix, "application/octet-stream")
|
| 311 |
+
return FileResponse(
|
| 312 |
+
path = str(path),
|
| 313 |
+
media_type = media_type,
|
| 314 |
+
headers = {"Cache-Control": "public, max-age=86400"},
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@router.get("/{dataset_id}/annotations", response_model=dict)
|
| 319 |
+
async def get_annotations_summary(dataset_id: str):
|
| 320 |
+
"""Return class distribution summary from the annotations index."""
|
| 321 |
+
from database.connection import get_db
|
| 322 |
+
db = await get_db()
|
| 323 |
+
async with db.execute(
|
| 324 |
+
"""SELECT label, COUNT(*) as count
|
| 325 |
+
FROM dataset_annotations
|
| 326 |
+
WHERE dataset_id=?
|
| 327 |
+
GROUP BY label
|
| 328 |
+
ORDER BY count DESC""",
|
| 329 |
+
(dataset_id,),
|
| 330 |
+
) as cur:
|
| 331 |
+
rows = await cur.fetchall()
|
| 332 |
+
return {
|
| 333 |
+
"dataset_id": dataset_id,
|
| 334 |
+
"class_distribution": [{"label": r["label"], "count": r["count"]} for r in rows],
|
| 335 |
+
"total_annotations": sum(r["count"] for r in rows),
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ββ Star / Delete βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
|
| 341 |
+
@router.post("/{dataset_id}/star", response_model=dict)
|
| 342 |
+
async def toggle_star(dataset_id: str):
|
| 343 |
+
new_val = await ds_reg.toggle_starred(dataset_id)
|
| 344 |
+
return {"dataset_id": dataset_id, "starred": new_val}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@router.delete("/{dataset_id}", response_model=dict)
|
| 348 |
+
async def delete_dataset(
|
| 349 |
+
dataset_id: str,
|
| 350 |
+
delete_files: bool = Query(False, description="Also remove local files from disk"),
|
| 351 |
+
):
|
| 352 |
+
ds = await ds_reg.get_dataset(dataset_id)
|
| 353 |
+
if not ds:
|
| 354 |
+
raise HTTPException(404, f"Dataset {dataset_id!r} not found")
|
| 355 |
+
|
| 356 |
+
if delete_files and ds.local_path:
|
| 357 |
+
import shutil
|
| 358 |
+
local = Path(ds.local_path)
|
| 359 |
+
if local.exists() and local.is_dir():
|
| 360 |
+
shutil.rmtree(str(local), ignore_errors=True)
|
| 361 |
+
log.info("dataset_files_deleted", path=str(local))
|
| 362 |
+
|
| 363 |
+
deleted = await ds_reg.delete_dataset(dataset_id)
|
| 364 |
+
await audit("dataset_deleted", {"dataset_id": dataset_id, "files_deleted": delete_files})
|
| 365 |
+
return {"deleted": deleted, "dataset_id": dataset_id}
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ββ Helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 369 |
+
|
| 370 |
+
def _to_summary(d: Dataset) -> DatasetSummary:
|
| 371 |
+
# Use 0.0 as default health_score if stats is missing or health_score is not present
|
| 372 |
+
health_score = 0.0
|
| 373 |
+
try:
|
| 374 |
+
if hasattr(d, 'stats') and d.stats:
|
| 375 |
+
health_score = getattr(d.stats, 'health_score', 0.0)
|
| 376 |
+
except Exception:
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
return DatasetSummary(
|
| 380 |
+
id = d.id,
|
| 381 |
+
name = d.name,
|
| 382 |
+
task = str(d.task),
|
| 383 |
+
format = str(d.format),
|
| 384 |
+
source = str(d.source),
|
| 385 |
+
status = str(d.status),
|
| 386 |
+
images = d.images,
|
| 387 |
+
classes = d.classes,
|
| 388 |
+
size_label = d.size_label,
|
| 389 |
+
tags = d.tags,
|
| 390 |
+
starred = d.starred,
|
| 391 |
+
import_progress = d.import_progress,
|
| 392 |
+
health_score = health_score,
|
| 393 |
+
created_at = d.created_at,
|
| 394 |
+
updated_at = d.updated_at,
|
| 395 |
+
)
|
api/routes/inference.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/inference.py β Inference Engine endpoints.
|
| 3 |
+
|
| 4 |
+
POST /inference/run β single synchronous inference pass
|
| 5 |
+
POST /inference/stream β SSE stream (stage-by-stage pipeline events)
|
| 6 |
+
GET /inference/history β session ledger
|
| 7 |
+
DELETE /inference/history β clear session ledger
|
| 8 |
+
GET /inference/cache β currently warm models in memory
|
| 9 |
+
DELETE /inference/cache/{model_id} β evict from cache
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
from fastapi import APIRouter, HTTPException, Response
|
| 18 |
+
from fastapi.responses import StreamingResponse
|
| 19 |
+
|
| 20 |
+
from inference.engine import InferenceEngine, evict_model, get_cache_status
|
| 21 |
+
from inference.session import clear_history, get_history, record
|
| 22 |
+
from models.inference import (
|
| 23 |
+
InferenceHistoryEntry,
|
| 24 |
+
InferenceRequest,
|
| 25 |
+
InferenceResult,
|
| 26 |
+
)
|
| 27 |
+
from observability.logger import get_logger
|
| 28 |
+
from registry.registry import get_model
|
| 29 |
+
|
| 30 |
+
log = get_logger("api.inference")
|
| 31 |
+
router = APIRouter(prefix="/inference", tags=["inference"])
|
| 32 |
+
|
| 33 |
+
_engine = InferenceEngine()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ββ Single run βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
@router.post("/run", response_model=InferenceResult)
|
| 39 |
+
async def run_inference(body: InferenceRequest) -> InferenceResult:
|
| 40 |
+
"""Execute one full inference pass and return the complete result."""
|
| 41 |
+
model = await get_model(body.model_id)
|
| 42 |
+
if not model:
|
| 43 |
+
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 44 |
+
|
| 45 |
+
result = await _engine.run(body, model)
|
| 46 |
+
|
| 47 |
+
if result.status == "error":
|
| 48 |
+
raise HTTPException(status_code=500, detail=result.error or "Inference failed")
|
| 49 |
+
|
| 50 |
+
await record(body, result, model.name)
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ββ SSE stream βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
|
| 56 |
+
@router.post("/stream")
|
| 57 |
+
async def stream_inference(body: InferenceRequest) -> StreamingResponse:
|
| 58 |
+
"""
|
| 59 |
+
Server-Sent Events stream.
|
| 60 |
+
Emits one JSON event per pipeline stage as it completes, then a final
|
| 61 |
+
'done' event with the full InferenceResult.
|
| 62 |
+
|
| 63 |
+
Client usage:
|
| 64 |
+
const es = new EventSource('/inference/stream'); // POST via fetch + EventSource polyfill
|
| 65 |
+
es.onmessage = e => console.log(JSON.parse(e.data));
|
| 66 |
+
"""
|
| 67 |
+
model = await get_model(body.model_id)
|
| 68 |
+
if not model:
|
| 69 |
+
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 70 |
+
|
| 71 |
+
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
| 72 |
+
|
| 73 |
+
async def _producer() -> None:
|
| 74 |
+
"""Run inference while pushing SSE events into the queue."""
|
| 75 |
+
try:
|
| 76 |
+
# Patch engine to emit stage events
|
| 77 |
+
result = await _engine_stream(body, model, queue)
|
| 78 |
+
await record(body, result, model.name)
|
| 79 |
+
# Final complete event
|
| 80 |
+
await queue.put(
|
| 81 |
+
f"event: done\ndata: {result.model_dump_json()}\n\n"
|
| 82 |
+
)
|
| 83 |
+
except Exception as exc:
|
| 84 |
+
await queue.put(
|
| 85 |
+
f"event: error\ndata: {json.dumps({'error': str(exc)})}\n\n"
|
| 86 |
+
)
|
| 87 |
+
finally:
|
| 88 |
+
await queue.put(None) # sentinel
|
| 89 |
+
|
| 90 |
+
asyncio.create_task(_producer())
|
| 91 |
+
|
| 92 |
+
async def _generator():
|
| 93 |
+
while True:
|
| 94 |
+
msg = await queue.get()
|
| 95 |
+
if msg is None:
|
| 96 |
+
break
|
| 97 |
+
yield msg
|
| 98 |
+
|
| 99 |
+
return StreamingResponse(
|
| 100 |
+
_generator(),
|
| 101 |
+
media_type="text/event-stream",
|
| 102 |
+
headers={
|
| 103 |
+
"Cache-Control": "no-cache",
|
| 104 |
+
"X-Accel-Buffering": "no",
|
| 105 |
+
},
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
async def _engine_stream(
|
| 110 |
+
req: InferenceRequest,
|
| 111 |
+
model,
|
| 112 |
+
queue: asyncio.Queue,
|
| 113 |
+
) -> InferenceResult:
|
| 114 |
+
"""
|
| 115 |
+
Run inference and push a 'stage' SSE event for each PipelineStage.
|
| 116 |
+
Falls back to a simple full run if streaming is not distinguishable.
|
| 117 |
+
"""
|
| 118 |
+
# Run full pipeline
|
| 119 |
+
result = await _engine.run(req, model)
|
| 120 |
+
|
| 121 |
+
# Emit one event per stage (replay after completion)
|
| 122 |
+
for stage in result.pipeline:
|
| 123 |
+
payload = json.dumps({
|
| 124 |
+
"type": "stage",
|
| 125 |
+
"stage": stage.model_dump(),
|
| 126 |
+
"ts": time.time(),
|
| 127 |
+
})
|
| 128 |
+
await queue.put(f"data: {payload}\n\n")
|
| 129 |
+
await asyncio.sleep(0) # yield
|
| 130 |
+
|
| 131 |
+
# Emit vitals snapshot
|
| 132 |
+
vitals_payload = json.dumps({
|
| 133 |
+
"type": "vitals",
|
| 134 |
+
"latency_ms": result.inference_ms,
|
| 135 |
+
"total_ms": result.total_ms,
|
| 136 |
+
"quality": result.quality_score,
|
| 137 |
+
})
|
| 138 |
+
await queue.put(f"data: {vitals_payload}\n\n")
|
| 139 |
+
|
| 140 |
+
return result
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ββ History ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
|
| 145 |
+
@router.get("/history", response_model=list[InferenceHistoryEntry])
|
| 146 |
+
async def inference_history(limit: int = 50) -> list[InferenceHistoryEntry]:
|
| 147 |
+
return await get_history(limit=min(limit, 200))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@router.delete("/history", status_code=204, response_model=None)
|
| 151 |
+
async def clear_inference_history():
|
| 152 |
+
await clear_history()
|
| 153 |
+
return Response(status_code=204)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ββ Model cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
+
|
| 158 |
+
@router.get("/cache")
|
| 159 |
+
async def cache_status() -> dict[str, bool]:
|
| 160 |
+
return get_cache_status()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@router.delete("/cache/{model_id}", status_code=204, response_model=None)
|
| 164 |
+
async def evict_from_cache(model_id: str):
|
| 165 |
+
evicted = evict_model(model_id)
|
| 166 |
+
if not evicted:
|
| 167 |
+
raise HTTPException(status_code=404, detail="Model not in cache")
|
| 168 |
+
return Response(status_code=204)
|
api/routes/jobs.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/jobs.py β /jobs & /download endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, HTTPException
|
| 7 |
+
|
| 8 |
+
from download.manager import cancel_job, enqueue_download, get_job, list_jobs
|
| 9 |
+
from models.job import Job, JobCreate
|
| 10 |
+
from observability.logger import audit, get_logger
|
| 11 |
+
from registry.registry import get_model
|
| 12 |
+
|
| 13 |
+
log = get_logger("api.jobs")
|
| 14 |
+
router = APIRouter(tags=["jobs"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.post("/download", response_model=Job, status_code=202)
|
| 18 |
+
async def trigger_download(body: JobCreate) -> Job:
|
| 19 |
+
"""Enqueue a model download. Returns the created job immediately."""
|
| 20 |
+
model = await get_model(body.model_id)
|
| 21 |
+
if not model:
|
| 22 |
+
raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
|
| 23 |
+
if model.downloaded:
|
| 24 |
+
raise HTTPException(status_code=409, detail="Model is already cached locally")
|
| 25 |
+
|
| 26 |
+
job_id = await enqueue_download(
|
| 27 |
+
model_id=body.model_id,
|
| 28 |
+
model_name=body.model_name,
|
| 29 |
+
version=body.version,
|
| 30 |
+
)
|
| 31 |
+
job = await get_job(job_id)
|
| 32 |
+
if not job:
|
| 33 |
+
raise HTTPException(status_code=500, detail="Job creation failed")
|
| 34 |
+
|
| 35 |
+
await audit("api_download_trigger", model_id=body.model_id, job_id=job_id)
|
| 36 |
+
return job
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.get("/jobs", response_model=list[Job])
|
| 40 |
+
async def jobs_list(status: str | None = None, limit: int = 50) -> list[Job]:
|
| 41 |
+
return await list_jobs(status=status, limit=limit)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@router.get("/jobs/{job_id}", response_model=Job)
|
| 45 |
+
async def job_detail(job_id: str) -> Job:
|
| 46 |
+
job = await get_job(job_id)
|
| 47 |
+
if not job:
|
| 48 |
+
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
| 49 |
+
return job
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@router.delete("/jobs/{job_id}", status_code=204, response_model=None)
|
| 53 |
+
async def job_cancel(job_id: str) -> None:
|
| 54 |
+
success = await cancel_job(job_id)
|
| 55 |
+
if not success:
|
| 56 |
+
raise HTTPException(status_code=409, detail="Job cannot be cancelled")
|
api/routes/models.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/models.py β /models REST endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Annotated
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form
|
| 9 |
+
from models.model import Model
|
| 10 |
+
from observability.logger import audit, get_logger
|
| 11 |
+
from registry.registry import count_models, get_model, list_models
|
| 12 |
+
from projects.service import get_active_project_id, import_local_model
|
| 13 |
+
from projects.registry import get_project
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import os
|
| 16 |
+
import tempfile
|
| 17 |
+
|
| 18 |
+
log = get_logger("api.models")
|
| 19 |
+
router = APIRouter(prefix="/models", tags=["models"])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@router.get("", response_model=list[Model])
|
| 23 |
+
async def index(
|
| 24 |
+
search: Annotated[str | None, Query()] = None,
|
| 25 |
+
task: Annotated[list[str] | None, Query()] = None,
|
| 26 |
+
framework: Annotated[list[str] | None, Query()] = None,
|
| 27 |
+
hardware: Annotated[list[str] | None, Query()] = None,
|
| 28 |
+
source: Annotated[list[str] | None, Query()] = None,
|
| 29 |
+
downloaded: Annotated[bool | None, Query()] = None,
|
| 30 |
+
sort_by: Annotated[str, Query()] = "downloads",
|
| 31 |
+
sort_dir: Annotated[str, Query()] = "desc",
|
| 32 |
+
limit: Annotated[int, Query(ge=1, le=1000)] = 200,
|
| 33 |
+
offset: Annotated[int, Query(ge=0)] = 0,
|
| 34 |
+
project_id: Annotated[str | None, Query()] = None,
|
| 35 |
+
) -> list[Model]:
|
| 36 |
+
"""
|
| 37 |
+
List and search models.
|
| 38 |
+
Supports FTS search + server-side filtering.
|
| 39 |
+
Target: < 100ms for up to 5 000 models.
|
| 40 |
+
"""
|
| 41 |
+
effective_project_id = project_id or await get_active_project_id()
|
| 42 |
+
|
| 43 |
+
models = await list_models(
|
| 44 |
+
search=search,
|
| 45 |
+
tasks=task,
|
| 46 |
+
frameworks=framework,
|
| 47 |
+
hardware=hardware,
|
| 48 |
+
sources=source,
|
| 49 |
+
downloaded=downloaded,
|
| 50 |
+
sort_by=sort_by,
|
| 51 |
+
sort_dir=sort_dir,
|
| 52 |
+
limit=limit,
|
| 53 |
+
offset=offset,
|
| 54 |
+
project_id=effective_project_id,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# If we have an active project, derive cache state from its workspace.
|
| 58 |
+
# This makes "downloaded" and "local_path" reflect the *current project*.
|
| 59 |
+
if effective_project_id:
|
| 60 |
+
proj = await get_project(effective_project_id)
|
| 61 |
+
if proj:
|
| 62 |
+
project_models_dir = Path(proj.path) / "models"
|
| 63 |
+
|
| 64 |
+
updated: list[Model] = []
|
| 65 |
+
for m in models:
|
| 66 |
+
model_dir = project_models_dir / m.id
|
| 67 |
+
if model_dir.exists() and model_dir.is_dir():
|
| 68 |
+
# Pick the first file in the model directory (best-effort).
|
| 69 |
+
found_file: str | None = None
|
| 70 |
+
try:
|
| 71 |
+
for p in model_dir.rglob("*"):
|
| 72 |
+
if p.is_file():
|
| 73 |
+
found_file = str(p)
|
| 74 |
+
break
|
| 75 |
+
except Exception:
|
| 76 |
+
found_file = None
|
| 77 |
+
|
| 78 |
+
if found_file:
|
| 79 |
+
updated.append(m.model_copy(update={"downloaded": True, "local_path": found_file}))
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Not present in this project β treat as not cached for this project.
|
| 83 |
+
updated.append(m.model_copy(update={"downloaded": False, "local_path": None}))
|
| 84 |
+
|
| 85 |
+
models = updated
|
| 86 |
+
|
| 87 |
+
await audit("api_list_models", payload={"count": len(models), "search": search})
|
| 88 |
+
return models
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@router.post("/import", response_model=Model)
|
| 92 |
+
async def import_model(
|
| 93 |
+
name: Annotated[str, Form()],
|
| 94 |
+
task: Annotated[str, Form()],
|
| 95 |
+
framework: Annotated[str, Form()],
|
| 96 |
+
file: UploadFile = File(...),
|
| 97 |
+
) -> Model:
|
| 98 |
+
"""Import a local model file into the active project."""
|
| 99 |
+
# Save uploaded file to a temporary location
|
| 100 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename or "")[1]) as tmp:
|
| 101 |
+
content = await file.read()
|
| 102 |
+
tmp.write(content)
|
| 103 |
+
tmp_path = tmp.name
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
model = await import_local_model(
|
| 107 |
+
name=name,
|
| 108 |
+
task=task,
|
| 109 |
+
framework=framework,
|
| 110 |
+
source_file_path=tmp_path
|
| 111 |
+
)
|
| 112 |
+
return model
|
| 113 |
+
except Exception as e:
|
| 114 |
+
log.error("model_import_failed", error=str(e))
|
| 115 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 116 |
+
finally:
|
| 117 |
+
if os.path.exists(tmp_path):
|
| 118 |
+
os.remove(tmp_path)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@router.get("/{model_id}", response_model=Model)
|
| 122 |
+
async def detail(model_id: str) -> Model:
|
| 123 |
+
model = await get_model(model_id)
|
| 124 |
+
if not model:
|
| 125 |
+
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
| 126 |
+
await audit("api_get_model", model_id=model_id)
|
| 127 |
+
return model
|
api/routes/projects.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""api/routes/projects.py β /projects REST endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 6 |
+
|
| 7 |
+
from models.project import Project
|
| 8 |
+
from observability.logger import audit
|
| 9 |
+
from projects.registry import delete_project, get_project, list_projects, touch_last_opened, upsert_project
|
| 10 |
+
from projects.service import set_active_project
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
router = APIRouter(prefix="/projects", tags=["projects"])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@router.get("", response_model=list[Project])
|
| 17 |
+
async def projects_list(
|
| 18 |
+
limit: int = Query(200, ge=1, le=1000),
|
| 19 |
+
offset: int = Query(0, ge=0),
|
| 20 |
+
) -> list[Project]:
|
| 21 |
+
projects = await list_projects(limit=limit, offset=offset)
|
| 22 |
+
await audit("api_list_projects", payload={"count": len(projects)})
|
| 23 |
+
return projects
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.post("", response_model=Project)
|
| 27 |
+
async def projects_upsert(project: Project) -> Project:
|
| 28 |
+
# Ensure project.created_at and last_opened are set if missing
|
| 29 |
+
if not project.created_at:
|
| 30 |
+
project.created_at = datetime.now(timezone.utc).isoformat()
|
| 31 |
+
if not project.last_opened:
|
| 32 |
+
project.last_opened = datetime.now(timezone.utc).isoformat()
|
| 33 |
+
|
| 34 |
+
await upsert_project(project)
|
| 35 |
+
await audit("api_upsert_project", payload={"project_id": project.id})
|
| 36 |
+
return project
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("/{project_id}/open", status_code=204, response_model=None)
|
| 40 |
+
async def projects_open(project_id: str) -> None:
|
| 41 |
+
await touch_last_opened(project_id)
|
| 42 |
+
project = await get_project(project_id)
|
| 43 |
+
if project:
|
| 44 |
+
await set_active_project(project.id, project.path)
|
| 45 |
+
await audit("api_open_project", payload={"project_id": project_id})
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.delete("/{project_id}", status_code=204, response_model=None)
|
| 49 |
+
async def projects_delete(project_id: str) -> None:
|
| 50 |
+
ok = await delete_project(project_id)
|
| 51 |
+
if not ok:
|
| 52 |
+
raise HTTPException(status_code=404, detail=f"Project '{project_id}' not found")
|
| 53 |
+
await audit("api_delete_project", payload={"project_id": project_id})
|
| 54 |
+
|
api/routes/sync.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/sync.py β /sync endpoint: fetch fresh models from all adapters.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, BackgroundTasks
|
| 7 |
+
|
| 8 |
+
from adapters.hf_adapter import HFAdapter
|
| 9 |
+
from adapters.onnx_adapter import ONNXAdapter
|
| 10 |
+
from observability.logger import audit, get_logger
|
| 11 |
+
from registry.registry import bulk_upsert, count_models
|
| 12 |
+
|
| 13 |
+
log = get_logger("api.sync")
|
| 14 |
+
router = APIRouter(tags=["sync"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
async def _run_full_sync() -> None:
|
| 18 |
+
log.info("sync_start")
|
| 19 |
+
total = 0
|
| 20 |
+
|
| 21 |
+
async with HFAdapter() as hf:
|
| 22 |
+
hf_models = await hf.fetch_models()
|
| 23 |
+
await bulk_upsert(hf_models)
|
| 24 |
+
total += len(hf_models)
|
| 25 |
+
log.info("sync_hf_done", count=len(hf_models))
|
| 26 |
+
|
| 27 |
+
# Prune any HF models outside the allowed task set (e.g. legacy NLP entries)
|
| 28 |
+
allowed_tasks = {"detection", "classification", "segmentation", "generation", "embedding"}
|
| 29 |
+
from database.connection import get_db
|
| 30 |
+
|
| 31 |
+
db = await get_db()
|
| 32 |
+
placeholders = ",".join(["?"] * len(allowed_tasks))
|
| 33 |
+
await db.execute(
|
| 34 |
+
f"DELETE FROM models WHERE source = 'hf' AND task NOT IN ({placeholders})",
|
| 35 |
+
tuple(sorted(allowed_tasks)),
|
| 36 |
+
)
|
| 37 |
+
# Prune non-vision generation/embedding HF models. We rely on the adapter
|
| 38 |
+
# adding the pipeline_tag as a normalised tag (e.g. text_to_image).
|
| 39 |
+
await db.execute(
|
| 40 |
+
"""
|
| 41 |
+
DELETE FROM models
|
| 42 |
+
WHERE source = 'hf'
|
| 43 |
+
AND task IN ('generation','embedding')
|
| 44 |
+
AND (
|
| 45 |
+
tags NOT LIKE '%text_to_image%'
|
| 46 |
+
AND tags NOT LIKE '%image_to_image%'
|
| 47 |
+
AND tags NOT LIKE '%image_feature_extraction%'
|
| 48 |
+
)
|
| 49 |
+
""",
|
| 50 |
+
)
|
| 51 |
+
await db.commit()
|
| 52 |
+
|
| 53 |
+
onnx = ONNXAdapter()
|
| 54 |
+
onnx_models = await onnx.fetch_models()
|
| 55 |
+
await bulk_upsert(onnx_models)
|
| 56 |
+
total += len(onnx_models)
|
| 57 |
+
log.info("sync_onnx_done", count=len(onnx_models))
|
| 58 |
+
|
| 59 |
+
log.info("sync_complete", total=total)
|
| 60 |
+
await audit("sync_complete", payload={"total": total})
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.post("/sync", status_code=202)
|
| 64 |
+
async def trigger_sync(background_tasks: BackgroundTasks) -> dict:
|
| 65 |
+
"""
|
| 66 |
+
Kick off a background sync that fetches models from all sources.
|
| 67 |
+
Returns immediately; progress visible via /models count.
|
| 68 |
+
"""
|
| 69 |
+
background_tasks.add_task(_run_full_sync)
|
| 70 |
+
current = await count_models()
|
| 71 |
+
log.info("sync_triggered", current_model_count=current)
|
| 72 |
+
await audit("sync_triggered", payload={"current": current})
|
| 73 |
+
return {"message": "Sync started", "current_model_count": current}
|
api/routes/system.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""api/routes/system.py β System metrics endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, Query
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
+
|
| 11 |
+
from models.system import SystemMetrics
|
| 12 |
+
from system.metrics import sample_metrics
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/system", tags=["system"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/metrics", response_model=SystemMetrics)
|
| 18 |
+
async def get_metrics(gpu_index: int = Query(0, ge=0)) -> SystemMetrics:
|
| 19 |
+
payload = sample_metrics(gpu_index=gpu_index)
|
| 20 |
+
return SystemMetrics(
|
| 21 |
+
ts=payload["ts"],
|
| 22 |
+
cpu_pct=payload["cpu_pct"],
|
| 23 |
+
cpu_model=payload.get("cpu_model"),
|
| 24 |
+
cpu_freq_mhz=payload.get("cpu_freq_mhz"),
|
| 25 |
+
cpu_count=payload.get("cpu_count"),
|
| 26 |
+
ram_used_mb=payload["ram_used_mb"],
|
| 27 |
+
ram_total_mb=payload["ram_total_mb"],
|
| 28 |
+
gpu=payload.get("gpu"),
|
| 29 |
+
disks=payload.get("disks", []),
|
| 30 |
+
network=payload.get("network", []),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@router.get("/metrics/stream")
|
| 35 |
+
async def stream_metrics(
|
| 36 |
+
gpu_index: int = Query(0, ge=0),
|
| 37 |
+
hz: float = Query(2.0, ge=0.2, le=20.0),
|
| 38 |
+
):
|
| 39 |
+
"""Server-Sent Events stream of system metrics."""
|
| 40 |
+
|
| 41 |
+
interval = 1.0 / float(hz)
|
| 42 |
+
|
| 43 |
+
async def gen():
|
| 44 |
+
# Initial comment helps some proxies establish the stream
|
| 45 |
+
yield ": connected\n\n"
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
payload = sample_metrics(gpu_index=gpu_index)
|
| 49 |
+
# Ensure the payload is valid JSON and wrapped in data: format
|
| 50 |
+
data = json.dumps(payload)
|
| 51 |
+
yield f"data: {data}\n\n"
|
| 52 |
+
except Exception as e:
|
| 53 |
+
# Log error but keep stream alive
|
| 54 |
+
print(f"Metrics streaming error: {e}")
|
| 55 |
+
await asyncio.sleep(interval)
|
| 56 |
+
|
| 57 |
+
return StreamingResponse(
|
| 58 |
+
gen(),
|
| 59 |
+
media_type="text/event-stream",
|
| 60 |
+
headers={
|
| 61 |
+
"Cache-Control": "no-cache",
|
| 62 |
+
"X-Accel-Buffering": "no",
|
| 63 |
+
"Connection": "keep-alive",
|
| 64 |
+
"Transfer-Encoding": "chunked",
|
| 65 |
+
},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@router.get("/logs/stream")
|
| 70 |
+
async def stream_system_logs():
|
| 71 |
+
"""SSE stream of global system and gateway logs."""
|
| 72 |
+
from observability.logger import _sys_log_subs
|
| 73 |
+
|
| 74 |
+
q: asyncio.Queue = asyncio.Queue()
|
| 75 |
+
_sys_log_subs.append(q)
|
| 76 |
+
|
| 77 |
+
async def generator():
|
| 78 |
+
yield ": connected\n\n"
|
| 79 |
+
try:
|
| 80 |
+
while True:
|
| 81 |
+
try:
|
| 82 |
+
entry = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 83 |
+
except asyncio.TimeoutError:
|
| 84 |
+
yield ": heartbeat\n\n"
|
| 85 |
+
continue
|
| 86 |
+
if entry is None:
|
| 87 |
+
break
|
| 88 |
+
yield f"data: {json.dumps(entry)}\n\n"
|
| 89 |
+
finally:
|
| 90 |
+
if q in _sys_log_subs:
|
| 91 |
+
_sys_log_subs.remove(q)
|
| 92 |
+
|
| 93 |
+
return StreamingResponse(
|
| 94 |
+
generator(),
|
| 95 |
+
media_type="text/event-stream",
|
| 96 |
+
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 97 |
+
)
|
api/routes/training.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api/routes/training.py β Training Engine REST + SSE endpoints.
|
| 3 |
+
|
| 4 |
+
POST /train/start β create and launch a training run
|
| 5 |
+
POST /train/stop β cancel a running run
|
| 6 |
+
POST /train/pause β pause a running run
|
| 7 |
+
POST /train/resume β resume a paused run
|
| 8 |
+
GET /train/status β run status + progress snapshot
|
| 9 |
+
GET /train/runs β list all runs
|
| 10 |
+
GET /train/runs/{run_id} β single run detail
|
| 11 |
+
GET /train/schema β UI schema for task/model/dataset combo
|
| 12 |
+
GET /train/checkpoints β checkpoints for a run (stub)
|
| 13 |
+
POST /train/checkpoints/{id}/export β export a checkpoint (stub)
|
| 14 |
+
GET /train/metrics/stream β SSE: real-time metrics ticks
|
| 15 |
+
GET /train/logs/stream β SSE: real-time log entries
|
| 16 |
+
GET /train/resources/stream β SSE: real-time resource ticks
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import asyncio
|
| 21 |
+
import json
|
| 22 |
+
import time
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 26 |
+
from fastapi.responses import StreamingResponse
|
| 27 |
+
|
| 28 |
+
from observability.logger import get_logger
|
| 29 |
+
from training import run_manager
|
| 30 |
+
from training.schema_engine import generate_schema
|
| 31 |
+
from training.schemas import (
|
| 32 |
+
CheckpointOut,
|
| 33 |
+
PauseTrainRequest,
|
| 34 |
+
ResumeTrainRequest,
|
| 35 |
+
StartTrainRequest,
|
| 36 |
+
StartTrainResponse,
|
| 37 |
+
StopTrainRequest,
|
| 38 |
+
TrainRunOut,
|
| 39 |
+
TrainStatusResponse,
|
| 40 |
+
TrainingSchemaResponse,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
log = get_logger("api.training")
|
| 44 |
+
router = APIRouter(prefix="/train", tags=["training"])
|
| 45 |
+
|
| 46 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
|
| 48 |
+
def _format_duration(seconds: float) -> str:
|
| 49 |
+
h = int(seconds // 3600)
|
| 50 |
+
m = int((seconds % 3600) // 60)
|
| 51 |
+
s = int(seconds % 60)
|
| 52 |
+
return f"{h}h {m}m {s}s"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _run_to_out(run: run_manager.TrainRun) -> TrainRunOut:
|
| 56 |
+
elapsed = (run.completed_at or time.time()) - run.created_at
|
| 57 |
+
return TrainRunOut(
|
| 58 |
+
id=run.run_id,
|
| 59 |
+
run_number=run.run_number,
|
| 60 |
+
model_id=run.model_id,
|
| 61 |
+
model_name=run.model_name,
|
| 62 |
+
dataset_id=run.dataset_id,
|
| 63 |
+
dataset_name=run.dataset_name,
|
| 64 |
+
task=run.task,
|
| 65 |
+
status=run.status,
|
| 66 |
+
epochs_done=run.epoch,
|
| 67 |
+
total_epochs=run.total_epochs,
|
| 68 |
+
best_metric=run.best_metric,
|
| 69 |
+
final_loss=run.final_loss,
|
| 70 |
+
duration=_format_duration(elapsed),
|
| 71 |
+
created_at=run.created_at,
|
| 72 |
+
completed_at=run.completed_at,
|
| 73 |
+
hyperparams=run.hyperparams,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ββ Control endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
|
| 79 |
+
@router.post("/start", response_model=StartTrainResponse)
|
| 80 |
+
async def start_training(body: StartTrainRequest) -> StartTrainResponse:
|
| 81 |
+
"""Create and immediately launch a training run."""
|
| 82 |
+
# Resolve friendly names (fall back to ids if registries unavailable)
|
| 83 |
+
model_name = body.model_id
|
| 84 |
+
dataset_name = body.dataset_id
|
| 85 |
+
try:
|
| 86 |
+
from registry.registry import get_model
|
| 87 |
+
m = await get_model(body.model_id)
|
| 88 |
+
if m:
|
| 89 |
+
model_name = m.name
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
try:
|
| 93 |
+
from datasets.registry import get_dataset
|
| 94 |
+
d = await get_dataset(body.dataset_id)
|
| 95 |
+
if d:
|
| 96 |
+
dataset_name = d.get("name", body.dataset_id) if isinstance(d, dict) else getattr(d, "name", body.dataset_id)
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
run = run_manager.create_run(
|
| 101 |
+
model_id=body.model_id,
|
| 102 |
+
model_name=model_name,
|
| 103 |
+
dataset_id=body.dataset_id,
|
| 104 |
+
dataset_name=dataset_name,
|
| 105 |
+
task=body.task,
|
| 106 |
+
hyperparams=body.hyperparams,
|
| 107 |
+
augmentation=body.augmentation,
|
| 108 |
+
scheduler=body.scheduler,
|
| 109 |
+
project_id=body.project_id
|
| 110 |
+
)
|
| 111 |
+
run_manager.start_run(run)
|
| 112 |
+
|
| 113 |
+
log.info("training_started", run_id=run.run_id, model=body.model_id)
|
| 114 |
+
return StartTrainResponse(
|
| 115 |
+
run_id=run.run_id,
|
| 116 |
+
status=run.status,
|
| 117 |
+
message=f"Training run {run.run_id} started.",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@router.post("/stop", status_code=200)
|
| 122 |
+
async def stop_training(body: StopTrainRequest) -> dict:
|
| 123 |
+
run = run_manager.get_run(body.run_id)
|
| 124 |
+
if not run:
|
| 125 |
+
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 126 |
+
run_manager.stop_run(run)
|
| 127 |
+
log.info("training_stopped", run_id=body.run_id)
|
| 128 |
+
return {"run_id": body.run_id, "status": run.status}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@router.post("/pause", status_code=200)
|
| 132 |
+
async def pause_training(body: PauseTrainRequest) -> dict:
|
| 133 |
+
run = run_manager.get_run(body.run_id)
|
| 134 |
+
if not run:
|
| 135 |
+
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 136 |
+
run_manager.pause_run(run)
|
| 137 |
+
return {"run_id": body.run_id, "status": run.status}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@router.post("/resume", status_code=200)
|
| 141 |
+
async def resume_training(body: ResumeTrainRequest) -> dict:
|
| 142 |
+
run = run_manager.get_run(body.run_id)
|
| 143 |
+
if not run:
|
| 144 |
+
raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
|
| 145 |
+
run_manager.resume_run(run)
|
| 146 |
+
return {"run_id": body.run_id, "status": run.status}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@router.get("/status", response_model=TrainStatusResponse)
|
| 150 |
+
async def get_train_status(run_id: str = Query(...)) -> TrainStatusResponse:
|
| 151 |
+
run = run_manager.get_run(run_id)
|
| 152 |
+
if not run:
|
| 153 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 154 |
+
return TrainStatusResponse(
|
| 155 |
+
run_id=run.run_id,
|
| 156 |
+
status=run.status,
|
| 157 |
+
epoch=run.epoch,
|
| 158 |
+
total_epochs=run.total_epochs,
|
| 159 |
+
step=run.step,
|
| 160 |
+
total_steps=run.total_epochs * 100,
|
| 161 |
+
eta_seconds=run.eta_seconds,
|
| 162 |
+
elapsed_seconds=run.elapsed_seconds,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ββ Run history βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
|
| 168 |
+
@router.get("/runs", response_model=list[TrainRunOut])
|
| 169 |
+
async def list_runs() -> list[TrainRunOut]:
|
| 170 |
+
return [_run_to_out(r) for r in reversed(run_manager.list_runs())]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@router.get("/runs/{run_id}", response_model=TrainRunOut)
|
| 174 |
+
async def get_run(run_id: str) -> TrainRunOut:
|
| 175 |
+
run = run_manager.get_run(run_id)
|
| 176 |
+
if not run:
|
| 177 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 178 |
+
return _run_to_out(run)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ββ Schema Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
+
|
| 183 |
+
@router.get("/schema", response_model=TrainingSchemaResponse)
|
| 184 |
+
async def get_schema(
|
| 185 |
+
model_id: str = Query(""),
|
| 186 |
+
dataset_id: str = Query(""),
|
| 187 |
+
task: str = Query("detection"),
|
| 188 |
+
) -> TrainingSchemaResponse:
|
| 189 |
+
schema = generate_schema(task=task, model_id=model_id, dataset_id=dataset_id)
|
| 190 |
+
return TrainingSchemaResponse(**schema)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ββ Checkpoints (stub β extend when artifact storage is wired) ββββββββββββββββ
|
| 194 |
+
|
| 195 |
+
@router.get("/checkpoints", response_model=list[CheckpointOut])
|
| 196 |
+
async def list_checkpoints(run_id: str = Query(...)) -> list[CheckpointOut]:
|
| 197 |
+
"""Returns an empty list until checkpoint persistence is implemented."""
|
| 198 |
+
run = run_manager.get_run(run_id)
|
| 199 |
+
if not run:
|
| 200 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 201 |
+
return []
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@router.post("/checkpoints/{checkpoint_id}/export")
|
| 205 |
+
async def export_checkpoint(checkpoint_id: str, body: dict = {}) -> dict:
|
| 206 |
+
raise HTTPException(status_code=501, detail="Checkpoint export not yet implemented")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ββ SSE: Metrics stream ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
+
|
| 211 |
+
@router.get("/metrics/stream")
|
| 212 |
+
async def stream_metrics(run_id: str = Query(...)) -> StreamingResponse:
|
| 213 |
+
"""
|
| 214 |
+
Server-Sent Events stream of TrainMetricsTick objects.
|
| 215 |
+
Connects to the run's metrics queue and forwards each tick as SSE.
|
| 216 |
+
Stream closes when the run finishes (sentinel None pushed by worker).
|
| 217 |
+
"""
|
| 218 |
+
run = run_manager.get_run(run_id)
|
| 219 |
+
if not run:
|
| 220 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 221 |
+
|
| 222 |
+
q: asyncio.Queue = asyncio.Queue()
|
| 223 |
+
run.metrics_subs.append(q)
|
| 224 |
+
|
| 225 |
+
async def generator():
|
| 226 |
+
yield ": connected\n\n"
|
| 227 |
+
try:
|
| 228 |
+
while True:
|
| 229 |
+
try:
|
| 230 |
+
tick = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 231 |
+
except asyncio.TimeoutError:
|
| 232 |
+
# Heartbeat to keep connection alive
|
| 233 |
+
yield ": heartbeat\n\n"
|
| 234 |
+
continue
|
| 235 |
+
if tick is None:
|
| 236 |
+
break
|
| 237 |
+
yield f"data: {json.dumps(tick)}\n\n"
|
| 238 |
+
finally:
|
| 239 |
+
if q in run.metrics_subs:
|
| 240 |
+
run.metrics_subs.remove(q)
|
| 241 |
+
|
| 242 |
+
return StreamingResponse(
|
| 243 |
+
generator(),
|
| 244 |
+
media_type="text/event-stream",
|
| 245 |
+
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ββ SSE: Logs stream ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 250 |
+
|
| 251 |
+
@router.get("/logs/stream")
|
| 252 |
+
async def stream_logs(run_id: str = Query(...)) -> StreamingResponse:
|
| 253 |
+
"""Server-Sent Events stream of LogEntry objects."""
|
| 254 |
+
run = run_manager.get_run(run_id)
|
| 255 |
+
if not run:
|
| 256 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 257 |
+
|
| 258 |
+
q: asyncio.Queue = asyncio.Queue()
|
| 259 |
+
run.log_subs.append(q)
|
| 260 |
+
|
| 261 |
+
async def generator():
|
| 262 |
+
yield ": connected\n\n"
|
| 263 |
+
try:
|
| 264 |
+
while True:
|
| 265 |
+
try:
|
| 266 |
+
entry = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 267 |
+
except asyncio.TimeoutError:
|
| 268 |
+
yield ": heartbeat\n\n"
|
| 269 |
+
continue
|
| 270 |
+
if entry is None:
|
| 271 |
+
break
|
| 272 |
+
yield f"data: {json.dumps(entry)}\n\n"
|
| 273 |
+
finally:
|
| 274 |
+
if q in run.log_subs:
|
| 275 |
+
run.log_subs.remove(q)
|
| 276 |
+
|
| 277 |
+
return StreamingResponse(
|
| 278 |
+
generator(),
|
| 279 |
+
media_type="text/event-stream",
|
| 280 |
+
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@router.get("/runs/{run_id}/history")
|
| 285 |
+
async def get_run_history(run_id: str) -> list[dict]:
|
| 286 |
+
"""Retrieves the full historical telemetry (metrics ticks) for a run."""
|
| 287 |
+
run = run_manager.get_run(run_id)
|
| 288 |
+
if not run:
|
| 289 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 290 |
+
|
| 291 |
+
from training.persistence import TrainingPersistence
|
| 292 |
+
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 293 |
+
telemetry_path = os.path.join(run_dir, "telemetry.jsonl")
|
| 294 |
+
|
| 295 |
+
history = []
|
| 296 |
+
if os.path.exists(telemetry_path):
|
| 297 |
+
try:
|
| 298 |
+
with open(telemetry_path, "r") as f:
|
| 299 |
+
for line in f:
|
| 300 |
+
if line.strip():
|
| 301 |
+
history.append(json.loads(line))
|
| 302 |
+
except Exception as e:
|
| 303 |
+
log.error("history_read_failed", run_id=run_id, error=str(e))
|
| 304 |
+
raise HTTPException(status_code=500, detail="Failed to read telemetry history")
|
| 305 |
+
|
| 306 |
+
return history
|
| 307 |
+
|
| 308 |
+
@router.get("/runs/{run_id}/artifacts")
|
| 309 |
+
async def list_run_artifacts(run_id: str) -> dict:
|
| 310 |
+
"""Lists available artifacts (images) for a specific run by scanning the directory."""
|
| 311 |
+
run = run_manager.get_run(run_id)
|
| 312 |
+
if not run:
|
| 313 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 314 |
+
|
| 315 |
+
from training.persistence import TrainingPersistence
|
| 316 |
+
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 317 |
+
|
| 318 |
+
if not os.path.exists(run_dir):
|
| 319 |
+
return {"artifacts": [], "batches": []}
|
| 320 |
+
|
| 321 |
+
artifacts = []
|
| 322 |
+
batches = []
|
| 323 |
+
|
| 324 |
+
# Standard YOLO artifact mappings for better UI titles
|
| 325 |
+
titles = {
|
| 326 |
+
"confusion_matrix.png": "Confusion Matrix",
|
| 327 |
+
"confusion_matrix_normalized.png": "Confusion Matrix (Norm)",
|
| 328 |
+
"results.png": "Results Summary",
|
| 329 |
+
"F1_curve.png": "F1 Curve",
|
| 330 |
+
"PR_curve.png": "PR Curve",
|
| 331 |
+
"P_curve.png": "Precision Curve",
|
| 332 |
+
"R_curve.png": "Recall Curve",
|
| 333 |
+
"BoxF1_curve.png": "Box F1 Curve",
|
| 334 |
+
"BoxP_curve.png": "Box Precision Curve",
|
| 335 |
+
"BoxPR_curve.png": "Box PR Curve",
|
| 336 |
+
"BoxR_curve.png": "Box Recall Curve",
|
| 337 |
+
"labels.jpg": "Labels Distribution",
|
| 338 |
+
"labels_correlogram.jpg": "Labels Correlogram"
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
for f in os.listdir(run_dir):
|
| 342 |
+
path = f"/train/runs/{run_id}/files/{f}"
|
| 343 |
+
if f.endswith(('.png', '.jpg', '.jpeg')):
|
| 344 |
+
item = {
|
| 345 |
+
"title": titles.get(f, f.replace('_', ' ').title().split('.')[0]),
|
| 346 |
+
"path": path,
|
| 347 |
+
"type": "Analysis"
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
if "batch" in f.lower():
|
| 351 |
+
item["type"] = "Batch Preview" if "val" in f.lower() else "Augmentation"
|
| 352 |
+
batches.append(item)
|
| 353 |
+
else:
|
| 354 |
+
if "curve" in f.lower():
|
| 355 |
+
item["type"] = "Precision-Recall"
|
| 356 |
+
elif "confusion" in f.lower():
|
| 357 |
+
item["type"] = "Analysis"
|
| 358 |
+
elif "results" in f.lower():
|
| 359 |
+
item["type"] = "Overall"
|
| 360 |
+
artifacts.append(item)
|
| 361 |
+
|
| 362 |
+
return {
|
| 363 |
+
"artifacts": sorted(artifacts, key=lambda x: x['title']),
|
| 364 |
+
"batches": sorted(batches, key=lambda x: x['title'])
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
@router.get("/runs/{run_id}/files/{filename}")
|
| 368 |
+
async def get_run_file(run_id: str, filename: str):
|
| 369 |
+
"""Serves a specific file from the run directory."""
|
| 370 |
+
run = run_manager.get_run(run_id)
|
| 371 |
+
if not run:
|
| 372 |
+
raise HTTPException(status_code=404, detail="Run not found")
|
| 373 |
+
|
| 374 |
+
# We need to find the project to get the run_dir
|
| 375 |
+
# Since run_manager doesn't easily expose the full path in memory,
|
| 376 |
+
# we recalculate it using persistence
|
| 377 |
+
from training.persistence import TrainingPersistence
|
| 378 |
+
run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
|
| 379 |
+
file_path = os.path.join(run_dir, filename)
|
| 380 |
+
|
| 381 |
+
if not os.path.exists(file_path):
|
| 382 |
+
raise HTTPException(status_code=404, detail="File not found")
|
| 383 |
+
|
| 384 |
+
from fastapi.responses import FileResponse
|
| 385 |
+
return FileResponse(file_path)
|
| 386 |
+
# The frontend uses /system/metrics/stream for resources (already implemented).
|
| 387 |
+
# This alias exists for training-scoped resource monitoring.
|
| 388 |
+
|
| 389 |
+
@router.get("/resources/stream")
|
| 390 |
+
async def stream_resources(
|
| 391 |
+
run_id: str = Query(...),
|
| 392 |
+
gpu_index: int = Query(0, ge=0),
|
| 393 |
+
hz: float = Query(1.0, ge=0.2, le=10.0),
|
| 394 |
+
) -> StreamingResponse:
|
| 395 |
+
"""
|
| 396 |
+
SSE stream of ResourceTick objects for a specific training run.
|
| 397 |
+
Forwards system metrics at the requested hz rate.
|
| 398 |
+
"""
|
| 399 |
+
run = run_manager.get_run(run_id)
|
| 400 |
+
if not run:
|
| 401 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 402 |
+
|
| 403 |
+
q: asyncio.Queue = asyncio.Queue()
|
| 404 |
+
run.resource_subs.append(q)
|
| 405 |
+
|
| 406 |
+
interval = 1.0 / hz
|
| 407 |
+
|
| 408 |
+
async def generator():
|
| 409 |
+
yield ": connected\n\n"
|
| 410 |
+
try:
|
| 411 |
+
while True:
|
| 412 |
+
try:
|
| 413 |
+
tick = await asyncio.wait_for(q.get(), timeout=30.0)
|
| 414 |
+
except asyncio.TimeoutError:
|
| 415 |
+
yield ": heartbeat\n\n"
|
| 416 |
+
continue
|
| 417 |
+
if tick is None:
|
| 418 |
+
break
|
| 419 |
+
yield f"data: {json.dumps(tick)}\n\n"
|
| 420 |
+
finally:
|
| 421 |
+
if q in run.resource_subs:
|
| 422 |
+
run.resource_subs.remove(q)
|
| 423 |
+
|
| 424 |
+
return StreamingResponse(
|
| 425 |
+
generator(),
|
| 426 |
+
media_type="text/event-stream",
|
| 427 |
+
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 428 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config.py β Centralized application settings.
|
| 3 |
+
All tuneable knobs live here; override via environment variables.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
model_config = SettingsConfigDict(
|
| 11 |
+
env_file=".env",
|
| 12 |
+
env_file_encoding="utf-8",
|
| 13 |
+
case_sensitive=False,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
app_name: str = "MLForge Platform"
|
| 18 |
+
version: str = "1.0.0"
|
| 19 |
+
debug: bool = False
|
| 20 |
+
|
| 21 |
+
# ββ API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
host: str = "0.0.0.0"
|
| 23 |
+
port: int = 8005
|
| 24 |
+
cors_origins: list[str] = [
|
| 25 |
+
"http://localhost:3000",
|
| 26 |
+
"http://127.0.0.1:3000",
|
| 27 |
+
"http://localhost:5173",
|
| 28 |
+
"http://127.0.0.1:5173",
|
| 29 |
+
"http://localhost:2000",
|
| 30 |
+
"http://127.0.0.1:2000",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# ββ Storage βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
base_dir: Path = Path(__file__).resolve().parents[1]
|
| 35 |
+
data_dir: Path = base_dir / "data"
|
| 36 |
+
models_dir: Path = data_dir / "models"
|
| 37 |
+
datasets_dir: Path = data_dir / "datasets" # root for imported datasets
|
| 38 |
+
logs_dir: Path = data_dir / "logs"
|
| 39 |
+
db_path: Path = data_dir / "modelzoo.db"
|
| 40 |
+
|
| 41 |
+
# ββ Download Manager ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
max_concurrent_downloads: int = 5
|
| 43 |
+
download_chunk_size: int = 1024 * 1024 # 1 MB
|
| 44 |
+
download_max_retries: int = 3
|
| 45 |
+
download_retry_delay: float = 2.0 # seconds (base, exponential backoff)
|
| 46 |
+
|
| 47 |
+
# ββ Search ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
search_max_results: int = 500
|
| 49 |
+
|
| 50 |
+
# ββ Sync ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
auto_sync_on_startup: bool = True
|
| 52 |
+
|
| 53 |
+
# ββ Hugging Face API ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
hf_api_base: str = "https://huggingface.co/api"
|
| 55 |
+
hf_hub_url: str = "https://huggingface.co"
|
| 56 |
+
hf_token: str | None = None # Optional: HF_TOKEN env var
|
| 57 |
+
hf_models_per_task: int = 100 # How many to pull per task
|
| 58 |
+
|
| 59 |
+
# ββ ONNX Zoo ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
onnx_models_url: str = (
|
| 61 |
+
"https://raw.githubusercontent.com/onnx/models/main/README.md"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# ββ Benchmark Bridge ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
benchmark_max_concurrent: int = 3 # max parallel benchmark jobs
|
| 66 |
+
benchmark_max_log_lines: int = 500 # log entries kept per job
|
| 67 |
+
benchmark_ws_poll_hz: float = 2.0 # WebSocket telemetry poll rate
|
| 68 |
+
|
| 69 |
+
# ββ Dataset Manager βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
roboflow_api_base: str = "https://api.roboflow.com"
|
| 71 |
+
dataset_import_workers: int = 3 # max concurrent import jobs
|
| 72 |
+
dataset_chunk_size: int = 1024 * 1024 * 4 # 4 MB download chunk
|
| 73 |
+
roboflow_cache_ttl_secs: int = 3600 # 1 hour
|
| 74 |
+
|
| 75 |
+
def ensure_dirs(self) -> None:
|
| 76 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
self.datasets_dir.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
(self.datasets_dir / "_tmp").mkdir(parents=True, exist_ok=True)
|
| 80 |
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
settings = Settings()
|
database/__init__.py
ADDED
|
File without changes
|
database/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
database/__pycache__/connection.cpython-310.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
database/benchmark_schema.sql
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MLForge Benchmark Bridge β SQLite Schema
|
| 3 |
+
-- Version: 1.0.0
|
| 4 |
+
-- ============================================================
|
| 5 |
+
|
| 6 |
+
PRAGMA journal_mode = WAL;
|
| 7 |
+
PRAGMA foreign_keys = ON;
|
| 8 |
+
|
| 9 |
+
-- ββ Benchmark Jobs ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
+
-- Tracks every benchmark run from queued β running β completed/failed.
|
| 11 |
+
-- config stores the full BenchmarkContext JSON for full reproducibility.
|
| 12 |
+
CREATE TABLE IF NOT EXISTS benchmark_jobs (
|
| 13 |
+
id TEXT PRIMARY KEY,
|
| 14 |
+
model_id TEXT NOT NULL,
|
| 15 |
+
dataset_id TEXT NOT NULL,
|
| 16 |
+
task TEXT NOT NULL,
|
| 17 |
+
framework TEXT NOT NULL,
|
| 18 |
+
hardware TEXT NOT NULL DEFAULT 'cpu',
|
| 19 |
+
precision TEXT NOT NULL DEFAULT 'FP32',
|
| 20 |
+
batch_size INTEGER NOT NULL DEFAULT 1,
|
| 21 |
+
config TEXT NOT NULL DEFAULT '{}', -- full BenchmarkContext JSON
|
| 22 |
+
status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed
|
| 23 |
+
progress REAL NOT NULL DEFAULT 0.0, -- 0.0β1.0
|
| 24 |
+
logs TEXT NOT NULL DEFAULT '[]', -- JSON array of timestamped log strings
|
| 25 |
+
error TEXT,
|
| 26 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 27 |
+
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 28 |
+
started_at TEXT,
|
| 29 |
+
ended_at TEXT
|
| 30 |
+
);
|
| 31 |
+
|
| 32 |
+
-- ββ Benchmark Results βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
-- Stores final computed metrics + telemetry summary after job completion.
|
| 34 |
+
CREATE TABLE IF NOT EXISTS benchmark_results (
|
| 35 |
+
id TEXT PRIMARY KEY,
|
| 36 |
+
job_id TEXT NOT NULL REFERENCES benchmark_jobs(id) ON DELETE CASCADE,
|
| 37 |
+
metrics TEXT NOT NULL DEFAULT '{}', -- JSON: BenchmarkMetrics
|
| 38 |
+
telemetry_summary TEXT NOT NULL DEFAULT '{}', -- JSON: TelemetrySummary
|
| 39 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 40 |
+
);
|
| 41 |
+
|
| 42 |
+
-- ββ Validation Logs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
-- Immutable audit trail of every compatibility check performed.
|
| 44 |
+
-- job_id = 'pre-check' for validations that blocked job creation.
|
| 45 |
+
CREATE TABLE IF NOT EXISTS benchmark_validation_logs (
|
| 46 |
+
id TEXT PRIMARY KEY,
|
| 47 |
+
job_id TEXT NOT NULL,
|
| 48 |
+
model_id TEXT NOT NULL,
|
| 49 |
+
dataset_id TEXT NOT NULL,
|
| 50 |
+
checks TEXT NOT NULL DEFAULT '[]', -- JSON: list[ValidationCheck]
|
| 51 |
+
passed INTEGER NOT NULL DEFAULT 1, -- 1=passed, 0=failed
|
| 52 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 53 |
+
);
|
| 54 |
+
|
| 55 |
+
-- ββ Indexes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_jobs_status ON benchmark_jobs(status);
|
| 57 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_jobs_model ON benchmark_jobs(model_id);
|
| 58 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_jobs_dataset ON benchmark_jobs(dataset_id);
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_jobs_created ON benchmark_jobs(created_at DESC);
|
| 60 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_results_job ON benchmark_results(job_id);
|
| 61 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_valid_job ON benchmark_validation_logs(job_id);
|
| 62 |
+
CREATE INDEX IF NOT EXISTS idx_bmark_valid_model ON benchmark_validation_logs(model_id);
|
database/connection.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
database/connection.py β Async SQLite connection & migration bootstrap.
|
| 3 |
+
Single module responsible for DB lifecycle. All queries use this pool.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import aiosqlite
|
| 11 |
+
|
| 12 |
+
from config import settings
|
| 13 |
+
|
| 14 |
+
# Module-level connection (shared within the process)
|
| 15 |
+
_db: aiosqlite.Connection | None = None
|
| 16 |
+
_lock = asyncio.Lock()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def get_db() -> aiosqlite.Connection:
|
| 20 |
+
"""Return the singleton async database connection."""
|
| 21 |
+
global _db
|
| 22 |
+
async with _lock:
|
| 23 |
+
if _db is None:
|
| 24 |
+
_db = await _open_connection()
|
| 25 |
+
return _db
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def _open_connection() -> aiosqlite.Connection:
|
| 29 |
+
settings.ensure_dirs()
|
| 30 |
+
conn = await aiosqlite.connect(settings.db_path, check_same_thread=False)
|
| 31 |
+
conn.row_factory = aiosqlite.Row
|
| 32 |
+
await conn.execute("PRAGMA journal_mode=WAL")
|
| 33 |
+
await conn.execute("PRAGMA foreign_keys=ON")
|
| 34 |
+
await conn.execute("PRAGMA synchronous=NORMAL")
|
| 35 |
+
await conn.execute("PRAGMA cache_size=-65536") # 64 MB page cache
|
| 36 |
+
await _run_migrations(conn)
|
| 37 |
+
await conn.commit()
|
| 38 |
+
return conn
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def _run_migrations(conn: aiosqlite.Connection) -> None:
|
| 42 |
+
"""Apply all schema files idempotently (CREATE IF NOT EXISTS)."""
|
| 43 |
+
base = Path(__file__).parent
|
| 44 |
+
|
| 45 |
+
# ββ STEP 1: Ensure basic tables exist ββ
|
| 46 |
+
for schema_file in ["schema.sql", "dataset_schema.sql", "benchmark_schema.sql"]:
|
| 47 |
+
path = base / schema_file
|
| 48 |
+
if path.exists():
|
| 49 |
+
sql = path.read_text(encoding="utf-8")
|
| 50 |
+
await conn.executescript(sql)
|
| 51 |
+
|
| 52 |
+
# ββ STEP 2: Legacy Alterations ββ
|
| 53 |
+
# Check 'models' table for specific columns
|
| 54 |
+
async with conn.execute("PRAGMA table_info(models)") as cur:
|
| 55 |
+
cols = {r[1] for r in await cur.fetchall()}
|
| 56 |
+
|
| 57 |
+
if cols: # only if table exists
|
| 58 |
+
if "download_url" not in cols:
|
| 59 |
+
await conn.execute("ALTER TABLE models ADD COLUMN download_url TEXT")
|
| 60 |
+
|
| 61 |
+
if "active_version" not in cols:
|
| 62 |
+
await conn.execute("ALTER TABLE models ADD COLUMN active_version TEXT")
|
| 63 |
+
|
| 64 |
+
if "metrics" not in cols:
|
| 65 |
+
await conn.execute("ALTER TABLE models ADD COLUMN metrics TEXT NOT NULL DEFAULT '{}' ")
|
| 66 |
+
|
| 67 |
+
# Check 'datasets' table for new columns (e.g. active_version)
|
| 68 |
+
async with conn.execute("PRAGMA table_info(datasets)") as cur:
|
| 69 |
+
ds_cols = {r[1] for r in await cur.fetchall()}
|
| 70 |
+
|
| 71 |
+
if ds_cols:
|
| 72 |
+
if "active_version" not in ds_cols:
|
| 73 |
+
await conn.execute("ALTER TABLE datasets ADD COLUMN active_version TEXT NOT NULL DEFAULT 'v1'")
|
| 74 |
+
if "roboflow_id" not in ds_cols:
|
| 75 |
+
await conn.execute("ALTER TABLE datasets ADD COLUMN roboflow_id TEXT")
|
| 76 |
+
if "health_score" not in ds_cols:
|
| 77 |
+
await conn.execute("ALTER TABLE datasets ADD COLUMN health_score INTEGER NOT NULL DEFAULT 0")
|
| 78 |
+
|
| 79 |
+
# Check 'models' table for project_id
|
| 80 |
+
async with conn.execute("PRAGMA table_info(models)") as cur:
|
| 81 |
+
model_cols = {r[1] for r in await cur.fetchall()}
|
| 82 |
+
|
| 83 |
+
if model_cols and "project_id" not in model_cols:
|
| 84 |
+
await conn.execute("ALTER TABLE models ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE CASCADE")
|
| 85 |
+
|
| 86 |
+
# Clean up any lingering temporary tables from failed legacy migrations
|
| 87 |
+
# COMMIT is essential here to ensure background jobs see the clean state immediately
|
| 88 |
+
# We use a try/except block to avoid "no such table" errors if the table is already gone
|
| 89 |
+
try:
|
| 90 |
+
await conn.execute("DROP TABLE IF EXISTS datasets_old")
|
| 91 |
+
except:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
await conn.execute("DROP TABLE IF EXISTS dataset_jobs_old")
|
| 96 |
+
except:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
await conn.commit()
|
| 100 |
+
|
| 101 |
+
async def close_db() -> None:
|
| 102 |
+
global _db
|
| 103 |
+
async with _lock:
|
| 104 |
+
if _db is not None:
|
| 105 |
+
await _db.close()
|
| 106 |
+
_db = None
|
database/dataset_schema.sql
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MLForge Dataset Manager β SQLite Schema Extension
|
| 3 |
+
-- Appended to existing modelzoo.db (CREATE IF NOT EXISTS)
|
| 4 |
+
-- ============================================================
|
| 5 |
+
|
| 6 |
+
-- ββ Datasets ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
CREATE TABLE IF NOT EXISTS datasets (
|
| 8 |
+
id TEXT PRIMARY KEY,
|
| 9 |
+
name TEXT NOT NULL,
|
| 10 |
+
description TEXT NOT NULL DEFAULT '',
|
| 11 |
+
task TEXT NOT NULL,
|
| 12 |
+
format TEXT NOT NULL,
|
| 13 |
+
source TEXT NOT NULL DEFAULT 'roboflow',
|
| 14 |
+
status TEXT NOT NULL DEFAULT 'available',
|
| 15 |
+
images INTEGER NOT NULL DEFAULT 0,
|
| 16 |
+
classes INTEGER NOT NULL DEFAULT 0,
|
| 17 |
+
class_names TEXT NOT NULL DEFAULT '[]', -- JSON array
|
| 18 |
+
size_bytes INTEGER NOT NULL DEFAULT 0,
|
| 19 |
+
size_label TEXT NOT NULL DEFAULT '0 B',
|
| 20 |
+
local_path TEXT,
|
| 21 |
+
import_progress REAL NOT NULL DEFAULT 0.0, -- 0.0β1.0
|
| 22 |
+
tags TEXT NOT NULL DEFAULT '[]', -- JSON array
|
| 23 |
+
versions TEXT NOT NULL DEFAULT '[]', -- JSON array
|
| 24 |
+
active_version TEXT NOT NULL DEFAULT 'v1',
|
| 25 |
+
starred INTEGER NOT NULL DEFAULT 0,
|
| 26 |
+
roboflow_id TEXT, -- workspace/project slug
|
| 27 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 28 |
+
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 29 |
+
);
|
| 30 |
+
|
| 31 |
+
-- ββ Dataset Jobs ββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
CREATE TABLE IF NOT EXISTS dataset_jobs (
|
| 33 |
+
id TEXT PRIMARY KEY,
|
| 34 |
+
type TEXT NOT NULL, -- import|extract|validate|analyze
|
| 35 |
+
status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled
|
| 36 |
+
dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE,
|
| 37 |
+
dataset_name TEXT NOT NULL DEFAULT '',
|
| 38 |
+
progress REAL NOT NULL DEFAULT 0.0, -- 0.0β1.0
|
| 39 |
+
message TEXT NOT NULL DEFAULT '',
|
| 40 |
+
error TEXT,
|
| 41 |
+
meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data
|
| 42 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 43 |
+
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 44 |
+
started_at TEXT,
|
| 45 |
+
ended_at TEXT
|
| 46 |
+
);
|
| 47 |
+
|
| 48 |
+
-- ββ Dataset Images Index ββββββββββββββββββββββββββββββββββ
|
| 49 |
+
-- Populated after extraction; enables fast paginated viewer queries
|
| 50 |
+
CREATE TABLE IF NOT EXISTS dataset_images (
|
| 51 |
+
id TEXT PRIMARY KEY, -- sha1 or sequential id
|
| 52 |
+
dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE,
|
| 53 |
+
filename TEXT NOT NULL,
|
| 54 |
+
rel_path TEXT NOT NULL, -- relative to dataset local_path
|
| 55 |
+
width INTEGER NOT NULL DEFAULT 0,
|
| 56 |
+
height INTEGER NOT NULL DEFAULT 0,
|
| 57 |
+
split TEXT NOT NULL DEFAULT 'train',
|
| 58 |
+
ann_count INTEGER NOT NULL DEFAULT 0, -- fast count without parsing
|
| 59 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 60 |
+
);
|
| 61 |
+
|
| 62 |
+
-- ββ Dataset Annotations Cache βββββββββββββββββββββββββββββ
|
| 63 |
+
-- Parsed annotations stored in normalised form for fast retrieval
|
| 64 |
+
CREATE TABLE IF NOT EXISTS dataset_annotations (
|
| 65 |
+
id TEXT PRIMARY KEY,
|
| 66 |
+
image_id TEXT NOT NULL REFERENCES dataset_images(id) ON DELETE CASCADE,
|
| 67 |
+
dataset_id TEXT NOT NULL,
|
| 68 |
+
label TEXT NOT NULL,
|
| 69 |
+
bbox_x REAL,
|
| 70 |
+
bbox_y REAL,
|
| 71 |
+
bbox_w REAL,
|
| 72 |
+
bbox_h REAL,
|
| 73 |
+
normalised INTEGER DEFAULT 1,
|
| 74 |
+
area REAL,
|
| 75 |
+
confidence REAL,
|
| 76 |
+
ann_type TEXT DEFAULT 'detection',
|
| 77 |
+
segmentation TEXT, -- JSON array of points [[x,y],...]
|
| 78 |
+
keypoints TEXT, -- JSON array of keypoints [x,y,v,...]
|
| 79 |
+
metadata TEXT -- Extra JSON metadata
|
| 80 |
+
);
|
| 81 |
+
|
| 82 |
+
-- ββ Roboflow Metadata Cache βββββββββββββββββββββββββββββββ
|
| 83 |
+
-- Avoids redundant API calls; TTL enforced in Python layer
|
| 84 |
+
CREATE TABLE IF NOT EXISTS roboflow_cache (
|
| 85 |
+
cache_key TEXT PRIMARY KEY, -- workspace/project or search query hash
|
| 86 |
+
payload TEXT NOT NULL, -- JSON blob
|
| 87 |
+
fetched_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 88 |
+
ttl_secs INTEGER NOT NULL DEFAULT 3600 -- 1 hour default
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
-- ββ Indexes βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
CREATE INDEX IF NOT EXISTS idx_datasets_task ON datasets(task);
|
| 93 |
+
CREATE INDEX IF NOT EXISTS idx_datasets_format ON datasets(format);
|
| 94 |
+
CREATE INDEX IF NOT EXISTS idx_datasets_source ON datasets(source);
|
| 95 |
+
CREATE INDEX IF NOT EXISTS idx_datasets_status ON datasets(status);
|
| 96 |
+
CREATE INDEX IF NOT EXISTS idx_datasets_starred ON datasets(starred);
|
| 97 |
+
|
| 98 |
+
CREATE INDEX IF NOT EXISTS idx_djobs_status ON dataset_jobs(status);
|
| 99 |
+
CREATE INDEX IF NOT EXISTS idx_djobs_dataset ON dataset_jobs(dataset_id);
|
| 100 |
+
|
| 101 |
+
CREATE INDEX IF NOT EXISTS idx_dimages_dataset ON dataset_images(dataset_id);
|
| 102 |
+
CREATE INDEX IF NOT EXISTS idx_dimages_split ON dataset_images(dataset_id, split);
|
| 103 |
+
|
| 104 |
+
CREATE INDEX IF NOT EXISTS idx_dann_image ON dataset_annotations(image_id);
|
| 105 |
+
CREATE INDEX IF NOT EXISTS idx_dann_dataset ON dataset_annotations(dataset_id);
|
| 106 |
+
CREATE INDEX IF NOT EXISTS idx_dann_label ON dataset_annotations(dataset_id, label);
|
| 107 |
+
|
| 108 |
+
-- ββ Updated-at trigger for datasets ββββββββββββββββββββββ
|
| 109 |
+
CREATE TRIGGER IF NOT EXISTS datasets_updated_at
|
| 110 |
+
AFTER UPDATE ON datasets BEGIN
|
| 111 |
+
UPDATE datasets SET updated_at = datetime('now') WHERE id = NEW.id;
|
| 112 |
+
END;
|
| 113 |
+
|
| 114 |
+
CREATE TRIGGER IF NOT EXISTS dataset_jobs_updated_at
|
| 115 |
+
AFTER UPDATE ON dataset_jobs BEGIN
|
| 116 |
+
UPDATE dataset_jobs SET updated_at = datetime('now') WHERE id = NEW.id;
|
| 117 |
+
END;
|
database/schema.sql
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MLForge Model Zoo β SQLite Schema
|
| 3 |
+
-- Version: 1.0.0
|
| 4 |
+
-- ============================================================
|
| 5 |
+
|
| 6 |
+
PRAGMA journal_mode = WAL;
|
| 7 |
+
PRAGMA foreign_keys = ON;
|
| 8 |
+
|
| 9 |
+
-- ββ Models ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
+
CREATE TABLE IF NOT EXISTS models (
|
| 11 |
+
id TEXT PRIMARY KEY,
|
| 12 |
+
name TEXT NOT NULL,
|
| 13 |
+
variant TEXT,
|
| 14 |
+
task TEXT NOT NULL,
|
| 15 |
+
framework TEXT NOT NULL,
|
| 16 |
+
source TEXT NOT NULL DEFAULT 'hf',
|
| 17 |
+
provider TEXT NOT NULL DEFAULT '',
|
| 18 |
+
description TEXT NOT NULL DEFAULT '',
|
| 19 |
+
download_url TEXT, -- explicit download source URL
|
| 20 |
+
size INTEGER NOT NULL DEFAULT 0,
|
| 21 |
+
size_label TEXT NOT NULL DEFAULT '0B',
|
| 22 |
+
tags TEXT NOT NULL DEFAULT '[]', -- JSON array
|
| 23 |
+
hardware TEXT NOT NULL DEFAULT '[]', -- JSON array
|
| 24 |
+
status TEXT NOT NULL DEFAULT 'available',
|
| 25 |
+
downloaded INTEGER NOT NULL DEFAULT 0,
|
| 26 |
+
active_version TEXT,
|
| 27 |
+
local_path TEXT,
|
| 28 |
+
metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc.
|
| 29 |
+
downloads INTEGER DEFAULT 0,
|
| 30 |
+
rating REAL,
|
| 31 |
+
liked INTEGER NOT NULL DEFAULT 0,
|
| 32 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 33 |
+
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 34 |
+
);
|
| 35 |
+
|
| 36 |
+
-- ββ Model Versions ββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
CREATE TABLE IF NOT EXISTS model_versions (
|
| 38 |
+
version_id TEXT PRIMARY KEY,
|
| 39 |
+
model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE,
|
| 40 |
+
version TEXT NOT NULL,
|
| 41 |
+
label TEXT NOT NULL DEFAULT 'Stable', -- Latest|Stable|Legacy
|
| 42 |
+
description TEXT,
|
| 43 |
+
metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc.
|
| 44 |
+
local_path TEXT,
|
| 45 |
+
downloaded INTEGER NOT NULL DEFAULT 0,
|
| 46 |
+
release_date TEXT,
|
| 47 |
+
changelog TEXT,
|
| 48 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 49 |
+
);
|
| 50 |
+
|
| 51 |
+
-- ββ Jobs βββββββββββββββββββββββββββββββββββββββββββββββββ-
|
| 52 |
+
CREATE TABLE IF NOT EXISTS jobs (
|
| 53 |
+
id TEXT PRIMARY KEY,
|
| 54 |
+
type TEXT NOT NULL, -- download|benchmark|sync
|
| 55 |
+
status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled
|
| 56 |
+
model_id TEXT REFERENCES models(id),
|
| 57 |
+
model_name TEXT,
|
| 58 |
+
progress REAL NOT NULL DEFAULT 0.0, -- 0.0β1.0
|
| 59 |
+
error TEXT,
|
| 60 |
+
meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data
|
| 61 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 62 |
+
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 63 |
+
started_at TEXT,
|
| 64 |
+
ended_at TEXT
|
| 65 |
+
);
|
| 66 |
+
|
| 67 |
+
-- ββ Projects βββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
CREATE TABLE IF NOT EXISTS projects (
|
| 69 |
+
id TEXT PRIMARY KEY,
|
| 70 |
+
name TEXT NOT NULL,
|
| 71 |
+
path TEXT NOT NULL,
|
| 72 |
+
created_at TEXT NOT NULL,
|
| 73 |
+
last_opened TEXT NOT NULL,
|
| 74 |
+
status TEXT NOT NULL DEFAULT 'idle'
|
| 75 |
+
);
|
| 76 |
+
|
| 77 |
+
-- ββ Session βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
-- Stores the currently active project so backend services
|
| 79 |
+
-- (e.g. download manager) can link assets into the workspace.
|
| 80 |
+
CREATE TABLE IF NOT EXISTS session (
|
| 81 |
+
key TEXT PRIMARY KEY,
|
| 82 |
+
value TEXT NOT NULL
|
| 83 |
+
);
|
| 84 |
+
|
| 85 |
+
CREATE UNIQUE INDEX IF NOT EXISTS idx_projects_path ON projects(path);
|
| 86 |
+
|
| 87 |
+
-- ββ Audit Log βββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
CREATE TABLE IF NOT EXISTS audit_log (
|
| 89 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 90 |
+
event_type TEXT NOT NULL, -- api_request|download_start|download_ok|error|sync
|
| 91 |
+
model_id TEXT,
|
| 92 |
+
job_id TEXT,
|
| 93 |
+
payload TEXT NOT NULL DEFAULT '{}', -- JSON
|
| 94 |
+
level TEXT NOT NULL DEFAULT 'info',
|
| 95 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
| 96 |
+
);
|
| 97 |
+
|
| 98 |
+
-- ββ FTS (Full-Text Search) ββββββββββββββββββββββββββββββββ
|
| 99 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS models_fts USING fts5(
|
| 100 |
+
id UNINDEXED,
|
| 101 |
+
name,
|
| 102 |
+
description,
|
| 103 |
+
tags,
|
| 104 |
+
provider,
|
| 105 |
+
task,
|
| 106 |
+
framework,
|
| 107 |
+
content='models',
|
| 108 |
+
content_rowid='rowid'
|
| 109 |
+
);
|
| 110 |
+
|
| 111 |
+
-- Triggers to keep FTS in sync
|
| 112 |
+
CREATE TRIGGER IF NOT EXISTS models_fts_insert AFTER INSERT ON models BEGIN
|
| 113 |
+
INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework)
|
| 114 |
+
VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework);
|
| 115 |
+
END;
|
| 116 |
+
|
| 117 |
+
CREATE TRIGGER IF NOT EXISTS models_fts_delete BEFORE DELETE ON models BEGIN
|
| 118 |
+
DELETE FROM models_fts WHERE rowid = old.rowid;
|
| 119 |
+
END;
|
| 120 |
+
|
| 121 |
+
CREATE TRIGGER IF NOT EXISTS models_fts_update AFTER UPDATE ON models BEGIN
|
| 122 |
+
DELETE FROM models_fts WHERE rowid = old.rowid;
|
| 123 |
+
INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework)
|
| 124 |
+
VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework);
|
| 125 |
+
END;
|
| 126 |
+
|
| 127 |
+
-- ββ Inference History ββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
CREATE TABLE IF NOT EXISTS inference_history (
|
| 129 |
+
id TEXT PRIMARY KEY,
|
| 130 |
+
model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE,
|
| 131 |
+
model_name TEXT NOT NULL,
|
| 132 |
+
adapter_type TEXT NOT NULL,
|
| 133 |
+
timestamp REAL NOT NULL DEFAULT (unixepoch('now')),
|
| 134 |
+
total_ms REAL NOT NULL DEFAULT 0.0,
|
| 135 |
+
quality_score REAL,
|
| 136 |
+
status TEXT NOT NULL DEFAULT 'ok',
|
| 137 |
+
request_snapshot TEXT NOT NULL DEFAULT '{}' -- JSON
|
| 138 |
+
);
|
| 139 |
+
|
| 140 |
+
CREATE INDEX IF NOT EXISTS idx_inference_model ON inference_history(model_id);
|
| 141 |
+
CREATE INDEX IF NOT EXISTS idx_inference_time ON inference_history(timestamp DESC);
|
| 142 |
+
|
| 143 |
+
-- ββ Indexes βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
CREATE INDEX IF NOT EXISTS idx_models_task ON models(task);
|
| 145 |
+
CREATE INDEX IF NOT EXISTS idx_models_framework ON models(framework);
|
| 146 |
+
CREATE INDEX IF NOT EXISTS idx_models_source ON models(source);
|
| 147 |
+
CREATE INDEX IF NOT EXISTS idx_models_status ON models(status);
|
| 148 |
+
CREATE INDEX IF NOT EXISTS idx_models_downloads ON models(downloads DESC);
|
| 149 |
+
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
| 150 |
+
CREATE INDEX IF NOT EXISTS idx_jobs_model ON jobs(model_id);
|
| 151 |
+
CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log(event_type);
|
| 152 |
+
CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log(created_at DESC);
|
main.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
main.py β FastAPI application entry point.
|
| 3 |
+
Wires together all modules, registers middleware/routes, manages lifespan.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# Ensure backend root is in sys.path to resolve 'backend.*' imports correctly
|
| 11 |
+
# when running from the 'backend' directory.
|
| 12 |
+
backend_root = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
if backend_root not in sys.path:
|
| 14 |
+
sys.path.insert(0, backend_root)
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
from typing import AsyncIterator
|
| 19 |
+
|
| 20 |
+
import traceback
|
| 21 |
+
|
| 22 |
+
from fastapi import FastAPI, Request
|
| 23 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 24 |
+
from fastapi.responses import JSONResponse
|
| 25 |
+
|
| 26 |
+
from api.routes import jobs as jobs_router
|
| 27 |
+
from api.routes import models as models_router
|
| 28 |
+
from api.routes import sync as sync_router
|
| 29 |
+
from api.routes import datasets as datasets_router
|
| 30 |
+
from api.routes import benchmark as benchmark_router
|
| 31 |
+
from api.routes import system as system_router
|
| 32 |
+
from api.routes import projects as projects_router
|
| 33 |
+
from api.routes import inference as inference_router
|
| 34 |
+
from api.routes import training as training_router
|
| 35 |
+
from config import settings
|
| 36 |
+
from database.connection import close_db, get_db
|
| 37 |
+
from middleware.logging_middleware import RequestLoggingMiddleware
|
| 38 |
+
from observability.logger import configure_logging, get_logger
|
| 39 |
+
|
| 40 |
+
# ββ Logging bootstrap (must be first) βββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
configure_logging()
|
| 42 |
+
log = get_logger("main")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Lifespan ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
@asynccontextmanager
|
| 47 |
+
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
| 48 |
+
# Startup
|
| 49 |
+
settings.ensure_dirs()
|
| 50 |
+
log.info("startup", host=settings.host, port=settings.port, version=settings.version)
|
| 51 |
+
await get_db() # Bootstrap DB / run migrations
|
| 52 |
+
log.info("database_ready", path=str(settings.db_path))
|
| 53 |
+
|
| 54 |
+
# Job Recovery (Cleanup stale imports/benchmarks)
|
| 55 |
+
try:
|
| 56 |
+
from datasets.import_service import recover_stale_jobs
|
| 57 |
+
await recover_stale_jobs()
|
| 58 |
+
except Exception as e:
|
| 59 |
+
log.error("job_recovery_failed", error=str(e))
|
| 60 |
+
|
| 61 |
+
if settings.auto_sync_on_startup:
|
| 62 |
+
from registry.registry import count_models
|
| 63 |
+
|
| 64 |
+
current = await count_models()
|
| 65 |
+
if current == 0:
|
| 66 |
+
from api.routes.sync import _run_full_sync
|
| 67 |
+
|
| 68 |
+
log.info("auto_sync_startup_triggered")
|
| 69 |
+
asyncio.create_task(_run_full_sync())
|
| 70 |
+
|
| 71 |
+
yield # β app runs
|
| 72 |
+
|
| 73 |
+
# Shutdown
|
| 74 |
+
await close_db()
|
| 75 |
+
log.info("shutdown")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ββ Application βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 79 |
+
app = FastAPI(
|
| 80 |
+
title=settings.app_name,
|
| 81 |
+
version=settings.version,
|
| 82 |
+
description="Production ML Model Zoo backend β local-first, traceable, extensible.",
|
| 83 |
+
docs_url="/docs",
|
| 84 |
+
redoc_url="/redoc",
|
| 85 |
+
lifespan=lifespan,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.exception_handler(Exception)
|
| 90 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 91 |
+
# Log full traceback for debugging 500s.
|
| 92 |
+
log.error(
|
| 93 |
+
"unhandled_exception",
|
| 94 |
+
path=request.url.path,
|
| 95 |
+
error=str(exc),
|
| 96 |
+
traceback=traceback.format_exc(),
|
| 97 |
+
)
|
| 98 |
+
return JSONResponse(
|
| 99 |
+
status_code=500,
|
| 100 |
+
content={"detail": "Internal Server Error", "error": str(exc)},
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# ββ Middleware βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
app.add_middleware(
|
| 105 |
+
CORSMiddleware,
|
| 106 |
+
allow_origins=settings.cors_origins,
|
| 107 |
+
allow_origin_regex=r"^https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?$",
|
| 108 |
+
allow_credentials=True,
|
| 109 |
+
allow_methods=["*"],
|
| 110 |
+
allow_headers=["*"],
|
| 111 |
+
)
|
| 112 |
+
app.add_middleware(RequestLoggingMiddleware)
|
| 113 |
+
|
| 114 |
+
# ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
app.include_router(models_router.router)
|
| 116 |
+
app.include_router(jobs_router.router)
|
| 117 |
+
app.include_router(sync_router.router)
|
| 118 |
+
app.include_router(datasets_router.router)
|
| 119 |
+
app.include_router(benchmark_router.router)
|
| 120 |
+
app.include_router(system_router.router)
|
| 121 |
+
app.include_router(projects_router.router)
|
| 122 |
+
app.include_router(inference_router.router)
|
| 123 |
+
app.include_router(training_router.router)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@app.get("/health", tags=["system"])
|
| 127 |
+
async def health() -> dict:
|
| 128 |
+
from registry.registry import count_models
|
| 129 |
+
from datasets.registry import count_datasets
|
| 130 |
+
n_models = await count_models()
|
| 131 |
+
n_datasets = await count_datasets()
|
| 132 |
+
return {
|
| 133 |
+
"status": "ok",
|
| 134 |
+
"version": settings.version,
|
| 135 |
+
"model_count": n_models,
|
| 136 |
+
"dataset_count": n_datasets,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ββ Dev runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
import uvicorn
|
| 143 |
+
uvicorn.run(
|
| 144 |
+
"main:app",
|
| 145 |
+
host=settings.host,
|
| 146 |
+
port=settings.port,
|
| 147 |
+
reload=settings.debug,
|
| 148 |
+
log_config=None, # We use structlog
|
| 149 |
+
)
|
middleware/__init__.py
ADDED
|
File without changes
|
middleware/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
middleware/__pycache__/logging_middleware.cpython-310.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
middleware/logging_middleware.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
middleware/logging_middleware.py β Structured request/response logging.
|
| 3 |
+
Attaches a trace_id to every request, logs timing, method, path, status.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from typing import Callable
|
| 10 |
+
|
| 11 |
+
from fastapi import Request, Response
|
| 12 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 13 |
+
|
| 14 |
+
from observability.logger import audit, get_logger, log_system_event
|
| 15 |
+
|
| 16 |
+
log = get_logger("http")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| 20 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 21 |
+
trace_id = str(uuid.uuid4())[:8]
|
| 22 |
+
request.state.trace_id = trace_id
|
| 23 |
+
start = time.perf_counter()
|
| 24 |
+
|
| 25 |
+
log_system_event(
|
| 26 |
+
level="info",
|
| 27 |
+
message=f"API Request: {request.method} {request.url.path}",
|
| 28 |
+
source="gateway",
|
| 29 |
+
payload={"trace_id": trace_id, "query": str(request.url.query)}
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
response = await call_next(request)
|
| 33 |
+
duration_ms = (time.perf_counter() - start) * 1000
|
| 34 |
+
|
| 35 |
+
log_system_event(
|
| 36 |
+
level="info" if response.status_code < 400 else "error",
|
| 37 |
+
message=f"API Response: {response.status_code} ({duration_ms:.1f}ms)",
|
| 38 |
+
source="gateway",
|
| 39 |
+
payload={"trace_id": trace_id, "status": response.status_code, "latency_ms": round(duration_ms, 2)}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
response.headers["X-Trace-Id"] = trace_id
|
| 43 |
+
response.headers["X-Response-Time"] = f"{duration_ms:.1f}ms"
|
| 44 |
+
|
| 45 |
+
# Audit slow requests
|
| 46 |
+
if duration_ms > 200:
|
| 47 |
+
await audit(
|
| 48 |
+
"slow_request",
|
| 49 |
+
payload={
|
| 50 |
+
"path": request.url.path,
|
| 51 |
+
"duration_ms": round(duration_ms, 2),
|
| 52 |
+
"trace_id": trace_id,
|
| 53 |
+
},
|
| 54 |
+
level="warning",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return response
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (141 Bytes). View file
|
|
|
models/__pycache__/benchmark.cpython-310.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
models/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
models/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
models/__pycache__/job.cpython-310.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
models/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
models/__pycache__/project.cpython-310.pyc
ADDED
|
Binary file (619 Bytes). View file
|
|
|
models/__pycache__/system.cpython-310.pyc
ADDED
|
Binary file (1.86 kB). View file
|
|
|
models/benchmark.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/benchmark.py β Pydantic domain models for the Benchmark Bridge System.
|
| 3 |
+
Single source of truth for all benchmark-related data shapes across API,
|
| 4 |
+
execution engine, and database layer.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ββ Input βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
|
| 16 |
+
class BenchmarkContext(BaseModel):
|
| 17 |
+
"""Payload the UI sends to initiate a benchmark run."""
|
| 18 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 19 |
+
model_id: str
|
| 20 |
+
dataset_id: str
|
| 21 |
+
task: str
|
| 22 |
+
framework: str
|
| 23 |
+
hardware: str = "cpu"
|
| 24 |
+
precision: str = "FP32"
|
| 25 |
+
batch_size: int = Field(1, ge=1, le=512)
|
| 26 |
+
# Task-specific overrides
|
| 27 |
+
max_tokens: int | None = 512
|
| 28 |
+
sequence_length: int | None = 512
|
| 29 |
+
img_size: int | None = 640
|
| 30 |
+
vid_stride: int | None = 1
|
| 31 |
+
stream: bool | None = False
|
| 32 |
+
input_source: str | None = "dataset"
|
| 33 |
+
video_path: str | None = None
|
| 34 |
+
rtsp_url: str | None = None
|
| 35 |
+
# Object Detection live preview data
|
| 36 |
+
detections: list[dict[str, Any]] = Field(default_factory=list)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ββ Validation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
class ValidationCheck(BaseModel):
|
| 42 |
+
"""Result of a single compatibility gate."""
|
| 43 |
+
name: str
|
| 44 |
+
passed: bool
|
| 45 |
+
detail: str
|
| 46 |
+
suggestion: str | None = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ValidationReport(BaseModel):
|
| 50 |
+
"""Aggregated result of all compatibility checks for a model+dataset pair."""
|
| 51 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 52 |
+
model_id: str
|
| 53 |
+
dataset_id: str
|
| 54 |
+
passed: bool # True only if ALL checks pass
|
| 55 |
+
checks: list[ValidationCheck]
|
| 56 |
+
errors: list[str] # details from failed checks
|
| 57 |
+
warnings: list[str] = Field(default_factory=list)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ββ Metrics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
|
| 62 |
+
class BenchmarkMetrics(BaseModel):
|
| 63 |
+
"""Task-specific + hardware performance metrics from a completed run."""
|
| 64 |
+
# Detection / Segmentation
|
| 65 |
+
mAP: float | None = None
|
| 66 |
+
mAP_50: float | None = None
|
| 67 |
+
mAP_50_95: float | None = None
|
| 68 |
+
# Classification
|
| 69 |
+
accuracy: float | None = None
|
| 70 |
+
top1: float | None = None
|
| 71 |
+
top5: float | None = None
|
| 72 |
+
# Segmentation
|
| 73 |
+
iou_mean: float | None = None
|
| 74 |
+
# NLP / Generation
|
| 75 |
+
rouge_l: float | None = None
|
| 76 |
+
bleu: float | None = None
|
| 77 |
+
perplexity: float | None = None
|
| 78 |
+
tokens_per_sec: float | None = None
|
| 79 |
+
# Throughput & Latency
|
| 80 |
+
fps: float | None = None
|
| 81 |
+
latency_mean_ms: float | None = None
|
| 82 |
+
latency_p95_ms: float | None = None
|
| 83 |
+
latency_p99_ms: float | None = None
|
| 84 |
+
# Memory
|
| 85 |
+
vram_peak_gb: float | None = None
|
| 86 |
+
vram_avg_gb: float | None = None
|
| 87 |
+
# Dataset info
|
| 88 |
+
total_images: int | None = None
|
| 89 |
+
total_tokens: int | None = None
|
| 90 |
+
batch_size: int | None = None
|
| 91 |
+
|
| 92 |
+
class Config:
|
| 93 |
+
extra = "allow"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ββ Telemetry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
|
| 98 |
+
class TelemetrySample(BaseModel):
|
| 99 |
+
"""Single hardware reading captured during benchmark execution."""
|
| 100 |
+
timestamp: float # Unix epoch seconds
|
| 101 |
+
gpu_util_pct: float = 0.0 # 0β100
|
| 102 |
+
vram_used_gb: float = 0.0
|
| 103 |
+
vram_total_gb: float = 0.0
|
| 104 |
+
temp_c: float = 0.0
|
| 105 |
+
power_w: float = 0.0
|
| 106 |
+
batch_idx: int = 0
|
| 107 |
+
progress: float = 0.0 # 0.0β1.0
|
| 108 |
+
# Optional task-specific live data (e.g. BBoxes for detection)
|
| 109 |
+
live_data: dict[str, Any] = Field(default_factory=dict)
|
| 110 |
+
detections: list[dict[str, Any]] = Field(default_factory=list)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LayerBreakdown(BaseModel):
|
| 114 |
+
"""Single layer entry in a bottleneck analysis."""
|
| 115 |
+
name: str
|
| 116 |
+
time_ms: float
|
| 117 |
+
percent: float
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class TelemetrySummary(BaseModel):
|
| 121 |
+
"""Aggregated telemetry statistics over the full benchmark run."""
|
| 122 |
+
gpu_util_avg: float = 0.0
|
| 123 |
+
gpu_util_peak: float = 0.0
|
| 124 |
+
vram_avg_gb: float = 0.0
|
| 125 |
+
vram_peak_gb: float = 0.0
|
| 126 |
+
temp_avg_c: float = 0.0
|
| 127 |
+
temp_peak_c: float = 0.0
|
| 128 |
+
power_avg_w: float = 0.0
|
| 129 |
+
power_peak_w: float = 0.0
|
| 130 |
+
layer_breakdown: list[LayerBreakdown] = Field(default_factory=list)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ββ Job & Result βββοΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 134 |
+
|
| 135 |
+
class BenchmarkJob(BaseModel):
|
| 136 |
+
id: str
|
| 137 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 138 |
+
model_id: str
|
| 139 |
+
dataset_id: str
|
| 140 |
+
task: str
|
| 141 |
+
framework: str
|
| 142 |
+
hardware: str
|
| 143 |
+
precision: str
|
| 144 |
+
batch_size: int
|
| 145 |
+
config: dict = Field(default_factory=dict)
|
| 146 |
+
status: str = "queued" # queued|running|completed|failed
|
| 147 |
+
progress: float = 0.0
|
| 148 |
+
logs: list[str] = Field(default_factory=list)
|
| 149 |
+
created_at: str | None = None
|
| 150 |
+
updated_at: str | None = None
|
| 151 |
+
started_at: str | None = None
|
| 152 |
+
ended_at: str | None = None
|
| 153 |
+
last_telemetry: TelemetrySample | None = None
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BenchmarkResult(BaseModel):
|
| 157 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 158 |
+
id: str
|
| 159 |
+
job_id: str
|
| 160 |
+
metrics: BenchmarkMetrics
|
| 161 |
+
telemetry_summary: TelemetrySummary
|
| 162 |
+
created_at: str | None = None
|
| 163 |
+
# Denormalized from Job for UI efficiency
|
| 164 |
+
model_id: str | None = None
|
| 165 |
+
dataset_id: str | None = None
|
| 166 |
+
task: str | None = None
|
| 167 |
+
framework: str | None = None
|
| 168 |
+
hardware: str | None = None
|
| 169 |
+
precision: str | None = None
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ββ API Responses βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
class BenchmarkRunResponse(BaseModel):
|
| 175 |
+
job_id: str
|
| 176 |
+
status: str
|
| 177 |
+
message: str
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ββ DB Row helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
+
|
| 182 |
+
def row_to_job(row: Any) -> BenchmarkJob:
|
| 183 |
+
d = dict(row)
|
| 184 |
+
cfg = json.loads(d.get("config") or "{}")
|
| 185 |
+
return BenchmarkJob(
|
| 186 |
+
id = d["id"],
|
| 187 |
+
model_id = d["model_id"],
|
| 188 |
+
dataset_id = d["dataset_id"],
|
| 189 |
+
task = d["task"],
|
| 190 |
+
framework = d["framework"],
|
| 191 |
+
hardware = d["hardware"],
|
| 192 |
+
precision = d["precision"],
|
| 193 |
+
batch_size = d["batch_size"],
|
| 194 |
+
config = cfg,
|
| 195 |
+
status = d["status"],
|
| 196 |
+
progress = float(d.get("progress", 0.0)),
|
| 197 |
+
logs = json.loads(d.get("logs") or "[]"),
|
| 198 |
+
error = d.get("error"),
|
| 199 |
+
created_at = d.get("created_at"),
|
| 200 |
+
updated_at = d.get("updated_at"),
|
| 201 |
+
started_at = d.get("started_at"),
|
| 202 |
+
ended_at = d.get("ended_at"),
|
| 203 |
+
last_telemetry = TelemetrySample(**json.loads(d.get("last_telemetry") or "{}")) if d.get("last_telemetry") else None,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def row_to_result(row: Any) -> BenchmarkResult:
|
| 208 |
+
d = dict(row)
|
| 209 |
+
metrics_raw = json.loads(d.get("metrics") or "{}")
|
| 210 |
+
telemetry_raw = json.loads(d.get("telemetry_summary") or "{}")
|
| 211 |
+
return BenchmarkResult(
|
| 212 |
+
id = d["id"],
|
| 213 |
+
job_id = d["job_id"],
|
| 214 |
+
metrics = BenchmarkMetrics(**metrics_raw),
|
| 215 |
+
telemetry_summary = TelemetrySummary(**telemetry_raw),
|
| 216 |
+
created_at = d.get("created_at"),
|
| 217 |
+
model_id = d.get("model_id"),
|
| 218 |
+
dataset_id = d.get("dataset_id"),
|
| 219 |
+
task = d.get("task"),
|
| 220 |
+
framework = d.get("framework"),
|
| 221 |
+
hardware = d.get("hardware"),
|
| 222 |
+
precision = d.get("precision"),
|
| 223 |
+
)
|
models/dataset.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/dataset.py β Pydantic domain models for the Dataset Manager.
|
| 3 |
+
Single source of truth for all dataset-related data shapes.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ββ Universal Dataset Viewer (UDV) Models ββββββββββββββββββββββββββββββββββ
|
| 16 |
+
|
| 17 |
+
class DatasetContentType(str, Enum):
|
| 18 |
+
image = "image"
|
| 19 |
+
text = "text"
|
| 20 |
+
audio = "audio"
|
| 21 |
+
tabular = "tabular"
|
| 22 |
+
|
| 23 |
+
class UniversalAnnotationType(str, Enum):
|
| 24 |
+
detection = "detection"
|
| 25 |
+
segmentation = "segmentation"
|
| 26 |
+
keypoints = "keypoints"
|
| 27 |
+
classification = "classification"
|
| 28 |
+
span = "span"
|
| 29 |
+
|
| 30 |
+
class UniversalAnnotation(BaseModel):
|
| 31 |
+
label: str
|
| 32 |
+
type: UniversalAnnotationType
|
| 33 |
+
bbox: Optional[list[float]] = None # [x, y, w, h] normalized
|
| 34 |
+
segmentation: Optional[list[list[float]]] = None # [[x1, y1, x2, y2, ...], ...]
|
| 35 |
+
keypoints: Optional[list[float]] = None # [x1, y1, v1, ...]
|
| 36 |
+
confidence: Optional[float] = None
|
| 37 |
+
metadata: Optional[dict[str, Any]] = None
|
| 38 |
+
|
| 39 |
+
class UniversalDatasetItem(BaseModel):
|
| 40 |
+
id: str
|
| 41 |
+
content_type: DatasetContentType
|
| 42 |
+
content_url: Optional[str] = None
|
| 43 |
+
content_body: Optional[str] = None # For text or raw json
|
| 44 |
+
filename: Optional[str] = None
|
| 45 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 46 |
+
annotations: list[UniversalAnnotation] = Field(default_factory=list)
|
| 47 |
+
|
| 48 |
+
class UniversalViewerPage(BaseModel):
|
| 49 |
+
dataset_id: str
|
| 50 |
+
page: int
|
| 51 |
+
page_size: int
|
| 52 |
+
total: int
|
| 53 |
+
total_pages: int
|
| 54 |
+
items: list[UniversalDatasetItem]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ββ Enumerations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
|
| 59 |
+
class DatasetTask(str, Enum):
|
| 60 |
+
detection = "detection"
|
| 61 |
+
classification = "classification"
|
| 62 |
+
segmentation = "segmentation"
|
| 63 |
+
nlp = "nlp"
|
| 64 |
+
generation = "generation"
|
| 65 |
+
keypoints = "keypoints"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DatasetFormat(str, Enum):
|
| 69 |
+
yolo = "yolo"
|
| 70 |
+
coco = "coco"
|
| 71 |
+
voc = "voc"
|
| 72 |
+
csv = "csv"
|
| 73 |
+
json = "json"
|
| 74 |
+
tfrecord = "tfrecord"
|
| 75 |
+
custom = "custom"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DatasetSource(str, Enum):
|
| 79 |
+
roboflow = "roboflow"
|
| 80 |
+
roboflow_curl = "roboflow_curl" # direct cURL / pre-signed URL download
|
| 81 |
+
local = "local"
|
| 82 |
+
huggingface = "huggingface"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class DatasetStatus(str, Enum):
|
| 86 |
+
available = "available"
|
| 87 |
+
queued = "queued"
|
| 88 |
+
importing = "importing"
|
| 89 |
+
extracting = "extracting"
|
| 90 |
+
validating = "validating"
|
| 91 |
+
imported = "imported"
|
| 92 |
+
failed = "failed"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class JobType(str, Enum):
|
| 96 |
+
import_ = "import"
|
| 97 |
+
extract = "extract"
|
| 98 |
+
validate = "validate"
|
| 99 |
+
analyze = "analyze"
|
| 100 |
+
delete = "delete"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class JobStatus(str, Enum):
|
| 104 |
+
queued = "queued"
|
| 105 |
+
running = "running"
|
| 106 |
+
completed = "completed"
|
| 107 |
+
failed = "failed"
|
| 108 |
+
cancelled = "cancelled"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class AnnotationType(str, Enum):
|
| 112 |
+
detection = "detection"
|
| 113 |
+
segmentation = "segmentation"
|
| 114 |
+
classification = "classification"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββ Sub-models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
class DatasetSplit(BaseModel):
|
| 120 |
+
train: int = 0
|
| 121 |
+
val: int = 0
|
| 122 |
+
test: int = 0
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def total(self) -> int:
|
| 126 |
+
return self.train + self.val + self.test
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DatasetVersion(BaseModel):
|
| 130 |
+
version: str
|
| 131 |
+
date: str = ""
|
| 132 |
+
changes: str = ""
|
| 133 |
+
images: int = 0
|
| 134 |
+
format: str = ""
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class DatasetStats(BaseModel):
|
| 138 |
+
"""Aggregate statistics computed during import/analysis."""
|
| 139 |
+
image_count: int = 0
|
| 140 |
+
annotation_count: int = 0
|
| 141 |
+
class_count: int = 0
|
| 142 |
+
avg_objects: float = 0.0
|
| 143 |
+
missing_labels: int = 0
|
| 144 |
+
empty_images: int = 0
|
| 145 |
+
duplicate_count: int = 0
|
| 146 |
+
health_score: float = 0.0
|
| 147 |
+
split: DatasetSplit = Field(default_factory=DatasetSplit)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ββ Core Domain Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
+
|
| 152 |
+
class Dataset(BaseModel):
|
| 153 |
+
model_config = ConfigDict(protected_namespaces=(), use_enum_values=True)
|
| 154 |
+
id: str
|
| 155 |
+
name: str
|
| 156 |
+
description: str = ""
|
| 157 |
+
task: DatasetTask
|
| 158 |
+
format: DatasetFormat
|
| 159 |
+
source: DatasetSource
|
| 160 |
+
status: DatasetStatus = DatasetStatus.available
|
| 161 |
+
images: int = 0
|
| 162 |
+
classes: int = 0
|
| 163 |
+
class_names: list[str] = Field(default_factory=list)
|
| 164 |
+
size_bytes: int = 0
|
| 165 |
+
size_label: str = "0 B"
|
| 166 |
+
local_path: str | None = None
|
| 167 |
+
import_progress: float = 0.0 # 0.0β1.0
|
| 168 |
+
tags: list[str] = Field(default_factory=list)
|
| 169 |
+
versions: list[DatasetVersion] = Field(default_factory=list)
|
| 170 |
+
active_version: str = "v1"
|
| 171 |
+
stats: DatasetStats = Field(default_factory=DatasetStats)
|
| 172 |
+
starred: bool = False
|
| 173 |
+
roboflow_id: str | None = None # workspace/project slug
|
| 174 |
+
created_at: str | None = None
|
| 175 |
+
updated_at: str | None = None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class DatasetSummary(BaseModel):
|
| 179 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 180 |
+
"""Lightweight projection for list endpoints."""
|
| 181 |
+
id: str
|
| 182 |
+
name: str
|
| 183 |
+
task: str
|
| 184 |
+
format: str
|
| 185 |
+
source: str
|
| 186 |
+
status: str
|
| 187 |
+
images: int
|
| 188 |
+
classes: int
|
| 189 |
+
size_label: str
|
| 190 |
+
tags: list[str]
|
| 191 |
+
starred: bool
|
| 192 |
+
import_progress: float
|
| 193 |
+
health_score: float = 0.0
|
| 194 |
+
created_at: str | None = None
|
| 195 |
+
updated_at: str | None = None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ββ Annotation Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 199 |
+
|
| 200 |
+
class BoundingBox(BaseModel):
|
| 201 |
+
x: float # top-left x (pixels or normalised)
|
| 202 |
+
y: float # top-left y
|
| 203 |
+
width: float
|
| 204 |
+
height: float
|
| 205 |
+
normalised: bool = True # True β 0β1 range, False β pixel coords
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class Annotation(BaseModel):
|
| 209 |
+
"""Unified annotation record (format-agnostic)."""
|
| 210 |
+
label: str
|
| 211 |
+
bbox: BoundingBox | None = None
|
| 212 |
+
segmentation: list[list[float]] | None = None # polygon points
|
| 213 |
+
keypoints: list[float] | None = None # [x, y, v, ...]
|
| 214 |
+
metadata: dict[str, Any] | None = None
|
| 215 |
+
confidence: float | None = None
|
| 216 |
+
area: float | None = None
|
| 217 |
+
type: AnnotationType = AnnotationType.detection
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ImageRecord(BaseModel):
|
| 221 |
+
"""Image + its parsed annotations β returned by viewer endpoints."""
|
| 222 |
+
image_id: str
|
| 223 |
+
filename: str
|
| 224 |
+
width: int = 0
|
| 225 |
+
height: int = 0
|
| 226 |
+
path: str # relative to dataset root
|
| 227 |
+
annotations: list[Annotation] = Field(default_factory=list)
|
| 228 |
+
split: str = "train" # train|val|test
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ViewerPage(BaseModel):
|
| 232 |
+
"""Paginated viewer response."""
|
| 233 |
+
dataset_id: str
|
| 234 |
+
page: int
|
| 235 |
+
page_size: int
|
| 236 |
+
total: int
|
| 237 |
+
total_pages: int
|
| 238 |
+
images: list[ImageRecord]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ββ Job Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
|
| 243 |
+
class DatasetJob(BaseModel):
|
| 244 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 245 |
+
id: str
|
| 246 |
+
type: str
|
| 247 |
+
status: str
|
| 248 |
+
dataset_id: str
|
| 249 |
+
dataset_name: str
|
| 250 |
+
progress: float = 0.0 # 0.0β1.0
|
| 251 |
+
message: str = ""
|
| 252 |
+
error: str | None = None
|
| 253 |
+
created_at: str | None = None
|
| 254 |
+
updated_at: str | None = None
|
| 255 |
+
started_at: str | None = None
|
| 256 |
+
ended_at: str | None = None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ββ Request/Response Schemas βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
|
| 261 |
+
class ImportRequest(BaseModel):
|
| 262 |
+
dataset_id: str
|
| 263 |
+
source: DatasetSource
|
| 264 |
+
roboflow_key: str | None = None # required when source=roboflow
|
| 265 |
+
roboflow_workspace: str | None = None
|
| 266 |
+
roboflow_project: str | None = None
|
| 267 |
+
roboflow_version: int = 1
|
| 268 |
+
hf_dataset_id: str | None = None # required when source=huggingface (e.g. "microsoft/coco")
|
| 269 |
+
format: DatasetFormat = DatasetFormat.yolo
|
| 270 |
+
local_path: str | None = None # required when source=local
|
| 271 |
+
# cURL / direct download (source=roboflow_curl)
|
| 272 |
+
download_url: str | None = None # pre-signed or direct download URL
|
| 273 |
+
headers: dict[str, str] = Field(default_factory=dict) # Custom headers for download
|
| 274 |
+
dataset_name: str | None = None # human-readable name override
|
| 275 |
+
name: str | None = None # alias for dataset_name (used in local folder import)
|
| 276 |
+
curl_format: str | None = None # export format label from Roboflow cURL (e.g. "yolov8")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class ImportResponse(BaseModel):
|
| 280 |
+
job_id: str
|
| 281 |
+
dataset_id: str
|
| 282 |
+
status: str
|
| 283 |
+
message: str
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class RoboflowSearchRequest(BaseModel):
|
| 287 |
+
query: str = ""
|
| 288 |
+
api_key: str
|
| 289 |
+
workspace: str | None = None
|
| 290 |
+
page: int = 0
|
| 291 |
+
page_size: int = 50
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ββ DB Row helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
|
| 296 |
+
def row_to_dataset(row: Any) -> Dataset:
|
| 297 |
+
"""
|
| 298 |
+
Robustly convert a DB row (sqlite3.Row or dict) to a Dataset model.
|
| 299 |
+
Handles:
|
| 300 |
+
1. Enum string cleaning (stripping prefixes like 'DatasetStatus.')
|
| 301 |
+
2. JSON parsing for nested fields (tags, class_names, versions)
|
| 302 |
+
3. Missing 'stats' object initialization
|
| 303 |
+
"""
|
| 304 |
+
import logging
|
| 305 |
+
logger = logging.getLogger("models.dataset")
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
d = dict(row) if not isinstance(row, dict) else row.copy()
|
| 309 |
+
|
| 310 |
+
def clean_enum(val: Any) -> Any:
|
| 311 |
+
if isinstance(val, str) and "." in val:
|
| 312 |
+
return val.split(".")[-1]
|
| 313 |
+
return val
|
| 314 |
+
|
| 315 |
+
# Clean enum fields
|
| 316 |
+
for field in ["status", "task", "format", "source"]:
|
| 317 |
+
if field in d:
|
| 318 |
+
d[field] = clean_enum(d[field])
|
| 319 |
+
|
| 320 |
+
# Parse JSON fields with safety
|
| 321 |
+
for field in ["class_names", "tags", "versions"]:
|
| 322 |
+
raw = d.get(field)
|
| 323 |
+
if isinstance(raw, str):
|
| 324 |
+
try:
|
| 325 |
+
d[field] = json.loads(raw)
|
| 326 |
+
except Exception:
|
| 327 |
+
d[field] = []
|
| 328 |
+
elif raw is None:
|
| 329 |
+
d[field] = []
|
| 330 |
+
|
| 331 |
+
# Handle 'stats' - it might be a JSON string or missing in DB
|
| 332 |
+
stats_obj = DatasetStats()
|
| 333 |
+
stats_raw = d.get("stats")
|
| 334 |
+
if isinstance(stats_raw, str):
|
| 335 |
+
try:
|
| 336 |
+
stats_data = json.loads(stats_raw)
|
| 337 |
+
stats_obj = DatasetStats(**stats_data)
|
| 338 |
+
except Exception:
|
| 339 |
+
pass
|
| 340 |
+
elif isinstance(stats_raw, dict):
|
| 341 |
+
try:
|
| 342 |
+
stats_obj = DatasetStats(**stats_raw)
|
| 343 |
+
except Exception:
|
| 344 |
+
pass
|
| 345 |
+
|
| 346 |
+
# Ensure other numeric/boolean fields have defaults
|
| 347 |
+
d["images"] = d.get("images", 0)
|
| 348 |
+
d["classes"] = d.get("classes", 0)
|
| 349 |
+
d["starred"] = bool(d.get("starred", 0))
|
| 350 |
+
d["import_progress"] = float(d.get("import_progress", 0.0))
|
| 351 |
+
d["size_bytes"] = d.get("size_bytes", 0)
|
| 352 |
+
|
| 353 |
+
# Build clean dict for Pydantic
|
| 354 |
+
clean_data = {
|
| 355 |
+
"id": d["id"],
|
| 356 |
+
"name": d["name"],
|
| 357 |
+
"description": d.get("description", ""),
|
| 358 |
+
"task": d["task"],
|
| 359 |
+
"format": d["format"],
|
| 360 |
+
"source": d["source"],
|
| 361 |
+
"status": d.get("status", "available"),
|
| 362 |
+
"images": d["images"],
|
| 363 |
+
"classes": d["classes"],
|
| 364 |
+
"class_names": d["class_names"],
|
| 365 |
+
"size_bytes": d["size_bytes"],
|
| 366 |
+
"size_label": d.get("size_label", "0 B"),
|
| 367 |
+
"local_path": d.get("local_path"),
|
| 368 |
+
"import_progress": d["import_progress"],
|
| 369 |
+
"tags": d["tags"],
|
| 370 |
+
"versions": d["versions"],
|
| 371 |
+
"active_version": d.get("active_version", "v1"),
|
| 372 |
+
"stats": stats_obj,
|
| 373 |
+
"starred": d["starred"],
|
| 374 |
+
"roboflow_id": d.get("roboflow_id"),
|
| 375 |
+
"created_at": d.get("created_at"),
|
| 376 |
+
"updated_at": d.get("updated_at")
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
return Dataset(**clean_data)
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"Pydantic instantiation error: {e}, row keys: {list(row.keys()) if hasattr(row, 'keys') else 'N/A'}")
|
| 383 |
+
raise
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def row_to_job(row: Any) -> DatasetJob:
|
| 387 |
+
d = dict(row)
|
| 388 |
+
return DatasetJob(
|
| 389 |
+
id = d["id"],
|
| 390 |
+
type = d["type"],
|
| 391 |
+
status = d["status"],
|
| 392 |
+
dataset_id = d.get("dataset_id", ""),
|
| 393 |
+
dataset_name = d.get("dataset_name", ""),
|
| 394 |
+
progress = float(d.get("progress", 0.0)),
|
| 395 |
+
message = d.get("message", ""),
|
| 396 |
+
error = d.get("error"),
|
| 397 |
+
created_at = d.get("created_at"),
|
| 398 |
+
updated_at = d.get("updated_at"),
|
| 399 |
+
started_at = d.get("started_at"),
|
| 400 |
+
ended_at = d.get("ended_at"),
|
| 401 |
+
)
|
models/inference.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/inference.py β Pydantic models for the Inference Engine.
|
| 3 |
+
Covers request, response, session history, and pipeline stage telemetry.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Any, Literal
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
import time
|
| 11 |
+
import uuid
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AdapterType(str, Enum):
|
| 15 |
+
YOLO = "yolo"
|
| 16 |
+
TRANSFORMERS = "transformers"
|
| 17 |
+
ONNX = "onnx"
|
| 18 |
+
CUSTOM = "custom"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InferencePrecision(str, Enum):
|
| 22 |
+
FP32 = "FP32"
|
| 23 |
+
FP16 = "FP16"
|
| 24 |
+
INT8 = "INT8"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class YOLOConfig(BaseModel):
|
| 28 |
+
confidence: float = Field(0.25, ge=0.0, le=1.0)
|
| 29 |
+
iou_threshold: float = Field(0.45, ge=0.1, le=0.9)
|
| 30 |
+
class_filter: list[str] = Field(default_factory=list)
|
| 31 |
+
max_detections: int = Field(300, ge=1, le=1000)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TransformersConfig(BaseModel):
|
| 35 |
+
max_new_tokens: int = Field(256, ge=1, le=4096)
|
| 36 |
+
temperature: float = Field(0.7, ge=0.0, le=2.0)
|
| 37 |
+
top_p: float = Field(0.9, ge=0.0, le=1.0)
|
| 38 |
+
top_k: int = Field(50, ge=0, le=200)
|
| 39 |
+
beam_width: int = Field(1, ge=1, le=8)
|
| 40 |
+
do_sample: bool = True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ONNXConfig(BaseModel):
|
| 44 |
+
execution_provider: Literal["CUDAExecutionProvider", "CPUExecutionProvider"] = "CUDAExecutionProvider"
|
| 45 |
+
input_size: int = Field(640, ge=32, le=1280)
|
| 46 |
+
normalize: bool = True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CustomConfig(BaseModel):
|
| 50 |
+
preprocess_script: str = ""
|
| 51 |
+
postprocess_script: str = ""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class InferenceRequest(BaseModel):
|
| 55 |
+
model_id: str
|
| 56 |
+
adapter_type: AdapterType
|
| 57 |
+
precision: InferencePrecision = InferencePrecision.FP16
|
| 58 |
+
|
| 59 |
+
# Input β one of these must be set
|
| 60 |
+
image_base64: str | None = None # base64-encoded image
|
| 61 |
+
text_input: str | None = None # text/prompt
|
| 62 |
+
|
| 63 |
+
# Per-adapter config
|
| 64 |
+
yolo_config: YOLOConfig | None = None
|
| 65 |
+
transformers_config: TransformersConfig | None = None
|
| 66 |
+
onnx_config: ONNXConfig | None = None
|
| 67 |
+
custom_config: CustomConfig | None = None
|
| 68 |
+
|
| 69 |
+
# Execution
|
| 70 |
+
run_mode: Literal["single", "stream"] = "single"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class PipelineStage(BaseModel):
|
| 74 |
+
name: str
|
| 75 |
+
status: Literal["pending", "running", "done", "error"] = "pending"
|
| 76 |
+
latency_ms: float | None = None
|
| 77 |
+
detail: str | None = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Detection(BaseModel):
|
| 81 |
+
x1: float
|
| 82 |
+
y1: float
|
| 83 |
+
x2: float
|
| 84 |
+
y2: float
|
| 85 |
+
confidence: float
|
| 86 |
+
class_id: int
|
| 87 |
+
class_name: str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InferenceResult(BaseModel):
|
| 91 |
+
# Identity
|
| 92 |
+
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 93 |
+
model_id: str
|
| 94 |
+
adapter_type: AdapterType
|
| 95 |
+
timestamp: float = Field(default_factory=time.time)
|
| 96 |
+
|
| 97 |
+
# Timing
|
| 98 |
+
preprocess_ms: float = 0.0
|
| 99 |
+
inference_ms: float = 0.0
|
| 100 |
+
postprocess_ms: float = 0.0
|
| 101 |
+
total_ms: float = 0.0
|
| 102 |
+
|
| 103 |
+
# Output β adapter-specific, all optional
|
| 104 |
+
detections: list[Detection] = Field(default_factory=list)
|
| 105 |
+
text_output: str | None = None
|
| 106 |
+
class_label: str | None = None
|
| 107 |
+
confidence: float | None = None
|
| 108 |
+
embeddings: list[float] | None = None
|
| 109 |
+
raw_output: Any = None # raw JSON for inspector
|
| 110 |
+
|
| 111 |
+
# Pipeline trace
|
| 112 |
+
pipeline: list[PipelineStage] = Field(default_factory=list)
|
| 113 |
+
|
| 114 |
+
# Quality score (0β5) derived from confidence mean
|
| 115 |
+
quality_score: float | None = None
|
| 116 |
+
|
| 117 |
+
# Error
|
| 118 |
+
error: str | None = None
|
| 119 |
+
status: Literal["ok", "error"] = "ok"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class InferenceHistoryEntry(BaseModel):
|
| 123 |
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 124 |
+
model_id: str
|
| 125 |
+
model_name: str
|
| 126 |
+
adapter_type: AdapterType
|
| 127 |
+
timestamp: float = Field(default_factory=time.time)
|
| 128 |
+
total_ms: float
|
| 129 |
+
quality_score: float | None
|
| 130 |
+
status: Literal["ok", "error"]
|
| 131 |
+
# Compact snapshot of result for re-run
|
| 132 |
+
request_snapshot: dict[str, Any] = Field(default_factory=dict)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SystemVitals(BaseModel):
|
| 136 |
+
ts: float
|
| 137 |
+
latency_ms: float
|
| 138 |
+
fps: float
|
| 139 |
+
vram_used_gb: float
|
| 140 |
+
vram_total_gb: float
|
| 141 |
+
gpu_temp_c: float | None = None
|
| 142 |
+
cpu_pct: float = 0.0
|
models/job.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/job.py β Job domain models (download / benchmark / sync).
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Job(BaseModel):
|
| 13 |
+
model_config = {"protected_namespaces": ()}
|
| 14 |
+
id: str
|
| 15 |
+
type: str # download|benchmark|sync
|
| 16 |
+
status: str # queued|running|completed|failed|cancelled
|
| 17 |
+
model_id: str | None = None
|
| 18 |
+
model_name: str | None = None
|
| 19 |
+
progress: float = 0.0 # 0.0β1.0
|
| 20 |
+
error: str | None = None
|
| 21 |
+
meta: dict[str, Any] = Field(default_factory=dict)
|
| 22 |
+
created_at: str | None = None
|
| 23 |
+
updated_at: str | None = None
|
| 24 |
+
started_at: str | None = None
|
| 25 |
+
ended_at: str | None = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class JobCreate(BaseModel):
|
| 29 |
+
model_config = {"protected_namespaces": ()}
|
| 30 |
+
model_id: str
|
| 31 |
+
model_name: str
|
| 32 |
+
type: str = "download"
|
| 33 |
+
version: str | None = None # specific weight file / version to download
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def row_to_job(row: Any) -> Job:
|
| 37 |
+
d = dict(row)
|
| 38 |
+
return Job(
|
| 39 |
+
id = d["id"],
|
| 40 |
+
type = d["type"],
|
| 41 |
+
status = d["status"],
|
| 42 |
+
model_id = d.get("model_id"),
|
| 43 |
+
model_name = d.get("model_name"),
|
| 44 |
+
progress = float(d.get("progress", 0.0)),
|
| 45 |
+
error = d.get("error"),
|
| 46 |
+
meta = json.loads(d.get("meta") or "{}"),
|
| 47 |
+
created_at = d.get("created_at"),
|
| 48 |
+
updated_at = d.get("updated_at"),
|
| 49 |
+
started_at = d.get("started_at"),
|
| 50 |
+
ended_at = d.get("ended_at"),
|
| 51 |
+
)
|
models/model.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models/model.py β Pydantic domain models (schema contract for API + internal).
|
| 3 |
+
Single source of truth for data shapes between all modules.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
| 12 |
+
|
| 13 |
+
# ββ Enumerations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
|
| 15 |
+
ModelTask = str # detection|classification|segmentation|generation|embedding|nlp
|
| 16 |
+
ModelFramework = str # pytorch|onnx|tensorflow|tflite|coreml
|
| 17 |
+
ModelSource = str # hf|onnx|local
|
| 18 |
+
ModelStatus = str # available|downloading|cached|error
|
| 19 |
+
HardwareTarget = str # gpu|cpu|edge|tpu
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ββ Sub-models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
+
class ModelMetrics(BaseModel):
|
| 25 |
+
latency_ms: float | None = None
|
| 26 |
+
mAP: float | None = None
|
| 27 |
+
accuracy: float | None = None
|
| 28 |
+
top1: float | None = None
|
| 29 |
+
vram_gb: float | None = None
|
| 30 |
+
fps: float | None = None
|
| 31 |
+
flops: float | None = None
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
extra = "allow"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ModelVersion(BaseModel):
|
| 38 |
+
version: str
|
| 39 |
+
label: str = "Stable" # Latest|Stable|Legacy|Nano|Small|Medium|Large|XLarge
|
| 40 |
+
description: str | None = None
|
| 41 |
+
releaseDate: str = ""
|
| 42 |
+
changelog: str | None = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Core Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
class Model(BaseModel):
|
| 48 |
+
id: str
|
| 49 |
+
name: str
|
| 50 |
+
variant: str | None = None
|
| 51 |
+
task: ModelTask
|
| 52 |
+
framework: ModelFramework
|
| 53 |
+
size: int = 0 # bytes
|
| 54 |
+
size_label: str = "0 B"
|
| 55 |
+
tags: list[str] = Field(default_factory=list)
|
| 56 |
+
source: ModelSource = "hf"
|
| 57 |
+
provider: str = ""
|
| 58 |
+
description: str = ""
|
| 59 |
+
download_url: str | None = None # explicit download source (HF repo URL, ONNX direct URL, etc.)
|
| 60 |
+
local_path: str | None = None
|
| 61 |
+
project_id: str | None = None
|
| 62 |
+
downloaded: bool = False
|
| 63 |
+
status: ModelStatus = "available"
|
| 64 |
+
hardware: list[HardwareTarget] = Field(default_factory=list)
|
| 65 |
+
metrics: ModelMetrics = Field(default_factory=ModelMetrics)
|
| 66 |
+
versions: list[ModelVersion] = Field(default_factory=list)
|
| 67 |
+
active_version: str | None = None
|
| 68 |
+
rating: float | None = None
|
| 69 |
+
downloads: int | None = None
|
| 70 |
+
liked: bool = False
|
| 71 |
+
created_at: str | None = None
|
| 72 |
+
updated_at: str | None = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ModelSummary(BaseModel):
|
| 76 |
+
"""Lightweight projection returned in list endpoints."""
|
| 77 |
+
id: str
|
| 78 |
+
name: str
|
| 79 |
+
task: ModelTask
|
| 80 |
+
framework: ModelFramework
|
| 81 |
+
source: ModelSource
|
| 82 |
+
provider: str
|
| 83 |
+
size_label: str
|
| 84 |
+
status: ModelStatus
|
| 85 |
+
downloaded: bool
|
| 86 |
+
downloads: int | None = None
|
| 87 |
+
rating: float | None = None
|
| 88 |
+
tags: list[str]
|
| 89 |
+
hardware: list[HardwareTarget]
|
| 90 |
+
metrics: ModelMetrics
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ββ DB Row β Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 94 |
+
|
| 95 |
+
def row_to_model(row: Any, versions: list[ModelVersion] | None = None) -> Model:
|
| 96 |
+
"""Convert an aiosqlite Row dict to a Model instance."""
|
| 97 |
+
d = dict(row)
|
| 98 |
+
metrics_raw = d.get("metrics") or "{}"
|
| 99 |
+
# metrics may come from model_versions join or not exist on models row
|
| 100 |
+
if isinstance(metrics_raw, str):
|
| 101 |
+
metrics_raw = json.loads(metrics_raw)
|
| 102 |
+
|
| 103 |
+
return Model(
|
| 104 |
+
id = d["id"],
|
| 105 |
+
name = d["name"],
|
| 106 |
+
variant = d.get("variant"),
|
| 107 |
+
task = d["task"],
|
| 108 |
+
framework = d["framework"],
|
| 109 |
+
source = d.get("source", "hf"),
|
| 110 |
+
provider = d.get("provider", ""),
|
| 111 |
+
description = d.get("description", ""),
|
| 112 |
+
download_url= d.get("download_url"),
|
| 113 |
+
size = d.get("size", 0),
|
| 114 |
+
size_label = d.get("size_label", "0 B"),
|
| 115 |
+
tags = json.loads(d.get("tags") or "[]"),
|
| 116 |
+
hardware = json.loads(d.get("hardware") or "[]"),
|
| 117 |
+
status = d.get("status", "available"),
|
| 118 |
+
downloaded = bool(d.get("downloaded", 0)),
|
| 119 |
+
local_path = d.get("local_path"),
|
| 120 |
+
project_id = d.get("project_id"),
|
| 121 |
+
downloads = d.get("downloads"),
|
| 122 |
+
rating = d.get("rating"),
|
| 123 |
+
liked = bool(d.get("liked", 0)),
|
| 124 |
+
metrics = ModelMetrics(**metrics_raw),
|
| 125 |
+
versions = versions or [],
|
| 126 |
+
active_version = d.get("active_version"),
|
| 127 |
+
created_at = d.get("created_at"),
|
| 128 |
+
updated_at = d.get("updated_at"),
|
| 129 |
+
)
|