""" Adaptive Fusion Module for Hybrid Food Classifier Combines CNN and ViT features using cross-attention mechanism """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple class AdaptiveFusionModule(nn.Module): """Adaptive fusion module with cross-attention""" def __init__( self, feature_dim: int = 768, hidden_dim: int = 512, num_heads: int = 8, dropout: float = 0.2, spatial_size: int = 7 # 7x7 for CNN spatial features ): super(AdaptiveFusionModule, self).__init__() self.feature_dim = feature_dim self.hidden_dim = hidden_dim self.num_heads = num_heads self.spatial_size = spatial_size # Cross-attention for CNN -> ViT self.cnn_to_vit_attention = nn.MultiheadAttention( embed_dim=feature_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Cross-attention for ViT -> CNN self.vit_to_cnn_attention = nn.MultiheadAttention( embed_dim=feature_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Self-attention for fused features self.self_attention = nn.MultiheadAttention( embed_dim=feature_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Feature projection layers self.cnn_spatial_proj = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) self.vit_spatial_proj = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) # Global feature fusion self.global_fusion = nn.Sequential( nn.Linear(feature_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) # Adaptive weighting self.adaptive_weight = nn.Sequential( nn.Linear(feature_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2), nn.Softmax(dim=-1) ) # Final projection self.final_proj = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout) ) def forward( self, cnn_spatial: torch.Tensor, # [B, feature_dim, 7, 7] cnn_global: torch.Tensor, # [B, feature_dim] vit_spatial: torch.Tensor, # [B, num_patches, feature_dim] vit_global: torch.Tensor # [B, feature_dim] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass Args: cnn_spatial: CNN spatial features [B, feature_dim, 7, 7] cnn_global: CNN global features [B, feature_dim] vit_spatial: ViT patch features [B, num_patches, feature_dim] vit_global: ViT CLS token features [B, feature_dim] Returns: fused_spatial: Fused spatial features [B, seq_len, feature_dim] fused_global: Fused global features [B, feature_dim] """ batch_size = cnn_spatial.size(0) # Reshape CNN spatial features to sequence format cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) # [B, 49, feature_dim] # Project spatial features cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) # [B, 49, feature_dim] vit_spatial_proj = self.vit_spatial_proj(vit_spatial) # [B, 196, feature_dim] # Cross-attention: CNN attends to ViT cnn_attended, _ = self.cnn_to_vit_attention( query=cnn_spatial_proj, key=vit_spatial_proj, value=vit_spatial_proj ) # [B, 49, feature_dim] # Cross-attention: ViT attends to CNN vit_attended, _ = self.vit_to_cnn_attention( query=vit_spatial_proj, key=cnn_spatial_proj, value=cnn_spatial_proj ) # [B, 196, feature_dim] # Combine attended features # Concatenate CNN and ViT spatial features combined_spatial = torch.cat([ cnn_attended + cnn_spatial_proj, # Residual connection vit_attended + vit_spatial_proj # Residual connection ], dim=1) # [B, 245, feature_dim] # Self-attention on combined features fused_spatial, _ = self.self_attention( query=combined_spatial, key=combined_spatial, value=combined_spatial ) # [B, 245, feature_dim] # Global feature fusion global_concat = torch.cat([cnn_global, vit_global], dim=-1) # [B, feature_dim*2] fused_global_base = self.global_fusion(global_concat) # [B, feature_dim] # Adaptive weighting for global features weights = self.adaptive_weight(global_concat) # [B, 2] cnn_weight = weights[:, 0:1] # [B, 1] vit_weight = weights[:, 1:2] # [B, 1] # Weighted combination fused_global = (cnn_weight * cnn_global + vit_weight * vit_global + fused_global_base) / 2 # [B, feature_dim] # Final projection fused_global = self.final_proj(fused_global) # [B, hidden_dim] return fused_spatial, fused_global def get_output_dim(self) -> int: """Get output feature dimension""" return self.hidden_dim