import argparse import os import sys import logging import numpy import numpy as np import torch import torch.utils.data import torchvision from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from tqdm import tqdm # Only if the files are in example folder. BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if BASE_DIR[-8:] == 'examples': sys.path.append(os.path.join(BASE_DIR, os.pardir)) os.chdir(os.path.join(BASE_DIR, os.pardir)) from learning3d.models import PointNet from learning3d.models import iPCRNet from learning3d.losses import ChamferDistanceLoss from learning3d.data_utils import RegistrationData, ModelNet40Data def _init_(args): if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('checkpoints/' + args.exp_name): os.makedirs('checkpoints/' + args.exp_name) if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'): os.makedirs('checkpoints/' + args.exp_name + '/' + 'models') os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup') os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup') class IOStream: def __init__(self, path): self.f = open(path, 'a') def cprint(self, text): print(text) self.f.write(text + '\n') self.f.flush() def close(self): self.f.close() def test_one_epoch(device, model, test_loader): model.eval() test_loss = 0.0 pred = 0.0 count = 0 for i, data in enumerate(tqdm(test_loader)): template, source, igt = data template = template.to(device) source = source.to(device) igt = igt.to(device) output = model(template, source) loss_val = ChamferDistanceLoss()(template, output['transformed_source']) test_loss += loss_val.item() count += 1 test_loss = float(test_loss)/count return test_loss def test(args, model, test_loader, textio): test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader) textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy)) def train_one_epoch(device, model, train_loader, optimizer): model.train() train_loss = 0.0 pred = 0.0 count = 0 for i, data in enumerate(tqdm(train_loader)): template, source, igt = data template = template.to(device) source = source.to(device) igt = igt.to(device) output = model(template, source) loss_val = ChamferDistanceLoss()(template, output['transformed_source']) # print(loss_val.item()) # forward + backward + optimize optimizer.zero_grad() loss_val.backward() optimizer.step() train_loss += loss_val.item() count += 1 train_loss = float(train_loss)/count return train_loss def train(args, model, train_loader, test_loader, boardio, textio, checkpoint): learnable_params = filter(lambda p: p.requires_grad, model.parameters()) if args.optimizer == 'Adam': optimizer = torch.optim.Adam(learnable_params) else: optimizer = torch.optim.SGD(learnable_params, lr=0.1) if checkpoint is not None: min_loss = checkpoint['min_loss'] optimizer.load_state_dict(checkpoint['optimizer']) best_test_loss = np.inf for epoch in range(args.start_epoch, args.epochs): train_loss = train_one_epoch(args.device, model, train_loader, optimizer) test_loss = test_one_epoch(args.device, model, test_loader) if test_loss