Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from starlette.concurrency import run_in_threadpool | |
| from src.ai_image_detector.inference import ( | |
| CalibrationConfig, | |
| PredictionResult, | |
| load_trained_model, | |
| predict_image_bytes, | |
| ) | |
| BASE_DIR = Path(__file__).resolve().parent | |
| STATIC_DIR = BASE_DIR / "static" | |
| MODE_CONFIGS = { | |
| "default": { | |
| "calibration": CalibrationConfig( | |
| threshold=0.65, | |
| uncertain_low=0.45, | |
| uncertain_high=0.70, | |
| ), | |
| "orientation_conservative": True, | |
| }, | |
| "sensitive": { | |
| "calibration": CalibrationConfig( | |
| threshold=0.40, | |
| uncertain_low=0.30, | |
| uncertain_high=0.50, | |
| ), | |
| "orientation_conservative": False, | |
| }, | |
| } | |
| app = FastAPI(title="SENTINEL_AI") | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| def cache_model() -> None: | |
| app.state.model = load_trained_model() | |
| async def health() -> dict[str, str]: | |
| return {"status": "ok"} | |
| def get_mode_settings(mode: str) -> dict: | |
| settings = MODE_CONFIGS.get(mode) | |
| if settings is None: | |
| raise HTTPException(status_code=400, detail=f"Unsupported mode: {mode}") | |
| return settings | |
| def serialize_prediction(result: PredictionResult) -> dict[str, float | str]: | |
| return { | |
| "label": result.label, | |
| "ai_probability": float(result.ai_probability), | |
| "confidence": float(result.confidence), | |
| } | |
| async def run_prediction(upload: UploadFile, mode: str) -> dict[str, float | str]: | |
| payload = await upload.read() | |
| if not payload: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty.") | |
| settings = get_mode_settings(mode) | |
| try: | |
| result = await run_in_threadpool( | |
| predict_image_bytes, | |
| app.state.model, | |
| payload, | |
| settings["calibration"], | |
| settings["orientation_conservative"], | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unable to process '{upload.filename or 'upload'}' as an image.", | |
| ) from exc | |
| return serialize_prediction(result) | |
| async def serve_index() -> FileResponse: | |
| return FileResponse(STATIC_DIR / "index.html") | |
| async def predict( | |
| file: UploadFile = File(...), | |
| mode: str = Form("default"), | |
| ) -> dict[str, float | str]: | |
| return await run_prediction(file, mode) | |
| async def predict_batch( | |
| files: list[UploadFile] = File(...), | |
| mode: str = Form("default"), | |
| ) -> list[dict[str, float | str]]: | |
| if not files: | |
| raise HTTPException(status_code=400, detail="Upload at least one image.") | |
| results: list[dict[str, float | str]] = [] | |
| for upload in files: | |
| row = await run_prediction(upload, mode) | |
| row["filename"] = upload.filename or "upload" | |
| results.append(row) | |
| return results | |