Spaces:
Running
Running
| import torch | |
| import os | |
| import torch.nn.functional as F | |
| def save_checkpoint(model, filelocation, save_parallel = True): | |
| if save_parallel: | |
| torch.save(model.module.state_dict(), filelocation) | |
| else: | |
| torch.save(model.state_dict(), filelocation) | |
| def load_Checkpoint(fileLocation,model, load_cpu=False): | |
| if load_cpu: | |
| model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage)) | |
| else: | |
| model.load_state_dict(torch.load(fileLocation)) | |
| return model | |
| def writeLog(logList, filename): | |
| with open(filename, 'w') as outfile: | |
| outfile.write("\n".join(logList)) | |
| def kl_loss(mu, logvar): | |
| return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean() | |