Spaces:
Runtime error
Runtime error
| from transformers import LogitsProcessor, LogitsProcessorList | |
| from transformers.pytorch_utils import isin_mps_friendly | |
| import math | |
| import torch | |
| class ParlerTTSLogitsProcessor(LogitsProcessor): | |
| r"""This processor ensures that the delayed pattern mask constraints are respected. | |
| <Tip warning={true}> | |
| This logits processor is exclusively compatible with Parler-TTS. | |
| See the model documentation for examples. | |
| </Tip> | |
| Args: | |
| eos_token_id (`Union[int, List[int], torch.Tensor]`): | |
| The id(s) of the *end-of-sequence* token. | |
| min_eos_p (`float`, *optional*): | |
| Minimum end of speech threshold. | |
| """ | |
| def __init__(self, eos_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"): | |
| if not isinstance(eos_token_id, torch.Tensor): | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id = torch.tensor(eos_token_id, device=device) | |
| self.eos_token_id = eos_token_id | |
| self.batch_size = batch_size | |
| if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): | |
| raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") | |
| self.num_codebooks = num_codebooks | |
| self.device = device | |
| self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device) | |
| self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks | |
| max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1 | |
| self.max_codebooks = max_codebooks | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| is_eos = isin_mps_friendly(input_ids, self.eos_token_id).sum(1) | |
| self.first_codebooks_unfinished = torch.where((is_eos[self.first_codebooks_unfinished]>0) & (self.first_codebooks_unfinished<self.max_codebooks), self.first_codebooks_unfinished+1, self.first_codebooks_unfinished) | |
| # every codebook higher than the first one unfinished will never be eos | |
| eos_token_mask = self.codebook_idx > self.first_codebooks_unfinished.repeat_interleave(self.num_codebooks) | |
| scores[eos_token_mask, self.eos_token_id] = -math.inf | |
| return scores |