| import os | |
| import torch | |
| class ModelSaver: | |
| def __init__(self, id): | |
| if id is None: | |
| self.no_log = True | |
| print('Saving has been disbaled.') | |
| else: | |
| self.no_log = False | |
| self.save_path = os.path.join('.', 'weights', '%s' % id ) | |
| def save_model(self, model, step): | |
| if self.no_log: | |
| print('Saving has been disabled.') | |
| return | |
| os.makedirs(self.save_path, exist_ok=True) | |
| model_path = os.path.join(self.save_path, 'model_%s' % step) | |
| torch.save(model.state_dict(), model_path) | |
| print('Model saved to %s.' % model_path) | |