Spaces:
Running
Running
| from __future__ import annotations | |
| from uuid import UUID | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from pydantic import BaseModel | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from ..db.models import FeedbackRow, PredictionRow | |
| from ..db.session import get_session | |
| router = APIRouter() | |
| class FeedbackIn(BaseModel): | |
| prediction_id: UUID | |
| is_correct: bool | None = None | |
| user_label: str | None = None | |
| comment: str | None = None | |
| async def post_feedback(payload: FeedbackIn, session: AsyncSession = Depends(get_session)) -> dict: | |
| pred = ( | |
| await session.execute(select(PredictionRow).where(PredictionRow.id == payload.prediction_id)) | |
| ).scalar_one_or_none() | |
| if pred is None: | |
| raise HTTPException(status_code=404, detail="prediction not found") | |
| row = FeedbackRow( | |
| prediction_id=payload.prediction_id, | |
| is_correct=payload.is_correct, | |
| user_label=payload.user_label, | |
| comment=payload.comment, | |
| ) | |
| session.add(row) | |
| await session.commit() | |
| return {"ok": True, "feedback_id": row.id} | |
| async def stats(session: AsyncSession = Depends(get_session)) -> dict: | |
| total = (await session.execute(select(FeedbackRow))).scalars().all() | |
| n = len(total) | |
| n_correct = sum(1 for r in total if r.is_correct is True) | |
| n_wrong = sum(1 for r in total if r.is_correct is False) | |
| return { | |
| "n_total": n, | |
| "n_correct": n_correct, | |
| "n_wrong": n_wrong, | |
| "accuracy_on_user_uploads": (n_correct / max(1, n_correct + n_wrong)), | |
| } | |