""" Hybrid CNN-ViT Food Classifier Combines ResNet50 and DeiT-Base with adaptive fusion """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Any, Optional from .cnn_branch import CNNBranch from .vit_branch import ViTBranch from .fusion_module import AdaptiveFusionModule class HybridFoodClassifier(nn.Module): """Hybrid CNN-ViT model for food classification""" def __init__( self, num_classes: int = 101, feature_dim: int = 768, hidden_dim: int = 512, dropout: float = 0.2, pretrained: bool = True, freeze_early_layers: bool = True ): super(HybridFoodClassifier, self).__init__() self.num_classes = num_classes self.feature_dim = feature_dim self.hidden_dim = hidden_dim # CNN Branch (ResNet50) self.cnn_branch = CNNBranch( pretrained=pretrained, freeze_early_layers=freeze_early_layers, dropout=dropout, feature_dim=feature_dim ) # ViT Branch (DeiT-Base) self.vit_branch = ViTBranch( pretrained=pretrained, freeze_early_layers=freeze_early_layers, dropout=dropout, feature_dim=feature_dim ) # Fusion Module self.fusion_module = AdaptiveFusionModule( feature_dim=feature_dim, hidden_dim=hidden_dim, dropout=dropout ) # Classification Head self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, num_classes) ) # Auxiliary classifiers for training stability self.cnn_aux_classifier = nn.Sequential( nn.Linear(feature_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, num_classes) ) self.vit_aux_classifier = nn.Sequential( nn.Linear(feature_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, num_classes) ) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize classifier weights""" for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]: for layer in m: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, 0) def forward( self, x: torch.Tensor, return_features: bool = False, use_aux_loss: bool = True ) -> Dict[str, torch.Tensor]: """ Forward pass Args: x: Input tensor [B, 3, H, W] return_features: Whether to return intermediate features use_aux_loss: Whether to compute auxiliary losses Returns: Dictionary containing logits and optionally features/aux_logits """ # CNN Branch cnn_spatial, cnn_global = self.cnn_branch(x) # ViT Branch vit_spatial, vit_global = self.vit_branch(x) # Fusion fused_spatial, fused_global = self.fusion_module( cnn_spatial, cnn_global, vit_spatial, vit_global ) # Main classification logits = self.classifier(fused_global) # Prepare output output = {'logits': logits} # Auxiliary losses for training if use_aux_loss and self.training: cnn_aux_logits = self.cnn_aux_classifier(cnn_global) vit_aux_logits = self.vit_aux_classifier(vit_global) output.update({ 'cnn_aux_logits': cnn_aux_logits, 'vit_aux_logits': vit_aux_logits }) # Return features if requested if return_features: output.update({ 'cnn_spatial': cnn_spatial, 'cnn_global': cnn_global, 'vit_spatial': vit_spatial, 'vit_global': vit_global, 'fused_spatial': fused_spatial, 'fused_global': fused_global }) return output def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Get attention maps for visualization""" with torch.no_grad(): # Get features output = self.forward(x, return_features=True, use_aux_loss=False) # CNN attention (using global average pooling weights) cnn_spatial = output['cnn_spatial'] # [B, feature_dim, 7, 7] cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) # [B, 1, 7, 7] cnn_attention = F.interpolate( cnn_attention, size=(224, 224), mode='bilinear', align_corners=False ) # [B, 1, 224, 224] # ViT attention (using patch importance) vit_spatial = output['vit_spatial'] # [B, 197, feature_dim] (196 patches + 1 CLS) vit_patches = vit_spatial[:, 1:] # Remove CLS token, get [B, 196, feature_dim] vit_attention = torch.mean(vit_patches, dim=-1) # [B, 196] vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) # [B, 1, 14, 14] vit_attention = F.interpolate( vit_attention, size=(224, 224), mode='bilinear', align_corners=False ) # [B, 1, 224, 224] return { 'cnn_attention': cnn_attention, 'vit_attention': vit_attention } def freeze_backbone(self): """Freeze backbone networks""" for param in self.cnn_branch.backbone.parameters(): param.requires_grad = False for param in self.vit_branch.vit.parameters(): param.requires_grad = False def unfreeze_backbone(self): """Unfreeze backbone networks""" for param in self.cnn_branch.backbone.parameters(): param.requires_grad = True for param in self.vit_branch.vit.parameters(): param.requires_grad = True def get_model_size(self) -> Dict[str, int]: """Get model size information""" total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) cnn_params = sum(p.numel() for p in self.cnn_branch.parameters()) vit_params = sum(p.numel() for p in self.vit_branch.parameters()) fusion_params = sum(p.numel() for p in self.fusion_module.parameters()) classifier_params = sum(p.numel() for p in self.classifier.parameters()) return { 'total_params': total_params, 'trainable_params': trainable_params, 'cnn_params': cnn_params, 'vit_params': vit_params, 'fusion_params': fusion_params, 'classifier_params': classifier_params }