Ruperth commited on
Commit
7ba2f95
·
1 Parent(s): 0ac8b84

feat: persist every prediction in supabase and expose history endpoint

Browse files
src/api/routes/predict.py CHANGED
@@ -1,6 +1,6 @@
1
  import time
2
 
3
- from fastapi import APIRouter, HTTPException
4
 
5
  from src.api.schemas import (
6
  BatchPredictRequest,
@@ -12,13 +12,23 @@ from src.api.schemas import (
12
  )
13
  from src.api.services import predict_single, to_predict_response
14
  from src.api.state import get_state
15
- from src.api.youtube import CommentsFetchError, fetch_comments
 
 
16
  router = APIRouter(tags=["Prediction"])
17
 
18
 
19
  @router.post("/predict", response_model=PredictResponse)
20
  async def predict(request: PredictRequest):
21
- return predict_single(request.text, request.threshold)
 
 
 
 
 
 
 
 
22
 
23
 
24
  @router.post("/predict-batch", response_model=BatchPredictResponse)
@@ -28,7 +38,15 @@ async def predict_batch(request: BatchPredictRequest):
28
  for text in request.texts:
29
  if not text.strip():
30
  continue
31
- results.append(predict_single(text.strip(), request.threshold))
 
 
 
 
 
 
 
 
32
  total_ms = round((time.perf_counter() - t0) * 1000, 2)
33
  toxic_count = sum(1 for r in results if r.is_toxic)
34
  return BatchPredictResponse(
@@ -51,6 +69,8 @@ async def predict_video(request: VideoRequest):
51
  if not comments:
52
  raise HTTPException(status_code=404, detail="No comments found for this video")
53
 
 
 
54
  t0 = time.perf_counter()
55
  results: list[PredictResponse] = []
56
  service = get_state()["service"]
@@ -61,7 +81,17 @@ async def predict_video(request: VideoRequest):
61
  if not text.strip():
62
  continue
63
  raw = service.predict(text)
64
- results.append(to_predict_response(text, raw, 0.0, request.threshold))
 
 
 
 
 
 
 
 
 
 
65
 
66
  total_ms = round((time.perf_counter() - t0) * 1000, 2)
67
  toxic_count = sum(1 for r in results if r.is_toxic)
@@ -75,3 +105,12 @@ async def predict_video(request: VideoRequest):
75
  results=results,
76
  source=source,
77
  )
 
 
 
 
 
 
 
 
 
 
1
  import time
2
 
3
+ from fastapi import APIRouter, HTTPException, Query
4
 
5
  from src.api.schemas import (
6
  BatchPredictRequest,
 
12
  )
13
  from src.api.services import predict_single, to_predict_response
14
  from src.api.state import get_state
15
+ from src.api.youtube import CommentsFetchError, extract_video_id, fetch_comments
16
+ from src.db.supabase_client import list_predictions, save_prediction
17
+
18
  router = APIRouter(tags=["Prediction"])
19
 
20
 
21
  @router.post("/predict", response_model=PredictResponse)
22
  async def predict(request: PredictRequest):
23
+ response = predict_single(request.text, request.threshold)
24
+ save_prediction(
25
+ text=request.text,
26
+ result=response,
27
+ source="api_direct",
28
+ threshold=request.threshold,
29
+ latency_ms=response.latency_ms,
30
+ )
31
+ return response
32
 
33
 
34
  @router.post("/predict-batch", response_model=BatchPredictResponse)
 
38
  for text in request.texts:
39
  if not text.strip():
40
  continue
41
+ single = predict_single(text.strip(), request.threshold)
42
+ results.append(single)
43
+ save_prediction(
44
+ text=text.strip(),
45
+ result=single,
46
+ source="api_direct",
47
+ threshold=request.threshold,
48
+ latency_ms=single.latency_ms,
49
+ )
50
  total_ms = round((time.perf_counter() - t0) * 1000, 2)
51
  toxic_count = sum(1 for r in results if r.is_toxic)
52
  return BatchPredictResponse(
 
69
  if not comments:
70
  raise HTTPException(status_code=404, detail="No comments found for this video")
71
 
72
+ video_id = extract_video_id(request.url)
73
+
74
  t0 = time.perf_counter()
75
  results: list[PredictResponse] = []
76
  service = get_state()["service"]
 
81
  if not text.strip():
82
  continue
83
  raw = service.predict(text)
84
+ response = to_predict_response(text, raw, 0.0, request.threshold)
85
+ results.append(response)
86
+ save_prediction(
87
+ text=text,
88
+ result=response,
89
+ source="video_fetch",
90
+ video_id=video_id,
91
+ video_url=request.url,
92
+ threshold=request.threshold,
93
+ latency_ms=response.latency_ms,
94
+ )
95
 
96
  total_ms = round((time.perf_counter() - t0) * 1000, 2)
97
  toxic_count = sum(1 for r in results if r.is_toxic)
 
105
  results=results,
106
  source=source,
107
  )
108
+
109
+
110
+ @router.get("/predictions")
111
+ async def get_predictions(
112
+ video_id: str | None = Query(default=None),
113
+ limit: int = Query(default=50, ge=1, le=200),
114
+ ):
115
+ rows = list_predictions(video_id=video_id, limit=limit)
116
+ return rows
src/db/__init__.py ADDED
File without changes
src/db/supabase_client.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optional Supabase persistence for predictions.
2
+
3
+ The API works fine without credentials — all functions degrade
4
+ gracefully when ``SUPABASE_URL`` / ``SUPABASE_KEY`` are missing.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from functools import lru_cache
11
+ from typing import Any
12
+
13
+ from src.utils.logger import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+ try:
18
+ from supabase import Client, create_client
19
+ except ImportError: # pragma: no cover - dep listed in pyproject
20
+ Client = None # type: ignore[assignment,misc]
21
+ create_client = None # type: ignore[assignment]
22
+
23
+
24
+ _TABLE = "predictions"
25
+
26
+
27
+ @lru_cache(maxsize=1)
28
+ def get_client() -> "Client | None":
29
+ """Return a cached Supabase client, or ``None`` if not configured."""
30
+ url = os.getenv("SUPABASE_URL", "").strip()
31
+ key = os.getenv("SUPABASE_KEY", "").strip()
32
+ if not url or not key:
33
+ return None
34
+ if create_client is None:
35
+ logger.warning("supabase package not available; persistence disabled")
36
+ return None
37
+ try:
38
+ client = create_client(url, key)
39
+ logger.info("Supabase client initialized")
40
+ return client
41
+ except Exception as exc: # pragma: no cover - network/config errors
42
+ logger.warning("Failed to initialize Supabase client: %s", exc)
43
+ return None
44
+
45
+
46
+ def save_prediction(
47
+ text: str,
48
+ result: Any,
49
+ source: str,
50
+ video_id: str | None = None,
51
+ video_url: str | None = None,
52
+ threshold: float | None = None,
53
+ latency_ms: float | None = None,
54
+ ) -> None:
55
+ """Persist a single prediction, silently no-op when DB is not configured.
56
+
57
+ ``result`` may be a Pydantic ``PredictResponse`` or a dict with the same
58
+ fields (``probability``, ``is_toxic``, ``labels``, ``model_used``,
59
+ ``latency_ms``).
60
+ """
61
+ client = get_client()
62
+ if client is None:
63
+ return
64
+
65
+ try:
66
+ if hasattr(result, "model_dump"):
67
+ data = result.model_dump()
68
+ elif isinstance(result, dict):
69
+ data = result
70
+ else:
71
+ data = {
72
+ "probability": getattr(result, "probability", None),
73
+ "is_toxic": getattr(result, "is_toxic", None),
74
+ "labels": getattr(result, "labels", []),
75
+ "model_used": getattr(result, "model_used", ""),
76
+ "latency_ms": getattr(result, "latency_ms", None),
77
+ }
78
+
79
+ row = {
80
+ "text": text,
81
+ "video_id": video_id,
82
+ "video_url": video_url,
83
+ "probability": data.get("probability"),
84
+ "is_toxic": data.get("is_toxic"),
85
+ "labels": data.get("labels", []) or [],
86
+ "model_used": data.get("model_used", ""),
87
+ "threshold": threshold,
88
+ "latency_ms": latency_ms if latency_ms is not None else data.get("latency_ms"),
89
+ "source": source,
90
+ }
91
+ client.table(_TABLE).insert(row).execute()
92
+ except Exception as exc:
93
+ logger.warning("save_prediction failed (non-critical): %s", exc)
94
+
95
+
96
+ def list_predictions(
97
+ video_id: str | None = None,
98
+ limit: int = 50,
99
+ ) -> list[dict]:
100
+ """Return latest predictions ordered by ``created_at`` desc.
101
+
102
+ Returns ``[]`` when the client is not configured.
103
+ """
104
+ client = get_client()
105
+ if client is None:
106
+ return []
107
+
108
+ try:
109
+ query = client.table(_TABLE).select("*").order("created_at", desc=True)
110
+ if video_id:
111
+ query = query.eq("video_id", video_id)
112
+ query = query.limit(max(1, min(limit, 200)))
113
+ response = query.execute()
114
+ return list(getattr(response, "data", []) or [])
115
+ except Exception as exc:
116
+ logger.warning("list_predictions failed: %s", exc)
117
+ return []
supabase/predictions_setup.sql ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- =====================================================================
2
+ -- SignalMod — Predictions table setup
3
+ -- Run this in Supabase SQL Editor:
4
+ -- Dashboard → SQL Editor → New query → paste this → Run
5
+ -- =====================================================================
6
+
7
+ -- 1. Table
8
+ create table if not exists public.predictions (
9
+ id bigserial primary key,
10
+ created_at timestamptz not null default now(),
11
+ text text not null,
12
+ video_id text,
13
+ video_url text,
14
+ probability double precision,
15
+ is_toxic boolean,
16
+ labels text[] default '{}',
17
+ model_used text,
18
+ threshold double precision,
19
+ latency_ms double precision,
20
+ source text -- "api_direct" | "video_fetch" | "user_comment"
21
+ );
22
+
23
+ -- 2. Indexes for the queries the API will run
24
+ create index if not exists predictions_created_at_idx
25
+ on public.predictions (created_at desc);
26
+
27
+ create index if not exists predictions_video_id_idx
28
+ on public.predictions (video_id);
29
+
30
+ -- 3. Row Level Security: allow anonymous insert + select
31
+ -- (we are using the publishable key from the frontend / backend with no auth)
32
+ alter table public.predictions enable row level security;
33
+
34
+ drop policy if exists "anon_insert" on public.predictions;
35
+ create policy "anon_insert"
36
+ on public.predictions
37
+ for insert
38
+ to anon
39
+ with check (true);
40
+
41
+ drop policy if exists "anon_select" on public.predictions;
42
+ create policy "anon_select"
43
+ on public.predictions
44
+ for select
45
+ to anon
46
+ using (true);
47
+
48
+ -- 4. Sanity check (run separately if you want)
49
+ -- select count(*) from public.predictions;