import torch from transformers import WhisperTimeStampLogitsProcessor class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = super().__call__(input_ids, scores) # Enable to early exit from silence via eos token if input_ids.shape[1] == self.begin_index: scores_processed[:, self.eos_token_id] = scores[:, self.eos_token_id] return scores_processed