|
|
""" |
|
|
Hybrid CNN-ViT Food Classifier |
|
|
Combines ResNet50 and DeiT-Base with adaptive fusion |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, Any, Optional |
|
|
|
|
|
from .cnn_branch import CNNBranch |
|
|
from .vit_branch import ViTBranch |
|
|
from .fusion_module import AdaptiveFusionModule |
|
|
|
|
|
class HybridFoodClassifier(nn.Module): |
|
|
"""Hybrid CNN-ViT model for food classification""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_classes: int = 101, |
|
|
feature_dim: int = 768, |
|
|
hidden_dim: int = 512, |
|
|
dropout: float = 0.2, |
|
|
pretrained: bool = True, |
|
|
freeze_early_layers: bool = True |
|
|
): |
|
|
super(HybridFoodClassifier, self).__init__() |
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.feature_dim = feature_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
|
|
|
self.cnn_branch = CNNBranch( |
|
|
pretrained=pretrained, |
|
|
freeze_early_layers=freeze_early_layers, |
|
|
dropout=dropout, |
|
|
feature_dim=feature_dim |
|
|
) |
|
|
|
|
|
|
|
|
self.vit_branch = ViTBranch( |
|
|
pretrained=pretrained, |
|
|
freeze_early_layers=freeze_early_layers, |
|
|
dropout=dropout, |
|
|
feature_dim=feature_dim |
|
|
) |
|
|
|
|
|
|
|
|
self.fusion_module = AdaptiveFusionModule( |
|
|
feature_dim=feature_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.LayerNorm(hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
self.cnn_aux_classifier = nn.Sequential( |
|
|
nn.Linear(feature_dim, hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.vit_aux_classifier = nn.Sequential( |
|
|
nn.Linear(feature_dim, hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
|
"""Initialize classifier weights""" |
|
|
for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]: |
|
|
for layer in m: |
|
|
if isinstance(layer, nn.Linear): |
|
|
nn.init.xavier_uniform_(layer.weight) |
|
|
if layer.bias is not None: |
|
|
nn.init.constant_(layer.bias, 0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
return_features: bool = False, |
|
|
use_aux_loss: bool = True |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input tensor [B, 3, H, W] |
|
|
return_features: Whether to return intermediate features |
|
|
use_aux_loss: Whether to compute auxiliary losses |
|
|
|
|
|
Returns: |
|
|
Dictionary containing logits and optionally features/aux_logits |
|
|
""" |
|
|
|
|
|
cnn_spatial, cnn_global = self.cnn_branch(x) |
|
|
|
|
|
|
|
|
vit_spatial, vit_global = self.vit_branch(x) |
|
|
|
|
|
|
|
|
fused_spatial, fused_global = self.fusion_module( |
|
|
cnn_spatial, cnn_global, vit_spatial, vit_global |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.classifier(fused_global) |
|
|
|
|
|
|
|
|
output = {'logits': logits} |
|
|
|
|
|
|
|
|
if use_aux_loss and self.training: |
|
|
cnn_aux_logits = self.cnn_aux_classifier(cnn_global) |
|
|
vit_aux_logits = self.vit_aux_classifier(vit_global) |
|
|
output.update({ |
|
|
'cnn_aux_logits': cnn_aux_logits, |
|
|
'vit_aux_logits': vit_aux_logits |
|
|
}) |
|
|
|
|
|
|
|
|
if return_features: |
|
|
output.update({ |
|
|
'cnn_spatial': cnn_spatial, |
|
|
'cnn_global': cnn_global, |
|
|
'vit_spatial': vit_spatial, |
|
|
'vit_global': vit_global, |
|
|
'fused_spatial': fused_spatial, |
|
|
'fused_global': fused_global |
|
|
}) |
|
|
|
|
|
return output |
|
|
|
|
|
def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Get attention maps for visualization""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
output = self.forward(x, return_features=True, use_aux_loss=False) |
|
|
|
|
|
|
|
|
cnn_spatial = output['cnn_spatial'] |
|
|
cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) |
|
|
cnn_attention = F.interpolate( |
|
|
cnn_attention, |
|
|
size=(224, 224), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
vit_spatial = output['vit_spatial'] |
|
|
vit_patches = vit_spatial[:, 1:] |
|
|
vit_attention = torch.mean(vit_patches, dim=-1) |
|
|
vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) |
|
|
vit_attention = F.interpolate( |
|
|
vit_attention, |
|
|
size=(224, 224), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
return { |
|
|
'cnn_attention': cnn_attention, |
|
|
'vit_attention': vit_attention |
|
|
} |
|
|
|
|
|
def freeze_backbone(self): |
|
|
"""Freeze backbone networks""" |
|
|
for param in self.cnn_branch.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.vit_branch.vit.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def unfreeze_backbone(self): |
|
|
"""Unfreeze backbone networks""" |
|
|
for param in self.cnn_branch.backbone.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.vit_branch.vit.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def get_model_size(self) -> Dict[str, int]: |
|
|
"""Get model size information""" |
|
|
total_params = sum(p.numel() for p in self.parameters()) |
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
cnn_params = sum(p.numel() for p in self.cnn_branch.parameters()) |
|
|
vit_params = sum(p.numel() for p in self.vit_branch.parameters()) |
|
|
fusion_params = sum(p.numel() for p in self.fusion_module.parameters()) |
|
|
classifier_params = sum(p.numel() for p in self.classifier.parameters()) |
|
|
|
|
|
return { |
|
|
'total_params': total_params, |
|
|
'trainable_params': trainable_params, |
|
|
'cnn_params': cnn_params, |
|
|
'vit_params': vit_params, |
|
|
'fusion_params': fusion_params, |
|
|
'classifier_params': classifier_params |
|
|
} |