feat: persist every prediction in supabase and expose history endpoint
Browse files- src/api/routes/predict.py +44 -5
- src/db/__init__.py +0 -0
- src/db/supabase_client.py +117 -0
- supabase/predictions_setup.sql +49 -0
src/api/routes/predict.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import time
|
| 2 |
|
| 3 |
-
from fastapi import APIRouter, HTTPException
|
| 4 |
|
| 5 |
from src.api.schemas import (
|
| 6 |
BatchPredictRequest,
|
|
@@ -12,13 +12,23 @@ from src.api.schemas import (
|
|
| 12 |
)
|
| 13 |
from src.api.services import predict_single, to_predict_response
|
| 14 |
from src.api.state import get_state
|
| 15 |
-
from src.api.youtube import CommentsFetchError, fetch_comments
|
|
|
|
|
|
|
| 16 |
router = APIRouter(tags=["Prediction"])
|
| 17 |
|
| 18 |
|
| 19 |
@router.post("/predict", response_model=PredictResponse)
|
| 20 |
async def predict(request: PredictRequest):
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@router.post("/predict-batch", response_model=BatchPredictResponse)
|
|
@@ -28,7 +38,15 @@ async def predict_batch(request: BatchPredictRequest):
|
|
| 28 |
for text in request.texts:
|
| 29 |
if not text.strip():
|
| 30 |
continue
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
total_ms = round((time.perf_counter() - t0) * 1000, 2)
|
| 33 |
toxic_count = sum(1 for r in results if r.is_toxic)
|
| 34 |
return BatchPredictResponse(
|
|
@@ -51,6 +69,8 @@ async def predict_video(request: VideoRequest):
|
|
| 51 |
if not comments:
|
| 52 |
raise HTTPException(status_code=404, detail="No comments found for this video")
|
| 53 |
|
|
|
|
|
|
|
| 54 |
t0 = time.perf_counter()
|
| 55 |
results: list[PredictResponse] = []
|
| 56 |
service = get_state()["service"]
|
|
@@ -61,7 +81,17 @@ async def predict_video(request: VideoRequest):
|
|
| 61 |
if not text.strip():
|
| 62 |
continue
|
| 63 |
raw = service.predict(text)
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
total_ms = round((time.perf_counter() - t0) * 1000, 2)
|
| 67 |
toxic_count = sum(1 for r in results if r.is_toxic)
|
|
@@ -75,3 +105,12 @@ async def predict_video(request: VideoRequest):
|
|
| 75 |
results=results,
|
| 76 |
source=source,
|
| 77 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
|
| 3 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 4 |
|
| 5 |
from src.api.schemas import (
|
| 6 |
BatchPredictRequest,
|
|
|
|
| 12 |
)
|
| 13 |
from src.api.services import predict_single, to_predict_response
|
| 14 |
from src.api.state import get_state
|
| 15 |
+
from src.api.youtube import CommentsFetchError, extract_video_id, fetch_comments
|
| 16 |
+
from src.db.supabase_client import list_predictions, save_prediction
|
| 17 |
+
|
| 18 |
router = APIRouter(tags=["Prediction"])
|
| 19 |
|
| 20 |
|
| 21 |
@router.post("/predict", response_model=PredictResponse)
|
| 22 |
async def predict(request: PredictRequest):
|
| 23 |
+
response = predict_single(request.text, request.threshold)
|
| 24 |
+
save_prediction(
|
| 25 |
+
text=request.text,
|
| 26 |
+
result=response,
|
| 27 |
+
source="api_direct",
|
| 28 |
+
threshold=request.threshold,
|
| 29 |
+
latency_ms=response.latency_ms,
|
| 30 |
+
)
|
| 31 |
+
return response
|
| 32 |
|
| 33 |
|
| 34 |
@router.post("/predict-batch", response_model=BatchPredictResponse)
|
|
|
|
| 38 |
for text in request.texts:
|
| 39 |
if not text.strip():
|
| 40 |
continue
|
| 41 |
+
single = predict_single(text.strip(), request.threshold)
|
| 42 |
+
results.append(single)
|
| 43 |
+
save_prediction(
|
| 44 |
+
text=text.strip(),
|
| 45 |
+
result=single,
|
| 46 |
+
source="api_direct",
|
| 47 |
+
threshold=request.threshold,
|
| 48 |
+
latency_ms=single.latency_ms,
|
| 49 |
+
)
|
| 50 |
total_ms = round((time.perf_counter() - t0) * 1000, 2)
|
| 51 |
toxic_count = sum(1 for r in results if r.is_toxic)
|
| 52 |
return BatchPredictResponse(
|
|
|
|
| 69 |
if not comments:
|
| 70 |
raise HTTPException(status_code=404, detail="No comments found for this video")
|
| 71 |
|
| 72 |
+
video_id = extract_video_id(request.url)
|
| 73 |
+
|
| 74 |
t0 = time.perf_counter()
|
| 75 |
results: list[PredictResponse] = []
|
| 76 |
service = get_state()["service"]
|
|
|
|
| 81 |
if not text.strip():
|
| 82 |
continue
|
| 83 |
raw = service.predict(text)
|
| 84 |
+
response = to_predict_response(text, raw, 0.0, request.threshold)
|
| 85 |
+
results.append(response)
|
| 86 |
+
save_prediction(
|
| 87 |
+
text=text,
|
| 88 |
+
result=response,
|
| 89 |
+
source="video_fetch",
|
| 90 |
+
video_id=video_id,
|
| 91 |
+
video_url=request.url,
|
| 92 |
+
threshold=request.threshold,
|
| 93 |
+
latency_ms=response.latency_ms,
|
| 94 |
+
)
|
| 95 |
|
| 96 |
total_ms = round((time.perf_counter() - t0) * 1000, 2)
|
| 97 |
toxic_count = sum(1 for r in results if r.is_toxic)
|
|
|
|
| 105 |
results=results,
|
| 106 |
source=source,
|
| 107 |
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@router.get("/predictions")
|
| 111 |
+
async def get_predictions(
|
| 112 |
+
video_id: str | None = Query(default=None),
|
| 113 |
+
limit: int = Query(default=50, ge=1, le=200),
|
| 114 |
+
):
|
| 115 |
+
rows = list_predictions(video_id=video_id, limit=limit)
|
| 116 |
+
return rows
|
src/db/__init__.py
ADDED
|
File without changes
|
src/db/supabase_client.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional Supabase persistence for predictions.
|
| 2 |
+
|
| 3 |
+
The API works fine without credentials — all functions degrade
|
| 4 |
+
gracefully when ``SUPABASE_URL`` / ``SUPABASE_KEY`` are missing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from src.utils.logger import get_logger
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from supabase import Client, create_client
|
| 19 |
+
except ImportError: # pragma: no cover - dep listed in pyproject
|
| 20 |
+
Client = None # type: ignore[assignment,misc]
|
| 21 |
+
create_client = None # type: ignore[assignment]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_TABLE = "predictions"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@lru_cache(maxsize=1)
|
| 28 |
+
def get_client() -> "Client | None":
|
| 29 |
+
"""Return a cached Supabase client, or ``None`` if not configured."""
|
| 30 |
+
url = os.getenv("SUPABASE_URL", "").strip()
|
| 31 |
+
key = os.getenv("SUPABASE_KEY", "").strip()
|
| 32 |
+
if not url or not key:
|
| 33 |
+
return None
|
| 34 |
+
if create_client is None:
|
| 35 |
+
logger.warning("supabase package not available; persistence disabled")
|
| 36 |
+
return None
|
| 37 |
+
try:
|
| 38 |
+
client = create_client(url, key)
|
| 39 |
+
logger.info("Supabase client initialized")
|
| 40 |
+
return client
|
| 41 |
+
except Exception as exc: # pragma: no cover - network/config errors
|
| 42 |
+
logger.warning("Failed to initialize Supabase client: %s", exc)
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_prediction(
|
| 47 |
+
text: str,
|
| 48 |
+
result: Any,
|
| 49 |
+
source: str,
|
| 50 |
+
video_id: str | None = None,
|
| 51 |
+
video_url: str | None = None,
|
| 52 |
+
threshold: float | None = None,
|
| 53 |
+
latency_ms: float | None = None,
|
| 54 |
+
) -> None:
|
| 55 |
+
"""Persist a single prediction, silently no-op when DB is not configured.
|
| 56 |
+
|
| 57 |
+
``result`` may be a Pydantic ``PredictResponse`` or a dict with the same
|
| 58 |
+
fields (``probability``, ``is_toxic``, ``labels``, ``model_used``,
|
| 59 |
+
``latency_ms``).
|
| 60 |
+
"""
|
| 61 |
+
client = get_client()
|
| 62 |
+
if client is None:
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
if hasattr(result, "model_dump"):
|
| 67 |
+
data = result.model_dump()
|
| 68 |
+
elif isinstance(result, dict):
|
| 69 |
+
data = result
|
| 70 |
+
else:
|
| 71 |
+
data = {
|
| 72 |
+
"probability": getattr(result, "probability", None),
|
| 73 |
+
"is_toxic": getattr(result, "is_toxic", None),
|
| 74 |
+
"labels": getattr(result, "labels", []),
|
| 75 |
+
"model_used": getattr(result, "model_used", ""),
|
| 76 |
+
"latency_ms": getattr(result, "latency_ms", None),
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
row = {
|
| 80 |
+
"text": text,
|
| 81 |
+
"video_id": video_id,
|
| 82 |
+
"video_url": video_url,
|
| 83 |
+
"probability": data.get("probability"),
|
| 84 |
+
"is_toxic": data.get("is_toxic"),
|
| 85 |
+
"labels": data.get("labels", []) or [],
|
| 86 |
+
"model_used": data.get("model_used", ""),
|
| 87 |
+
"threshold": threshold,
|
| 88 |
+
"latency_ms": latency_ms if latency_ms is not None else data.get("latency_ms"),
|
| 89 |
+
"source": source,
|
| 90 |
+
}
|
| 91 |
+
client.table(_TABLE).insert(row).execute()
|
| 92 |
+
except Exception as exc:
|
| 93 |
+
logger.warning("save_prediction failed (non-critical): %s", exc)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def list_predictions(
|
| 97 |
+
video_id: str | None = None,
|
| 98 |
+
limit: int = 50,
|
| 99 |
+
) -> list[dict]:
|
| 100 |
+
"""Return latest predictions ordered by ``created_at`` desc.
|
| 101 |
+
|
| 102 |
+
Returns ``[]`` when the client is not configured.
|
| 103 |
+
"""
|
| 104 |
+
client = get_client()
|
| 105 |
+
if client is None:
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
query = client.table(_TABLE).select("*").order("created_at", desc=True)
|
| 110 |
+
if video_id:
|
| 111 |
+
query = query.eq("video_id", video_id)
|
| 112 |
+
query = query.limit(max(1, min(limit, 200)))
|
| 113 |
+
response = query.execute()
|
| 114 |
+
return list(getattr(response, "data", []) or [])
|
| 115 |
+
except Exception as exc:
|
| 116 |
+
logger.warning("list_predictions failed: %s", exc)
|
| 117 |
+
return []
|
supabase/predictions_setup.sql
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- =====================================================================
|
| 2 |
+
-- SignalMod — Predictions table setup
|
| 3 |
+
-- Run this in Supabase SQL Editor:
|
| 4 |
+
-- Dashboard → SQL Editor → New query → paste this → Run
|
| 5 |
+
-- =====================================================================
|
| 6 |
+
|
| 7 |
+
-- 1. Table
|
| 8 |
+
create table if not exists public.predictions (
|
| 9 |
+
id bigserial primary key,
|
| 10 |
+
created_at timestamptz not null default now(),
|
| 11 |
+
text text not null,
|
| 12 |
+
video_id text,
|
| 13 |
+
video_url text,
|
| 14 |
+
probability double precision,
|
| 15 |
+
is_toxic boolean,
|
| 16 |
+
labels text[] default '{}',
|
| 17 |
+
model_used text,
|
| 18 |
+
threshold double precision,
|
| 19 |
+
latency_ms double precision,
|
| 20 |
+
source text -- "api_direct" | "video_fetch" | "user_comment"
|
| 21 |
+
);
|
| 22 |
+
|
| 23 |
+
-- 2. Indexes for the queries the API will run
|
| 24 |
+
create index if not exists predictions_created_at_idx
|
| 25 |
+
on public.predictions (created_at desc);
|
| 26 |
+
|
| 27 |
+
create index if not exists predictions_video_id_idx
|
| 28 |
+
on public.predictions (video_id);
|
| 29 |
+
|
| 30 |
+
-- 3. Row Level Security: allow anonymous insert + select
|
| 31 |
+
-- (we are using the publishable key from the frontend / backend with no auth)
|
| 32 |
+
alter table public.predictions enable row level security;
|
| 33 |
+
|
| 34 |
+
drop policy if exists "anon_insert" on public.predictions;
|
| 35 |
+
create policy "anon_insert"
|
| 36 |
+
on public.predictions
|
| 37 |
+
for insert
|
| 38 |
+
to anon
|
| 39 |
+
with check (true);
|
| 40 |
+
|
| 41 |
+
drop policy if exists "anon_select" on public.predictions;
|
| 42 |
+
create policy "anon_select"
|
| 43 |
+
on public.predictions
|
| 44 |
+
for select
|
| 45 |
+
to anon
|
| 46 |
+
using (true);
|
| 47 |
+
|
| 48 |
+
-- 4. Sanity check (run separately if you want)
|
| 49 |
+
-- select count(*) from public.predictions;
|