File size: 540 Bytes
96b9702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from transformers import WhisperTimeStampLogitsProcessor
class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = super().__call__(input_ids, scores)
# Enable to early exit from silence via eos token
if input_ids.shape[1] == self.begin_index:
scores_processed[:, self.eos_token_id] = scores[:, self.eos_token_id]
return scores_processed
|