File size: 540 Bytes
96b9702
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from transformers import WhisperTimeStampLogitsProcessor


class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores_processed = super().__call__(input_ids, scores)

        # Enable to early exit from silence via eos token
        if input_ids.shape[1] == self.begin_index:
            scores_processed[:, self.eos_token_id] = scores[:, self.eos_token_id]

        return scores_processed