| """ |
| Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v |
| Verified to be the same as tf version by https://github.com/universome/fvd-comparison |
| """ |
|
|
| import io |
| import re |
| import requests |
| import html |
| import hashlib |
| import urllib |
| import urllib.request |
| from typing import Any, List, Tuple, Union, Dict |
| import scipy |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| def open_url( |
| url: str, |
| num_attempts: int = 10, |
| verbose: bool = True, |
| return_filename: bool = False, |
| ) -> Any: |
| """Download the given URL and return a binary-mode file object to access the data.""" |
| assert num_attempts >= 1 |
|
|
| |
| if not re.match("^[a-z]+://", url): |
| return url if return_filename else open(url, "rb") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if url.startswith("file://"): |
| filename = urllib.parse.urlparse(url).path |
| if re.match(r"^/[a-zA-Z]:", filename): |
| filename = filename[1:] |
| return filename if return_filename else open(filename, "rb") |
|
|
| url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() |
|
|
| |
| url_name = None |
| url_data = None |
| with requests.Session() as session: |
| if verbose: |
| print("Downloading %s ..." % url, end="", flush=True) |
| for attempts_left in reversed(range(num_attempts)): |
| try: |
| with session.get(url) as res: |
| res.raise_for_status() |
| if len(res.content) == 0: |
| raise IOError("No data received") |
|
|
| if len(res.content) < 8192: |
| content_str = res.content.decode("utf-8") |
| if "download_warning" in res.headers.get("Set-Cookie", ""): |
| links = [ |
| html.unescape(link) |
| for link in content_str.split('"') |
| if "export=download" in link |
| ] |
| if len(links) == 1: |
| url = requests.compat.urljoin(url, links[0]) |
| raise IOError("Google Drive virus checker nag") |
| if "Google Drive - Quota exceeded" in content_str: |
| raise IOError( |
| "Google Drive download quota exceeded -- please try again later" |
| ) |
|
|
| match = re.search( |
| r'filename="([^"]*)"', |
| res.headers.get("Content-Disposition", ""), |
| ) |
| url_name = match[1] if match else url |
| url_data = res.content |
| if verbose: |
| print(" done") |
| break |
| except KeyboardInterrupt: |
| raise |
| except: |
| if not attempts_left: |
| if verbose: |
| print(" failed") |
| raise |
| if verbose: |
| print(".", end="", flush=True) |
|
|
| |
| assert not return_filename |
| return io.BytesIO(url_data) |
|
|
|
|
| def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: |
| mu_gen, sigma_gen = compute_stats(feats_fake) |
| mu_real, sigma_real = compute_stats(feats_real) |
|
|
| m = np.square(mu_gen - mu_real).sum() |
| s, _ = scipy.linalg.sqrtm( |
| np.dot(sigma_gen, sigma_real), disp=False |
| ) |
| fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) |
|
|
| return float(fid) |
|
|
|
|
| def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| mu = feats.mean(axis=0) |
| sigma = np.cov(feats, rowvar=False) |
|
|
| return mu, sigma |
|
|
|
|
| class FrechetVideoDistance(nn.Module): |
| def __init__(self): |
| super().__init__() |
| detector_url = ( |
| "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1" |
| ) |
| |
| self.detector_kwargs = dict(rescale=False, resize=True, return_features=True) |
| with open_url(detector_url, verbose=False) as f: |
| self.detector = torch.jit.load(f).eval() |
|
|
| @torch.no_grad() |
| def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor): |
| """ |
| :param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width) |
| :param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width) |
| :return: |
| """ |
| n_frames, batch_size, c, h, w = videos_fake.shape |
| if n_frames < 2: |
| raise ValueError("Video must have more than 1 frame for FVD") |
|
|
| videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous() |
| videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous() |
|
|
| |
| feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy() |
| feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy() |
|
|
| return compute_fvd(feats_fake, feats_real) |
|
|