Spaces:
Build error
Build error
| """Implements training new models""" | |
| import time | |
| import copy | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from cirtorch.layers.loss import ContrastiveLoss | |
| from cirtorch.datasets.datahelpers import collate_tuples | |
| from cirtorch.datasets.traindataset import TuplesDataset | |
| from cirtorch.datasets.genericdataset import ImagesFromList | |
| from ..networks import how_net | |
| from ..utils import data_helpers, io_helpers, logging, plots | |
| from . import evaluate | |
| def train(demo_train, training, validation, model, globals): | |
| """Demo training a network | |
| :param dict demo_train: Demo-related options | |
| :param dict training: Training options | |
| :param dict validation: Validation options | |
| :param dict model: Model options | |
| :param dict globals: Global options | |
| """ | |
| logger = globals["logger"] | |
| (globals["exp_path"] / "epochs").mkdir(exist_ok=True) | |
| if (globals["exp_path"] / f"epochs/model_epoch{training['epochs']}.pth").exists(): | |
| logger.info("Skipping network training, already trained") | |
| return | |
| # Global setup | |
| set_seed(0) | |
| globals["device"] = torch.device("cpu") | |
| if demo_train['gpu_id'] is not None: | |
| globals["device"] = torch.device(("cuda:%s" % demo_train['gpu_id'])) | |
| # Initialize network | |
| net = how_net.init_network(**model).to(globals["device"]) | |
| globals["transform"] = transforms.Compose([transforms.ToTensor(), \ | |
| transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))]) | |
| with logging.LoggingStopwatch("initializing network whitening", logger.info, logger.debug): | |
| initialize_dim_reduction(net, globals, **training['initialize_dim_reduction']) | |
| # Initialize training | |
| optimizer, scheduler, criterion, train_loader = \ | |
| initialize_training(net.parameter_groups(training["optimizer"]), training, globals) | |
| validation = Validation(validation, globals) | |
| for epoch in range(training['epochs']): | |
| epoch1 = epoch + 1 | |
| set_seed(epoch1) | |
| time0 = time.time() | |
| train_loss = train_epoch(train_loader, net, globals, criterion, optimizer, epoch1) | |
| validation.add_train_loss(train_loss, epoch1) | |
| validation.validate(net, epoch1) | |
| scheduler.step() | |
| io_helpers.save_checkpoint({ | |
| 'epoch': epoch1, 'meta': net.meta, 'state_dict': net.state_dict(), | |
| 'optimizer' : optimizer.state_dict(), 'best_score': validation.best_score[1], | |
| 'scores': validation.scores, 'net_params': model, '_version': 'how/2020', | |
| }, validation.best_score[0] == epoch1, epoch1 == training['epochs'], globals["exp_path"] / "epochs") | |
| logger.info(f"Epoch {epoch1} finished in {time.time() - time0:.1f}s") | |
| def train_epoch(train_loader, net, globals, criterion, optimizer, epoch1): | |
| """Train for one epoch""" | |
| logger = globals['logger'] | |
| batch_time = data_helpers.AverageMeter() | |
| data_time = data_helpers.AverageMeter() | |
| losses = data_helpers.AverageMeter() | |
| # Prepare epoch | |
| train_loader.dataset.create_epoch_tuples(net) | |
| net.train() | |
| end = time.time() | |
| for i, (input, target) in enumerate(train_loader): | |
| data_time.update(time.time() - end) | |
| optimizer.zero_grad() | |
| num_images = len(input[0]) # number of images per tuple | |
| for inp, trg in zip(input, target): | |
| output = torch.zeros(net.meta['outputdim'], num_images).to(globals["device"]) | |
| for imi in range(num_images): | |
| output[:, imi] = net(inp[imi].to(globals["device"])).squeeze() | |
| loss = criterion(output, trg.to(globals["device"])) | |
| loss.backward() | |
| losses.update(loss.item()) | |
| optimizer.step() | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if (i+1) % 20 == 0 or i == 0 or (i+1) == len(train_loader): | |
| logger.info(f'>> Train: [{epoch1}][{i+1}/{len(train_loader)}]\t' \ | |
| f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ | |
| f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ | |
| f'Loss {losses.val:.4f} ({losses.avg:.4f})') | |
| return losses.avg | |
| def set_seed(seed): | |
| """Sets given seed globally in used libraries""" | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| def initialize_training(net_parameters, training, globals): | |
| """Initialize classes necessary for training""" | |
| # Need to check for keys because of defaults | |
| assert training['optimizer'].keys() == {"lr", "weight_decay"} | |
| assert training['lr_scheduler'].keys() == {"gamma"} | |
| assert training['loss'].keys() == {"margin"} | |
| assert training['dataset'].keys() == {"name", "mode", "imsize", "nnum", "qsize", "poolsize"} | |
| assert training['loader'].keys() == {"batch_size"} | |
| optimizer = torch.optim.Adam(net_parameters, **training["optimizer"]) | |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **training["lr_scheduler"]) | |
| criterion = ContrastiveLoss(**training["loss"]).to(globals["device"]) | |
| train_dataset = TuplesDataset(**training['dataset'], transform=globals["transform"]) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, **training['loader'], \ | |
| pin_memory=True, drop_last=True, shuffle=True, collate_fn=collate_tuples, \ | |
| num_workers=how_net.NUM_WORKERS) | |
| return optimizer, scheduler, criterion, train_loader | |
| def extract_train_descriptors(net, globals, *, images, features_num): | |
| """Extract descriptors for a given number of images from the train set""" | |
| if features_num is None: | |
| features_num = net.runtime['features_num'] | |
| images = data_helpers.load_dataset('train', data_root=globals['root_path'])[0][:images] | |
| dataset = ImagesFromList(root='', images=images, imsize=net.runtime['image_size'], bbxs=None, | |
| transform=globals["transform"]) | |
| des_train = how_net.extract_vectors_local(net, dataset, globals["device"], | |
| scales=net.runtime['training_scales'], | |
| features_num=features_num)[0] | |
| return des_train | |
| def initialize_dim_reduction(net, globals, **kwargs): | |
| """Initialize dimensionality reduction by PCA whitening from 'images' number of descriptors""" | |
| if not net.dim_reduction: | |
| return | |
| print(">> Initializing dim reduction") | |
| des_train = extract_train_descriptors(net.copy_excluding_dim_reduction(), globals, **kwargs) | |
| net.dim_reduction.initialize_pca_whitening(des_train) | |
| class Validation: | |
| """A convenient interface to validation, keeping historical values and plotting continuously | |
| :param dict validations: Options for each validation type (e.g. local_descriptor) | |
| :param dict globals: Global options | |
| """ | |
| methods = { | |
| "global_descriptor": evaluate.eval_global, | |
| "local_descriptor": evaluate.eval_asmk, | |
| } | |
| def __init__(self, validations, globals): | |
| validations = copy.deepcopy(validations) | |
| self.frequencies = {x: y.pop("frequency") for x, y in validations.items()} | |
| self.validations = validations | |
| self.globals = globals | |
| self.scores = {x: defaultdict(list) for x in validations} | |
| self.scores["train_loss"] = [] | |
| def add_train_loss(self, loss, epoch): | |
| """Store training loss for given epoch""" | |
| self.scores['train_loss'].append((epoch, loss)) | |
| fig = plots.EpochFigure("train set", ylabel="loss") | |
| fig.plot(*list(zip(*self.scores["train_loss"])), 'o-', label='train') | |
| fig.save(self.globals['exp_path'] / "fig_train.jpg") | |
| def validate(self, net, epoch): | |
| """Perform validation of the network and store the resulting score for given epoch""" | |
| for name, frequency in self.frequencies.items(): | |
| if frequency and epoch % frequency == 0: | |
| scores = self.methods[name](net, net.runtime, self.globals, **self.validations[name]) | |
| for dataset, values in scores.items(): | |
| value = values['map_medium'] if "map_medium" in values else values['map'] | |
| self.scores[name][dataset].append((epoch, value)) | |
| if "val_eccv20" in scores: | |
| fig = plots.EpochFigure(f"val set - {name}", ylabel="mAP") | |
| fig.plot(*list(zip(*self.scores[name]['val_eccv20'])), 'o-', label='val') | |
| fig.save(self.globals['exp_path'] / f"fig_val_{name}.jpg") | |
| if scores.keys() - {"val_eccv20"}: | |
| fig = plots.EpochFigure(f"test set - {name}", ylabel="mAP") | |
| for dataset, value in self.scores[name].items(): | |
| if dataset != "val_eccv20": | |
| fig.plot(*list(zip(*value)), 'o-', label=dataset) | |
| fig.save(self.globals['exp_path'] / f"fig_test_{name}.jpg") | |
| def decisive_scores(self): | |
| """List of pairs (epoch, score) where score is decisive for comparing epochs""" | |
| for name in ["local_descriptor", "global_descriptor"]: | |
| if self.frequencies[name] and "val_eccv20" in self.scores[name]: | |
| return self.scores[name]['val_eccv20'] | |
| return self.scores["train_loss"] | |
| def last_epoch(self): | |
| """Tuple (last epoch, last score) or (None, None) before decisive score is computed""" | |
| decisive_scores = self.decisive_scores | |
| if not decisive_scores: | |
| return None, None | |
| return decisive_scores[-1] | |
| def best_score(self): | |
| """Tuple (best epoch, best score) or (None, None) before decisive score is computed""" | |
| decisive_scores = self.decisive_scores | |
| if not decisive_scores: | |
| return None, None | |
| aggr = min | |
| for name in ["local_descriptor", "global_descriptor"]: | |
| if self.frequencies[name] and "val_eccv20" in self.scores[name]: | |
| aggr = max | |
| return aggr(decisive_scores, key=lambda x: x[1]) | |