| | """Evaluator for reconstruction results.""" |
| |
|
| | import warnings |
| |
|
| | from typing import Sequence, Optional, Mapping, Text |
| | import numpy as np |
| | from scipy import linalg |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from .inception import get_inception_model |
| |
|
| |
|
| | def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor: |
| | """Computes covariance of the input tensor. |
| | |
| | Args: |
| | sigma: A torch.Tensor, sum of outer products of input features. |
| | total: A torch.Tensor, sum of all input features. |
| | num_examples: An integer, number of examples in the input tensor. |
| | Returns: |
| | A torch.Tensor, covariance of the input tensor. |
| | """ |
| | if num_examples == 0: |
| | return torch.zeros_like(sigma) |
| |
|
| | sub_matrix = torch.outer(total, total) |
| | sub_matrix = sub_matrix / num_examples |
| |
|
| | return (sigma - sub_matrix) / (num_examples - 1) |
| |
|
| |
|
| | class VQGANEvaluator: |
| | def __init__( |
| | self, |
| | device, |
| | enable_rfid: bool = True, |
| | enable_inception_score: bool = True, |
| | enable_codebook_usage_measure: bool = False, |
| | enable_codebook_entropy_measure: bool = False, |
| | num_codebook_entries: int = 1024 |
| | ): |
| | """Initializes VQGAN Evaluator. |
| | |
| | Args: |
| | device: The device to use for evaluation. |
| | enable_rfid: A boolean, whether enabling rFID score. |
| | enable_inception_score: A boolean, whether enabling Inception Score. |
| | enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure. |
| | enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure. |
| | num_codebook_entries: An integer, the number of codebook entries. |
| | """ |
| | self._device = device |
| |
|
| | self._enable_rfid = enable_rfid |
| | self._enable_inception_score = enable_inception_score |
| | self._enable_codebook_usage_measure = enable_codebook_usage_measure |
| | self._enable_codebook_entropy_measure = enable_codebook_entropy_measure |
| | self._num_codebook_entries = num_codebook_entries |
| |
|
| | |
| | self._inception_model = None |
| | self._is_num_features = 0 |
| | self._rfid_num_features = 0 |
| | if self._enable_inception_score or self._enable_rfid: |
| | self._rfid_num_features = 2048 |
| | self._is_num_features = 1008 |
| | self._inception_model = get_inception_model().to(self._device) |
| | self._inception_model.eval() |
| | self._is_eps = 1e-16 |
| | self._rfid_eps = 1e-6 |
| |
|
| | self.reset_metrics() |
| |
|
| | def reset_metrics(self): |
| | """Resets all metrics.""" |
| | self._num_examples = 0 |
| | self._num_updates = 0 |
| |
|
| | self._is_prob_total = torch.zeros( |
| | self._is_num_features, dtype=torch.float64, device=self._device |
| | ) |
| | self._is_total_kl_d = torch.zeros( |
| | self._is_num_features, dtype=torch.float64, device=self._device |
| | ) |
| | self._rfid_real_sigma = torch.zeros( |
| | (self._rfid_num_features, self._rfid_num_features), |
| | dtype=torch.float64, device=self._device |
| | ) |
| | self._rfid_real_total = torch.zeros( |
| | self._rfid_num_features, dtype=torch.float64, device=self._device |
| | ) |
| | self._rfid_fake_sigma = torch.zeros( |
| | (self._rfid_num_features, self._rfid_num_features), |
| | dtype=torch.float64, device=self._device |
| | ) |
| | self._rfid_fake_total = torch.zeros( |
| | self._rfid_num_features, dtype=torch.float64, device=self._device |
| | ) |
| |
|
| | self._set_of_codebook_indices = set() |
| | self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device) |
| |
|
| | def update( |
| | self, |
| | real_images: torch.Tensor, |
| | fake_images: torch.Tensor, |
| | codebook_indices: Optional[torch.Tensor] = None |
| | ): |
| | """Updates the metrics with the given images. |
| | |
| | Args: |
| | real_images: A torch.Tensor, the real images. |
| | fake_images: A torch.Tensor, the fake images. |
| | codebook_indices: A torch.Tensor, the indices of the codebooks for each image. |
| | |
| | Raises: |
| | ValueError: If the fake images is not in RGB (3 channel). |
| | ValueError: If the fake and real images have different shape. |
| | """ |
| |
|
| | batch_size = real_images.shape[0] |
| | dim = tuple(range(1, real_images.ndim)) |
| | self._num_examples += batch_size |
| | self._num_updates += 1 |
| |
|
| | if self._enable_inception_score or self._enable_rfid: |
| | |
| | fake_inception_images = (fake_images * 255).to(torch.uint8) |
| | features_fake = self._inception_model(fake_inception_images) |
| | inception_logits_fake = features_fake["logits_unbiased"] |
| | inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1) |
| | |
| | if self._enable_inception_score: |
| | probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64) |
| |
|
| | log_prob = torch.log(inception_probabilities_fake + self._is_eps) |
| | if log_prob.dtype != inception_probabilities_fake.dtype: |
| | log_prob = log_prob.to(inception_probabilities_fake) |
| | kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64) |
| |
|
| | self._is_prob_total += probabiliies_sum |
| | self._is_total_kl_d += kl_sum |
| |
|
| | if self._enable_rfid: |
| | real_inception_images = (real_images * 255).to(torch.uint8) |
| | features_real = self._inception_model(real_inception_images) |
| | if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or |
| | features_real['2048'].shape[1] != features_fake['2048'].shape[1]): |
| | raise ValueError(f"Number of features should be equal for real and fake.") |
| |
|
| | for f_real, f_fake in zip(features_real['2048'], features_fake['2048']): |
| | self._rfid_real_total += f_real |
| | self._rfid_fake_total += f_fake |
| |
|
| | self._rfid_real_sigma += torch.outer(f_real, f_real) |
| | self._rfid_fake_sigma += torch.outer(f_fake, f_fake) |
| |
|
| | if self._enable_codebook_usage_measure: |
| | self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist()) |
| |
|
| | if self._enable_codebook_entropy_measure: |
| | entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True) |
| | self._codebook_frequencies.index_add_(0, entries.int(), counts.double()) |
| |
|
| |
|
| | def result(self) -> Mapping[Text, torch.Tensor]: |
| | """Returns the evaluation result.""" |
| | eval_score = {} |
| |
|
| | if self._num_examples < 1: |
| | raise ValueError("No examples to evaluate.") |
| | |
| | if self._enable_inception_score: |
| | mean_probs = self._is_prob_total / self._num_examples |
| | log_mean_probs = torch.log(mean_probs + self._is_eps) |
| | if log_mean_probs.dtype != self._is_prob_total.dtype: |
| | log_mean_probs = log_mean_probs.to(self._is_prob_total) |
| | excess_entropy = self._is_prob_total * log_mean_probs |
| | avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples |
| |
|
| | inception_score = torch.exp(avg_kl_d).item() |
| | eval_score["InceptionScore"] = inception_score |
| |
|
| | if self._enable_rfid: |
| | mu_real = self._rfid_real_total / self._num_examples |
| | mu_fake = self._rfid_fake_total / self._num_examples |
| | sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples) |
| | sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples) |
| |
|
| | mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu() |
| | sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu() |
| |
|
| | diff = mu_real - mu_fake |
| |
|
| | |
| | covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False) |
| | |
| | if np.iscomplexobj(covmean): |
| | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): |
| | m = np.max(np.abs(covmean.imag)) |
| | raise ValueError("Imaginary component {}".format(m)) |
| | covmean = covmean.real |
| |
|
| | tr_covmean = np.trace(covmean) |
| |
|
| | if not np.isfinite(covmean).all(): |
| | tr_covmean = np.sum(np.sqrt(( |
| | (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps)) |
| | / (self._rfid_eps * self._rfid_eps) |
| | )) |
| |
|
| | rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake) |
| | - 2 * tr_covmean |
| | ) |
| | if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)): |
| | warnings.warn("The product of covariance of train and test features is out of bounds.") |
| |
|
| | eval_score["rFID"] = rfid |
| |
|
| | if self._enable_codebook_usage_measure: |
| | usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries |
| | eval_score["CodebookUsage"] = usage |
| |
|
| | if self._enable_codebook_entropy_measure: |
| | probs = self._codebook_frequencies / self._codebook_frequencies.sum() |
| | entropy = (-torch.log2(probs + 1e-8) * probs).sum() |
| | eval_score["CodebookEntropy"] = entropy |
| |
|
| | return eval_score |
| |
|