Spaces:
Sleeping
Sleeping
File size: 3,186 Bytes
ff0c419 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | from __future__ import annotations
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from starlette.concurrency import run_in_threadpool
from src.ai_image_detector.inference import (
CalibrationConfig,
PredictionResult,
load_trained_model,
predict_image_bytes,
)
BASE_DIR = Path(__file__).resolve().parent
STATIC_DIR = BASE_DIR / "static"
MODE_CONFIGS = {
"default": {
"calibration": CalibrationConfig(
threshold=0.65,
uncertain_low=0.45,
uncertain_high=0.70,
),
"orientation_conservative": True,
},
"sensitive": {
"calibration": CalibrationConfig(
threshold=0.40,
uncertain_low=0.30,
uncertain_high=0.50,
),
"orientation_conservative": False,
},
}
app = FastAPI(title="SENTINEL_AI")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.on_event("startup")
def cache_model() -> None:
app.state.model = load_trained_model()
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok"}
def get_mode_settings(mode: str) -> dict:
settings = MODE_CONFIGS.get(mode)
if settings is None:
raise HTTPException(status_code=400, detail=f"Unsupported mode: {mode}")
return settings
def serialize_prediction(result: PredictionResult) -> dict[str, float | str]:
return {
"label": result.label,
"ai_probability": float(result.ai_probability),
"confidence": float(result.confidence),
}
async def run_prediction(upload: UploadFile, mode: str) -> dict[str, float | str]:
payload = await upload.read()
if not payload:
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
settings = get_mode_settings(mode)
try:
result = await run_in_threadpool(
predict_image_bytes,
app.state.model,
payload,
settings["calibration"],
settings["orientation_conservative"],
)
except Exception as exc: # noqa: BLE001
raise HTTPException(
status_code=400,
detail=f"Unable to process '{upload.filename or 'upload'}' as an image.",
) from exc
return serialize_prediction(result)
@app.get("/")
async def serve_index() -> FileResponse:
return FileResponse(STATIC_DIR / "index.html")
@app.post("/predict")
async def predict(
file: UploadFile = File(...),
mode: str = Form("default"),
) -> dict[str, float | str]:
return await run_prediction(file, mode)
@app.post("/predict/batch")
async def predict_batch(
files: list[UploadFile] = File(...),
mode: str = Form("default"),
) -> list[dict[str, float | str]]:
if not files:
raise HTTPException(status_code=400, detail="Upload at least one image.")
results: list[dict[str, float | str]] = []
for upload in files:
row = await run_prediction(upload, mode)
row["filename"] = upload.filename or "upload"
results.append(row)
return results
|