| | import datasets |
| | import evaluate |
| | from typing import List |
| | import torch |
| |
|
| |
|
| | _DESCRIPTION = """ |
| | Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere. |
| | (https://github.com/ssnl/align_uniform) |
| | """ |
| |
|
| | _KWARGS_DESCRIPTION = """ |
| | Args: |
| | xs (`list` of a list of `int`): a group of embeddings |
| | ys (`list` of `int`): the other group of embeddings paired with the ys |
| | |
| | Returns: |
| | "align_loss": float(align_loss_val), |
| | "x_unif_loss": float(x_unif_loss_v), |
| | "y_unif_loss": float(y_unif_loss_v), |
| | "unif_loss": float(unif_loss) |
| | |
| | |
| | Examples: |
| | |
| | Example 1-A simple example |
| | >>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity") |
| | >>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]]) |
| | >>> print(results) |
| | {'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0} |
| | """ |
| |
|
| | _CITATION = """""" |
| |
|
| |
|
| | def align_loss(x, y, alpha=2): |
| | return (x - y).norm(p=2, dim=1).pow(alpha).mean() |
| |
|
| |
|
| | def uniform_loss(x, t=2): |
| | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() |
| |
|
| |
|
| | def nonneg_uniform_loss(x, t=2): |
| | tmp = torch.pdist(x, p=2).pow(2) |
| | original = tmp.mul(-t).exp().mean().log() |
| | boundary = -t * tmp.mean() |
| | return original - boundary |
| |
|
| |
|
| | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| | class AlignmentandUniformity(evaluate.Metric): |
| | def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, |
| | *args, **kwargs): |
| | super(AlignmentandUniformity, self).__init__(*args, **kwargs) |
| | self.align_alpha = align_alpha |
| | self.unif_t = unif_t |
| | |
| | def _info(self): |
| | return evaluate.MetricInfo( |
| | description=_DESCRIPTION, |
| | citation=_CITATION, |
| | inputs_description=_KWARGS_DESCRIPTION, |
| | features=datasets.Features( |
| | { |
| | "xs": datasets.Sequence(datasets.Value("float32")), |
| | "ys": datasets.Sequence(datasets.Value("float32")), |
| | } |
| | ), |
| | reference_urls=[], |
| | ) |
| |
|
| | def _compute(self, xs: List[List], ys: List[List]): |
| | |
| | if isinstance(xs, torch.Tensor): |
| | xs = torch.Tensor(xs) |
| | elif isinstance(ys, list): |
| | xs = torch.Tensor(xs) |
| | else: |
| | raise NotImplementedError() |
| | |
| | if isinstance(ys, torch.Tensor): |
| | ys = torch.Tensor(ys) |
| | elif isinstance(ys, list): |
| | ys = torch.Tensor(ys) |
| | else: |
| | raise NotImplementedError() |
| | |
| | align_loss_val = align_loss(xs, ys, self.align_alpha) |
| | x_unif_loss_v = uniform_loss(xs, t=self.unif_t) |
| | y_unif_loss_v = uniform_loss(ys, t=self.unif_t) |
| | unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2 |
| | |
| | nn_x_unif_loss_v = nonneg_uniform_loss(xs, t=self.unif_t) |
| | nn_y_unif_loss_v = nonneg_uniform_loss(ys, t=self.unif_t) |
| | nn_unif_loss = (nn_x_unif_loss_v + nn_y_unif_loss_v) / 2 |
| | |
| | return { |
| | "align_loss": float(align_loss_val), |
| | "x_unif_loss": float(x_unif_loss_v), |
| | "y_unif_loss": float(y_unif_loss_v), |
| | "unif_loss": float(unif_loss), |
| | "nonneg_x_unif_loss": float(nn_x_unif_loss_v), |
| | "nonneg_y_unif_loss": float(nn_y_unif_loss_v), |
| | "nonneg_unif_loss": float(nn_unif_loss) |
| | } |
| |
|