import argparse import os import sys import logging import numpy import numpy as np import torch import torch.utils.data import torchvision from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from tqdm import tqdm # Only if the files are in example folder. BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if BASE_DIR[-8:] == 'examples': sys.path.append(os.path.join(BASE_DIR, os.pardir)) os.chdir(os.path.join(BASE_DIR, os.pardir)) from learning3d.models import PRNet from learning3d.data_utils import RegistrationData, ModelNet40Data def _init_(args): if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('checkpoints/' + args.exp_name): os.makedirs('checkpoints/' + args.exp_name) if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'): os.makedirs('checkpoints/' + args.exp_name + '/' + 'models') os.system('cp train_dcp.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup') class IOStream: def __init__(self, path): self.f = open(path, 'a') def cprint(self, text): print(text) self.f.write(text + '\n') self.f.flush() def close(self): self.f.close() def get_transformations(igt): R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab return R_ab, translation_ab, R_ba, translation_ba def test_one_epoch(device, model, test_loader): model.eval() test_loss = 0.0 pred = 0.0 count = 0 for i, data in enumerate(tqdm(test_loader)): template, source, igt = data transformations = get_transformations(igt) transformations = [t.to(device) for t in transformations] R_ab, translation_ab, R_ba, translation_ba = transformations template = template.to(device) source = source.to(device) igt = igt.to(device) output = model(template, source, R_ab, translation_ab.squeeze(2)) loss_val = output['loss'] test_loss += loss_val.item() count += 1 test_loss = float(test_loss)/count return test_loss def test(args, model, test_loader, textio): test_loss = test_one_epoch(args.device, model, test_loader) textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy)) def train_one_epoch(device, model, train_loader, optimizer): model.train() train_loss = 0.0 pred = 0.0 count = 0 for i, data in enumerate(tqdm(train_loader)): template, source, igt = data transformations = get_transformations(igt) transformations = [t.to(device) for t in transformations] R_ab, translation_ab, R_ba, translation_ba = transformations template = template.to(device) source = source.to(device) igt = igt.to(device) output = model(template, source, R_ab, translation_ab.squeeze(2)) loss_val = output['loss'] # forward + backward + optimize optimizer.zero_grad() loss_val.backward() optimizer.step() train_loss += loss_val.item() count += 1 train_loss = float(train_loss)/count return train_loss def train(args, model, train_loader, test_loader, boardio, textio, checkpoint): learnable_params = filter(lambda p: p.requires_grad, model.parameters()) if args.optimizer == 'Adam': optimizer = torch.optim.Adam(learnable_params) else: optimizer = torch.optim.SGD(learnable_params, lr=0.1) if checkpoint is not None: min_loss = checkpoint['min_loss'] optimizer.load_state_dict(checkpoint['optimizer']) best_test_loss = np.inf for epoch in range(args.start_epoch, args.epochs): train_loss = train_one_epoch(args.device, model, train_loader, optimizer) test_loss = test_one_epoch(args.device, model, test_loader) if test_loss