jbheeman commited on
Commit
61e3ebe
·
verified ·
1 Parent(s): e7fa458

Upload asdklasjkladkladskladskjladsklj/AgeClassificationModel with huggingface_hub

Browse files
asdklasjkladkladskladskjladsklj/AgeClassificationModel ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_feature_classification_model import BaseFeatureClassificationModel
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ComplexNN(nn.Module):
6
+ def __init__(self,n_classes,input_dim):
7
+ super(ComplexNN, self).__init__()
8
+ self.fc1 = nn.Linear(input_dim, 8)
9
+ self.fc2 = nn.Linear(8, n_classes)
10
+
11
+ def forward(self, x):
12
+ x = F.relu(self.fc1(x))
13
+ x = self.fc2(x)
14
+ return x
15
+
16
+ class AgeClassificationModel(BaseFeatureClassificationModel):
17
+ def __init__(self, ckpt, batch_size=64, val_batch_size=128,
18
+ learning_rate=1e-4, epochs=500, es_paiteince=5, num_workers=8, train_ratio=0.85,device='gpu',gpu_ids=[0]):
19
+ super().__init__(ckpt, batch_size, val_batch_size,
20
+ learning_rate, epochs, es_paiteince, num_workers, train_ratio,device,gpu_ids)
21
+
22
+ def input_dataset(self, dataset):
23
+ model = ComplexNN(n_classes=dataset.n_classes, input_dim=dataset.num_features)
24
+ super().input_dataset(dataset,model)
25
+
26
+
27
+