codealchemist01 commited on
Commit
7cd02bc
·
verified ·
1 Parent(s): 28b51fd

Upload models/fusion_module.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/fusion_module.py +173 -0
models/fusion_module.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adaptive Fusion Module for Hybrid Food Classifier
3
+ Combines CNN and ViT features using cross-attention mechanism
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Tuple
9
+
10
+ class AdaptiveFusionModule(nn.Module):
11
+ """Adaptive fusion module with cross-attention"""
12
+
13
+ def __init__(
14
+ self,
15
+ feature_dim: int = 768,
16
+ hidden_dim: int = 512,
17
+ num_heads: int = 8,
18
+ dropout: float = 0.2,
19
+ spatial_size: int = 7 # 7x7 for CNN spatial features
20
+ ):
21
+ super(AdaptiveFusionModule, self).__init__()
22
+
23
+ self.feature_dim = feature_dim
24
+ self.hidden_dim = hidden_dim
25
+ self.num_heads = num_heads
26
+ self.spatial_size = spatial_size
27
+
28
+ # Cross-attention for CNN -> ViT
29
+ self.cnn_to_vit_attention = nn.MultiheadAttention(
30
+ embed_dim=feature_dim,
31
+ num_heads=num_heads,
32
+ dropout=dropout,
33
+ batch_first=True
34
+ )
35
+
36
+ # Cross-attention for ViT -> CNN
37
+ self.vit_to_cnn_attention = nn.MultiheadAttention(
38
+ embed_dim=feature_dim,
39
+ num_heads=num_heads,
40
+ dropout=dropout,
41
+ batch_first=True
42
+ )
43
+
44
+ # Self-attention for fused features
45
+ self.self_attention = nn.MultiheadAttention(
46
+ embed_dim=feature_dim,
47
+ num_heads=num_heads,
48
+ dropout=dropout,
49
+ batch_first=True
50
+ )
51
+
52
+ # Feature projection layers
53
+ self.cnn_spatial_proj = nn.Sequential(
54
+ nn.Linear(feature_dim, feature_dim),
55
+ nn.LayerNorm(feature_dim),
56
+ nn.GELU(),
57
+ nn.Dropout(dropout)
58
+ )
59
+
60
+ self.vit_spatial_proj = nn.Sequential(
61
+ nn.Linear(feature_dim, feature_dim),
62
+ nn.LayerNorm(feature_dim),
63
+ nn.GELU(),
64
+ nn.Dropout(dropout)
65
+ )
66
+
67
+ # Global feature fusion
68
+ self.global_fusion = nn.Sequential(
69
+ nn.Linear(feature_dim * 2, hidden_dim),
70
+ nn.LayerNorm(hidden_dim),
71
+ nn.GELU(),
72
+ nn.Dropout(dropout),
73
+ nn.Linear(hidden_dim, feature_dim),
74
+ nn.LayerNorm(feature_dim),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout)
77
+ )
78
+
79
+ # Adaptive weighting
80
+ self.adaptive_weight = nn.Sequential(
81
+ nn.Linear(feature_dim * 2, hidden_dim),
82
+ nn.ReLU(),
83
+ nn.Linear(hidden_dim, 2),
84
+ nn.Softmax(dim=-1)
85
+ )
86
+
87
+ # Final projection
88
+ self.final_proj = nn.Sequential(
89
+ nn.Linear(feature_dim, hidden_dim),
90
+ nn.LayerNorm(hidden_dim),
91
+ nn.GELU(),
92
+ nn.Dropout(dropout)
93
+ )
94
+
95
+ def forward(
96
+ self,
97
+ cnn_spatial: torch.Tensor, # [B, feature_dim, 7, 7]
98
+ cnn_global: torch.Tensor, # [B, feature_dim]
99
+ vit_spatial: torch.Tensor, # [B, num_patches, feature_dim]
100
+ vit_global: torch.Tensor # [B, feature_dim]
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ """
103
+ Forward pass
104
+
105
+ Args:
106
+ cnn_spatial: CNN spatial features [B, feature_dim, 7, 7]
107
+ cnn_global: CNN global features [B, feature_dim]
108
+ vit_spatial: ViT patch features [B, num_patches, feature_dim]
109
+ vit_global: ViT CLS token features [B, feature_dim]
110
+
111
+ Returns:
112
+ fused_spatial: Fused spatial features [B, seq_len, feature_dim]
113
+ fused_global: Fused global features [B, feature_dim]
114
+ """
115
+ batch_size = cnn_spatial.size(0)
116
+
117
+ # Reshape CNN spatial features to sequence format
118
+ cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) # [B, 49, feature_dim]
119
+
120
+ # Project spatial features
121
+ cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) # [B, 49, feature_dim]
122
+ vit_spatial_proj = self.vit_spatial_proj(vit_spatial) # [B, 196, feature_dim]
123
+
124
+ # Cross-attention: CNN attends to ViT
125
+ cnn_attended, _ = self.cnn_to_vit_attention(
126
+ query=cnn_spatial_proj,
127
+ key=vit_spatial_proj,
128
+ value=vit_spatial_proj
129
+ ) # [B, 49, feature_dim]
130
+
131
+ # Cross-attention: ViT attends to CNN
132
+ vit_attended, _ = self.vit_to_cnn_attention(
133
+ query=vit_spatial_proj,
134
+ key=cnn_spatial_proj,
135
+ value=cnn_spatial_proj
136
+ ) # [B, 196, feature_dim]
137
+
138
+ # Combine attended features
139
+ # Concatenate CNN and ViT spatial features
140
+ combined_spatial = torch.cat([
141
+ cnn_attended + cnn_spatial_proj, # Residual connection
142
+ vit_attended + vit_spatial_proj # Residual connection
143
+ ], dim=1) # [B, 245, feature_dim]
144
+
145
+ # Self-attention on combined features
146
+ fused_spatial, _ = self.self_attention(
147
+ query=combined_spatial,
148
+ key=combined_spatial,
149
+ value=combined_spatial
150
+ ) # [B, 245, feature_dim]
151
+
152
+ # Global feature fusion
153
+ global_concat = torch.cat([cnn_global, vit_global], dim=-1) # [B, feature_dim*2]
154
+ fused_global_base = self.global_fusion(global_concat) # [B, feature_dim]
155
+
156
+ # Adaptive weighting for global features
157
+ weights = self.adaptive_weight(global_concat) # [B, 2]
158
+ cnn_weight = weights[:, 0:1] # [B, 1]
159
+ vit_weight = weights[:, 1:2] # [B, 1]
160
+
161
+ # Weighted combination
162
+ fused_global = (cnn_weight * cnn_global +
163
+ vit_weight * vit_global +
164
+ fused_global_base) / 2 # [B, feature_dim]
165
+
166
+ # Final projection
167
+ fused_global = self.final_proj(fused_global) # [B, hidden_dim]
168
+
169
+ return fused_spatial, fused_global
170
+
171
+ def get_output_dim(self) -> int:
172
+ """Get output feature dimension"""
173
+ return self.hidden_dim