""" Wrapper for ViT Base submodel. """ import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from PIL import Image from torchvision import transforms try: import timm TIMM_AVAILABLE = True except ImportError: TIMM_AVAILABLE = False from app.core.errors import InferenceError, ConfigurationError from app.core.logging import get_logger from app.models.wrappers.base_wrapper import BaseSubmodelWrapper from app.services.explainability import attention_rollout, heatmap_to_base64, compute_focus_summary logger = get_logger(__name__) class ViTWithMLPHead(nn.Module): """ ViT model wrapper matching the training checkpoint format. The checkpoint was saved with: - self.vit = timm ViT backbone (num_classes=0) - self.fc1 = Linear(768, hidden) - self.fc2 = Linear(hidden, num_classes) """ def __init__(self, arch: str = "vit_base_patch16_224", num_classes: int = 2, hidden_dim: int = 512): super().__init__() # Create backbone without classification head self.vit = timm.create_model(arch, pretrained=False, num_classes=0) embed_dim = self.vit.embed_dim # 768 for ViT-Base self.fc1 = nn.Linear(embed_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.vit(x) # [B, embed_dim] x = F.relu(self.fc1(features)) logits = self.fc2(x) return logits class ViTBaseWrapper(BaseSubmodelWrapper): """ Wrapper for ViT Base model (Vision Transformer). Model expects 224x224 RGB images with ImageNet normalization. """ def __init__( self, repo_id: str, config: Dict[str, Any], local_path: str ): super().__init__(repo_id, config, local_path) self._model: Optional[nn.Module] = None self._transform: Optional[transforms.Compose] = None self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._threshold = config.get("threshold", 0.5) logger.info(f"Initialized ViTBaseWrapper for {repo_id}") def load(self) -> None: """Load the ViT Base model with trained weights.""" if not TIMM_AVAILABLE: raise ConfigurationError( message="timm package not installed. Run: pip install timm", details={"repo_id": self.repo_id} ) weights_path = Path(self.local_path) / "deepfake_vit_finetuned_wildfake.pth" preprocess_path = Path(self.local_path) / "preprocess.json" if not weights_path.exists(): raise ConfigurationError( message=f"deepfake_vit_finetuned_wildfake.pth not found in {self.local_path}", details={"repo_id": self.repo_id, "expected_path": str(weights_path)} ) try: # Load preprocessing config preprocess_config = {} if preprocess_path.exists(): with open(preprocess_path, "r") as f: preprocess_config = json.load(f) # Build transform pipeline input_size = preprocess_config.get("input_size", 224) if isinstance(input_size, list): input_size = input_size[0] normalize_config = preprocess_config.get("normalize", {}) mean = normalize_config.get("mean", [0.485, 0.456, 0.406]) std = normalize_config.get("std", [0.229, 0.224, 0.225]) # Use bicubic interpolation as specified interpolation = preprocess_config.get("interpolation", "bicubic") interp_mode = transforms.InterpolationMode.BICUBIC if interpolation == "bicubic" else transforms.InterpolationMode.BILINEAR self._transform = transforms.Compose([ transforms.Resize((input_size, input_size), interpolation=interp_mode), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) # Create model architecture matching the training checkpoint format arch = self.config.get("arch", "vit_base_patch16_224") num_classes = self.config.get("num_classes", 2) # MLP hidden dim is 512 per training notebook (fc1: 768->512, fc2: 512->2) # Note: config.hidden_dim (768) is ViT embedding dim, not MLP hidden dim mlp_hidden_dim = self.config.get("mlp_hidden_dim", 512) # Use custom wrapper that matches checkpoint structure (vit.* + fc1/fc2) self._model = ViTWithMLPHead(arch=arch, num_classes=num_classes, hidden_dim=mlp_hidden_dim) # Load trained weights checkpoint = torch.load(weights_path, map_location=self._device, weights_only=False) # Handle training checkpoint format (has "model", "optimizer_state", "epoch" keys) if isinstance(checkpoint, dict) and "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint self._model.load_state_dict(state_dict) self._model.to(self._device) self._model.eval() # Mark as loaded self._predict_fn = self._run_inference logger.info(f"Loaded ViT Base model from {self.repo_id}") except ConfigurationError: raise except Exception as e: logger.error(f"Failed to load ViT Base model: {e}") raise ConfigurationError( message=f"Failed to load model: {e}", details={"repo_id": self.repo_id, "error": str(e)} ) def _run_inference( self, image_tensor: torch.Tensor, explain: bool = False ) -> Dict[str, Any]: """Run model inference on preprocessed tensor.""" heatmap = None if explain: # Collect attention weights from all blocks attentions: List[torch.Tensor] = [] handles = [] def get_attention_hook(module, input, output): # For timm ViT, the attention forward returns (attn @ v) # We need to hook into the softmax to get raw attention weights # Alternative: access module's internal attn variable if available pass # Hook into attention modules to capture weights # timm ViT blocks structure: blocks[i].attn # We'll use a forward hook that computes attention manually def create_attn_hook(): stored_attn = [] def hook(module, inputs, outputs): # Get q, k from the module's forward computation # inputs[0] is x of shape [B, N, C] x = inputs[0] B, N, C = x.shape # Access the attention module's parameters qkv = module.qkv(x) # [B, N, 3*dim] qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, dim_head] q, k, v = qkv[0], qkv[1], qkv[2] # Compute attention weights scale = (C // module.num_heads) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) # [B, heads, N, N] # Average over heads attn_avg = attn.mean(dim=1) # [B, N, N] stored_attn.append(attn_avg.detach()) return hook, stored_attn all_stored_attns = [] for block in self._model.vit.blocks: hook_fn, stored = create_attn_hook() all_stored_attns.append(stored) handle = block.attn.register_forward_hook(hook_fn) handles.append(handle) try: with torch.no_grad(): logits = self._model(image_tensor) probs = F.softmax(logits, dim=1) prob_fake = probs[0, 1].item() pred_int = 1 if prob_fake >= self._threshold else 0 # Get attention from hooks attention_list = [stored[0] for stored in all_stored_attns if len(stored) > 0] if attention_list: # Stack: [num_layers, B, N, N] attention_stack = torch.stack(attention_list, dim=0) # Compute rollout - returns (grid_size, grid_size) heatmap attention_map = attention_rollout( attention_stack[:, 0], # [num_layers, N, N] head_fusion="mean", # Already averaged discard_ratio=0.0, num_prefix_tokens=1 # ViT has 1 CLS token ) # Returns (14, 14) for ViT-Base # Resize to image size from PIL import Image as PILImage heatmap_img = PILImage.fromarray( (attention_map * 255).astype(np.uint8) ).resize((224, 224), PILImage.BILINEAR) heatmap = np.array(heatmap_img).astype(np.float32) / 255.0 finally: for handle in handles: handle.remove() else: with torch.no_grad(): logits = self._model(image_tensor) probs = F.softmax(logits, dim=1) prob_fake = probs[0, 1].item() pred_int = 1 if prob_fake >= self._threshold else 0 result = { "logits": logits[0].cpu().numpy().tolist(), "prob_fake": prob_fake, "pred_int": pred_int } if heatmap is not None: result["heatmap"] = heatmap return result def predict( self, image: Optional[Image.Image] = None, image_bytes: Optional[bytes] = None, explain: bool = False, **kwargs ) -> Dict[str, Any]: """ Run prediction on an image. Args: image: PIL Image object image_bytes: Raw image bytes (will be converted to PIL Image) explain: If True, compute attention rollout heatmap Returns: Standardized prediction dictionary with optional heatmap """ if self._model is None or self._transform is None: raise InferenceError( message="Model not loaded", details={"repo_id": self.repo_id} ) try: # Convert bytes to PIL Image if needed if image is None and image_bytes is not None: import io image = Image.open(io.BytesIO(image_bytes)).convert("RGB") elif image is not None: image = image.convert("RGB") else: raise InferenceError( message="No image provided", details={"repo_id": self.repo_id} ) # Preprocess image_tensor = self._transform(image).unsqueeze(0).to(self._device) # Run inference result = self._run_inference(image_tensor, explain=explain) # Standardize output labels = self.config.get("labels", {"0": "real", "1": "fake"}) pred_int = result["pred_int"] output = { "pred_int": pred_int, "pred": labels.get(str(pred_int), "unknown"), "prob_fake": result["prob_fake"], "meta": { "model": self.name, "threshold": self._threshold, "logits": result["logits"] } } # Add heatmap if requested if explain and "heatmap" in result: heatmap = result["heatmap"] output["heatmap_base64"] = heatmap_to_base64(heatmap) output["explainability_type"] = "attention_rollout" output["focus_summary"] = compute_focus_summary(heatmap) return output except InferenceError: raise except Exception as e: logger.error(f"Prediction failed for {self.repo_id}: {e}") raise InferenceError( message=f"Prediction failed: {e}", details={"repo_id": self.repo_id, "error": str(e)} )