Spaces:
Build error
Build error
| """Implements evaluation of trained models""" | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| import pickle | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from PIL import ImageFile | |
| from cirtorch.datasets.genericdataset import ImagesFromList | |
| from asmk import asmk_method, kernel as kern_pkg | |
| from ..networks import how_net | |
| from ..utils import score_helpers, data_helpers, logging | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| warnings.filterwarnings("ignore", r"^Possibly corrupt EXIF data", category=UserWarning) | |
| def evaluate_demo(demo_eval, evaluation, globals): | |
| """Demo evaluating a trained network | |
| :param dict demo_eval: Demo-related options | |
| :param dict evaluation: Evaluation-related options | |
| :param dict globals: Global options | |
| """ | |
| globals["device"] = torch.device("cpu") | |
| if demo_eval['gpu_id'] is not None: | |
| globals["device"] = torch.device(("cuda:%s" % demo_eval['gpu_id'])) | |
| # Handle net_path when directory | |
| net_path = Path(demo_eval['exp_folder']) / demo_eval['net_path'] | |
| if net_path.is_dir() and (net_path / "epochs/model_best.pth").exists(): | |
| net_path = net_path / "epochs/model_best.pth" | |
| # Load net | |
| state = _convert_checkpoint(torch.load(net_path, map_location='cpu')) | |
| net = how_net.init_network(**state['net_params']).to(globals['device']) | |
| net.load_state_dict(state['state_dict']) | |
| globals["transform"] = transforms.Compose([transforms.ToTensor(), \ | |
| transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))]) | |
| # Eval | |
| if evaluation['global_descriptor']['datasets']: | |
| eval_global(net, evaluation['inference'], globals, **evaluation['global_descriptor']) | |
| if evaluation['multistep']: | |
| eval_asmk_multistep(net, evaluation['inference'], evaluation['multistep'], globals, **evaluation['local_descriptor']) | |
| elif evaluation['local_descriptor']['datasets']: | |
| eval_asmk(net, evaluation['inference'], globals, **evaluation['local_descriptor']) | |
| def eval_global(net, inference, globals, *, datasets): | |
| """Evaluate global descriptors""" | |
| net.eval() | |
| time0 = time.time() | |
| logger = globals["logger"] | |
| logger.info("Starting global evaluation") | |
| results = {} | |
| for dataset in datasets: | |
| images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path']) | |
| logger.info(f"Evaluating {dataset}") | |
| with logging.LoggingStopwatch("extracting database images", logger.info, logger.debug): | |
| dset = ImagesFromList(root='', images=images, imsize=inference['image_size'], bbxs=None, | |
| transform=globals['transform']) | |
| vecs = how_net.extract_vectors(net, dset, globals["device"], scales=inference['scales']) | |
| with logging.LoggingStopwatch("extracting query images", logger.info, logger.debug): | |
| qdset = ImagesFromList(root='', images=qimages, imsize=inference['image_size'], bbxs=bbxs, | |
| transform=globals['transform']) | |
| qvecs = how_net.extract_vectors(net, qdset, globals["device"], scales=inference['scales']) | |
| vecs, qvecs = vecs.numpy(), qvecs.numpy() | |
| ranks = np.argsort(-np.dot(vecs, qvecs.T), axis=0) | |
| results[dataset] = score_helpers.compute_map_and_log(dataset, ranks, gnd, logger=logger) | |
| logger.info(f"Finished global evaluation in {int(time.time()-time0) // 60} min") | |
| return results | |
| def eval_asmk(net, inference, globals, *, datasets, codebook_training, asmk): | |
| """Evaluate local descriptors with ASMK""" | |
| net.eval() | |
| time0 = time.time() | |
| logger = globals["logger"] | |
| logger.info("Starting asmk evaluation") | |
| asmk = asmk_method.ASMKMethod.initialize_untrained(asmk) | |
| asmk = asmk_train_codebook(net, inference, globals, logger, codebook_training=codebook_training, | |
| asmk=asmk, cache_path=None) | |
| results = {} | |
| for dataset in datasets: | |
| dataset_name = dataset if isinstance(dataset, str) else dataset['name'] | |
| images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path']) | |
| logger.info(f"Evaluating '{dataset_name}'") | |
| asmk_dataset = asmk_index_database(net, inference, globals, logger, asmk=asmk, images=images) | |
| asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset, | |
| qimages=qimages, bbxs=bbxs, gnd=gnd, results=results, | |
| cache_path=globals["exp_path"] / "query_results.pkl") | |
| logger.info(f"Finished asmk evaluation in {int(time.time()-time0) // 60} min") | |
| return results | |
| def eval_asmk_multistep(net, inference, multistep, globals, *, datasets, codebook_training, asmk): | |
| """Evaluate local descriptors with ASMK""" | |
| valid_steps = ["train_codebook", "aggregate_database", "build_ivf", "query_ivf", "aggregate_build_query"] | |
| assert multistep['step'] in valid_steps, multistep['step'] | |
| net.eval() | |
| time0 = time.time() | |
| logger = globals["logger"] | |
| (globals["exp_path"] / "eval").mkdir(exist_ok=True) | |
| logger.info(f"Starting asmk evaluation step '{multistep['step']}'") | |
| # Handle partitioning | |
| partition = {"suffix": "", "norm_start": 0, "norm_end": 1} | |
| if multistep.get("partition"): | |
| total, index = multistep['partition'] | |
| partition = {"suffix": f":{total}_{str(index).zfill(len(str(total-1)))}", | |
| "norm_start": index / total, | |
| "norm_end": (index+1) / total} | |
| if multistep['step'] == "aggregate_database" or multistep['step'] == "query_ivf": | |
| logger.info(f"Processing partition '{total}_{index}'") | |
| # Handle distractors | |
| distractors_path = None | |
| distractors = multistep.get("distractors") | |
| if distractors: | |
| distractors_path = globals["exp_path"] / f"eval/{distractors}.ivf.pkl" | |
| # Train codebook | |
| asmk = asmk_method.ASMKMethod.initialize_untrained(asmk) | |
| cdb_path = globals["exp_path"] / "eval/codebook.pkl" | |
| if multistep['step'] == "train_codebook": | |
| asmk_train_codebook(net, inference, globals, logger, codebook_training=codebook_training, | |
| asmk=asmk, cache_path=cdb_path) | |
| return None | |
| asmk = asmk.train_codebook(None, cache_path=cdb_path) | |
| results = {} | |
| for dataset in datasets: | |
| dataset_name = database_name = dataset if isinstance(dataset, str) else dataset['name'] | |
| if distractors and multistep['step'] != "aggregate_database": | |
| dataset_name = f"{distractors}_{database_name}" | |
| images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path']) | |
| logger.info(f"Processing dataset '{dataset_name}'") | |
| # Infer database | |
| if multistep['step'] == "aggregate_database": | |
| agg_path = globals["exp_path"] / f"eval/{database_name}.agg{partition['suffix']}.pkl" | |
| asmk_aggregate_database(net, inference, globals, logger, asmk=asmk, images=images, | |
| partition=partition, cache_path=agg_path) | |
| # Build ivf | |
| elif multistep['step'] == "build_ivf": | |
| ivf_path = globals["exp_path"] / f"eval/{dataset_name}.ivf.pkl" | |
| asmk_build_ivf(globals, logger, asmk=asmk, cache_path=ivf_path, database_name=database_name, | |
| distractors=distractors, distractors_path=distractors_path) | |
| # Query ivf | |
| elif multistep['step'] == "query_ivf": | |
| asmk_dataset = asmk.build_ivf(None, None, cache_path=globals["exp_path"] / f"eval/{dataset_name}.ivf.pkl") | |
| start, end = int(len(qimages)*partition['norm_start']), int(len(qimages)*partition['norm_end']) | |
| bbxs = bbxs[start:end] if bbxs is not None else None | |
| results_path = globals["exp_path"] / f"eval/{dataset_name}.results{partition['suffix']}.pkl" | |
| asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset, | |
| qimages=qimages[start:end], bbxs=bbxs, gnd=gnd, results=results, | |
| cache_path=results_path, imid_offset=start) | |
| # All 3 dataset steps | |
| elif multistep['step'] == "aggregate_build_query": | |
| if multistep.get("partition"): | |
| raise NotImplementedError("Partitions within step 'aggregate_build_query' are not" \ | |
| " supported, use separate steps") | |
| results_path = globals["exp_path"] / "query_results.pkl" | |
| if gnd is None and results_path.exists(): | |
| logger.debug("Step results already exist") | |
| continue | |
| asmk_dataset = asmk_index_database(net, inference, globals, logger, asmk=asmk, images=images, | |
| distractors_path=distractors_path) | |
| asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset, | |
| qimages=qimages, bbxs=bbxs, gnd=gnd, results=results, cache_path=results_path) | |
| logger.info(f"Finished asmk evaluation step '{multistep['step']}' in {int(time.time()-time0) // 60} min") | |
| return results | |
| # | |
| # Separate steps | |
| # | |
| def asmk_train_codebook(net, inference, globals, logger, *, codebook_training, asmk, cache_path): | |
| """Asmk evaluation step 'train_codebook'""" | |
| if cache_path and cache_path.exists(): | |
| return asmk.train_codebook(None, cache_path=cache_path) | |
| images = data_helpers.load_dataset('train', data_root=globals['root_path'])[0] | |
| images = images[:codebook_training['images']] | |
| dset = ImagesFromList(root='', images=images, imsize=inference['image_size'], bbxs=None, | |
| transform=globals['transform']) | |
| infer_opts = {"scales": codebook_training['scales'], "features_num": inference['features_num']} | |
| des_train = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts)[0] | |
| asmk = asmk.train_codebook(des_train, cache_path=cache_path) | |
| logger.info(f"Codebook trained in {asmk.metadata['train_codebook']['train_time']:.1f}s") | |
| return asmk | |
| def asmk_aggregate_database(net, inference, globals, logger, *, asmk, images, partition, cache_path): | |
| """Asmk evaluation step 'aggregate_database'""" | |
| if cache_path.exists(): | |
| logger.debug("Step results already exist") | |
| return | |
| codebook = asmk.codebook | |
| kernel = kern_pkg.ASMKKernel(codebook, **asmk.params['build_ivf']['kernel']) | |
| start, end = int(len(images)*partition['norm_start']), int(len(images)*partition['norm_end']) | |
| data_opts = {"imsize": inference['image_size'], "transform": globals['transform']} | |
| infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']} | |
| # Aggregate database | |
| dset = ImagesFromList(root='', images=images[start:end], bbxs=None, **data_opts) | |
| vecs, imids, *_ = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts) | |
| imids += start | |
| quantized = codebook.quantize(vecs, imids, **asmk.params["build_ivf"]["quantize"]) | |
| aggregated = kernel.aggregate(*quantized, **asmk.params["build_ivf"]["aggregate"]) | |
| with cache_path.open("wb") as handle: | |
| pickle.dump(dict(zip(["des", "word_ids", "image_ids"], aggregated)), handle) | |
| def asmk_build_ivf(globals, logger, *, asmk, cache_path, database_name, distractors, distractors_path): | |
| """Asmk evaluation step 'build_ivf'""" | |
| if cache_path.exists(): | |
| logger.debug("Step results already exist") | |
| return asmk.build_ivf(None, None, cache_path=cache_path) | |
| builder = asmk.create_ivf_builder(cache_path=cache_path) | |
| # Build ivf | |
| if not builder.loaded_from_cache: | |
| if distractors: | |
| builder.initialize_with_distractors(distractors_path) | |
| logger.debug(f"Loaded ivf with distractors '{distractors}'") | |
| for path in sorted(globals["exp_path"].glob(f"eval/{database_name}.agg*.pkl")): | |
| with path.open("rb") as handle: | |
| des = pickle.load(handle) | |
| builder.ivf.add(des['des'], des['word_ids'], des['image_ids']) | |
| logger.info(f"Indexed '{path.name}'") | |
| asmk_dataset = asmk.add_ivf_builder(builder) | |
| logger.debug(f"IVF stats: {asmk_dataset.metadata['build_ivf']['ivf_stats']}") | |
| return asmk_dataset | |
| def asmk_index_database(net, inference, globals, logger, *, asmk, images, distractors_path=None): | |
| """Asmk evaluation step 'aggregate_database' and 'build_ivf'""" | |
| data_opts = {"imsize": inference['image_size'], "transform": globals['transform']} | |
| infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']} | |
| # Index database vectors | |
| dset = ImagesFromList(root='', images=images, bbxs=None, **data_opts) | |
| vecs, imids, *_ = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts) | |
| asmk_dataset = asmk.build_ivf(vecs, imids, distractors_path=distractors_path) | |
| logger.info(f"Indexed images in {asmk_dataset.metadata['build_ivf']['index_time']:.2f}s") | |
| logger.debug(f"IVF stats: {asmk_dataset.metadata['build_ivf']['ivf_stats']}") | |
| return asmk_dataset | |
| def asmk_query_ivf(net, inference, globals, logger, *, dataset, asmk_dataset, qimages, bbxs, gnd, | |
| results, cache_path, imid_offset=0): | |
| """Asmk evaluation step 'query_ivf'""" | |
| if gnd is None and cache_path and cache_path.exists(): | |
| logger.debug("Step results already exist") | |
| return | |
| data_opts = {"imsize": inference['image_size'], "transform": globals['transform']} | |
| infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']} | |
| # Query vectors | |
| qdset = ImagesFromList(root='', images=qimages, bbxs=bbxs, **data_opts) | |
| qvecs, qimids, *_ = how_net.extract_vectors_local(net, qdset, globals["device"], **infer_opts) | |
| qimids += imid_offset | |
| metadata, query_ids, ranks, scores = asmk_dataset.query_ivf(qvecs, qimids) | |
| logger.debug(f"Average query time (quant+aggr+search) is {metadata['query_avg_time']:.3f}s") | |
| # Evaluate | |
| if gnd is not None: | |
| results[dataset] = score_helpers.compute_map_and_log(dataset, ranks.T, gnd, logger=logger) | |
| with cache_path.open("wb") as handle: | |
| pickle.dump({"metadata": metadata, "query_ids": query_ids, "ranks": ranks, "scores": scores}, handle) | |
| # | |
| # Helpers | |
| # | |
| def _convert_checkpoint(state): | |
| """Enable loading checkpoints in the old format""" | |
| if "_version" not in state: | |
| # Old checkpoint format | |
| meta = state['meta'] | |
| state['net_params'] = { | |
| "architecture": meta['architecture'], | |
| "pretrained": True, | |
| "skip_layer": meta['skip_layer'], | |
| "dim_reduction": {"dim": meta["dim"]}, | |
| "smoothing": {"kernel_size": meta["feat_pool_k"]}, | |
| "runtime": { | |
| "mean_std": [meta['mean'], meta['std']], | |
| "image_size": 1024, | |
| "features_num": 1000, | |
| "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], | |
| "training_scales": [1], | |
| }, | |
| } | |
| state_dict = state['state_dict'] | |
| state_dict['dim_reduction.weight'] = state_dict.pop("whiten.weight") | |
| state_dict['dim_reduction.bias'] = state_dict.pop("whiten.bias") | |
| state['_version'] = "how/2020" | |
| return state | |