Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
BUT-FIT
custom_code
Instructions to use BUT-FIT/DiCoW_v3_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BUT-FIT/DiCoW_v3_2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="BUT-FIT/DiCoW_v3_2", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("BUT-FIT/DiCoW_v3_2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
yro7 commited on
Commit ·
1cb39b4
1
Parent(s): 00a09c6
fix(transformers): beging monkey patching the imports
Browse files- SCBs.py +13 -1
- encoder.py +13 -1
- generation.py +10 -9
- modeling_dicow.py +2 -2
SCBs.py
CHANGED
|
@@ -2,7 +2,19 @@ import torch
|
|
| 2 |
from torch import nn
|
| 3 |
from transformers import WhisperConfig
|
| 4 |
from transformers.activations import ACT2FN
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from .coattention import CoAttention
|
| 8 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
from transformers import WhisperConfig
|
| 4 |
from transformers.activations import ACT2FN
|
| 5 |
+
try:
|
| 6 |
+
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
|
| 7 |
+
except ImportError:
|
| 8 |
+
from transformers.models.whisper.modeling_whisper import WhisperAttention, WhisperSdpaAttention
|
| 9 |
+
try:
|
| 10 |
+
from transformers.models.whisper.modeling_whisper import WhisperFlashAttention2
|
| 11 |
+
except ImportError:
|
| 12 |
+
WhisperFlashAttention2 = WhisperAttention
|
| 13 |
+
WHISPER_ATTENTION_CLASSES = {
|
| 14 |
+
"eager": WhisperAttention,
|
| 15 |
+
"sdpa": WhisperSdpaAttention,
|
| 16 |
+
"flash_attention_2": WhisperFlashAttention2,
|
| 17 |
+
}
|
| 18 |
import torch.nn.functional as F
|
| 19 |
from .coattention import CoAttention
|
| 20 |
from .layers import CustomLinear, CustomDiagonalLinear, Gate
|
encoder.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from .FDDT import FDDT
|
| 7 |
from .config import DiCoWConfig
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
|
| 4 |
+
try:
|
| 5 |
+
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
|
| 6 |
+
except ImportError:
|
| 7 |
+
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WhisperAttention, WhisperSdpaAttention
|
| 8 |
+
try:
|
| 9 |
+
from transformers.models.whisper.modeling_whisper import WhisperFlashAttention2
|
| 10 |
+
except ImportError:
|
| 11 |
+
WhisperFlashAttention2 = WhisperAttention
|
| 12 |
+
WHISPER_ATTENTION_CLASSES = {
|
| 13 |
+
"eager": WhisperAttention,
|
| 14 |
+
"sdpa": WhisperSdpaAttention,
|
| 15 |
+
"flash_attention_2": WhisperFlashAttention2,
|
| 16 |
+
}
|
| 17 |
|
| 18 |
from .FDDT import FDDT
|
| 19 |
from .config import DiCoWConfig
|
generation.py
CHANGED
|
@@ -55,6 +55,8 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 55 |
model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
|
| 56 |
inputs_tensor, model_kwargs, model_input_name, generation_config
|
| 57 |
)
|
|
|
|
|
|
|
| 58 |
self.encoder_logits = model_kwargs["encoder_outputs"].logits
|
| 59 |
|
| 60 |
return model_kwargs
|
|
@@ -1436,15 +1438,14 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
|
|
| 1436 |
gen_config_copy = copy.deepcopy(generation_config)
|
| 1437 |
gen_config_copy.forced_decoder_ids = None
|
| 1438 |
processors = super()._get_logits_processor(
|
| 1439 |
-
gen_config_copy,
|
| 1440 |
-
input_ids_seq_length,
|
| 1441 |
-
encoder_input_ids,
|
| 1442 |
-
prefix_allowed_tokens_fn,
|
| 1443 |
-
logits_processor,
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
negative_prompt_attention_mask,
|
| 1448 |
)
|
| 1449 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
| 1450 |
enc_logits = self.encoder_logits
|
|
|
|
| 55 |
model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
|
| 56 |
inputs_tensor, model_kwargs, model_input_name, generation_config
|
| 57 |
)
|
| 58 |
+
# Ensure output_hidden_states is in model_kwargs
|
| 59 |
+
model_kwargs["output_hidden_states"] = True
|
| 60 |
self.encoder_logits = model_kwargs["encoder_outputs"].logits
|
| 61 |
|
| 62 |
return model_kwargs
|
|
|
|
| 1438 |
gen_config_copy = copy.deepcopy(generation_config)
|
| 1439 |
gen_config_copy.forced_decoder_ids = None
|
| 1440 |
processors = super()._get_logits_processor(
|
| 1441 |
+
generation_config=gen_config_copy,
|
| 1442 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 1443 |
+
encoder_input_ids=encoder_input_ids,
|
| 1444 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 1445 |
+
logits_processor=logits_processor,
|
| 1446 |
+
model_kwargs=model_kwargs,
|
| 1447 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 1448 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
|
|
| 1449 |
)
|
| 1450 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
| 1451 |
enc_logits = self.encoder_logits
|
modeling_dicow.py
CHANGED
|
@@ -101,7 +101,7 @@ class DiCoW(WhisperModel):
|
|
| 101 |
decoder_outputs = self.decoder(
|
| 102 |
input_ids=decoder_input_ids,
|
| 103 |
attention_mask=decoder_attention_mask,
|
| 104 |
-
encoder_hidden_states=encoder_outputs.hidden_states[-1],
|
| 105 |
head_mask=decoder_head_mask,
|
| 106 |
cross_attn_head_mask=cross_attn_head_mask,
|
| 107 |
past_key_values=past_key_values,
|
|
@@ -122,7 +122,7 @@ class DiCoW(WhisperModel):
|
|
| 122 |
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 123 |
decoder_attentions=decoder_outputs.attentions,
|
| 124 |
cross_attentions=decoder_outputs.cross_attentions,
|
| 125 |
-
encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
|
| 126 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 127 |
encoder_attentions=encoder_outputs.attentions,
|
| 128 |
encoder_logits=encoder_outputs.logits,
|
|
|
|
| 101 |
decoder_outputs = self.decoder(
|
| 102 |
input_ids=decoder_input_ids,
|
| 103 |
attention_mask=decoder_attention_mask,
|
| 104 |
+
encoder_hidden_states=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
|
| 105 |
head_mask=decoder_head_mask,
|
| 106 |
cross_attn_head_mask=cross_attn_head_mask,
|
| 107 |
past_key_values=past_key_values,
|
|
|
|
| 122 |
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 123 |
decoder_attentions=decoder_outputs.attentions,
|
| 124 |
cross_attentions=decoder_outputs.cross_attentions,
|
| 125 |
+
encoder_last_hidden_state=encoder_outputs.hidden_states[-1] if encoder_outputs.hidden_states is not None else encoder_outputs.last_hidden_state,
|
| 126 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 127 |
encoder_attentions=encoder_outputs.attentions,
|
| 128 |
encoder_logits=encoder_outputs.logits,
|