File size: 3,439 Bytes
84c468a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
"""
CNN Branch for Hybrid Food Classifier
Uses ResNet50 as backbone with adaptive pooling
"""
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Tuple
class CNNBranch(nn.Module):
"""CNN branch using ResNet50 backbone"""
def __init__(
self,
backbone: str = "resnet50",
pretrained: bool = True,
freeze_early_layers: bool = True,
dropout: float = 0.3,
feature_dim: int = 2048
):
super(CNNBranch, self).__init__()
self.feature_dim = feature_dim
# Load backbone
if backbone == "resnet50":
self.backbone = models.resnet50(pretrained=pretrained)
# Remove the final classification layer
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
backbone_dim = 2048
else:
raise ValueError(f"Unsupported backbone: {backbone}")
# Freeze early layers if specified
if freeze_early_layers:
self._freeze_early_layers()
# Adaptive pooling to get consistent feature size
self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # 7x7 spatial features
# Feature projection
self.feature_proj = nn.Sequential(
nn.Conv2d(backbone_dim, feature_dim, kernel_size=1),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout)
)
# Global average pooling for final features
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
# Additional feature processing
self.feature_head = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
def _freeze_early_layers(self):
"""Freeze early layers of the backbone"""
# Freeze first 6 layers (conv1, bn1, relu, maxpool, layer1, layer2)
layers_to_freeze = 6
for i, child in enumerate(self.backbone.children()):
if i < layers_to_freeze:
for param in child.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
Args:
x: Input tensor [B, 3, H, W]
Returns:
spatial_features: Spatial features [B, feature_dim, 7, 7]
global_features: Global features [B, feature_dim]
"""
# Extract features from backbone
features = self.backbone(x) # [B, 2048, H', W']
# Adaptive pooling
features = self.adaptive_pool(features) # [B, 2048, 7, 7]
# Project features
spatial_features = self.feature_proj(features) # [B, feature_dim, 7, 7]
# Global pooling for classification features
global_features = self.global_pool(spatial_features) # [B, feature_dim, 1, 1]
global_features = global_features.flatten(1) # [B, feature_dim]
# Additional processing
global_features = self.feature_head(global_features) # [B, feature_dim]
return spatial_features, global_features
def get_feature_dim(self) -> int:
"""Get feature dimension"""
return self.feature_dim |