File size: 9,442 Bytes
ef5ede7 7ed1dd3 ef5ede7 5046041 dbea7ff 5046041 dbea7ff 5046041 dbea7ff 5046041 dbea7ff 5046041 ef5ede7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 | from fastapi import FastAPI, APIRouter, HTTPException, UploadFile, File, Form, Depends
from dotenv import load_dotenv
from starlette.middleware.cors import CORSMiddleware
from motor.motor_asyncio import AsyncIOMotorClient
import os
import logging
import asyncio
from pathlib import Path
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field, ConfigDict
from typing import List, Optional
import uuid
from datetime import datetime, timezone
from model import load_model, predict
from auth import auth_router, init_auth_db, get_current_user
ROOT_DIR = Path(__file__).parent
load_dotenv(ROOT_DIR / '.env')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# MongoDB connection
mongo_url = os.environ['MONGO_URL']
mongo_client = AsyncIOMotorClient(mongo_url)
db = mongo_client[os.environ['DB_NAME']]
# Model path – default looks two directories up (project root)
MODEL_PATH = os.environ.get(
'MODEL_PATH',
str(ROOT_DIR / 'best_deepfake_model_tensor_finetuned.pt'),
)
# ---------- Lifespan (load model once) ----------
@asynccontextmanager
async def lifespan(application: FastAPI):
"""Load the ML model at startup, clean up at shutdown."""
# Share DB with auth module
init_auth_db(db)
logger.info("Loading deepfake detection model …")
model, feature_extractor = load_model(MODEL_PATH, device="cpu")
application.state.model = model
application.state.feature_extractor = feature_extractor
logger.info("Model ready.")
yield
mongo_client.close()
logger.info("Shutdown complete.")
app = FastAPI(title="SADA API", lifespan=lifespan)
api_router = APIRouter(prefix="/api")
# ---------- Models ----------
class DetectionRequest(BaseModel):
filename: str
duration_seconds: float = 0.0
source: str = "upload" # "upload" | "record"
size_bytes: int = 0
mime_type: Optional[str] = None
class DetectionResult(BaseModel):
model_config = ConfigDict(extra="ignore")
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
user_id: Optional[str] = None
filename: str
duration_seconds: float = 0.0
source: str = "upload"
size_bytes: int = 0
mime_type: Optional[str] = None
label: str # "ai" | "human"
confidence: float # 0..100
breakdown: dict # {"ai": float, "human": float, "noise": float}
model_used: str = "SADA-Mock-v1"
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class StatsResponse(BaseModel):
total: int
ai_count: int
human_count: int
ai_ratio: float
human_ratio: float
avg_confidence: float
last_7_days: List[dict]
# ---------- Helpers ----------
def _serialize(doc: dict) -> dict:
if isinstance(doc.get("created_at"), datetime):
doc["created_at"] = doc["created_at"].isoformat()
return doc
def _deserialize(doc: dict) -> dict:
if isinstance(doc.get("created_at"), str):
try:
doc["created_at"] = datetime.fromisoformat(doc["created_at"])
except Exception:
pass
return doc
# (_mock_detect removed – using real model inference)
# ---------- Routes ----------
@api_router.get("/")
async def root():
return {"service": "SADA", "status": "ok"}
@api_router.post("/detect", response_model=DetectionResult)
async def detect_audio(
file: UploadFile = File(...),
duration_seconds: float = Form(0.0),
source: str = Form("upload"),
current_user: dict = Depends(get_current_user),
):
# Read uploaded audio bytes
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty audio file")
# Run real inference in a thread pool to avoid blocking the event loop
try:
result = await asyncio.to_thread(
predict,
audio_bytes,
app.state.model,
app.state.feature_extractor,
"cpu",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Inference failed")
raise HTTPException(status_code=500, detail="Inference error")
obj = DetectionResult(
user_id=current_user["id"],
filename=file.filename or "unknown",
duration_seconds=result.get("duration_seconds", duration_seconds),
source=source,
size_bytes=len(audio_bytes),
mime_type=file.content_type,
label=result["label"],
confidence=result["confidence"],
breakdown=result["breakdown"],
model_used="SADA-Wav2Vec2-v1",
)
doc = obj.model_dump()
doc = _serialize(doc)
await db.detections.insert_one(doc)
return obj
@api_router.get("/history", response_model=List[DetectionResult])
async def get_history(
limit: int = 50,
label: Optional[str] = None,
current_user: dict = Depends(get_current_user),
):
query = {"user_id": current_user["id"]}
if label in {"ai", "human"}:
query["label"] = label
cursor = db.detections.find(query, {"_id": 0}).sort("created_at", -1).limit(limit)
items = await cursor.to_list(length=limit)
return [DetectionResult(**_deserialize(item)) for item in items]
@api_router.get("/history/{detection_id}", response_model=DetectionResult)
async def get_detection(
detection_id: str,
current_user: dict = Depends(get_current_user),
):
item = await db.detections.find_one(
{"id": detection_id, "user_id": current_user["id"]}, {"_id": 0}
)
if not item:
raise HTTPException(status_code=404, detail="Detection not found")
return DetectionResult(**_deserialize(item))
@api_router.delete("/history/{detection_id}")
async def delete_detection(
detection_id: str,
current_user: dict = Depends(get_current_user),
):
result = await db.detections.delete_one(
{"id": detection_id, "user_id": current_user["id"]}
)
if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Detection not found")
return {"deleted": True, "id": detection_id}
@api_router.delete("/history")
async def clear_history(current_user: dict = Depends(get_current_user)):
result = await db.detections.delete_many({"user_id": current_user["id"]})
return {"deleted": result.deleted_count}
@api_router.get("/stats", response_model=StatsResponse)
async def get_stats(current_user: dict = Depends(get_current_user)):
items = await db.detections.find(
{"user_id": current_user["id"]}, {"_id": 0}
).to_list(length=10000)
total = len(items)
ai_count = sum(1 for i in items if i.get("label") == "ai")
human_count = sum(1 for i in items if i.get("label") == "human")
avg_conf = (sum(float(i.get("confidence", 0)) for i in items) / total) if total else 0.0
# Last 7 days bucket
from collections import defaultdict
buckets = defaultdict(lambda: {"ai": 0, "human": 0})
today = datetime.now(timezone.utc).date()
for i in items:
ts = i.get("created_at")
if isinstance(ts, str):
try:
ts = datetime.fromisoformat(ts)
except Exception:
continue
if not isinstance(ts, datetime):
continue
d = ts.date()
delta = (today - d).days
if 0 <= delta <= 6:
key = d.isoformat()
buckets[key][i.get("label", "human")] += 1
last_7 = []
for n in range(6, -1, -1):
from datetime import timedelta
d = (today - timedelta(days=n)).isoformat()
b = buckets.get(d, {"ai": 0, "human": 0})
last_7.append({"date": d, "ai": b["ai"], "human": b["human"]})
return StatsResponse(
total=total,
ai_count=ai_count,
human_count=human_count,
ai_ratio=round((ai_count / total) * 100, 2) if total else 0.0,
human_ratio=round((human_count / total) * 100, 2) if total else 0.0,
avg_confidence=round(avg_conf, 2),
last_7_days=last_7,
)
@api_router.get("/global-stats")
async def get_global_stats():
# Iterate all for a simple global count
items = await db.detections.find({}, {"_id": 0, "label": 1}).sort("created_at", -1).to_list(length=100000)
total_found = len(items)
ai_count = sum(1 for i in items if i.get("label") == "ai")
human_count = sum(1 for i in items if i.get("label") == "human")
# Get last 56 labels for the live waveform visual
recent_labels = [i.get("label", "human") for i in items[:56]]
# Hardcoded global accuracy representing the SADA model
avg_accuracy = 79.8
if total_found == 0:
return {
"total": total_found,
"ai_ratio": 0.0,
"human_ratio": 0.0,
"avg_accuracy": avg_accuracy,
"recent_labels": []
}
return {
"total": total_found,
"ai_ratio": round((ai_count / total_found) * 100, 1),
"human_ratio": round((human_count / total_found) * 100, 1),
"avg_accuracy": avg_accuracy,
"recent_labels": recent_labels
}
app.include_router(api_router)
app.include_router(auth_router)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=os.environ.get('CORS_ORIGINS', '*').split(','),
allow_methods=["*"],
allow_headers=["*"],
)
|