DeepFakeDetectorBackend / app /api /routes_predict.py
lukhsaankumar's picture
Deploy DeepFake Detector API - 2026-03-07 09:12:00
df4a21a
"""
Prediction routes.
"""
import base64
from typing import Optional
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from app.core.errors import (
DeepFakeDetectorError,
ImageProcessingError,
InferenceError,
FusionError,
ModelNotFoundError,
ModelNotLoadedError
)
from app.core.logging import get_logger
from app.schemas.predict import (
PredictResponse,
PredictionResult,
TimingInfo,
ErrorResponse,
FusionMeta,
ModelDisplayInfo,
ExplainModelResponse,
SingleModelInsight
)
from app.services.inference_service import get_inference_service
from app.services.fusion_service import get_fusion_service
from app.services.preprocess_service import get_preprocess_service
from app.services.model_registry import get_model_registry
from app.services.llm_service import get_llm_service, get_model_display_info, MODEL_DISPLAY_INFO
from app.utils.timing import Timer
logger = get_logger(__name__)
router = APIRouter(tags=["predict"])
@router.post(
"/predict",
response_model=PredictResponse,
summary="Predict if image is real or fake",
description="Upload an image to get a deepfake detection prediction",
responses={
400: {"model": ErrorResponse, "description": "Invalid image or request"},
404: {"model": ErrorResponse, "description": "Model not found"},
500: {"model": ErrorResponse, "description": "Inference error"}
}
)
async def predict(
image: UploadFile = File(..., description="Image file to analyze"),
use_fusion: bool = Query(
True,
description="Use fusion model (majority vote) across all submodels"
),
model: Optional[str] = Query(
None,
description="Specific submodel to use (name or repo_id). Only used when use_fusion=false"
),
return_submodels: Optional[bool] = Query(
None,
description="Include individual submodel predictions in response. Defaults to true when use_fusion=true"
),
explain: bool = Query(
True,
description="Generate explainability heatmaps (Grad-CAM for CNNs, attention rollout for transformers)"
)
) -> PredictResponse:
"""
Predict if an uploaded image is real or fake.
When use_fusion=true (default):
- Runs all submodels on the image
- Combines predictions using majority vote fusion
- Returns the fused result plus optionally individual submodel results
When use_fusion=false:
- Runs only the specified submodel (or the first available if not specified)
- Returns just that model's prediction
Response includes timing information for each step.
"""
timer = Timer()
timer.start_total()
# Determine if we should return submodel results
should_return_submodels = return_submodels if return_submodels is not None else use_fusion
try:
# Read image bytes
with timer.measure("download"):
image_bytes = await image.read()
# Validate and preprocess
with timer.measure("preprocess"):
preprocess_service = get_preprocess_service()
preprocess_service.validate_image(image_bytes)
inference_service = get_inference_service()
fusion_service = get_fusion_service()
registry = get_model_registry()
if use_fusion:
# Run all submodels
with timer.measure("inference"):
submodel_outputs = inference_service.predict_all_submodels(
image_bytes=image_bytes,
explain=explain
)
# Run fusion
with timer.measure("fusion"):
final_result = fusion_service.fuse(submodel_outputs=submodel_outputs)
timer.stop_total()
# Extract fusion meta (contribution percentages)
fusion_meta_dict = final_result.get("meta", {})
contribution_percentages = fusion_meta_dict.get("contribution_percentages", {})
# Build fusion meta object
fusion_meta = FusionMeta(
submodel_weights=fusion_meta_dict.get("submodel_weights", {}),
weighted_contributions=fusion_meta_dict.get("weighted_contributions", {}),
contribution_percentages=contribution_percentages
) if fusion_meta_dict else None
# Build model display info for frontend
model_display_info = {
name: ModelDisplayInfo(**get_model_display_info(name))
for name in submodel_outputs.keys()
}
# Build response
return PredictResponse(
final=PredictionResult(
pred=final_result["pred"],
pred_int=final_result["pred_int"],
prob_fake=final_result["prob_fake"]
),
fusion_used=True,
submodels={
name: PredictionResult(
pred=output["pred"],
pred_int=output["pred_int"],
prob_fake=output["prob_fake"],
heatmap_base64=output.get("heatmap_base64"),
explainability_type=output.get("explainability_type"),
focus_summary=output.get("focus_summary"),
contribution_percentage=contribution_percentages.get(name)
)
for name, output in submodel_outputs.items()
} if should_return_submodels else None,
fusion_meta=fusion_meta,
model_display_info=model_display_info if should_return_submodels else None,
timing_ms=TimingInfo(**timer.get_timings())
)
else:
# Single model prediction
model_key = model or registry.get_submodel_names()[0]
with timer.measure("inference"):
result = inference_service.predict_single(
model_key=model_key,
image_bytes=image_bytes,
explain=explain
)
timer.stop_total()
return PredictResponse(
final=PredictionResult(
pred=result["pred"],
pred_int=result["pred_int"],
prob_fake=result["prob_fake"],
heatmap_base64=result.get("heatmap_base64"),
explainability_type=result.get("explainability_type"),
focus_summary=result.get("focus_summary")
),
fusion_used=False,
submodels=None,
timing_ms=TimingInfo(**timer.get_timings())
)
except ImageProcessingError as e:
logger.warning(f"Image processing error: {e.message}")
raise HTTPException(
status_code=400,
detail={"error": "ImageProcessingError", "message": e.message, "details": e.details}
)
except ModelNotFoundError as e:
logger.warning(f"Model not found: {e.message}")
raise HTTPException(
status_code=404,
detail={"error": "ModelNotFoundError", "message": e.message, "details": e.details}
)
except ModelNotLoadedError as e:
logger.error(f"Models not loaded: {e.message}")
raise HTTPException(
status_code=503,
detail={"error": "ModelNotLoadedError", "message": e.message, "details": e.details}
)
except (InferenceError, FusionError) as e:
logger.error(f"Inference/Fusion error: {e.message}")
raise HTTPException(
status_code=500,
detail={"error": type(e).__name__, "message": e.message, "details": e.details}
)
except Exception as e:
logger.exception(f"Unexpected error in predict endpoint: {e}")
raise HTTPException(
status_code=500,
detail={"error": "InternalError", "message": str(e)}
)
@router.post("/explain-model", response_model=ExplainModelResponse)
async def explain_model(
image: UploadFile = File(...),
model_name: str = Form(...),
prob_fake: float = Form(...),
contribution_percentage: float = Form(None),
heatmap_base64: str = Form(None),
focus_summary: str = Form(None)
):
"""
Generate an on-demand LLM explanation for a single model's prediction.
This endpoint is token-efficient - only called when user requests insights.
"""
try:
# Read and validate image
image_bytes = await image.read()
if len(image_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty image file")
# Encode image to base64 for LLM
original_b64 = base64.b64encode(image_bytes).decode('utf-8')
# Get LLM service
llm_service = get_llm_service()
if not llm_service.enabled:
raise HTTPException(
status_code=503,
detail="LLM service is not enabled. Set GEMINI_API_KEY environment variable."
)
# Generate explanation
result = llm_service.generate_single_model_explanation(
model_name=model_name,
original_image_b64=original_b64,
prob_fake=prob_fake,
heatmap_b64=heatmap_base64,
contribution_percentage=contribution_percentage,
focus_summary=focus_summary
)
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to generate explanation from LLM"
)
return ExplainModelResponse(
model_name=model_name,
insight=SingleModelInsight(
key_finding=result["key_finding"],
what_model_saw=result["what_model_saw"],
important_regions=result["important_regions"],
confidence_qualifier=result["confidence_qualifier"]
)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"Error generating model explanation: {e}")
raise HTTPException(
status_code=500,
detail={"error": "ExplanationError", "message": str(e)}
)