senthil2421 commited on
Commit
ac5551d
Β·
0 Parent(s):

Deploy cloud brain to HF Spaces

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. Dockerfile +29 -0
  2. api/__init__.py +0 -0
  3. api/__pycache__/__init__.cpython-310.pyc +0 -0
  4. api/routes/__init__.py +0 -0
  5. api/routes/__pycache__/__init__.cpython-310.pyc +0 -0
  6. api/routes/__pycache__/benchmark.cpython-310.pyc +0 -0
  7. api/routes/__pycache__/datasets.cpython-310.pyc +0 -0
  8. api/routes/__pycache__/inference.cpython-310.pyc +0 -0
  9. api/routes/__pycache__/jobs.cpython-310.pyc +0 -0
  10. api/routes/__pycache__/models.cpython-310.pyc +0 -0
  11. api/routes/__pycache__/projects.cpython-310.pyc +0 -0
  12. api/routes/__pycache__/sync.cpython-310.pyc +0 -0
  13. api/routes/__pycache__/system.cpython-310.pyc +0 -0
  14. api/routes/__pycache__/training.cpython-310.pyc +0 -0
  15. api/routes/benchmark.py +238 -0
  16. api/routes/datasets.py +395 -0
  17. api/routes/inference.py +168 -0
  18. api/routes/jobs.py +56 -0
  19. api/routes/models.py +127 -0
  20. api/routes/projects.py +54 -0
  21. api/routes/sync.py +73 -0
  22. api/routes/system.py +97 -0
  23. api/routes/training.py +428 -0
  24. config.py +83 -0
  25. database/__init__.py +0 -0
  26. database/__pycache__/__init__.cpython-310.pyc +0 -0
  27. database/__pycache__/connection.cpython-310.pyc +0 -0
  28. database/benchmark_schema.sql +62 -0
  29. database/connection.py +106 -0
  30. database/dataset_schema.sql +117 -0
  31. database/schema.sql +152 -0
  32. main.py +149 -0
  33. middleware/__init__.py +0 -0
  34. middleware/__pycache__/__init__.cpython-310.pyc +0 -0
  35. middleware/__pycache__/logging_middleware.cpython-310.pyc +0 -0
  36. middleware/logging_middleware.py +57 -0
  37. models/__init__.py +0 -0
  38. models/__pycache__/__init__.cpython-310.pyc +0 -0
  39. models/__pycache__/benchmark.cpython-310.pyc +0 -0
  40. models/__pycache__/dataset.cpython-310.pyc +0 -0
  41. models/__pycache__/inference.cpython-310.pyc +0 -0
  42. models/__pycache__/job.cpython-310.pyc +0 -0
  43. models/__pycache__/model.cpython-310.pyc +0 -0
  44. models/__pycache__/project.cpython-310.pyc +0 -0
  45. models/__pycache__/system.cpython-310.pyc +0 -0
  46. models/benchmark.py +223 -0
  47. models/dataset.py +401 -0
  48. models/inference.py +142 -0
  49. models/job.py +51 -0
  50. models/model.py +129 -0
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a lightweight Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PORT=7860
7
+
8
+ # Set working directory
9
+ WORKDIR /app
10
+
11
+ # Install system dependencies
12
+ RUN apt-get update && apt-get install -y \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy requirements and install
17
+ # Note: requirements.txt should be in the same directory as Dockerfile (backend/)
18
+ COPY requirements.txt .
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the rest of the backend code
22
+ COPY . .
23
+
24
+ # HuggingFace Spaces uses port 7860 by default
25
+ EXPOSE 7860
26
+
27
+ # Run the FastAPI app
28
+ # We use 0.0.0.0 to allow external connections within the container
29
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
api/__init__.py ADDED
File without changes
api/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
api/routes/__init__.py ADDED
File without changes
api/routes/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
api/routes/__pycache__/benchmark.cpython-310.pyc ADDED
Binary file (6.16 kB). View file
 
api/routes/__pycache__/datasets.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
api/routes/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.1 kB). View file
 
api/routes/__pycache__/jobs.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
api/routes/__pycache__/models.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
api/routes/__pycache__/projects.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
api/routes/__pycache__/sync.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
api/routes/__pycache__/system.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
api/routes/__pycache__/training.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
api/routes/benchmark.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/benchmark.py β€” Benchmark Bridge REST + WebSocket API.
3
+
4
+ Routes:
5
+ POST /benchmark/validate β€” compatibility check (no job created)
6
+ POST /benchmark/run β€” validate + create + enqueue (202)
7
+ GET /benchmark/jobs β€” list jobs (filterable)
8
+ GET /benchmark/results/all β€” list all results
9
+ GET /benchmark/{job_id} β€” single job status + logs
10
+ GET /benchmark/{job_id}/result β€” metrics + telemetry for completed job
11
+ WS /benchmark/live/{job_id} β€” real-time progress stream
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ from typing import Any
17
+
18
+ from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect
19
+
20
+ import benchmark.orchestrator as orchestrator
21
+ import benchmark.registry as bench_reg
22
+ from models.benchmark import (
23
+ BenchmarkContext,
24
+ BenchmarkJob,
25
+ BenchmarkResult,
26
+ BenchmarkRunResponse,
27
+ ValidationReport,
28
+ )
29
+ from observability.logger import get_logger
30
+
31
+ log = get_logger("api.benchmark")
32
+
33
+ router = APIRouter(prefix="/benchmark", tags=["benchmark"])
34
+
35
+
36
+ # ── POST /benchmark/validate ──────────────────────────────────────────────────
37
+
38
+ @router.post(
39
+ "/validate",
40
+ response_model = ValidationReport,
41
+ summary = "Validate model ↔ dataset compatibility",
42
+ description = (
43
+ "Runs all 5 compatibility gates (task, format, frameworkΓ—hardware, "
44
+ "VRAM, precision) and returns a structured report. "
45
+ "Does NOT create a benchmark job."
46
+ ),
47
+ )
48
+ async def validate_benchmark(ctx: BenchmarkContext) -> ValidationReport:
49
+ try:
50
+ return await orchestrator.validate_context(ctx)
51
+ except HTTPException:
52
+ raise
53
+ except Exception as exc:
54
+ log.exception("validate_error")
55
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
56
+
57
+
58
+ # ── POST /benchmark/run ───────────────────────────────────────────────────────
59
+
60
+ @router.post(
61
+ "/run",
62
+ response_model = BenchmarkRunResponse,
63
+ status_code = 202,
64
+ summary = "Start a benchmark run",
65
+ description = (
66
+ "Validates compatibility, creates a benchmark job, and starts async "
67
+ "execution. Returns job_id immediately β€” poll GET /benchmark/{job_id} "
68
+ "or connect to WS /benchmark/live/{job_id} for progress."
69
+ ),
70
+ )
71
+ async def run_benchmark(ctx: BenchmarkContext) -> BenchmarkRunResponse:
72
+ try:
73
+ job = await orchestrator.create_and_run(ctx)
74
+ return BenchmarkRunResponse(
75
+ job_id = job.id,
76
+ status = job.status,
77
+ message = f"Benchmark job {job.id} queued β€” connect to /benchmark/live/{job.id} for live telemetry",
78
+ )
79
+ except HTTPException:
80
+ raise
81
+ except Exception as exc:
82
+ log.exception("run_benchmark_error")
83
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
84
+
85
+
86
+ # ── POST /benchmark/sync ──────────────────────────────────────────────────────────
87
+
88
+ @router.post(
89
+ "/sync",
90
+ summary = "Sync benchmarks from active project folder",
91
+ description = "Scans the active project's 'benchmarks' folder and ensures all JSON records are indexed in SQLite.",
92
+ )
93
+ async def sync_benchmarks() -> dict[str, Any]:
94
+ try:
95
+ count = await orchestrator.sync_project_benchmarks()
96
+ return {"status": "success", "count": count}
97
+ except Exception as exc:
98
+ log.exception("sync_error")
99
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
100
+
101
+
102
+ # ── GET /benchmark/jobs ───────────────────────────────────────────────────────
103
+
104
+ @router.get(
105
+ "/jobs",
106
+ response_model = list[BenchmarkJob],
107
+ summary = "List benchmark jobs",
108
+ )
109
+ async def list_jobs(
110
+ status: str | None = Query(None, description="Filter by status (queued|running|completed|failed)"),
111
+ model_id: str | None = Query(None, description="Filter by model_id"),
112
+ limit: int = Query(100, ge=1, le=500),
113
+ ) -> list[BenchmarkJob]:
114
+ return await bench_reg.list_jobs(status=status, model_id=model_id, limit=limit)
115
+
116
+
117
+ # ── GET /benchmark/results/all ────────────────────────────────────────────────
118
+ # Must be declared BEFORE /{job_id} to avoid "results" being treated as a job_id
119
+
120
+ @router.get(
121
+ "/results/all",
122
+ response_model = list[BenchmarkResult],
123
+ summary = "List all benchmark results (leaderboard feed)",
124
+ )
125
+ async def list_results(
126
+ limit: int = Query(100, ge=1, le=500),
127
+ ) -> list[BenchmarkResult]:
128
+ return await bench_reg.list_results(limit=limit)
129
+
130
+
131
+ # ── GET /benchmark/{job_id} ───────────────────────────────────────────────────
132
+
133
+ @router.get(
134
+ "/{job_id}",
135
+ response_model = BenchmarkJob,
136
+ summary = "Get benchmark job status + logs",
137
+ )
138
+ async def get_job(job_id: str) -> BenchmarkJob:
139
+ job = await bench_reg.get_job(job_id)
140
+ if not job:
141
+ raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
142
+ return job
143
+
144
+
145
+ # ── GET /benchmark/{job_id}/result ───────────────────────────────────────────
146
+
147
+ @router.get(
148
+ "/{job_id}/result",
149
+ response_model = BenchmarkResult,
150
+ summary = "Get final metrics + telemetry for a completed job",
151
+ )
152
+ async def get_result(job_id: str) -> BenchmarkResult:
153
+ result = await bench_reg.get_result(job_id)
154
+ if not result:
155
+ raise HTTPException(
156
+ status_code = 404,
157
+ detail = f"No result for job '{job_id}' β€” job may still be running",
158
+ )
159
+ return result
160
+
161
+
162
+ # ── WS /benchmark/live/{job_id} ───────────────────────────────────────────────
163
+
164
+ @router.websocket("/live/{job_id}")
165
+ async def live_telemetry(websocket: WebSocket, job_id: str) -> None:
166
+ """
167
+ WebSocket stream for real-time benchmark progress.
168
+ Streams incremental logs and high-frequency telemetry.
169
+ """
170
+ await websocket.accept()
171
+ log.info("ws_connected", job_id=job_id)
172
+
173
+ last_log_idx = 0
174
+
175
+ try:
176
+ while True:
177
+ job = await bench_reg.get_job(job_id)
178
+
179
+ if not job:
180
+ await websocket.send_json(
181
+ {"error": f"Job '{job_id}' not found", "job_id": job_id}
182
+ )
183
+ break
184
+
185
+ # Only send new logs
186
+ new_logs = job.logs[last_log_idx:]
187
+ last_log_idx = len(job.logs)
188
+
189
+ payload: dict[str, Any] = {
190
+ "job_id": job.id,
191
+ "status": job.status,
192
+ "progress": round(job.progress, 4),
193
+ "logs": new_logs,
194
+ "telemetry": job.last_telemetry.model_dump() if job.last_telemetry else None,
195
+ }
196
+ # Explicitly include detections for the UI visualizer if they exist
197
+ if job.last_telemetry and hasattr(job.last_telemetry, "detections"):
198
+ payload["detections"] = job.last_telemetry.detections
199
+
200
+ await websocket.send_json(payload)
201
+
202
+ if job.status == "completed":
203
+ result = await bench_reg.get_result(job_id)
204
+ if result:
205
+ await websocket.send_json(
206
+ {
207
+ "job_id": job_id,
208
+ "status": "completed",
209
+ "result": result.model_dump(),
210
+ }
211
+ )
212
+ break
213
+
214
+ if job.status == "failed":
215
+ await websocket.send_json(
216
+ {
217
+ "job_id": job_id,
218
+ "status": "failed",
219
+ "error": job.error or "Unknown error",
220
+ }
221
+ )
222
+ break
223
+
224
+ await asyncio.sleep(0.5) # poll at 2 Hz
225
+
226
+ except WebSocketDisconnect:
227
+ log.info("ws_disconnected", job_id=job_id)
228
+ except Exception as exc:
229
+ log.exception("ws_error", job_id=job_id)
230
+ try:
231
+ await websocket.send_json({"error": str(exc), "job_id": job_id})
232
+ except Exception:
233
+ pass
234
+ finally:
235
+ try:
236
+ await websocket.close()
237
+ except Exception:
238
+ pass
api/routes/datasets.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/datasets.py β€” Dataset Manager REST API.
3
+
4
+ Routes:
5
+ GET /datasets β€” list/search datasets
6
+ GET /datasets/{id} β€” dataset detail
7
+ POST /datasets/search/roboflow β€” search Roboflow Universe (real-time)
8
+ POST /datasets/sync/roboflow β€” sync workspace datasets to local DB
9
+ POST /datasets/{id}/import β€” initiate dataset import job
10
+ GET /datasets/{id}/images β€” paginated viewer (images + annotations)
11
+ GET /datasets/{id}/image/{img} β€” serve raw image bytes
12
+ GET /datasets/jobs β€” list import jobs
13
+ GET /datasets/jobs/{job_id} β€” single job status
14
+ POST /datasets/{id}/star β€” toggle starred
15
+ DELETE /datasets/{id} β€” delete dataset record (+ local files)
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from pathlib import Path
20
+ from typing import Optional
21
+ from datetime import datetime
22
+
23
+ from fastapi import APIRouter, Depends, HTTPException, Query
24
+ from fastapi.responses import FileResponse, Response
25
+
26
+ from adapters.roboflow_adapter import RoboflowAdapter
27
+ from datasets import registry as ds_reg
28
+ from datasets.import_service import start_import
29
+ from datasets.viewer_service import get_universal_viewer_page, get_viewer_page, resolve_image_path
30
+ from models.dataset import (
31
+ Dataset, DatasetJob, DatasetSummary, DatasetSource, DatasetTask,
32
+ DatasetFormat, DatasetStatus, ImportRequest, ImportResponse,
33
+ RoboflowSearchRequest, ViewerPage, UniversalViewerPage, row_to_dataset,
34
+ )
35
+ from observability.logger import audit, get_logger
36
+
37
+ log = get_logger("datasets_route")
38
+
39
+ router = APIRouter(prefix="/datasets", tags=["datasets"])
40
+
41
+
42
+ # ── List / Search datasets ────────────────────────────────────────────────────
43
+
44
+ @router.get("", response_model=list[DatasetSummary])
45
+ async def list_datasets(
46
+ task: Optional[str] = Query(None),
47
+ format: Optional[str] = Query(None),
48
+ source: Optional[str] = Query(None),
49
+ status: Optional[str] = Query(None),
50
+ search: Optional[str] = Query(None),
51
+ starred: Optional[bool] = Query(None),
52
+ limit: int = Query(100, ge=1, le=1000),
53
+ offset: int = Query(0, ge=0),
54
+ ):
55
+ try:
56
+ datasets = await ds_reg.get_all_datasets(
57
+ task=task, format=format, source=source,
58
+ status=status, search=search, starred=starred,
59
+ limit=limit, offset=offset,
60
+ )
61
+ return [_to_summary(d) for d in datasets]
62
+ except Exception as exc:
63
+ log.exception("list_datasets_error")
64
+ raise HTTPException(status_code=500, detail=str(exc))
65
+
66
+
67
+ @router.get("/jobs", response_model=list[DatasetJob])
68
+ async def list_jobs(limit: int = Query(50, ge=1, le=500)):
69
+ return await ds_reg.get_all_jobs(limit=limit)
70
+
71
+
72
+ @router.get("/jobs/{job_id}", response_model=DatasetJob)
73
+ async def get_job(job_id: str):
74
+ job = await ds_reg.get_job(job_id)
75
+ if not job:
76
+ raise HTTPException(404, f"Job {job_id!r} not found")
77
+ return job
78
+
79
+
80
+ @router.post("/jobs/{job_id}/stop")
81
+ async def stop_job(job_id: str):
82
+ """Cancel a running import job."""
83
+ # Logic to cancel the asyncio task would go here
84
+ # For now, we update the status in the DB
85
+ await ds_reg.update_job(job_id, status="failed", error="Cancelled by user", ended_at=datetime.utcnow().isoformat())
86
+ return {"status": "success", "message": "Job stop requested"}
87
+
88
+
89
+ @router.post("/jobs/{job_id}/pause")
90
+ async def pause_job(job_id: str):
91
+ """Pause a running import job."""
92
+ await ds_reg.update_job(job_id, status="paused")
93
+ return {"status": "success", "message": "Job pause requested"}
94
+
95
+
96
+ @router.post("/jobs/{job_id}/resume")
97
+ async def resume_job(job_id: str):
98
+ """Resume a paused import job."""
99
+ await ds_reg.update_job(job_id, status="running")
100
+ return {"status": "success", "message": "Job resume requested"}
101
+
102
+
103
+ # ── Roboflow Search & Sync ────────────────────────────────────────────────────
104
+
105
+ @router.post("/search/roboflow", response_model=list[DatasetSummary])
106
+ async def search_roboflow(req: RoboflowSearchRequest):
107
+ """
108
+ Live search Roboflow Universe. Results are cached for 1 hour.
109
+ Also upserts results into local DB so they appear in /datasets.
110
+ """
111
+ try:
112
+ datasets = await RoboflowAdapter.search_datasets(
113
+ api_key = req.api_key,
114
+ query = req.query,
115
+ workspace = req.workspace,
116
+ page = req.page,
117
+ page_size = req.page_size,
118
+ )
119
+ except Exception as exc:
120
+ log.error("roboflow_search_error", error=str(exc))
121
+ raise HTTPException(502, f"Roboflow API error: {exc}")
122
+
123
+ # Upsert to local DB
124
+ await ds_reg.bulk_upsert_datasets(datasets)
125
+ await audit("roboflow_search", {"query": req.query, "count": len(datasets)})
126
+ return [_to_summary(d) for d in datasets]
127
+
128
+
129
+ @router.post("/sync/roboflow", response_model=dict)
130
+ async def sync_roboflow_workspace(
131
+ api_key: str = Query(..., description="Roboflow API key"),
132
+ workspace: str = Query(..., description="Workspace slug"),
133
+ ):
134
+ """Sync all datasets from a Roboflow workspace into local DB."""
135
+ try:
136
+ datasets = await RoboflowAdapter.list_workspace_datasets(api_key, workspace)
137
+ except Exception as exc:
138
+ raise HTTPException(502, f"Roboflow API error: {exc}")
139
+ count = await ds_reg.bulk_upsert_datasets(datasets)
140
+ return {"synced": count, "workspace": workspace}
141
+
142
+
143
+ # ── Dataset detail ────────────────────────────────────────────────────────────
144
+
145
+ @router.get("/{dataset_id}", response_model=Dataset)
146
+ async def get_dataset(dataset_id: str):
147
+ ds = await ds_reg.get_dataset(dataset_id)
148
+ if not ds:
149
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found")
150
+ return ds
151
+
152
+
153
+ # ── Import ────────────────────────────────────────────────────────────────────
154
+
155
+ @router.post("/{dataset_id}/import", response_model=ImportResponse)
156
+ async def import_dataset(dataset_id: str, req: ImportRequest):
157
+ """
158
+ Initiate a background import job for a dataset.
159
+ Supports sources: roboflow | roboflow_curl | huggingface | local
160
+ """
161
+ req.dataset_id = dataset_id # enforce consistency
162
+
163
+ # Sources that are discovered outside the registry must be auto-registered.
164
+ auto_register_sources = {DatasetSource.huggingface, DatasetSource.roboflow_curl, DatasetSource.local}
165
+ if req.source in auto_register_sources:
166
+ ds = await ds_reg.get_dataset(dataset_id)
167
+ if not ds:
168
+ # Determine human-readable name
169
+ if req.source == DatasetSource.huggingface and req.hf_dataset_id:
170
+ name = req.hf_dataset_id
171
+ roboflow_ref = req.hf_dataset_id
172
+ fmt = DatasetFormat.json
173
+ src = DatasetSource.huggingface
174
+
175
+ elif req.source == DatasetSource.local:
176
+ # local: use provided name or folder name from path
177
+ # Try req.local_path first, then req.name, then fallback to dataset_id
178
+ path_to_use = req.local_path or req.name or ""
179
+ name = req.name or (Path(path_to_use).name if path_to_use else dataset_id)
180
+ roboflow_ref = None
181
+ fmt = DatasetFormat.custom
182
+ src = DatasetSource.local
183
+ else:
184
+ # roboflow_curl: use provided dataset_name or fall back to dataset_id
185
+ name = req.dataset_name or dataset_id
186
+ roboflow_ref = None
187
+ fmt = _curl_format_to_enum(req.curl_format)
188
+ src = DatasetSource.roboflow_curl
189
+
190
+ stub = Dataset(
191
+ id=dataset_id,
192
+ name=name,
193
+ task=DatasetTask.detection,
194
+ format=fmt,
195
+ source=src,
196
+ status=DatasetStatus.available,
197
+ roboflow_id=roboflow_ref,
198
+ created_at=datetime.utcnow().isoformat(),
199
+ )
200
+ await ds_reg.upsert_dataset(stub)
201
+ log.info("dataset_auto_registered", dataset_id=dataset_id, source=str(req.source))
202
+ else:
203
+ ds = await ds_reg.get_dataset(dataset_id)
204
+ if not ds:
205
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found in registry. "
206
+ "Run /datasets/sync/roboflow first.")
207
+
208
+ try:
209
+ job_id = await start_import(req)
210
+ except ValueError as exc:
211
+ raise HTTPException(400, str(exc))
212
+
213
+ await audit("dataset_import_requested", {"dataset_id": dataset_id, "source": str(req.source)})
214
+ return ImportResponse(
215
+ job_id = job_id,
216
+ dataset_id = dataset_id,
217
+ status = "queued",
218
+ message = "Import job created successfully",
219
+ )
220
+
221
+
222
+ def _curl_format_to_enum(curl_format: str | None) -> DatasetFormat:
223
+ """Map Roboflow export format string from cURL to DatasetFormat enum."""
224
+ if not curl_format:
225
+ return DatasetFormat.yolo
226
+ fmt = curl_format.lower()
227
+ if "yolo" in fmt:
228
+ return DatasetFormat.yolo
229
+ if "coco" in fmt:
230
+ return DatasetFormat.coco
231
+ if "voc" in fmt or "pascal" in fmt:
232
+ return DatasetFormat.voc
233
+ if "tfrecord" in fmt:
234
+ return DatasetFormat.tfrecord
235
+ if "csv" in fmt:
236
+ return DatasetFormat.csv
237
+ if "json" in fmt or "createml" in fmt:
238
+ return DatasetFormat.json
239
+ return DatasetFormat.yolo
240
+
241
+
242
+ # ── Viewer ──────────────────────────────────────────────────────────────────���─
243
+
244
+ @router.get("/{dataset_id}/universal", response_model=UniversalViewerPage)
245
+ async def get_universal_items(
246
+ dataset_id: str,
247
+ page: int = Query(0, ge=0),
248
+ page_size: int = Query(20, ge=1, le=100),
249
+ split: Optional[str] = Query(None, regex="^(train|val|test)$"),
250
+ class_label: Optional[str] = Query(None),
251
+ ):
252
+ """
253
+ Polymorphic dataset item viewer (UDV).
254
+ Supports Vision, NLP, and Tabular data via the Universal Dataset Item schema.
255
+ """
256
+ ds = await ds_reg.get_dataset(dataset_id)
257
+ if not ds:
258
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found")
259
+
260
+ # Allow viewing even if not fully imported for NLP/Tabular if files exist,
261
+ # but for Vision we usually need the index.
262
+ return await get_universal_viewer_page(dataset_id, page, page_size, split, class_label)
263
+
264
+
265
+ @router.get("/{dataset_id}/images", response_model=ViewerPage)
266
+ async def get_images(
267
+ dataset_id: str,
268
+ page: int = Query(0, ge=0),
269
+ page_size: int = Query(20, ge=1, le=100),
270
+ split: Optional[str] = Query(None, regex="^(train|val|test)$"),
271
+ class_label: Optional[str] = Query(None),
272
+ ):
273
+ """
274
+ Paginated image + annotation data for the viewer.
275
+ Annotations are returned in normalised [0–1] coordinates.
276
+ Supports filtering by split and class label.
277
+ """
278
+ ds = await ds_reg.get_dataset(dataset_id)
279
+ if not ds:
280
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found")
281
+ if ds.status != "imported":
282
+ raise HTTPException(409, f"Dataset is not imported yet (status: {ds.status})")
283
+
284
+ return await get_viewer_page(dataset_id, page, page_size, split, class_label)
285
+
286
+
287
+ @router.get("/{dataset_id}/stats", response_model=dict)
288
+ async def get_dataset_stats(dataset_id: str):
289
+ """Return pre-computed class distributions and split statistics."""
290
+ ds = await ds_reg.get_dataset(dataset_id)
291
+ if not ds:
292
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found")
293
+
294
+ return await ds_reg.get_dataset_stats(dataset_id)
295
+
296
+
297
+ @router.get("/{dataset_id}/image/{image_id}")
298
+ async def serve_image(dataset_id: str, image_id: str):
299
+ """Serve raw image bytes for the viewer (cached by browser)."""
300
+ path = await resolve_image_path(dataset_id, image_id)
301
+ if path is None:
302
+ raise HTTPException(404, "Image not found or dataset not imported")
303
+
304
+ suffix = path.suffix.lower()
305
+ media_types = {
306
+ ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
307
+ ".png": "image/png", ".bmp": "image/bmp",
308
+ ".webp": "image/webp",
309
+ }
310
+ media_type = media_types.get(suffix, "application/octet-stream")
311
+ return FileResponse(
312
+ path = str(path),
313
+ media_type = media_type,
314
+ headers = {"Cache-Control": "public, max-age=86400"},
315
+ )
316
+
317
+
318
+ @router.get("/{dataset_id}/annotations", response_model=dict)
319
+ async def get_annotations_summary(dataset_id: str):
320
+ """Return class distribution summary from the annotations index."""
321
+ from database.connection import get_db
322
+ db = await get_db()
323
+ async with db.execute(
324
+ """SELECT label, COUNT(*) as count
325
+ FROM dataset_annotations
326
+ WHERE dataset_id=?
327
+ GROUP BY label
328
+ ORDER BY count DESC""",
329
+ (dataset_id,),
330
+ ) as cur:
331
+ rows = await cur.fetchall()
332
+ return {
333
+ "dataset_id": dataset_id,
334
+ "class_distribution": [{"label": r["label"], "count": r["count"]} for r in rows],
335
+ "total_annotations": sum(r["count"] for r in rows),
336
+ }
337
+
338
+
339
+ # ── Star / Delete ─────────────────────────────────────────────────────────────
340
+
341
+ @router.post("/{dataset_id}/star", response_model=dict)
342
+ async def toggle_star(dataset_id: str):
343
+ new_val = await ds_reg.toggle_starred(dataset_id)
344
+ return {"dataset_id": dataset_id, "starred": new_val}
345
+
346
+
347
+ @router.delete("/{dataset_id}", response_model=dict)
348
+ async def delete_dataset(
349
+ dataset_id: str,
350
+ delete_files: bool = Query(False, description="Also remove local files from disk"),
351
+ ):
352
+ ds = await ds_reg.get_dataset(dataset_id)
353
+ if not ds:
354
+ raise HTTPException(404, f"Dataset {dataset_id!r} not found")
355
+
356
+ if delete_files and ds.local_path:
357
+ import shutil
358
+ local = Path(ds.local_path)
359
+ if local.exists() and local.is_dir():
360
+ shutil.rmtree(str(local), ignore_errors=True)
361
+ log.info("dataset_files_deleted", path=str(local))
362
+
363
+ deleted = await ds_reg.delete_dataset(dataset_id)
364
+ await audit("dataset_deleted", {"dataset_id": dataset_id, "files_deleted": delete_files})
365
+ return {"deleted": deleted, "dataset_id": dataset_id}
366
+
367
+
368
+ # ── Helper ────────────────────────────────────────────────────────────────────
369
+
370
+ def _to_summary(d: Dataset) -> DatasetSummary:
371
+ # Use 0.0 as default health_score if stats is missing or health_score is not present
372
+ health_score = 0.0
373
+ try:
374
+ if hasattr(d, 'stats') and d.stats:
375
+ health_score = getattr(d.stats, 'health_score', 0.0)
376
+ except Exception:
377
+ pass
378
+
379
+ return DatasetSummary(
380
+ id = d.id,
381
+ name = d.name,
382
+ task = str(d.task),
383
+ format = str(d.format),
384
+ source = str(d.source),
385
+ status = str(d.status),
386
+ images = d.images,
387
+ classes = d.classes,
388
+ size_label = d.size_label,
389
+ tags = d.tags,
390
+ starred = d.starred,
391
+ import_progress = d.import_progress,
392
+ health_score = health_score,
393
+ created_at = d.created_at,
394
+ updated_at = d.updated_at,
395
+ )
api/routes/inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/inference.py β€” Inference Engine endpoints.
3
+
4
+ POST /inference/run β€” single synchronous inference pass
5
+ POST /inference/stream β€” SSE stream (stage-by-stage pipeline events)
6
+ GET /inference/history β€” session ledger
7
+ DELETE /inference/history β€” clear session ledger
8
+ GET /inference/cache β€” currently warm models in memory
9
+ DELETE /inference/cache/{model_id} β€” evict from cache
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import time
16
+
17
+ from fastapi import APIRouter, HTTPException, Response
18
+ from fastapi.responses import StreamingResponse
19
+
20
+ from inference.engine import InferenceEngine, evict_model, get_cache_status
21
+ from inference.session import clear_history, get_history, record
22
+ from models.inference import (
23
+ InferenceHistoryEntry,
24
+ InferenceRequest,
25
+ InferenceResult,
26
+ )
27
+ from observability.logger import get_logger
28
+ from registry.registry import get_model
29
+
30
+ log = get_logger("api.inference")
31
+ router = APIRouter(prefix="/inference", tags=["inference"])
32
+
33
+ _engine = InferenceEngine()
34
+
35
+
36
+ # ── Single run ───────────────────────────────────────────────────────────────
37
+
38
+ @router.post("/run", response_model=InferenceResult)
39
+ async def run_inference(body: InferenceRequest) -> InferenceResult:
40
+ """Execute one full inference pass and return the complete result."""
41
+ model = await get_model(body.model_id)
42
+ if not model:
43
+ raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
44
+
45
+ result = await _engine.run(body, model)
46
+
47
+ if result.status == "error":
48
+ raise HTTPException(status_code=500, detail=result.error or "Inference failed")
49
+
50
+ await record(body, result, model.name)
51
+ return result
52
+
53
+
54
+ # ── SSE stream ───────────────────────────────────────────────────────────────
55
+
56
+ @router.post("/stream")
57
+ async def stream_inference(body: InferenceRequest) -> StreamingResponse:
58
+ """
59
+ Server-Sent Events stream.
60
+ Emits one JSON event per pipeline stage as it completes, then a final
61
+ 'done' event with the full InferenceResult.
62
+
63
+ Client usage:
64
+ const es = new EventSource('/inference/stream'); // POST via fetch + EventSource polyfill
65
+ es.onmessage = e => console.log(JSON.parse(e.data));
66
+ """
67
+ model = await get_model(body.model_id)
68
+ if not model:
69
+ raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
70
+
71
+ queue: asyncio.Queue[str | None] = asyncio.Queue()
72
+
73
+ async def _producer() -> None:
74
+ """Run inference while pushing SSE events into the queue."""
75
+ try:
76
+ # Patch engine to emit stage events
77
+ result = await _engine_stream(body, model, queue)
78
+ await record(body, result, model.name)
79
+ # Final complete event
80
+ await queue.put(
81
+ f"event: done\ndata: {result.model_dump_json()}\n\n"
82
+ )
83
+ except Exception as exc:
84
+ await queue.put(
85
+ f"event: error\ndata: {json.dumps({'error': str(exc)})}\n\n"
86
+ )
87
+ finally:
88
+ await queue.put(None) # sentinel
89
+
90
+ asyncio.create_task(_producer())
91
+
92
+ async def _generator():
93
+ while True:
94
+ msg = await queue.get()
95
+ if msg is None:
96
+ break
97
+ yield msg
98
+
99
+ return StreamingResponse(
100
+ _generator(),
101
+ media_type="text/event-stream",
102
+ headers={
103
+ "Cache-Control": "no-cache",
104
+ "X-Accel-Buffering": "no",
105
+ },
106
+ )
107
+
108
+
109
+ async def _engine_stream(
110
+ req: InferenceRequest,
111
+ model,
112
+ queue: asyncio.Queue,
113
+ ) -> InferenceResult:
114
+ """
115
+ Run inference and push a 'stage' SSE event for each PipelineStage.
116
+ Falls back to a simple full run if streaming is not distinguishable.
117
+ """
118
+ # Run full pipeline
119
+ result = await _engine.run(req, model)
120
+
121
+ # Emit one event per stage (replay after completion)
122
+ for stage in result.pipeline:
123
+ payload = json.dumps({
124
+ "type": "stage",
125
+ "stage": stage.model_dump(),
126
+ "ts": time.time(),
127
+ })
128
+ await queue.put(f"data: {payload}\n\n")
129
+ await asyncio.sleep(0) # yield
130
+
131
+ # Emit vitals snapshot
132
+ vitals_payload = json.dumps({
133
+ "type": "vitals",
134
+ "latency_ms": result.inference_ms,
135
+ "total_ms": result.total_ms,
136
+ "quality": result.quality_score,
137
+ })
138
+ await queue.put(f"data: {vitals_payload}\n\n")
139
+
140
+ return result
141
+
142
+
143
+ # ── History ──────────────────────────────────────────────────────────────────
144
+
145
+ @router.get("/history", response_model=list[InferenceHistoryEntry])
146
+ async def inference_history(limit: int = 50) -> list[InferenceHistoryEntry]:
147
+ return await get_history(limit=min(limit, 200))
148
+
149
+
150
+ @router.delete("/history", status_code=204, response_model=None)
151
+ async def clear_inference_history():
152
+ await clear_history()
153
+ return Response(status_code=204)
154
+
155
+
156
+ # ── Model cache ──────────────────────────────────────────────────────────────
157
+
158
+ @router.get("/cache")
159
+ async def cache_status() -> dict[str, bool]:
160
+ return get_cache_status()
161
+
162
+
163
+ @router.delete("/cache/{model_id}", status_code=204, response_model=None)
164
+ async def evict_from_cache(model_id: str):
165
+ evicted = evict_model(model_id)
166
+ if not evicted:
167
+ raise HTTPException(status_code=404, detail="Model not in cache")
168
+ return Response(status_code=204)
api/routes/jobs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/jobs.py β€” /jobs & /download endpoints.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from fastapi import APIRouter, HTTPException
7
+
8
+ from download.manager import cancel_job, enqueue_download, get_job, list_jobs
9
+ from models.job import Job, JobCreate
10
+ from observability.logger import audit, get_logger
11
+ from registry.registry import get_model
12
+
13
+ log = get_logger("api.jobs")
14
+ router = APIRouter(tags=["jobs"])
15
+
16
+
17
+ @router.post("/download", response_model=Job, status_code=202)
18
+ async def trigger_download(body: JobCreate) -> Job:
19
+ """Enqueue a model download. Returns the created job immediately."""
20
+ model = await get_model(body.model_id)
21
+ if not model:
22
+ raise HTTPException(status_code=404, detail=f"Model '{body.model_id}' not found")
23
+ if model.downloaded:
24
+ raise HTTPException(status_code=409, detail="Model is already cached locally")
25
+
26
+ job_id = await enqueue_download(
27
+ model_id=body.model_id,
28
+ model_name=body.model_name,
29
+ version=body.version,
30
+ )
31
+ job = await get_job(job_id)
32
+ if not job:
33
+ raise HTTPException(status_code=500, detail="Job creation failed")
34
+
35
+ await audit("api_download_trigger", model_id=body.model_id, job_id=job_id)
36
+ return job
37
+
38
+
39
+ @router.get("/jobs", response_model=list[Job])
40
+ async def jobs_list(status: str | None = None, limit: int = 50) -> list[Job]:
41
+ return await list_jobs(status=status, limit=limit)
42
+
43
+
44
+ @router.get("/jobs/{job_id}", response_model=Job)
45
+ async def job_detail(job_id: str) -> Job:
46
+ job = await get_job(job_id)
47
+ if not job:
48
+ raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
49
+ return job
50
+
51
+
52
+ @router.delete("/jobs/{job_id}", status_code=204, response_model=None)
53
+ async def job_cancel(job_id: str) -> None:
54
+ success = await cancel_job(job_id)
55
+ if not success:
56
+ raise HTTPException(status_code=409, detail="Job cannot be cancelled")
api/routes/models.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/models.py β€” /models REST endpoints.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from typing import Annotated
7
+
8
+ from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form
9
+ from models.model import Model
10
+ from observability.logger import audit, get_logger
11
+ from registry.registry import count_models, get_model, list_models
12
+ from projects.service import get_active_project_id, import_local_model
13
+ from projects.registry import get_project
14
+ from pathlib import Path
15
+ import os
16
+ import tempfile
17
+
18
+ log = get_logger("api.models")
19
+ router = APIRouter(prefix="/models", tags=["models"])
20
+
21
+
22
+ @router.get("", response_model=list[Model])
23
+ async def index(
24
+ search: Annotated[str | None, Query()] = None,
25
+ task: Annotated[list[str] | None, Query()] = None,
26
+ framework: Annotated[list[str] | None, Query()] = None,
27
+ hardware: Annotated[list[str] | None, Query()] = None,
28
+ source: Annotated[list[str] | None, Query()] = None,
29
+ downloaded: Annotated[bool | None, Query()] = None,
30
+ sort_by: Annotated[str, Query()] = "downloads",
31
+ sort_dir: Annotated[str, Query()] = "desc",
32
+ limit: Annotated[int, Query(ge=1, le=1000)] = 200,
33
+ offset: Annotated[int, Query(ge=0)] = 0,
34
+ project_id: Annotated[str | None, Query()] = None,
35
+ ) -> list[Model]:
36
+ """
37
+ List and search models.
38
+ Supports FTS search + server-side filtering.
39
+ Target: < 100ms for up to 5 000 models.
40
+ """
41
+ effective_project_id = project_id or await get_active_project_id()
42
+
43
+ models = await list_models(
44
+ search=search,
45
+ tasks=task,
46
+ frameworks=framework,
47
+ hardware=hardware,
48
+ sources=source,
49
+ downloaded=downloaded,
50
+ sort_by=sort_by,
51
+ sort_dir=sort_dir,
52
+ limit=limit,
53
+ offset=offset,
54
+ project_id=effective_project_id,
55
+ )
56
+
57
+ # If we have an active project, derive cache state from its workspace.
58
+ # This makes "downloaded" and "local_path" reflect the *current project*.
59
+ if effective_project_id:
60
+ proj = await get_project(effective_project_id)
61
+ if proj:
62
+ project_models_dir = Path(proj.path) / "models"
63
+
64
+ updated: list[Model] = []
65
+ for m in models:
66
+ model_dir = project_models_dir / m.id
67
+ if model_dir.exists() and model_dir.is_dir():
68
+ # Pick the first file in the model directory (best-effort).
69
+ found_file: str | None = None
70
+ try:
71
+ for p in model_dir.rglob("*"):
72
+ if p.is_file():
73
+ found_file = str(p)
74
+ break
75
+ except Exception:
76
+ found_file = None
77
+
78
+ if found_file:
79
+ updated.append(m.model_copy(update={"downloaded": True, "local_path": found_file}))
80
+ continue
81
+
82
+ # Not present in this project β†’ treat as not cached for this project.
83
+ updated.append(m.model_copy(update={"downloaded": False, "local_path": None}))
84
+
85
+ models = updated
86
+
87
+ await audit("api_list_models", payload={"count": len(models), "search": search})
88
+ return models
89
+
90
+
91
+ @router.post("/import", response_model=Model)
92
+ async def import_model(
93
+ name: Annotated[str, Form()],
94
+ task: Annotated[str, Form()],
95
+ framework: Annotated[str, Form()],
96
+ file: UploadFile = File(...),
97
+ ) -> Model:
98
+ """Import a local model file into the active project."""
99
+ # Save uploaded file to a temporary location
100
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename or "")[1]) as tmp:
101
+ content = await file.read()
102
+ tmp.write(content)
103
+ tmp_path = tmp.name
104
+
105
+ try:
106
+ model = await import_local_model(
107
+ name=name,
108
+ task=task,
109
+ framework=framework,
110
+ source_file_path=tmp_path
111
+ )
112
+ return model
113
+ except Exception as e:
114
+ log.error("model_import_failed", error=str(e))
115
+ raise HTTPException(status_code=500, detail=str(e))
116
+ finally:
117
+ if os.path.exists(tmp_path):
118
+ os.remove(tmp_path)
119
+
120
+
121
+ @router.get("/{model_id}", response_model=Model)
122
+ async def detail(model_id: str) -> Model:
123
+ model = await get_model(model_id)
124
+ if not model:
125
+ raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
126
+ await audit("api_get_model", model_id=model_id)
127
+ return model
api/routes/projects.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """api/routes/projects.py β€” /projects REST endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi import APIRouter, HTTPException, Query
6
+
7
+ from models.project import Project
8
+ from observability.logger import audit
9
+ from projects.registry import delete_project, get_project, list_projects, touch_last_opened, upsert_project
10
+ from projects.service import set_active_project
11
+
12
+
13
+ router = APIRouter(prefix="/projects", tags=["projects"])
14
+
15
+
16
+ @router.get("", response_model=list[Project])
17
+ async def projects_list(
18
+ limit: int = Query(200, ge=1, le=1000),
19
+ offset: int = Query(0, ge=0),
20
+ ) -> list[Project]:
21
+ projects = await list_projects(limit=limit, offset=offset)
22
+ await audit("api_list_projects", payload={"count": len(projects)})
23
+ return projects
24
+
25
+
26
+ @router.post("", response_model=Project)
27
+ async def projects_upsert(project: Project) -> Project:
28
+ # Ensure project.created_at and last_opened are set if missing
29
+ if not project.created_at:
30
+ project.created_at = datetime.now(timezone.utc).isoformat()
31
+ if not project.last_opened:
32
+ project.last_opened = datetime.now(timezone.utc).isoformat()
33
+
34
+ await upsert_project(project)
35
+ await audit("api_upsert_project", payload={"project_id": project.id})
36
+ return project
37
+
38
+
39
+ @router.post("/{project_id}/open", status_code=204, response_model=None)
40
+ async def projects_open(project_id: str) -> None:
41
+ await touch_last_opened(project_id)
42
+ project = await get_project(project_id)
43
+ if project:
44
+ await set_active_project(project.id, project.path)
45
+ await audit("api_open_project", payload={"project_id": project_id})
46
+
47
+
48
+ @router.delete("/{project_id}", status_code=204, response_model=None)
49
+ async def projects_delete(project_id: str) -> None:
50
+ ok = await delete_project(project_id)
51
+ if not ok:
52
+ raise HTTPException(status_code=404, detail=f"Project '{project_id}' not found")
53
+ await audit("api_delete_project", payload={"project_id": project_id})
54
+
api/routes/sync.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/sync.py β€” /sync endpoint: fetch fresh models from all adapters.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from fastapi import APIRouter, BackgroundTasks
7
+
8
+ from adapters.hf_adapter import HFAdapter
9
+ from adapters.onnx_adapter import ONNXAdapter
10
+ from observability.logger import audit, get_logger
11
+ from registry.registry import bulk_upsert, count_models
12
+
13
+ log = get_logger("api.sync")
14
+ router = APIRouter(tags=["sync"])
15
+
16
+
17
+ async def _run_full_sync() -> None:
18
+ log.info("sync_start")
19
+ total = 0
20
+
21
+ async with HFAdapter() as hf:
22
+ hf_models = await hf.fetch_models()
23
+ await bulk_upsert(hf_models)
24
+ total += len(hf_models)
25
+ log.info("sync_hf_done", count=len(hf_models))
26
+
27
+ # Prune any HF models outside the allowed task set (e.g. legacy NLP entries)
28
+ allowed_tasks = {"detection", "classification", "segmentation", "generation", "embedding"}
29
+ from database.connection import get_db
30
+
31
+ db = await get_db()
32
+ placeholders = ",".join(["?"] * len(allowed_tasks))
33
+ await db.execute(
34
+ f"DELETE FROM models WHERE source = 'hf' AND task NOT IN ({placeholders})",
35
+ tuple(sorted(allowed_tasks)),
36
+ )
37
+ # Prune non-vision generation/embedding HF models. We rely on the adapter
38
+ # adding the pipeline_tag as a normalised tag (e.g. text_to_image).
39
+ await db.execute(
40
+ """
41
+ DELETE FROM models
42
+ WHERE source = 'hf'
43
+ AND task IN ('generation','embedding')
44
+ AND (
45
+ tags NOT LIKE '%text_to_image%'
46
+ AND tags NOT LIKE '%image_to_image%'
47
+ AND tags NOT LIKE '%image_feature_extraction%'
48
+ )
49
+ """,
50
+ )
51
+ await db.commit()
52
+
53
+ onnx = ONNXAdapter()
54
+ onnx_models = await onnx.fetch_models()
55
+ await bulk_upsert(onnx_models)
56
+ total += len(onnx_models)
57
+ log.info("sync_onnx_done", count=len(onnx_models))
58
+
59
+ log.info("sync_complete", total=total)
60
+ await audit("sync_complete", payload={"total": total})
61
+
62
+
63
+ @router.post("/sync", status_code=202)
64
+ async def trigger_sync(background_tasks: BackgroundTasks) -> dict:
65
+ """
66
+ Kick off a background sync that fetches models from all sources.
67
+ Returns immediately; progress visible via /models count.
68
+ """
69
+ background_tasks.add_task(_run_full_sync)
70
+ current = await count_models()
71
+ log.info("sync_triggered", current_model_count=current)
72
+ await audit("sync_triggered", payload={"current": current})
73
+ return {"message": "Sync started", "current_model_count": current}
api/routes/system.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """api/routes/system.py β€” System metrics endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+
8
+ from fastapi import APIRouter, Query
9
+ from fastapi.responses import StreamingResponse
10
+
11
+ from models.system import SystemMetrics
12
+ from system.metrics import sample_metrics
13
+
14
+ router = APIRouter(prefix="/system", tags=["system"])
15
+
16
+
17
+ @router.get("/metrics", response_model=SystemMetrics)
18
+ async def get_metrics(gpu_index: int = Query(0, ge=0)) -> SystemMetrics:
19
+ payload = sample_metrics(gpu_index=gpu_index)
20
+ return SystemMetrics(
21
+ ts=payload["ts"],
22
+ cpu_pct=payload["cpu_pct"],
23
+ cpu_model=payload.get("cpu_model"),
24
+ cpu_freq_mhz=payload.get("cpu_freq_mhz"),
25
+ cpu_count=payload.get("cpu_count"),
26
+ ram_used_mb=payload["ram_used_mb"],
27
+ ram_total_mb=payload["ram_total_mb"],
28
+ gpu=payload.get("gpu"),
29
+ disks=payload.get("disks", []),
30
+ network=payload.get("network", []),
31
+ )
32
+
33
+
34
+ @router.get("/metrics/stream")
35
+ async def stream_metrics(
36
+ gpu_index: int = Query(0, ge=0),
37
+ hz: float = Query(2.0, ge=0.2, le=20.0),
38
+ ):
39
+ """Server-Sent Events stream of system metrics."""
40
+
41
+ interval = 1.0 / float(hz)
42
+
43
+ async def gen():
44
+ # Initial comment helps some proxies establish the stream
45
+ yield ": connected\n\n"
46
+ while True:
47
+ try:
48
+ payload = sample_metrics(gpu_index=gpu_index)
49
+ # Ensure the payload is valid JSON and wrapped in data: format
50
+ data = json.dumps(payload)
51
+ yield f"data: {data}\n\n"
52
+ except Exception as e:
53
+ # Log error but keep stream alive
54
+ print(f"Metrics streaming error: {e}")
55
+ await asyncio.sleep(interval)
56
+
57
+ return StreamingResponse(
58
+ gen(),
59
+ media_type="text/event-stream",
60
+ headers={
61
+ "Cache-Control": "no-cache",
62
+ "X-Accel-Buffering": "no",
63
+ "Connection": "keep-alive",
64
+ "Transfer-Encoding": "chunked",
65
+ },
66
+ )
67
+
68
+
69
+ @router.get("/logs/stream")
70
+ async def stream_system_logs():
71
+ """SSE stream of global system and gateway logs."""
72
+ from observability.logger import _sys_log_subs
73
+
74
+ q: asyncio.Queue = asyncio.Queue()
75
+ _sys_log_subs.append(q)
76
+
77
+ async def generator():
78
+ yield ": connected\n\n"
79
+ try:
80
+ while True:
81
+ try:
82
+ entry = await asyncio.wait_for(q.get(), timeout=30.0)
83
+ except asyncio.TimeoutError:
84
+ yield ": heartbeat\n\n"
85
+ continue
86
+ if entry is None:
87
+ break
88
+ yield f"data: {json.dumps(entry)}\n\n"
89
+ finally:
90
+ if q in _sys_log_subs:
91
+ _sys_log_subs.remove(q)
92
+
93
+ return StreamingResponse(
94
+ generator(),
95
+ media_type="text/event-stream",
96
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
97
+ )
api/routes/training.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api/routes/training.py β€” Training Engine REST + SSE endpoints.
3
+
4
+ POST /train/start β€” create and launch a training run
5
+ POST /train/stop β€” cancel a running run
6
+ POST /train/pause β€” pause a running run
7
+ POST /train/resume β€” resume a paused run
8
+ GET /train/status β€” run status + progress snapshot
9
+ GET /train/runs β€” list all runs
10
+ GET /train/runs/{run_id} β€” single run detail
11
+ GET /train/schema β€” UI schema for task/model/dataset combo
12
+ GET /train/checkpoints β€” checkpoints for a run (stub)
13
+ POST /train/checkpoints/{id}/export β€” export a checkpoint (stub)
14
+ GET /train/metrics/stream β€” SSE: real-time metrics ticks
15
+ GET /train/logs/stream β€” SSE: real-time log entries
16
+ GET /train/resources/stream β€” SSE: real-time resource ticks
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import asyncio
21
+ import json
22
+ import time
23
+ import os
24
+
25
+ from fastapi import APIRouter, HTTPException, Query
26
+ from fastapi.responses import StreamingResponse
27
+
28
+ from observability.logger import get_logger
29
+ from training import run_manager
30
+ from training.schema_engine import generate_schema
31
+ from training.schemas import (
32
+ CheckpointOut,
33
+ PauseTrainRequest,
34
+ ResumeTrainRequest,
35
+ StartTrainRequest,
36
+ StartTrainResponse,
37
+ StopTrainRequest,
38
+ TrainRunOut,
39
+ TrainStatusResponse,
40
+ TrainingSchemaResponse,
41
+ )
42
+
43
+ log = get_logger("api.training")
44
+ router = APIRouter(prefix="/train", tags=["training"])
45
+
46
+ # ── Helpers ────────────────────────────────────────────────────────────────────
47
+
48
+ def _format_duration(seconds: float) -> str:
49
+ h = int(seconds // 3600)
50
+ m = int((seconds % 3600) // 60)
51
+ s = int(seconds % 60)
52
+ return f"{h}h {m}m {s}s"
53
+
54
+
55
+ def _run_to_out(run: run_manager.TrainRun) -> TrainRunOut:
56
+ elapsed = (run.completed_at or time.time()) - run.created_at
57
+ return TrainRunOut(
58
+ id=run.run_id,
59
+ run_number=run.run_number,
60
+ model_id=run.model_id,
61
+ model_name=run.model_name,
62
+ dataset_id=run.dataset_id,
63
+ dataset_name=run.dataset_name,
64
+ task=run.task,
65
+ status=run.status,
66
+ epochs_done=run.epoch,
67
+ total_epochs=run.total_epochs,
68
+ best_metric=run.best_metric,
69
+ final_loss=run.final_loss,
70
+ duration=_format_duration(elapsed),
71
+ created_at=run.created_at,
72
+ completed_at=run.completed_at,
73
+ hyperparams=run.hyperparams,
74
+ )
75
+
76
+
77
+ # ── Control endpoints ─────────────────────────────────────────────────────────
78
+
79
+ @router.post("/start", response_model=StartTrainResponse)
80
+ async def start_training(body: StartTrainRequest) -> StartTrainResponse:
81
+ """Create and immediately launch a training run."""
82
+ # Resolve friendly names (fall back to ids if registries unavailable)
83
+ model_name = body.model_id
84
+ dataset_name = body.dataset_id
85
+ try:
86
+ from registry.registry import get_model
87
+ m = await get_model(body.model_id)
88
+ if m:
89
+ model_name = m.name
90
+ except Exception:
91
+ pass
92
+ try:
93
+ from datasets.registry import get_dataset
94
+ d = await get_dataset(body.dataset_id)
95
+ if d:
96
+ dataset_name = d.get("name", body.dataset_id) if isinstance(d, dict) else getattr(d, "name", body.dataset_id)
97
+ except Exception:
98
+ pass
99
+
100
+ run = run_manager.create_run(
101
+ model_id=body.model_id,
102
+ model_name=model_name,
103
+ dataset_id=body.dataset_id,
104
+ dataset_name=dataset_name,
105
+ task=body.task,
106
+ hyperparams=body.hyperparams,
107
+ augmentation=body.augmentation,
108
+ scheduler=body.scheduler,
109
+ project_id=body.project_id
110
+ )
111
+ run_manager.start_run(run)
112
+
113
+ log.info("training_started", run_id=run.run_id, model=body.model_id)
114
+ return StartTrainResponse(
115
+ run_id=run.run_id,
116
+ status=run.status,
117
+ message=f"Training run {run.run_id} started.",
118
+ )
119
+
120
+
121
+ @router.post("/stop", status_code=200)
122
+ async def stop_training(body: StopTrainRequest) -> dict:
123
+ run = run_manager.get_run(body.run_id)
124
+ if not run:
125
+ raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
126
+ run_manager.stop_run(run)
127
+ log.info("training_stopped", run_id=body.run_id)
128
+ return {"run_id": body.run_id, "status": run.status}
129
+
130
+
131
+ @router.post("/pause", status_code=200)
132
+ async def pause_training(body: PauseTrainRequest) -> dict:
133
+ run = run_manager.get_run(body.run_id)
134
+ if not run:
135
+ raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
136
+ run_manager.pause_run(run)
137
+ return {"run_id": body.run_id, "status": run.status}
138
+
139
+
140
+ @router.post("/resume", status_code=200)
141
+ async def resume_training(body: ResumeTrainRequest) -> dict:
142
+ run = run_manager.get_run(body.run_id)
143
+ if not run:
144
+ raise HTTPException(status_code=404, detail=f"Run '{body.run_id}' not found")
145
+ run_manager.resume_run(run)
146
+ return {"run_id": body.run_id, "status": run.status}
147
+
148
+
149
+ @router.get("/status", response_model=TrainStatusResponse)
150
+ async def get_train_status(run_id: str = Query(...)) -> TrainStatusResponse:
151
+ run = run_manager.get_run(run_id)
152
+ if not run:
153
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
154
+ return TrainStatusResponse(
155
+ run_id=run.run_id,
156
+ status=run.status,
157
+ epoch=run.epoch,
158
+ total_epochs=run.total_epochs,
159
+ step=run.step,
160
+ total_steps=run.total_epochs * 100,
161
+ eta_seconds=run.eta_seconds,
162
+ elapsed_seconds=run.elapsed_seconds,
163
+ )
164
+
165
+
166
+ # ── Run history ───────────────────────────────────────────────────────────────
167
+
168
+ @router.get("/runs", response_model=list[TrainRunOut])
169
+ async def list_runs() -> list[TrainRunOut]:
170
+ return [_run_to_out(r) for r in reversed(run_manager.list_runs())]
171
+
172
+
173
+ @router.get("/runs/{run_id}", response_model=TrainRunOut)
174
+ async def get_run(run_id: str) -> TrainRunOut:
175
+ run = run_manager.get_run(run_id)
176
+ if not run:
177
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
178
+ return _run_to_out(run)
179
+
180
+
181
+ # ── Schema Engine ─────────────────────────────────────────────────────────────
182
+
183
+ @router.get("/schema", response_model=TrainingSchemaResponse)
184
+ async def get_schema(
185
+ model_id: str = Query(""),
186
+ dataset_id: str = Query(""),
187
+ task: str = Query("detection"),
188
+ ) -> TrainingSchemaResponse:
189
+ schema = generate_schema(task=task, model_id=model_id, dataset_id=dataset_id)
190
+ return TrainingSchemaResponse(**schema)
191
+
192
+
193
+ # ── Checkpoints (stub β€” extend when artifact storage is wired) ────────────────
194
+
195
+ @router.get("/checkpoints", response_model=list[CheckpointOut])
196
+ async def list_checkpoints(run_id: str = Query(...)) -> list[CheckpointOut]:
197
+ """Returns an empty list until checkpoint persistence is implemented."""
198
+ run = run_manager.get_run(run_id)
199
+ if not run:
200
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
201
+ return []
202
+
203
+
204
+ @router.post("/checkpoints/{checkpoint_id}/export")
205
+ async def export_checkpoint(checkpoint_id: str, body: dict = {}) -> dict:
206
+ raise HTTPException(status_code=501, detail="Checkpoint export not yet implemented")
207
+
208
+
209
+ # ── SSE: Metrics stream ────────────────────────────────────────────────────────
210
+
211
+ @router.get("/metrics/stream")
212
+ async def stream_metrics(run_id: str = Query(...)) -> StreamingResponse:
213
+ """
214
+ Server-Sent Events stream of TrainMetricsTick objects.
215
+ Connects to the run's metrics queue and forwards each tick as SSE.
216
+ Stream closes when the run finishes (sentinel None pushed by worker).
217
+ """
218
+ run = run_manager.get_run(run_id)
219
+ if not run:
220
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
221
+
222
+ q: asyncio.Queue = asyncio.Queue()
223
+ run.metrics_subs.append(q)
224
+
225
+ async def generator():
226
+ yield ": connected\n\n"
227
+ try:
228
+ while True:
229
+ try:
230
+ tick = await asyncio.wait_for(q.get(), timeout=30.0)
231
+ except asyncio.TimeoutError:
232
+ # Heartbeat to keep connection alive
233
+ yield ": heartbeat\n\n"
234
+ continue
235
+ if tick is None:
236
+ break
237
+ yield f"data: {json.dumps(tick)}\n\n"
238
+ finally:
239
+ if q in run.metrics_subs:
240
+ run.metrics_subs.remove(q)
241
+
242
+ return StreamingResponse(
243
+ generator(),
244
+ media_type="text/event-stream",
245
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
246
+ )
247
+
248
+
249
+ # ── SSE: Logs stream ──────────────────────────────────────────────────────────
250
+
251
+ @router.get("/logs/stream")
252
+ async def stream_logs(run_id: str = Query(...)) -> StreamingResponse:
253
+ """Server-Sent Events stream of LogEntry objects."""
254
+ run = run_manager.get_run(run_id)
255
+ if not run:
256
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
257
+
258
+ q: asyncio.Queue = asyncio.Queue()
259
+ run.log_subs.append(q)
260
+
261
+ async def generator():
262
+ yield ": connected\n\n"
263
+ try:
264
+ while True:
265
+ try:
266
+ entry = await asyncio.wait_for(q.get(), timeout=30.0)
267
+ except asyncio.TimeoutError:
268
+ yield ": heartbeat\n\n"
269
+ continue
270
+ if entry is None:
271
+ break
272
+ yield f"data: {json.dumps(entry)}\n\n"
273
+ finally:
274
+ if q in run.log_subs:
275
+ run.log_subs.remove(q)
276
+
277
+ return StreamingResponse(
278
+ generator(),
279
+ media_type="text/event-stream",
280
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
281
+ )
282
+
283
+
284
+ @router.get("/runs/{run_id}/history")
285
+ async def get_run_history(run_id: str) -> list[dict]:
286
+ """Retrieves the full historical telemetry (metrics ticks) for a run."""
287
+ run = run_manager.get_run(run_id)
288
+ if not run:
289
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
290
+
291
+ from training.persistence import TrainingPersistence
292
+ run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
293
+ telemetry_path = os.path.join(run_dir, "telemetry.jsonl")
294
+
295
+ history = []
296
+ if os.path.exists(telemetry_path):
297
+ try:
298
+ with open(telemetry_path, "r") as f:
299
+ for line in f:
300
+ if line.strip():
301
+ history.append(json.loads(line))
302
+ except Exception as e:
303
+ log.error("history_read_failed", run_id=run_id, error=str(e))
304
+ raise HTTPException(status_code=500, detail="Failed to read telemetry history")
305
+
306
+ return history
307
+
308
+ @router.get("/runs/{run_id}/artifacts")
309
+ async def list_run_artifacts(run_id: str) -> dict:
310
+ """Lists available artifacts (images) for a specific run by scanning the directory."""
311
+ run = run_manager.get_run(run_id)
312
+ if not run:
313
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
314
+
315
+ from training.persistence import TrainingPersistence
316
+ run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
317
+
318
+ if not os.path.exists(run_dir):
319
+ return {"artifacts": [], "batches": []}
320
+
321
+ artifacts = []
322
+ batches = []
323
+
324
+ # Standard YOLO artifact mappings for better UI titles
325
+ titles = {
326
+ "confusion_matrix.png": "Confusion Matrix",
327
+ "confusion_matrix_normalized.png": "Confusion Matrix (Norm)",
328
+ "results.png": "Results Summary",
329
+ "F1_curve.png": "F1 Curve",
330
+ "PR_curve.png": "PR Curve",
331
+ "P_curve.png": "Precision Curve",
332
+ "R_curve.png": "Recall Curve",
333
+ "BoxF1_curve.png": "Box F1 Curve",
334
+ "BoxP_curve.png": "Box Precision Curve",
335
+ "BoxPR_curve.png": "Box PR Curve",
336
+ "BoxR_curve.png": "Box Recall Curve",
337
+ "labels.jpg": "Labels Distribution",
338
+ "labels_correlogram.jpg": "Labels Correlogram"
339
+ }
340
+
341
+ for f in os.listdir(run_dir):
342
+ path = f"/train/runs/{run_id}/files/{f}"
343
+ if f.endswith(('.png', '.jpg', '.jpeg')):
344
+ item = {
345
+ "title": titles.get(f, f.replace('_', ' ').title().split('.')[0]),
346
+ "path": path,
347
+ "type": "Analysis"
348
+ }
349
+
350
+ if "batch" in f.lower():
351
+ item["type"] = "Batch Preview" if "val" in f.lower() else "Augmentation"
352
+ batches.append(item)
353
+ else:
354
+ if "curve" in f.lower():
355
+ item["type"] = "Precision-Recall"
356
+ elif "confusion" in f.lower():
357
+ item["type"] = "Analysis"
358
+ elif "results" in f.lower():
359
+ item["type"] = "Overall"
360
+ artifacts.append(item)
361
+
362
+ return {
363
+ "artifacts": sorted(artifacts, key=lambda x: x['title']),
364
+ "batches": sorted(batches, key=lambda x: x['title'])
365
+ }
366
+
367
+ @router.get("/runs/{run_id}/files/{filename}")
368
+ async def get_run_file(run_id: str, filename: str):
369
+ """Serves a specific file from the run directory."""
370
+ run = run_manager.get_run(run_id)
371
+ if not run:
372
+ raise HTTPException(status_code=404, detail="Run not found")
373
+
374
+ # We need to find the project to get the run_dir
375
+ # Since run_manager doesn't easily expose the full path in memory,
376
+ # we recalculate it using persistence
377
+ from training.persistence import TrainingPersistence
378
+ run_dir = await TrainingPersistence.get_run_dir(run.project_id or "default", run_id)
379
+ file_path = os.path.join(run_dir, filename)
380
+
381
+ if not os.path.exists(file_path):
382
+ raise HTTPException(status_code=404, detail="File not found")
383
+
384
+ from fastapi.responses import FileResponse
385
+ return FileResponse(file_path)
386
+ # The frontend uses /system/metrics/stream for resources (already implemented).
387
+ # This alias exists for training-scoped resource monitoring.
388
+
389
+ @router.get("/resources/stream")
390
+ async def stream_resources(
391
+ run_id: str = Query(...),
392
+ gpu_index: int = Query(0, ge=0),
393
+ hz: float = Query(1.0, ge=0.2, le=10.0),
394
+ ) -> StreamingResponse:
395
+ """
396
+ SSE stream of ResourceTick objects for a specific training run.
397
+ Forwards system metrics at the requested hz rate.
398
+ """
399
+ run = run_manager.get_run(run_id)
400
+ if not run:
401
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
402
+
403
+ q: asyncio.Queue = asyncio.Queue()
404
+ run.resource_subs.append(q)
405
+
406
+ interval = 1.0 / hz
407
+
408
+ async def generator():
409
+ yield ": connected\n\n"
410
+ try:
411
+ while True:
412
+ try:
413
+ tick = await asyncio.wait_for(q.get(), timeout=30.0)
414
+ except asyncio.TimeoutError:
415
+ yield ": heartbeat\n\n"
416
+ continue
417
+ if tick is None:
418
+ break
419
+ yield f"data: {json.dumps(tick)}\n\n"
420
+ finally:
421
+ if q in run.resource_subs:
422
+ run.resource_subs.remove(q)
423
+
424
+ return StreamingResponse(
425
+ generator(),
426
+ media_type="text/event-stream",
427
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
428
+ )
config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ config.py β€” Centralized application settings.
3
+ All tuneable knobs live here; override via environment variables.
4
+ """
5
+ from pathlib import Path
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ model_config = SettingsConfigDict(
11
+ env_file=".env",
12
+ env_file_encoding="utf-8",
13
+ case_sensitive=False,
14
+ )
15
+
16
+ # ── App ───────────────────────────────────────────────────────────
17
+ app_name: str = "MLForge Platform"
18
+ version: str = "1.0.0"
19
+ debug: bool = False
20
+
21
+ # ── API ───────────────────────────────────────────────────────────
22
+ host: str = "0.0.0.0"
23
+ port: int = 8005
24
+ cors_origins: list[str] = [
25
+ "http://localhost:3000",
26
+ "http://127.0.0.1:3000",
27
+ "http://localhost:5173",
28
+ "http://127.0.0.1:5173",
29
+ "http://localhost:2000",
30
+ "http://127.0.0.1:2000",
31
+ ]
32
+
33
+ # ── Storage ───────────────────────────────────────────────────────
34
+ base_dir: Path = Path(__file__).resolve().parents[1]
35
+ data_dir: Path = base_dir / "data"
36
+ models_dir: Path = data_dir / "models"
37
+ datasets_dir: Path = data_dir / "datasets" # root for imported datasets
38
+ logs_dir: Path = data_dir / "logs"
39
+ db_path: Path = data_dir / "modelzoo.db"
40
+
41
+ # ── Download Manager ──────────────────────────────────────────────
42
+ max_concurrent_downloads: int = 5
43
+ download_chunk_size: int = 1024 * 1024 # 1 MB
44
+ download_max_retries: int = 3
45
+ download_retry_delay: float = 2.0 # seconds (base, exponential backoff)
46
+
47
+ # ── Search ────────────────────────────────────────────────────────
48
+ search_max_results: int = 500
49
+
50
+ # ── Sync ──────────────────────────────────────────────────────────
51
+ auto_sync_on_startup: bool = True
52
+
53
+ # ── Hugging Face API ──────────────────────────────────────────────
54
+ hf_api_base: str = "https://huggingface.co/api"
55
+ hf_hub_url: str = "https://huggingface.co"
56
+ hf_token: str | None = None # Optional: HF_TOKEN env var
57
+ hf_models_per_task: int = 100 # How many to pull per task
58
+
59
+ # ── ONNX Zoo ──────────────────────────────────────────────────────
60
+ onnx_models_url: str = (
61
+ "https://raw.githubusercontent.com/onnx/models/main/README.md"
62
+ )
63
+
64
+ # ── Benchmark Bridge ──────────────────────────────────────────────
65
+ benchmark_max_concurrent: int = 3 # max parallel benchmark jobs
66
+ benchmark_max_log_lines: int = 500 # log entries kept per job
67
+ benchmark_ws_poll_hz: float = 2.0 # WebSocket telemetry poll rate
68
+
69
+ # ── Dataset Manager ───────────────────────────────────────────────
70
+ roboflow_api_base: str = "https://api.roboflow.com"
71
+ dataset_import_workers: int = 3 # max concurrent import jobs
72
+ dataset_chunk_size: int = 1024 * 1024 * 4 # 4 MB download chunk
73
+ roboflow_cache_ttl_secs: int = 3600 # 1 hour
74
+
75
+ def ensure_dirs(self) -> None:
76
+ self.data_dir.mkdir(parents=True, exist_ok=True)
77
+ self.models_dir.mkdir(parents=True, exist_ok=True)
78
+ self.datasets_dir.mkdir(parents=True, exist_ok=True)
79
+ (self.datasets_dir / "_tmp").mkdir(parents=True, exist_ok=True)
80
+ self.logs_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+
83
+ settings = Settings()
database/__init__.py ADDED
File without changes
database/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
database/__pycache__/connection.cpython-310.pyc ADDED
Binary file (3.6 kB). View file
 
database/benchmark_schema.sql ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MLForge Benchmark Bridge β€” SQLite Schema
3
+ -- Version: 1.0.0
4
+ -- ============================================================
5
+
6
+ PRAGMA journal_mode = WAL;
7
+ PRAGMA foreign_keys = ON;
8
+
9
+ -- ── Benchmark Jobs ────────────────────────────────────────────────────────────
10
+ -- Tracks every benchmark run from queued β†’ running β†’ completed/failed.
11
+ -- config stores the full BenchmarkContext JSON for full reproducibility.
12
+ CREATE TABLE IF NOT EXISTS benchmark_jobs (
13
+ id TEXT PRIMARY KEY,
14
+ model_id TEXT NOT NULL,
15
+ dataset_id TEXT NOT NULL,
16
+ task TEXT NOT NULL,
17
+ framework TEXT NOT NULL,
18
+ hardware TEXT NOT NULL DEFAULT 'cpu',
19
+ precision TEXT NOT NULL DEFAULT 'FP32',
20
+ batch_size INTEGER NOT NULL DEFAULT 1,
21
+ config TEXT NOT NULL DEFAULT '{}', -- full BenchmarkContext JSON
22
+ status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed
23
+ progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0
24
+ logs TEXT NOT NULL DEFAULT '[]', -- JSON array of timestamped log strings
25
+ error TEXT,
26
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
27
+ updated_at TEXT NOT NULL DEFAULT (datetime('now')),
28
+ started_at TEXT,
29
+ ended_at TEXT
30
+ );
31
+
32
+ -- ── Benchmark Results ─────────────────────────────────────────────────────────
33
+ -- Stores final computed metrics + telemetry summary after job completion.
34
+ CREATE TABLE IF NOT EXISTS benchmark_results (
35
+ id TEXT PRIMARY KEY,
36
+ job_id TEXT NOT NULL REFERENCES benchmark_jobs(id) ON DELETE CASCADE,
37
+ metrics TEXT NOT NULL DEFAULT '{}', -- JSON: BenchmarkMetrics
38
+ telemetry_summary TEXT NOT NULL DEFAULT '{}', -- JSON: TelemetrySummary
39
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
40
+ );
41
+
42
+ -- ── Validation Logs ───────────────────────────────────────────────────────────
43
+ -- Immutable audit trail of every compatibility check performed.
44
+ -- job_id = 'pre-check' for validations that blocked job creation.
45
+ CREATE TABLE IF NOT EXISTS benchmark_validation_logs (
46
+ id TEXT PRIMARY KEY,
47
+ job_id TEXT NOT NULL,
48
+ model_id TEXT NOT NULL,
49
+ dataset_id TEXT NOT NULL,
50
+ checks TEXT NOT NULL DEFAULT '[]', -- JSON: list[ValidationCheck]
51
+ passed INTEGER NOT NULL DEFAULT 1, -- 1=passed, 0=failed
52
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
53
+ );
54
+
55
+ -- ── Indexes ───────────────────────────────────────────────────────────────────
56
+ CREATE INDEX IF NOT EXISTS idx_bmark_jobs_status ON benchmark_jobs(status);
57
+ CREATE INDEX IF NOT EXISTS idx_bmark_jobs_model ON benchmark_jobs(model_id);
58
+ CREATE INDEX IF NOT EXISTS idx_bmark_jobs_dataset ON benchmark_jobs(dataset_id);
59
+ CREATE INDEX IF NOT EXISTS idx_bmark_jobs_created ON benchmark_jobs(created_at DESC);
60
+ CREATE INDEX IF NOT EXISTS idx_bmark_results_job ON benchmark_results(job_id);
61
+ CREATE INDEX IF NOT EXISTS idx_bmark_valid_job ON benchmark_validation_logs(job_id);
62
+ CREATE INDEX IF NOT EXISTS idx_bmark_valid_model ON benchmark_validation_logs(model_id);
database/connection.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ database/connection.py β€” Async SQLite connection & migration bootstrap.
3
+ Single module responsible for DB lifecycle. All queries use this pool.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ from pathlib import Path
9
+
10
+ import aiosqlite
11
+
12
+ from config import settings
13
+
14
+ # Module-level connection (shared within the process)
15
+ _db: aiosqlite.Connection | None = None
16
+ _lock = asyncio.Lock()
17
+
18
+
19
+ async def get_db() -> aiosqlite.Connection:
20
+ """Return the singleton async database connection."""
21
+ global _db
22
+ async with _lock:
23
+ if _db is None:
24
+ _db = await _open_connection()
25
+ return _db
26
+
27
+
28
+ async def _open_connection() -> aiosqlite.Connection:
29
+ settings.ensure_dirs()
30
+ conn = await aiosqlite.connect(settings.db_path, check_same_thread=False)
31
+ conn.row_factory = aiosqlite.Row
32
+ await conn.execute("PRAGMA journal_mode=WAL")
33
+ await conn.execute("PRAGMA foreign_keys=ON")
34
+ await conn.execute("PRAGMA synchronous=NORMAL")
35
+ await conn.execute("PRAGMA cache_size=-65536") # 64 MB page cache
36
+ await _run_migrations(conn)
37
+ await conn.commit()
38
+ return conn
39
+
40
+
41
+ async def _run_migrations(conn: aiosqlite.Connection) -> None:
42
+ """Apply all schema files idempotently (CREATE IF NOT EXISTS)."""
43
+ base = Path(__file__).parent
44
+
45
+ # ── STEP 1: Ensure basic tables exist ──
46
+ for schema_file in ["schema.sql", "dataset_schema.sql", "benchmark_schema.sql"]:
47
+ path = base / schema_file
48
+ if path.exists():
49
+ sql = path.read_text(encoding="utf-8")
50
+ await conn.executescript(sql)
51
+
52
+ # ── STEP 2: Legacy Alterations ──
53
+ # Check 'models' table for specific columns
54
+ async with conn.execute("PRAGMA table_info(models)") as cur:
55
+ cols = {r[1] for r in await cur.fetchall()}
56
+
57
+ if cols: # only if table exists
58
+ if "download_url" not in cols:
59
+ await conn.execute("ALTER TABLE models ADD COLUMN download_url TEXT")
60
+
61
+ if "active_version" not in cols:
62
+ await conn.execute("ALTER TABLE models ADD COLUMN active_version TEXT")
63
+
64
+ if "metrics" not in cols:
65
+ await conn.execute("ALTER TABLE models ADD COLUMN metrics TEXT NOT NULL DEFAULT '{}' ")
66
+
67
+ # Check 'datasets' table for new columns (e.g. active_version)
68
+ async with conn.execute("PRAGMA table_info(datasets)") as cur:
69
+ ds_cols = {r[1] for r in await cur.fetchall()}
70
+
71
+ if ds_cols:
72
+ if "active_version" not in ds_cols:
73
+ await conn.execute("ALTER TABLE datasets ADD COLUMN active_version TEXT NOT NULL DEFAULT 'v1'")
74
+ if "roboflow_id" not in ds_cols:
75
+ await conn.execute("ALTER TABLE datasets ADD COLUMN roboflow_id TEXT")
76
+ if "health_score" not in ds_cols:
77
+ await conn.execute("ALTER TABLE datasets ADD COLUMN health_score INTEGER NOT NULL DEFAULT 0")
78
+
79
+ # Check 'models' table for project_id
80
+ async with conn.execute("PRAGMA table_info(models)") as cur:
81
+ model_cols = {r[1] for r in await cur.fetchall()}
82
+
83
+ if model_cols and "project_id" not in model_cols:
84
+ await conn.execute("ALTER TABLE models ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE CASCADE")
85
+
86
+ # Clean up any lingering temporary tables from failed legacy migrations
87
+ # COMMIT is essential here to ensure background jobs see the clean state immediately
88
+ # We use a try/except block to avoid "no such table" errors if the table is already gone
89
+ try:
90
+ await conn.execute("DROP TABLE IF EXISTS datasets_old")
91
+ except:
92
+ pass
93
+
94
+ try:
95
+ await conn.execute("DROP TABLE IF EXISTS dataset_jobs_old")
96
+ except:
97
+ pass
98
+
99
+ await conn.commit()
100
+
101
+ async def close_db() -> None:
102
+ global _db
103
+ async with _lock:
104
+ if _db is not None:
105
+ await _db.close()
106
+ _db = None
database/dataset_schema.sql ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MLForge Dataset Manager β€” SQLite Schema Extension
3
+ -- Appended to existing modelzoo.db (CREATE IF NOT EXISTS)
4
+ -- ============================================================
5
+
6
+ -- ── Datasets ──────────────────────────────────────────────
7
+ CREATE TABLE IF NOT EXISTS datasets (
8
+ id TEXT PRIMARY KEY,
9
+ name TEXT NOT NULL,
10
+ description TEXT NOT NULL DEFAULT '',
11
+ task TEXT NOT NULL,
12
+ format TEXT NOT NULL,
13
+ source TEXT NOT NULL DEFAULT 'roboflow',
14
+ status TEXT NOT NULL DEFAULT 'available',
15
+ images INTEGER NOT NULL DEFAULT 0,
16
+ classes INTEGER NOT NULL DEFAULT 0,
17
+ class_names TEXT NOT NULL DEFAULT '[]', -- JSON array
18
+ size_bytes INTEGER NOT NULL DEFAULT 0,
19
+ size_label TEXT NOT NULL DEFAULT '0 B',
20
+ local_path TEXT,
21
+ import_progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0
22
+ tags TEXT NOT NULL DEFAULT '[]', -- JSON array
23
+ versions TEXT NOT NULL DEFAULT '[]', -- JSON array
24
+ active_version TEXT NOT NULL DEFAULT 'v1',
25
+ starred INTEGER NOT NULL DEFAULT 0,
26
+ roboflow_id TEXT, -- workspace/project slug
27
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
28
+ updated_at TEXT NOT NULL DEFAULT (datetime('now'))
29
+ );
30
+
31
+ -- ── Dataset Jobs ──────────────────────────────────────────
32
+ CREATE TABLE IF NOT EXISTS dataset_jobs (
33
+ id TEXT PRIMARY KEY,
34
+ type TEXT NOT NULL, -- import|extract|validate|analyze
35
+ status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled
36
+ dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE,
37
+ dataset_name TEXT NOT NULL DEFAULT '',
38
+ progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0
39
+ message TEXT NOT NULL DEFAULT '',
40
+ error TEXT,
41
+ meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data
42
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
43
+ updated_at TEXT NOT NULL DEFAULT (datetime('now')),
44
+ started_at TEXT,
45
+ ended_at TEXT
46
+ );
47
+
48
+ -- ── Dataset Images Index ──────────────────────────────────
49
+ -- Populated after extraction; enables fast paginated viewer queries
50
+ CREATE TABLE IF NOT EXISTS dataset_images (
51
+ id TEXT PRIMARY KEY, -- sha1 or sequential id
52
+ dataset_id TEXT NOT NULL REFERENCES datasets(id) ON DELETE CASCADE,
53
+ filename TEXT NOT NULL,
54
+ rel_path TEXT NOT NULL, -- relative to dataset local_path
55
+ width INTEGER NOT NULL DEFAULT 0,
56
+ height INTEGER NOT NULL DEFAULT 0,
57
+ split TEXT NOT NULL DEFAULT 'train',
58
+ ann_count INTEGER NOT NULL DEFAULT 0, -- fast count without parsing
59
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
60
+ );
61
+
62
+ -- ── Dataset Annotations Cache ─────────────────────────────
63
+ -- Parsed annotations stored in normalised form for fast retrieval
64
+ CREATE TABLE IF NOT EXISTS dataset_annotations (
65
+ id TEXT PRIMARY KEY,
66
+ image_id TEXT NOT NULL REFERENCES dataset_images(id) ON DELETE CASCADE,
67
+ dataset_id TEXT NOT NULL,
68
+ label TEXT NOT NULL,
69
+ bbox_x REAL,
70
+ bbox_y REAL,
71
+ bbox_w REAL,
72
+ bbox_h REAL,
73
+ normalised INTEGER DEFAULT 1,
74
+ area REAL,
75
+ confidence REAL,
76
+ ann_type TEXT DEFAULT 'detection',
77
+ segmentation TEXT, -- JSON array of points [[x,y],...]
78
+ keypoints TEXT, -- JSON array of keypoints [x,y,v,...]
79
+ metadata TEXT -- Extra JSON metadata
80
+ );
81
+
82
+ -- ── Roboflow Metadata Cache ───────────────────────────────
83
+ -- Avoids redundant API calls; TTL enforced in Python layer
84
+ CREATE TABLE IF NOT EXISTS roboflow_cache (
85
+ cache_key TEXT PRIMARY KEY, -- workspace/project or search query hash
86
+ payload TEXT NOT NULL, -- JSON blob
87
+ fetched_at TEXT NOT NULL DEFAULT (datetime('now')),
88
+ ttl_secs INTEGER NOT NULL DEFAULT 3600 -- 1 hour default
89
+ );
90
+
91
+ -- ── Indexes ───────────────────────────────────────────────
92
+ CREATE INDEX IF NOT EXISTS idx_datasets_task ON datasets(task);
93
+ CREATE INDEX IF NOT EXISTS idx_datasets_format ON datasets(format);
94
+ CREATE INDEX IF NOT EXISTS idx_datasets_source ON datasets(source);
95
+ CREATE INDEX IF NOT EXISTS idx_datasets_status ON datasets(status);
96
+ CREATE INDEX IF NOT EXISTS idx_datasets_starred ON datasets(starred);
97
+
98
+ CREATE INDEX IF NOT EXISTS idx_djobs_status ON dataset_jobs(status);
99
+ CREATE INDEX IF NOT EXISTS idx_djobs_dataset ON dataset_jobs(dataset_id);
100
+
101
+ CREATE INDEX IF NOT EXISTS idx_dimages_dataset ON dataset_images(dataset_id);
102
+ CREATE INDEX IF NOT EXISTS idx_dimages_split ON dataset_images(dataset_id, split);
103
+
104
+ CREATE INDEX IF NOT EXISTS idx_dann_image ON dataset_annotations(image_id);
105
+ CREATE INDEX IF NOT EXISTS idx_dann_dataset ON dataset_annotations(dataset_id);
106
+ CREATE INDEX IF NOT EXISTS idx_dann_label ON dataset_annotations(dataset_id, label);
107
+
108
+ -- ── Updated-at trigger for datasets ──────────────────────
109
+ CREATE TRIGGER IF NOT EXISTS datasets_updated_at
110
+ AFTER UPDATE ON datasets BEGIN
111
+ UPDATE datasets SET updated_at = datetime('now') WHERE id = NEW.id;
112
+ END;
113
+
114
+ CREATE TRIGGER IF NOT EXISTS dataset_jobs_updated_at
115
+ AFTER UPDATE ON dataset_jobs BEGIN
116
+ UPDATE dataset_jobs SET updated_at = datetime('now') WHERE id = NEW.id;
117
+ END;
database/schema.sql ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MLForge Model Zoo β€” SQLite Schema
3
+ -- Version: 1.0.0
4
+ -- ============================================================
5
+
6
+ PRAGMA journal_mode = WAL;
7
+ PRAGMA foreign_keys = ON;
8
+
9
+ -- ── Models ────────────────────────────────────────────────
10
+ CREATE TABLE IF NOT EXISTS models (
11
+ id TEXT PRIMARY KEY,
12
+ name TEXT NOT NULL,
13
+ variant TEXT,
14
+ task TEXT NOT NULL,
15
+ framework TEXT NOT NULL,
16
+ source TEXT NOT NULL DEFAULT 'hf',
17
+ provider TEXT NOT NULL DEFAULT '',
18
+ description TEXT NOT NULL DEFAULT '',
19
+ download_url TEXT, -- explicit download source URL
20
+ size INTEGER NOT NULL DEFAULT 0,
21
+ size_label TEXT NOT NULL DEFAULT '0B',
22
+ tags TEXT NOT NULL DEFAULT '[]', -- JSON array
23
+ hardware TEXT NOT NULL DEFAULT '[]', -- JSON array
24
+ status TEXT NOT NULL DEFAULT 'available',
25
+ downloaded INTEGER NOT NULL DEFAULT 0,
26
+ active_version TEXT,
27
+ local_path TEXT,
28
+ metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc.
29
+ downloads INTEGER DEFAULT 0,
30
+ rating REAL,
31
+ liked INTEGER NOT NULL DEFAULT 0,
32
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
33
+ updated_at TEXT NOT NULL DEFAULT (datetime('now'))
34
+ );
35
+
36
+ -- ── Model Versions ────────────────────────────────────────
37
+ CREATE TABLE IF NOT EXISTS model_versions (
38
+ version_id TEXT PRIMARY KEY,
39
+ model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE,
40
+ version TEXT NOT NULL,
41
+ label TEXT NOT NULL DEFAULT 'Stable', -- Latest|Stable|Legacy
42
+ description TEXT,
43
+ metrics TEXT NOT NULL DEFAULT '{}', -- JSON: latency, mAP, etc.
44
+ local_path TEXT,
45
+ downloaded INTEGER NOT NULL DEFAULT 0,
46
+ release_date TEXT,
47
+ changelog TEXT,
48
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
49
+ );
50
+
51
+ -- ── Jobs ─────────────────────────────────────────────────-
52
+ CREATE TABLE IF NOT EXISTS jobs (
53
+ id TEXT PRIMARY KEY,
54
+ type TEXT NOT NULL, -- download|benchmark|sync
55
+ status TEXT NOT NULL DEFAULT 'queued', -- queued|running|completed|failed|cancelled
56
+ model_id TEXT REFERENCES models(id),
57
+ model_name TEXT,
58
+ progress REAL NOT NULL DEFAULT 0.0, -- 0.0–1.0
59
+ error TEXT,
60
+ meta TEXT NOT NULL DEFAULT '{}', -- JSON extra data
61
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
62
+ updated_at TEXT NOT NULL DEFAULT (datetime('now')),
63
+ started_at TEXT,
64
+ ended_at TEXT
65
+ );
66
+
67
+ -- ── Projects ─────────────────────────────────────────────
68
+ CREATE TABLE IF NOT EXISTS projects (
69
+ id TEXT PRIMARY KEY,
70
+ name TEXT NOT NULL,
71
+ path TEXT NOT NULL,
72
+ created_at TEXT NOT NULL,
73
+ last_opened TEXT NOT NULL,
74
+ status TEXT NOT NULL DEFAULT 'idle'
75
+ );
76
+
77
+ -- ── Session ───────────────────────────────────────────────
78
+ -- Stores the currently active project so backend services
79
+ -- (e.g. download manager) can link assets into the workspace.
80
+ CREATE TABLE IF NOT EXISTS session (
81
+ key TEXT PRIMARY KEY,
82
+ value TEXT NOT NULL
83
+ );
84
+
85
+ CREATE UNIQUE INDEX IF NOT EXISTS idx_projects_path ON projects(path);
86
+
87
+ -- ── Audit Log ─────────────────────────────────────────────
88
+ CREATE TABLE IF NOT EXISTS audit_log (
89
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
90
+ event_type TEXT NOT NULL, -- api_request|download_start|download_ok|error|sync
91
+ model_id TEXT,
92
+ job_id TEXT,
93
+ payload TEXT NOT NULL DEFAULT '{}', -- JSON
94
+ level TEXT NOT NULL DEFAULT 'info',
95
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
96
+ );
97
+
98
+ -- ── FTS (Full-Text Search) ────────────────────────────────
99
+ CREATE VIRTUAL TABLE IF NOT EXISTS models_fts USING fts5(
100
+ id UNINDEXED,
101
+ name,
102
+ description,
103
+ tags,
104
+ provider,
105
+ task,
106
+ framework,
107
+ content='models',
108
+ content_rowid='rowid'
109
+ );
110
+
111
+ -- Triggers to keep FTS in sync
112
+ CREATE TRIGGER IF NOT EXISTS models_fts_insert AFTER INSERT ON models BEGIN
113
+ INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework)
114
+ VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework);
115
+ END;
116
+
117
+ CREATE TRIGGER IF NOT EXISTS models_fts_delete BEFORE DELETE ON models BEGIN
118
+ DELETE FROM models_fts WHERE rowid = old.rowid;
119
+ END;
120
+
121
+ CREATE TRIGGER IF NOT EXISTS models_fts_update AFTER UPDATE ON models BEGIN
122
+ DELETE FROM models_fts WHERE rowid = old.rowid;
123
+ INSERT INTO models_fts(rowid, id, name, description, tags, provider, task, framework)
124
+ VALUES (new.rowid, new.id, new.name, new.description, new.tags, new.provider, new.task, new.framework);
125
+ END;
126
+
127
+ -- ── Inference History ────────────────────────────────────
128
+ CREATE TABLE IF NOT EXISTS inference_history (
129
+ id TEXT PRIMARY KEY,
130
+ model_id TEXT NOT NULL REFERENCES models(id) ON DELETE CASCADE,
131
+ model_name TEXT NOT NULL,
132
+ adapter_type TEXT NOT NULL,
133
+ timestamp REAL NOT NULL DEFAULT (unixepoch('now')),
134
+ total_ms REAL NOT NULL DEFAULT 0.0,
135
+ quality_score REAL,
136
+ status TEXT NOT NULL DEFAULT 'ok',
137
+ request_snapshot TEXT NOT NULL DEFAULT '{}' -- JSON
138
+ );
139
+
140
+ CREATE INDEX IF NOT EXISTS idx_inference_model ON inference_history(model_id);
141
+ CREATE INDEX IF NOT EXISTS idx_inference_time ON inference_history(timestamp DESC);
142
+
143
+ -- ── Indexes ───────────────────────────────────────────────
144
+ CREATE INDEX IF NOT EXISTS idx_models_task ON models(task);
145
+ CREATE INDEX IF NOT EXISTS idx_models_framework ON models(framework);
146
+ CREATE INDEX IF NOT EXISTS idx_models_source ON models(source);
147
+ CREATE INDEX IF NOT EXISTS idx_models_status ON models(status);
148
+ CREATE INDEX IF NOT EXISTS idx_models_downloads ON models(downloads DESC);
149
+ CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
150
+ CREATE INDEX IF NOT EXISTS idx_jobs_model ON jobs(model_id);
151
+ CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log(event_type);
152
+ CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log(created_at DESC);
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ main.py β€” FastAPI application entry point.
3
+ Wires together all modules, registers middleware/routes, manages lifespan.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import sys
9
+
10
+ # Ensure backend root is in sys.path to resolve 'backend.*' imports correctly
11
+ # when running from the 'backend' directory.
12
+ backend_root = os.path.dirname(os.path.abspath(__file__))
13
+ if backend_root not in sys.path:
14
+ sys.path.insert(0, backend_root)
15
+
16
+ import asyncio
17
+ from contextlib import asynccontextmanager
18
+ from typing import AsyncIterator
19
+
20
+ import traceback
21
+
22
+ from fastapi import FastAPI, Request
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.responses import JSONResponse
25
+
26
+ from api.routes import jobs as jobs_router
27
+ from api.routes import models as models_router
28
+ from api.routes import sync as sync_router
29
+ from api.routes import datasets as datasets_router
30
+ from api.routes import benchmark as benchmark_router
31
+ from api.routes import system as system_router
32
+ from api.routes import projects as projects_router
33
+ from api.routes import inference as inference_router
34
+ from api.routes import training as training_router
35
+ from config import settings
36
+ from database.connection import close_db, get_db
37
+ from middleware.logging_middleware import RequestLoggingMiddleware
38
+ from observability.logger import configure_logging, get_logger
39
+
40
+ # ── Logging bootstrap (must be first) ─────────────────────────────────────────
41
+ configure_logging()
42
+ log = get_logger("main")
43
+
44
+
45
+ # ── Lifespan ──────────────────────────────────────────────────────────────────
46
+ @asynccontextmanager
47
+ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
48
+ # Startup
49
+ settings.ensure_dirs()
50
+ log.info("startup", host=settings.host, port=settings.port, version=settings.version)
51
+ await get_db() # Bootstrap DB / run migrations
52
+ log.info("database_ready", path=str(settings.db_path))
53
+
54
+ # Job Recovery (Cleanup stale imports/benchmarks)
55
+ try:
56
+ from datasets.import_service import recover_stale_jobs
57
+ await recover_stale_jobs()
58
+ except Exception as e:
59
+ log.error("job_recovery_failed", error=str(e))
60
+
61
+ if settings.auto_sync_on_startup:
62
+ from registry.registry import count_models
63
+
64
+ current = await count_models()
65
+ if current == 0:
66
+ from api.routes.sync import _run_full_sync
67
+
68
+ log.info("auto_sync_startup_triggered")
69
+ asyncio.create_task(_run_full_sync())
70
+
71
+ yield # ← app runs
72
+
73
+ # Shutdown
74
+ await close_db()
75
+ log.info("shutdown")
76
+
77
+
78
+ # ── Application ───────────────────────────────────────────────────────────────
79
+ app = FastAPI(
80
+ title=settings.app_name,
81
+ version=settings.version,
82
+ description="Production ML Model Zoo backend β€” local-first, traceable, extensible.",
83
+ docs_url="/docs",
84
+ redoc_url="/redoc",
85
+ lifespan=lifespan,
86
+ )
87
+
88
+
89
+ @app.exception_handler(Exception)
90
+ async def global_exception_handler(request: Request, exc: Exception):
91
+ # Log full traceback for debugging 500s.
92
+ log.error(
93
+ "unhandled_exception",
94
+ path=request.url.path,
95
+ error=str(exc),
96
+ traceback=traceback.format_exc(),
97
+ )
98
+ return JSONResponse(
99
+ status_code=500,
100
+ content={"detail": "Internal Server Error", "error": str(exc)},
101
+ )
102
+
103
+ # ── Middleware ─────────────────────────────────────────────────────────────────
104
+ app.add_middleware(
105
+ CORSMiddleware,
106
+ allow_origins=settings.cors_origins,
107
+ allow_origin_regex=r"^https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?$",
108
+ allow_credentials=True,
109
+ allow_methods=["*"],
110
+ allow_headers=["*"],
111
+ )
112
+ app.add_middleware(RequestLoggingMiddleware)
113
+
114
+ # ── Routes ────────────────────────────────────────────────────────────────────
115
+ app.include_router(models_router.router)
116
+ app.include_router(jobs_router.router)
117
+ app.include_router(sync_router.router)
118
+ app.include_router(datasets_router.router)
119
+ app.include_router(benchmark_router.router)
120
+ app.include_router(system_router.router)
121
+ app.include_router(projects_router.router)
122
+ app.include_router(inference_router.router)
123
+ app.include_router(training_router.router)
124
+
125
+
126
+ @app.get("/health", tags=["system"])
127
+ async def health() -> dict:
128
+ from registry.registry import count_models
129
+ from datasets.registry import count_datasets
130
+ n_models = await count_models()
131
+ n_datasets = await count_datasets()
132
+ return {
133
+ "status": "ok",
134
+ "version": settings.version,
135
+ "model_count": n_models,
136
+ "dataset_count": n_datasets,
137
+ }
138
+
139
+
140
+ # ── Dev runner ────────────────────────────────────────────────────────────────
141
+ if __name__ == "__main__":
142
+ import uvicorn
143
+ uvicorn.run(
144
+ "main:app",
145
+ host=settings.host,
146
+ port=settings.port,
147
+ reload=settings.debug,
148
+ log_config=None, # We use structlog
149
+ )
middleware/__init__.py ADDED
File without changes
middleware/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
middleware/__pycache__/logging_middleware.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
middleware/logging_middleware.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ middleware/logging_middleware.py β€” Structured request/response logging.
3
+ Attaches a trace_id to every request, logs timing, method, path, status.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import time
8
+ import uuid
9
+ from typing import Callable
10
+
11
+ from fastapi import Request, Response
12
+ from starlette.middleware.base import BaseHTTPMiddleware
13
+
14
+ from observability.logger import audit, get_logger, log_system_event
15
+
16
+ log = get_logger("http")
17
+
18
+
19
+ class RequestLoggingMiddleware(BaseHTTPMiddleware):
20
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
21
+ trace_id = str(uuid.uuid4())[:8]
22
+ request.state.trace_id = trace_id
23
+ start = time.perf_counter()
24
+
25
+ log_system_event(
26
+ level="info",
27
+ message=f"API Request: {request.method} {request.url.path}",
28
+ source="gateway",
29
+ payload={"trace_id": trace_id, "query": str(request.url.query)}
30
+ )
31
+
32
+ response = await call_next(request)
33
+ duration_ms = (time.perf_counter() - start) * 1000
34
+
35
+ log_system_event(
36
+ level="info" if response.status_code < 400 else "error",
37
+ message=f"API Response: {response.status_code} ({duration_ms:.1f}ms)",
38
+ source="gateway",
39
+ payload={"trace_id": trace_id, "status": response.status_code, "latency_ms": round(duration_ms, 2)}
40
+ )
41
+
42
+ response.headers["X-Trace-Id"] = trace_id
43
+ response.headers["X-Response-Time"] = f"{duration_ms:.1f}ms"
44
+
45
+ # Audit slow requests
46
+ if duration_ms > 200:
47
+ await audit(
48
+ "slow_request",
49
+ payload={
50
+ "path": request.url.path,
51
+ "duration_ms": round(duration_ms, 2),
52
+ "trace_id": trace_id,
53
+ },
54
+ level="warning",
55
+ )
56
+
57
+ return response
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (141 Bytes). View file
 
models/__pycache__/benchmark.cpython-310.pyc ADDED
Binary file (7.09 kB). View file
 
models/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
models/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
models/__pycache__/job.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
models/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.06 kB). View file
 
models/__pycache__/project.cpython-310.pyc ADDED
Binary file (619 Bytes). View file
 
models/__pycache__/system.cpython-310.pyc ADDED
Binary file (1.86 kB). View file
 
models/benchmark.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models/benchmark.py β€” Pydantic domain models for the Benchmark Bridge System.
3
+ Single source of truth for all benchmark-related data shapes across API,
4
+ execution engine, and database layer.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field, ConfigDict
12
+
13
+
14
+ # ── Input ─────────────────────────────────────────────────────────────────────
15
+
16
+ class BenchmarkContext(BaseModel):
17
+ """Payload the UI sends to initiate a benchmark run."""
18
+ model_config = ConfigDict(protected_namespaces=())
19
+ model_id: str
20
+ dataset_id: str
21
+ task: str
22
+ framework: str
23
+ hardware: str = "cpu"
24
+ precision: str = "FP32"
25
+ batch_size: int = Field(1, ge=1, le=512)
26
+ # Task-specific overrides
27
+ max_tokens: int | None = 512
28
+ sequence_length: int | None = 512
29
+ img_size: int | None = 640
30
+ vid_stride: int | None = 1
31
+ stream: bool | None = False
32
+ input_source: str | None = "dataset"
33
+ video_path: str | None = None
34
+ rtsp_url: str | None = None
35
+ # Object Detection live preview data
36
+ detections: list[dict[str, Any]] = Field(default_factory=list)
37
+
38
+
39
+ # ── Validation ────────────────────────────────────────────────────────────────
40
+
41
+ class ValidationCheck(BaseModel):
42
+ """Result of a single compatibility gate."""
43
+ name: str
44
+ passed: bool
45
+ detail: str
46
+ suggestion: str | None = None
47
+
48
+
49
+ class ValidationReport(BaseModel):
50
+ """Aggregated result of all compatibility checks for a model+dataset pair."""
51
+ model_config = ConfigDict(protected_namespaces=())
52
+ model_id: str
53
+ dataset_id: str
54
+ passed: bool # True only if ALL checks pass
55
+ checks: list[ValidationCheck]
56
+ errors: list[str] # details from failed checks
57
+ warnings: list[str] = Field(default_factory=list)
58
+
59
+
60
+ # ── Metrics ───────────────────────────────────────────────────────────────────
61
+
62
+ class BenchmarkMetrics(BaseModel):
63
+ """Task-specific + hardware performance metrics from a completed run."""
64
+ # Detection / Segmentation
65
+ mAP: float | None = None
66
+ mAP_50: float | None = None
67
+ mAP_50_95: float | None = None
68
+ # Classification
69
+ accuracy: float | None = None
70
+ top1: float | None = None
71
+ top5: float | None = None
72
+ # Segmentation
73
+ iou_mean: float | None = None
74
+ # NLP / Generation
75
+ rouge_l: float | None = None
76
+ bleu: float | None = None
77
+ perplexity: float | None = None
78
+ tokens_per_sec: float | None = None
79
+ # Throughput & Latency
80
+ fps: float | None = None
81
+ latency_mean_ms: float | None = None
82
+ latency_p95_ms: float | None = None
83
+ latency_p99_ms: float | None = None
84
+ # Memory
85
+ vram_peak_gb: float | None = None
86
+ vram_avg_gb: float | None = None
87
+ # Dataset info
88
+ total_images: int | None = None
89
+ total_tokens: int | None = None
90
+ batch_size: int | None = None
91
+
92
+ class Config:
93
+ extra = "allow"
94
+
95
+
96
+ # ── Telemetry ─────────────────────────────────────────────────────────────────
97
+
98
+ class TelemetrySample(BaseModel):
99
+ """Single hardware reading captured during benchmark execution."""
100
+ timestamp: float # Unix epoch seconds
101
+ gpu_util_pct: float = 0.0 # 0–100
102
+ vram_used_gb: float = 0.0
103
+ vram_total_gb: float = 0.0
104
+ temp_c: float = 0.0
105
+ power_w: float = 0.0
106
+ batch_idx: int = 0
107
+ progress: float = 0.0 # 0.0–1.0
108
+ # Optional task-specific live data (e.g. BBoxes for detection)
109
+ live_data: dict[str, Any] = Field(default_factory=dict)
110
+ detections: list[dict[str, Any]] = Field(default_factory=list)
111
+
112
+
113
+ class LayerBreakdown(BaseModel):
114
+ """Single layer entry in a bottleneck analysis."""
115
+ name: str
116
+ time_ms: float
117
+ percent: float
118
+
119
+
120
+ class TelemetrySummary(BaseModel):
121
+ """Aggregated telemetry statistics over the full benchmark run."""
122
+ gpu_util_avg: float = 0.0
123
+ gpu_util_peak: float = 0.0
124
+ vram_avg_gb: float = 0.0
125
+ vram_peak_gb: float = 0.0
126
+ temp_avg_c: float = 0.0
127
+ temp_peak_c: float = 0.0
128
+ power_avg_w: float = 0.0
129
+ power_peak_w: float = 0.0
130
+ layer_breakdown: list[LayerBreakdown] = Field(default_factory=list)
131
+
132
+
133
+ # ── Job & Result ───��──────────────────────────────────────────────────────────
134
+
135
+ class BenchmarkJob(BaseModel):
136
+ id: str
137
+ model_config = ConfigDict(protected_namespaces=())
138
+ model_id: str
139
+ dataset_id: str
140
+ task: str
141
+ framework: str
142
+ hardware: str
143
+ precision: str
144
+ batch_size: int
145
+ config: dict = Field(default_factory=dict)
146
+ status: str = "queued" # queued|running|completed|failed
147
+ progress: float = 0.0
148
+ logs: list[str] = Field(default_factory=list)
149
+ created_at: str | None = None
150
+ updated_at: str | None = None
151
+ started_at: str | None = None
152
+ ended_at: str | None = None
153
+ last_telemetry: TelemetrySample | None = None
154
+
155
+
156
+ class BenchmarkResult(BaseModel):
157
+ model_config = ConfigDict(protected_namespaces=())
158
+ id: str
159
+ job_id: str
160
+ metrics: BenchmarkMetrics
161
+ telemetry_summary: TelemetrySummary
162
+ created_at: str | None = None
163
+ # Denormalized from Job for UI efficiency
164
+ model_id: str | None = None
165
+ dataset_id: str | None = None
166
+ task: str | None = None
167
+ framework: str | None = None
168
+ hardware: str | None = None
169
+ precision: str | None = None
170
+
171
+
172
+ # ── API Responses ─────────────────────────────────────────────────────────────
173
+
174
+ class BenchmarkRunResponse(BaseModel):
175
+ job_id: str
176
+ status: str
177
+ message: str
178
+
179
+
180
+ # ── DB Row helpers ────────────────────────────────────────────────────────────
181
+
182
+ def row_to_job(row: Any) -> BenchmarkJob:
183
+ d = dict(row)
184
+ cfg = json.loads(d.get("config") or "{}")
185
+ return BenchmarkJob(
186
+ id = d["id"],
187
+ model_id = d["model_id"],
188
+ dataset_id = d["dataset_id"],
189
+ task = d["task"],
190
+ framework = d["framework"],
191
+ hardware = d["hardware"],
192
+ precision = d["precision"],
193
+ batch_size = d["batch_size"],
194
+ config = cfg,
195
+ status = d["status"],
196
+ progress = float(d.get("progress", 0.0)),
197
+ logs = json.loads(d.get("logs") or "[]"),
198
+ error = d.get("error"),
199
+ created_at = d.get("created_at"),
200
+ updated_at = d.get("updated_at"),
201
+ started_at = d.get("started_at"),
202
+ ended_at = d.get("ended_at"),
203
+ last_telemetry = TelemetrySample(**json.loads(d.get("last_telemetry") or "{}")) if d.get("last_telemetry") else None,
204
+ )
205
+
206
+
207
+ def row_to_result(row: Any) -> BenchmarkResult:
208
+ d = dict(row)
209
+ metrics_raw = json.loads(d.get("metrics") or "{}")
210
+ telemetry_raw = json.loads(d.get("telemetry_summary") or "{}")
211
+ return BenchmarkResult(
212
+ id = d["id"],
213
+ job_id = d["job_id"],
214
+ metrics = BenchmarkMetrics(**metrics_raw),
215
+ telemetry_summary = TelemetrySummary(**telemetry_raw),
216
+ created_at = d.get("created_at"),
217
+ model_id = d.get("model_id"),
218
+ dataset_id = d.get("dataset_id"),
219
+ task = d.get("task"),
220
+ framework = d.get("framework"),
221
+ hardware = d.get("hardware"),
222
+ precision = d.get("precision"),
223
+ )
models/dataset.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models/dataset.py β€” Pydantic domain models for the Dataset Manager.
3
+ Single source of truth for all dataset-related data shapes.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import BaseModel, Field, ConfigDict
13
+
14
+
15
+ # ── Universal Dataset Viewer (UDV) Models ──────────────────────────────────
16
+
17
+ class DatasetContentType(str, Enum):
18
+ image = "image"
19
+ text = "text"
20
+ audio = "audio"
21
+ tabular = "tabular"
22
+
23
+ class UniversalAnnotationType(str, Enum):
24
+ detection = "detection"
25
+ segmentation = "segmentation"
26
+ keypoints = "keypoints"
27
+ classification = "classification"
28
+ span = "span"
29
+
30
+ class UniversalAnnotation(BaseModel):
31
+ label: str
32
+ type: UniversalAnnotationType
33
+ bbox: Optional[list[float]] = None # [x, y, w, h] normalized
34
+ segmentation: Optional[list[list[float]]] = None # [[x1, y1, x2, y2, ...], ...]
35
+ keypoints: Optional[list[float]] = None # [x1, y1, v1, ...]
36
+ confidence: Optional[float] = None
37
+ metadata: Optional[dict[str, Any]] = None
38
+
39
+ class UniversalDatasetItem(BaseModel):
40
+ id: str
41
+ content_type: DatasetContentType
42
+ content_url: Optional[str] = None
43
+ content_body: Optional[str] = None # For text or raw json
44
+ filename: Optional[str] = None
45
+ metadata: dict[str, Any] = Field(default_factory=dict)
46
+ annotations: list[UniversalAnnotation] = Field(default_factory=list)
47
+
48
+ class UniversalViewerPage(BaseModel):
49
+ dataset_id: str
50
+ page: int
51
+ page_size: int
52
+ total: int
53
+ total_pages: int
54
+ items: list[UniversalDatasetItem]
55
+
56
+
57
+ # ── Enumerations ──────────────────────────────────────────────────────────────
58
+
59
+ class DatasetTask(str, Enum):
60
+ detection = "detection"
61
+ classification = "classification"
62
+ segmentation = "segmentation"
63
+ nlp = "nlp"
64
+ generation = "generation"
65
+ keypoints = "keypoints"
66
+
67
+
68
+ class DatasetFormat(str, Enum):
69
+ yolo = "yolo"
70
+ coco = "coco"
71
+ voc = "voc"
72
+ csv = "csv"
73
+ json = "json"
74
+ tfrecord = "tfrecord"
75
+ custom = "custom"
76
+
77
+
78
+ class DatasetSource(str, Enum):
79
+ roboflow = "roboflow"
80
+ roboflow_curl = "roboflow_curl" # direct cURL / pre-signed URL download
81
+ local = "local"
82
+ huggingface = "huggingface"
83
+
84
+
85
+ class DatasetStatus(str, Enum):
86
+ available = "available"
87
+ queued = "queued"
88
+ importing = "importing"
89
+ extracting = "extracting"
90
+ validating = "validating"
91
+ imported = "imported"
92
+ failed = "failed"
93
+
94
+
95
+ class JobType(str, Enum):
96
+ import_ = "import"
97
+ extract = "extract"
98
+ validate = "validate"
99
+ analyze = "analyze"
100
+ delete = "delete"
101
+
102
+
103
+ class JobStatus(str, Enum):
104
+ queued = "queued"
105
+ running = "running"
106
+ completed = "completed"
107
+ failed = "failed"
108
+ cancelled = "cancelled"
109
+
110
+
111
+ class AnnotationType(str, Enum):
112
+ detection = "detection"
113
+ segmentation = "segmentation"
114
+ classification = "classification"
115
+
116
+
117
+ # ── Sub-models ────────────────────────────────────────────────────────────────
118
+
119
+ class DatasetSplit(BaseModel):
120
+ train: int = 0
121
+ val: int = 0
122
+ test: int = 0
123
+
124
+ @property
125
+ def total(self) -> int:
126
+ return self.train + self.val + self.test
127
+
128
+
129
+ class DatasetVersion(BaseModel):
130
+ version: str
131
+ date: str = ""
132
+ changes: str = ""
133
+ images: int = 0
134
+ format: str = ""
135
+
136
+
137
+ class DatasetStats(BaseModel):
138
+ """Aggregate statistics computed during import/analysis."""
139
+ image_count: int = 0
140
+ annotation_count: int = 0
141
+ class_count: int = 0
142
+ avg_objects: float = 0.0
143
+ missing_labels: int = 0
144
+ empty_images: int = 0
145
+ duplicate_count: int = 0
146
+ health_score: float = 0.0
147
+ split: DatasetSplit = Field(default_factory=DatasetSplit)
148
+
149
+
150
+ # ── Core Domain Models ────────────────────────────────────────────────────────
151
+
152
+ class Dataset(BaseModel):
153
+ model_config = ConfigDict(protected_namespaces=(), use_enum_values=True)
154
+ id: str
155
+ name: str
156
+ description: str = ""
157
+ task: DatasetTask
158
+ format: DatasetFormat
159
+ source: DatasetSource
160
+ status: DatasetStatus = DatasetStatus.available
161
+ images: int = 0
162
+ classes: int = 0
163
+ class_names: list[str] = Field(default_factory=list)
164
+ size_bytes: int = 0
165
+ size_label: str = "0 B"
166
+ local_path: str | None = None
167
+ import_progress: float = 0.0 # 0.0–1.0
168
+ tags: list[str] = Field(default_factory=list)
169
+ versions: list[DatasetVersion] = Field(default_factory=list)
170
+ active_version: str = "v1"
171
+ stats: DatasetStats = Field(default_factory=DatasetStats)
172
+ starred: bool = False
173
+ roboflow_id: str | None = None # workspace/project slug
174
+ created_at: str | None = None
175
+ updated_at: str | None = None
176
+
177
+
178
+ class DatasetSummary(BaseModel):
179
+ model_config = ConfigDict(protected_namespaces=())
180
+ """Lightweight projection for list endpoints."""
181
+ id: str
182
+ name: str
183
+ task: str
184
+ format: str
185
+ source: str
186
+ status: str
187
+ images: int
188
+ classes: int
189
+ size_label: str
190
+ tags: list[str]
191
+ starred: bool
192
+ import_progress: float
193
+ health_score: float = 0.0
194
+ created_at: str | None = None
195
+ updated_at: str | None = None
196
+
197
+
198
+ # ── Annotation Models ─────────────────────────────────────────────────────────
199
+
200
+ class BoundingBox(BaseModel):
201
+ x: float # top-left x (pixels or normalised)
202
+ y: float # top-left y
203
+ width: float
204
+ height: float
205
+ normalised: bool = True # True β†’ 0–1 range, False β†’ pixel coords
206
+
207
+
208
+ class Annotation(BaseModel):
209
+ """Unified annotation record (format-agnostic)."""
210
+ label: str
211
+ bbox: BoundingBox | None = None
212
+ segmentation: list[list[float]] | None = None # polygon points
213
+ keypoints: list[float] | None = None # [x, y, v, ...]
214
+ metadata: dict[str, Any] | None = None
215
+ confidence: float | None = None
216
+ area: float | None = None
217
+ type: AnnotationType = AnnotationType.detection
218
+
219
+
220
+ class ImageRecord(BaseModel):
221
+ """Image + its parsed annotations β€” returned by viewer endpoints."""
222
+ image_id: str
223
+ filename: str
224
+ width: int = 0
225
+ height: int = 0
226
+ path: str # relative to dataset root
227
+ annotations: list[Annotation] = Field(default_factory=list)
228
+ split: str = "train" # train|val|test
229
+
230
+
231
+ class ViewerPage(BaseModel):
232
+ """Paginated viewer response."""
233
+ dataset_id: str
234
+ page: int
235
+ page_size: int
236
+ total: int
237
+ total_pages: int
238
+ images: list[ImageRecord]
239
+
240
+
241
+ # ── Job Models ────────────────────────────────────────────────────────────────
242
+
243
+ class DatasetJob(BaseModel):
244
+ model_config = ConfigDict(protected_namespaces=())
245
+ id: str
246
+ type: str
247
+ status: str
248
+ dataset_id: str
249
+ dataset_name: str
250
+ progress: float = 0.0 # 0.0–1.0
251
+ message: str = ""
252
+ error: str | None = None
253
+ created_at: str | None = None
254
+ updated_at: str | None = None
255
+ started_at: str | None = None
256
+ ended_at: str | None = None
257
+
258
+
259
+ # ── Request/Response Schemas ─────────────────────────────────────────────────
260
+
261
+ class ImportRequest(BaseModel):
262
+ dataset_id: str
263
+ source: DatasetSource
264
+ roboflow_key: str | None = None # required when source=roboflow
265
+ roboflow_workspace: str | None = None
266
+ roboflow_project: str | None = None
267
+ roboflow_version: int = 1
268
+ hf_dataset_id: str | None = None # required when source=huggingface (e.g. "microsoft/coco")
269
+ format: DatasetFormat = DatasetFormat.yolo
270
+ local_path: str | None = None # required when source=local
271
+ # cURL / direct download (source=roboflow_curl)
272
+ download_url: str | None = None # pre-signed or direct download URL
273
+ headers: dict[str, str] = Field(default_factory=dict) # Custom headers for download
274
+ dataset_name: str | None = None # human-readable name override
275
+ name: str | None = None # alias for dataset_name (used in local folder import)
276
+ curl_format: str | None = None # export format label from Roboflow cURL (e.g. "yolov8")
277
+
278
+
279
+ class ImportResponse(BaseModel):
280
+ job_id: str
281
+ dataset_id: str
282
+ status: str
283
+ message: str
284
+
285
+
286
+ class RoboflowSearchRequest(BaseModel):
287
+ query: str = ""
288
+ api_key: str
289
+ workspace: str | None = None
290
+ page: int = 0
291
+ page_size: int = 50
292
+
293
+
294
+ # ── DB Row helpers ────────────────────────────────────────────────────────────
295
+
296
+ def row_to_dataset(row: Any) -> Dataset:
297
+ """
298
+ Robustly convert a DB row (sqlite3.Row or dict) to a Dataset model.
299
+ Handles:
300
+ 1. Enum string cleaning (stripping prefixes like 'DatasetStatus.')
301
+ 2. JSON parsing for nested fields (tags, class_names, versions)
302
+ 3. Missing 'stats' object initialization
303
+ """
304
+ import logging
305
+ logger = logging.getLogger("models.dataset")
306
+
307
+ try:
308
+ d = dict(row) if not isinstance(row, dict) else row.copy()
309
+
310
+ def clean_enum(val: Any) -> Any:
311
+ if isinstance(val, str) and "." in val:
312
+ return val.split(".")[-1]
313
+ return val
314
+
315
+ # Clean enum fields
316
+ for field in ["status", "task", "format", "source"]:
317
+ if field in d:
318
+ d[field] = clean_enum(d[field])
319
+
320
+ # Parse JSON fields with safety
321
+ for field in ["class_names", "tags", "versions"]:
322
+ raw = d.get(field)
323
+ if isinstance(raw, str):
324
+ try:
325
+ d[field] = json.loads(raw)
326
+ except Exception:
327
+ d[field] = []
328
+ elif raw is None:
329
+ d[field] = []
330
+
331
+ # Handle 'stats' - it might be a JSON string or missing in DB
332
+ stats_obj = DatasetStats()
333
+ stats_raw = d.get("stats")
334
+ if isinstance(stats_raw, str):
335
+ try:
336
+ stats_data = json.loads(stats_raw)
337
+ stats_obj = DatasetStats(**stats_data)
338
+ except Exception:
339
+ pass
340
+ elif isinstance(stats_raw, dict):
341
+ try:
342
+ stats_obj = DatasetStats(**stats_raw)
343
+ except Exception:
344
+ pass
345
+
346
+ # Ensure other numeric/boolean fields have defaults
347
+ d["images"] = d.get("images", 0)
348
+ d["classes"] = d.get("classes", 0)
349
+ d["starred"] = bool(d.get("starred", 0))
350
+ d["import_progress"] = float(d.get("import_progress", 0.0))
351
+ d["size_bytes"] = d.get("size_bytes", 0)
352
+
353
+ # Build clean dict for Pydantic
354
+ clean_data = {
355
+ "id": d["id"],
356
+ "name": d["name"],
357
+ "description": d.get("description", ""),
358
+ "task": d["task"],
359
+ "format": d["format"],
360
+ "source": d["source"],
361
+ "status": d.get("status", "available"),
362
+ "images": d["images"],
363
+ "classes": d["classes"],
364
+ "class_names": d["class_names"],
365
+ "size_bytes": d["size_bytes"],
366
+ "size_label": d.get("size_label", "0 B"),
367
+ "local_path": d.get("local_path"),
368
+ "import_progress": d["import_progress"],
369
+ "tags": d["tags"],
370
+ "versions": d["versions"],
371
+ "active_version": d.get("active_version", "v1"),
372
+ "stats": stats_obj,
373
+ "starred": d["starred"],
374
+ "roboflow_id": d.get("roboflow_id"),
375
+ "created_at": d.get("created_at"),
376
+ "updated_at": d.get("updated_at")
377
+ }
378
+
379
+ return Dataset(**clean_data)
380
+
381
+ except Exception as e:
382
+ logger.error(f"Pydantic instantiation error: {e}, row keys: {list(row.keys()) if hasattr(row, 'keys') else 'N/A'}")
383
+ raise
384
+
385
+
386
+ def row_to_job(row: Any) -> DatasetJob:
387
+ d = dict(row)
388
+ return DatasetJob(
389
+ id = d["id"],
390
+ type = d["type"],
391
+ status = d["status"],
392
+ dataset_id = d.get("dataset_id", ""),
393
+ dataset_name = d.get("dataset_name", ""),
394
+ progress = float(d.get("progress", 0.0)),
395
+ message = d.get("message", ""),
396
+ error = d.get("error"),
397
+ created_at = d.get("created_at"),
398
+ updated_at = d.get("updated_at"),
399
+ started_at = d.get("started_at"),
400
+ ended_at = d.get("ended_at"),
401
+ )
models/inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models/inference.py β€” Pydantic models for the Inference Engine.
3
+ Covers request, response, session history, and pipeline stage telemetry.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from enum import Enum
8
+ from typing import Any, Literal
9
+ from pydantic import BaseModel, Field
10
+ import time
11
+ import uuid
12
+
13
+
14
+ class AdapterType(str, Enum):
15
+ YOLO = "yolo"
16
+ TRANSFORMERS = "transformers"
17
+ ONNX = "onnx"
18
+ CUSTOM = "custom"
19
+
20
+
21
+ class InferencePrecision(str, Enum):
22
+ FP32 = "FP32"
23
+ FP16 = "FP16"
24
+ INT8 = "INT8"
25
+
26
+
27
+ class YOLOConfig(BaseModel):
28
+ confidence: float = Field(0.25, ge=0.0, le=1.0)
29
+ iou_threshold: float = Field(0.45, ge=0.1, le=0.9)
30
+ class_filter: list[str] = Field(default_factory=list)
31
+ max_detections: int = Field(300, ge=1, le=1000)
32
+
33
+
34
+ class TransformersConfig(BaseModel):
35
+ max_new_tokens: int = Field(256, ge=1, le=4096)
36
+ temperature: float = Field(0.7, ge=0.0, le=2.0)
37
+ top_p: float = Field(0.9, ge=0.0, le=1.0)
38
+ top_k: int = Field(50, ge=0, le=200)
39
+ beam_width: int = Field(1, ge=1, le=8)
40
+ do_sample: bool = True
41
+
42
+
43
+ class ONNXConfig(BaseModel):
44
+ execution_provider: Literal["CUDAExecutionProvider", "CPUExecutionProvider"] = "CUDAExecutionProvider"
45
+ input_size: int = Field(640, ge=32, le=1280)
46
+ normalize: bool = True
47
+
48
+
49
+ class CustomConfig(BaseModel):
50
+ preprocess_script: str = ""
51
+ postprocess_script: str = ""
52
+
53
+
54
+ class InferenceRequest(BaseModel):
55
+ model_id: str
56
+ adapter_type: AdapterType
57
+ precision: InferencePrecision = InferencePrecision.FP16
58
+
59
+ # Input β€” one of these must be set
60
+ image_base64: str | None = None # base64-encoded image
61
+ text_input: str | None = None # text/prompt
62
+
63
+ # Per-adapter config
64
+ yolo_config: YOLOConfig | None = None
65
+ transformers_config: TransformersConfig | None = None
66
+ onnx_config: ONNXConfig | None = None
67
+ custom_config: CustomConfig | None = None
68
+
69
+ # Execution
70
+ run_mode: Literal["single", "stream"] = "single"
71
+
72
+
73
+ class PipelineStage(BaseModel):
74
+ name: str
75
+ status: Literal["pending", "running", "done", "error"] = "pending"
76
+ latency_ms: float | None = None
77
+ detail: str | None = None
78
+
79
+
80
+ class Detection(BaseModel):
81
+ x1: float
82
+ y1: float
83
+ x2: float
84
+ y2: float
85
+ confidence: float
86
+ class_id: int
87
+ class_name: str
88
+
89
+
90
+ class InferenceResult(BaseModel):
91
+ # Identity
92
+ request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
93
+ model_id: str
94
+ adapter_type: AdapterType
95
+ timestamp: float = Field(default_factory=time.time)
96
+
97
+ # Timing
98
+ preprocess_ms: float = 0.0
99
+ inference_ms: float = 0.0
100
+ postprocess_ms: float = 0.0
101
+ total_ms: float = 0.0
102
+
103
+ # Output β€” adapter-specific, all optional
104
+ detections: list[Detection] = Field(default_factory=list)
105
+ text_output: str | None = None
106
+ class_label: str | None = None
107
+ confidence: float | None = None
108
+ embeddings: list[float] | None = None
109
+ raw_output: Any = None # raw JSON for inspector
110
+
111
+ # Pipeline trace
112
+ pipeline: list[PipelineStage] = Field(default_factory=list)
113
+
114
+ # Quality score (0–5) derived from confidence mean
115
+ quality_score: float | None = None
116
+
117
+ # Error
118
+ error: str | None = None
119
+ status: Literal["ok", "error"] = "ok"
120
+
121
+
122
+ class InferenceHistoryEntry(BaseModel):
123
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
124
+ model_id: str
125
+ model_name: str
126
+ adapter_type: AdapterType
127
+ timestamp: float = Field(default_factory=time.time)
128
+ total_ms: float
129
+ quality_score: float | None
130
+ status: Literal["ok", "error"]
131
+ # Compact snapshot of result for re-run
132
+ request_snapshot: dict[str, Any] = Field(default_factory=dict)
133
+
134
+
135
+ class SystemVitals(BaseModel):
136
+ ts: float
137
+ latency_ms: float
138
+ fps: float
139
+ vram_used_gb: float
140
+ vram_total_gb: float
141
+ gpu_temp_c: float | None = None
142
+ cpu_pct: float = 0.0
models/job.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models/job.py β€” Job domain models (download / benchmark / sync).
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ from typing import Any
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class Job(BaseModel):
13
+ model_config = {"protected_namespaces": ()}
14
+ id: str
15
+ type: str # download|benchmark|sync
16
+ status: str # queued|running|completed|failed|cancelled
17
+ model_id: str | None = None
18
+ model_name: str | None = None
19
+ progress: float = 0.0 # 0.0–1.0
20
+ error: str | None = None
21
+ meta: dict[str, Any] = Field(default_factory=dict)
22
+ created_at: str | None = None
23
+ updated_at: str | None = None
24
+ started_at: str | None = None
25
+ ended_at: str | None = None
26
+
27
+
28
+ class JobCreate(BaseModel):
29
+ model_config = {"protected_namespaces": ()}
30
+ model_id: str
31
+ model_name: str
32
+ type: str = "download"
33
+ version: str | None = None # specific weight file / version to download
34
+
35
+
36
+ def row_to_job(row: Any) -> Job:
37
+ d = dict(row)
38
+ return Job(
39
+ id = d["id"],
40
+ type = d["type"],
41
+ status = d["status"],
42
+ model_id = d.get("model_id"),
43
+ model_name = d.get("model_name"),
44
+ progress = float(d.get("progress", 0.0)),
45
+ error = d.get("error"),
46
+ meta = json.loads(d.get("meta") or "{}"),
47
+ created_at = d.get("created_at"),
48
+ updated_at = d.get("updated_at"),
49
+ started_at = d.get("started_at"),
50
+ ended_at = d.get("ended_at"),
51
+ )
models/model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models/model.py β€” Pydantic domain models (schema contract for API + internal).
3
+ Single source of truth for data shapes between all modules.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field, field_validator, model_validator
12
+
13
+ # ── Enumerations ──────────────────────────────────────────────────────────────
14
+
15
+ ModelTask = str # detection|classification|segmentation|generation|embedding|nlp
16
+ ModelFramework = str # pytorch|onnx|tensorflow|tflite|coreml
17
+ ModelSource = str # hf|onnx|local
18
+ ModelStatus = str # available|downloading|cached|error
19
+ HardwareTarget = str # gpu|cpu|edge|tpu
20
+
21
+
22
+ # ── Sub-models ────────────────────────────────────────────────────────────────
23
+
24
+ class ModelMetrics(BaseModel):
25
+ latency_ms: float | None = None
26
+ mAP: float | None = None
27
+ accuracy: float | None = None
28
+ top1: float | None = None
29
+ vram_gb: float | None = None
30
+ fps: float | None = None
31
+ flops: float | None = None
32
+
33
+ class Config:
34
+ extra = "allow"
35
+
36
+
37
+ class ModelVersion(BaseModel):
38
+ version: str
39
+ label: str = "Stable" # Latest|Stable|Legacy|Nano|Small|Medium|Large|XLarge
40
+ description: str | None = None
41
+ releaseDate: str = ""
42
+ changelog: str | None = None
43
+
44
+
45
+ # ── Core Model ────────────────────────────────────────────────────────────────
46
+
47
+ class Model(BaseModel):
48
+ id: str
49
+ name: str
50
+ variant: str | None = None
51
+ task: ModelTask
52
+ framework: ModelFramework
53
+ size: int = 0 # bytes
54
+ size_label: str = "0 B"
55
+ tags: list[str] = Field(default_factory=list)
56
+ source: ModelSource = "hf"
57
+ provider: str = ""
58
+ description: str = ""
59
+ download_url: str | None = None # explicit download source (HF repo URL, ONNX direct URL, etc.)
60
+ local_path: str | None = None
61
+ project_id: str | None = None
62
+ downloaded: bool = False
63
+ status: ModelStatus = "available"
64
+ hardware: list[HardwareTarget] = Field(default_factory=list)
65
+ metrics: ModelMetrics = Field(default_factory=ModelMetrics)
66
+ versions: list[ModelVersion] = Field(default_factory=list)
67
+ active_version: str | None = None
68
+ rating: float | None = None
69
+ downloads: int | None = None
70
+ liked: bool = False
71
+ created_at: str | None = None
72
+ updated_at: str | None = None
73
+
74
+
75
+ class ModelSummary(BaseModel):
76
+ """Lightweight projection returned in list endpoints."""
77
+ id: str
78
+ name: str
79
+ task: ModelTask
80
+ framework: ModelFramework
81
+ source: ModelSource
82
+ provider: str
83
+ size_label: str
84
+ status: ModelStatus
85
+ downloaded: bool
86
+ downloads: int | None = None
87
+ rating: float | None = None
88
+ tags: list[str]
89
+ hardware: list[HardwareTarget]
90
+ metrics: ModelMetrics
91
+
92
+
93
+ # ── DB Row β†’ Model ────────────────────────────────────────────────────────────
94
+
95
+ def row_to_model(row: Any, versions: list[ModelVersion] | None = None) -> Model:
96
+ """Convert an aiosqlite Row dict to a Model instance."""
97
+ d = dict(row)
98
+ metrics_raw = d.get("metrics") or "{}"
99
+ # metrics may come from model_versions join or not exist on models row
100
+ if isinstance(metrics_raw, str):
101
+ metrics_raw = json.loads(metrics_raw)
102
+
103
+ return Model(
104
+ id = d["id"],
105
+ name = d["name"],
106
+ variant = d.get("variant"),
107
+ task = d["task"],
108
+ framework = d["framework"],
109
+ source = d.get("source", "hf"),
110
+ provider = d.get("provider", ""),
111
+ description = d.get("description", ""),
112
+ download_url= d.get("download_url"),
113
+ size = d.get("size", 0),
114
+ size_label = d.get("size_label", "0 B"),
115
+ tags = json.loads(d.get("tags") or "[]"),
116
+ hardware = json.loads(d.get("hardware") or "[]"),
117
+ status = d.get("status", "available"),
118
+ downloaded = bool(d.get("downloaded", 0)),
119
+ local_path = d.get("local_path"),
120
+ project_id = d.get("project_id"),
121
+ downloads = d.get("downloads"),
122
+ rating = d.get("rating"),
123
+ liked = bool(d.get("liked", 0)),
124
+ metrics = ModelMetrics(**metrics_raw),
125
+ versions = versions or [],
126
+ active_version = d.get("active_version"),
127
+ created_at = d.get("created_at"),
128
+ updated_at = d.get("updated_at"),
129
+ )