| import os |
| import torch |
| import torch.optim as optim |
| from tqdm import tqdm |
|
|
| from torch.autograd import Variable |
|
|
| from network_v0.model import PointModel |
| from loss_function import KeypointLoss |
|
|
|
|
| class Trainer(object): |
| def __init__(self, config, train_loader=None): |
| self.config = config |
| |
| self.train_loader = train_loader |
| self.num_train = len(self.train_loader) |
|
|
| |
| self.max_epoch = config.max_epoch |
| self.start_epoch = config.start_epoch |
| self.momentum = config.momentum |
| self.lr = config.init_lr |
| self.lr_factor = config.lr_factor |
| self.display = config.display |
|
|
| |
| self.use_gpu = config.use_gpu |
| self.random_seed = config.seed |
| self.gpu = config.gpu |
| self.ckpt_dir = config.ckpt_dir |
| self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) |
|
|
| |
| self.model = PointModel(is_test=False) |
|
|
| |
| if self.use_gpu: |
| torch.cuda.set_device(self.gpu) |
| self.model.cuda() |
|
|
| print( |
| "Number of model parameters: {:,}".format( |
| sum([p.data.nelement() for p in self.model.parameters()]) |
| ) |
| ) |
|
|
| |
| self.loss_func = KeypointLoss(config) |
|
|
| |
| self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) |
| self.lr_scheduler = optim.lr_scheduler.MultiStepLR( |
| self.optimizer, milestones=[4, 8], gamma=self.lr_factor |
| ) |
|
|
| |
| if int(self.config.start_epoch) > 0: |
| ( |
| self.config.start_epoch, |
| self.model, |
| self.optimizer, |
| self.lr_scheduler, |
| ) = self.load_checkpoint( |
| int(self.config.start_epoch), |
| self.model, |
| self.optimizer, |
| self.lr_scheduler, |
| ) |
|
|
| def train(self): |
| print("\nTrain on {} samples".format(self.num_train)) |
| self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) |
| for epoch in range(self.start_epoch, self.max_epoch): |
| print( |
| "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) |
| ) |
| |
| self.train_one_epoch(epoch) |
| if self.lr_scheduler: |
| self.lr_scheduler.step() |
| self.save_checkpoint( |
| epoch + 1, self.model, self.optimizer, self.lr_scheduler |
| ) |
|
|
| def train_one_epoch(self, epoch): |
| self.model.train() |
| for (i, data) in enumerate(tqdm(self.train_loader)): |
|
|
| if self.use_gpu: |
| source_img = data["image_aug"].cuda() |
| target_img = data["image"].cuda() |
| homography = data["homography"].cuda() |
|
|
| source_img = Variable(source_img) |
| target_img = Variable(target_img) |
| homography = Variable(homography) |
|
|
| |
| output = self.model(source_img, target_img, homography) |
|
|
| |
| loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) |
|
|
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
|
|
| |
| msg_batch = ( |
| "Epoch:{} Iter:{} lr:{:.4f} " |
| "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " |
| "loss={:.4f} ".format( |
| (epoch + 1), |
| i, |
| self.lr, |
| loc_loss.data, |
| desc_loss.data, |
| score_loss.data, |
| corres_loss.data, |
| loss.data, |
| ) |
| ) |
|
|
| if (i % self.display) == 0: |
| print(msg_batch) |
| return |
|
|
| def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
| filename = self.ckpt_name + "_" + str(epoch) + ".pth" |
| torch.save( |
| { |
| "epoch": epoch, |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "lr_scheduler": lr_scheduler.state_dict(), |
| }, |
| os.path.join(self.ckpt_dir, filename), |
| ) |
|
|
| def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
| filename = self.ckpt_name + "_" + str(epoch) + ".pth" |
| ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) |
| epoch = ckpt["epoch"] |
| model.load_state_dict(ckpt["model_state"]) |
| optimizer.load_state_dict(ckpt["optimizer_state"]) |
| lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) |
|
|
| print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) |
|
|
| return epoch, model, optimizer, lr_scheduler |
|
|