Spaces:
Runtime error
Runtime error
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons Attribution-NonCommercial | |
| # 4.0 International License. To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
| # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
| """Linear Separability (LS).""" | |
| from collections import defaultdict | |
| import numpy as np | |
| import sklearn.svm | |
| import tensorflow as tf | |
| import dnnlib.tflib as tflib | |
| from metrics import metric_base | |
| from training import misc | |
| #---------------------------------------------------------------------------- | |
| classifier_urls = [ | |
| 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl | |
| 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl | |
| 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl | |
| 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl | |
| 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl | |
| 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl | |
| 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl | |
| 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl | |
| 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl | |
| 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl | |
| 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl | |
| 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl | |
| 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl | |
| 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl | |
| 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl | |
| 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl | |
| 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl | |
| 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl | |
| 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl | |
| 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl | |
| 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl | |
| 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl | |
| 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl | |
| 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl | |
| 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl | |
| 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl | |
| 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl | |
| 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl | |
| 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl | |
| 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl | |
| 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl | |
| 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl | |
| 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl | |
| 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl | |
| 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl | |
| 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl | |
| 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl | |
| 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl | |
| 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl | |
| 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl | |
| ] | |
| #---------------------------------------------------------------------------- | |
| def prob_normalize(p): | |
| p = np.asarray(p).astype(np.float32) | |
| assert len(p.shape) == 2 | |
| return p / np.sum(p) | |
| def mutual_information(p): | |
| p = prob_normalize(p) | |
| px = np.sum(p, axis=1) | |
| py = np.sum(p, axis=0) | |
| result = 0.0 | |
| for x in range(p.shape[0]): | |
| p_x = px[x] | |
| for y in range(p.shape[1]): | |
| p_xy = p[x][y] | |
| p_y = py[y] | |
| if p_xy > 0.0: | |
| result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output | |
| return result | |
| def entropy(p): | |
| p = prob_normalize(p) | |
| result = 0.0 | |
| for x in range(p.shape[0]): | |
| for y in range(p.shape[1]): | |
| p_xy = p[x][y] | |
| if p_xy > 0.0: | |
| result -= p_xy * np.log2(p_xy) | |
| return result | |
| def conditional_entropy(p): | |
| # H(Y|X) where X corresponds to axis 0, Y to axis 1 | |
| # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? | |
| p = prob_normalize(p) | |
| y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) | |
| return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. | |
| #---------------------------------------------------------------------------- | |
| class LS(metric_base.MetricBase): | |
| def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): | |
| assert num_keep <= num_samples | |
| super().__init__(**kwargs) | |
| self.num_samples = num_samples | |
| self.num_keep = num_keep | |
| self.attrib_indices = attrib_indices | |
| self.minibatch_per_gpu = minibatch_per_gpu | |
| def _evaluate(self, Gs, num_gpus): | |
| minibatch_size = num_gpus * self.minibatch_per_gpu | |
| # Construct TensorFlow graph for each GPU. | |
| result_expr = [] | |
| for gpu_idx in range(num_gpus): | |
| with tf.device('/gpu:%d' % gpu_idx): | |
| Gs_clone = Gs.clone() | |
| # Generate images. | |
| latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) | |
| dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) | |
| images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) | |
| # Downsample to 256x256. The attribute classifiers were built for 256x256. | |
| if images.shape[2] > 256: | |
| factor = images.shape[2] // 256 | |
| images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) | |
| images = tf.reduce_mean(images, axis=[3, 5]) | |
| # Run classifier for each attribute. | |
| result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) | |
| for attrib_idx in self.attrib_indices: | |
| classifier = misc.load_pkl(classifier_urls[attrib_idx]) | |
| logits = classifier.get_output_for(images, None) | |
| predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) | |
| result_dict[attrib_idx] = predictions | |
| result_expr.append(result_dict) | |
| # Sampling loop. | |
| results = [] | |
| for _ in range(0, self.num_samples, minibatch_size): | |
| results += tflib.run(result_expr) | |
| results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} | |
| # Calculate conditional entropy for each attribute. | |
| conditional_entropies = defaultdict(list) | |
| for attrib_idx in self.attrib_indices: | |
| # Prune the least confident samples. | |
| pruned_indices = list(range(self.num_samples)) | |
| pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) | |
| pruned_indices = pruned_indices[:self.num_keep] | |
| # Fit SVM to the remaining samples. | |
| svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) | |
| for space in ['latents', 'dlatents']: | |
| svm_inputs = results[space][pruned_indices] | |
| try: | |
| svm = sklearn.svm.LinearSVC() | |
| svm.fit(svm_inputs, svm_targets) | |
| svm.score(svm_inputs, svm_targets) | |
| svm_outputs = svm.predict(svm_inputs) | |
| except: | |
| svm_outputs = svm_targets # assume perfect prediction | |
| # Calculate conditional entropy. | |
| p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] | |
| conditional_entropies[space].append(conditional_entropy(p)) | |
| # Calculate separability scores. | |
| scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} | |
| self._report_result(scores['latents'], suffix='_z') | |
| self._report_result(scores['dlatents'], suffix='_w') | |
| #---------------------------------------------------------------------------- | |