| | import argparse |
| |
|
| | import numpy as np |
| |
|
| | import os |
| |
|
| | import shutil |
| |
|
| | import torch |
| | import torch.optim as optim |
| |
|
| | from torch.utils.data import DataLoader |
| |
|
| | from tqdm import tqdm |
| |
|
| | import warnings |
| |
|
| | from lib.dataset import MegaDepthDataset |
| | from lib.exceptions import NoGradientError |
| | from lib.loss import loss_function |
| | from lib.model import D2Net |
| |
|
| |
|
| | |
| | use_cuda = torch.cuda.is_available() |
| | device = torch.device("cuda:0" if use_cuda else "cpu") |
| |
|
| | |
| | torch.manual_seed(1) |
| | if use_cuda: |
| | torch.cuda.manual_seed(1) |
| | np.random.seed(1) |
| |
|
| | |
| | parser = argparse.ArgumentParser(description="Training script") |
| |
|
| | parser.add_argument( |
| | "--dataset_path", type=str, required=True, help="path to the dataset" |
| | ) |
| | parser.add_argument( |
| | "--scene_info_path", type=str, required=True, help="path to the processed scenes" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--preprocessing", |
| | type=str, |
| | default="caffe", |
| | help="image preprocessing (caffe or torch)", |
| | ) |
| | parser.add_argument( |
| | "--model_file", type=str, default="models/d2_ots.pth", help="path to the full model" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--num_epochs", type=int, default=10, help="number of training epochs" |
| | ) |
| | parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate") |
| | parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
| | parser.add_argument( |
| | "--num_workers", type=int, default=4, help="number of workers for data loading" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--use_validation", |
| | dest="use_validation", |
| | action="store_true", |
| | help="use the validation split", |
| | ) |
| | parser.set_defaults(use_validation=False) |
| |
|
| | parser.add_argument( |
| | "--log_interval", type=int, default=250, help="loss logging interval" |
| | ) |
| |
|
| | parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file") |
| |
|
| | parser.add_argument( |
| | "--plot", dest="plot", action="store_true", help="plot training pairs" |
| | ) |
| | parser.set_defaults(plot=False) |
| |
|
| | parser.add_argument( |
| | "--checkpoint_directory", |
| | type=str, |
| | default="checkpoints", |
| | help="directory for training checkpoints", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_prefix", |
| | type=str, |
| | default="d2", |
| | help="prefix for training checkpoints", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | print(args) |
| |
|
| | |
| | if args.plot: |
| | plot_path = "train_vis" |
| | if os.path.isdir(plot_path): |
| | print("[Warning] Plotting directory already exists.") |
| | else: |
| | os.mkdir(plot_path) |
| |
|
| | |
| | model = D2Net(model_file=args.model_file, use_cuda=use_cuda) |
| |
|
| | |
| | optimizer = optim.Adam( |
| | filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr |
| | ) |
| |
|
| | |
| | if args.use_validation: |
| | validation_dataset = MegaDepthDataset( |
| | scene_list_path="megadepth_utils/valid_scenes.txt", |
| | scene_info_path=args.scene_info_path, |
| | base_path=args.dataset_path, |
| | train=False, |
| | preprocessing=args.preprocessing, |
| | pairs_per_scene=25, |
| | ) |
| | validation_dataloader = DataLoader( |
| | validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers |
| | ) |
| |
|
| | training_dataset = MegaDepthDataset( |
| | scene_list_path="megadepth_utils/train_scenes.txt", |
| | scene_info_path=args.scene_info_path, |
| | base_path=args.dataset_path, |
| | preprocessing=args.preprocessing, |
| | ) |
| | training_dataloader = DataLoader( |
| | training_dataset, batch_size=args.batch_size, num_workers=args.num_workers |
| | ) |
| |
|
| |
|
| | |
| | def process_epoch( |
| | epoch_idx, |
| | model, |
| | loss_function, |
| | optimizer, |
| | dataloader, |
| | device, |
| | log_file, |
| | args, |
| | train=True, |
| | ): |
| | epoch_losses = [] |
| |
|
| | torch.set_grad_enabled(train) |
| |
|
| | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) |
| | for batch_idx, batch in progress_bar: |
| | if train: |
| | optimizer.zero_grad() |
| |
|
| | batch["train"] = train |
| | batch["epoch_idx"] = epoch_idx |
| | batch["batch_idx"] = batch_idx |
| | batch["batch_size"] = args.batch_size |
| | batch["preprocessing"] = args.preprocessing |
| | batch["log_interval"] = args.log_interval |
| |
|
| | try: |
| | loss = loss_function(model, batch, device, plot=args.plot) |
| | except NoGradientError: |
| | continue |
| |
|
| | current_loss = loss.data.cpu().numpy()[0] |
| | epoch_losses.append(current_loss) |
| |
|
| | progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses))) |
| |
|
| | if batch_idx % args.log_interval == 0: |
| | log_file.write( |
| | "[%s] epoch %d - batch %d / %d - avg_loss: %f\n" |
| | % ( |
| | "train" if train else "valid", |
| | epoch_idx, |
| | batch_idx, |
| | len(dataloader), |
| | np.mean(epoch_losses), |
| | ) |
| | ) |
| |
|
| | if train: |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | log_file.write( |
| | "[%s] epoch %d - avg_loss: %f\n" |
| | % ("train" if train else "valid", epoch_idx, np.mean(epoch_losses)) |
| | ) |
| | log_file.flush() |
| |
|
| | return np.mean(epoch_losses) |
| |
|
| |
|
| | |
| | if os.path.isdir(args.checkpoint_directory): |
| | print("[Warning] Checkpoint directory already exists.") |
| | else: |
| | os.mkdir(args.checkpoint_directory) |
| |
|
| |
|
| | |
| | if os.path.exists(args.log_file): |
| | print("[Warning] Log file already exists.") |
| | log_file = open(args.log_file, "a+") |
| |
|
| | |
| | train_loss_history = [] |
| | validation_loss_history = [] |
| | if args.use_validation: |
| | validation_dataset.build_dataset() |
| | min_validation_loss = process_epoch( |
| | 0, |
| | model, |
| | loss_function, |
| | optimizer, |
| | validation_dataloader, |
| | device, |
| | log_file, |
| | args, |
| | train=False, |
| | ) |
| |
|
| | |
| | for epoch_idx in range(1, args.num_epochs + 1): |
| | |
| | training_dataset.build_dataset() |
| | train_loss_history.append( |
| | process_epoch( |
| | epoch_idx, |
| | model, |
| | loss_function, |
| | optimizer, |
| | training_dataloader, |
| | device, |
| | log_file, |
| | args, |
| | ) |
| | ) |
| |
|
| | if args.use_validation: |
| | validation_loss_history.append( |
| | process_epoch( |
| | epoch_idx, |
| | model, |
| | loss_function, |
| | optimizer, |
| | validation_dataloader, |
| | device, |
| | log_file, |
| | args, |
| | train=False, |
| | ) |
| | ) |
| |
|
| | |
| | checkpoint_path = os.path.join( |
| | args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx) |
| | ) |
| | checkpoint = { |
| | "args": args, |
| | "epoch_idx": epoch_idx, |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "train_loss_history": train_loss_history, |
| | "validation_loss_history": validation_loss_history, |
| | } |
| | torch.save(checkpoint, checkpoint_path) |
| | if args.use_validation and validation_loss_history[-1] < min_validation_loss: |
| | min_validation_loss = validation_loss_history[-1] |
| | best_checkpoint_path = os.path.join( |
| | args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix |
| | ) |
| | shutil.copy(checkpoint_path, best_checkpoint_path) |
| |
|
| | |
| | log_file.close() |
| |
|