File size: 5,993 Bytes
7cd02bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Adaptive Fusion Module for Hybrid Food Classifier
Combines CNN and ViT features using cross-attention mechanism
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class AdaptiveFusionModule(nn.Module):
    """Adaptive fusion module with cross-attention"""
    
    def __init__(
        self,
        feature_dim: int = 768,
        hidden_dim: int = 512,
        num_heads: int = 8,
        dropout: float = 0.2,
        spatial_size: int = 7  # 7x7 for CNN spatial features
    ):
        super(AdaptiveFusionModule, self).__init__()
        
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.spatial_size = spatial_size
        
        # Cross-attention for CNN -> ViT
        self.cnn_to_vit_attention = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Cross-attention for ViT -> CNN
        self.vit_to_cnn_attention = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Self-attention for fused features
        self.self_attention = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Feature projection layers
        self.cnn_spatial_proj = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        self.vit_spatial_proj = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Global feature fusion
        self.global_fusion = nn.Sequential(
            nn.Linear(feature_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Adaptive weighting
        self.adaptive_weight = nn.Sequential(
            nn.Linear(feature_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
            nn.Softmax(dim=-1)
        )
        
        # Final projection
        self.final_proj = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
    
    def forward(
        self,
        cnn_spatial: torch.Tensor,  # [B, feature_dim, 7, 7]
        cnn_global: torch.Tensor,   # [B, feature_dim]
        vit_spatial: torch.Tensor,  # [B, num_patches, feature_dim]
        vit_global: torch.Tensor    # [B, feature_dim]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass
        
        Args:
            cnn_spatial: CNN spatial features [B, feature_dim, 7, 7]
            cnn_global: CNN global features [B, feature_dim]
            vit_spatial: ViT patch features [B, num_patches, feature_dim]
            vit_global: ViT CLS token features [B, feature_dim]
            
        Returns:
            fused_spatial: Fused spatial features [B, seq_len, feature_dim]
            fused_global: Fused global features [B, feature_dim]
        """
        batch_size = cnn_spatial.size(0)
        
        # Reshape CNN spatial features to sequence format
        cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2)  # [B, 49, feature_dim]
        
        # Project spatial features
        cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq)  # [B, 49, feature_dim]
        vit_spatial_proj = self.vit_spatial_proj(vit_spatial)      # [B, 196, feature_dim]
        
        # Cross-attention: CNN attends to ViT
        cnn_attended, _ = self.cnn_to_vit_attention(
            query=cnn_spatial_proj,
            key=vit_spatial_proj,
            value=vit_spatial_proj
        )  # [B, 49, feature_dim]
        
        # Cross-attention: ViT attends to CNN
        vit_attended, _ = self.vit_to_cnn_attention(
            query=vit_spatial_proj,
            key=cnn_spatial_proj,
            value=cnn_spatial_proj
        )  # [B, 196, feature_dim]
        
        # Combine attended features
        # Concatenate CNN and ViT spatial features
        combined_spatial = torch.cat([
            cnn_attended + cnn_spatial_proj,  # Residual connection
            vit_attended + vit_spatial_proj   # Residual connection
        ], dim=1)  # [B, 245, feature_dim]
        
        # Self-attention on combined features
        fused_spatial, _ = self.self_attention(
            query=combined_spatial,
            key=combined_spatial,
            value=combined_spatial
        )  # [B, 245, feature_dim]
        
        # Global feature fusion
        global_concat = torch.cat([cnn_global, vit_global], dim=-1)  # [B, feature_dim*2]
        fused_global_base = self.global_fusion(global_concat)  # [B, feature_dim]
        
        # Adaptive weighting for global features
        weights = self.adaptive_weight(global_concat)  # [B, 2]
        cnn_weight = weights[:, 0:1]  # [B, 1]
        vit_weight = weights[:, 1:2]  # [B, 1]
        
        # Weighted combination
        fused_global = (cnn_weight * cnn_global + 
                       vit_weight * vit_global + 
                       fused_global_base) / 2  # [B, feature_dim]
        
        # Final projection
        fused_global = self.final_proj(fused_global)  # [B, hidden_dim]
        
        return fused_spatial, fused_global
    
    def get_output_dim(self) -> int:
        """Get output feature dimension"""
        return self.hidden_dim