File size: 7,460 Bytes
dbb41ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
Hybrid CNN-ViT Food Classifier
Combines ResNet50 and DeiT-Base with adaptive fusion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Optional
from .cnn_branch import CNNBranch
from .vit_branch import ViTBranch
from .fusion_module import AdaptiveFusionModule
class HybridFoodClassifier(nn.Module):
"""Hybrid CNN-ViT model for food classification"""
def __init__(
self,
num_classes: int = 101,
feature_dim: int = 768,
hidden_dim: int = 512,
dropout: float = 0.2,
pretrained: bool = True,
freeze_early_layers: bool = True
):
super(HybridFoodClassifier, self).__init__()
self.num_classes = num_classes
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
# CNN Branch (ResNet50)
self.cnn_branch = CNNBranch(
pretrained=pretrained,
freeze_early_layers=freeze_early_layers,
dropout=dropout,
feature_dim=feature_dim
)
# ViT Branch (DeiT-Base)
self.vit_branch = ViTBranch(
pretrained=pretrained,
freeze_early_layers=freeze_early_layers,
dropout=dropout,
feature_dim=feature_dim
)
# Fusion Module
self.fusion_module = AdaptiveFusionModule(
feature_dim=feature_dim,
hidden_dim=hidden_dim,
dropout=dropout
)
# Classification Head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
# Auxiliary classifiers for training stability
self.cnn_aux_classifier = nn.Sequential(
nn.Linear(feature_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
self.vit_aux_classifier = nn.Sequential(
nn.Linear(feature_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize classifier weights"""
for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]:
for layer in m:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
def forward(
self,
x: torch.Tensor,
return_features: bool = False,
use_aux_loss: bool = True
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
x: Input tensor [B, 3, H, W]
return_features: Whether to return intermediate features
use_aux_loss: Whether to compute auxiliary losses
Returns:
Dictionary containing logits and optionally features/aux_logits
"""
# CNN Branch
cnn_spatial, cnn_global = self.cnn_branch(x)
# ViT Branch
vit_spatial, vit_global = self.vit_branch(x)
# Fusion
fused_spatial, fused_global = self.fusion_module(
cnn_spatial, cnn_global, vit_spatial, vit_global
)
# Main classification
logits = self.classifier(fused_global)
# Prepare output
output = {'logits': logits}
# Auxiliary losses for training
if use_aux_loss and self.training:
cnn_aux_logits = self.cnn_aux_classifier(cnn_global)
vit_aux_logits = self.vit_aux_classifier(vit_global)
output.update({
'cnn_aux_logits': cnn_aux_logits,
'vit_aux_logits': vit_aux_logits
})
# Return features if requested
if return_features:
output.update({
'cnn_spatial': cnn_spatial,
'cnn_global': cnn_global,
'vit_spatial': vit_spatial,
'vit_global': vit_global,
'fused_spatial': fused_spatial,
'fused_global': fused_global
})
return output
def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get attention maps for visualization"""
with torch.no_grad():
# Get features
output = self.forward(x, return_features=True, use_aux_loss=False)
# CNN attention (using global average pooling weights)
cnn_spatial = output['cnn_spatial'] # [B, feature_dim, 7, 7]
cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) # [B, 1, 7, 7]
cnn_attention = F.interpolate(
cnn_attention,
size=(224, 224),
mode='bilinear',
align_corners=False
) # [B, 1, 224, 224]
# ViT attention (using patch importance)
vit_spatial = output['vit_spatial'] # [B, 197, feature_dim] (196 patches + 1 CLS)
vit_patches = vit_spatial[:, 1:] # Remove CLS token, get [B, 196, feature_dim]
vit_attention = torch.mean(vit_patches, dim=-1) # [B, 196]
vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) # [B, 1, 14, 14]
vit_attention = F.interpolate(
vit_attention,
size=(224, 224),
mode='bilinear',
align_corners=False
) # [B, 1, 224, 224]
return {
'cnn_attention': cnn_attention,
'vit_attention': vit_attention
}
def freeze_backbone(self):
"""Freeze backbone networks"""
for param in self.cnn_branch.backbone.parameters():
param.requires_grad = False
for param in self.vit_branch.vit.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
"""Unfreeze backbone networks"""
for param in self.cnn_branch.backbone.parameters():
param.requires_grad = True
for param in self.vit_branch.vit.parameters():
param.requires_grad = True
def get_model_size(self) -> Dict[str, int]:
"""Get model size information"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
cnn_params = sum(p.numel() for p in self.cnn_branch.parameters())
vit_params = sum(p.numel() for p in self.vit_branch.parameters())
fusion_params = sum(p.numel() for p in self.fusion_module.parameters())
classifier_params = sum(p.numel() for p in self.classifier.parameters())
return {
'total_params': total_params,
'trainable_params': trainable_params,
'cnn_params': cnn_params,
'vit_params': vit_params,
'fusion_params': fusion_params,
'classifier_params': classifier_params
} |