File size: 9,635 Bytes
928b74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
backend/app.py
---------------
ImageForensics-Detect β€” FastAPI Backend
STATUS: COMPLETE

Endpoints:
  POST /predict    β€” Accept image upload, run all branches, return JSON result
  GET  /health     β€” Health check
  GET  /logs       β€” Summary statistics from prediction log

Run locally:
  cd ImageForensics-Detect/
  uvicorn backend.app:app --reload --host 0.0.0.0 --port 8000

Test with curl:
  curl -X POST "http://localhost:8000/predict" \
       -F "file=@path/to/test_image.jpg"
"""

import sys
import os
from pathlib import Path

# ── Add project root to sys.path ─────────────────────────────────
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
os.chdir(ROOT)   # Ensure relative paths resolve correctly

import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse

from utils.image_utils import load_image_from_bytes
from utils.logger import log_prediction, get_log_summary
from branches.spectral_branch import run_spectral_branch
from branches.edge_branch import run_edge_branch
from branches.cnn_branch import run_cnn_branch
from branches.vit_branch import run_vit_branch
from branches.diffusion_branch import run_diffusion_branch
from fusion.fusion import fuse_branches, format_result_for_display
from explainability.gradcam import compute_gradcam, _fallback_heatmap
from explainability.spectral_heatmap import (
    render_spectral_heatmap,
    render_noise_map,
    render_edge_map,
)

# ─────────────────────────────────────────────────────────────────
# App Setup
# ─────────────────────────────────────────────────────────────────

app = FastAPI(
    title="ImageForensics-Detect API",
    description="Multi-branch image forensics for real vs. AI-generated image detection.",
    version="1.0.0",
)

# Allow frontend (localhost:3000 / file://) to call the API
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# Serve output visualization files
OUTPUTS_DIR = ROOT / "outputs"
OUTPUTS_DIR.mkdir(exist_ok=True)
app.mount("/outputs", StaticFiles(directory=str(OUTPUTS_DIR)), name="outputs")

# NEW: Serve frontend static files
FRONTEND_DIR = ROOT / "frontend"
if FRONTEND_DIR.exists():
    app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")

    from fastapi.responses import FileResponse
    @app.get("/")
    async def read_index():
        return FileResponse(FRONTEND_DIR / "index.html")

# Allowed image MIME types
ALLOWED_MIME = {"image/jpeg", "image/png", "image/webp", "image/bmp"}
MAX_FILE_SIZE = 15 * 1024 * 1024  # 15 MB


# ─────────────────────────────────────────────────────────────────
# Endpoints
# ─────────────────────────────────────────────────────────────────

@app.get("/health")
def health():
    """Server health check."""
    return {"status": "ok", "service": "ImageForensics-Detect", "version": "1.0.0"}


@app.get("/logs")
def logs():
    """Return prediction log summary statistics."""
    return get_log_summary()


@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """
    Analyze an uploaded image through all 5 forensic branches and return:
      - Final prediction (Real / AI-Generated)
      - Confidence score (%)
      - Per-branch probability and confidence
      - Base64-encoded Grad-CAM heatmap
      - Base64-encoded spectral heatmap with anomaly annotation
      - Base64-encoded residual noise map
      - Base64-encoded edge map
    """
    # ── 1. Validate Upload ────────────────────────────────────────
    if file.content_type not in ALLOWED_MIME:
        raise HTTPException(
            status_code=415,
            detail=f"Unsupported file type: {file.content_type}. "
                   f"Accepted: JPEG, PNG, WEBP, BMP"
        )

    raw_bytes = await file.read()
    if len(raw_bytes) > MAX_FILE_SIZE:
        raise HTTPException(status_code=413, detail="File too large (max 15 MB).")

    # ── 2. Load & Preprocess Image ────────────────────────────────
    try:
        img = load_image_from_bytes(raw_bytes, size=(224, 224))  # float32 [0,1]
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Could not decode image: {e}")

    # ── 3. Run All Branches ───────────────────────────────────────
    try:
        spectral_out  = run_spectral_branch(img)
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Spectral branch error: {e}")
    try:
        edge_out      = run_edge_branch(img)
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Edge branch error: {e}")
    try:
        cnn_out       = run_cnn_branch(img)
    except Exception as e:
        import traceback; traceback.print_exc()
        cnn_out = {"prob_fake": 0.5, "confidence": 0.0, "feature_model": None,
                   "img_tensor": None, "model_loaded": False}
        print(f"[Backend] CNN branch failed (non-fatal): {e}")
    try:
        vit_out       = run_vit_branch(img)
    except Exception as e:
        import traceback; traceback.print_exc()
        vit_out = {"prob_fake": 0.5, "confidence": 0.0, "attn_weights": None, "model_loaded": False}
        print(f"[Backend] ViT branch failed (non-fatal): {e}")
    try:
        diffusion_out = run_diffusion_branch(img)
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Diffusion branch error: {e}")

    # ── 4. Fuse Branch Outputs ────────────────────────────────────
    branch_outputs = {
        "spectral":  spectral_out,
        "edge":      edge_out,
        "cnn":       cnn_out,
        "vit":       vit_out,
        "diffusion": diffusion_out,
    }
    fusion_result = fuse_branches(branch_outputs)

    # Print to server console for debugging
    print(format_result_for_display(fusion_result))

    # ── 5. Explainability Visualizations ─────────────────────────
    from explainability.gradcam import _saliency_heatmap
    # Grad-CAM for CNN branch (or saliency heatmap as fallback)
    if cnn_out.get("feature_model") is not None:
        try:
            gradcam_data = compute_gradcam(
                cnn_out["feature_model"],
                cnn_out["img_tensor"],
                target_class=1
            )
        except Exception:
            gradcam_data = _saliency_heatmap(img)
    else:
        # CNN weights not loaded β€” generate saliency heatmap from image
        gradcam_data = _saliency_heatmap(img)

    # Spectral heatmap
    spectral_viz = render_spectral_heatmap(spectral_out["spectrum_map"], img)

    # Noise map (diffusion branch)
    noise_b64 = render_noise_map(diffusion_out["noise_map"])

    # Edge map
    edge_b64 = render_edge_map(edge_out["edge_map"])

    # ── 6. Log Prediction ─────────────────────────────────────────
    try:
        log_prediction(file.filename or "unknown", fusion_result)
    except Exception:
        pass  # Logging failure should not affect the response

    # ── 7. Build Response (all values cast to JSON-safe Python primitives) ──
    response = {
        # Core result
        "prediction":    str(fusion_result["prediction"]),
        "confidence":    float(fusion_result["confidence"]),
        "prob_fake":     float(fusion_result["prob_fake"]),
        "low_certainty": bool(fusion_result["low_certainty"]),

        # Branch scorecards β€” cast each field
        "branches": {
            name: {
                "prob_fake":  float(info["prob_fake"]),
                "confidence": float(info["confidence"]),
                "label":      str(info["label"]),
            }
            for name, info in fusion_result["branches"].items()
        },
        "fused_weights": {
            k: float(v) for k, v in fusion_result["fused_weight"].items()
        },

        # Visualizations (base64-encoded JPEG strings)
        "gradcam_b64":            str(gradcam_data.get("heatmap_b64", "")),
        "gradcam_available":      bool(gradcam_data.get("available", False)),
        "spectrum_b64":           str(spectral_viz.get("spectrum_b64", "")),
        "spectrum_annotated_b64": str(spectral_viz.get("annotated_b64", "")),
        "noise_map_b64":          str(noise_b64) if noise_b64 else "",
        "edge_map_b64":           str(edge_b64) if edge_b64 else "",
    }

    return JSONResponse(content=response)