File size: 9,635 Bytes
928b74f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """
backend/app.py
---------------
ImageForensics-Detect β FastAPI Backend
STATUS: COMPLETE
Endpoints:
POST /predict β Accept image upload, run all branches, return JSON result
GET /health β Health check
GET /logs β Summary statistics from prediction log
Run locally:
cd ImageForensics-Detect/
uvicorn backend.app:app --reload --host 0.0.0.0 --port 8000
Test with curl:
curl -X POST "http://localhost:8000/predict" \
-F "file=@path/to/test_image.jpg"
"""
import sys
import os
from pathlib import Path
# ββ Add project root to sys.path βββββββββββββββββββββββββββββββββ
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
os.chdir(ROOT) # Ensure relative paths resolve correctly
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from utils.image_utils import load_image_from_bytes
from utils.logger import log_prediction, get_log_summary
from branches.spectral_branch import run_spectral_branch
from branches.edge_branch import run_edge_branch
from branches.cnn_branch import run_cnn_branch
from branches.vit_branch import run_vit_branch
from branches.diffusion_branch import run_diffusion_branch
from fusion.fusion import fuse_branches, format_result_for_display
from explainability.gradcam import compute_gradcam, _fallback_heatmap
from explainability.spectral_heatmap import (
render_spectral_heatmap,
render_noise_map,
render_edge_map,
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# App Setup
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title="ImageForensics-Detect API",
description="Multi-branch image forensics for real vs. AI-generated image detection.",
version="1.0.0",
)
# Allow frontend (localhost:3000 / file://) to call the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Serve output visualization files
OUTPUTS_DIR = ROOT / "outputs"
OUTPUTS_DIR.mkdir(exist_ok=True)
app.mount("/outputs", StaticFiles(directory=str(OUTPUTS_DIR)), name="outputs")
# NEW: Serve frontend static files
FRONTEND_DIR = ROOT / "frontend"
if FRONTEND_DIR.exists():
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
from fastapi.responses import FileResponse
@app.get("/")
async def read_index():
return FileResponse(FRONTEND_DIR / "index.html")
# Allowed image MIME types
ALLOWED_MIME = {"image/jpeg", "image/png", "image/webp", "image/bmp"}
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15 MB
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Endpoints
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
def health():
"""Server health check."""
return {"status": "ok", "service": "ImageForensics-Detect", "version": "1.0.0"}
@app.get("/logs")
def logs():
"""Return prediction log summary statistics."""
return get_log_summary()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Analyze an uploaded image through all 5 forensic branches and return:
- Final prediction (Real / AI-Generated)
- Confidence score (%)
- Per-branch probability and confidence
- Base64-encoded Grad-CAM heatmap
- Base64-encoded spectral heatmap with anomaly annotation
- Base64-encoded residual noise map
- Base64-encoded edge map
"""
# ββ 1. Validate Upload ββββββββββββββββββββββββββββββββββββββββ
if file.content_type not in ALLOWED_MIME:
raise HTTPException(
status_code=415,
detail=f"Unsupported file type: {file.content_type}. "
f"Accepted: JPEG, PNG, WEBP, BMP"
)
raw_bytes = await file.read()
if len(raw_bytes) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large (max 15 MB).")
# ββ 2. Load & Preprocess Image ββββββββββββββββββββββββββββββββ
try:
img = load_image_from_bytes(raw_bytes, size=(224, 224)) # float32 [0,1]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not decode image: {e}")
# ββ 3. Run All Branches βββββββββββββββββββββββββββββββββββββββ
try:
spectral_out = run_spectral_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Spectral branch error: {e}")
try:
edge_out = run_edge_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Edge branch error: {e}")
try:
cnn_out = run_cnn_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
cnn_out = {"prob_fake": 0.5, "confidence": 0.0, "feature_model": None,
"img_tensor": None, "model_loaded": False}
print(f"[Backend] CNN branch failed (non-fatal): {e}")
try:
vit_out = run_vit_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
vit_out = {"prob_fake": 0.5, "confidence": 0.0, "attn_weights": None, "model_loaded": False}
print(f"[Backend] ViT branch failed (non-fatal): {e}")
try:
diffusion_out = run_diffusion_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Diffusion branch error: {e}")
# ββ 4. Fuse Branch Outputs ββββββββββββββββββββββββββββββββββββ
branch_outputs = {
"spectral": spectral_out,
"edge": edge_out,
"cnn": cnn_out,
"vit": vit_out,
"diffusion": diffusion_out,
}
fusion_result = fuse_branches(branch_outputs)
# Print to server console for debugging
print(format_result_for_display(fusion_result))
# ββ 5. Explainability Visualizations βββββββββββββββββββββββββ
from explainability.gradcam import _saliency_heatmap
# Grad-CAM for CNN branch (or saliency heatmap as fallback)
if cnn_out.get("feature_model") is not None:
try:
gradcam_data = compute_gradcam(
cnn_out["feature_model"],
cnn_out["img_tensor"],
target_class=1
)
except Exception:
gradcam_data = _saliency_heatmap(img)
else:
# CNN weights not loaded β generate saliency heatmap from image
gradcam_data = _saliency_heatmap(img)
# Spectral heatmap
spectral_viz = render_spectral_heatmap(spectral_out["spectrum_map"], img)
# Noise map (diffusion branch)
noise_b64 = render_noise_map(diffusion_out["noise_map"])
# Edge map
edge_b64 = render_edge_map(edge_out["edge_map"])
# ββ 6. Log Prediction βββββββββββββββββββββββββββββββββββββββββ
try:
log_prediction(file.filename or "unknown", fusion_result)
except Exception:
pass # Logging failure should not affect the response
# ββ 7. Build Response (all values cast to JSON-safe Python primitives) ββ
response = {
# Core result
"prediction": str(fusion_result["prediction"]),
"confidence": float(fusion_result["confidence"]),
"prob_fake": float(fusion_result["prob_fake"]),
"low_certainty": bool(fusion_result["low_certainty"]),
# Branch scorecards β cast each field
"branches": {
name: {
"prob_fake": float(info["prob_fake"]),
"confidence": float(info["confidence"]),
"label": str(info["label"]),
}
for name, info in fusion_result["branches"].items()
},
"fused_weights": {
k: float(v) for k, v in fusion_result["fused_weight"].items()
},
# Visualizations (base64-encoded JPEG strings)
"gradcam_b64": str(gradcam_data.get("heatmap_b64", "")),
"gradcam_available": bool(gradcam_data.get("available", False)),
"spectrum_b64": str(spectral_viz.get("spectrum_b64", "")),
"spectrum_annotated_b64": str(spectral_viz.get("annotated_b64", "")),
"noise_map_b64": str(noise_b64) if noise_b64 else "",
"edge_map_b64": str(edge_b64) if edge_b64 else "",
}
return JSONResponse(content=response)
|