| import torch | |
| import torch.nn as nn | |
| import os | |
| import numpy as np | |
| from models import spinal_net | |
| import decoder | |
| import loss | |
| from dataset import BaseDataset | |
| def collater(data): | |
| out_data_dict = {} | |
| for name in data[0]: | |
| out_data_dict[name] = [] | |
| for sample in data: | |
| for name in sample: | |
| out_data_dict[name].append(torch.from_numpy(sample[name])) | |
| for name in out_data_dict: | |
| out_data_dict[name] = torch.stack(out_data_dict[name], dim=0) | |
| return out_data_dict | |
| class Network(object): | |
| def __init__(self, args): | |
| torch.manual_seed(317) | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| heads = {'hm': args.num_classes, | |
| 'reg': 2*args.num_classes, | |
| 'wh': 2*4,} | |
| self.model = spinal_net.SpineNet(heads=heads, | |
| pretrained=True, | |
| down_ratio=args.down_ratio, | |
| final_kernel=1, | |
| head_conv=256) | |
| self.num_classes = args.num_classes | |
| self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh) | |
| self.dataset = {'spinal': BaseDataset} | |
| def save_model(self, path, epoch, model): | |
| if isinstance(model, torch.nn.DataParallel): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| data = {'epoch': epoch, 'state_dict': state_dict} | |
| torch.save(data, path) | |
| def load_model(self, model, resume, strict=True): | |
| checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) | |
| print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch'])) | |
| state_dict_ = checkpoint['state_dict'] | |
| state_dict = {} | |
| for k in state_dict_: | |
| if k.startswith('module') and not k.startswith('module_list'): | |
| state_dict[k[7:]] = state_dict_[k] | |
| else: | |
| state_dict[k] = state_dict_[k] | |
| model_state_dict = model.state_dict() | |
| if not strict: | |
| for k in state_dict: | |
| if k in model_state_dict: | |
| if state_dict[k].shape != model_state_dict[k].shape: | |
| print('Skip loading parameter {}, required shape{}, ' \ | |
| 'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape)) | |
| state_dict[k] = model_state_dict[k] | |
| else: | |
| print('Drop parameter {}.'.format(k)) | |
| for k in model_state_dict: | |
| if not (k in state_dict): | |
| print('No param {}.'.format(k)) | |
| state_dict[k] = model_state_dict[k] | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def train_network(self, args): | |
| save_path = 'weights_'+args.dataset | |
| if not os.path.exists(save_path): | |
| os.mkdir(save_path) | |
| self.optimizer = torch.optim.Adam(self.model.parameters(), args.init_lr) | |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.96, last_epoch=-1) | |
| if args.ngpus>0: | |
| if torch.cuda.device_count() > 1: | |
| print("Let's use", torch.cuda.device_count(), "GPUs!") | |
| self.model = nn.DataParallel(self.model) | |
| self.model.to(self.device) | |
| criterion = loss.LossAll() | |
| print('Setting up data...') | |
| dataset_module = self.dataset[args.dataset] | |
| dsets = {x: dataset_module(data_dir=args.data_dir, | |
| phase=x, | |
| input_h=args.input_h, | |
| input_w=args.input_w, | |
| down_ratio=args.down_ratio) | |
| for x in ['train', 'val']} | |
| dsets_loader = {'train': torch.utils.data.DataLoader(dsets['train'], | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| collate_fn=collater), | |
| 'val':torch.utils.data.DataLoader(dsets['val'], | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=1, | |
| pin_memory=True, | |
| collate_fn=collater)} | |
| print('Starting training...') | |
| train_loss = [] | |
| val_loss = [] | |
| for epoch in range(1, args.num_epoch+1): | |
| print('-'*10) | |
| print('Epoch: {}/{} '.format(epoch, args.num_epoch)) | |
| epoch_loss = self.run_epoch(phase='train', | |
| data_loader=dsets_loader['train'], | |
| criterion=criterion) | |
| train_loss.append(epoch_loss) | |
| scheduler.step(epoch) | |
| epoch_loss = self.run_epoch(phase='val', | |
| data_loader=dsets_loader['val'], | |
| criterion=criterion) | |
| val_loss.append(epoch_loss) | |
| np.savetxt(os.path.join(save_path, 'train_loss.txt'), train_loss, fmt='%.6f') | |
| np.savetxt(os.path.join(save_path, 'val_loss.txt'), val_loss, fmt='%.6f') | |
| if epoch % 10 == 0 or epoch ==1: | |
| self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)), epoch, self.model) | |
| if len(val_loss)>1: | |
| if val_loss[-1]<np.min(val_loss[:-1]): | |
| self.save_model(os.path.join(save_path, 'model_last.pth'), epoch, self.model) | |
| def run_epoch(self, phase, data_loader, criterion): | |
| if phase == 'train': | |
| self.model.train() | |
| else: | |
| self.model.eval() | |
| running_loss = 0. | |
| for data_dict in data_loader: | |
| for name in data_dict: | |
| data_dict[name] = data_dict[name].to(device=self.device) | |
| if phase == 'train': | |
| self.optimizer.zero_grad() | |
| with torch.enable_grad(): | |
| pr_decs = self.model(data_dict['input']) | |
| loss = criterion(pr_decs, data_dict) | |
| loss.backward() | |
| self.optimizer.step() | |
| else: | |
| with torch.no_grad(): | |
| pr_decs = self.model(data_dict['input']) | |
| loss = criterion(pr_decs, data_dict) | |
| running_loss += loss.item() | |
| epoch_loss = running_loss / len(data_loader) | |
| print('{} loss: {}'.format(phase, epoch_loss)) | |
| return epoch_loss | |