File size: 10,501 Bytes
df4a21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""
Prediction routes.
"""

import base64
from typing import Optional

from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile

from app.core.errors import (
    DeepFakeDetectorError,
    ImageProcessingError,
    InferenceError,
    FusionError,
    ModelNotFoundError,
    ModelNotLoadedError
)
from app.core.logging import get_logger
from app.schemas.predict import (
    PredictResponse,
    PredictionResult,
    TimingInfo,
    ErrorResponse,
    FusionMeta,
    ModelDisplayInfo,
    ExplainModelResponse,
    SingleModelInsight
)
from app.services.inference_service import get_inference_service
from app.services.fusion_service import get_fusion_service
from app.services.preprocess_service import get_preprocess_service
from app.services.model_registry import get_model_registry
from app.services.llm_service import get_llm_service, get_model_display_info, MODEL_DISPLAY_INFO
from app.utils.timing import Timer

logger = get_logger(__name__)
router = APIRouter(tags=["predict"])


@router.post(
    "/predict",
    response_model=PredictResponse,
    summary="Predict if image is real or fake",
    description="Upload an image to get a deepfake detection prediction",
    responses={
        400: {"model": ErrorResponse, "description": "Invalid image or request"},
        404: {"model": ErrorResponse, "description": "Model not found"},
        500: {"model": ErrorResponse, "description": "Inference error"}
    }
)
async def predict(
    image: UploadFile = File(..., description="Image file to analyze"),
    use_fusion: bool = Query(
        True,
        description="Use fusion model (majority vote) across all submodels"
    ),
    model: Optional[str] = Query(
        None,
        description="Specific submodel to use (name or repo_id). Only used when use_fusion=false"
    ),
    return_submodels: Optional[bool] = Query(
        None,
        description="Include individual submodel predictions in response. Defaults to true when use_fusion=true"
    ),
    explain: bool = Query(
        True,
        description="Generate explainability heatmaps (Grad-CAM for CNNs, attention rollout for transformers)"
    )
) -> PredictResponse:
    """
    Predict if an uploaded image is real or fake.
    
    When use_fusion=true (default):
    - Runs all submodels on the image
    - Combines predictions using majority vote fusion
    - Returns the fused result plus optionally individual submodel results
    
    When use_fusion=false:
    - Runs only the specified submodel (or the first available if not specified)
    - Returns just that model's prediction
    
    Response includes timing information for each step.
    """
    timer = Timer()
    timer.start_total()
    
    # Determine if we should return submodel results
    should_return_submodels = return_submodels if return_submodels is not None else use_fusion
    
    try:
        # Read image bytes
        with timer.measure("download"):
            image_bytes = await image.read()
        
        # Validate and preprocess
        with timer.measure("preprocess"):
            preprocess_service = get_preprocess_service()
            preprocess_service.validate_image(image_bytes)
        
        inference_service = get_inference_service()
        fusion_service = get_fusion_service()
        registry = get_model_registry()
        
        if use_fusion:
            # Run all submodels
            with timer.measure("inference"):
                submodel_outputs = inference_service.predict_all_submodels(
                    image_bytes=image_bytes,
                    explain=explain
                )
            
            # Run fusion
            with timer.measure("fusion"):
                final_result = fusion_service.fuse(submodel_outputs=submodel_outputs)
            
            timer.stop_total()
            
            # Extract fusion meta (contribution percentages)
            fusion_meta_dict = final_result.get("meta", {})
            contribution_percentages = fusion_meta_dict.get("contribution_percentages", {})
            
            # Build fusion meta object
            fusion_meta = FusionMeta(
                submodel_weights=fusion_meta_dict.get("submodel_weights", {}),
                weighted_contributions=fusion_meta_dict.get("weighted_contributions", {}),
                contribution_percentages=contribution_percentages
            ) if fusion_meta_dict else None
            
            # Build model display info for frontend
            model_display_info = {
                name: ModelDisplayInfo(**get_model_display_info(name))
                for name in submodel_outputs.keys()
            }
            
            # Build response
            return PredictResponse(
                final=PredictionResult(
                    pred=final_result["pred"],
                    pred_int=final_result["pred_int"],
                    prob_fake=final_result["prob_fake"]
                ),
                fusion_used=True,
                submodels={
                    name: PredictionResult(
                        pred=output["pred"],
                        pred_int=output["pred_int"],
                        prob_fake=output["prob_fake"],
                        heatmap_base64=output.get("heatmap_base64"),
                        explainability_type=output.get("explainability_type"),
                        focus_summary=output.get("focus_summary"),
                        contribution_percentage=contribution_percentages.get(name)
                    )
                    for name, output in submodel_outputs.items()
                } if should_return_submodels else None,
                fusion_meta=fusion_meta,
                model_display_info=model_display_info if should_return_submodels else None,
                timing_ms=TimingInfo(**timer.get_timings())
            )
        
        else:
            # Single model prediction
            model_key = model or registry.get_submodel_names()[0]
            
            with timer.measure("inference"):
                result = inference_service.predict_single(
                    model_key=model_key,
                    image_bytes=image_bytes,
                    explain=explain
                )
            
            timer.stop_total()
            
            return PredictResponse(
                final=PredictionResult(
                    pred=result["pred"],
                    pred_int=result["pred_int"],
                    prob_fake=result["prob_fake"],
                    heatmap_base64=result.get("heatmap_base64"),
                    explainability_type=result.get("explainability_type"),
                    focus_summary=result.get("focus_summary")
                ),
                fusion_used=False,
                submodels=None,
                timing_ms=TimingInfo(**timer.get_timings())
            )
    
    except ImageProcessingError as e:
        logger.warning(f"Image processing error: {e.message}")
        raise HTTPException(
            status_code=400,
            detail={"error": "ImageProcessingError", "message": e.message, "details": e.details}
        )
    
    except ModelNotFoundError as e:
        logger.warning(f"Model not found: {e.message}")
        raise HTTPException(
            status_code=404,
            detail={"error": "ModelNotFoundError", "message": e.message, "details": e.details}
        )
    
    except ModelNotLoadedError as e:
        logger.error(f"Models not loaded: {e.message}")
        raise HTTPException(
            status_code=503,
            detail={"error": "ModelNotLoadedError", "message": e.message, "details": e.details}
        )
    
    except (InferenceError, FusionError) as e:
        logger.error(f"Inference/Fusion error: {e.message}")
        raise HTTPException(
            status_code=500,
            detail={"error": type(e).__name__, "message": e.message, "details": e.details}
        )
    
    except Exception as e:
        logger.exception(f"Unexpected error in predict endpoint: {e}")
        raise HTTPException(
            status_code=500,
            detail={"error": "InternalError", "message": str(e)}
        )


@router.post("/explain-model", response_model=ExplainModelResponse)
async def explain_model(
    image: UploadFile = File(...),
    model_name: str = Form(...),
    prob_fake: float = Form(...),
    contribution_percentage: float = Form(None),
    heatmap_base64: str = Form(None),
    focus_summary: str = Form(None)
):
    """
    Generate an on-demand LLM explanation for a single model's prediction.
    This endpoint is token-efficient - only called when user requests insights.
    """
    try:
        # Read and validate image
        image_bytes = await image.read()
        if len(image_bytes) == 0:
            raise HTTPException(status_code=400, detail="Empty image file")
        
        # Encode image to base64 for LLM
        original_b64 = base64.b64encode(image_bytes).decode('utf-8')
        
        # Get LLM service
        llm_service = get_llm_service()
        if not llm_service.enabled:
            raise HTTPException(
                status_code=503, 
                detail="LLM service is not enabled. Set GEMINI_API_KEY environment variable."
            )
        
        # Generate explanation
        result = llm_service.generate_single_model_explanation(
            model_name=model_name,
            original_image_b64=original_b64,
            prob_fake=prob_fake,
            heatmap_b64=heatmap_base64,
            contribution_percentage=contribution_percentage,
            focus_summary=focus_summary
        )
        
        if result is None:
            raise HTTPException(
                status_code=500,
                detail="Failed to generate explanation from LLM"
            )
        
        return ExplainModelResponse(
            model_name=model_name,
            insight=SingleModelInsight(
                key_finding=result["key_finding"],
                what_model_saw=result["what_model_saw"],
                important_regions=result["important_regions"],
                confidence_qualifier=result["confidence_qualifier"]
            )
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.exception(f"Error generating model explanation: {e}")
        raise HTTPException(
            status_code=500,
            detail={"error": "ExplanationError", "message": str(e)}
        )