yro7 commited on
Commit
1cb39b4
·
1 Parent(s): 00a09c6

fix(transformers): beging monkey patching the imports

Browse files
Files changed (4) hide show
  1. SCBs.py +13 -1
  2. encoder.py +13 -1
  3. generation.py +10 -9
  4. modeling_dicow.py +2 -2
SCBs.py CHANGED
@@ -2,7 +2,19 @@ import torch
2
  from torch import nn
3
  from transformers import WhisperConfig
4
  from transformers.activations import ACT2FN
5
- from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch.nn.functional as F
7
  from .coattention import CoAttention
8
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
 
2
  from torch import nn
3
  from transformers import WhisperConfig
4
  from transformers.activations import ACT2FN
5
+ try:
6
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
7
+ except ImportError:
8
+ from transformers.models.whisper.modeling_whisper import WhisperAttention, WhisperSdpaAttention
9
+ try:
10
+ from transformers.models.whisper.modeling_whisper import WhisperFlashAttention2
11
+ except ImportError:
12
+ WhisperFlashAttention2 = WhisperAttention
13
+ WHISPER_ATTENTION_CLASSES = {
14
+ "eager": WhisperAttention,
15
+ "sdpa": WhisperSdpaAttention,
16
+ "flash_attention_2": WhisperFlashAttention2,
17
+ }
18
  import torch.nn.functional as F
19
  from .coattention import CoAttention
20
  from .layers import CustomLinear, CustomDiagonalLinear, Gate
encoder.py CHANGED
@@ -1,7 +1,19 @@
1
  import torch
2
  from torch import nn
3
  from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
- from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from .FDDT import FDDT
7
  from .config import DiCoWConfig
 
1
  import torch
2
  from torch import nn
3
  from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
+ try:
5
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
6
+ except ImportError:
7
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WhisperAttention, WhisperSdpaAttention
8
+ try:
9
+ from transformers.models.whisper.modeling_whisper import WhisperFlashAttention2
10
+ except ImportError:
11
+ WhisperFlashAttention2 = WhisperAttention
12
+ WHISPER_ATTENTION_CLASSES = {
13
+ "eager": WhisperAttention,
14
+ "sdpa": WhisperSdpaAttention,
15
+ "flash_attention_2": WhisperFlashAttention2,
16
+ }
17
 
18
  from .FDDT import FDDT
19
  from .config import DiCoWConfig
generation.py CHANGED
@@ -55,6 +55,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
55
  model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
56
  inputs_tensor, model_kwargs, model_input_name, generation_config
57
  )
 
 
58
  self.encoder_logits = model_kwargs["encoder_outputs"].logits
59
 
60
  return model_kwargs
@@ -1436,15 +1438,14 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1436
  gen_config_copy = copy.deepcopy(generation_config)
1437
  gen_config_copy.forced_decoder_ids = None
1438
  processors = super()._get_logits_processor(
1439
- gen_config_copy,
1440
- input_ids_seq_length,
1441
- encoder_input_ids,
1442
- prefix_allowed_tokens_fn,
1443
- logits_processor,
1444
- device,
1445
- model_kwargs,
1446
- negative_prompt_ids,
1447
- negative_prompt_attention_mask,
1448
  )
1449
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1450
  enc_logits = self.encoder_logits
 
55
  model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
56
  inputs_tensor, model_kwargs, model_input_name, generation_config
57
  )
58
+ # Ensure output_hidden_states is in model_kwargs
59
+ model_kwargs["output_hidden_states"] = True
60
  self.encoder_logits = model_kwargs["encoder_outputs"].logits
61
 
62
  return model_kwargs
 
1438
  gen_config_copy = copy.deepcopy(generation_config)
1439
  gen_config_copy.forced_decoder_ids = None
1440
  processors = super()._get_logits_processor(
1441
+ generation_config=gen_config_copy,
1442
+ input_ids_seq_length=input_ids_seq_length,
1443
+ encoder_input_ids=encoder_input_ids,
1444
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1445
+ logits_processor=logits_processor,
1446
+ model_kwargs=model_kwargs,
1447
+ negative_prompt_ids=negative_prompt_ids,
1448
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
 
1449
  )
1450
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1451
  enc_logits = self.encoder_logits
modeling_dicow.py CHANGED
@@ -101,7 +101,7 @@ class DiCoW(WhisperModel):
101
  decoder_outputs = self.decoder(
102
  input_ids=decoder_input_ids,
103
  attention_mask=decoder_attention_mask,
104
- encoder_hidden_states=encoder_outputs.hidden_states[-1],
105
  head_mask=decoder_head_mask,
106
  cross_attn_head_mask=cross_attn_head_mask,
107
  past_key_values=past_key_values,
@@ -122,7 +122,7 @@ class DiCoW(WhisperModel):
122
  decoder_hidden_states=decoder_outputs.hidden_states,
123
  decoder_attentions=decoder_outputs.attentions,
124
  cross_attentions=decoder_outputs.cross_attentions,
125
- encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
126
  encoder_hidden_states=encoder_outputs.hidden_states,
127
  encoder_attentions=encoder_outputs.attentions,
128
  encoder_logits=encoder_outputs.logits,
 
101
  decoder_outputs = self.decoder(
102
  input_ids=decoder_input_ids,
103
  attention_mask=decoder_attention_mask,
104
+ encoder_hidden_states=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
105
  head_mask=decoder_head_mask,
106
  cross_attn_head_mask=cross_attn_head_mask,
107
  past_key_values=past_key_values,
 
122
  decoder_hidden_states=decoder_outputs.hidden_states,
123
  decoder_attentions=decoder_outputs.attentions,
124
  cross_attentions=decoder_outputs.cross_attentions,
125
+ encoder_last_hidden_state=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
126
  encoder_hidden_states=encoder_outputs.hidden_states,
127
  encoder_attentions=encoder_outputs.attentions,
128
  encoder_logits=encoder_outputs.logits,