import time from fastapi import APIRouter, HTTPException, Query from src.api.schemas import ( BatchPredictRequest, BatchPredictResponse, PredictRequest, PredictResponse, VideoRequest, VideoResponse, ) from src.api.services import predict_single, to_predict_response from src.api.state import get_state from src.api.youtube import CommentsFetchError, extract_video_id, fetch_comments from src.db.supabase_client import list_predictions, save_prediction router = APIRouter(tags=["Prediction"]) @router.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): response = predict_single(request.text, request.threshold) save_prediction( text=request.text, result=response, source="api_direct", threshold=request.threshold, latency_ms=response.latency_ms, ) return response @router.post("/predict-batch", response_model=BatchPredictResponse) async def predict_batch(request: BatchPredictRequest): t0 = time.perf_counter() results: list[PredictResponse] = [] for text in request.texts: if not text.strip(): continue single = predict_single(text.strip(), request.threshold) results.append(single) save_prediction( text=text.strip(), result=single, source="api_direct", threshold=request.threshold, latency_ms=single.latency_ms, ) total_ms = round((time.perf_counter() - t0) * 1000, 2) toxic_count = sum(1 for r in results if r.is_toxic) return BatchPredictResponse( results=results, total=len(results), toxic_count=toxic_count, latency_ms=total_ms, ) @router.post("/predict-video", response_model=VideoResponse) async def predict_video(request: VideoRequest): try: comments, source = fetch_comments(request.url, request.max_comments) except CommentsFetchError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=422, detail=f"Failed to fetch comments: {exc}") from exc if not comments: raise HTTPException(status_code=404, detail="No comments found for this video") video_id = extract_video_id(request.url) t0 = time.perf_counter() results: list[PredictResponse] = [] service = get_state()["service"] if service is None: raise HTTPException(status_code=503, detail="Model not loaded") for text in comments: if not text.strip(): continue raw = service.predict(text) response = to_predict_response(text, raw, 0.0, request.threshold) results.append(response) save_prediction( text=text, result=response, source="video_fetch", video_id=video_id, video_url=request.url, threshold=request.threshold, latency_ms=response.latency_ms, ) total_ms = round((time.perf_counter() - t0) * 1000, 2) toxic_count = sum(1 for r in results if r.is_toxic) get_state()["predictions_served"] = get_state().get("predictions_served", 0) + len(results) return VideoResponse( video_url=request.url, total_fetched=len(results), toxic_count=toxic_count, toxic_rate=round(toxic_count / len(results), 4) if results else 0.0, results=results, source=source, ) @router.get("/predictions") async def get_predictions( video_id: str | None = Query(default=None), limit: int = Query(default=50, ge=1, le=200), ): rows = list_predictions(video_id=video_id, limit=limit) return rows