dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
backend/app.py
---------------
ImageForensics-Detect β€” FastAPI Backend
STATUS: COMPLETE
Endpoints:
POST /predict β€” Accept image upload, run all branches, return JSON result
GET /health β€” Health check
GET /logs β€” Summary statistics from prediction log
Run locally:
cd ImageForensics-Detect/
uvicorn backend.app:app --reload --host 0.0.0.0 --port 8000
Test with curl:
curl -X POST "http://localhost:8000/predict" \
-F "file=@path/to/test_image.jpg"
"""
import sys
import os
from pathlib import Path
# ── Add project root to sys.path ─────────────────────────────────
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
os.chdir(ROOT) # Ensure relative paths resolve correctly
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from utils.image_utils import load_image_from_bytes
from utils.logger import log_prediction, get_log_summary
from branches.spectral_branch import run_spectral_branch
from branches.edge_branch import run_edge_branch
from branches.cnn_branch import run_cnn_branch
from branches.vit_branch import run_vit_branch
from branches.diffusion_branch import run_diffusion_branch
from fusion.fusion import fuse_branches, format_result_for_display
from explainability.gradcam import compute_gradcam, _fallback_heatmap
from explainability.spectral_heatmap import (
render_spectral_heatmap,
render_noise_map,
render_edge_map,
)
# ─────────────────────────────────────────────────────────────────
# App Setup
# ─────────────────────────────────────────────────────────────────
app = FastAPI(
title="ImageForensics-Detect API",
description="Multi-branch image forensics for real vs. AI-generated image detection.",
version="1.0.0",
)
# Allow frontend (localhost:3000 / file://) to call the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Serve output visualization files
OUTPUTS_DIR = ROOT / "outputs"
OUTPUTS_DIR.mkdir(exist_ok=True)
app.mount("/outputs", StaticFiles(directory=str(OUTPUTS_DIR)), name="outputs")
# NEW: Serve frontend static files
FRONTEND_DIR = ROOT / "frontend"
if FRONTEND_DIR.exists():
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
from fastapi.responses import FileResponse
@app.get("/")
async def read_index():
return FileResponse(FRONTEND_DIR / "index.html")
# Allowed image MIME types
ALLOWED_MIME = {"image/jpeg", "image/png", "image/webp", "image/bmp"}
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15 MB
# ─────────────────────────────────────────────────────────────────
# Endpoints
# ─────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
"""Server health check."""
return {"status": "ok", "service": "ImageForensics-Detect", "version": "1.0.0"}
@app.get("/logs")
def logs():
"""Return prediction log summary statistics."""
return get_log_summary()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Analyze an uploaded image through all 5 forensic branches and return:
- Final prediction (Real / AI-Generated)
- Confidence score (%)
- Per-branch probability and confidence
- Base64-encoded Grad-CAM heatmap
- Base64-encoded spectral heatmap with anomaly annotation
- Base64-encoded residual noise map
- Base64-encoded edge map
"""
# ── 1. Validate Upload ────────────────────────────────────────
if file.content_type not in ALLOWED_MIME:
raise HTTPException(
status_code=415,
detail=f"Unsupported file type: {file.content_type}. "
f"Accepted: JPEG, PNG, WEBP, BMP"
)
raw_bytes = await file.read()
if len(raw_bytes) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large (max 15 MB).")
# ── 2. Load & Preprocess Image ────────────────────────────────
try:
img = load_image_from_bytes(raw_bytes, size=(224, 224)) # float32 [0,1]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not decode image: {e}")
# ── 3. Run All Branches ───────────────────────────────────────
try:
spectral_out = run_spectral_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Spectral branch error: {e}")
try:
edge_out = run_edge_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Edge branch error: {e}")
try:
cnn_out = run_cnn_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
cnn_out = {"prob_fake": 0.5, "confidence": 0.0, "feature_model": None,
"img_tensor": None, "model_loaded": False}
print(f"[Backend] CNN branch failed (non-fatal): {e}")
try:
vit_out = run_vit_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
vit_out = {"prob_fake": 0.5, "confidence": 0.0, "attn_weights": None, "model_loaded": False}
print(f"[Backend] ViT branch failed (non-fatal): {e}")
try:
diffusion_out = run_diffusion_branch(img)
except Exception as e:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Diffusion branch error: {e}")
# ── 4. Fuse Branch Outputs ────────────────────────────────────
branch_outputs = {
"spectral": spectral_out,
"edge": edge_out,
"cnn": cnn_out,
"vit": vit_out,
"diffusion": diffusion_out,
}
fusion_result = fuse_branches(branch_outputs)
# Print to server console for debugging
print(format_result_for_display(fusion_result))
# ── 5. Explainability Visualizations ─────────────────────────
from explainability.gradcam import _saliency_heatmap
# Grad-CAM for CNN branch (or saliency heatmap as fallback)
if cnn_out.get("feature_model") is not None:
try:
gradcam_data = compute_gradcam(
cnn_out["feature_model"],
cnn_out["img_tensor"],
target_class=1
)
except Exception:
gradcam_data = _saliency_heatmap(img)
else:
# CNN weights not loaded β€” generate saliency heatmap from image
gradcam_data = _saliency_heatmap(img)
# Spectral heatmap
spectral_viz = render_spectral_heatmap(spectral_out["spectrum_map"], img)
# Noise map (diffusion branch)
noise_b64 = render_noise_map(diffusion_out["noise_map"])
# Edge map
edge_b64 = render_edge_map(edge_out["edge_map"])
# ── 6. Log Prediction ─────────────────────────────────────────
try:
log_prediction(file.filename or "unknown", fusion_result)
except Exception:
pass # Logging failure should not affect the response
# ── 7. Build Response (all values cast to JSON-safe Python primitives) ──
response = {
# Core result
"prediction": str(fusion_result["prediction"]),
"confidence": float(fusion_result["confidence"]),
"prob_fake": float(fusion_result["prob_fake"]),
"low_certainty": bool(fusion_result["low_certainty"]),
# Branch scorecards β€” cast each field
"branches": {
name: {
"prob_fake": float(info["prob_fake"]),
"confidence": float(info["confidence"]),
"label": str(info["label"]),
}
for name, info in fusion_result["branches"].items()
},
"fused_weights": {
k: float(v) for k, v in fusion_result["fused_weight"].items()
},
# Visualizations (base64-encoded JPEG strings)
"gradcam_b64": str(gradcam_data.get("heatmap_b64", "")),
"gradcam_available": bool(gradcam_data.get("available", False)),
"spectrum_b64": str(spectral_viz.get("spectrum_b64", "")),
"spectrum_annotated_b64": str(spectral_viz.get("annotated_b64", "")),
"noise_map_b64": str(noise_b64) if noise_b64 else "",
"edge_map_b64": str(edge_b64) if edge_b64 else "",
}
return JSONResponse(content=response)