Spaces:
Runtime error
Runtime error
| """ | |
| """ | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from cortex_DIM.functions.misc import log_sum_exp | |
| def raise_measure_error(measure): | |
| supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1'] | |
| raise NotImplementedError( | |
| 'Measure `{}` not supported. Supported: {}'.format(measure, | |
| supported_measures)) | |
| def get_positive_expectation(p_samples, measure, average=True): | |
| """Computes the positive part of a divergence / difference. | |
| Args: | |
| p_samples: Positive samples. | |
| measure: Measure to compute for. | |
| average: Average the result over samples. | |
| Returns: | |
| torch.Tensor | |
| """ | |
| log_2 = math.log(2.) | |
| if measure == 'GAN': | |
| Ep = - F.softplus(-p_samples) | |
| elif measure == 'JSD': | |
| Ep = log_2 - F.softplus(- p_samples) | |
| elif measure == 'X2': | |
| Ep = p_samples ** 2 | |
| elif measure == 'KL': | |
| Ep = p_samples + 1. | |
| elif measure == 'RKL': | |
| Ep = -torch.exp(-p_samples) | |
| elif measure == 'DV': | |
| Ep = p_samples | |
| elif measure == 'H2': | |
| Ep = 1. - torch.exp(-p_samples) | |
| elif measure == 'W1': | |
| Ep = p_samples | |
| else: | |
| raise_measure_error(measure) | |
| if average: | |
| return Ep.mean() | |
| else: | |
| return Ep | |
| def get_negative_expectation(q_samples, measure, average=True): | |
| """Computes the negative part of a divergence / difference. | |
| Args: | |
| q_samples: Negative samples. | |
| measure: Measure to compute for. | |
| average: Average the result over samples. | |
| Returns: | |
| torch.Tensor | |
| """ | |
| log_2 = math.log(2.) | |
| if measure == 'GAN': | |
| Eq = F.softplus(-q_samples) + q_samples | |
| elif measure == 'JSD': | |
| Eq = F.softplus(-q_samples) + q_samples - log_2 | |
| elif measure == 'X2': | |
| Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) | |
| elif measure == 'KL': | |
| Eq = torch.exp(q_samples) | |
| elif measure == 'RKL': | |
| Eq = q_samples - 1. | |
| elif measure == 'DV': | |
| Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) | |
| elif measure == 'H2': | |
| Eq = torch.exp(q_samples) - 1. | |
| elif measure == 'W1': | |
| Eq = q_samples | |
| else: | |
| raise_measure_error(measure) | |
| if average: | |
| return Eq.mean() | |
| else: | |
| return Eq |