| | """ |
| | Wrapper for CNN Transfer (EfficientNet-B0) submodel. |
| | """ |
| |
|
| | import json |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional, Tuple |
| | from PIL import Image |
| | from torchvision import transforms |
| | from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights |
| |
|
| | from app.core.errors import InferenceError, ConfigurationError |
| | from app.core.logging import get_logger |
| | from app.models.wrappers.base_wrapper import BaseSubmodelWrapper |
| | from app.services.explainability import GradCAM, heatmap_to_base64, compute_focus_summary |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class CNNTransferWrapper(BaseSubmodelWrapper): |
| | """ |
| | Wrapper for CNN Transfer model using EfficientNet-B0 backbone. |
| | |
| | Model expects 224x224 RGB images with ImageNet normalization. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | repo_id: str, |
| | config: Dict[str, Any], |
| | local_path: str |
| | ): |
| | super().__init__(repo_id, config, local_path) |
| | self._model: Optional[nn.Module] = None |
| | self._transform: Optional[transforms.Compose] = None |
| | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self._threshold = config.get("threshold", 0.5) |
| | logger.info(f"Initialized CNNTransferWrapper for {repo_id}") |
| | |
| | def load(self) -> None: |
| | """Load the EfficientNet-B0 model with trained weights.""" |
| | weights_path = Path(self.local_path) / "model.pth" |
| | preprocess_path = Path(self.local_path) / "preprocess.json" |
| | |
| | if not weights_path.exists(): |
| | raise ConfigurationError( |
| | message=f"model.pth not found in {self.local_path}", |
| | details={"repo_id": self.repo_id, "expected_path": str(weights_path)} |
| | ) |
| | |
| | try: |
| | |
| | preprocess_config = {} |
| | if preprocess_path.exists(): |
| | with open(preprocess_path, "r") as f: |
| | preprocess_config = json.load(f) |
| | |
| | |
| | input_size = preprocess_config.get("input_size", [224, 224]) |
| | if isinstance(input_size, int): |
| | input_size = [input_size, input_size] |
| | |
| | normalize_config = preprocess_config.get("normalize", {}) |
| | mean = normalize_config.get("mean", [0.485, 0.456, 0.406]) |
| | std = normalize_config.get("std", [0.229, 0.224, 0.225]) |
| | |
| | self._transform = transforms.Compose([ |
| | transforms.Resize(input_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=mean, std=std) |
| | ]) |
| | |
| | |
| | num_classes = self.config.get("num_classes", 2) |
| | self._model = efficientnet_b0(weights=None) |
| | |
| | |
| | in_features = self._model.classifier[1].in_features |
| | self._model.classifier = nn.Sequential( |
| | nn.Dropout(p=0.2, inplace=True), |
| | nn.Linear(in_features, num_classes) |
| | ) |
| | |
| | |
| | state_dict = torch.load(weights_path, map_location=self._device, weights_only=True) |
| | self._model.load_state_dict(state_dict) |
| | self._model.to(self._device) |
| | self._model.eval() |
| | |
| | |
| | self._predict_fn = self._run_inference |
| | logger.info(f"Loaded CNN Transfer model from {self.repo_id}") |
| | |
| | except ConfigurationError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Failed to load CNN Transfer model: {e}") |
| | raise ConfigurationError( |
| | message=f"Failed to load model: {e}", |
| | details={"repo_id": self.repo_id, "error": str(e)} |
| | ) |
| | |
| | def _run_inference( |
| | self, |
| | image_tensor: torch.Tensor, |
| | explain: bool = False |
| | ) -> Dict[str, Any]: |
| | """Run model inference on preprocessed tensor.""" |
| | heatmap = None |
| | |
| | if explain: |
| | |
| | target_layer = self._model.features[-1] |
| | gradcam = GradCAM(self._model, target_layer) |
| | try: |
| | |
| | logits = self._model(image_tensor) |
| | probs = F.softmax(logits, dim=1) |
| | prob_fake = probs[0, 1].item() |
| | pred_int = 1 if prob_fake >= self._threshold else 0 |
| | |
| | |
| | heatmap = gradcam( |
| | image_tensor.clone(), |
| | target_class=pred_int, |
| | output_size=(224, 224) |
| | ) |
| | finally: |
| | gradcam.remove_hooks() |
| | else: |
| | with torch.no_grad(): |
| | logits = self._model(image_tensor) |
| | probs = F.softmax(logits, dim=1) |
| | prob_fake = probs[0, 1].item() |
| | pred_int = 1 if prob_fake >= self._threshold else 0 |
| | |
| | result = { |
| | "logits": logits[0].detach().cpu().numpy().tolist(), |
| | "prob_fake": prob_fake, |
| | "pred_int": pred_int |
| | } |
| | |
| | if heatmap is not None: |
| | result["heatmap"] = heatmap |
| | |
| | return result |
| | |
| | def predict( |
| | self, |
| | image: Optional[Image.Image] = None, |
| | image_bytes: Optional[bytes] = None, |
| | explain: bool = False, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Run prediction on an image. |
| | |
| | Args: |
| | image: PIL Image object |
| | image_bytes: Raw image bytes (will be converted to PIL Image) |
| | explain: If True, compute GradCAM heatmap |
| | |
| | Returns: |
| | Standardized prediction dictionary with optional heatmap |
| | """ |
| | if self._model is None or self._transform is None: |
| | raise InferenceError( |
| | message="Model not loaded", |
| | details={"repo_id": self.repo_id} |
| | ) |
| | |
| | try: |
| | |
| | if image is None and image_bytes is not None: |
| | import io |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | elif image is not None: |
| | image = image.convert("RGB") |
| | else: |
| | raise InferenceError( |
| | message="No image provided", |
| | details={"repo_id": self.repo_id} |
| | ) |
| | |
| | |
| | image_tensor = self._transform(image).unsqueeze(0).to(self._device) |
| | |
| | |
| | result = self._run_inference(image_tensor, explain=explain) |
| | |
| | |
| | labels = self.config.get("labels", {"0": "real", "1": "fake"}) |
| | pred_int = result["pred_int"] |
| | |
| | output = { |
| | "pred_int": pred_int, |
| | "pred": labels.get(str(pred_int), "unknown"), |
| | "prob_fake": result["prob_fake"], |
| | "meta": { |
| | "model": self.name, |
| | "threshold": self._threshold, |
| | "logits": result["logits"] |
| | } |
| | } |
| | |
| | |
| | if explain and "heatmap" in result: |
| | heatmap = result["heatmap"] |
| | output["heatmap_base64"] = heatmap_to_base64(heatmap) |
| | output["explainability_type"] = "grad_cam" |
| | output["focus_summary"] = compute_focus_summary(heatmap) |
| | |
| | return output |
| | |
| | except InferenceError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Prediction failed for {self.repo_id}: {e}") |
| | raise InferenceError( |
| | message=f"Prediction failed: {e}", |
| | details={"repo_id": self.repo_id, "error": str(e)} |
| | ) |
| |
|