File size: 528 Bytes
c9e0c1d
 
06c8a6d
c9e0c1d
 
 
 
 
0e63e05
06c8a6d
 
 
 
 
 
 
 
c9e0c1d
06c8a6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch


class ModelConfig:
    def __init__(self):
        self.learning_rate = 0.001
        self.batch_size = 32
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.epochs = 5
        self.log_interval = 2  # Log every 2 batches => number of items is 32*2 = 64
        
        # Wandb config
        self.wandb = True
        self.wandb_project = "template-pytorch-model"
        self.wandb_entity = "nguyen"
        self.wandb_api_key = ""

    def get_config(self):
        return self