deedax commited on
Commit
952bbca
·
1 Parent(s): a6b83fe

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +36 -54
utils.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import pytorch_lightning as pl
4
  import torch.nn.functional as F
5
  import timm
6
- class CustomModelMain(nn.Module):
7
  def __init__(self, problem_type, n_classes):
8
  super().__init__()
9
  if problem_type == 'Classification' and n_classes == 1:
@@ -41,10 +41,37 @@ class CustomModelMain(nn.Module):
41
  x = self.fc2(x)
42
  x = self.output(x)
43
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class age_lightningg(pl.LightningModule):
45
  def __init__(self):
46
  super().__init__()
47
- self.model = CustomModelMain('Regression', 1)
48
  def forward(self, x):
49
  return self.model(x)
50
  def training_step(self, batch, batch_idx):
@@ -66,7 +93,7 @@ class age_lightningg(pl.LightningModule):
66
  class gender_lightning(pl.LightningModule):
67
  def __init__(self):
68
  super().__init__()
69
- self.model = CustomModelMain('Classification', 1)
70
  def forward(self, x):
71
  return self.model(x)
72
  def training_step(self, batch, batch_idx):
@@ -91,7 +118,7 @@ class gender_lightning(pl.LightningModule):
91
  class race_lightning(pl.LightningModule):
92
  def __init__(self):
93
  super().__init__()
94
- self.model = CustomModelMain('Classification', 5)
95
  def forward(self, x):
96
  return self.model(x)
97
  def training_step(self, batch, batch_idx):
@@ -110,7 +137,7 @@ class race_lightning(pl.LightningModule):
110
  y_val = y[:, 2]
111
  y_hat = self(x)
112
  y_oh = F.one_hot(y_val, num_classes = 5)
113
- loss = F.cross_entropy(y_hat, y_oh.float())
114
  preds = y_hat.argmax(dim = 1)
115
  acc = torch.eq(y_val, preds).float().mean()
116
  self.log('valid loss', loss, prog_bar = True)
@@ -120,9 +147,9 @@ class race_lightning(pl.LightningModule):
120
  class Ultimate_Lightning(pl.LightningModule):
121
  def __init__(self):
122
  super().__init__()
123
- self.age_model = CustomModelMain('Regression', 1)
124
- self.gender_model = CustomModelMain('Classification', 1)
125
- self.race_model = CustomModelMain('Classification', 5)
126
  def forward(self, x):
127
  return self.age_model(x), self.gender_model(x), self.race_model(x)
128
  def training_step(self, batch, batch_idx):
@@ -177,49 +204,4 @@ class Ultimate_Lightning(pl.LightningModule):
177
  self.log('val race acc', race_acc, prog_bar = True)
178
 
179
  def configure_optimizers(self):
180
- return torch.optim.Adam(self.parameters(), lr=1e-4)
181
-
182
- class CustomModelMain2(nn.Module):
183
- def __init__(self):
184
- super().__init__()
185
- self.backbone = timm.create_model('efficientnet_b0', pretrained = True, num_classes = 1)
186
- for name, param in self.backbone.named_parameters():
187
- if name.startswith('blocks'):
188
- if not 'blocks.6' in name:
189
- param.requires_grad = False
190
- else:
191
- param.requires_grad = True
192
- #if name == 'conv_stem.weight' or 'bn1.weight' or 'bn1.bias':
193
- # param.requires_grad = False
194
- num_features = self.backbone.classifier.in_features
195
- self.classifer = nn.Sequential(
196
- nn.Linear(num_features, 256),
197
- nn.ReLU(inplace = True),
198
- nn.Linear(256 + 8, 1),
199
- nn.ReLU()
200
- )
201
- def forward(self, x):
202
- x = self.backbone(x)
203
- return x
204
- class age_lightning(pl.LightningModule):
205
- def __init__(self):
206
- super().__init__()
207
- self.model = CustomModelMain2()
208
- def forward(self, x):
209
- return self.model(x)
210
- def training_step(self, batch, batch_idx):
211
- x, y = batch
212
- y = y[:, 0]
213
- y_hat = self(x)
214
- loss = F.mse_loss(y_hat, y.unsqueeze(-1).float())
215
- acc = torch.eq((y_hat > 0.5).int().to(torch.int64), y.unsqueeze(-1).int()).all(dim=1).sum() / len(y)
216
- self.log('train loss', loss, prog_bar = True)
217
- return loss
218
- def validation_step(self, batch, batch_idx):
219
- x, y = batch
220
- y_val = y[:, 0]
221
- y_hat = self(x)
222
- loss = F.mse_loss(y_hat, y_val.unsqueeze(-1).float())
223
- self.log('valid loss', loss, prog_bar = True)
224
- def configure_optimizers(self):
225
- return torch.optim.Adam(self.parameters(), lr=1e-4)
 
3
  import pytorch_lightning as pl
4
  import torch.nn.functional as F
5
  import timm
6
+ class CustomModelMain_Old(nn.Module):
7
  def __init__(self, problem_type, n_classes):
8
  super().__init__()
9
  if problem_type == 'Classification' and n_classes == 1:
 
41
  x = self.fc2(x)
42
  x = self.output(x)
43
  return x
44
+ class CustomModelMain_New(nn.Module):
45
+ def __init__(self, problem_type, n_classes):
46
+ super().__init__()
47
+ if problem_type == 'Classification' and n_classes == 1:
48
+ output = nn.Sigmoid()
49
+ elif problem_type == 'Regression' and n_classes == 1:
50
+ output = nn.ReLU()
51
+ elif problem_type == 'Classification' and n_classes > 1:
52
+ output = nn.Softmax(dim = 1)
53
+
54
+ self.backbone = timm.create_model('efficientnet_b0', pretrained = True, num_classes = n_classes)
55
+ for name, param in self.backbone.named_parameters():
56
+ if name.startswith('blocks'):
57
+ if not 'blocks.5' in name:
58
+ param.requires_grad = False
59
+ else:
60
+ param.requires_grad = True
61
+ num_features = self.backbone.classifier.in_features
62
+ self.backbone.classifier = nn.Sequential(
63
+ nn.Linear(num_features, 256),
64
+ nn.ReLU(),
65
+ nn.Linear(256, n_classes),
66
+ output
67
+ )
68
+ def forward(self, x):
69
+ x = self.backbone(x)
70
+ return x
71
  class age_lightningg(pl.LightningModule):
72
  def __init__(self):
73
  super().__init__()
74
+ self.model = CustomModelMain_New('Regression', 1)
75
  def forward(self, x):
76
  return self.model(x)
77
  def training_step(self, batch, batch_idx):
 
93
  class gender_lightning(pl.LightningModule):
94
  def __init__(self):
95
  super().__init__()
96
+ self.model = CustomModelMain_New('Classification', 1)
97
  def forward(self, x):
98
  return self.model(x)
99
  def training_step(self, batch, batch_idx):
 
118
  class race_lightning(pl.LightningModule):
119
  def __init__(self):
120
  super().__init__()
121
+ self.model = CustomModelMain_New('Classification', 5)
122
  def forward(self, x):
123
  return self.model(x)
124
  def training_step(self, batch, batch_idx):
 
137
  y_val = y[:, 2]
138
  y_hat = self(x)
139
  y_oh = F.one_hot(y_val, num_classes = 5)
140
+ loss = F.cross_entropy(y_hat.log(), y_oh.float())
141
  preds = y_hat.argmax(dim = 1)
142
  acc = torch.eq(y_val, preds).float().mean()
143
  self.log('valid loss', loss, prog_bar = True)
 
147
  class Ultimate_Lightning(pl.LightningModule):
148
  def __init__(self):
149
  super().__init__()
150
+ self.age_model = CustomModelMain_New('Regression', 1)
151
+ self.gender_model = CustomModelMain_New('Classification', 1)
152
+ self.race_model = CustomModelMain_New('Classification', 5)
153
  def forward(self, x):
154
  return self.age_model(x), self.gender_model(x), self.race_model(x)
155
  def training_step(self, batch, batch_idx):
 
204
  self.log('val race acc', race_acc, prog_bar = True)
205
 
206
  def configure_optimizers(self):
207
+ return torch.optim.Adam(self.parameters(), lr=1e-4)