Spaces:
Sleeping
Sleeping
| """ | |
| registry/registry.py β Model Registry. | |
| Responsible for persisting, reading, and updating model metadata in SQLite. | |
| All callers go through this module; no direct DB access from other modules. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from datetime import datetime, timezone | |
| from typing import Any | |
| from database.connection import get_db | |
| from models.model import Model, ModelVersion, row_to_model | |
| from observability.logger import audit, get_logger | |
| log = get_logger("registry") | |
| async def upsert_model(model: Model) -> None: | |
| """Insert or update a model record (and its first version).""" | |
| db = await get_db() | |
| now = datetime.now(timezone.utc).isoformat() | |
| await db.execute( | |
| """INSERT INTO models | |
| (id, name, variant, task, framework, source, provider, description, | |
| download_url, size, size_label, tags, hardware, status, downloaded, local_path, project_id, | |
| active_version, metrics, | |
| downloads, rating, liked, created_at, updated_at) | |
| VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) | |
| ON CONFLICT(id) DO UPDATE SET | |
| name=excluded.name, | |
| variant=excluded.variant, | |
| task=excluded.task, | |
| framework=excluded.framework, | |
| source=excluded.source, | |
| provider=excluded.provider, | |
| description=excluded.description, | |
| download_url=excluded.download_url, | |
| size=excluded.size, | |
| size_label=excluded.size_label, | |
| tags=excluded.tags, | |
| hardware=excluded.hardware, | |
| status=excluded.status, | |
| downloads=excluded.downloads, | |
| rating=excluded.rating, | |
| active_version=excluded.active_version, | |
| metrics=excluded.metrics, | |
| local_path=excluded.local_path, | |
| project_id=excluded.project_id, | |
| updated_at=excluded.updated_at""", | |
| ( | |
| model.id, model.name, model.variant, model.task, model.framework, | |
| model.source, model.provider, model.description, model.download_url, | |
| model.size, model.size_label, json.dumps(model.tags), json.dumps(model.hardware), | |
| model.status, int(model.downloaded), model.local_path, model.project_id, | |
| model.active_version, model.metrics.model_dump_json(), | |
| model.downloads, model.rating, int(model.liked), | |
| model.created_at or now, now, | |
| ), | |
| ) | |
| # Upsert versions | |
| for v in model.versions: | |
| version_id = f"{model.id}_{v.version}" | |
| await db.execute( | |
| """INSERT INTO model_versions | |
| (version_id, model_id, version, label, description, metrics, release_date, changelog) | |
| VALUES (?,?,?,?,?,?,?,?) | |
| ON CONFLICT(version_id) DO UPDATE SET | |
| label=excluded.label, description=excluded.description, | |
| release_date=excluded.release_date, changelog=excluded.changelog""", | |
| ( | |
| version_id, model.id, v.version, v.label, v.description, | |
| json.dumps({}), v.releaseDate, v.changelog, | |
| ), | |
| ) | |
| await db.commit() | |
| async def bulk_upsert(models: list[Model]) -> None: | |
| """Batch upsert for sync operations.""" | |
| inserted = 0 | |
| for model in models: | |
| await upsert_model(model) | |
| inserted += 1 | |
| log.info("registry_bulk_upsert", total=inserted) | |
| await audit("registry_sync", payload={"count": inserted}) | |
| async def get_model(model_id: str) -> Model | None: | |
| db = await get_db() | |
| async with db.execute("SELECT * FROM models WHERE id = ?", (model_id,)) as cur: | |
| row = await cur.fetchone() | |
| if not row: | |
| return None | |
| # Fetch versions | |
| async with db.execute( | |
| "SELECT * FROM model_versions WHERE model_id = ? ORDER BY created_at DESC", | |
| (model_id,), | |
| ) as cur: | |
| version_rows = await cur.fetchall() | |
| versions = [ | |
| ModelVersion( | |
| version=r["version"], | |
| label=r["label"], | |
| description=r["description"] if "description" in r.keys() else None, | |
| releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "", | |
| changelog=r["changelog"] if "changelog" in r.keys() else None, | |
| ) | |
| for r in version_rows | |
| ] | |
| return row_to_model(row, versions) | |
| async def list_models( | |
| *, | |
| tasks: list[str] | None = None, | |
| frameworks: list[str] | None = None, | |
| hardware: list[str] | None = None, | |
| sources: list[str] | None = None, | |
| downloaded: bool | None = None, | |
| sort_by: str = "downloads", | |
| sort_dir: str = "desc", | |
| limit: int = 500, | |
| offset: int = 0, | |
| search: str | None = None, | |
| project_id: str | None = None, | |
| ) -> list[Model]: | |
| db = await get_db() | |
| # ββ WHERE conditions ββββββββββββββββββββββββββββββββββββββββββββββ | |
| conditions: list[str] = [] | |
| params: list[Any] = [] | |
| # FTS5 subquery β valid SQLite syntax | |
| if search and search.strip(): | |
| fts_term = f'"{search.strip()}"*' | |
| conditions.append( | |
| "m.id IN (SELECT id FROM models_fts WHERE models_fts MATCH ?)" | |
| ) | |
| params.append(fts_term) | |
| if tasks: | |
| placeholders = ",".join(["?"] * len(tasks)) | |
| conditions.append(f"m.task IN ({placeholders})") | |
| params.extend(tasks) | |
| if frameworks: | |
| placeholders = ",".join(["?"] * len(frameworks)) | |
| conditions.append(f"m.framework IN ({placeholders})") | |
| params.extend(frameworks) | |
| if sources: | |
| placeholders = ",".join(["?"] * len(sources)) | |
| conditions.append(f"m.source IN ({placeholders})") | |
| params.extend(sources) | |
| if hardware: | |
| hw_conds = ["m.hardware LIKE ?" for _ in hardware] | |
| conditions.append(f"({' OR '.join(hw_conds)})") | |
| params.extend([f"%{h}%" for h in hardware]) | |
| if downloaded is not None: | |
| conditions.append("m.downloaded = ?") | |
| params.append(int(downloaded)) | |
| if project_id: | |
| conditions.append("(m.project_id = ? OR m.project_id IS NULL)") | |
| params.append(project_id) | |
| where_clause = ("WHERE " + " AND ".join(conditions)) if conditions else "" | |
| # ββ Sort ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sort_col_map = { | |
| "downloads": "m.downloads", | |
| "name": "m.name", | |
| "size": "m.size", | |
| "rating": "m.rating", | |
| "created": "m.created_at", | |
| } | |
| col = sort_col_map.get(sort_by, "m.downloads") | |
| direction = "DESC" if sort_dir == "desc" else "ASC" | |
| sql = f""" | |
| SELECT m.* FROM models m | |
| {where_clause} | |
| ORDER BY {col} {direction} NULLS LAST | |
| LIMIT ? OFFSET ? | |
| """ | |
| async with db.execute(sql, params + [limit, offset]) as cur: | |
| rows = await cur.fetchall() | |
| models = [row_to_model(row, []) for row in rows] | |
| if not models: | |
| return models | |
| ids = [m.id for m in models] | |
| placeholders = ",".join(["?"] * len(ids)) | |
| async with db.execute( | |
| f"SELECT * FROM model_versions WHERE model_id IN ({placeholders}) ORDER BY created_at DESC", | |
| ids, | |
| ) as cur: | |
| vrows = await cur.fetchall() | |
| by_model: dict[str, list[ModelVersion]] = {} | |
| for r in vrows: | |
| mv = ModelVersion( | |
| version=r["version"], | |
| label=r["label"], | |
| description=r["description"] if "description" in r.keys() else None, | |
| releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "", | |
| changelog=r["changelog"] if "changelog" in r.keys() else None, | |
| ) | |
| by_model.setdefault(r["model_id"], []).append(mv) | |
| return [m.model_copy(update={"versions": by_model.get(m.id, [])}) for m in models] | |
| async def update_model_status( | |
| model_id: str, | |
| *, | |
| status: str | None = None, | |
| downloaded: bool | None = None, | |
| local_path: str | None = None, | |
| ) -> None: | |
| db = await get_db() | |
| now = datetime.now(timezone.utc).isoformat() | |
| parts: list[str] = ["updated_at = ?"] | |
| vals: list[Any] = [now] | |
| if status is not None: | |
| parts.append("status = ?"); vals.append(status) | |
| if downloaded is not None: | |
| parts.append("downloaded = ?"); vals.append(int(downloaded)) | |
| if local_path is not None: | |
| parts.append("local_path = ?"); vals.append(local_path) | |
| vals.append(model_id) | |
| await db.execute(f"UPDATE models SET {', '.join(parts)} WHERE id = ?", vals) | |
| await db.commit() | |
| async def count_models() -> int: | |
| db = await get_db() | |
| async with db.execute("SELECT COUNT(*) FROM models") as cur: | |
| row = await cur.fetchone() | |
| return row[0] if row else 0 | |