DeepFakeDetectorBackend / app /models /wrappers /cnn_transfer_wrapper.py
lukhsaankumar's picture
Deploy DeepFake Detector API - 2026-03-07 09:12:00
df4a21a
"""
Wrapper for CNN Transfer (EfficientNet-B0) submodel.
"""
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from PIL import Image
from torchvision import transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from app.core.errors import InferenceError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
from app.services.explainability import GradCAM, heatmap_to_base64, compute_focus_summary
logger = get_logger(__name__)
class CNNTransferWrapper(BaseSubmodelWrapper):
"""
Wrapper for CNN Transfer model using EfficientNet-B0 backbone.
Model expects 224x224 RGB images with ImageNet normalization.
"""
def __init__(
self,
repo_id: str,
config: Dict[str, Any],
local_path: str
):
super().__init__(repo_id, config, local_path)
self._model: Optional[nn.Module] = None
self._transform: Optional[transforms.Compose] = None
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._threshold = config.get("threshold", 0.5)
logger.info(f"Initialized CNNTransferWrapper for {repo_id}")
def load(self) -> None:
"""Load the EfficientNet-B0 model with trained weights."""
weights_path = Path(self.local_path) / "model.pth"
preprocess_path = Path(self.local_path) / "preprocess.json"
if not weights_path.exists():
raise ConfigurationError(
message=f"model.pth not found in {self.local_path}",
details={"repo_id": self.repo_id, "expected_path": str(weights_path)}
)
try:
# Load preprocessing config
preprocess_config = {}
if preprocess_path.exists():
with open(preprocess_path, "r") as f:
preprocess_config = json.load(f)
# Build transform pipeline
input_size = preprocess_config.get("input_size", [224, 224])
if isinstance(input_size, int):
input_size = [input_size, input_size]
normalize_config = preprocess_config.get("normalize", {})
mean = normalize_config.get("mean", [0.485, 0.456, 0.406])
std = normalize_config.get("std", [0.229, 0.224, 0.225])
self._transform = transforms.Compose([
transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# Create model architecture
num_classes = self.config.get("num_classes", 2)
self._model = efficientnet_b0(weights=None)
# Replace classifier for binary classification
in_features = self._model.classifier[1].in_features
self._model.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features, num_classes)
)
# Load trained weights
state_dict = torch.load(weights_path, map_location=self._device, weights_only=True)
self._model.load_state_dict(state_dict)
self._model.to(self._device)
self._model.eval()
# Mark as loaded
self._predict_fn = self._run_inference
logger.info(f"Loaded CNN Transfer model from {self.repo_id}")
except ConfigurationError:
raise
except Exception as e:
logger.error(f"Failed to load CNN Transfer model: {e}")
raise ConfigurationError(
message=f"Failed to load model: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def _run_inference(
self,
image_tensor: torch.Tensor,
explain: bool = False
) -> Dict[str, Any]:
"""Run model inference on preprocessed tensor."""
heatmap = None
if explain:
# Use GradCAM for explainability (requires gradients)
target_layer = self._model.features[-1] # Last MBConv block
gradcam = GradCAM(self._model, target_layer)
try:
# GradCAM needs gradients, so don't use no_grad
logits = self._model(image_tensor)
probs = F.softmax(logits, dim=1)
prob_fake = probs[0, 1].item()
pred_int = 1 if prob_fake >= self._threshold else 0
# Compute heatmap for predicted class
heatmap = gradcam(
image_tensor.clone(),
target_class=pred_int,
output_size=(224, 224)
)
finally:
gradcam.remove_hooks()
else:
with torch.no_grad():
logits = self._model(image_tensor)
probs = F.softmax(logits, dim=1)
prob_fake = probs[0, 1].item()
pred_int = 1 if prob_fake >= self._threshold else 0
result = {
"logits": logits[0].detach().cpu().numpy().tolist(),
"prob_fake": prob_fake,
"pred_int": pred_int
}
if heatmap is not None:
result["heatmap"] = heatmap
return result
def predict(
self,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
explain: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Run prediction on an image.
Args:
image: PIL Image object
image_bytes: Raw image bytes (will be converted to PIL Image)
explain: If True, compute GradCAM heatmap
Returns:
Standardized prediction dictionary with optional heatmap
"""
if self._model is None or self._transform is None:
raise InferenceError(
message="Model not loaded",
details={"repo_id": self.repo_id}
)
try:
# Convert bytes to PIL Image if needed
if image is None and image_bytes is not None:
import io
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif image is not None:
image = image.convert("RGB")
else:
raise InferenceError(
message="No image provided",
details={"repo_id": self.repo_id}
)
# Preprocess
image_tensor = self._transform(image).unsqueeze(0).to(self._device)
# Run inference
result = self._run_inference(image_tensor, explain=explain)
# Standardize output
labels = self.config.get("labels", {"0": "real", "1": "fake"})
pred_int = result["pred_int"]
output = {
"pred_int": pred_int,
"pred": labels.get(str(pred_int), "unknown"),
"prob_fake": result["prob_fake"],
"meta": {
"model": self.name,
"threshold": self._threshold,
"logits": result["logits"]
}
}
# Add heatmap if requested
if explain and "heatmap" in result:
heatmap = result["heatmap"]
output["heatmap_base64"] = heatmap_to_base64(heatmap)
output["explainability_type"] = "grad_cam"
output["focus_summary"] = compute_focus_summary(heatmap)
return output
except InferenceError:
raise
except Exception as e:
logger.error(f"Prediction failed for {self.repo_id}: {e}")
raise InferenceError(
message=f"Prediction failed: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)