codealchemist01 commited on
Commit
dbb41ac
·
verified ·
1 Parent(s): 7049b0e

Upload models/hybrid_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/hybrid_model.py +211 -0
models/hybrid_model.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid CNN-ViT Food Classifier
3
+ Combines ResNet50 and DeiT-Base with adaptive fusion
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Dict, Any, Optional
9
+
10
+ from .cnn_branch import CNNBranch
11
+ from .vit_branch import ViTBranch
12
+ from .fusion_module import AdaptiveFusionModule
13
+
14
+ class HybridFoodClassifier(nn.Module):
15
+ """Hybrid CNN-ViT model for food classification"""
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int = 101,
20
+ feature_dim: int = 768,
21
+ hidden_dim: int = 512,
22
+ dropout: float = 0.2,
23
+ pretrained: bool = True,
24
+ freeze_early_layers: bool = True
25
+ ):
26
+ super(HybridFoodClassifier, self).__init__()
27
+
28
+ self.num_classes = num_classes
29
+ self.feature_dim = feature_dim
30
+ self.hidden_dim = hidden_dim
31
+
32
+ # CNN Branch (ResNet50)
33
+ self.cnn_branch = CNNBranch(
34
+ pretrained=pretrained,
35
+ freeze_early_layers=freeze_early_layers,
36
+ dropout=dropout,
37
+ feature_dim=feature_dim
38
+ )
39
+
40
+ # ViT Branch (DeiT-Base)
41
+ self.vit_branch = ViTBranch(
42
+ pretrained=pretrained,
43
+ freeze_early_layers=freeze_early_layers,
44
+ dropout=dropout,
45
+ feature_dim=feature_dim
46
+ )
47
+
48
+ # Fusion Module
49
+ self.fusion_module = AdaptiveFusionModule(
50
+ feature_dim=feature_dim,
51
+ hidden_dim=hidden_dim,
52
+ dropout=dropout
53
+ )
54
+
55
+ # Classification Head
56
+ self.classifier = nn.Sequential(
57
+ nn.Linear(hidden_dim, hidden_dim // 2),
58
+ nn.LayerNorm(hidden_dim // 2),
59
+ nn.GELU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(hidden_dim // 2, num_classes)
62
+ )
63
+
64
+ # Auxiliary classifiers for training stability
65
+ self.cnn_aux_classifier = nn.Sequential(
66
+ nn.Linear(feature_dim, hidden_dim // 2),
67
+ nn.ReLU(),
68
+ nn.Dropout(dropout),
69
+ nn.Linear(hidden_dim // 2, num_classes)
70
+ )
71
+
72
+ self.vit_aux_classifier = nn.Sequential(
73
+ nn.Linear(feature_dim, hidden_dim // 2),
74
+ nn.ReLU(),
75
+ nn.Dropout(dropout),
76
+ nn.Linear(hidden_dim // 2, num_classes)
77
+ )
78
+
79
+ # Initialize weights
80
+ self._initialize_weights()
81
+
82
+ def _initialize_weights(self):
83
+ """Initialize classifier weights"""
84
+ for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]:
85
+ for layer in m:
86
+ if isinstance(layer, nn.Linear):
87
+ nn.init.xavier_uniform_(layer.weight)
88
+ if layer.bias is not None:
89
+ nn.init.constant_(layer.bias, 0)
90
+
91
+ def forward(
92
+ self,
93
+ x: torch.Tensor,
94
+ return_features: bool = False,
95
+ use_aux_loss: bool = True
96
+ ) -> Dict[str, torch.Tensor]:
97
+ """
98
+ Forward pass
99
+
100
+ Args:
101
+ x: Input tensor [B, 3, H, W]
102
+ return_features: Whether to return intermediate features
103
+ use_aux_loss: Whether to compute auxiliary losses
104
+
105
+ Returns:
106
+ Dictionary containing logits and optionally features/aux_logits
107
+ """
108
+ # CNN Branch
109
+ cnn_spatial, cnn_global = self.cnn_branch(x)
110
+
111
+ # ViT Branch
112
+ vit_spatial, vit_global = self.vit_branch(x)
113
+
114
+ # Fusion
115
+ fused_spatial, fused_global = self.fusion_module(
116
+ cnn_spatial, cnn_global, vit_spatial, vit_global
117
+ )
118
+
119
+ # Main classification
120
+ logits = self.classifier(fused_global)
121
+
122
+ # Prepare output
123
+ output = {'logits': logits}
124
+
125
+ # Auxiliary losses for training
126
+ if use_aux_loss and self.training:
127
+ cnn_aux_logits = self.cnn_aux_classifier(cnn_global)
128
+ vit_aux_logits = self.vit_aux_classifier(vit_global)
129
+ output.update({
130
+ 'cnn_aux_logits': cnn_aux_logits,
131
+ 'vit_aux_logits': vit_aux_logits
132
+ })
133
+
134
+ # Return features if requested
135
+ if return_features:
136
+ output.update({
137
+ 'cnn_spatial': cnn_spatial,
138
+ 'cnn_global': cnn_global,
139
+ 'vit_spatial': vit_spatial,
140
+ 'vit_global': vit_global,
141
+ 'fused_spatial': fused_spatial,
142
+ 'fused_global': fused_global
143
+ })
144
+
145
+ return output
146
+
147
+ def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
148
+ """Get attention maps for visualization"""
149
+ with torch.no_grad():
150
+ # Get features
151
+ output = self.forward(x, return_features=True, use_aux_loss=False)
152
+
153
+ # CNN attention (using global average pooling weights)
154
+ cnn_spatial = output['cnn_spatial'] # [B, feature_dim, 7, 7]
155
+ cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) # [B, 1, 7, 7]
156
+ cnn_attention = F.interpolate(
157
+ cnn_attention,
158
+ size=(224, 224),
159
+ mode='bilinear',
160
+ align_corners=False
161
+ ) # [B, 1, 224, 224]
162
+
163
+ # ViT attention (using patch importance)
164
+ vit_spatial = output['vit_spatial'] # [B, 197, feature_dim] (196 patches + 1 CLS)
165
+ vit_patches = vit_spatial[:, 1:] # Remove CLS token, get [B, 196, feature_dim]
166
+ vit_attention = torch.mean(vit_patches, dim=-1) # [B, 196]
167
+ vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) # [B, 1, 14, 14]
168
+ vit_attention = F.interpolate(
169
+ vit_attention,
170
+ size=(224, 224),
171
+ mode='bilinear',
172
+ align_corners=False
173
+ ) # [B, 1, 224, 224]
174
+
175
+ return {
176
+ 'cnn_attention': cnn_attention,
177
+ 'vit_attention': vit_attention
178
+ }
179
+
180
+ def freeze_backbone(self):
181
+ """Freeze backbone networks"""
182
+ for param in self.cnn_branch.backbone.parameters():
183
+ param.requires_grad = False
184
+ for param in self.vit_branch.vit.parameters():
185
+ param.requires_grad = False
186
+
187
+ def unfreeze_backbone(self):
188
+ """Unfreeze backbone networks"""
189
+ for param in self.cnn_branch.backbone.parameters():
190
+ param.requires_grad = True
191
+ for param in self.vit_branch.vit.parameters():
192
+ param.requires_grad = True
193
+
194
+ def get_model_size(self) -> Dict[str, int]:
195
+ """Get model size information"""
196
+ total_params = sum(p.numel() for p in self.parameters())
197
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
198
+
199
+ cnn_params = sum(p.numel() for p in self.cnn_branch.parameters())
200
+ vit_params = sum(p.numel() for p in self.vit_branch.parameters())
201
+ fusion_params = sum(p.numel() for p in self.fusion_module.parameters())
202
+ classifier_params = sum(p.numel() for p in self.classifier.parameters())
203
+
204
+ return {
205
+ 'total_params': total_params,
206
+ 'trainable_params': trainable_params,
207
+ 'cnn_params': cnn_params,
208
+ 'vit_params': vit_params,
209
+ 'fusion_params': fusion_params,
210
+ 'classifier_params': classifier_params
211
+ }