File size: 5,993 Bytes
7cd02bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Adaptive Fusion Module for Hybrid Food Classifier
Combines CNN and ViT features using cross-attention mechanism
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class AdaptiveFusionModule(nn.Module):
"""Adaptive fusion module with cross-attention"""
def __init__(
self,
feature_dim: int = 768,
hidden_dim: int = 512,
num_heads: int = 8,
dropout: float = 0.2,
spatial_size: int = 7 # 7x7 for CNN spatial features
):
super(AdaptiveFusionModule, self).__init__()
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.spatial_size = spatial_size
# Cross-attention for CNN -> ViT
self.cnn_to_vit_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Cross-attention for ViT -> CNN
self.vit_to_cnn_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Self-attention for fused features
self.self_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Feature projection layers
self.cnn_spatial_proj = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.vit_spatial_proj = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# Global feature fusion
self.global_fusion = nn.Sequential(
nn.Linear(feature_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# Adaptive weighting
self.adaptive_weight = nn.Sequential(
nn.Linear(feature_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2),
nn.Softmax(dim=-1)
)
# Final projection
self.final_proj = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout)
)
def forward(
self,
cnn_spatial: torch.Tensor, # [B, feature_dim, 7, 7]
cnn_global: torch.Tensor, # [B, feature_dim]
vit_spatial: torch.Tensor, # [B, num_patches, feature_dim]
vit_global: torch.Tensor # [B, feature_dim]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
Args:
cnn_spatial: CNN spatial features [B, feature_dim, 7, 7]
cnn_global: CNN global features [B, feature_dim]
vit_spatial: ViT patch features [B, num_patches, feature_dim]
vit_global: ViT CLS token features [B, feature_dim]
Returns:
fused_spatial: Fused spatial features [B, seq_len, feature_dim]
fused_global: Fused global features [B, feature_dim]
"""
batch_size = cnn_spatial.size(0)
# Reshape CNN spatial features to sequence format
cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) # [B, 49, feature_dim]
# Project spatial features
cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) # [B, 49, feature_dim]
vit_spatial_proj = self.vit_spatial_proj(vit_spatial) # [B, 196, feature_dim]
# Cross-attention: CNN attends to ViT
cnn_attended, _ = self.cnn_to_vit_attention(
query=cnn_spatial_proj,
key=vit_spatial_proj,
value=vit_spatial_proj
) # [B, 49, feature_dim]
# Cross-attention: ViT attends to CNN
vit_attended, _ = self.vit_to_cnn_attention(
query=vit_spatial_proj,
key=cnn_spatial_proj,
value=cnn_spatial_proj
) # [B, 196, feature_dim]
# Combine attended features
# Concatenate CNN and ViT spatial features
combined_spatial = torch.cat([
cnn_attended + cnn_spatial_proj, # Residual connection
vit_attended + vit_spatial_proj # Residual connection
], dim=1) # [B, 245, feature_dim]
# Self-attention on combined features
fused_spatial, _ = self.self_attention(
query=combined_spatial,
key=combined_spatial,
value=combined_spatial
) # [B, 245, feature_dim]
# Global feature fusion
global_concat = torch.cat([cnn_global, vit_global], dim=-1) # [B, feature_dim*2]
fused_global_base = self.global_fusion(global_concat) # [B, feature_dim]
# Adaptive weighting for global features
weights = self.adaptive_weight(global_concat) # [B, 2]
cnn_weight = weights[:, 0:1] # [B, 1]
vit_weight = weights[:, 1:2] # [B, 1]
# Weighted combination
fused_global = (cnn_weight * cnn_global +
vit_weight * vit_global +
fused_global_base) / 2 # [B, feature_dim]
# Final projection
fused_global = self.final_proj(fused_global) # [B, hidden_dim]
return fused_spatial, fused_global
def get_output_dim(self) -> int:
"""Get output feature dimension"""
return self.hidden_dim |