Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import warnings | |
| from typing import Callable, Tuple | |
| import torch | |
| from torch import Tensor | |
| def _divide_and_aggregate_metrics( | |
| inputs: Tuple[Tensor, ...], | |
| n_perturb_samples: int, | |
| metric_func: Callable, | |
| agg_func: Callable = torch.add, | |
| max_examples_per_batch: int = None, | |
| ) -> Tensor: | |
| r""" | |
| This function is used to slice large number of samples `n_perturb_samples` per | |
| input example into smaller pieces, computing the metrics for each small piece and | |
| aggregating the results across all `n_perturb_samples` per example. The function | |
| returns overall aggregated metric per sample. The size of each slice is determined | |
| by the `max_examples_per_batch` input parameter. | |
| Args: | |
| inputs (tuple): The original inputs formatted in a tuple that are passed to | |
| the metrics function and that are used to compute the | |
| attributions for. | |
| n_perturb_samples (int): The number of samples per example that are used for | |
| perturbation purposes for example. | |
| metric_func (callable): This function takes the number of samples per | |
| input batch and returns an overall metric for each example. | |
| agg_func (callable, optional): This function is used to aggregate the | |
| metrics across multiple sub-batches and that are | |
| generated by `metric_func`. | |
| max_examples_per_batch (int, optional): The maximum number of allowed examples | |
| per batch. | |
| Returns: | |
| metric (tensor): A metric score estimated by `metric_func` per | |
| input example. | |
| """ | |
| bsz = inputs[0].size(0) | |
| if max_examples_per_batch is not None and ( | |
| max_examples_per_batch // bsz < 1 | |
| or max_examples_per_batch // bsz > n_perturb_samples | |
| ): | |
| warnings.warn( | |
| ( | |
| "`max_examples_per_batch` must be at least equal to the" | |
| " input batch size and at most to " | |
| "`input batch size` * `n_perturb_samples`." | |
| "`max_examples_per_batch` is: {} and the input batch size is: {}." | |
| "This is necessary because we require that each sub-batch that is used " | |
| "to compute the metrics, contains at least an instance of " | |
| "the original example and doesn't exceed the number of " | |
| "expanded n_perturb_samples." | |
| ).format(max_examples_per_batch, bsz) | |
| ) | |
| max_inps_per_batch = ( | |
| n_perturb_samples | |
| if max_examples_per_batch is None | |
| else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples) | |
| ) | |
| current_n_steps = max_inps_per_batch | |
| metrics_sum = metric_func(max_inps_per_batch) | |
| while current_n_steps < n_perturb_samples: | |
| current_n_steps += max_inps_per_batch | |
| metric = metric_func( | |
| max_inps_per_batch | |
| if current_n_steps <= n_perturb_samples | |
| else max_inps_per_batch - (current_n_steps - n_perturb_samples) | |
| ) | |
| current_n_steps = min(current_n_steps, n_perturb_samples) | |
| metrics_sum = agg_func(metrics_sum, metric) | |
| return metrics_sum | |