""" CNN Branch for Hybrid Food Classifier Uses ResNet50 as backbone with adaptive pooling """ import torch import torch.nn as nn import torchvision.models as models from typing import Tuple class CNNBranch(nn.Module): """CNN branch using ResNet50 backbone""" def __init__( self, backbone: str = "resnet50", pretrained: bool = True, freeze_early_layers: bool = True, dropout: float = 0.3, feature_dim: int = 2048 ): super(CNNBranch, self).__init__() self.feature_dim = feature_dim # Load backbone if backbone == "resnet50": self.backbone = models.resnet50(pretrained=pretrained) # Remove the final classification layer self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) backbone_dim = 2048 else: raise ValueError(f"Unsupported backbone: {backbone}") # Freeze early layers if specified if freeze_early_layers: self._freeze_early_layers() # Adaptive pooling to get consistent feature size self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # 7x7 spatial features # Feature projection self.feature_proj = nn.Sequential( nn.Conv2d(backbone_dim, feature_dim, kernel_size=1), nn.BatchNorm2d(feature_dim), nn.ReLU(inplace=True), nn.Dropout2d(dropout) ) # Global average pooling for final features self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) # Additional feature processing self.feature_head = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True), nn.Dropout(dropout) ) def _freeze_early_layers(self): """Freeze early layers of the backbone""" # Freeze first 6 layers (conv1, bn1, relu, maxpool, layer1, layer2) layers_to_freeze = 6 for i, child in enumerate(self.backbone.children()): if i < layers_to_freeze: for param in child.parameters(): param.requires_grad = False def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass Args: x: Input tensor [B, 3, H, W] Returns: spatial_features: Spatial features [B, feature_dim, 7, 7] global_features: Global features [B, feature_dim] """ # Extract features from backbone features = self.backbone(x) # [B, 2048, H', W'] # Adaptive pooling features = self.adaptive_pool(features) # [B, 2048, 7, 7] # Project features spatial_features = self.feature_proj(features) # [B, feature_dim, 7, 7] # Global pooling for classification features global_features = self.global_pool(spatial_features) # [B, feature_dim, 1, 1] global_features = global_features.flatten(1) # [B, feature_dim] # Additional processing global_features = self.feature_head(global_features) # [B, feature_dim] return spatial_features, global_features def get_feature_dim(self) -> int: """Get feature dimension""" return self.feature_dim