venkyvicky commited on
Commit
102c4fc
·
verified ·
1 Parent(s): d6b2969

Update ResNet_for_CC.py

Browse files
Files changed (1) hide show
  1. ResNet_for_CC.py +25 -59
ResNet_for_CC.py CHANGED
@@ -2,92 +2,58 @@ import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
4
 
 
5
  class ResClassifier(nn.Module):
6
- """
7
- A classifier with two fully connected layers followed by a final linear layer.
8
- Uses BatchNorm, ReLU activations, and Dropout for better generalization.
9
- """
10
- def __init__(self, num_classes=14):
11
  super(ResClassifier, self).__init__()
12
-
13
- # First fully connected layer: reduces 128D features to 64D
14
  self.fc1 = nn.Sequential(
15
  nn.Linear(128, 64),
16
  nn.BatchNorm1d(64, affine=True),
17
  nn.ReLU(inplace=True),
18
  nn.Dropout()
19
  )
20
-
21
- # Second fully connected layer: retains 64D features
22
  self.fc2 = nn.Sequential(
23
  nn.Linear(64, 64),
24
  nn.BatchNorm1d(64, affine=True),
25
  nn.ReLU(inplace=True),
26
  nn.Dropout()
27
  )
28
-
29
- # Final classification layer mapping 64D features to class logits
30
- self.fc3 = nn.Linear(64, num_classes)
31
 
32
  def forward(self, x):
33
- """
34
- Forward pass through the classifier.
35
- Returns class logits after two hidden layers.
36
- """
37
- x = self.fc1(x) # First FC layer
38
- x = self.fc2(x) # Second FC layer
39
- output = self.fc3(x) # Final classification layer
40
- return output
41
-
42
 
43
  class CC_model(nn.Module):
44
- """
45
- Clothing Classification Model based on ResNet50.
46
- Extracts deep features and uses two independent classifiers for predictions.
47
- """
48
  def __init__(self, num_classes1=14, num_classes2=None):
 
 
 
 
49
  super(CC_model, self).__init__()
50
-
51
- # If num_classes2 is not specified, default to num_classes1
52
- num_classes2 = num_classes2 if num_classes2 else num_classes1
53
- assert num_classes1 == num_classes2 # Ensure both classifiers predict the same categories
54
-
55
  self.num_classes = num_classes1
56
-
57
- # Load a pretrained ResNet-50 model as the feature extractor
58
  self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
59
-
60
- # Remove ResNet's original classification layer to use as a feature extractor
61
  num_ftrs = self.model_resnet.fc.in_features
62
- self.model_resnet.fc = nn.Identity() # Identity layer keeps feature dimensions
63
-
64
- # Additional transformation layer reducing feature size to 128D
65
  self.dr = nn.Linear(num_ftrs, 128)
66
-
67
- # Two independent classifiers
68
  self.fc1 = ResClassifier(num_classes1)
69
  self.fc2 = ResClassifier(num_classes1)
70
 
71
  def forward(self, x, detach_feature=False):
72
- """
73
- Forward pass through the model.
74
- Extracts deep features from ResNet and processes them through classifiers.
75
- """
76
  with torch.no_grad():
77
- # Extract deep features using ResNet-50 (without its original classification head)
78
  feature = self.model_resnet(x)
79
-
80
- # Generate transformed features (128D) using the custom linear layer
81
- dr_feature = self.dr(feature)
82
-
83
- if detach_feature:
84
- dr_feature = dr_feature.detach() # Detach feature for non-trainable forward pass
85
-
86
- # Pass features through two independent classifiers
87
- out1 = self.fc1(dr_feature)
88
- out2 = self.fc2(dr_feature)
89
-
90
- # Compute the mean prediction from both classifiers
91
- output_mean = (out1 + out2) / 2
92
-
93
- return dr_feature, output_mean # Returning feature embeddings and final prediction
 
2
  import torch.nn as nn
3
  import torchvision.models as models
4
 
5
+
6
  class ResClassifier(nn.Module):
7
+ def __init__(self, class_num=14):
 
 
 
 
8
  super(ResClassifier, self).__init__()
 
 
9
  self.fc1 = nn.Sequential(
10
  nn.Linear(128, 64),
11
  nn.BatchNorm1d(64, affine=True),
12
  nn.ReLU(inplace=True),
13
  nn.Dropout()
14
  )
 
 
15
  self.fc2 = nn.Sequential(
16
  nn.Linear(64, 64),
17
  nn.BatchNorm1d(64, affine=True),
18
  nn.ReLU(inplace=True),
19
  nn.Dropout()
20
  )
21
+ self.fc3 = nn.Linear(64, class_num)
 
 
22
 
23
  def forward(self, x):
24
+ fc1_emb = self.fc1(x)
25
+ fc2_emb = self.fc2(fc1_emb)
26
+ logit = self.fc3(fc2_emb)
27
+ return logit
 
 
 
 
 
28
 
29
  class CC_model(nn.Module):
 
 
 
 
30
  def __init__(self, num_classes1=14, num_classes2=None):
31
+
32
+ if num_classes2 is None:
33
+ num_classes2 = num_classes1
34
+
35
  super(CC_model, self).__init__()
36
+ assert num_classes1 == num_classes2
 
 
 
 
37
  self.num_classes = num_classes1
 
 
38
  self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
 
 
39
  num_ftrs = self.model_resnet.fc.in_features
40
+ self.model_resnet.fc = nn.Identity()
41
+ self.classification_fc = nn.Linear(num_ftrs, num_classes1)
 
42
  self.dr = nn.Linear(num_ftrs, 128)
 
 
43
  self.fc1 = ResClassifier(num_classes1)
44
  self.fc2 = ResClassifier(num_classes1)
45
 
46
  def forward(self, x, detach_feature=False):
47
+
 
 
 
48
  with torch.no_grad():
 
49
  feature = self.model_resnet(x)
50
+ res_out = self.classification_fc(feature)
51
+ if detach_feature:
52
+ feature = feature.detach()
53
+ dr_feature = self.dr(feature)
54
+ out1 = self.fc1(dr_feature)
55
+ out2 = self.fc2(dr_feature)
56
+ output_mean = (out1 + out2) / 2
57
+ return dr_feature, output_mean
58
+
59
+ #return dr_feature