Update train_only_GEGLU.py
Browse files- train_only_GEGLU.py +3 -5
train_only_GEGLU.py
CHANGED
|
@@ -162,7 +162,7 @@ def test(dataloader, model, loss_fn):
|
|
| 162 |
|
| 163 |
# apply train and test
|
| 164 |
|
| 165 |
-
logname = "/
|
| 166 |
if not os.path.exists(logname):
|
| 167 |
with open(logname, 'w') as logfile:
|
| 168 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -174,9 +174,7 @@ epochs = 100
|
|
| 174 |
for epoch in range(epochs):
|
| 175 |
print(f"Epoch {epoch+1}\n-----------------------------------")
|
| 176 |
train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
|
| 177 |
-
|
| 178 |
-
#if scheduler is not None:
|
| 179 |
-
# scheduler.step()
|
| 180 |
test_loss, test_acc = test(test_dataloader, model, loss_fn)
|
| 181 |
with open(logname, 'a') as logfile:
|
| 182 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -186,7 +184,7 @@ print("Done!")
|
|
| 186 |
|
| 187 |
# saving trained model
|
| 188 |
|
| 189 |
-
path = "/
|
| 190 |
model_name = "ACTIVATOR_only_GEGLUImageClassification_cifar10"
|
| 191 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 192 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
|
| 162 |
|
| 163 |
# apply train and test
|
| 164 |
|
| 165 |
+
logname = "/PATH/Activator/Experiments_cifar10/logs_activator/logs_cifar10_only_geglu.csv"
|
| 166 |
if not os.path.exists(logname):
|
| 167 |
with open(logname, 'w') as logfile:
|
| 168 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 174 |
for epoch in range(epochs):
|
| 175 |
print(f"Epoch {epoch+1}\n-----------------------------------")
|
| 176 |
train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
|
| 177 |
+
|
|
|
|
|
|
|
| 178 |
test_loss, test_acc = test(test_dataloader, model, loss_fn)
|
| 179 |
with open(logname, 'a') as logfile:
|
| 180 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 184 |
|
| 185 |
# saving trained model
|
| 186 |
|
| 187 |
+
path = "/PATH/Activator/Experiments_cifar10/weights_activator"
|
| 188 |
model_name = "ACTIVATOR_only_GEGLUImageClassification_cifar10"
|
| 189 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 190 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|