| | import torch |
| | from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper |
| |
|
| | class TypicalLogitsWarper(BaseTypicalLogitsWarper): |
| | def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
| | super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep) |
| |
|
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| | |
| | normalized = torch.nn.functional.log_softmax(scores, dim=-1) |
| | p = torch.exp(normalized) |
| | ent = -(normalized * p).nansum(-1, keepdim=True) |
| |
|
| | |
| | shifted_scores = torch.abs((-normalized) - ent) |
| | sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) |
| | sorted_logits = scores.gather(-1, sorted_indices) |
| | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
| |
|
| | |
| | last_ind = (cumulative_probs < self.mass).sum(dim=1) |
| | last_ind[last_ind < 0] = 0 |
| | sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) |
| | if self.min_tokens_to_keep > 1: |
| | |
| | sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| |
|
| | scores = scores.masked_fill(indices_to_remove, self.filter_value) |
| | return scores |
| |
|