codealchemist01's picture
Upload models/cnn_branch.py with huggingface_hub
84c468a verified
raw
history blame
3.44 kB
"""
CNN Branch for Hybrid Food Classifier
Uses ResNet50 as backbone with adaptive pooling
"""
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Tuple
class CNNBranch(nn.Module):
"""CNN branch using ResNet50 backbone"""
def __init__(
self,
backbone: str = "resnet50",
pretrained: bool = True,
freeze_early_layers: bool = True,
dropout: float = 0.3,
feature_dim: int = 2048
):
super(CNNBranch, self).__init__()
self.feature_dim = feature_dim
# Load backbone
if backbone == "resnet50":
self.backbone = models.resnet50(pretrained=pretrained)
# Remove the final classification layer
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
backbone_dim = 2048
else:
raise ValueError(f"Unsupported backbone: {backbone}")
# Freeze early layers if specified
if freeze_early_layers:
self._freeze_early_layers()
# Adaptive pooling to get consistent feature size
self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # 7x7 spatial features
# Feature projection
self.feature_proj = nn.Sequential(
nn.Conv2d(backbone_dim, feature_dim, kernel_size=1),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout)
)
# Global average pooling for final features
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
# Additional feature processing
self.feature_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
def _freeze_early_layers(self):
"""Freeze early layers of the backbone"""
# Freeze first 6 layers (conv1, bn1, relu, maxpool, layer1, layer2)
layers_to_freeze = 6
for i, child in enumerate(self.backbone.children()):
if i < layers_to_freeze:
for param in child.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
Args:
x: Input tensor [B, 3, H, W]
Returns:
spatial_features: Spatial features [B, feature_dim, 7, 7]
global_features: Global features [B, feature_dim]
"""
# Extract features from backbone
features = self.backbone(x) # [B, 2048, H', W']
# Adaptive pooling
features = self.adaptive_pool(features) # [B, 2048, 7, 7]
# Project features
spatial_features = self.feature_proj(features) # [B, feature_dim, 7, 7]
# Global pooling for classification features
global_features = self.global_pool(spatial_features) # [B, feature_dim, 1, 1]
global_features = global_features.flatten(1) # [B, feature_dim]
# Additional processing
global_features = self.feature_head(global_features) # [B, feature_dim]
return spatial_features, global_features
def get_feature_dim(self) -> int:
"""Get feature dimension"""
return self.feature_dim