| import math |
|
|
| import torch |
| import transformers |
| from transformers import LogitsWarper |
| from transformers.generation.logits_process import ( |
| LogitNormalization, |
| LogitsProcessor, |
| LogitsProcessorList, |
| TemperatureLogitsWarper |
| ) |
|
|
| global_scores = None |
|
|
|
|
| class TailFreeLogitsWarper(LogitsWarper): |
| def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
| tfs = float(tfs) |
| if tfs < 0 or tfs > 1.0: |
| raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}") |
| self.tfs = tfs |
| self.filter_value = filter_value |
| self.min_tokens_to_keep = min_tokens_to_keep |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| sorted_logits, sorted_indices = torch.sort(scores, descending=True) |
| probs = sorted_logits.softmax(dim=-1) |
|
|
| |
| d2 = probs.diff().diff().abs() |
| normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) |
| normalized_d2_cdf = normalized_d2.cumsum(dim=-1) |
|
|
| |
| sorted_indices_to_remove = normalized_d2_cdf > self.tfs |
|
|
| |
| sorted_indices_to_remove = torch.cat( |
| ( |
| torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), |
| sorted_indices_to_remove, |
| torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), |
| ), |
| dim=-1, |
| ) |
|
|
| if self.min_tokens_to_keep > 1: |
| |
| sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 |
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| scores = scores.masked_fill(indices_to_remove, self.filter_value) |
| return scores |
|
|
|
|
| class TopALogitsWarper(LogitsWarper): |
| def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
| top_a = float(top_a) |
| if top_a < 0 or top_a > 1.0: |
| raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") |
| self.top_a = top_a |
| self.filter_value = filter_value |
| self.min_tokens_to_keep = min_tokens_to_keep |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| sorted_logits, sorted_indices = torch.sort(scores, descending=True) |
| probs = sorted_logits.softmax(dim=-1) |
|
|
| |
| probs_max = probs[..., 0, None] |
| sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a |
|
|
| if self.min_tokens_to_keep > 1: |
| |
| sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 |
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| scores = scores.masked_fill(indices_to_remove, self.filter_value) |
| return scores |
|
|
|
|
| class MirostatLogitsWarper(LogitsWarper): |
| def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
| if mirostat_mode not in [2]: |
| raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}") |
| self.mirostat_mode = mirostat_mode |
| self.mirostat_eta = mirostat_eta |
| self.mirostat_tau = mirostat_tau |
| self.filter_value = filter_value |
| self.min_tokens_to_keep = min_tokens_to_keep |
| self.mu = 2 * self.mirostat_tau |
| self.e = 0 |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| logits = scores[0] |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| prob_original = torch.softmax(sorted_logits, dim=-1).tolist() |
|
|
| |
| for i, candidate in enumerate(prob_original): |
| if candidate > 0 and -math.log2(candidate) > self.mu: |
| if (i == 0): |
| sorted_logits = sorted_logits[:1] |
| else: |
| sorted_logits = sorted_logits[:i] |
| break |
|
|
| |
| prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') |
|
|
| prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') |
|
|
| observed_surprise = -math.log2(prob_topk[prev_i]) |
| self.e = observed_surprise - self.mirostat_tau |
|
|
| |
| self.mu -= self.mirostat_eta * self.e |
|
|
| sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool) |
| sorted_indices_to_remove[prev_i] = False |
|
|
| indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0)) |
| scores = scores.masked_fill(indices_to_remove, self.filter_value) |
| return scores |
|
|
|
|
| class SpyLogitsWarper(LogitsWarper): |
| def __init__(self): |
| pass |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| global global_scores |
| global_scores = scores |
| return scores |
|
|
|
|
| class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): |
| ''' |
| Copied from the transformers library |
| ''' |
|
|
| def __init__(self, penalty: float, _range: int): |
| if not isinstance(penalty, float) or not (penalty > 0): |
| raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") |
|
|
| self.penalty = penalty |
| self._range = _range |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
| input_ids = input_ids[:, -self._range:] |
| score = torch.gather(scores, 1, input_ids) |
|
|
| |
| score = torch.where(score < 0, score * self.penalty, score / self.penalty) |
|
|
| scores.scatter_(1, input_ids, score) |
| return scores |
|
|
|
|
| def get_logits_warper_patch(self, generation_config): |
| warpers = self._get_logits_warper_old(generation_config) |
| warpers_to_add = LogitsProcessorList() |
| min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 |
|
|
| if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2: |
| warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep)) |
| |
| for warper in warpers: |
| if not isinstance(warper, TemperatureLogitsWarper): |
| warpers.remove(warper) |
| else: |
| if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0: |
| warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep)) |
| if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0: |
| warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep)) |
|
|
| if warpers and isinstance(warpers[-1], LogitNormalization): |
| warpers = warpers[:-1] + warpers_to_add + [warpers[-1]] |
| else: |
| warpers += warpers_to_add |
|
|
| warpers.append(SpyLogitsWarper()) |
| return warpers |
|
|
|
|
| def get_logits_processor_patch(self, **kwargs): |
| result = self._get_logits_processor_old(**kwargs) |
| repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range |
| repetition_penalty = kwargs['generation_config'].repetition_penalty |
|
|
| if repetition_penalty_range > 0: |
| for i in range(len(result)): |
| if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': |
| result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range) |
|
|
| return result |
|
|
|
|
| def generation_config_init_patch(self, **kwargs): |
| self.__init___old(**kwargs) |
| self.tfs = kwargs.pop("tfs", 1.0) |
| self.top_a = kwargs.pop("top_a", 0.0) |
| self.mirostat_mode = kwargs.pop("mirostat_mode", 0) |
| self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) |
| self.mirostat_tau = kwargs.pop("mirostat_tau", 5) |
| self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) |
|
|
|
|
| def hijack_samplers(): |
| transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper |
| transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch |
|
|
| transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor |
| transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch |
|
|
| transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__ |
| transformers.GenerationConfig.__init__ = generation_config_init_patch |
|
|