|
|
""" |
|
|
CNN Branch for Hybrid Food Classifier |
|
|
Uses ResNet50 as backbone with adaptive pooling |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
from typing import Tuple |
|
|
|
|
|
class CNNBranch(nn.Module): |
|
|
"""CNN branch using ResNet50 backbone""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
backbone: str = "resnet50", |
|
|
pretrained: bool = True, |
|
|
freeze_early_layers: bool = True, |
|
|
dropout: float = 0.3, |
|
|
feature_dim: int = 2048 |
|
|
): |
|
|
super(CNNBranch, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
|
if backbone == "resnet50": |
|
|
self.backbone = models.resnet50(pretrained=pretrained) |
|
|
|
|
|
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) |
|
|
backbone_dim = 2048 |
|
|
else: |
|
|
raise ValueError(f"Unsupported backbone: {backbone}") |
|
|
|
|
|
|
|
|
if freeze_early_layers: |
|
|
self._freeze_early_layers() |
|
|
|
|
|
|
|
|
self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) |
|
|
|
|
|
|
|
|
self.feature_proj = nn.Sequential( |
|
|
nn.Conv2d(backbone_dim, feature_dim, kernel_size=1), |
|
|
nn.BatchNorm2d(feature_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout2d(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
|
|
|
self.feature_head = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.BatchNorm1d(feature_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def _freeze_early_layers(self): |
|
|
"""Freeze early layers of the backbone""" |
|
|
|
|
|
layers_to_freeze = 6 |
|
|
for i, child in enumerate(self.backbone.children()): |
|
|
if i < layers_to_freeze: |
|
|
for param in child.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input tensor [B, 3, H, W] |
|
|
|
|
|
Returns: |
|
|
spatial_features: Spatial features [B, feature_dim, 7, 7] |
|
|
global_features: Global features [B, feature_dim] |
|
|
""" |
|
|
|
|
|
features = self.backbone(x) |
|
|
|
|
|
|
|
|
features = self.adaptive_pool(features) |
|
|
|
|
|
|
|
|
spatial_features = self.feature_proj(features) |
|
|
|
|
|
|
|
|
global_features = self.global_pool(spatial_features) |
|
|
global_features = global_features.flatten(1) |
|
|
|
|
|
|
|
|
global_features = self.feature_head(global_features) |
|
|
|
|
|
return spatial_features, global_features |
|
|
|
|
|
def get_feature_dim(self) -> int: |
|
|
"""Get feature dimension""" |
|
|
return self.feature_dim |