File size: 3,712 Bytes
7ba2f95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """Optional Supabase persistence for predictions.
The API works fine without credentials — all functions degrade
gracefully when ``SUPABASE_URL`` / ``SUPABASE_KEY`` are missing.
"""
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any
from src.utils.logger import get_logger
logger = get_logger(__name__)
try:
from supabase import Client, create_client
except ImportError: # pragma: no cover - dep listed in pyproject
Client = None # type: ignore[assignment,misc]
create_client = None # type: ignore[assignment]
_TABLE = "predictions"
@lru_cache(maxsize=1)
def get_client() -> "Client | None":
"""Return a cached Supabase client, or ``None`` if not configured."""
url = os.getenv("SUPABASE_URL", "").strip()
key = os.getenv("SUPABASE_KEY", "").strip()
if not url or not key:
return None
if create_client is None:
logger.warning("supabase package not available; persistence disabled")
return None
try:
client = create_client(url, key)
logger.info("Supabase client initialized")
return client
except Exception as exc: # pragma: no cover - network/config errors
logger.warning("Failed to initialize Supabase client: %s", exc)
return None
def save_prediction(
text: str,
result: Any,
source: str,
video_id: str | None = None,
video_url: str | None = None,
threshold: float | None = None,
latency_ms: float | None = None,
) -> None:
"""Persist a single prediction, silently no-op when DB is not configured.
``result`` may be a Pydantic ``PredictResponse`` or a dict with the same
fields (``probability``, ``is_toxic``, ``labels``, ``model_used``,
``latency_ms``).
"""
client = get_client()
if client is None:
return
try:
if hasattr(result, "model_dump"):
data = result.model_dump()
elif isinstance(result, dict):
data = result
else:
data = {
"probability": getattr(result, "probability", None),
"is_toxic": getattr(result, "is_toxic", None),
"labels": getattr(result, "labels", []),
"model_used": getattr(result, "model_used", ""),
"latency_ms": getattr(result, "latency_ms", None),
}
row = {
"text": text,
"video_id": video_id,
"video_url": video_url,
"probability": data.get("probability"),
"is_toxic": data.get("is_toxic"),
"labels": data.get("labels", []) or [],
"model_used": data.get("model_used", ""),
"threshold": threshold,
"latency_ms": latency_ms if latency_ms is not None else data.get("latency_ms"),
"source": source,
}
client.table(_TABLE).insert(row).execute()
except Exception as exc:
logger.warning("save_prediction failed (non-critical): %s", exc)
def list_predictions(
video_id: str | None = None,
limit: int = 50,
) -> list[dict]:
"""Return latest predictions ordered by ``created_at`` desc.
Returns ``[]`` when the client is not configured.
"""
client = get_client()
if client is None:
return []
try:
query = client.table(_TABLE).select("*").order("created_at", desc=True)
if video_id:
query = query.eq("video_id", video_id)
query = query.limit(max(1, min(limit, 200)))
response = query.execute()
return list(getattr(response, "data", []) or [])
except Exception as exc:
logger.warning("list_predictions failed: %s", exc)
return []
|