Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
BUT-FIT
custom_code
Instructions to use BUT-FIT/DiCoW_v3_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BUT-FIT/DiCoW_v3_2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="BUT-FIT/DiCoW_v3_2", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("BUT-FIT/DiCoW_v3_2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Optional | |
| import torch | |
| from transformers import WhisperTimeStampLogitsProcessor | |
| def remove_fake_elements(inputs, per_group_sizes): | |
| max_spks = per_group_sizes.max() | |
| number_of_groups = per_group_sizes.shape[0] | |
| outputs = [] | |
| inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:]) | |
| for i, group_size in enumerate(per_group_sizes): | |
| outputs.append(inputs[i, :group_size]) | |
| outputs = torch.cat(outputs, dim=0) | |
| return outputs | |
| class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor): | |
| def __init__( | |
| self, generate_config, begin_index: Optional[int] = None, | |
| _detect_timestamp_from_logprob: Optional[bool] = None | |
| ): # support for the kwargs | |
| self.no_timestamps_token_id = generate_config.no_timestamps_token_id | |
| self.timestamp_begin = generate_config.no_timestamps_token_id + 1 | |
| self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id | |
| # this variable is mostly just used for testing | |
| self._detect_timestamp_from_logprob = ( | |
| _detect_timestamp_from_logprob | |
| if _detect_timestamp_from_logprob is not None | |
| else getattr(generate_config, "_detect_timestamp_from_logprob", True) | |
| ) | |
| _forced_decoder_ids = getattr(generate_config, "forced_decoder_ids", None) | |
| num_forced_ids = len(_forced_decoder_ids) if _forced_decoder_ids is not None else 0 | |
| self.begin_index = begin_index or (num_forced_ids + 1) | |
| self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) | |
| self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None) | |
| # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 | |
| # self.max_initial_timestamp_index = 50 | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # suppress <|notimestamps|> which is handled by without_timestamps | |
| scores_processed = scores.clone() | |
| scores_processed[:, self.no_timestamps_token_id] = -float("inf") | |
| # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly | |
| for k in range(input_ids.shape[0]): | |
| sampled_tokens = input_ids[k, self.begin_index:] | |
| seq = list(sampled_tokens.tolist()) | |
| last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin | |
| penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin | |
| if last_was_timestamp: | |
| if penultimate_was_timestamp: # has to be non-timestamp | |
| scores_processed[k, self.timestamp_begin:] = -float("inf") | |
| else: # cannot be normal text tokens | |
| scores_processed[k, : self.eos_token_id] = -float("inf") | |
| timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)] | |
| if timestamps.numel() > 0: | |
| # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last | |
| # The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090 | |
| if last_was_timestamp and not penultimate_was_timestamp: | |
| timestamp_last = timestamps[-1] | |
| else: | |
| # Avoid to emit <|0.00|> again | |
| timestamp_last = timestamps[-1] + 1 | |
| scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf") | |
| # apply the `max_initial_timestamp` option | |
| if input_ids.shape[1] == self.begin_index: | |
| eos_scores = scores_processed[:, self.eos_token_id].clone() | |
| scores_processed[:, : self.timestamp_begin] = -float("inf") | |
| scores_processed[:, self.eos_token_id] = eos_scores | |
| if self.max_initial_timestamp_index is not None: | |
| last_allowed = self.timestamp_begin + self.max_initial_timestamp_index | |
| scores_processed[:, last_allowed + 1:] = -float("inf") | |
| if self.min_initial_timestamp_index is not None: | |
| first_allowed = self.timestamp_begin + self.min_initial_timestamp_index | |
| scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf") | |
| # if sum of probability over timestamps is above any other token, sample timestamp | |
| logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1) | |
| for k in range(input_ids.shape[0]): | |
| timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1) | |
| max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() | |
| if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: | |
| scores_processed[k, : self.timestamp_begin] = -float("inf") | |
| return scores_processed | |