yro7 commited on
Commit
fecdbe0
·
1 Parent(s): bdf6d6c

feat(transformers): progress on update transformers to 4.55.0

Browse files
Files changed (2) hide show
  1. generation.py +16 -7
  2. modeling_dicow.py +1 -0
generation.py CHANGED
@@ -444,7 +444,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
444
  proc.set_begin_index(decoder_input_ids.shape[-1])
445
 
446
  # 6.8 Run generate with fallback
447
- seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
448
  segment_input=segment_input,
449
  decoder_input_ids=decoder_input_ids,
450
  cur_bsz=cur_bsz,
@@ -827,7 +827,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
827
  num_beams = beam_scorer.num_beams
828
 
829
  batch_beam_size, cur_len = input_ids.shape
830
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
831
 
832
  if num_beams * batch_size != batch_beam_size:
833
  raise ValueError(
@@ -890,7 +890,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
890
  )
891
  outputs_per_sub_batch = [
892
  self(
893
- **inputs_per_sub_batch,
 
894
  return_dict=True,
895
  output_attentions=output_attentions,
896
  output_hidden_states=output_hidden_states,
@@ -901,8 +902,10 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
901
  outputs = stack_model_outputs(outputs_per_sub_batch)
902
 
903
  else: # Unchanged original behavior
 
 
904
  outputs = self(
905
- **model_inputs,
906
  return_dict=True,
907
  output_attentions=output_attentions,
908
  output_hidden_states=output_hidden_states,
@@ -1129,13 +1132,16 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1129
  batch_size = input_ids.shape[0]
1130
  this_peer_finished = False
1131
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1132
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1133
 
1134
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1135
  # prepare model inputs
1136
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1137
 
1138
  # forward pass to get next token
 
 
 
1139
  outputs = self(
1140
  **model_inputs,
1141
  return_dict=True,
@@ -1330,7 +1336,7 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1330
  )
1331
  self.stno_mask_seek = None
1332
 
1333
- return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
1334
 
1335
  def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1336
  def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
@@ -1731,6 +1737,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1731
  prev_idx,
1732
  idx,
1733
  return_token_timestamps,
 
 
1734
  ):
1735
  # find the predicted "end of segment" predictions of Whisper
1736
  # "end of segment" predictions occur whenever Whisper predicts a timestamp token
@@ -1822,7 +1830,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1822
 
1823
  return segments, segment_offset
1824
 
1825
- def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
 
1826
  # remove all previously passed decoder input ids
1827
  if isinstance(seek_outputs, torch.Tensor):
1828
  seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
 
444
  proc.set_begin_index(decoder_input_ids.shape[-1])
445
 
446
  # 6.8 Run generate with fallback
447
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type = self.generate_with_fallback(
448
  segment_input=segment_input,
449
  decoder_input_ids=decoder_input_ids,
450
  cur_bsz=cur_bsz,
 
827
  num_beams = beam_scorer.num_beams
828
 
829
  batch_beam_size, cur_len = input_ids.shape
830
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
831
 
832
  if num_beams * batch_size != batch_beam_size:
833
  raise ValueError(
 
890
  )
891
  outputs_per_sub_batch = [
892
  self(
893
+ **{k: v for k, v in inputs_per_sub_batch.items()
894
+ if k not in ("output_attentions", "output_hidden_states")},
895
  return_dict=True,
896
  output_attentions=output_attentions,
897
  output_hidden_states=output_hidden_states,
 
902
  outputs = stack_model_outputs(outputs_per_sub_batch)
903
 
904
  else: # Unchanged original behavior
905
+ _beam_model_inputs = {k: v for k, v in model_inputs.items()
906
+ if k not in ("output_attentions", "output_hidden_states")}
907
  outputs = self(
908
+ **_beam_model_inputs,
909
  return_dict=True,
910
  output_attentions=output_attentions,
911
  output_hidden_states=output_hidden_states,
 
1132
  batch_size = input_ids.shape[0]
1133
  this_peer_finished = False
1134
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1135
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[-1], input_ids.device, model_kwargs)
1136
 
1137
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1138
  # prepare model inputs
1139
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1140
 
1141
  # forward pass to get next token
1142
+ # Pop keys that we pass explicitly to avoid duplicate keyword argument error
1143
+ model_inputs.pop("output_attentions", None)
1144
+ model_inputs.pop("output_hidden_states", None)
1145
  outputs = self(
1146
  **model_inputs,
1147
  return_dict=True,
 
1336
  )
1337
  self.stno_mask_seek = None
1338
 
1339
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, _model_output_type
1340
 
1341
  def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1342
  def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
 
1737
  prev_idx,
1738
  idx,
1739
  return_token_timestamps,
1740
+ time_precision_features=None,
1741
+ decoder_input_ids=None,
1742
  ):
1743
  # find the predicted "end of segment" predictions of Whisper
1744
  # "end of segment" predictions occur whenever Whisper predicts a timestamp token
 
1830
 
1831
  return segments, segment_offset
1832
 
1833
+ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config,
1834
+ is_shortform=False, seek=None, batch_idx_map=None):
1835
  # remove all previously passed decoder input ids
1836
  if isinstance(seek_outputs, torch.Tensor):
1837
  seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
modeling_dicow.py CHANGED
@@ -240,6 +240,7 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
240
  output_hidden_states: Optional[bool] = None,
241
  return_dict: Optional[bool] = None,
242
  is_valid: Optional[bool] = None,
 
243
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
244
  r"""
245
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
240
  output_hidden_states: Optional[bool] = None,
241
  return_dict: Optional[bool] = None,
242
  is_valid: Optional[bool] = None,
243
+ cache_position: Optional[torch.LongTensor] = None,
244
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
245
  r"""
246
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):