File size: 641 Bytes
352cafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import torch

class ModelSaver:
    def __init__(self, id):

        if id is None:
            self.no_log = True
            print('Saving has been disbaled.')
        else:
            self.no_log = False

            self.save_path = os.path.join('.', 'weights', '%s' % id )

    def save_model(self, model, step):
        if self.no_log:
            print('Saving has been disabled.')
            return

        os.makedirs(self.save_path, exist_ok=True)

        model_path = os.path.join(self.save_path, 'model_%s' % step)
        torch.save(model.state_dict(), model_path)
        print('Model saved to %s.' % model_path)