codealchemist01 commited on
Commit
84c468a
·
verified ·
1 Parent(s): dbb41ac

Upload models/cnn_branch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/cnn_branch.py +100 -0
models/cnn_branch.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CNN Branch for Hybrid Food Classifier
3
+ Uses ResNet50 as backbone with adaptive pooling
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.models as models
8
+ from typing import Tuple
9
+
10
+ class CNNBranch(nn.Module):
11
+ """CNN branch using ResNet50 backbone"""
12
+
13
+ def __init__(
14
+ self,
15
+ backbone: str = "resnet50",
16
+ pretrained: bool = True,
17
+ freeze_early_layers: bool = True,
18
+ dropout: float = 0.3,
19
+ feature_dim: int = 2048
20
+ ):
21
+ super(CNNBranch, self).__init__()
22
+
23
+ self.feature_dim = feature_dim
24
+
25
+ # Load backbone
26
+ if backbone == "resnet50":
27
+ self.backbone = models.resnet50(pretrained=pretrained)
28
+ # Remove the final classification layer
29
+ self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
30
+ backbone_dim = 2048
31
+ else:
32
+ raise ValueError(f"Unsupported backbone: {backbone}")
33
+
34
+ # Freeze early layers if specified
35
+ if freeze_early_layers:
36
+ self._freeze_early_layers()
37
+
38
+ # Adaptive pooling to get consistent feature size
39
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # 7x7 spatial features
40
+
41
+ # Feature projection
42
+ self.feature_proj = nn.Sequential(
43
+ nn.Conv2d(backbone_dim, feature_dim, kernel_size=1),
44
+ nn.BatchNorm2d(feature_dim),
45
+ nn.ReLU(inplace=True),
46
+ nn.Dropout2d(dropout)
47
+ )
48
+
49
+ # Global average pooling for final features
50
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
51
+
52
+ # Additional feature processing
53
+ self.feature_head = nn.Sequential(
54
+ nn.Linear(feature_dim, feature_dim),
55
+ nn.BatchNorm1d(feature_dim),
56
+ nn.ReLU(inplace=True),
57
+ nn.Dropout(dropout)
58
+ )
59
+
60
+ def _freeze_early_layers(self):
61
+ """Freeze early layers of the backbone"""
62
+ # Freeze first 6 layers (conv1, bn1, relu, maxpool, layer1, layer2)
63
+ layers_to_freeze = 6
64
+ for i, child in enumerate(self.backbone.children()):
65
+ if i < layers_to_freeze:
66
+ for param in child.parameters():
67
+ param.requires_grad = False
68
+
69
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """
71
+ Forward pass
72
+
73
+ Args:
74
+ x: Input tensor [B, 3, H, W]
75
+
76
+ Returns:
77
+ spatial_features: Spatial features [B, feature_dim, 7, 7]
78
+ global_features: Global features [B, feature_dim]
79
+ """
80
+ # Extract features from backbone
81
+ features = self.backbone(x) # [B, 2048, H', W']
82
+
83
+ # Adaptive pooling
84
+ features = self.adaptive_pool(features) # [B, 2048, 7, 7]
85
+
86
+ # Project features
87
+ spatial_features = self.feature_proj(features) # [B, feature_dim, 7, 7]
88
+
89
+ # Global pooling for classification features
90
+ global_features = self.global_pool(spatial_features) # [B, feature_dim, 1, 1]
91
+ global_features = global_features.flatten(1) # [B, feature_dim]
92
+
93
+ # Additional processing
94
+ global_features = self.feature_head(global_features) # [B, feature_dim]
95
+
96
+ return spatial_features, global_features
97
+
98
+ def get_feature_dim(self) -> int:
99
+ """Get feature dimension"""
100
+ return self.feature_dim