fix: add compatibility with Transformers 4.55.0

#2
by Yrooo - opened
SCBs.py CHANGED
@@ -2,7 +2,18 @@ import torch
2
  from torch import nn
3
  from transformers import WhisperConfig
4
  from transformers.activations import ACT2FN
5
- from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
 
 
 
 
 
 
 
 
 
 
 
6
  import torch.nn.functional as F
7
  from .coattention import CoAttention
8
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
 
2
  from torch import nn
3
  from transformers import WhisperConfig
4
  from transformers.activations import ACT2FN
5
+ # Compatibility fallback for Transformers versions
6
+ # - Transformers <= 4.38: WHISPER_ATTENTION_CLASSES is available
7
+ # - Transformers >= 4.39 (including 4.55.0): WHISPER_ATTENTION_CLASSES removed, use WhisperAttention dispatcher
8
+ try:
9
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
10
+ except ImportError:
11
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
12
+ WHISPER_ATTENTION_CLASSES = {
13
+ "eager": WhisperAttention,
14
+ "sdpa": WhisperAttention,
15
+ "flash_attention_2": WhisperAttention,
16
+ }
17
  import torch.nn.functional as F
18
  from .coattention import CoAttention
19
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
__pycache__/modeling_dicow.cpython-312.pyc ADDED
Binary file (20.1 kB). View file
 
encoder.py CHANGED
@@ -1,7 +1,19 @@
1
  import torch
2
  from torch import nn
3
  from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
- from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from .FDDT import FDDT
7
  from .config import DiCoWConfig
@@ -244,7 +256,7 @@ class DiCoWEncoder(WhisperEncoder):
244
  layer_head_mask=None,
245
  )
246
  elif hasattr(self, "additional_self_attention_layer"):
247
- inter_output, _, __ = self.additional_self_attention_layer(
248
  outputs.last_hidden_state,
249
  attention_mask=None,
250
  output_attentions=output_attentions,
 
1
  import torch
2
  from torch import nn
3
  from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
+ # Compatibility fallback for Transformers versions
5
+ # - Transformers <= 4.38: WHISPER_ATTENTION_CLASSES is available
6
+ # - Transformers >= 4.39 (including 4.55.0): WHISPER_ATTENTION_CLASSES removed, use WhisperAttention dispatcher
7
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer
8
+ try:
9
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
10
+ except ImportError:
11
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
12
+ WHISPER_ATTENTION_CLASSES = {
13
+ "eager": WhisperAttention,
14
+ "sdpa": WhisperAttention,
15
+ "flash_attention_2": WhisperAttention,
16
+ }
17
 
18
  from .FDDT import FDDT
19
  from .config import DiCoWConfig
 
256
  layer_head_mask=None,
257
  )
258
  elif hasattr(self, "additional_self_attention_layer"):
259
+ inter_output, _ = self.additional_self_attention_layer(
260
  outputs.last_hidden_state,
261
  attention_mask=None,
262
  output_attentions=output_attentions,
generation.py CHANGED
@@ -25,9 +25,10 @@ from transformers.generation.stopping_criteria import (
25
  StoppingCriteriaList,
26
  )
27
  from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
28
- stack_model_outputs, GenerateBeamEncoderDecoderOutput, _split_model_inputs, GenerateNonBeamOutput, \
29
  GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
30
  from transformers.modeling_outputs import BaseModelOutput
 
31
  from transformers.models.whisper.modeling_whisper import (
32
  WhisperForConditionalGeneration,
33
  )
@@ -42,6 +43,93 @@ logging.set_verbosity_debug()
42
  logger = logging.get_logger("transformers")
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class DiCoWGenerationMixin(WhisperForConditionalGeneration):
46
  def _prepare_encoder_decoder_kwargs_for_generation(
47
  self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
@@ -55,6 +143,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
55
  model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
56
  inputs_tensor, model_kwargs, model_input_name, generation_config
57
  )
 
 
58
  self.encoder_logits = model_kwargs["encoder_outputs"].logits
59
 
60
  return model_kwargs
@@ -153,6 +243,9 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
153
  input_features=input_features, input_stride=input_stride, kwargs=kwargs
154
  )
155
  is_shortform = total_input_frames <= num_segment_frames
 
 
 
156
 
157
  if is_shortform:
158
  # warn user of ignored inputs
@@ -170,7 +263,6 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
170
  self._set_return_outputs(
171
  return_dict_in_generate=return_dict_in_generate,
172
  return_token_timestamps=return_token_timestamps,
173
- is_shortform=is_shortform,
174
  logprob_threshold=logprob_threshold,
175
  generation_config=generation_config,
176
  )
@@ -181,7 +273,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
181
  language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
182
  )
183
  self._set_num_frames(
184
- return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
 
185
  )
186
  self._set_thresholds_and_condition(
187
  generation_config=generation_config,
@@ -278,8 +371,10 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
278
  batch_size = input_features.shape[0]
279
 
280
  max_frames, seek = self._retrieve_max_frames_and_seek(
281
- batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
 
282
  )
 
283
 
284
  # 6.2 Preppare running variables, list for generation
285
  cur_bsz = batch_size
@@ -349,7 +444,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
349
  proc.set_begin_index(decoder_input_ids.shape[-1])
350
 
351
  # 6.8 Run generate with fallback
352
- seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
353
  segment_input=segment_input,
354
  decoder_input_ids=decoder_input_ids,
355
  cur_bsz=cur_bsz,
@@ -732,7 +827,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
732
  num_beams = beam_scorer.num_beams
733
 
734
  batch_beam_size, cur_len = input_ids.shape
735
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
736
 
737
  if num_beams * batch_size != batch_beam_size:
738
  raise ValueError(
@@ -795,7 +890,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
795
  )
796
  outputs_per_sub_batch = [
797
  self(
798
- **inputs_per_sub_batch,
 
799
  return_dict=True,
800
  output_attentions=output_attentions,
801
  output_hidden_states=output_hidden_states,
@@ -806,8 +902,10 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
806
  outputs = stack_model_outputs(outputs_per_sub_batch)
807
 
808
  else: # Unchanged original behavior
 
 
809
  outputs = self(
810
- **model_inputs,
811
  return_dict=True,
812
  output_attentions=output_attentions,
813
  output_hidden_states=output_hidden_states,
@@ -1034,13 +1132,16 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1034
  batch_size = input_ids.shape[0]
1035
  this_peer_finished = False
1036
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1037
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1038
 
1039
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1040
  # prepare model inputs
1041
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1042
 
1043
  # forward pass to get next token
 
 
 
1044
  outputs = self(
1045
  **model_inputs,
1046
  return_dict=True,
@@ -1186,10 +1287,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1186
  segment_input,
1187
  decoder_input_ids,
1188
  cur_bsz,
1189
- batch_idx_map,
1190
  seek,
1191
- num_segment_frames,
1192
- max_frames,
1193
  temperatures,
1194
  generation_config,
1195
  logits_processor,
@@ -1198,36 +1297,46 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1198
  synced_gpus,
1199
  return_token_timestamps,
1200
  do_condition_on_prev_tokens,
1201
- kwargs,
 
 
 
 
 
 
1202
  ):
 
 
1203
  kwargs = copy.copy(kwargs)
1204
- kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek, num_segment_frames,
1205
- max_frames, kwargs)
1206
- seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = super().generate_with_fallback(
1207
- segment_input,
1208
- decoder_input_ids,
1209
- cur_bsz,
1210
- batch_idx_map,
1211
- seek,
1212
- num_segment_frames,
1213
- max_frames,
1214
- temperatures,
1215
- generation_config,
1216
- logits_processor,
1217
- stopping_criteria,
1218
- prefix_allowed_tokens_fn,
1219
- synced_gpus,
1220
- return_token_timestamps,
1221
- do_condition_on_prev_tokens,
1222
- kwargs,
1223
- )
 
 
 
 
 
 
1224
  self.stno_mask_seek = None
1225
 
1226
- # for i, seq in enumerate(seek_outputs):
1227
- # print(f"Sequence {i}: {self.tokenizer.decode(seq, decode_with_timestamps=True)}")
1228
- # print("-"*50)
1229
-
1230
- return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
1231
 
1232
  def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1233
  def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
@@ -1264,7 +1373,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1264
  task = getattr(generation_config, "task", None)
1265
  language = getattr(generation_config, "language", None)
1266
 
1267
- forced_decoder_ids = generation_config.forced_decoder_ids
1268
  if forced_decoder_ids is not None:
1269
  if language is None and task is None and forced_decoder_ids[0][1] is None:
1270
  logger.warning_once(
@@ -1289,7 +1398,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1289
  return forced_decoder_ids
1290
 
1291
  # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1292
- generation_config.forced_decoder_ids = None
 
1293
 
1294
  is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1295
 
@@ -1434,17 +1544,17 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1434
  ) -> LogitsProcessorList:
1435
  # pylint: disable=no-member
1436
  gen_config_copy = copy.deepcopy(generation_config)
1437
- gen_config_copy.forced_decoder_ids = None
 
1438
  processors = super()._get_logits_processor(
1439
- gen_config_copy,
1440
- input_ids_seq_length,
1441
- encoder_input_ids,
1442
- prefix_allowed_tokens_fn,
1443
- logits_processor,
1444
- device,
1445
- model_kwargs,
1446
- negative_prompt_ids,
1447
- negative_prompt_attention_mask,
1448
  )
1449
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1450
  enc_logits = self.encoder_logits
@@ -1469,8 +1579,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1469
  processors.append(self.ctc_rescorer)
1470
  return processors
1471
 
1472
- def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams,
1473
- device):
1474
  if generation_config.return_timestamps is True:
1475
  timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
1476
  logits_processor = (
@@ -1627,6 +1737,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1627
  prev_idx,
1628
  idx,
1629
  return_token_timestamps,
 
 
1630
  ):
1631
  # find the predicted "end of segment" predictions of Whisper
1632
  # "end of segment" predictions occur whenever Whisper predicts a timestamp token
@@ -1718,7 +1830,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1718
 
1719
  return segments, segment_offset
1720
 
1721
- def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
 
1722
  # remove all previously passed decoder input ids
1723
  if isinstance(seek_outputs, torch.Tensor):
1724
  seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
 
25
  StoppingCriteriaList,
26
  )
27
  from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
28
+ GenerateBeamEncoderDecoderOutput, GenerateNonBeamOutput, \
29
  GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
30
  from transformers.modeling_outputs import BaseModelOutput
31
+ from transformers.utils import ModelOutput
32
  from transformers.models.whisper.modeling_whisper import (
33
  WhisperForConditionalGeneration,
34
  )
 
43
  logger = logging.get_logger("transformers")
44
 
45
 
46
+ # Backport of transformers 4.40 helpers removed in 4.55
47
+ def _split(data, full_batch_size: int, split_size: int = None):
48
+ if data is None:
49
+ return [None] * (full_batch_size // split_size)
50
+ if isinstance(data, torch.Tensor):
51
+ return [data[i: i + split_size] for i in range(0, full_batch_size, split_size)]
52
+ elif isinstance(data, tuple):
53
+ if isinstance(data[0], tuple):
54
+ return [
55
+ tuple(tuple(tensor[i: i + split_size] for tensor in inner_tuple) for inner_tuple in data)
56
+ for i in range(0, full_batch_size, split_size)
57
+ ]
58
+ else:
59
+ return [
60
+ tuple(sub_tensor[i: i + split_size] for sub_tensor in data)
61
+ for i in range(0, full_batch_size, split_size)
62
+ ]
63
+ else:
64
+ raise ValueError(f"Unexpected attribute type: {type(data)}")
65
+
66
+
67
+ def _split_model_inputs(
68
+ model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
69
+ ) -> List[Union[ModelOutput, Dict]]:
70
+ if model_input is None:
71
+ return [model_input] * (full_batch_size // split_size)
72
+ model_output_cls = type(model_input)
73
+ if (full_batch_size % split_size) != 0:
74
+ raise ValueError("`full_batch_size` must be divisible by `split_size`")
75
+ if split_size > full_batch_size:
76
+ raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")
77
+ keys = (
78
+ model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
79
+ )
80
+ keys = [k for k in keys if k in model_input]
81
+ bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
82
+ keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
83
+ non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
84
+ data_split_list = [
85
+ {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
86
+ for i in range(full_batch_size // split_size)
87
+ ]
88
+ bool_data = {k: model_input[k] for k in bool_keys}
89
+ if "encoder_outputs" in model_input:
90
+ encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
91
+ data_split_list = [
92
+ {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
93
+ ]
94
+ if "num_logits_to_keep" in model_input:
95
+ data_split_list = [
96
+ {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list
97
+ ]
98
+ return [model_output_cls(**data_split, **bool_data) for data_split in data_split_list]
99
+
100
+
101
+ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
102
+ if not model_outputs:
103
+ raise ValueError("Input list is empty.")
104
+ model_output_cls = type(model_outputs[0])
105
+ if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
106
+ raise ValueError("All elements in the list should be of the same type.")
107
+
108
+ def _concat(data):
109
+ if any(d is None for d in data):
110
+ return None
111
+ if isinstance(data[0], torch.Tensor):
112
+ return torch.cat(data, dim=0)
113
+ elif isinstance(data[0], tuple):
114
+ if isinstance(data[0][0], tuple):
115
+ return tuple(
116
+ tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
117
+ for i in range(len(data[0]))
118
+ )
119
+ else:
120
+ return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
121
+ elif isinstance(data[0], (int, float)):
122
+ return torch.tensor(data)
123
+ else:
124
+ raise ValueError(f"Unexpected attribute type: {type(data[0])}")
125
+
126
+ concatenated_data = {
127
+ k: _concat([getattr(model_output, k) for model_output in model_outputs])
128
+ for k in model_output_cls.__dataclass_fields__.keys()
129
+ }
130
+ return model_output_cls(**concatenated_data)
131
+
132
+
133
  class DiCoWGenerationMixin(WhisperForConditionalGeneration):
134
  def _prepare_encoder_decoder_kwargs_for_generation(
135
  self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
 
143
  model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
144
  inputs_tensor, model_kwargs, model_input_name, generation_config
145
  )
146
+ # Ensure output_hidden_states is in model_kwargs
147
+ model_kwargs["output_hidden_states"] = True
148
  self.encoder_logits = model_kwargs["encoder_outputs"].logits
149
 
150
  return model_kwargs
 
243
  input_features=input_features, input_stride=input_stride, kwargs=kwargs
244
  )
245
  is_shortform = total_input_frames <= num_segment_frames
246
+ # Store for use in generate_with_fallback when called via parent's generate (shortform path)
247
+ self._num_segment_frames = num_segment_frames
248
+ self._max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
249
 
250
  if is_shortform:
251
  # warn user of ignored inputs
 
263
  self._set_return_outputs(
264
  return_dict_in_generate=return_dict_in_generate,
265
  return_token_timestamps=return_token_timestamps,
 
266
  logprob_threshold=logprob_threshold,
267
  generation_config=generation_config,
268
  )
 
273
  language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
274
  )
275
  self._set_num_frames(
276
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config,
277
+ attention_mask=attention_mask, kwargs=kwargs
278
  )
279
  self._set_thresholds_and_condition(
280
  generation_config=generation_config,
 
371
  batch_size = input_features.shape[0]
372
 
373
  max_frames, seek = self._retrieve_max_frames_and_seek(
374
+ batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames,
375
+ is_shortform=is_shortform,
376
  )
377
+ self._max_frames = max_frames
378
 
379
  # 6.2 Preppare running variables, list for generation
380
  cur_bsz = batch_size
 
444
  proc.set_begin_index(decoder_input_ids.shape[-1])
445
 
446
  # 6.8 Run generate with fallback
447
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type = self.generate_with_fallback(
448
  segment_input=segment_input,
449
  decoder_input_ids=decoder_input_ids,
450
  cur_bsz=cur_bsz,
 
827
  num_beams = beam_scorer.num_beams
828
 
829
  batch_beam_size, cur_len = input_ids.shape
830
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
831
 
832
  if num_beams * batch_size != batch_beam_size:
833
  raise ValueError(
 
890
  )
891
  outputs_per_sub_batch = [
892
  self(
893
+ **{k: v for k, v in inputs_per_sub_batch.items()
894
+ if k not in ("output_attentions", "output_hidden_states")},
895
  return_dict=True,
896
  output_attentions=output_attentions,
897
  output_hidden_states=output_hidden_states,
 
902
  outputs = stack_model_outputs(outputs_per_sub_batch)
903
 
904
  else: # Unchanged original behavior
905
+ _beam_model_inputs = {k: v for k, v in model_inputs.items()
906
+ if k not in ("output_attentions", "output_hidden_states")}
907
  outputs = self(
908
+ **_beam_model_inputs,
909
  return_dict=True,
910
  output_attentions=output_attentions,
911
  output_hidden_states=output_hidden_states,
 
1132
  batch_size = input_ids.shape[0]
1133
  this_peer_finished = False
1134
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1135
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
1136
 
1137
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1138
  # prepare model inputs
1139
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1140
 
1141
  # forward pass to get next token
1142
+ # Pop keys that we pass explicitly to avoid duplicate keyword argument error
1143
+ model_inputs.pop("output_attentions", None)
1144
+ model_inputs.pop("output_hidden_states", None)
1145
  outputs = self(
1146
  **model_inputs,
1147
  return_dict=True,
 
1287
  segment_input,
1288
  decoder_input_ids,
1289
  cur_bsz,
 
1290
  seek,
1291
+ batch_idx_map,
 
1292
  temperatures,
1293
  generation_config,
1294
  logits_processor,
 
1297
  synced_gpus,
1298
  return_token_timestamps,
1299
  do_condition_on_prev_tokens,
1300
+ is_shortform=False,
1301
+ batch_size=None,
1302
+ attention_mask=None,
1303
+ kwargs=None,
1304
+ # Legacy args kept for DiCoW's own generate() call
1305
+ num_segment_frames=None,
1306
+ max_frames=None,
1307
  ):
1308
+ if kwargs is None:
1309
+ kwargs = {}
1310
  kwargs = copy.copy(kwargs)
1311
+ # Use instance-stored values if not provided (e.g. when called from parent's generate)
1312
+ _num_segment_frames = num_segment_frames if num_segment_frames is not None else getattr(self, "_num_segment_frames", None)
1313
+ _max_frames = max_frames if max_frames is not None else getattr(self, "_max_frames", None)
1314
+ if _num_segment_frames is not None and _max_frames is not None:
1315
+ kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek,
1316
+ _num_segment_frames, _max_frames, kwargs)
1317
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type = \
1318
+ super().generate_with_fallback(
1319
+ segment_input=segment_input,
1320
+ decoder_input_ids=decoder_input_ids,
1321
+ cur_bsz=cur_bsz,
1322
+ seek=seek,
1323
+ batch_idx_map=batch_idx_map,
1324
+ temperatures=temperatures,
1325
+ generation_config=generation_config,
1326
+ logits_processor=logits_processor,
1327
+ stopping_criteria=stopping_criteria,
1328
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1329
+ synced_gpus=synced_gpus,
1330
+ return_token_timestamps=return_token_timestamps,
1331
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
1332
+ is_shortform=is_shortform,
1333
+ batch_size=batch_size if batch_size is not None else cur_bsz,
1334
+ attention_mask=attention_mask,
1335
+ kwargs=kwargs,
1336
+ )
1337
  self.stno_mask_seek = None
1338
 
1339
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type
 
 
 
 
1340
 
1341
  def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1342
  def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
 
1373
  task = getattr(generation_config, "task", None)
1374
  language = getattr(generation_config, "language", None)
1375
 
1376
+ forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
1377
  if forced_decoder_ids is not None:
1378
  if language is None and task is None and forced_decoder_ids[0][1] is None:
1379
  logger.warning_once(
 
1398
  return forced_decoder_ids
1399
 
1400
  # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1401
+ if hasattr(generation_config, "forced_decoder_ids"):
1402
+ generation_config.forced_decoder_ids = None
1403
 
1404
  is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1405
 
 
1544
  ) -> LogitsProcessorList:
1545
  # pylint: disable=no-member
1546
  gen_config_copy = copy.deepcopy(generation_config)
1547
+ if hasattr(gen_config_copy, "forced_decoder_ids"):
1548
+ gen_config_copy.forced_decoder_ids = None
1549
  processors = super()._get_logits_processor(
1550
+ generation_config=gen_config_copy,
1551
+ input_ids_seq_length=input_ids_seq_length,
1552
+ encoder_input_ids=encoder_input_ids,
1553
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1554
+ logits_processor=logits_processor,
1555
+ model_kwargs=model_kwargs,
1556
+ negative_prompt_ids=negative_prompt_ids,
1557
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
 
1558
  )
1559
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1560
  enc_logits = self.encoder_logits
 
1579
  processors.append(self.ctc_rescorer)
1580
  return processors
1581
 
1582
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform=False,
1583
+ num_beams=1, device=None):
1584
  if generation_config.return_timestamps is True:
1585
  timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
1586
  logits_processor = (
 
1737
  prev_idx,
1738
  idx,
1739
  return_token_timestamps,
1740
+ time_precision_features=None,
1741
+ decoder_input_ids=None,
1742
  ):
1743
  # find the predicted "end of segment" predictions of Whisper
1744
  # "end of segment" predictions occur whenever Whisper predicts a timestamp token
 
1830
 
1831
  return segments, segment_offset
1832
 
1833
+ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config,
1834
+ is_shortform=False, seek=None, batch_idx_map=None):
1835
  # remove all previously passed decoder input ids
1836
  if isinstance(seek_outputs, torch.Tensor):
1837
  seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
modeling_dicow.py CHANGED
@@ -101,7 +101,7 @@ class DiCoW(WhisperModel):
101
  decoder_outputs = self.decoder(
102
  input_ids=decoder_input_ids,
103
  attention_mask=decoder_attention_mask,
104
- encoder_hidden_states=encoder_outputs.hidden_states[-1],
105
  head_mask=decoder_head_mask,
106
  cross_attn_head_mask=cross_attn_head_mask,
107
  past_key_values=past_key_values,
@@ -122,7 +122,7 @@ class DiCoW(WhisperModel):
122
  decoder_hidden_states=decoder_outputs.hidden_states,
123
  decoder_attentions=decoder_outputs.attentions,
124
  cross_attentions=decoder_outputs.cross_attentions,
125
- encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
126
  encoder_hidden_states=encoder_outputs.hidden_states,
127
  encoder_attentions=encoder_outputs.attentions,
128
  encoder_logits=encoder_outputs.logits,
@@ -240,6 +240,7 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
240
  output_hidden_states: Optional[bool] = None,
241
  return_dict: Optional[bool] = None,
242
  is_valid: Optional[bool] = None,
 
243
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
244
  r"""
245
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
101
  decoder_outputs = self.decoder(
102
  input_ids=decoder_input_ids,
103
  attention_mask=decoder_attention_mask,
104
+ encoder_hidden_states=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
105
  head_mask=decoder_head_mask,
106
  cross_attn_head_mask=cross_attn_head_mask,
107
  past_key_values=past_key_values,
 
122
  decoder_hidden_states=decoder_outputs.hidden_states,
123
  decoder_attentions=decoder_outputs.attentions,
124
  cross_attentions=decoder_outputs.cross_attentions,
125
+ encoder_last_hidden_state=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
126
  encoder_hidden_states=encoder_outputs.hidden_states,
127
  encoder_attentions=encoder_outputs.attentions,
128
  encoder_logits=encoder_outputs.logits,
 
240
  output_hidden_states: Optional[bool] = None,
241
  return_dict: Optional[bool] = None,
242
  is_valid: Optional[bool] = None,
243
+ cache_position: Optional[torch.LongTensor] = None,
244
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
245
  r"""
246
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
utils.py CHANGED
@@ -31,9 +31,8 @@ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
31
  else getattr(generate_config, "_detect_timestamp_from_logprob", True)
32
  )
33
 
34
- num_forced_ids = (
35
- len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
36
- )
37
  self.begin_index = begin_index or (num_forced_ids + 1)
38
 
39
  self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
 
31
  else getattr(generate_config, "_detect_timestamp_from_logprob", True)
32
  )
33
 
34
+ _forced_decoder_ids = getattr(generate_config, "forced_decoder_ids", None)
35
+ num_forced_ids = len(_forced_decoder_ids) if _forced_decoder_ids is not None else 0
 
36
  self.begin_index = begin_index or (num_forced_ids + 1)
37
 
38
  self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)