mlforge / registry /registry.py
senthil2421's picture
Deploy cloud brain to HF Spaces
ac5551d
"""
registry/registry.py β€” Model Registry.
Responsible for persisting, reading, and updating model metadata in SQLite.
All callers go through this module; no direct DB access from other modules.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from database.connection import get_db
from models.model import Model, ModelVersion, row_to_model
from observability.logger import audit, get_logger
log = get_logger("registry")
async def upsert_model(model: Model) -> None:
"""Insert or update a model record (and its first version)."""
db = await get_db()
now = datetime.now(timezone.utc).isoformat()
await db.execute(
"""INSERT INTO models
(id, name, variant, task, framework, source, provider, description,
download_url, size, size_label, tags, hardware, status, downloaded, local_path, project_id,
active_version, metrics,
downloads, rating, liked, created_at, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
ON CONFLICT(id) DO UPDATE SET
name=excluded.name,
variant=excluded.variant,
task=excluded.task,
framework=excluded.framework,
source=excluded.source,
provider=excluded.provider,
description=excluded.description,
download_url=excluded.download_url,
size=excluded.size,
size_label=excluded.size_label,
tags=excluded.tags,
hardware=excluded.hardware,
status=excluded.status,
downloads=excluded.downloads,
rating=excluded.rating,
active_version=excluded.active_version,
metrics=excluded.metrics,
local_path=excluded.local_path,
project_id=excluded.project_id,
updated_at=excluded.updated_at""",
(
model.id, model.name, model.variant, model.task, model.framework,
model.source, model.provider, model.description, model.download_url,
model.size, model.size_label, json.dumps(model.tags), json.dumps(model.hardware),
model.status, int(model.downloaded), model.local_path, model.project_id,
model.active_version, model.metrics.model_dump_json(),
model.downloads, model.rating, int(model.liked),
model.created_at or now, now,
),
)
# Upsert versions
for v in model.versions:
version_id = f"{model.id}_{v.version}"
await db.execute(
"""INSERT INTO model_versions
(version_id, model_id, version, label, description, metrics, release_date, changelog)
VALUES (?,?,?,?,?,?,?,?)
ON CONFLICT(version_id) DO UPDATE SET
label=excluded.label, description=excluded.description,
release_date=excluded.release_date, changelog=excluded.changelog""",
(
version_id, model.id, v.version, v.label, v.description,
json.dumps({}), v.releaseDate, v.changelog,
),
)
await db.commit()
async def bulk_upsert(models: list[Model]) -> None:
"""Batch upsert for sync operations."""
inserted = 0
for model in models:
await upsert_model(model)
inserted += 1
log.info("registry_bulk_upsert", total=inserted)
await audit("registry_sync", payload={"count": inserted})
async def get_model(model_id: str) -> Model | None:
db = await get_db()
async with db.execute("SELECT * FROM models WHERE id = ?", (model_id,)) as cur:
row = await cur.fetchone()
if not row:
return None
# Fetch versions
async with db.execute(
"SELECT * FROM model_versions WHERE model_id = ? ORDER BY created_at DESC",
(model_id,),
) as cur:
version_rows = await cur.fetchall()
versions = [
ModelVersion(
version=r["version"],
label=r["label"],
description=r["description"] if "description" in r.keys() else None,
releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "",
changelog=r["changelog"] if "changelog" in r.keys() else None,
)
for r in version_rows
]
return row_to_model(row, versions)
async def list_models(
*,
tasks: list[str] | None = None,
frameworks: list[str] | None = None,
hardware: list[str] | None = None,
sources: list[str] | None = None,
downloaded: bool | None = None,
sort_by: str = "downloads",
sort_dir: str = "desc",
limit: int = 500,
offset: int = 0,
search: str | None = None,
project_id: str | None = None,
) -> list[Model]:
db = await get_db()
# ── WHERE conditions ──────────────────────────────────────────────
conditions: list[str] = []
params: list[Any] = []
# FTS5 subquery β€” valid SQLite syntax
if search and search.strip():
fts_term = f'"{search.strip()}"*'
conditions.append(
"m.id IN (SELECT id FROM models_fts WHERE models_fts MATCH ?)"
)
params.append(fts_term)
if tasks:
placeholders = ",".join(["?"] * len(tasks))
conditions.append(f"m.task IN ({placeholders})")
params.extend(tasks)
if frameworks:
placeholders = ",".join(["?"] * len(frameworks))
conditions.append(f"m.framework IN ({placeholders})")
params.extend(frameworks)
if sources:
placeholders = ",".join(["?"] * len(sources))
conditions.append(f"m.source IN ({placeholders})")
params.extend(sources)
if hardware:
hw_conds = ["m.hardware LIKE ?" for _ in hardware]
conditions.append(f"({' OR '.join(hw_conds)})")
params.extend([f"%{h}%" for h in hardware])
if downloaded is not None:
conditions.append("m.downloaded = ?")
params.append(int(downloaded))
if project_id:
conditions.append("(m.project_id = ? OR m.project_id IS NULL)")
params.append(project_id)
where_clause = ("WHERE " + " AND ".join(conditions)) if conditions else ""
# ── Sort ──────────────────────────────────────────────────────────
sort_col_map = {
"downloads": "m.downloads",
"name": "m.name",
"size": "m.size",
"rating": "m.rating",
"created": "m.created_at",
}
col = sort_col_map.get(sort_by, "m.downloads")
direction = "DESC" if sort_dir == "desc" else "ASC"
sql = f"""
SELECT m.* FROM models m
{where_clause}
ORDER BY {col} {direction} NULLS LAST
LIMIT ? OFFSET ?
"""
async with db.execute(sql, params + [limit, offset]) as cur:
rows = await cur.fetchall()
models = [row_to_model(row, []) for row in rows]
if not models:
return models
ids = [m.id for m in models]
placeholders = ",".join(["?"] * len(ids))
async with db.execute(
f"SELECT * FROM model_versions WHERE model_id IN ({placeholders}) ORDER BY created_at DESC",
ids,
) as cur:
vrows = await cur.fetchall()
by_model: dict[str, list[ModelVersion]] = {}
for r in vrows:
mv = ModelVersion(
version=r["version"],
label=r["label"],
description=r["description"] if "description" in r.keys() else None,
releaseDate=r["release_date"] if "release_date" in r.keys() and r["release_date"] else "",
changelog=r["changelog"] if "changelog" in r.keys() else None,
)
by_model.setdefault(r["model_id"], []).append(mv)
return [m.model_copy(update={"versions": by_model.get(m.id, [])}) for m in models]
async def update_model_status(
model_id: str,
*,
status: str | None = None,
downloaded: bool | None = None,
local_path: str | None = None,
) -> None:
db = await get_db()
now = datetime.now(timezone.utc).isoformat()
parts: list[str] = ["updated_at = ?"]
vals: list[Any] = [now]
if status is not None:
parts.append("status = ?"); vals.append(status)
if downloaded is not None:
parts.append("downloaded = ?"); vals.append(int(downloaded))
if local_path is not None:
parts.append("local_path = ?"); vals.append(local_path)
vals.append(model_id)
await db.execute(f"UPDATE models SET {', '.join(parts)} WHERE id = ?", vals)
await db.commit()
async def count_models() -> int:
db = await get_db()
async with db.execute("SELECT COUNT(*) FROM models") as cur:
row = await cur.fetchone()
return row[0] if row else 0