drozdgk's picture
chore: vendor third_party (remove submodules, ignore artifacts)
352cafd
import os
import torch
class ModelSaver:
def __init__(self, id):
if id is None:
self.no_log = True
print('Saving has been disbaled.')
else:
self.no_log = False
self.save_path = os.path.join('.', 'weights', '%s' % id )
def save_model(self, model, step):
if self.no_log:
print('Saving has been disabled.')
return
os.makedirs(self.save_path, exist_ok=True)
model_path = os.path.join(self.save_path, 'model_%s' % step)
torch.save(model.state_dict(), model_path)
print('Model saved to %s.' % model_path)