food-classifier-space / models /hybrid_model.py
codealchemist01's picture
Upload models/hybrid_model.py with huggingface_hub
dbb41ac verified
raw
history blame
7.46 kB
"""
Hybrid CNN-ViT Food Classifier
Combines ResNet50 and DeiT-Base with adaptive fusion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Optional
from .cnn_branch import CNNBranch
from .vit_branch import ViTBranch
from .fusion_module import AdaptiveFusionModule
class HybridFoodClassifier(nn.Module):
"""Hybrid CNN-ViT model for food classification"""
def __init__(
self,
num_classes: int = 101,
feature_dim: int = 768,
hidden_dim: int = 512,
dropout: float = 0.2,
pretrained: bool = True,
freeze_early_layers: bool = True
):
super(HybridFoodClassifier, self).__init__()
self.num_classes = num_classes
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
# CNN Branch (ResNet50)
self.cnn_branch = CNNBranch(
pretrained=pretrained,
freeze_early_layers=freeze_early_layers,
dropout=dropout,
feature_dim=feature_dim
)
# ViT Branch (DeiT-Base)
self.vit_branch = ViTBranch(
pretrained=pretrained,
freeze_early_layers=freeze_early_layers,
dropout=dropout,
feature_dim=feature_dim
)
# Fusion Module
self.fusion_module = AdaptiveFusionModule(
feature_dim=feature_dim,
hidden_dim=hidden_dim,
dropout=dropout
)
# Classification Head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
# Auxiliary classifiers for training stability
self.cnn_aux_classifier = nn.Sequential(
nn.Linear(feature_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
self.vit_aux_classifier = nn.Sequential(
nn.Linear(feature_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize classifier weights"""
for m in [self.classifier, self.cnn_aux_classifier, self.vit_aux_classifier]:
for layer in m:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
def forward(
self,
x: torch.Tensor,
return_features: bool = False,
use_aux_loss: bool = True
) -> Dict[str, torch.Tensor]:
"""
Forward pass
Args:
x: Input tensor [B, 3, H, W]
return_features: Whether to return intermediate features
use_aux_loss: Whether to compute auxiliary losses
Returns:
Dictionary containing logits and optionally features/aux_logits
"""
# CNN Branch
cnn_spatial, cnn_global = self.cnn_branch(x)
# ViT Branch
vit_spatial, vit_global = self.vit_branch(x)
# Fusion
fused_spatial, fused_global = self.fusion_module(
cnn_spatial, cnn_global, vit_spatial, vit_global
)
# Main classification
logits = self.classifier(fused_global)
# Prepare output
output = {'logits': logits}
# Auxiliary losses for training
if use_aux_loss and self.training:
cnn_aux_logits = self.cnn_aux_classifier(cnn_global)
vit_aux_logits = self.vit_aux_classifier(vit_global)
output.update({
'cnn_aux_logits': cnn_aux_logits,
'vit_aux_logits': vit_aux_logits
})
# Return features if requested
if return_features:
output.update({
'cnn_spatial': cnn_spatial,
'cnn_global': cnn_global,
'vit_spatial': vit_spatial,
'vit_global': vit_global,
'fused_spatial': fused_spatial,
'fused_global': fused_global
})
return output
def get_attention_maps(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get attention maps for visualization"""
with torch.no_grad():
# Get features
output = self.forward(x, return_features=True, use_aux_loss=False)
# CNN attention (using global average pooling weights)
cnn_spatial = output['cnn_spatial'] # [B, feature_dim, 7, 7]
cnn_attention = torch.mean(cnn_spatial, dim=1, keepdim=True) # [B, 1, 7, 7]
cnn_attention = F.interpolate(
cnn_attention,
size=(224, 224),
mode='bilinear',
align_corners=False
) # [B, 1, 224, 224]
# ViT attention (using patch importance)
vit_spatial = output['vit_spatial'] # [B, 197, feature_dim] (196 patches + 1 CLS)
vit_patches = vit_spatial[:, 1:] # Remove CLS token, get [B, 196, feature_dim]
vit_attention = torch.mean(vit_patches, dim=-1) # [B, 196]
vit_attention = vit_attention.view(-1, 14, 14).unsqueeze(1) # [B, 1, 14, 14]
vit_attention = F.interpolate(
vit_attention,
size=(224, 224),
mode='bilinear',
align_corners=False
) # [B, 1, 224, 224]
return {
'cnn_attention': cnn_attention,
'vit_attention': vit_attention
}
def freeze_backbone(self):
"""Freeze backbone networks"""
for param in self.cnn_branch.backbone.parameters():
param.requires_grad = False
for param in self.vit_branch.vit.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
"""Unfreeze backbone networks"""
for param in self.cnn_branch.backbone.parameters():
param.requires_grad = True
for param in self.vit_branch.vit.parameters():
param.requires_grad = True
def get_model_size(self) -> Dict[str, int]:
"""Get model size information"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
cnn_params = sum(p.numel() for p in self.cnn_branch.parameters())
vit_params = sum(p.numel() for p in self.vit_branch.parameters())
fusion_params = sum(p.numel() for p in self.fusion_module.parameters())
classifier_params = sum(p.numel() for p in self.classifier.parameters())
return {
'total_params': total_params,
'trainable_params': trainable_params,
'cnn_params': cnn_params,
'vit_params': vit_params,
'fusion_params': fusion_params,
'classifier_params': classifier_params
}