Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
BUT-FIT
custom_code
Instructions to use BUT-FIT/DiCoW_v3_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BUT-FIT/DiCoW_v3_2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="BUT-FIT/DiCoW_v3_2", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("BUT-FIT/DiCoW_v3_2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
fix: add compatibility with Transformers 4.55.0
#2
by Yrooo - opened
- SCBs.py +12 -1
- __pycache__/modeling_dicow.cpython-312.pyc +0 -0
- encoder.py +14 -2
- generation.py +166 -53
- modeling_dicow.py +3 -2
- utils.py +2 -3
SCBs.py
CHANGED
|
@@ -2,7 +2,18 @@ import torch
|
|
| 2 |
from torch import nn
|
| 3 |
from transformers import WhisperConfig
|
| 4 |
from transformers.activations import ACT2FN
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from .coattention import CoAttention
|
| 8 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
from transformers import WhisperConfig
|
| 4 |
from transformers.activations import ACT2FN
|
| 5 |
+
# Compatibility fallback for Transformers versions
|
| 6 |
+
# - Transformers <= 4.38: WHISPER_ATTENTION_CLASSES is available
|
| 7 |
+
# - Transformers >= 4.39 (including 4.55.0): WHISPER_ATTENTION_CLASSES removed, use WhisperAttention dispatcher
|
| 8 |
+
try:
|
| 9 |
+
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
|
| 10 |
+
except ImportError:
|
| 11 |
+
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
| 12 |
+
WHISPER_ATTENTION_CLASSES = {
|
| 13 |
+
"eager": WhisperAttention,
|
| 14 |
+
"sdpa": WhisperAttention,
|
| 15 |
+
"flash_attention_2": WhisperAttention,
|
| 16 |
+
}
|
| 17 |
import torch.nn.functional as F
|
| 18 |
from .coattention import CoAttention
|
| 19 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
__pycache__/modeling_dicow.cpython-312.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
encoder.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from .FDDT import FDDT
|
| 7 |
from .config import DiCoWConfig
|
|
@@ -244,7 +256,7 @@ class DiCoWEncoder(WhisperEncoder):
|
|
| 244 |
layer_head_mask=None,
|
| 245 |
)
|
| 246 |
elif hasattr(self, "additional_self_attention_layer"):
|
| 247 |
-
inter_output, _
|
| 248 |
outputs.last_hidden_state,
|
| 249 |
attention_mask=None,
|
| 250 |
output_attentions=output_attentions,
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
|
| 4 |
+
# Compatibility fallback for Transformers versions
|
| 5 |
+
# - Transformers <= 4.38: WHISPER_ATTENTION_CLASSES is available
|
| 6 |
+
# - Transformers >= 4.39 (including 4.55.0): WHISPER_ATTENTION_CLASSES removed, use WhisperAttention dispatcher
|
| 7 |
+
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer
|
| 8 |
+
try:
|
| 9 |
+
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
|
| 10 |
+
except ImportError:
|
| 11 |
+
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
| 12 |
+
WHISPER_ATTENTION_CLASSES = {
|
| 13 |
+
"eager": WhisperAttention,
|
| 14 |
+
"sdpa": WhisperAttention,
|
| 15 |
+
"flash_attention_2": WhisperAttention,
|
| 16 |
+
}
|
| 17 |
|
| 18 |
from .FDDT import FDDT
|
| 19 |
from .config import DiCoWConfig
|
|
|
|
| 256 |
layer_head_mask=None,
|
| 257 |
)
|
| 258 |
elif hasattr(self, "additional_self_attention_layer"):
|
| 259 |
+
inter_output, _ = self.additional_self_attention_layer(
|
| 260 |
outputs.last_hidden_state,
|
| 261 |
attention_mask=None,
|
| 262 |
output_attentions=output_attentions,
|
generation.py
CHANGED
|
@@ -25,9 +25,10 @@ from transformers.generation.stopping_criteria import (
|
|
| 25 |
StoppingCriteriaList,
|
| 26 |
)
|
| 27 |
from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
|
| 28 |
-
|
| 29 |
GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
|
| 30 |
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
| 31 |
from transformers.models.whisper.modeling_whisper import (
|
| 32 |
WhisperForConditionalGeneration,
|
| 33 |
)
|
|
@@ -42,6 +43,93 @@ logging.set_verbosity_debug()
|
|
| 42 |
logger = logging.get_logger("transformers")
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
| 46 |
def _prepare_encoder_decoder_kwargs_for_generation(
|
| 47 |
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
|
|
@@ -55,6 +143,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 55 |
model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
|
| 56 |
inputs_tensor, model_kwargs, model_input_name, generation_config
|
| 57 |
)
|
|
|
|
|
|
|
| 58 |
self.encoder_logits = model_kwargs["encoder_outputs"].logits
|
| 59 |
|
| 60 |
return model_kwargs
|
|
@@ -153,6 +243,9 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 153 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
| 154 |
)
|
| 155 |
is_shortform = total_input_frames <= num_segment_frames
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
if is_shortform:
|
| 158 |
# warn user of ignored inputs
|
|
@@ -170,7 +263,6 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 170 |
self._set_return_outputs(
|
| 171 |
return_dict_in_generate=return_dict_in_generate,
|
| 172 |
return_token_timestamps=return_token_timestamps,
|
| 173 |
-
is_shortform=is_shortform,
|
| 174 |
logprob_threshold=logprob_threshold,
|
| 175 |
generation_config=generation_config,
|
| 176 |
)
|
|
@@ -181,7 +273,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 181 |
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
| 182 |
)
|
| 183 |
self._set_num_frames(
|
| 184 |
-
return_token_timestamps=return_token_timestamps, generation_config=generation_config,
|
|
|
|
| 185 |
)
|
| 186 |
self._set_thresholds_and_condition(
|
| 187 |
generation_config=generation_config,
|
|
@@ -278,8 +371,10 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 278 |
batch_size = input_features.shape[0]
|
| 279 |
|
| 280 |
max_frames, seek = self._retrieve_max_frames_and_seek(
|
| 281 |
-
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
|
|
|
|
| 282 |
)
|
|
|
|
| 283 |
|
| 284 |
# 6.2 Preppare running variables, list for generation
|
| 285 |
cur_bsz = batch_size
|
|
@@ -349,7 +444,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 349 |
proc.set_begin_index(decoder_input_ids.shape[-1])
|
| 350 |
|
| 351 |
# 6.8 Run generate with fallback
|
| 352 |
-
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
|
| 353 |
segment_input=segment_input,
|
| 354 |
decoder_input_ids=decoder_input_ids,
|
| 355 |
cur_bsz=cur_bsz,
|
|
@@ -732,7 +827,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 732 |
num_beams = beam_scorer.num_beams
|
| 733 |
|
| 734 |
batch_beam_size, cur_len = input_ids.shape
|
| 735 |
-
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
| 736 |
|
| 737 |
if num_beams * batch_size != batch_beam_size:
|
| 738 |
raise ValueError(
|
|
@@ -795,7 +890,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 795 |
)
|
| 796 |
outputs_per_sub_batch = [
|
| 797 |
self(
|
| 798 |
-
**
|
|
|
|
| 799 |
return_dict=True,
|
| 800 |
output_attentions=output_attentions,
|
| 801 |
output_hidden_states=output_hidden_states,
|
|
@@ -806,8 +902,10 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 806 |
outputs = stack_model_outputs(outputs_per_sub_batch)
|
| 807 |
|
| 808 |
else: # Unchanged original behavior
|
|
|
|
|
|
|
| 809 |
outputs = self(
|
| 810 |
-
**
|
| 811 |
return_dict=True,
|
| 812 |
output_attentions=output_attentions,
|
| 813 |
output_hidden_states=output_hidden_states,
|
|
@@ -1034,13 +1132,16 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1034 |
batch_size = input_ids.shape[0]
|
| 1035 |
this_peer_finished = False
|
| 1036 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 1037 |
-
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
| 1038 |
|
| 1039 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 1040 |
# prepare model inputs
|
| 1041 |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 1042 |
|
| 1043 |
# forward pass to get next token
|
|
|
|
|
|
|
|
|
|
| 1044 |
outputs = self(
|
| 1045 |
**model_inputs,
|
| 1046 |
return_dict=True,
|
|
@@ -1186,10 +1287,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1186 |
segment_input,
|
| 1187 |
decoder_input_ids,
|
| 1188 |
cur_bsz,
|
| 1189 |
-
batch_idx_map,
|
| 1190 |
seek,
|
| 1191 |
-
|
| 1192 |
-
max_frames,
|
| 1193 |
temperatures,
|
| 1194 |
generation_config,
|
| 1195 |
logits_processor,
|
|
@@ -1198,36 +1297,46 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1198 |
synced_gpus,
|
| 1199 |
return_token_timestamps,
|
| 1200 |
do_condition_on_prev_tokens,
|
| 1201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1202 |
):
|
|
|
|
|
|
|
| 1203 |
kwargs = copy.copy(kwargs)
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
| 1215 |
-
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
-
|
| 1223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1224 |
self.stno_mask_seek = None
|
| 1225 |
|
| 1226 |
-
|
| 1227 |
-
# print(f"Sequence {i}: {self.tokenizer.decode(seq, decode_with_timestamps=True)}")
|
| 1228 |
-
# print("-"*50)
|
| 1229 |
-
|
| 1230 |
-
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
|
| 1231 |
|
| 1232 |
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
|
| 1233 |
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
|
|
@@ -1264,7 +1373,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1264 |
task = getattr(generation_config, "task", None)
|
| 1265 |
language = getattr(generation_config, "language", None)
|
| 1266 |
|
| 1267 |
-
forced_decoder_ids = generation_config
|
| 1268 |
if forced_decoder_ids is not None:
|
| 1269 |
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
| 1270 |
logger.warning_once(
|
|
@@ -1289,7 +1398,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1289 |
return forced_decoder_ids
|
| 1290 |
|
| 1291 |
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
|
| 1292 |
-
generation_config
|
|
|
|
| 1293 |
|
| 1294 |
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
| 1295 |
|
|
@@ -1434,17 +1544,17 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1434 |
) -> LogitsProcessorList:
|
| 1435 |
# pylint: disable=no-member
|
| 1436 |
gen_config_copy = copy.deepcopy(generation_config)
|
| 1437 |
-
gen_config_copy
|
|
|
|
| 1438 |
processors = super()._get_logits_processor(
|
| 1439 |
-
gen_config_copy,
|
| 1440 |
-
input_ids_seq_length,
|
| 1441 |
-
encoder_input_ids,
|
| 1442 |
-
prefix_allowed_tokens_fn,
|
| 1443 |
-
logits_processor,
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
negative_prompt_attention_mask,
|
| 1448 |
)
|
| 1449 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
| 1450 |
enc_logits = self.encoder_logits
|
|
@@ -1469,8 +1579,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1469 |
processors.append(self.ctc_rescorer)
|
| 1470 |
return processors
|
| 1471 |
|
| 1472 |
-
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform,
|
| 1473 |
-
device):
|
| 1474 |
if generation_config.return_timestamps is True:
|
| 1475 |
timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
|
| 1476 |
logits_processor = (
|
|
@@ -1627,6 +1737,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1627 |
prev_idx,
|
| 1628 |
idx,
|
| 1629 |
return_token_timestamps,
|
|
|
|
|
|
|
| 1630 |
):
|
| 1631 |
# find the predicted "end of segment" predictions of Whisper
|
| 1632 |
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
|
@@ -1718,7 +1830,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1718 |
|
| 1719 |
return segments, segment_offset
|
| 1720 |
|
| 1721 |
-
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config
|
|
|
|
| 1722 |
# remove all previously passed decoder input ids
|
| 1723 |
if isinstance(seek_outputs, torch.Tensor):
|
| 1724 |
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
|
|
|
|
| 25 |
StoppingCriteriaList,
|
| 26 |
)
|
| 27 |
from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
|
| 28 |
+
GenerateBeamEncoderDecoderOutput, GenerateNonBeamOutput, \
|
| 29 |
GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
|
| 30 |
from transformers.modeling_outputs import BaseModelOutput
|
| 31 |
+
from transformers.utils import ModelOutput
|
| 32 |
from transformers.models.whisper.modeling_whisper import (
|
| 33 |
WhisperForConditionalGeneration,
|
| 34 |
)
|
|
|
|
| 43 |
logger = logging.get_logger("transformers")
|
| 44 |
|
| 45 |
|
| 46 |
+
# Backport of transformers 4.40 helpers removed in 4.55
|
| 47 |
+
def _split(data, full_batch_size: int, split_size: int = None):
|
| 48 |
+
if data is None:
|
| 49 |
+
return [None] * (full_batch_size // split_size)
|
| 50 |
+
if isinstance(data, torch.Tensor):
|
| 51 |
+
return [data[i: i + split_size] for i in range(0, full_batch_size, split_size)]
|
| 52 |
+
elif isinstance(data, tuple):
|
| 53 |
+
if isinstance(data[0], tuple):
|
| 54 |
+
return [
|
| 55 |
+
tuple(tuple(tensor[i: i + split_size] for tensor in inner_tuple) for inner_tuple in data)
|
| 56 |
+
for i in range(0, full_batch_size, split_size)
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
return [
|
| 60 |
+
tuple(sub_tensor[i: i + split_size] for sub_tensor in data)
|
| 61 |
+
for i in range(0, full_batch_size, split_size)
|
| 62 |
+
]
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unexpected attribute type: {type(data)}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _split_model_inputs(
|
| 68 |
+
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
|
| 69 |
+
) -> List[Union[ModelOutput, Dict]]:
|
| 70 |
+
if model_input is None:
|
| 71 |
+
return [model_input] * (full_batch_size // split_size)
|
| 72 |
+
model_output_cls = type(model_input)
|
| 73 |
+
if (full_batch_size % split_size) != 0:
|
| 74 |
+
raise ValueError("`full_batch_size` must be divisible by `split_size`")
|
| 75 |
+
if split_size > full_batch_size:
|
| 76 |
+
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")
|
| 77 |
+
keys = (
|
| 78 |
+
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
|
| 79 |
+
)
|
| 80 |
+
keys = [k for k in keys if k in model_input]
|
| 81 |
+
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
|
| 82 |
+
keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
|
| 83 |
+
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
|
| 84 |
+
data_split_list = [
|
| 85 |
+
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
|
| 86 |
+
for i in range(full_batch_size // split_size)
|
| 87 |
+
]
|
| 88 |
+
bool_data = {k: model_input[k] for k in bool_keys}
|
| 89 |
+
if "encoder_outputs" in model_input:
|
| 90 |
+
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
|
| 91 |
+
data_split_list = [
|
| 92 |
+
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
|
| 93 |
+
]
|
| 94 |
+
if "num_logits_to_keep" in model_input:
|
| 95 |
+
data_split_list = [
|
| 96 |
+
{**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list
|
| 97 |
+
]
|
| 98 |
+
return [model_output_cls(**data_split, **bool_data) for data_split in data_split_list]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
| 102 |
+
if not model_outputs:
|
| 103 |
+
raise ValueError("Input list is empty.")
|
| 104 |
+
model_output_cls = type(model_outputs[0])
|
| 105 |
+
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
| 106 |
+
raise ValueError("All elements in the list should be of the same type.")
|
| 107 |
+
|
| 108 |
+
def _concat(data):
|
| 109 |
+
if any(d is None for d in data):
|
| 110 |
+
return None
|
| 111 |
+
if isinstance(data[0], torch.Tensor):
|
| 112 |
+
return torch.cat(data, dim=0)
|
| 113 |
+
elif isinstance(data[0], tuple):
|
| 114 |
+
if isinstance(data[0][0], tuple):
|
| 115 |
+
return tuple(
|
| 116 |
+
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
| 117 |
+
for i in range(len(data[0]))
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
| 121 |
+
elif isinstance(data[0], (int, float)):
|
| 122 |
+
return torch.tensor(data)
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unexpected attribute type: {type(data[0])}")
|
| 125 |
+
|
| 126 |
+
concatenated_data = {
|
| 127 |
+
k: _concat([getattr(model_output, k) for model_output in model_outputs])
|
| 128 |
+
for k in model_output_cls.__dataclass_fields__.keys()
|
| 129 |
+
}
|
| 130 |
+
return model_output_cls(**concatenated_data)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
| 134 |
def _prepare_encoder_decoder_kwargs_for_generation(
|
| 135 |
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
|
|
|
|
| 143 |
model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
|
| 144 |
inputs_tensor, model_kwargs, model_input_name, generation_config
|
| 145 |
)
|
| 146 |
+
# Ensure output_hidden_states is in model_kwargs
|
| 147 |
+
model_kwargs["output_hidden_states"] = True
|
| 148 |
self.encoder_logits = model_kwargs["encoder_outputs"].logits
|
| 149 |
|
| 150 |
return model_kwargs
|
|
|
|
| 243 |
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
| 244 |
)
|
| 245 |
is_shortform = total_input_frames <= num_segment_frames
|
| 246 |
+
# Store for use in generate_with_fallback when called via parent's generate (shortform path)
|
| 247 |
+
self._num_segment_frames = num_segment_frames
|
| 248 |
+
self._max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
|
| 249 |
|
| 250 |
if is_shortform:
|
| 251 |
# warn user of ignored inputs
|
|
|
|
| 263 |
self._set_return_outputs(
|
| 264 |
return_dict_in_generate=return_dict_in_generate,
|
| 265 |
return_token_timestamps=return_token_timestamps,
|
|
|
|
| 266 |
logprob_threshold=logprob_threshold,
|
| 267 |
generation_config=generation_config,
|
| 268 |
)
|
|
|
|
| 273 |
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
| 274 |
)
|
| 275 |
self._set_num_frames(
|
| 276 |
+
return_token_timestamps=return_token_timestamps, generation_config=generation_config,
|
| 277 |
+
attention_mask=attention_mask, kwargs=kwargs
|
| 278 |
)
|
| 279 |
self._set_thresholds_and_condition(
|
| 280 |
generation_config=generation_config,
|
|
|
|
| 371 |
batch_size = input_features.shape[0]
|
| 372 |
|
| 373 |
max_frames, seek = self._retrieve_max_frames_and_seek(
|
| 374 |
+
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames,
|
| 375 |
+
is_shortform=is_shortform,
|
| 376 |
)
|
| 377 |
+
self._max_frames = max_frames
|
| 378 |
|
| 379 |
# 6.2 Preppare running variables, list for generation
|
| 380 |
cur_bsz = batch_size
|
|
|
|
| 444 |
proc.set_begin_index(decoder_input_ids.shape[-1])
|
| 445 |
|
| 446 |
# 6.8 Run generate with fallback
|
| 447 |
+
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type = self.generate_with_fallback(
|
| 448 |
segment_input=segment_input,
|
| 449 |
decoder_input_ids=decoder_input_ids,
|
| 450 |
cur_bsz=cur_bsz,
|
|
|
|
| 827 |
num_beams = beam_scorer.num_beams
|
| 828 |
|
| 829 |
batch_beam_size, cur_len = input_ids.shape
|
| 830 |
+
model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
|
| 831 |
|
| 832 |
if num_beams * batch_size != batch_beam_size:
|
| 833 |
raise ValueError(
|
|
|
|
| 890 |
)
|
| 891 |
outputs_per_sub_batch = [
|
| 892 |
self(
|
| 893 |
+
**{k: v for k, v in inputs_per_sub_batch.items()
|
| 894 |
+
if k not in ("output_attentions", "output_hidden_states")},
|
| 895 |
return_dict=True,
|
| 896 |
output_attentions=output_attentions,
|
| 897 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 902 |
outputs = stack_model_outputs(outputs_per_sub_batch)
|
| 903 |
|
| 904 |
else: # Unchanged original behavior
|
| 905 |
+
_beam_model_inputs = {k: v for k, v in model_inputs.items()
|
| 906 |
+
if k not in ("output_attentions", "output_hidden_states")}
|
| 907 |
outputs = self(
|
| 908 |
+
**_beam_model_inputs,
|
| 909 |
return_dict=True,
|
| 910 |
output_attentions=output_attentions,
|
| 911 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 1132 |
batch_size = input_ids.shape[0]
|
| 1133 |
this_peer_finished = False
|
| 1134 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 1135 |
+
model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
|
| 1136 |
|
| 1137 |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 1138 |
# prepare model inputs
|
| 1139 |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 1140 |
|
| 1141 |
# forward pass to get next token
|
| 1142 |
+
# Pop keys that we pass explicitly to avoid duplicate keyword argument error
|
| 1143 |
+
model_inputs.pop("output_attentions", None)
|
| 1144 |
+
model_inputs.pop("output_hidden_states", None)
|
| 1145 |
outputs = self(
|
| 1146 |
**model_inputs,
|
| 1147 |
return_dict=True,
|
|
|
|
| 1287 |
segment_input,
|
| 1288 |
decoder_input_ids,
|
| 1289 |
cur_bsz,
|
|
|
|
| 1290 |
seek,
|
| 1291 |
+
batch_idx_map,
|
|
|
|
| 1292 |
temperatures,
|
| 1293 |
generation_config,
|
| 1294 |
logits_processor,
|
|
|
|
| 1297 |
synced_gpus,
|
| 1298 |
return_token_timestamps,
|
| 1299 |
do_condition_on_prev_tokens,
|
| 1300 |
+
is_shortform=False,
|
| 1301 |
+
batch_size=None,
|
| 1302 |
+
attention_mask=None,
|
| 1303 |
+
kwargs=None,
|
| 1304 |
+
# Legacy args kept for DiCoW's own generate() call
|
| 1305 |
+
num_segment_frames=None,
|
| 1306 |
+
max_frames=None,
|
| 1307 |
):
|
| 1308 |
+
if kwargs is None:
|
| 1309 |
+
kwargs = {}
|
| 1310 |
kwargs = copy.copy(kwargs)
|
| 1311 |
+
# Use instance-stored values if not provided (e.g. when called from parent's generate)
|
| 1312 |
+
_num_segment_frames = num_segment_frames if num_segment_frames is not None else getattr(self, "_num_segment_frames", None)
|
| 1313 |
+
_max_frames = max_frames if max_frames is not None else getattr(self, "_max_frames", None)
|
| 1314 |
+
if _num_segment_frames is not None and _max_frames is not None:
|
| 1315 |
+
kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek,
|
| 1316 |
+
_num_segment_frames, _max_frames, kwargs)
|
| 1317 |
+
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type = \
|
| 1318 |
+
super().generate_with_fallback(
|
| 1319 |
+
segment_input=segment_input,
|
| 1320 |
+
decoder_input_ids=decoder_input_ids,
|
| 1321 |
+
cur_bsz=cur_bsz,
|
| 1322 |
+
seek=seek,
|
| 1323 |
+
batch_idx_map=batch_idx_map,
|
| 1324 |
+
temperatures=temperatures,
|
| 1325 |
+
generation_config=generation_config,
|
| 1326 |
+
logits_processor=logits_processor,
|
| 1327 |
+
stopping_criteria=stopping_criteria,
|
| 1328 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 1329 |
+
synced_gpus=synced_gpus,
|
| 1330 |
+
return_token_timestamps=return_token_timestamps,
|
| 1331 |
+
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
| 1332 |
+
is_shortform=is_shortform,
|
| 1333 |
+
batch_size=batch_size if batch_size is not None else cur_bsz,
|
| 1334 |
+
attention_mask=attention_mask,
|
| 1335 |
+
kwargs=kwargs,
|
| 1336 |
+
)
|
| 1337 |
self.stno_mask_seek = None
|
| 1338 |
|
| 1339 |
+
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1340 |
|
| 1341 |
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
|
| 1342 |
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
|
|
|
|
| 1373 |
task = getattr(generation_config, "task", None)
|
| 1374 |
language = getattr(generation_config, "language", None)
|
| 1375 |
|
| 1376 |
+
forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
|
| 1377 |
if forced_decoder_ids is not None:
|
| 1378 |
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
| 1379 |
logger.warning_once(
|
|
|
|
| 1398 |
return forced_decoder_ids
|
| 1399 |
|
| 1400 |
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
|
| 1401 |
+
if hasattr(generation_config, "forced_decoder_ids"):
|
| 1402 |
+
generation_config.forced_decoder_ids = None
|
| 1403 |
|
| 1404 |
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
| 1405 |
|
|
|
|
| 1544 |
) -> LogitsProcessorList:
|
| 1545 |
# pylint: disable=no-member
|
| 1546 |
gen_config_copy = copy.deepcopy(generation_config)
|
| 1547 |
+
if hasattr(gen_config_copy, "forced_decoder_ids"):
|
| 1548 |
+
gen_config_copy.forced_decoder_ids = None
|
| 1549 |
processors = super()._get_logits_processor(
|
| 1550 |
+
generation_config=gen_config_copy,
|
| 1551 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 1552 |
+
encoder_input_ids=encoder_input_ids,
|
| 1553 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 1554 |
+
logits_processor=logits_processor,
|
| 1555 |
+
model_kwargs=model_kwargs,
|
| 1556 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 1557 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
|
|
| 1558 |
)
|
| 1559 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
| 1560 |
enc_logits = self.encoder_logits
|
|
|
|
| 1579 |
processors.append(self.ctc_rescorer)
|
| 1580 |
return processors
|
| 1581 |
|
| 1582 |
+
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform=False,
|
| 1583 |
+
num_beams=1, device=None):
|
| 1584 |
if generation_config.return_timestamps is True:
|
| 1585 |
timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
|
| 1586 |
logits_processor = (
|
|
|
|
| 1737 |
prev_idx,
|
| 1738 |
idx,
|
| 1739 |
return_token_timestamps,
|
| 1740 |
+
time_precision_features=None,
|
| 1741 |
+
decoder_input_ids=None,
|
| 1742 |
):
|
| 1743 |
# find the predicted "end of segment" predictions of Whisper
|
| 1744 |
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
|
|
|
| 1830 |
|
| 1831 |
return segments, segment_offset
|
| 1832 |
|
| 1833 |
+
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config,
|
| 1834 |
+
is_shortform=False, seek=None, batch_idx_map=None):
|
| 1835 |
# remove all previously passed decoder input ids
|
| 1836 |
if isinstance(seek_outputs, torch.Tensor):
|
| 1837 |
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
|
modeling_dicow.py
CHANGED
|
@@ -101,7 +101,7 @@ class DiCoW(WhisperModel):
|
|
| 101 |
decoder_outputs = self.decoder(
|
| 102 |
input_ids=decoder_input_ids,
|
| 103 |
attention_mask=decoder_attention_mask,
|
| 104 |
-
encoder_hidden_states=encoder_outputs.hidden_states[-1],
|
| 105 |
head_mask=decoder_head_mask,
|
| 106 |
cross_attn_head_mask=cross_attn_head_mask,
|
| 107 |
past_key_values=past_key_values,
|
|
@@ -122,7 +122,7 @@ class DiCoW(WhisperModel):
|
|
| 122 |
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 123 |
decoder_attentions=decoder_outputs.attentions,
|
| 124 |
cross_attentions=decoder_outputs.cross_attentions,
|
| 125 |
-
encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
|
| 126 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 127 |
encoder_attentions=encoder_outputs.attentions,
|
| 128 |
encoder_logits=encoder_outputs.logits,
|
|
@@ -240,6 +240,7 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
|
|
| 240 |
output_hidden_states: Optional[bool] = None,
|
| 241 |
return_dict: Optional[bool] = None,
|
| 242 |
is_valid: Optional[bool] = None,
|
|
|
|
| 243 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
| 244 |
r"""
|
| 245 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 101 |
decoder_outputs = self.decoder(
|
| 102 |
input_ids=decoder_input_ids,
|
| 103 |
attention_mask=decoder_attention_mask,
|
| 104 |
+
encoder_hidden_states=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
|
| 105 |
head_mask=decoder_head_mask,
|
| 106 |
cross_attn_head_mask=cross_attn_head_mask,
|
| 107 |
past_key_values=past_key_values,
|
|
|
|
| 122 |
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 123 |
decoder_attentions=decoder_outputs.attentions,
|
| 124 |
cross_attentions=decoder_outputs.cross_attentions,
|
| 125 |
+
encoder_last_hidden_state=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
|
| 126 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 127 |
encoder_attentions=encoder_outputs.attentions,
|
| 128 |
encoder_logits=encoder_outputs.logits,
|
|
|
|
| 240 |
output_hidden_states: Optional[bool] = None,
|
| 241 |
return_dict: Optional[bool] = None,
|
| 242 |
is_valid: Optional[bool] = None,
|
| 243 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 244 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
| 245 |
r"""
|
| 246 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
utils.py
CHANGED
|
@@ -31,9 +31,8 @@ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
|
|
| 31 |
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
| 32 |
)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
self.begin_index = begin_index or (num_forced_ids + 1)
|
| 38 |
|
| 39 |
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
|
|
|
| 31 |
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
| 32 |
)
|
| 33 |
|
| 34 |
+
_forced_decoder_ids = getattr(generate_config, "forced_decoder_ids", None)
|
| 35 |
+
num_forced_ids = len(_forced_decoder_ids) if _forced_decoder_ids is not None else 0
|
|
|
|
| 36 |
self.begin_index = begin_index or (num_forced_ids + 1)
|
| 37 |
|
| 38 |
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|