lab1-cvlface-code / cvlface /research /recognition /code /run_v1 /evaluations /tinyface_evaluator.py
| from datasets import Dataset | |
| import torch | |
| from functools import partial | |
| from .base_evaluator import BaseEvaluator | |
| from tqdm import tqdm | |
| import os | |
| from .tinyface.evaluate import evaluate | |
| def preprocess_transform(examples, image_transforms): | |
| images = [image.convert("RGB") for image in examples['image']] | |
| images = [image_transforms(image) for image in images] | |
| examples["pixel_values"] = images | |
| return examples | |
| def collate_fn(examples): | |
| pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
| indexes = torch.tensor([example["index"] for example in examples], dtype=torch.int) | |
| image_paths = [example["path"] for example in examples] | |
| return { | |
| "pixel_values": pixel_values, | |
| "index": indexes, | |
| "image_paths": image_paths | |
| } | |
| class TinyFaceEvaluator(BaseEvaluator): | |
| def __init__(self, name, data_path, transform, fabric, batch_size, num_workers): | |
| super().__init__(name, fabric, batch_size) | |
| self.name = name | |
| self.data_path = data_path | |
| dataset = Dataset.load_from_disk(data_path) | |
| preprocess = partial(preprocess_transform, image_transforms=transform) | |
| dataset = dataset.with_transform(preprocess) | |
| self.dataloader = fabric.setup_dataloader_from_dataset(dataset, | |
| is_train=False, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn) | |
| self.meta = torch.load(os.path.join(data_path, 'metadata.pt')) | |
| def integrity_check(self, eval_color_space, pipeline_color_space): | |
| assert eval_color_space == pipeline_color_space | |
| def evaluate(self, pipeline, epoch=0, step=0, n_images_seen=0): | |
| pipeline.eval() | |
| collection = self.extract(pipeline) | |
| collection_flip = self.extract(pipeline, flip_images=True) | |
| if self.fabric.local_rank == 0: | |
| result = self.compute_metric(collection, collection_flip) | |
| self.log(result, epoch, step, n_images_seen) | |
| else: | |
| result = {} | |
| return result | |
| def extract(self, pipeline, flip_images=False): | |
| all_features = [] | |
| all_index = [] | |
| all_image_paths = [] | |
| for batch_idx, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader), desc='TinyFace Feature', | |
| disable=self.fabric.local_rank != 0): | |
| batch = self.complete_batch(batch) # needed for last batch to be gather compatible | |
| if self.is_debug_run(): | |
| if batch_idx > 10: | |
| break | |
| images = batch['pixel_values'] | |
| index = batch['index'] | |
| image_paths = batch['image_paths'] | |
| if flip_images: | |
| images = torch.flip(images, dims=[3]) | |
| features = pipeline(images) | |
| all_features.append(features.cpu().detach()) | |
| all_index.append(index.cpu().detach()) | |
| all_image_paths.extend(image_paths) | |
| # aggregate across all gpus | |
| per_gpu_collection = {"index": torch.cat(all_index, dim=0), | |
| 'features': torch.cat(all_features, dim=0), | |
| 'image_paths': all_image_paths} | |
| # cpu based gathering just in case we have a lot of data | |
| collection = self.gather_collection(method='cpu', per_gpu_collection=per_gpu_collection) | |
| return collection | |
| def compute_metric(self, collection, collection_flip): | |
| if self.is_debug_run(): | |
| print('Debug run, skipping metric computation') | |
| ranks = [1, 5, 20] | |
| return {k: 0.0 for k in ['rank-{}'.format(r) for r in ranks]} | |
| embeddings = (collection['features'] + collection_flip['features']).numpy() | |
| result = evaluate( | |
| all_features=embeddings, | |
| image_paths=collection['image_paths'], | |
| meta=self.meta, | |
| ) | |
| return result | |