food-classifier-space / models /fusion_module.py
codealchemist01's picture
Upload models/fusion_module.py with huggingface_hub
7cd02bc verified
"""
Adaptive Fusion Module for Hybrid Food Classifier
Combines CNN and ViT features using cross-attention mechanism
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class AdaptiveFusionModule(nn.Module):
"""Adaptive fusion module with cross-attention"""
def __init__(
self,
feature_dim: int = 768,
hidden_dim: int = 512,
num_heads: int = 8,
dropout: float = 0.2,
spatial_size: int = 7 # 7x7 for CNN spatial features
):
super(AdaptiveFusionModule, self).__init__()
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.spatial_size = spatial_size
# Cross-attention for CNN -> ViT
self.cnn_to_vit_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Cross-attention for ViT -> CNN
self.vit_to_cnn_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Self-attention for fused features
self.self_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Feature projection layers
self.cnn_spatial_proj = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.vit_spatial_proj = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# Global feature fusion
self.global_fusion = nn.Sequential(
nn.Linear(feature_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# Adaptive weighting
self.adaptive_weight = nn.Sequential(
nn.Linear(feature_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2),
nn.Softmax(dim=-1)
)
# Final projection
self.final_proj = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout)
)
def forward(
self,
cnn_spatial: torch.Tensor, # [B, feature_dim, 7, 7]
cnn_global: torch.Tensor, # [B, feature_dim]
vit_spatial: torch.Tensor, # [B, num_patches, feature_dim]
vit_global: torch.Tensor # [B, feature_dim]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
Args:
cnn_spatial: CNN spatial features [B, feature_dim, 7, 7]
cnn_global: CNN global features [B, feature_dim]
vit_spatial: ViT patch features [B, num_patches, feature_dim]
vit_global: ViT CLS token features [B, feature_dim]
Returns:
fused_spatial: Fused spatial features [B, seq_len, feature_dim]
fused_global: Fused global features [B, feature_dim]
"""
batch_size = cnn_spatial.size(0)
# Reshape CNN spatial features to sequence format
cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) # [B, 49, feature_dim]
# Project spatial features
cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) # [B, 49, feature_dim]
vit_spatial_proj = self.vit_spatial_proj(vit_spatial) # [B, 196, feature_dim]
# Cross-attention: CNN attends to ViT
cnn_attended, _ = self.cnn_to_vit_attention(
query=cnn_spatial_proj,
key=vit_spatial_proj,
value=vit_spatial_proj
) # [B, 49, feature_dim]
# Cross-attention: ViT attends to CNN
vit_attended, _ = self.vit_to_cnn_attention(
query=vit_spatial_proj,
key=cnn_spatial_proj,
value=cnn_spatial_proj
) # [B, 196, feature_dim]
# Combine attended features
# Concatenate CNN and ViT spatial features
combined_spatial = torch.cat([
cnn_attended + cnn_spatial_proj, # Residual connection
vit_attended + vit_spatial_proj # Residual connection
], dim=1) # [B, 245, feature_dim]
# Self-attention on combined features
fused_spatial, _ = self.self_attention(
query=combined_spatial,
key=combined_spatial,
value=combined_spatial
) # [B, 245, feature_dim]
# Global feature fusion
global_concat = torch.cat([cnn_global, vit_global], dim=-1) # [B, feature_dim*2]
fused_global_base = self.global_fusion(global_concat) # [B, feature_dim]
# Adaptive weighting for global features
weights = self.adaptive_weight(global_concat) # [B, 2]
cnn_weight = weights[:, 0:1] # [B, 1]
vit_weight = weights[:, 1:2] # [B, 1]
# Weighted combination
fused_global = (cnn_weight * cnn_global +
vit_weight * vit_global +
fused_global_base) / 2 # [B, feature_dim]
# Final projection
fused_global = self.final_proj(fused_global) # [B, hidden_dim]
return fused_spatial, fused_global
def get_output_dim(self) -> int:
"""Get output feature dimension"""
return self.hidden_dim