File size: 3,439 Bytes
84c468a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
CNN Branch for Hybrid Food Classifier
Uses ResNet50 as backbone with adaptive pooling
"""
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Tuple

class CNNBranch(nn.Module):
    """CNN branch using ResNet50 backbone"""
    
    def __init__(
        self,
        backbone: str = "resnet50",
        pretrained: bool = True,
        freeze_early_layers: bool = True,
        dropout: float = 0.3,
        feature_dim: int = 2048
    ):
        super(CNNBranch, self).__init__()
        
        self.feature_dim = feature_dim
        
        # Load backbone
        if backbone == "resnet50":
            self.backbone = models.resnet50(pretrained=pretrained)
            # Remove the final classification layer
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_dim = 2048
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        
        # Freeze early layers if specified
        if freeze_early_layers:
            self._freeze_early_layers()
        
        # Adaptive pooling to get consistent feature size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))  # 7x7 spatial features
        
        # Feature projection
        self.feature_proj = nn.Sequential(
            nn.Conv2d(backbone_dim, feature_dim, kernel_size=1),
            nn.BatchNorm2d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout)
        )
        
        # Global average pooling for final features
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Additional feature processing
        self.feature_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
    
    def _freeze_early_layers(self):
        """Freeze early layers of the backbone"""
        # Freeze first 6 layers (conv1, bn1, relu, maxpool, layer1, layer2)
        layers_to_freeze = 6
        for i, child in enumerate(self.backbone.children()):
            if i < layers_to_freeze:
                for param in child.parameters():
                    param.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass
        
        Args:
            x: Input tensor [B, 3, H, W]
            
        Returns:
            spatial_features: Spatial features [B, feature_dim, 7, 7]
            global_features: Global features [B, feature_dim]
        """
        # Extract features from backbone
        features = self.backbone(x)  # [B, 2048, H', W']
        
        # Adaptive pooling
        features = self.adaptive_pool(features)  # [B, 2048, 7, 7]
        
        # Project features
        spatial_features = self.feature_proj(features)  # [B, feature_dim, 7, 7]
        
        # Global pooling for classification features
        global_features = self.global_pool(spatial_features)  # [B, feature_dim, 1, 1]
        global_features = global_features.flatten(1)  # [B, feature_dim]
        
        # Additional processing
        global_features = self.feature_head(global_features)  # [B, feature_dim]
        
        return spatial_features, global_features
    
    def get_feature_dim(self) -> int:
        """Get feature dimension"""
        return self.feature_dim