Spaces:
Running
Running
File size: 3,707 Bytes
9c4b1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# from pix2pix
import os
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
class BaseModel(nn.Module):
def __init__(self, opt):
super(BaseModel, self).__init__()
self.opt = opt
self.total_steps = 0
self.isTrain = opt.isTrain
self.lr = opt.lr
# self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
# self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.save_dir = os.path.join(f'./checkpoint/{opt.name}/weights/')
os.makedirs(self.save_dir, exist_ok=True)
#self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
self.device = torch.device(opt.device if torch.cuda.is_available() else 'cpu')
def save_networks(self, epoch):
# save_filename = 'model_epoch_%s.pth' % epoch
save_filename = f'{epoch}.pt'
save_path = os.path.join(self.save_dir, save_filename)
# serialize model and optimizer to dict
# state_dict = {
# 'model': self.model.state_dict(),
# 'optimizer' : self.optimizer.state_dict(),
# 'total_steps' : self.total_steps,
# }
torch.save(self.model.state_dict(), save_path)
print(f'Saving model {save_path}')
# load models from the disk
def load_networks(self, epoch):
# load_filename = 'model_epoch_%s.pth' % epoch
load_filename = f'{epoch}.pt'
load_path = os.path.join(self.save_dir, load_filename)
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=self.device)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
self.model.load_state_dict(state_dict['model'])
self.total_steps = state_dict['total_steps']
if self.isTrain and not self.opt.new_optim:
self.optimizer.load_state_dict(state_dict['optimizer'])
### move optimizer state to GPU
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.device)
for g in self.optimizer.param_groups:
g['lr'] = self.opt.lr
def eval(self):
self.model.eval()
def train(self):
self.model.train()
def test(self):
with torch.no_grad():
self.forward()
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
|