Yingtao-Zheng commited on
Commit
9216f6f
·
1 Parent(s): a12e47d

Add ClearML for mlp

Browse files
Files changed (2) hide show
  1. models/mlp/train.py +28 -1
  2. requirements.txt +1 -0
models/mlp/train.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- import os
3
  import random
4
 
5
  import numpy as np
@@ -7,6 +7,8 @@ import torch
7
  import torch.nn as nn
8
  import torch.optim as optim
9
 
 
 
10
  from models.prepare_dataset import get_dataloaders
11
 
12
  CFG = {
@@ -24,6 +26,21 @@ CFG = {
24
  }
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def set_seed(seed: int):
28
  random.seed(seed)
29
  np.random.seed(seed)
@@ -152,6 +169,16 @@ def main():
152
  history["val_loss"].append(round(val_loss, 4))
153
  history["val_acc"].append(round(val_acc, 4))
154
 
 
 
 
 
 
 
 
 
 
 
155
  marker = ""
156
  if val_acc > best_val_acc:
157
  best_val_acc = val_acc
 
1
  import json
2
+ import os, sys
3
  import random
4
 
5
  import numpy as np
 
7
  import torch.nn as nn
8
  import torch.optim as optim
9
 
10
+ from clearml import Task
11
+
12
  from models.prepare_dataset import get_dataloaders
13
 
14
  CFG = {
 
26
  }
27
 
28
 
29
+ # ==== ClearML Initialisation =============================================
30
+ task = Task.init(
31
+ project_name="Focus Guard",
32
+ task_name=f"MLP Model Training",
33
+ tags=["training", "mlp_model"]
34
+ )
35
+
36
+ prefix = 'checkpoints/'+task.name+'_'+task.id+'/'
37
+ os.makedirs(prefix, exist_ok=True)
38
+
39
+ task.connect(CFG)
40
+
41
+
42
+
43
+ # ==== Model =============================================
44
  def set_seed(seed: int):
45
  random.seed(seed)
46
  np.random.seed(seed)
 
169
  history["val_loss"].append(round(val_loss, 4))
170
  history["val_acc"].append(round(val_acc, 4))
171
 
172
+
173
+ # Log scalars to ClearML
174
+ current_lr = optimizer.param_groups[0]['lr']
175
+ task.logger.report_scalar("Loss", "Train", float(train_loss), iteration=epoch)
176
+ task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
177
+ task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
178
+ task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
179
+ task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
180
+ task.logger.flush()
181
+
182
  marker = ""
183
  if val_acc > best_val_acc:
184
  best_val_acc = val_acc
requirements.txt CHANGED
@@ -5,3 +5,4 @@ torch>=2.0.0
5
  torchvision>=0.15.0
6
  scikit-learn>=1.2.0
7
  joblib>=1.2.0
 
 
5
  torchvision>=0.15.0
6
  scikit-learn>=1.2.0
7
  joblib>=1.2.0
8
+ clearml>=2.0.2