Lakoc commited on
Commit
96b9702
·
verified ·
1 Parent(s): 1346660

Upload DiCoWForConditionalGeneration

Browse files
Files changed (12) hide show
  1. FDDT.py +63 -0
  2. README.md +199 -0
  3. config.json +75 -0
  4. config.py +63 -0
  5. decoding.py +349 -0
  6. encoder.py +246 -0
  7. generation.py +1147 -0
  8. generation_config.json +12 -0
  9. layers.py +223 -0
  10. model.safetensors +3 -0
  11. modeling_dicow.py +357 -0
  12. utils.py +14 -0
FDDT.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .layers import CustomDiagonalLinear, CustomLinear
4
+
5
+
6
+ class FDDT(nn.Module):
7
+ def __init__(self, d_model, non_target_rate=0.01, fddt_init=None, is_diagonal=False,
8
+ bias_only=False, use_silence=True, use_target=True, use_overlap=True, use_non_target=True):
9
+ super().__init__()
10
+ if use_target:
11
+ self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
12
+ CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init,
13
+ init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
14
+ d_model,
15
+ bias=True, fddt_init=fddt_init,
16
+ init_eye_val=1.0))
17
+ if use_non_target:
18
+ self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
19
+ CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init,
20
+ init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
21
+ d_model, d_model, bias=True, fddt_init=fddt_init, init_eye_val=non_target_rate))
22
+ if use_overlap:
23
+ self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
24
+ CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init,
25
+ init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
26
+ d_model,
27
+ bias=True, fddt_init=fddt_init,
28
+ init_eye_val=1.0))
29
+ if use_silence:
30
+ self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
31
+ CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init,
32
+ init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
33
+ d_model, d_model, bias=True, fddt_init=fddt_init, init_eye_val=non_target_rate))
34
+
35
+ self.use_silence = use_silence
36
+ self.use_target = use_target
37
+ self.use_overlap = use_overlap
38
+ self.use_non_target = use_non_target
39
+ self.bias_only = bias_only
40
+
41
+ def forward(self, hidden_states, stno_mask):
42
+ stno_mask = stno_mask.to(hidden_states.device)[..., None]
43
+ if self.bias_only:
44
+ if self.use_silence:
45
+ hidden_states += stno_mask[:, 0, ...] * self.silence_linear
46
+ if self.use_target:
47
+ hidden_states += stno_mask[:, 1, ...] * self.target_linear
48
+ if self.use_non_target:
49
+ hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
50
+ if self.use_overlap:
51
+ hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
52
+ else:
53
+ orig_hidden_states = hidden_states
54
+ hidden_states = (self.silence_linear(
55
+ orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
56
+ (self.target_linear(
57
+ orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
58
+ (self.non_target_linear(
59
+ orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
60
+ :] + \
61
+ (self.overlap_linear(
62
+ orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :]
63
+ return hidden_states
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "additional_layer": false,
5
+ "additional_self_attention_layer": true,
6
+ "apply_fddt_to_n_layers": -1,
7
+ "apply_spec_augment": false,
8
+ "architectures": [
9
+ "DiCoWForConditionalGeneration"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "auto_map": {
13
+ "AutoConfig": "config.DiCoWConfig",
14
+ "AutoModelForSpeechSeq2Seq": "modeling_dicow.DiCoWForConditionalGeneration"
15
+ },
16
+ "begin_suppress_tokens": [
17
+ 220,
18
+ 50256
19
+ ],
20
+ "blank_token_id": null,
21
+ "bos_token_id": 50257,
22
+ "classifier_proj_size": 256,
23
+ "ctc_loss_reduction": "mean",
24
+ "ctc_weight": 0.3,
25
+ "ctc_zero_infinity": false,
26
+ "d_model": 1280,
27
+ "decoder_attention_heads": 20,
28
+ "decoder_ffn_dim": 5120,
29
+ "decoder_layerdrop": 0.0,
30
+ "decoder_layers": 4,
31
+ "decoder_start_token_id": 50258,
32
+ "dropout": 0.0,
33
+ "encoder_attention_heads": 20,
34
+ "encoder_ffn_dim": 5120,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 32,
37
+ "eos_token_id": 50257,
38
+ "fddt_bias_only": false,
39
+ "fddt_init": "suppressive",
40
+ "fddt_is_diagonal": true,
41
+ "fddt_use_non_target": true,
42
+ "fddt_use_overlap": true,
43
+ "fddt_use_silence": true,
44
+ "fddt_use_target": true,
45
+ "final_dropout": 0.0,
46
+ "forced_decoder_ids": null,
47
+ "init_std": 0.02,
48
+ "is_encoder_decoder": true,
49
+ "mask_feature_length": 10,
50
+ "mask_feature_min_masks": 0,
51
+ "mask_feature_prob": 0.0,
52
+ "mask_time_length": 10,
53
+ "mask_time_min_masks": 2,
54
+ "mask_time_prob": 0.05,
55
+ "max_source_positions": 1500,
56
+ "max_target_positions": 448,
57
+ "median_filter_width": 7,
58
+ "model_type": "DiCoW",
59
+ "non_target_fddt_value": 0.5,
60
+ "num_hidden_layers": 32,
61
+ "num_mel_bins": 128,
62
+ "pad_token_id": 50257,
63
+ "pre_ctc_sub_sample": true,
64
+ "remove_timestamps_from_ctc": true,
65
+ "scale_embedding": false,
66
+ "scb_layers": 8,
67
+ "torch_dtype": "float32",
68
+ "transformers_version": "4.55.0",
69
+ "use_cache": true,
70
+ "use_enrollments": true,
71
+ "use_fddt": true,
72
+ "use_pre_pos_fddt": true,
73
+ "use_weighted_layer_sum": false,
74
+ "vocab_size": 51866
75
+ }
config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import WhisperConfig
4
+
5
+
6
+ class DiCoWConfig(WhisperConfig):
7
+ """This is a modified version of the `WhisperEncoder` model from the `transformers` library.
8
+ The model has been modified to support CTC loss computation in the forward pass."""
9
+ model_type = "DiCoW"
10
+
11
+ def __init__(
12
+ self,
13
+ ctc_loss_reduction: str = "mean",
14
+ final_dropout: float = 0.0,
15
+ ctc_zero_infinity: bool = False,
16
+ ctc_weight: float = 0.0,
17
+ blank_token_id: Optional[int] = None,
18
+ additional_layer: bool = False,
19
+ additional_self_attention_layer: bool = False,
20
+ pre_ctc_sub_sample: bool = False,
21
+ use_fddt: bool = True,
22
+ fddt_is_diagonal: bool = True,
23
+ fddt_bias_only: bool = False,
24
+ fddt_use_silence: bool = True,
25
+ fddt_use_target: bool = True,
26
+ fddt_use_overlap: bool = True,
27
+ fddt_use_non_target: bool = True,
28
+ remove_timestamps_from_ctc: bool = False,
29
+ apply_fddt_to_n_layers: int = -1,
30
+ fddt_init: str = 'suppressive', # random, non-disturbing
31
+ non_target_fddt_value: float = 0.0,
32
+ use_enrollments: bool = False,
33
+ scb_layers: Optional[int] = None,
34
+ use_pre_pos_fddt: bool = False,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.ctc_loss_reduction = ctc_loss_reduction
39
+ self.final_dropout = final_dropout
40
+ self.ctc_zero_infinity = ctc_zero_infinity
41
+ self.ctc_weight = ctc_weight
42
+ self.blank_token_id = blank_token_id
43
+ self.additional_layer = additional_layer
44
+ self.additional_self_attention_layer = additional_self_attention_layer
45
+ self.pre_ctc_sub_sample = pre_ctc_sub_sample
46
+ self.use_fddt = use_fddt
47
+ self.fddt_is_diagonal = fddt_is_diagonal
48
+ self.fddt_bias_only = fddt_bias_only
49
+ self.fddt_use_silence = fddt_use_silence
50
+ self.fddt_use_target = fddt_use_target
51
+ self.fddt_use_overlap = fddt_use_overlap
52
+ self.fddt_use_non_target = fddt_use_non_target
53
+ self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
54
+ self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
55
+ self.fddt_init = fddt_init
56
+ self.non_target_fddt_value = non_target_fddt_value
57
+ self.use_enrollments = use_enrollments
58
+ self.scb_layers = scb_layers
59
+ self.use_pre_pos_fddt = use_pre_pos_fddt
60
+
61
+
62
+
63
+ _HIDDEN_STATES_START_POSITION = 2
decoding.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import LogitsProcessor, PreTrainedTokenizer
6
+
7
+
8
+ class CTCPrefixScore(object):
9
+ """Compute CTC label sequence scores
10
+
11
+ which is based on Algorithm 2 in WATANABE et al.
12
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
13
+ but extended to efficiently compute the label probabilities for multiple
14
+ hypotheses simultaneously
15
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
16
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
17
+ """
18
+
19
+ def __init__(self, x, blank, eos):
20
+ self.logzero = -1e10
21
+ self.blank = blank
22
+ self.eos = eos
23
+ self.input_length = x.shape[1]
24
+ self.batch_size = x.shape[0]
25
+ self.x = x
26
+ self.device = x.device
27
+
28
+ # Preallocate `r` and `xs` tensors
29
+ # `num_labels` will be set dynamically in __call__ but preallocated with maximum capacity
30
+ self.max_num_labels = x.shape[2] # Set to a max value that can be dynamically resized
31
+ self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero,
32
+ device=self.device)
33
+ self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero,
34
+ device=self.device)
35
+
36
+ def initial_state(self):
37
+ """Obtain an initial CTC state."""
38
+ # Create initial CTC state tensor and use in-place operations to fill
39
+ r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device)
40
+ r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1)
41
+ s = torch.zeros((self.batch_size, 1), device=self.device)
42
+
43
+ return r, s
44
+
45
+ def _resize_tensors(self, number_of_current_samples, num_labels):
46
+ if self.r.shape[0] != number_of_current_samples:
47
+ self.r = self.r[:number_of_current_samples, ...]
48
+ self.xs = self.xs[:number_of_current_samples, ...]
49
+
50
+ if self.r.shape[3] != num_labels:
51
+ self.r = self.r[:, :, :, :num_labels].fill_(self.logzero)
52
+ self.xs = self.xs[:, :, :num_labels].fill_(self.logzero)
53
+ else:
54
+ self.r.fill_(self.logzero)
55
+ self.xs.fill_(self.logzero)
56
+
57
+ def _initialize_r(self, decoded_len):
58
+ mask = (decoded_len == 0)
59
+ self.r[mask, 0, 0, :] = self.xs[mask, 0]
60
+
61
+ def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev):
62
+ # Expand r_sum for num_labels and initialize log_phi
63
+ log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1])
64
+
65
+ # Create mask for cases where `decoded_len > 0` and to identify where `c == last[i]` for all `i`
66
+ non_zero_mask = (decoded_len > 0)
67
+ label_match_mask = (cs == last.unsqueeze(1))
68
+
69
+ # Update log_phi where both `decoded_len > 0` and `c == last[i]`
70
+ log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi)
71
+ return log_phi
72
+
73
+ def _compute_log_psi(self, decoded_len, log_phi, x_current):
74
+ """This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)),
75
+ and log prefix probabilities log(psi) for all labels in the batch.
76
+
77
+ :param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence
78
+ :param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities
79
+ :param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame
80
+
81
+ :return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities
82
+ """
83
+ B, T, V = log_phi.shape
84
+ start = torch.clamp(decoded_len, min=1) # Ensure start is at least 1 to avoid out-of-bounds
85
+
86
+ # Initialize log_psi with the start position of r[:, start - 1, 0, :]
87
+ log_psi = self.r[torch.arange(B), start - 1, 0, :]
88
+
89
+ # Mask for handling sequence lengths based on decoded_len
90
+ mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1)
91
+
92
+ # Accumulate log_psi only up to the last valid time step for each sequence
93
+ log_psi = torch.logaddexp(log_psi, torch.logsumexp(
94
+ torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1))
95
+
96
+ start = torch.clamp(decoded_len, 1)
97
+
98
+
99
+ for t in range(start.min(), self.input_length):
100
+ should_decode = decoded_len <= t
101
+ self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0],
102
+ log_phi[:, t - 1]) + self.xs[:, t]
103
+ self.r[:, t, 1] = (
104
+ torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None]
105
+ )
106
+ if ~should_decode.any():
107
+ self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero)
108
+
109
+ return log_psi
110
+
111
+ def _update_log_psi_with_eos(self, log_psi, cs, r_sum):
112
+ # Update log_psi for eos positions
113
+ eos_mask = (cs == self.eos)
114
+ log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask]
115
+
116
+ # Exclude blank probabilities if eos is not the blank
117
+ if self.eos != self.blank:
118
+ blank_mask = (cs == self.blank)
119
+ log_psi[blank_mask] = self.logzero
120
+ return log_psi
121
+
122
+ def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev):
123
+ """Compute CTC prefix scores for next labels
124
+
125
+ :param y : prefix label sequence
126
+ :param cs : array of next labels
127
+ :param r_prev: previous CTC state
128
+ :return ctc_scores, ctc_states
129
+ """
130
+ # initialize CTC states
131
+ # output_length = y.shape[1] - 1 # ignore sos
132
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
133
+ # that corresponds to r_t^n(h) and r_t^b(h).
134
+
135
+ # Dynamically resize r and xs to match num_labels if necessary
136
+ num_labels = cs.shape[1]
137
+ number_of_current_samples = cs.shape[0]
138
+ self._resize_tensors(number_of_current_samples, num_labels)
139
+
140
+ # Create a view of the current input frame
141
+ x_current = self.x[samples_to_be_decoded]
142
+ self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1))
143
+
144
+ # Initialize r for the first frame
145
+ self._initialize_r(decoded_len)
146
+
147
+ # prepare forward probabilities for the last label
148
+ r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) # log(r_t^n(g) + r_t^b(g))
149
+ last = y[:, -1]
150
+
151
+ # precompute log_phi
152
+ log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev)
153
+
154
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
155
+ # and log prefix probabilities log(psi)
156
+ log_psi = self._compute_log_psi(decoded_len, log_phi, x_current)
157
+
158
+ # get P(...eos|X) that ends with the prefix itself
159
+ log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum)
160
+
161
+ # return the log prefix probability and CTC states, where the label axis
162
+ # of the CTC states is moved to the first axis to slice it easily
163
+ return log_psi, self.r
164
+
165
+
166
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
167
+ def __init__(
168
+ self,
169
+ encoder_logits: torch.FloatTensor,
170
+ encoder_output_lens: torch.Tensor,
171
+ blank_token_id: int,
172
+ pad_token_id: int,
173
+ eos_token_id: int,
174
+ bos_token_id: int,
175
+ tokenizer: PreTrainedTokenizer,
176
+ ctc_margin: int,
177
+ ctc_weight: float,
178
+ num_beams: int,
179
+ debug: bool = False,
180
+ ctc_tokens_to_score: int = 500
181
+ ):
182
+ super().__init__()
183
+ same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items())))
184
+
185
+ logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1)
186
+ logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]]
187
+
188
+ self.logits = logits
189
+
190
+ self.ctc_prefix_scorer = CTCPrefixScore(
191
+ self.logits,
192
+ blank_token_id,
193
+ eos_token_id,
194
+ )
195
+ self.batch_size = logits.shape[0]
196
+ self.input_length = logits.shape[1]
197
+ self.num_tokens = logits.shape[2]
198
+ self.device = logits.device
199
+ self.ctc_weight = ctc_weight
200
+ self.num_beams = num_beams
201
+ self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state()
202
+ self.eos_token_id = eos_token_id
203
+ self.bos_token_id = bos_token_id
204
+ self.tokenizer = tokenizer
205
+ self.pad_token_id = pad_token_id
206
+ self.blank_token_id = blank_token_id
207
+ self.debug = False
208
+ self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"]
209
+ self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device)
210
+ self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2),
211
+ device=self.device)
212
+ self.ctc_tokens_to_score = ctc_tokens_to_score
213
+
214
+ def analyze_predictions(self,
215
+ scores, ctc_scores, next_token_scores, input_ids, k=10):
216
+ print("\n" + "#" * 100)
217
+
218
+ batch_size = input_ids.shape[0]
219
+
220
+ best_att_ids = scores.topk(k=k, dim=1)
221
+ ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero
222
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
223
+ best_ids = next_token_scores.topk(k=k, dim=1)
224
+
225
+ decoded_prefixes = self.tokenizer.batch_decode(
226
+ input_ids, decode_with_timestamps=True, skip_special_tokens=False
227
+ )
228
+
229
+ def prepare_and_decode(best_ids_tensor):
230
+ new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long)
231
+ new_tensor[:, 0::2] = best_ids_tensor.indices
232
+ new_tensor[:, 1::2] = self.tokenizer.vocab['#']
233
+
234
+ # Flatten to (batch_size * k, 2)
235
+ flat_tensor = new_tensor.view(-1, 2)
236
+ decoded = self.tokenizer.batch_decode(
237
+ flat_tensor, decode_with_timestamps=True, skip_special_tokens=False
238
+ )
239
+ # Reshape back to (batch_size, k)
240
+ decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)]
241
+ return decoded
242
+
243
+ decoded_att = prepare_and_decode(best_att_ids)
244
+ decoded_ctc = prepare_and_decode(best_ctc_ids)
245
+ decoded_next = prepare_and_decode(best_ids)
246
+
247
+ for idx in range(batch_size):
248
+ print("-" * 80)
249
+ print(f"HYPOTHESIS {idx}")
250
+ print("\nPREFIX:")
251
+ print(decoded_prefixes[idx])
252
+
253
+ def print_with_pandas(tokens, scores, title):
254
+ df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]])
255
+ df.index = [f"{title}", "Score"]
256
+ print(f"\n{title}:")
257
+ print(df.to_string(index=True, header=False))
258
+
259
+ print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS")
260
+ print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS")
261
+ print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS")
262
+
263
+ print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}")
264
+ print()
265
+
266
+ print("#" * 100)
267
+
268
+ def update_state(self, best_ids, beam_idx):
269
+ mask = best_ids < self.first_timestamp_token_id
270
+ self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1),
271
+ self.tmp_ctc_states[beam_idx, best_ids],
272
+ self.ctc_state_prev[beam_idx])
273
+ self.ctc_score_prev = torch.where(mask.unsqueeze(-1),
274
+ self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1),
275
+ self.ctc_score_prev[beam_idx])
276
+
277
+ def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
278
+ input_ids = input_ids_orig.clone()
279
+
280
+ # Remove prefix from CTC scoring
281
+ if (input_ids[:, 0] != self.bos_token_id).any():
282
+ input_ids = torch.stack(
283
+ [row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids])
284
+
285
+ # Remove task/lang/timestamp tokens from input_ids
286
+ input_prefix_len = len(self.tokenizer.prefix_tokens)
287
+ if input_prefix_len > 1:
288
+ input_ids = input_ids[:, input_prefix_len - 1:]
289
+
290
+ # Setup the first token to be the blank token(sos)
291
+ input_ids[:, 0] = self.blank_token_id
292
+
293
+ # If there is last token in input_ids timestamp replicate last non-timestamp token which could be potentially even the first token
294
+ decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id,
295
+ input_ids != self.blank_token_id).sum(dim=1)
296
+ mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id,
297
+ input_ids[:, -1] != self.blank_token_id)
298
+ last_non_timestamp_token = torch.gather(input_ids, 1,
299
+ torch.logical_or(input_ids < self.first_timestamp_token_id,
300
+ input_ids == self.blank_token_id).sum(dim=1,
301
+ keepdim=True) - 1)
302
+ input_ids[mask, -1] = last_non_timestamp_token[mask, 0]
303
+
304
+ # If there is no eos token in the last position, we need to continue decoding
305
+ to_be_decoded = input_ids[:, -1] != self.eos_token_id
306
+ self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero
307
+
308
+ input_ids_local = input_ids[to_be_decoded]
309
+ ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices
310
+
311
+ # always score EOS token if not present put on position of last id
312
+ is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1)
313
+ ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id
314
+
315
+ decoded_len_local = decoded_len[to_be_decoded]
316
+
317
+ ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded],
318
+ decoded_len_local, to_be_decoded,
319
+ self.ctc_state_prev[to_be_decoded])
320
+
321
+ # As the CTC scorer might run on subset of samples, we need to scatter the results back to the original batch
322
+ self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded]
323
+ .scatter(1, ids_to_score[to_be_decoded], ctc_scores_local))
324
+ self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1)
325
+ .scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1)
326
+ .repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local)
327
+ .permute(0, 3, 1, 2))
328
+
329
+ # Set the CTC score for the timestamp tokens to the maximum to prefer them over the rest
330
+ self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None]
331
+ ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev
332
+
333
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
334
+
335
+ if self.debug:
336
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig)
337
+
338
+ return next_token_scores
339
+
340
+
341
+ class LogSoftmaxProcessor(LogitsProcessor):
342
+ def __init__(
343
+ self,
344
+ ):
345
+ super().__init__()
346
+
347
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
348
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
349
+ return scores
encoder.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
4
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WhisperAttention
5
+ from .FDDT import FDDT
6
+ from .config import DiCoWConfig
7
+ from .layers import CustomLinear, CustomDiagonalLinear, Gate, SpeakerCommunicationBlock
8
+
9
+
10
+ class DiCoWEncoder(WhisperEncoder):
11
+ config_class = DiCoWConfig
12
+
13
+ def __init__(self, config: DiCoWConfig):
14
+ super().__init__(config)
15
+ self.ctc_weight = config.ctc_weight
16
+ if config.additional_layer and self.ctc_weight > 0.0:
17
+ self.additional_layer = WhisperEncoderLayer(config)
18
+ if config.additional_self_attention_layer and self.ctc_weight > 0.0:
19
+ self.additional_self_attention_layer = WhisperAttention(
20
+ embed_dim=config.d_model,
21
+ num_heads=config.encoder_attention_heads,
22
+ dropout=config.attention_dropout,
23
+ config=config,
24
+ )
25
+ if config.pre_ctc_sub_sample and self.ctc_weight > 0.0:
26
+ self.subsample_conv1 = nn.Conv1d(
27
+ in_channels=config.d_model,
28
+ out_channels=config.d_model,
29
+ kernel_size=3,
30
+ stride=2,
31
+ padding=1,
32
+ bias=False,
33
+ )
34
+ self.subsample_conv2 = nn.Conv1d(
35
+ in_channels=config.d_model,
36
+ out_channels=config.d_model,
37
+ kernel_size=3,
38
+ stride=2,
39
+ padding=1,
40
+ bias=False,
41
+ )
42
+ if self.ctc_weight > 0.0:
43
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False)
44
+ self.final_dropout = nn.Dropout(config.final_dropout)
45
+ if config.use_fddt:
46
+ num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len(
47
+ self.layers)
48
+ self.fddts = nn.ModuleList([
49
+ FDDT(
50
+ d_model=config.d_model,
51
+ non_target_rate=1.0,
52
+ fddt_init=config.fddt_init,
53
+ is_diagonal=config.fddt_is_diagonal,
54
+ bias_only=config.fddt_bias_only,
55
+ use_silence=config.fddt_use_silence,
56
+ use_target=config.fddt_use_target,
57
+ use_overlap=config.fddt_use_overlap,
58
+ use_non_target=config.fddt_use_non_target,
59
+ )
60
+ for _ in range(num_fddts)
61
+ ])
62
+ if config.use_pre_pos_fddt:
63
+ self.initial_fddt = FDDT(
64
+ d_model=config.d_model,
65
+ non_target_rate=config.non_target_fddt_value,
66
+ fddt_init=config.fddt_init,
67
+ is_diagonal=config.fddt_is_diagonal,
68
+ bias_only=config.fddt_bias_only,
69
+ use_silence=config.fddt_use_silence,
70
+ use_target=config.fddt_use_target,
71
+ use_overlap=config.fddt_use_overlap,
72
+ use_non_target=config.fddt_use_non_target,
73
+ )
74
+ if config.use_enrollments and config.scb_layers is not None:
75
+ self.ca_enrolls = nn.ModuleList([SpeakerCommunicationBlock(config) for _ in range(config.scb_layers)])
76
+ self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks
77
+ self.post_init()
78
+
79
+ def _init_weights(self, module):
80
+ super()._init_weights(module)
81
+ if isinstance(module, CustomLinear) or isinstance(module, CustomDiagonalLinear) or isinstance(module, Gate):
82
+ module.reset_parameters()
83
+
84
+ def get_output_embeddings(self):
85
+ return None
86
+
87
+ def possibly_update_last_hidden_states(self, hidden_states):
88
+ if hasattr(self, "additional_layer"):
89
+ hidden_states, = self.additional_layer(
90
+ hidden_states,
91
+ attention_mask=None,
92
+ output_attentions=False,
93
+ layer_head_mask=None,
94
+ )
95
+ elif hasattr(self, "additional_self_attention_layer"):
96
+ hidden_states, _ = self.additional_self_attention_layer(
97
+ hidden_states,
98
+ attention_mask=None,
99
+ output_attentions=False,
100
+ layer_head_mask=None,
101
+ )
102
+
103
+ hidden_states = self.final_dropout(hidden_states)
104
+ if hasattr(self, "subsample_conv2"):
105
+ hidden_states = self.subsample_conv2(self.subsample_conv1(hidden_states.transpose(1, 2))).transpose(1, 2)
106
+ return hidden_states
107
+
108
+ def get_loss(self, logits, labels):
109
+ if labels.max() >= self.config.vocab_size:
110
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
111
+ if self.config.remove_timestamps_from_ctc:
112
+ labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels],
113
+ padding_value=-100).T
114
+ input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1],
115
+ device=logits.device)
116
+
117
+ # assuming that padded tokens are filled with -100
118
+ # when not being attended to
119
+ labels_mask = labels >= 0
120
+ target_lengths = labels_mask.sum(-1)
121
+
122
+ # ctc_loss doesn't support fp16
123
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
124
+
125
+ with torch.backends.cudnn.flags(enabled=True):
126
+ ctc_loss = nn.functional.ctc_loss(
127
+ log_probs,
128
+ labels,
129
+ input_lengths,
130
+ target_lengths,
131
+ blank=logits.shape[-1] - 1,
132
+ reduction=self.config.ctc_loss_reduction,
133
+ zero_infinity=True,
134
+ )
135
+ return ctc_loss
136
+
137
+ def get_max_len(self):
138
+ return self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
139
+
140
+ def forward(
141
+ self,
142
+ input_features,
143
+ attention_mask=None,
144
+ head_mask=None,
145
+ output_attentions=None,
146
+ output_hidden_states=None,
147
+ return_dict=None,
148
+ stno_mask=None,
149
+ return_logits=False,
150
+ enrollments=None
151
+ ):
152
+ if enrollments is not None:
153
+ input_features = torch.stack((input_features, enrollments['input_features']), dim=1).flatten(0,1)
154
+ stno_mask = torch.stack((stno_mask, enrollments['stno_mask']),dim=1).flatten(0,1)
155
+
156
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
157
+ if input_features.shape[-1] != expected_seq_length:
158
+ raise ValueError(
159
+ f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
160
+ )
161
+
162
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
163
+ output_hidden_states = (
164
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
165
+ )
166
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
167
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
168
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
169
+
170
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
171
+
172
+ """<DiCoW CODE>"""
173
+ if self.config.use_fddt and self.config.use_pre_pos_fddt:
174
+ inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask)
175
+ """</DiCoW CODE>"""
176
+
177
+ all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)
178
+
179
+ hidden_states = inputs_embeds + self.embed_positions(all_positions)
180
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
181
+
182
+ encoder_states = () if output_hidden_states else None
183
+ all_attentions = () if output_attentions else None
184
+
185
+ # check if head_mask has a correct number of layers specified if desired
186
+ if head_mask is not None:
187
+ assert head_mask.size()[0] == (len(self.layers)), (
188
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
189
+ )
190
+
191
+ for idx, encoder_layer in enumerate(self.layers):
192
+ if output_hidden_states:
193
+ encoder_states = encoder_states + (hidden_states,)
194
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
195
+ to_drop = False
196
+ if self.training:
197
+ dropout_probability = torch.rand([])
198
+ if dropout_probability < self.layerdrop: # skip the layer
199
+ to_drop = True
200
+
201
+ if to_drop:
202
+ layer_outputs = (None, None)
203
+ else:
204
+ """<DiCoW CODE>"""
205
+ if self.config.use_fddt and idx < len(self.fddts):
206
+ hidden_states = self.fddts[idx](hidden_states, stno_mask)
207
+
208
+ if self.config.use_enrollments and idx < self.config.scb_layers:
209
+ hidden_states = self.ca_enrolls[idx](hidden_states)
210
+ if idx == self.config.scb_layers -1:
211
+ # enrollment representations are not longer needed
212
+ hidden_states = hidden_states[::2]
213
+ stno_mask = stno_mask[::2]
214
+ """</DiCoW CODE>"""
215
+
216
+ layer_outputs = encoder_layer(
217
+ hidden_states,
218
+ None,
219
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
220
+ output_attentions=output_attentions,
221
+ )
222
+
223
+ hidden_states = layer_outputs[0]
224
+
225
+ if output_attentions:
226
+ all_attentions = all_attentions + (layer_outputs[1],)
227
+
228
+ hidden_states = self.layer_norm(hidden_states)
229
+
230
+ if output_hidden_states:
231
+ encoder_states = encoder_states + (hidden_states,)
232
+
233
+ if return_logits:
234
+ hidden_states = hidden_states
235
+ hidden_states = self.possibly_update_last_hidden_states(hidden_states)
236
+ logits = self.lm_head(hidden_states)
237
+
238
+ return CausalLMOutput(
239
+ loss=None, logits=logits, hidden_states=hidden_states,
240
+ )
241
+
242
+ if not return_dict:
243
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
244
+ return BaseModelOutput(
245
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
246
+ )
generation.py ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from decimal import Decimal, ROUND_HALF_UP
4
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from transformers import PreTrainedModel
13
+ from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
14
+ from transformers.generation.logits_process import (
15
+ LogitsProcessorList,
16
+ SuppressTokensAtBeginLogitsProcessor,
17
+ SuppressTokensLogitsProcessor, )
18
+ from transformers.generation.logits_process import WhisperNoSpeechDetection
19
+ from transformers.generation.stopping_criteria import (
20
+ StoppingCriteriaList,
21
+ )
22
+ from transformers.generation.utils import GenerateNonBeamOutput, \
23
+ GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateBeamOutput, GenerateBeamDecoderOnlyOutput, \
24
+ GenerateBeamEncoderDecoderOutput
25
+ from transformers.modeling_outputs import BaseModelOutput
26
+ from transformers.models.whisper.modeling_whisper import (
27
+ WhisperForConditionalGeneration,
28
+ )
29
+ from transformers.utils import logging
30
+ from .decoding import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
31
+ from .utils import WhisperTimeStampLogitsProcessorCustom
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers.generation.streamers import BaseStreamer
35
+
36
+ logging.set_verbosity_debug()
37
+ logger = logging.get_logger("transformers")
38
+
39
+
40
+ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
41
+
42
+ def _prepare_encoder_decoder_kwargs_for_generation(
43
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
44
+ ) -> Dict[str, Any]:
45
+ # pylint: disable=no-memberva
46
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
47
+ inputs_tensor, model_kwargs, model_input_name, generation_config
48
+ )
49
+
50
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
51
+ self.encoder_logits = self.get_enc_logits(model_kwargs["encoder_outputs"].last_hidden_state)
52
+
53
+ return model_kwargs
54
+
55
+ def _prepare_decoder_input_ids_for_generation(
56
+ self,
57
+ batch_size: int,
58
+ model_input_name: str,
59
+ model_kwargs: Dict[str, torch.Tensor],
60
+ decoder_start_token_id: torch.Tensor,
61
+ device: torch.device = None,
62
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
63
+ batch_size = model_kwargs['decoder_input_ids'].shape[0]
64
+ out = super()._prepare_decoder_input_ids_for_generation(
65
+ batch_size,
66
+ model_input_name,
67
+ model_kwargs,
68
+ decoder_start_token_id,
69
+ device,
70
+ )
71
+ return out
72
+
73
+ def prepare_kwargs_for_generate(self,
74
+ max_frames,
75
+ cur_bsz,
76
+ batch_idx_map,
77
+ seek,
78
+ kwargs,
79
+ attention_mask):
80
+ """This method also prepares STNO masks and other kwargs for generation."""
81
+
82
+ seek_vad = seek // 2
83
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
84
+ num_segment_frames = input_stride * self.config.max_source_positions
85
+ num_frames_vad = num_segment_frames // 2
86
+ max_frames_vad = max_frames // 2
87
+ seek_num_frames = (max_frames_vad - seek_vad).clamp(max=num_frames_vad)
88
+
89
+ stno_masks = []
90
+ for i in range(cur_bsz):
91
+ prev_i = batch_idx_map[i]
92
+ segment_input_slice = kwargs["stno_mask"][prev_i: prev_i + 1, :,
93
+ seek_vad[prev_i]: seek_vad[prev_i] + seek_num_frames[prev_i]]
94
+
95
+ if segment_input_slice.shape[-1] < num_frames_vad:
96
+ orig_len = segment_input_slice.shape[-1]
97
+ # pad to 1500 if necessary
98
+ segment_input_slice = torch.nn.functional.pad(
99
+ segment_input_slice, pad=(0, num_frames_vad - orig_len)
100
+ )
101
+ # set corresponding padding tokens to 1 in vad mask representing silence
102
+ segment_input_slice[0, 0, orig_len:] = 1.0
103
+
104
+ stno_masks.append(segment_input_slice)
105
+ kwargs["stno_mask"] = torch.cat(stno_masks, dim=0)
106
+ self.stno_mask_seek = kwargs["stno_mask"]
107
+
108
+ if self.config.use_enrollments and "enrollments" in kwargs:
109
+ for key in kwargs["enrollments"]:
110
+ kwargs["enrollments"][key] = kwargs["enrollments"][key][batch_idx_map]
111
+
112
+ if attention_mask is not None:
113
+ attention_mask = attention_mask[batch_idx_map]
114
+
115
+ if "labels" in kwargs:
116
+ kwargs['labels'] = kwargs["labels"][batch_idx_map]
117
+ kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
118
+ return kwargs, attention_mask
119
+
120
+
121
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
122
+ task = getattr(generation_config, "task", None)
123
+ language = getattr(generation_config, "language", None)
124
+
125
+ forced_decoder_ids = generation_config.forced_decoder_ids if hasattr(generation_config, "forced_decoder_ids") else None
126
+ if forced_decoder_ids is not None:
127
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
128
+ logger.warning_once(
129
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
130
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
131
+ )
132
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
133
+ forced_decoder_ids = config.forced_decoder_ids
134
+
135
+ elif forced_decoder_ids is not None and language is not None:
136
+ logger.info(
137
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
138
+ )
139
+ forced_decoder_ids = None
140
+
141
+ if forced_decoder_ids is not None:
142
+ return forced_decoder_ids
143
+
144
+ init_tokens = super()._retrieve_init_tokens(input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
145
+ return init_tokens
146
+
147
+ def detect_language(
148
+ self,
149
+ input_features: Optional[torch.FloatTensor] = None,
150
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
151
+ generation_config: Optional[GenerationConfig] = None,
152
+ num_segment_frames: int = 3000,
153
+ ) -> torch.Tensor:
154
+ """
155
+ Detects language from log-mel input features or encoder_outputs
156
+
157
+ Parameters:
158
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
159
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
160
+ loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via
161
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
162
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
163
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
164
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
165
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
166
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
167
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
168
+ generation_config (`~generation.GenerationConfig`, *optional*):
169
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
170
+ passed to generate matching the attributes of `generation_config` will override them. If
171
+ `generation_config` is not provided, the default will be used, which had the following loading
172
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
173
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
174
+ default values, whose documentation should be checked to parameterize generation.
175
+ num_segment_frames (`int`, *optional*, defaults to 3000):
176
+ The number of log-mel frames the model expects
177
+
178
+ Return:
179
+ A `torch.LongTensor` representing the detected language ids.
180
+ """
181
+ if input_features is None and encoder_outputs is None:
182
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
183
+ elif input_features is not None and encoder_outputs is not None:
184
+ raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")
185
+ elif input_features is not None:
186
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
187
+ batch_size = input_features.shape[0]
188
+ elif encoder_outputs is not None:
189
+ inputs = {"encoder_outputs": encoder_outputs}
190
+ batch_size = (
191
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
192
+ )
193
+
194
+ generation_config = generation_config or self.generation_config
195
+ decoder_input_ids = (
196
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
197
+ * generation_config.decoder_start_token_id
198
+ )
199
+
200
+ with torch.no_grad():
201
+
202
+ """<DiCoW CODE>"""
203
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False,
204
+ stno_mask=self.stno_mask[:, :, :num_segment_frames // 2]).logits[:, -1]
205
+ """</DiCoW CODE>"""
206
+
207
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
208
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
209
+
210
+ logits[:, non_lang_mask] = -np.inf
211
+
212
+ lang_ids = logits.argmax(-1)
213
+
214
+ return lang_ids
215
+
216
+ def _get_logits_processor(
217
+ self,
218
+ generation_config: GenerationConfig,
219
+ input_ids_seq_length: Optional[int] = None,
220
+ encoder_input_ids: Optional[torch.LongTensor] = None,
221
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
222
+ logits_processor: Optional[LogitsProcessorList] = None,
223
+ device: Optional[str] = None,
224
+ model_kwargs: Optional[dict[str, Any]] = None,
225
+ negative_prompt_ids: Optional[torch.Tensor] = None,
226
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
227
+ ) -> LogitsProcessorList:
228
+ # pylint: disable=no-member
229
+ gen_config_copy = copy.deepcopy(generation_config)
230
+ gen_config_copy.forced_decoder_ids = None
231
+ processors = super()._get_logits_processor(
232
+ gen_config_copy,
233
+ input_ids_seq_length,
234
+ encoder_input_ids,
235
+ prefix_allowed_tokens_fn,
236
+ logits_processor,
237
+ device,
238
+ model_kwargs,
239
+ negative_prompt_ids,
240
+ negative_prompt_attention_mask,
241
+ )
242
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
243
+ enc_logits = self.encoder_logits
244
+ if generation_config.num_beams <= 1:
245
+ processors.append(LogSoftmaxProcessor())
246
+ else:
247
+ enc_logits = enc_logits.repeat_interleave(generation_config.num_beams, dim=0)
248
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
249
+ enc_logits,
250
+ torch.full((enc_logits.shape[0],), fill_value=enc_logits.shape[1],
251
+ device=enc_logits.device),
252
+ enc_logits.shape[-1] - 1,
253
+ generation_config.pad_token_id,
254
+ generation_config.eos_token_id,
255
+ generation_config.decoder_start_token_id,
256
+ self.tokenizer,
257
+ 0,
258
+ generation_config.ctc_weight,
259
+ generation_config.num_beams,
260
+ False,
261
+ )
262
+ processors.append(self.ctc_rescorer)
263
+ return processors
264
+
265
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
266
+ if generation_config.return_timestamps is True:
267
+ """<DiCoW CODE>"""
268
+ timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
269
+ """</DiCoW CODE>"""
270
+ logits_processor = (
271
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
272
+ )
273
+
274
+ if generation_config.suppress_tokens is not None:
275
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
276
+ logits_processor = (
277
+ [suppress_tokens_processor]
278
+ if logits_processor is None
279
+ else [suppress_tokens_processor] + logits_processor
280
+ )
281
+ generation_config.suppress_tokens = None
282
+
283
+ if generation_config.begin_suppress_tokens is not None:
284
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
285
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
286
+ )
287
+ logits_processor = (
288
+ [begin_suppress_processor]
289
+ if logits_processor is None
290
+ else [begin_suppress_processor] + logits_processor
291
+ )
292
+ generation_config.begin_suppress_tokens = None
293
+
294
+ if generation_config.no_speech_threshold is not None:
295
+ no_speech_detector = WhisperNoSpeechDetection(
296
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
297
+ begin_index=begin_index,
298
+ scores_is_logprobs=num_beams > 1,
299
+ )
300
+ logits_processor = (
301
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
302
+ )
303
+ no_speech_detector.set_model(self)
304
+
305
+ return logits_processor
306
+
307
+ @staticmethod
308
+ def round_to_nearest_0_02(x):
309
+ d = Decimal(str(x)) # Use str(x) to preserve input precision
310
+ step = Decimal('0.02')
311
+ # Divide, round, multiply back
312
+ rounded = (d / step).to_integral_value(rounding=ROUND_HALF_UP) * step
313
+ return rounded
314
+
315
+ def _fix_timestamps_from_segmentation(self, sequences):
316
+ """
317
+ Adjusts token sequences with global timestamps to fit within Whisper's 0–30s timestamp token range.
318
+ """
319
+ # Get the token ID for the "<|0.00|>" timestamp used to detect dummy segments
320
+ first_timestamp_token = self.tokenizer.get_vocab()["<|0.00|>"]
321
+ empty_text_token = self.tokenizer.get_vocab()["Ġ"]
322
+ results = []
323
+
324
+ # Filter out segments that are either empty or consist only of the "<|0.00|>" token
325
+ for idx, sequence_segs in enumerate(sequences['segments']):
326
+ sequences['segments'][idx] = [
327
+ seg for seg in sequence_segs
328
+ if len(seg['tokens']) > 0 and (len(seg['tokens']) != 1 or seg['tokens'][0] != first_timestamp_token)
329
+ ]
330
+
331
+ # Iterate over each group of segments
332
+ for idx, sequence_segs in enumerate(sequences['segments']):
333
+ result = []
334
+ prev_segment_end_time = None
335
+ correction = Decimal(0.0)
336
+
337
+ for i, seg in enumerate(sequence_segs):
338
+ # Round start and end times to nearest 0.02 seconds
339
+ start_time = self.round_to_nearest_0_02(seg['start'].item())
340
+ end_time = self.round_to_nearest_0_02(seg['end'].item())
341
+ tokens = seg['tokens']
342
+
343
+ # Determine which 30s window this segment falls into
344
+ current_block = (start_time + correction) // 30
345
+
346
+ if prev_segment_end_time is not None:
347
+ # We subtract a tiny epsilon from prev_segment_end_time.
348
+ # If prev ended exactly at 30.0, it belongs to block 0, not block 1.
349
+ # 30.0 // 30 = 1 (Wrong) | 29.999 // 30 = 0 (Correct)
350
+ prev_block = (prev_segment_end_time - Decimal("0.001")) // 30
351
+
352
+ num_dummies = current_block - prev_block - 1
353
+
354
+ # Insert (30, [], 30) marker if we're moving to a new block
355
+ if current_block > prev_block:
356
+ result.append((30, [empty_text_token], 30))
357
+
358
+ # Insert dummy segments to bridge skipped 30s blocks
359
+ for _ in range(int(num_dummies)):
360
+ result.append((0, [empty_text_token], 30))
361
+ else:
362
+ # For the first segment, add dummy blocks if it starts after 30s
363
+ for _ in range(int(start_time // 30)):
364
+ result.append((0, [empty_text_token], 30))
365
+
366
+ # Determine whether segment fits in one block or wraps to the next
367
+ if ((start_time + correction) // 30 == (end_time + correction) // 30):
368
+ # Segment fits within a single 30s window
369
+ result.append(((start_time + correction) % 30, tokens, (end_time + correction) % 30))
370
+ elif (end_time + correction) % 30 == 0:
371
+ result.append(((start_time + correction) % 30, tokens, 30))
372
+ # Important: reset correction if we landed exactly on the boundary
373
+ correction = Decimal(0.0)
374
+ else:
375
+ # Segment would wrap across a 30s boundary
376
+ new_seg_start = (correction + start_time) % 30
377
+ seg_duration = end_time - start_time
378
+ new_end_time = (end_time + correction) % 30
379
+
380
+ if seg_duration == 30.0:
381
+ if float(new_seg_start) % 30.0 == 0.0:
382
+ new_end_time = Decimal(30.0)
383
+ correction = Decimal(0.0)
384
+ else:
385
+ correction = Decimal(-0.02)
386
+ new_end_time += Decimal(correction)
387
+ else:
388
+ correction = Decimal(0.0)
389
+ result.append((new_seg_start, tokens, new_end_time))
390
+
391
+ # Update the previous segment's end time for next iteration
392
+ prev_segment_end_time = end_time + correction
393
+
394
+ # Convert result segments into a token sequence with proper timestamp formatting
395
+ encoded = self.tokenizer(
396
+ "".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result])
397
+ )['input_ids']
398
+ results.append(encoded)
399
+
400
+ # Pad all sequences to the same length for batching
401
+ sequences = pad_sequence(
402
+ [torch.tensor(res, device=sequences['sequences'].device) for res in results],
403
+ batch_first=True,
404
+ padding_value=self.tokenizer.pad_token_id
405
+ )
406
+ return sequences
407
+
408
+ @staticmethod
409
+ def _retrieve_segment(
410
+ seek_sequence,
411
+ seek_outputs,
412
+ time_offset,
413
+ timestamp_begin,
414
+ seek_num_frames,
415
+ time_precision,
416
+ time_precision_features,
417
+ input_stride,
418
+ prev_idx,
419
+ idx,
420
+ return_token_timestamps,
421
+ decoder_input_ids,
422
+ ):
423
+ # find the predicted "end of segment" predictions of Whisper
424
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
425
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
426
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
427
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
428
+ timestamp_segment_indices.add_(1)
429
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
430
+ idx_offset = decoder_input_ids.shape[-1]
431
+ device = seek_sequence.device
432
+
433
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
434
+ # "end of segment" prediction and slice the decoding into segments accordingly
435
+ if len(timestamp_segment_indices) > 0:
436
+ # if the output contains two consecutive timestamp tokens
437
+ slices = timestamp_segment_indices.tolist()
438
+ segments = []
439
+ if single_timestamp_ending:
440
+ slices.append(len(seek_sequence))
441
+ else:
442
+ # we want to include the last timestamp token in the last segment to know it was no single ending
443
+ slices[-1] += 1
444
+
445
+ last_slice = 0
446
+ # Add each segment to list of all segments
447
+ for i, current_slice in enumerate(slices):
448
+ is_last_slice = i == len(slices) - 1
449
+ sliced_tokens = seek_sequence[last_slice:current_slice]
450
+ start_timestamp_pos = sliced_tokens[0] - timestamp_begin
451
+ idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
452
+ end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
453
+ segments.append(
454
+ {
455
+ "start": time_offset[prev_idx]
456
+ + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
457
+ * time_precision,
458
+ "end": time_offset[prev_idx]
459
+ + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
460
+ * time_precision,
461
+ "tokens": sliced_tokens,
462
+ "idxs": (idx_offset + last_slice, idx_offset + current_slice),
463
+ "result": seek_outputs[idx],
464
+ }
465
+ )
466
+ if return_token_timestamps:
467
+ segments[-1]["token_timestamps"] = (
468
+ token_timestamps[idx_offset + last_slice: idx_offset + current_slice] + time_offset[
469
+ prev_idx]
470
+ )
471
+ last_slice = current_slice
472
+
473
+ if single_timestamp_ending:
474
+ # single timestamp at the end means no speech after the last timestamp.
475
+ segment_offset = seek_num_frames[prev_idx]
476
+ else:
477
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
478
+ # here we throw away all predictions after the last predicted "end of segment"
479
+ # since we are cutting right in the middle of an audio
480
+ last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
481
+ segment_offset = last_timestamp_pos * input_stride
482
+ else:
483
+ # If whisper does not predict any "end of segment" token, then
484
+ # the whole decoding is considered a segment and we add it to the list of segments
485
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
486
+ start_timestamp_pos = 0.0
487
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
488
+ skip = False
489
+ segment_offset = seek_num_frames[prev_idx]
490
+
491
+ if timestamps.numel() > 1:
492
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
493
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
494
+ elif timestamps.numel() == 1:
495
+ # no consecutive timestamps but it has a timestamp; use the last one.
496
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
497
+ if start_timestamp_pos > 200:
498
+ # segment does not fit into decoding window, so we need to rollback
499
+ segment_offset = start_timestamp_pos * input_stride - 100 # timestamp might be inaccurate
500
+ skip = True
501
+ elif timestamps.numel() == 0 and len(seek_sequence) > 1:
502
+ # Decoding without timestamps, return output as it is
503
+ pass
504
+ else:
505
+ # empty sequence, or sequence w/o timestamps
506
+ skip = True
507
+
508
+ if skip:
509
+ segments = []
510
+ else:
511
+ segments = [
512
+ {
513
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
514
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
515
+ "tokens": seek_sequence,
516
+ "result": seek_outputs[idx],
517
+ }
518
+ ]
519
+ if return_token_timestamps:
520
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
521
+ segment_offset = seek_num_frames[prev_idx]
522
+
523
+ if segment_offset <= 0:
524
+ msg = f"Timestamps: {timestamps}, Segments: {segments}"
525
+ raise ValueError(f"Segment offset: {segment_offset} <= 0. This should not happen!\n{msg}")
526
+
527
+ return segments, segment_offset
528
+
529
+ def generate(
530
+ self,
531
+ generation_config: Optional[GenerationConfig] = None,
532
+ condition_on_prev_tokens: Optional[bool] = None,
533
+ assistant_model: Optional["PreTrainedModel"] = None,
534
+ **kwargs,
535
+ ):
536
+ if condition_on_prev_tokens:
537
+ raise NotImplementedError("Current version does not support conditioning")
538
+
539
+ gen_c, _ = self._prepare_generation_config(generation_config, **kwargs)
540
+ gen_mode = gen_c.get_generation_mode(assistant_model)
541
+
542
+ if gen_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.BEAM_SEARCH]:
543
+ raise ValueError(
544
+ f"Provided generation mode {gen_mode} is not supported"
545
+ f" for WhisperForConditionalGeneration with joint CTC decoding")
546
+
547
+ if "stno_mask" in kwargs:
548
+ self.stno_mask = kwargs["stno_mask"]
549
+
550
+ output = super().generate(**kwargs, return_segments=True)
551
+
552
+ self.encoder_logits = None
553
+
554
+ if isinstance(output, dict):
555
+ output = self._fix_timestamps_from_segmentation(output)
556
+
557
+ return output
558
+
559
+
560
+ def generate_with_fallback(
561
+ self,
562
+ segment_input,
563
+ decoder_input_ids,
564
+ cur_bsz,
565
+ seek,
566
+ batch_idx_map,
567
+ temperatures,
568
+ generation_config,
569
+ logits_processor,
570
+ stopping_criteria,
571
+ prefix_allowed_tokens_fn,
572
+ synced_gpus,
573
+ return_token_timestamps,
574
+ do_condition_on_prev_tokens,
575
+ is_shortform,
576
+ batch_size,
577
+ attention_mask,
578
+ kwargs,
579
+ ):
580
+ kwargs_local = copy.deepcopy(kwargs)
581
+ max_frames = attention_mask.sum(-1).cpu().to(torch.long)
582
+ kwargs_local, attention_mask = self.prepare_kwargs_for_generate(max_frames, cur_bsz, batch_idx_map, seek, kwargs_local, attention_mask)
583
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type = super().generate_with_fallback(
584
+ segment_input,
585
+ decoder_input_ids,
586
+ cur_bsz,
587
+ seek,
588
+ batch_idx_map,
589
+ temperatures,
590
+ generation_config,
591
+ logits_processor,
592
+ stopping_criteria,
593
+ prefix_allowed_tokens_fn,
594
+ synced_gpus,
595
+ return_token_timestamps,
596
+ do_condition_on_prev_tokens,
597
+ is_shortform,
598
+ batch_size,
599
+ attention_mask,
600
+ kwargs_local,
601
+ )
602
+ self.stno_mask_seek = None
603
+
604
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
605
+
606
+
607
+ def _sample(
608
+ self,
609
+ input_ids: torch.LongTensor,
610
+ logits_processor: LogitsProcessorList,
611
+ stopping_criteria: StoppingCriteriaList,
612
+ generation_config: GenerationConfig,
613
+ synced_gpus: bool = False,
614
+ streamer: Optional["BaseStreamer"] = None,
615
+ **model_kwargs,
616
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
617
+ r"""
618
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
619
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
620
+
621
+ Parameters:
622
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
623
+ The sequence used as a prompt for the generation.
624
+ logits_processor (`LogitsProcessorList`):
625
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
626
+ used to modify the prediction scores of the language modeling head applied at each generation step.
627
+ stopping_criteria (`StoppingCriteriaList`):
628
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
629
+ used to tell if the generation loop should stop.
630
+ generation_config ([`~generation.GenerationConfig`]):
631
+ The generation configuration to be used as parametrization of the decoding method.
632
+ synced_gpus (`bool`):
633
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
634
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
635
+ streamer (`BaseStreamer`, *optional*):
636
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
637
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
638
+ model_kwargs:
639
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
640
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
641
+
642
+ Return:
643
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
644
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
645
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
646
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
647
+ `model.config.is_encoder_decoder=True`.
648
+ """
649
+ # init values
650
+ pad_token_id = generation_config._pad_token_tensor
651
+ output_attentions = generation_config.output_attentions
652
+ output_hidden_states = generation_config.output_hidden_states
653
+ output_scores = generation_config.output_scores
654
+ output_logits = generation_config.output_logits
655
+ return_dict_in_generate = generation_config.return_dict_in_generate
656
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
657
+ do_sample = generation_config.do_sample
658
+
659
+ # init attention / hidden states / scores tuples
660
+ scores = () if (return_dict_in_generate and output_scores) else None
661
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
662
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
663
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
664
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
665
+
666
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
667
+ if return_dict_in_generate and self.config.is_encoder_decoder:
668
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
669
+ encoder_hidden_states = (
670
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
671
+ )
672
+
673
+ # keep track of which sequences are already finished
674
+ batch_size, cur_len = input_ids.shape[:2]
675
+ this_peer_finished = False
676
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
677
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
678
+
679
+ model_forward = self.__call__
680
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
681
+ if compile_forward:
682
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
683
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
684
+ if self.config._attn_implementation == "flash_attention_2":
685
+ # only raise warning if the user passed an explicit compile-config
686
+ if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
687
+ logger.warning_once(
688
+ "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
689
+ "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
690
+ )
691
+ generation_config.compile_config.fullgraph = False
692
+ model_forward = self.get_compiled_call(generation_config.compile_config)
693
+
694
+ if generation_config.prefill_chunk_size is not None:
695
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
696
+ is_prefill = False
697
+ else:
698
+ is_prefill = True
699
+
700
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
701
+ # prepare model inputs
702
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
703
+
704
+ if is_prefill:
705
+ outputs = self(**model_inputs, return_dict=True)
706
+ is_prefill = False
707
+ else:
708
+ outputs = model_forward(**model_inputs, return_dict=True)
709
+
710
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
711
+ model_kwargs = self._update_model_kwargs_for_generation(
712
+ outputs,
713
+ model_kwargs,
714
+ is_encoder_decoder=self.config.is_encoder_decoder,
715
+ )
716
+ if synced_gpus and this_peer_finished:
717
+ continue
718
+
719
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
720
+ # (the clone itself is always small)
721
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
722
+
723
+ # pre-process distribution
724
+ next_token_scores = logits_processor(input_ids, next_token_logits)
725
+
726
+ # Store scores, attentions and hidden_states when required
727
+ if return_dict_in_generate:
728
+ if output_scores:
729
+ scores += (next_token_scores,)
730
+ if output_logits:
731
+ raw_logits += (next_token_logits,)
732
+ if output_attentions:
733
+ decoder_attentions += (
734
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
735
+ )
736
+ if self.config.is_encoder_decoder:
737
+ cross_attentions += (outputs.cross_attentions,)
738
+
739
+ if output_hidden_states:
740
+ decoder_hidden_states += (
741
+ (outputs.decoder_hidden_states,)
742
+ if self.config.is_encoder_decoder
743
+ else (outputs.hidden_states,)
744
+ )
745
+
746
+ # token selection
747
+ if do_sample:
748
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
749
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
750
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
751
+ else:
752
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
753
+
754
+ # finished sentences should have their next token be a padding token
755
+ if has_eos_stopping_criteria:
756
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
757
+
758
+ """<DiCoW CODE>"""
759
+ # Based on the next tokens select the ctc prev states and scores
760
+ if hasattr(self, "ctc_rescorer"):
761
+ self.ctc_rescorer.update_state(next_tokens, torch.arange(next_tokens.shape[0]))
762
+ """</DiCoW CODE>"""
763
+
764
+ # update generated ids, model inputs, and length for next step
765
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
766
+ if streamer is not None:
767
+ streamer.put(next_tokens.cpu())
768
+
769
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
770
+ this_peer_finished = unfinished_sequences.max() == 0
771
+ cur_len += 1
772
+
773
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
774
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
775
+ del outputs
776
+
777
+ if streamer is not None:
778
+ streamer.end()
779
+
780
+ if return_dict_in_generate:
781
+ if self.config.is_encoder_decoder:
782
+ return GenerateEncoderDecoderOutput(
783
+ sequences=input_ids,
784
+ scores=scores,
785
+ logits=raw_logits,
786
+ encoder_attentions=encoder_attentions,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ decoder_attentions=decoder_attentions,
789
+ cross_attentions=cross_attentions,
790
+ decoder_hidden_states=decoder_hidden_states,
791
+ past_key_values=model_kwargs.get("past_key_values"),
792
+ )
793
+ else:
794
+ return GenerateDecoderOnlyOutput(
795
+ sequences=input_ids,
796
+ scores=scores,
797
+ logits=raw_logits,
798
+ attentions=decoder_attentions,
799
+ hidden_states=decoder_hidden_states,
800
+ past_key_values=model_kwargs.get("past_key_values"),
801
+ )
802
+ else:
803
+ return input_ids
804
+
805
+
806
+
807
+
808
+ def _beam_search(
809
+ self,
810
+ input_ids: torch.LongTensor,
811
+ logits_processor: LogitsProcessorList,
812
+ stopping_criteria: StoppingCriteriaList,
813
+ generation_config: GenerationConfig,
814
+ synced_gpus: bool,
815
+ **model_kwargs,
816
+ ) -> Union[GenerateBeamOutput, torch.LongTensor]:
817
+ r"""
818
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
819
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
820
+
821
+ If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
822
+ https://huggingface.co/blog/how-to-generate (especially the beam search section).
823
+
824
+ You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
825
+ (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
826
+
827
+ Parameters:
828
+ input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
829
+ The sequence used as a prompt for the generation.
830
+ logits_processor (`LogitsProcessorList`):
831
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
832
+ used to modify the prediction scores of the language modeling head applied at each generation step.
833
+ stopping_criteria (`StoppingCriteriaList`:
834
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
835
+ used to tell if the generation loop should stop.
836
+ generation_config ([`~generation.GenerationConfig`]):
837
+ The generation configuration to be used as parametrization of the decoding method.
838
+ synced_gpus (`bool`):
839
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
840
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
841
+ model_kwargs:
842
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
843
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
844
+
845
+ Return:
846
+ [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
847
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
848
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
849
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
850
+ `model.config.is_encoder_decoder=True`.
851
+ """
852
+
853
+ # 1. init beam_search values
854
+ pad_token_id = generation_config._pad_token_tensor
855
+ eos_token_id = generation_config._eos_token_tensor
856
+ output_attentions = generation_config.output_attentions
857
+ output_hidden_states = generation_config.output_hidden_states
858
+ output_scores = generation_config.output_scores
859
+ output_logits = generation_config.output_logits
860
+ return_dict_in_generate = generation_config.return_dict_in_generate
861
+ do_sample = generation_config.do_sample
862
+ early_stopping = generation_config.early_stopping
863
+ length_penalty = generation_config.length_penalty
864
+ max_length = generation_config.max_length
865
+ num_beams = generation_config.num_beams
866
+ num_return_sequences = generation_config.num_return_sequences
867
+
868
+ batch_size_unflattened, cur_len = input_ids.shape[:2]
869
+ batch_size = batch_size_unflattened // num_beams
870
+ # TODO (joao): standardize special cases
871
+ if self.__class__.__name__ == "MoshiDepthDecoder":
872
+ vocab_size = self.config.audio_vocab_size
873
+ elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
874
+ vocab_size = self.get_output_embeddings().out_features
875
+ else:
876
+ vocab_size = self.config.get_text_config().vocab_size
877
+ decoder_prompt_len = cur_len
878
+ this_peer_finished = False
879
+
880
+ # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
881
+ # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
882
+ # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
883
+ # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
884
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
885
+ beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
886
+ top_num_beam_mask = torch.cat(
887
+ (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
888
+ dim=0,
889
+ ).to(input_ids.device)
890
+
891
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
892
+
893
+ # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
894
+ # are newer low-memory alternatives like the offloaded cache)
895
+ sequential = generation_config.low_memory
896
+ if sequential:
897
+ raise ValueError(
898
+ "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
899
+ "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
900
+ )
901
+
902
+ # 2. init output tuples
903
+ all_scores = () if (return_dict_in_generate and output_scores) else None
904
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
905
+ beam_indices = () if (return_dict_in_generate and output_logits) else None
906
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
907
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
908
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
909
+
910
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
911
+ if return_dict_in_generate and self.config.is_encoder_decoder:
912
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
913
+ encoder_hidden_states = (
914
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
915
+ )
916
+
917
+ # 3. init running tensors and static-shaped placeholders
918
+
919
+ # per batch, beam-item holding current token in loop and completed sequences
920
+ output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
921
+ running_sequences = torch.full(
922
+ (batch_size, num_beams, max_length),
923
+ fill_value=output_fill_value,
924
+ dtype=torch.int64,
925
+ device=input_ids.device,
926
+ )
927
+ running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
928
+ sequences = running_sequences.detach().clone()
929
+
930
+ # per batch, beam-item score, logprobs
931
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
932
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
933
+ running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
934
+ running_beam_scores[:, 1:] = -1e9
935
+ beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
936
+
937
+ # per batch, beam-item state bit indicating if sentence has finished.
938
+ is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
939
+
940
+ # per batch state bit indicating if there is a possibility to improve the best finished sentence.
941
+ is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
942
+
943
+ # per batch, beam-item state bit indicating if there are valid continuations.
944
+ next_token_hits_stopping_criteria = torch.zeros(
945
+ (batch_size, num_beams), dtype=torch.bool, device=input_ids.device
946
+ )
947
+
948
+ # per batch selected beam indices
949
+ running_beam_indices = torch.full(
950
+ (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
951
+ )
952
+ beam_indices = running_beam_indices.detach().clone()
953
+
954
+ # 4. run the generation loop
955
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
956
+ # a. Forward current tokens, obtain the logits
957
+ flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
958
+ model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
959
+
960
+ # prepare variable output controls (note: some models won't accept all output controls)
961
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
962
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
963
+
964
+ model_outputs = self(**model_inputs, return_dict=True)
965
+
966
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
967
+ model_kwargs = self._update_model_kwargs_for_generation(
968
+ model_outputs,
969
+ model_kwargs,
970
+ is_encoder_decoder=self.config.is_encoder_decoder,
971
+ )
972
+ if synced_gpus and this_peer_finished:
973
+ continue
974
+
975
+ # Copy is needed to avoid keeping a hanging ref
976
+ logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
977
+
978
+ # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
979
+ # `temperature`, ...), and add new logprobs to existing running logprobs scores.
980
+ log_probs = nn.functional.log_softmax(logits, dim=-1)
981
+ log_probs = logits_processor(flat_running_sequences, log_probs)
982
+
983
+ # Store logits, attentions and hidden_states when required
984
+ if return_dict_in_generate:
985
+ if output_logits:
986
+ raw_logits += (logits.clone(),)
987
+ if return_dict_in_generate and output_scores:
988
+ all_scores += (log_probs.clone(),)
989
+
990
+ if output_attentions:
991
+ decoder_attentions += (
992
+ (model_outputs.decoder_attentions,)
993
+ if self.config.is_encoder_decoder
994
+ else (model_outputs.attentions,)
995
+ )
996
+ if self.config.is_encoder_decoder:
997
+ cross_attentions += (model_outputs.cross_attentions,)
998
+
999
+ if output_hidden_states:
1000
+ decoder_hidden_states += (
1001
+ (model_outputs.decoder_hidden_states,)
1002
+ if self.config.is_encoder_decoder
1003
+ else (model_outputs.hidden_states,)
1004
+ )
1005
+
1006
+ # This is needed to properly delete logits which may be very large for first iteration
1007
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
1008
+ del model_outputs
1009
+
1010
+ log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
1011
+ log_probs = log_probs + running_beam_scores[:, :, None]
1012
+ log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
1013
+
1014
+ # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
1015
+ # continuations among all beams based on the accumulated scores.
1016
+ topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
1017
+ accumulated_log_probs=log_probs,
1018
+ running_sequences=running_sequences,
1019
+ running_beam_indices=running_beam_indices,
1020
+ cur_len=cur_len,
1021
+ decoder_prompt_len=decoder_prompt_len,
1022
+ do_sample=do_sample,
1023
+ beams_to_keep=beams_to_keep,
1024
+ num_beams=num_beams,
1025
+ vocab_size=vocab_size,
1026
+ batch_size=batch_size,
1027
+ )
1028
+
1029
+ # d. Check which running sequences have finished
1030
+ next_token_hits_stopping_criteria = stopping_criteria(
1031
+ self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
1032
+ all_scores,
1033
+ )
1034
+ next_token_hits_stopping_criteria = self._unflatten_beam_dim(
1035
+ next_token_hits_stopping_criteria, batch_size, beams_to_keep
1036
+ )
1037
+
1038
+ # e. Get the non-finished running `num_beams` sequences for the next generation step
1039
+ running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
1040
+ topk_log_probs=topk_log_probs,
1041
+ topk_running_sequences=topk_running_sequences,
1042
+ topk_running_beam_indices=topk_running_beam_indices,
1043
+ next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
1044
+ num_beams=num_beams,
1045
+ )
1046
+
1047
+ # f. Update the completed beams if a new high score in a finished sequence is found
1048
+ sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
1049
+ sequences=sequences,
1050
+ topk_running_sequences=topk_running_sequences,
1051
+ beam_scores=beam_scores,
1052
+ topk_log_probs=topk_log_probs,
1053
+ beam_indices=beam_indices,
1054
+ topk_running_beam_indices=topk_running_beam_indices,
1055
+ is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
1056
+ is_sent_finished=is_sent_finished,
1057
+ next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
1058
+ top_num_beam_mask=top_num_beam_mask,
1059
+ num_beams=num_beams,
1060
+ cur_len=cur_len,
1061
+ decoder_prompt_len=decoder_prompt_len,
1062
+ length_penalty=length_penalty,
1063
+ early_stopping=early_stopping,
1064
+ )
1065
+
1066
+
1067
+ # g. Prepare remaining data for the next iteration, including computing the stopping condition for
1068
+ # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
1069
+
1070
+ beam_idx = None
1071
+ # pluck the cache from the beam indices that will be used in the next iteration
1072
+ # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
1073
+ if model_kwargs.get("past_key_values", None) is not None:
1074
+ beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
1075
+ if hasattr(self, "_reorder_cache"):
1076
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
1077
+ else:
1078
+ model_kwargs["past_key_values"].reorder_cache(beam_idx)
1079
+
1080
+ if hasattr(self, "ctc_rescorer"):
1081
+ self.ctc_rescorer.update_state(running_sequences.flatten(0,1)[:, cur_len], beam_idx)
1082
+
1083
+ cur_len = cur_len + 1
1084
+ is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
1085
+ is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
1086
+ running_beam_scores=running_beam_scores,
1087
+ beam_scores=beam_scores,
1088
+ is_sent_finished=is_sent_finished,
1089
+ cur_len=cur_len,
1090
+ max_length=max_length,
1091
+ decoder_prompt_len=decoder_prompt_len,
1092
+ early_stopping=early_stopping,
1093
+ length_penalty=length_penalty,
1094
+ )
1095
+ this_peer_finished = not self._beam_search_has_unfinished_sequences(
1096
+ is_early_stop_heuristic_unsatisfied,
1097
+ is_sent_finished,
1098
+ next_token_hits_stopping_criteria,
1099
+ early_stopping,
1100
+ )
1101
+
1102
+ # 5. prepare outputs
1103
+ # Take best beams for each batch (the score is sorted in descending order)
1104
+ sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
1105
+ beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
1106
+ beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
1107
+
1108
+ # Crop the static-shaped tensors to the actual size.
1109
+ # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
1110
+ # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
1111
+ # previous decoding iteration)
1112
+ max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
1113
+ output_length = decoder_prompt_len + max_generated_length
1114
+ sequences = sequences[:, :output_length]
1115
+ beam_indices = beam_indices[:, :max_generated_length]
1116
+
1117
+ if return_dict_in_generate:
1118
+ if not output_scores:
1119
+ beam_scores = None
1120
+
1121
+ if self.config.is_encoder_decoder:
1122
+ return GenerateBeamEncoderDecoderOutput(
1123
+ sequences=sequences,
1124
+ sequences_scores=beam_scores,
1125
+ scores=all_scores,
1126
+ logits=raw_logits,
1127
+ beam_indices=beam_indices,
1128
+ encoder_attentions=encoder_attentions,
1129
+ encoder_hidden_states=encoder_hidden_states,
1130
+ decoder_attentions=decoder_attentions,
1131
+ cross_attentions=cross_attentions,
1132
+ decoder_hidden_states=decoder_hidden_states,
1133
+ past_key_values=model_kwargs.get("past_key_values"),
1134
+ )
1135
+ else:
1136
+ return GenerateBeamDecoderOnlyOutput(
1137
+ sequences=sequences,
1138
+ sequences_scores=beam_scores,
1139
+ scores=all_scores,
1140
+ logits=raw_logits,
1141
+ beam_indices=beam_indices,
1142
+ attentions=decoder_attentions,
1143
+ hidden_states=decoder_hidden_states,
1144
+ past_key_values=model_kwargs.get("past_key_values"),
1145
+ )
1146
+ else:
1147
+ return sequences
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "begin_suppress_tokens": [
4
+ 220,
5
+ 50256
6
+ ],
7
+ "bos_token_id": 50257,
8
+ "decoder_start_token_id": 50258,
9
+ "eos_token_id": 50257,
10
+ "pad_token_id": 50257,
11
+ "transformers_version": "4.55.0"
12
+ }
layers.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
5
+ from transformers.activations import ACT2FN
6
+
7
+ class CustomLinear(nn.Linear):
8
+ def __init__(self, *args, init_eye_val=0.0, fddt_init=None, init_fun=None, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.init_eye_val = init_eye_val
11
+ self.fddt_init = fddt_init
12
+ self.init_fun = init_fun
13
+ self.reset_parameters() # Ensure consistent init on creation
14
+
15
+ def reset_parameters(self) -> None:
16
+ with torch.no_grad():
17
+ # Apply custom init function if provided
18
+ if hasattr(self,"init_fun") and self.init_fun is not None:
19
+ self.init_fun(self)
20
+ return
21
+
22
+ # Default initialization
23
+ nn.init.xavier_uniform_(self.weight)
24
+ if self.bias is not None:
25
+ nn.init.zeros_(self.bias)
26
+
27
+ if hasattr(self, "fddt_init"):
28
+ # FDDT-specific inits
29
+ if self.fddt_init == 'non-disturbing':
30
+ # Make weight an identity matrix (if possible)
31
+ if self.weight.shape[0] == self.weight.shape[1]:
32
+ self.weight.copy_(torch.eye(self.weight.shape[0], device=self.weight.device))
33
+ else:
34
+ # Not square — fill first min(n, m) diagonals
35
+ eye = torch.zeros_like(self.weight)
36
+ n = min(self.weight.shape)
37
+ eye[:n, :n] = torch.eye(n, device=self.weight.device)
38
+ self.weight.copy_(eye)
39
+
40
+ elif self.fddt_init == 'suppressive':
41
+ if self.weight.shape[0] == self.weight.shape[1]:
42
+ self.weight.copy_(self.init_eye_val * torch.eye(self.weight.shape[0], device=self.weight.device))
43
+ else:
44
+ eye = torch.zeros_like(self.weight)
45
+ n = min(self.weight.shape)
46
+ eye[:n, :n] = self.init_eye_val * torch.eye(n, device=self.weight.device)
47
+ self.weight.copy_(eye)
48
+
49
+ class CustomDiagonalLinear(nn.Module):
50
+ def __init__(self, d_model, bias=True, init_eye_val=0.0, fddt_init=None):
51
+ super().__init__()
52
+ self.init_eye_val = init_eye_val
53
+ self.weight = nn.Parameter(torch.full((d_model,), init_eye_val))
54
+ self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
55
+ self.fddt_init = fddt_init
56
+ self.reset_parameters()
57
+
58
+ def reset_parameters(self):
59
+ with torch.no_grad():
60
+ # random init
61
+ fan = self.weight.size(0)
62
+ bound = math.sqrt(3.0 / fan)
63
+ self.weight.uniform_(-bound, bound)
64
+ if self.bias is not None:
65
+ self.bias.zero_()
66
+
67
+ # custom modes
68
+ if self.fddt_init == 'non-disturbing':
69
+ self.weight.fill_(1.0)
70
+ elif self.fddt_init == 'suppressive':
71
+ self.weight.fill_(self.init_eye_val)
72
+
73
+ def forward(self, input):
74
+ out = input * self.weight
75
+ if self.bias is not None:
76
+ out += self.bias
77
+ return out
78
+
79
+ class Gate(nn.Module):
80
+ def __init__(self, items, init_val=0.0):
81
+ super().__init__()
82
+ self.init_val = init_val
83
+ self.gate = nn.Parameter(torch.full((items,), init_val))
84
+ self.reset_parameters()
85
+
86
+ def forward(self, orig_seq, new_seq):
87
+ gate_act = torch.nn.functional.tanh(self.gate)
88
+ output = orig_seq + gate_act * new_seq
89
+ return output
90
+
91
+ def reset_parameters(self):
92
+ with torch.no_grad():
93
+ self.gate.fill_(self.init_val)
94
+
95
+ def propagate_first_half_embeds_init(module):
96
+ # Zero out all weights initially
97
+ # module.weight.data.zero_()
98
+ torch.nn.init.xavier_uniform_(module.weight, gain=1e-1)
99
+
100
+ # Create identity mapping for first half of input (cross_attn_output)
101
+ # Input: [cross_attn_output, q_orig] -> map cross_attn_output to first embed_dim outputs
102
+ module.weight.data[:module.weight.shape[1] // 2, :module.weight.shape[1] // 2] += torch.eye(
103
+ module.weight.shape[1] // 2)
104
+
105
+ # Zero bias
106
+ module.bias.data.zero_()
107
+
108
+
109
+ def propage_first_embeds_to_match_output_dim_init(module):
110
+ # module.weight.data.zero_()
111
+ torch.nn.init.xavier_uniform_(module.weight, gain=1e-1)
112
+
113
+ # Create identity mapping from first embed_dim inputs to output
114
+ module.weight.data[:, :module.weight.shape[0]] += torch.eye(module.weight.shape[0])
115
+
116
+ # Zero bias for second linear
117
+ module.bias.data.zero_()
118
+
119
+ # Cross attention block that can easily learn to ignore cross attention initially
120
+ class CrossAttentionEnrollBlock(nn.Module):
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ self.embed_dim = config.d_model
124
+ self.ffn_dim = config.encoder_ffn_dim
125
+
126
+ self.cross_attn = WhisperAttention(
127
+ embed_dim=self.embed_dim,
128
+ num_heads=config.encoder_attention_heads,
129
+ dropout=config.attention_dropout,
130
+ config=config,
131
+ )
132
+
133
+ # Layer normalization (pre-norm style)
134
+ # self.norm_attn = nn.LayerNorm(self.embed_dim, eps=layer_norm_eps)
135
+ self.cross_gate = Gate(1,init_val=.0)
136
+ # Feed-forward network that maps concat space back to single channel
137
+ self.ffn = nn.Sequential(
138
+ CustomLinear(self.embed_dim * 2, self.ffn_dim, init_fun=propagate_first_half_embeds_init),
139
+ ACT2FN[config.activation_function],
140
+ nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1),
141
+ CustomLinear(self.ffn_dim, self.embed_dim, init_fun=propage_first_embeds_to_match_output_dim_init),
142
+ nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
143
+ )
144
+
145
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Args:
148
+ hidden_states: (B, 2, T, F) - batch, channels, time, features
149
+ Returns:
150
+ Updated hidden states of same shape
151
+ """
152
+ q = hidden_states[:, 0] # (B, T, F)
153
+ kv = hidden_states[:, 1] # (B, T, F)
154
+
155
+ # Cross-attention
156
+ attn_output = self.cross_attn(
157
+ hidden_states=q,
158
+ key_value_states=kv,
159
+ output_attentions=False
160
+ )[0]
161
+
162
+ # Concatenate attention output with original normalized query
163
+ q_concat = torch.cat([attn_output, q], dim=-1) # (B, T, 2*F)
164
+
165
+ # Feed-forward processing (no normalization to preserve initialization)
166
+ updated_q = self.ffn(q_concat) # (B, T, F)
167
+
168
+ q_out = self.cross_gate(q, updated_q)
169
+ # Return stacked result (only query channel is updated)
170
+ return torch.stack([q_out, kv], dim=1)
171
+
172
+ class SpeakerCommunicationBlock(nn.Module):
173
+ def __init__(self, config):
174
+ super().__init__()
175
+ self.streams = 2
176
+ self.config = config
177
+
178
+ self.cae = CrossAttentionEnrollBlock(config)
179
+
180
+ def forward(self, x):
181
+ # x: (B, T, F)
182
+ B, T, F = x.shape
183
+ S = self.streams
184
+
185
+ # Reshape to (B//S, S, T, F)
186
+ x_reshaped = x.view(B//S, S, T, F)
187
+
188
+ # Call the selected method
189
+ out = self.cae(x_reshaped)
190
+
191
+ # Reshape back (B, T, F)
192
+ out_merged = out.view(B, T, F)
193
+ return out_merged
194
+
195
+
196
+ if __name__ == "__main__":
197
+ model1 = CustomLinear(16 * 2, 64, init_fun=propagate_first_half_embeds_init)
198
+ model2 = CustomLinear(64, 16, init_fun=propage_first_embeds_to_match_output_dim_init)
199
+ input1 = torch.ones(16, 16)
200
+ input2 = torch.zeros(16, 16)
201
+ input = torch.concat((input1, input2), dim=-1)
202
+ output = model2(model1(input))
203
+ print(f"Mean err: {(input1-output).mean()}")
204
+
205
+
206
+ model_1 = CustomDiagonalLinear(4, bias=False, fddt_init='suppressive', init_eye_val=0.1)
207
+ model_2 = CustomDiagonalLinear(4, bias=False, fddt_init='suppressive', init_eye_val=0.1)
208
+ model_3 = CustomDiagonalLinear(4, bias=False, fddt_init='suppressive', init_eye_val=0.1)
209
+ model_4 = CustomDiagonalLinear(4, bias=False, fddt_init='suppressive', init_eye_val=0.1)
210
+ model = nn.Sequential(model_1, model_2, model_3, model_4)
211
+ opt = torch.optim.Adam(model.parameters(), lr=0.01)
212
+ model_1.reset_parameters()
213
+
214
+
215
+ x = torch.ones(2, 4)
216
+ y = torch.ones(2, 4)
217
+
218
+ for i in range(100):
219
+ opt.zero_grad()
220
+ loss = ((model(x) - y) ** 2).mean()
221
+ loss.backward()
222
+ opt.step()
223
+ print(f"Step {i}: mean weight {model_1.weight.mean().item():.4f}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e42482cf705a422b022a6c0df84f02be60da5a2fbf454ca517001359a1b9b20
3
+ size 4407276024
modeling_dicow.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ import re
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import Cache
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
9
+ from transformers.models.whisper.modeling_whisper import (
10
+ WhisperForConditionalGeneration,
11
+ shift_tokens_right,
12
+ WhisperModel
13
+ )
14
+ from transformers.utils import logging
15
+ from .config import DiCoWConfig
16
+ from .encoder import DiCoWEncoder
17
+ from .generation import DiCoWGenerationMixin
18
+
19
+ logging.set_verbosity_debug()
20
+ logger = logging.get_logger("transformers")
21
+
22
+
23
+ class SoftLabelCreator(torch.nn.Module):
24
+ """
25
+ Handles label smoothing for timestamps and the dual-loss logic (Upper vs Lower case).
26
+ """
27
+
28
+ def __init__(self, tokenizer, timestamp_sigma=0.08):
29
+ super().__init__()
30
+ self.tokenizer = tokenizer
31
+ self.timestamp_sigma = timestamp_sigma
32
+ # Pre-compute the Gaussian smoothing matrix
33
+ self.register_buffer('ts_smoothing_matrix', self._build_smoothing_matrix())
34
+
35
+ def _build_smoothing_matrix(self):
36
+ # FIX: Use get_vocab() instead of .decoder.items()
37
+ vocab = self.tokenizer.get_vocab()
38
+ vocab_size = len(vocab)
39
+
40
+ timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>')
41
+
42
+ # 1. Map Token IDs to Time Values
43
+ id_to_time = {}
44
+ for token_str, token_id in vocab.items():
45
+ match = timestamp_pattern.match(token_str)
46
+ if match:
47
+ id_to_time[token_id] = float(match.group(1))
48
+
49
+ if not id_to_time:
50
+ return None
51
+
52
+ # Sorted list for fast lookups
53
+ sorted_ids = sorted(id_to_time.keys())
54
+ self.sorted_ts_ids = torch.tensor(sorted_ids)
55
+ times = torch.tensor([id_to_time[i] for i in sorted_ids])
56
+
57
+ # 2. Create the Smoothing Matrix (Num_Timestamps x Vocab_Size)
58
+ num_ts = len(sorted_ids)
59
+ smoothing_matrix = torch.zeros(num_ts, vocab_size)
60
+
61
+ # Vectorized Gaussian computation
62
+ diff_sq = (times.unsqueeze(1) - times.unsqueeze(0)) ** 2
63
+ weights = torch.exp(-diff_sq / (2 * self.timestamp_sigma ** 2))
64
+
65
+ # Normalize
66
+ weights = weights / weights.sum(dim=1, keepdim=True)
67
+
68
+ # Scatter rows back to vocab size
69
+ for i, ts_id in enumerate(sorted_ids):
70
+ smoothing_matrix[i, self.sorted_ts_ids] = weights[i]
71
+
72
+ return smoothing_matrix
73
+
74
+ def _get_soft_distribution(self, labels, vocab_size):
75
+ """Internal helper to convert hard labels -> soft timestamp labels"""
76
+ device = labels.device
77
+
78
+ # Start with One-Hot (clamp -100 to 0 temporarily)
79
+ labels_clamped = labels.clamp(min=0)
80
+ soft_labels = F.one_hot(labels_clamped, num_classes=vocab_size).float()
81
+
82
+ # Apply Timestamp Smoothing if matrix exists
83
+ if hasattr(self, 'ts_smoothing_matrix') and self.ts_smoothing_matrix is not None:
84
+ sorted_ts_ids = self.sorted_ts_ids.to(device)
85
+ smoothing_matrix = self.ts_smoothing_matrix.to(device)
86
+
87
+ is_timestamp = torch.isin(labels, sorted_ts_ids)
88
+
89
+ if is_timestamp.any():
90
+ ts_indices = torch.searchsorted(sorted_ts_ids, labels[is_timestamp])
91
+ soft_labels[is_timestamp] = smoothing_matrix[ts_indices]
92
+
93
+ return soft_labels
94
+
95
+ def compute_loss(self, logits, labels, upp_labels):
96
+ """
97
+ Computes the enhanced SOT loss:
98
+ 1. Generates soft labels (timestamp smoothed) for both 'labels' and 'upp_labels'.
99
+ 2. Computes KL Divergence (via CrossEntropy) for both.
100
+ 3. Takes the minimum loss per token (case invariance).
101
+ 4. Applies padding mask.
102
+ """
103
+ vocab_size = logits.size(-1)
104
+ device = logits.device
105
+
106
+ # Ensure labels are on correct device
107
+ labels = labels.to(device)
108
+ if upp_labels is not None:
109
+ upp_labels = upp_labels.to(device)
110
+
111
+ # Flatten inputs
112
+ flat_logits = logits.view(-1, vocab_size)
113
+ flat_labels = labels.reshape(-1)
114
+
115
+ # 1. Generate Soft Targets for Lowercase
116
+ soft_lower = self._get_soft_distribution(flat_labels, vocab_size)
117
+
118
+ # 2. Generate Soft Targets for Uppercase (if provided)
119
+ if upp_labels is not None:
120
+ flat_upp = upp_labels.reshape(-1)
121
+ soft_upper = self._get_soft_distribution(flat_upp, vocab_size)
122
+ else:
123
+ # Fallback if no upper labels provided (shouldn't happen in this pipeline)
124
+ soft_upper = soft_lower
125
+
126
+ # 3. Compute Cross Entropy (Soft Target Mode)
127
+ # Note: CE with soft targets = -sum(target * log_prob)
128
+ loss_fct = CrossEntropyLoss(reduction='none')
129
+
130
+ loss_lower = loss_fct(flat_logits, soft_lower)
131
+ loss_upper = loss_fct(flat_logits, soft_upper)
132
+
133
+ # 4. Mask Padding (ignore_index = -100)
134
+ # Soft-target CE doesn't support ignore_index automatically
135
+ mask = (flat_labels != -100).float()
136
+
137
+ loss_lower = loss_lower * mask
138
+ loss_upper = loss_upper * mask
139
+
140
+ # 5. Take Min (Case Invariance) and Normalize
141
+ combined_min = torch.min(loss_lower, loss_upper)
142
+
143
+ # Sum and divide by number of non-padding tokens
144
+ return combined_min.sum() / mask.sum().clamp(min=1)
145
+
146
+ class DiCoW(WhisperModel):
147
+ def __init__(self, config: DiCoWConfig):
148
+ super().__init__(config)
149
+ self.encoder = DiCoWEncoder(config)
150
+ self.post_init()
151
+
152
+ def forward(
153
+ self,
154
+ input_features: Optional[torch.FloatTensor] = None,
155
+ attention_mask: Optional[torch.LongTensor] = None,
156
+ stno_mask: Optional[torch.FloatTensor] = None,
157
+ decoder_input_ids: Optional[torch.LongTensor] = None,
158
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
159
+ head_mask: Optional[torch.Tensor] = None,
160
+ decoder_head_mask: Optional[torch.Tensor] = None,
161
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
162
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
163
+ past_key_values: Optional[Cache] = None,
164
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
165
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
166
+ use_cache: Optional[bool] = None,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ cache_position: Optional[torch.LongTensor] = None,
171
+ enrollments=None
172
+ ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
173
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
174
+ output_hidden_states = (
175
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
176
+ )
177
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
178
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
179
+
180
+ if encoder_outputs is None:
181
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
182
+
183
+ encoder_outputs = self.encoder(
184
+ input_features,
185
+ output_attentions=output_attentions,
186
+ output_hidden_states=output_hidden_states,
187
+ head_mask=head_mask,
188
+ return_dict=return_dict,
189
+ stno_mask=stno_mask,
190
+ enrollments=enrollments
191
+ )
192
+
193
+ decoder_outputs = self.decoder(
194
+ input_ids=decoder_input_ids,
195
+ attention_mask=decoder_attention_mask,
196
+ encoder_hidden_states=encoder_outputs[0],
197
+ head_mask=decoder_head_mask,
198
+ cross_attn_head_mask=cross_attn_head_mask,
199
+ past_key_values=past_key_values,
200
+ inputs_embeds=decoder_inputs_embeds,
201
+ position_ids=decoder_position_ids,
202
+ use_cache=use_cache,
203
+ output_attentions=output_attentions,
204
+ output_hidden_states=output_hidden_states,
205
+ return_dict=return_dict,
206
+ cache_position=cache_position,
207
+ )
208
+
209
+ if not return_dict:
210
+ return decoder_outputs + encoder_outputs
211
+
212
+ return Seq2SeqModelOutput(
213
+ last_hidden_state=decoder_outputs.last_hidden_state,
214
+ past_key_values=decoder_outputs.past_key_values,
215
+ decoder_hidden_states=decoder_outputs.hidden_states,
216
+ decoder_attentions=decoder_outputs.attentions,
217
+ cross_attentions=decoder_outputs.cross_attentions,
218
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
219
+ encoder_hidden_states=encoder_outputs.hidden_states,
220
+ encoder_attentions=encoder_outputs.attentions,
221
+ )
222
+
223
+
224
+ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
225
+ config_class = DiCoWConfig
226
+
227
+ def __init__(self, config: DiCoWConfig):
228
+ super().__init__(config)
229
+ self.model = DiCoW(config)
230
+ self.encoder_logits = None
231
+ self.tokenizer = None
232
+ self.stno_mask = None
233
+ self.stno_mask_seek = None
234
+ self.soft_label_creator = None
235
+ self.post_init()
236
+
237
+ def set_tokenizer(self, tokenizer):
238
+ self.tokenizer = tokenizer
239
+ # Initialize the helper class
240
+ self.soft_label_creator = SoftLabelCreator(tokenizer)
241
+
242
+ def get_enc_logits(self, hidden_states):
243
+ encoder = self.model.get_encoder()
244
+ hidden_states = encoder.possibly_update_last_hidden_states(hidden_states)
245
+ logits = encoder.lm_head(hidden_states)
246
+ return logits
247
+
248
+ def forward(
249
+ self,
250
+ input_features: Optional[torch.FloatTensor] = None,
251
+ attention_mask: Optional[torch.LongTensor] = None,
252
+ stno_mask: Optional[torch.FloatTensor] = None,
253
+ decoder_input_ids: Optional[torch.LongTensor] = None,
254
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
255
+ head_mask: Optional[torch.Tensor] = None,
256
+ decoder_head_mask: Optional[torch.Tensor] = None,
257
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
258
+ encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
259
+ past_key_values: Optional[Cache] = None,
260
+ decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
261
+ decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
262
+ labels: Optional[torch.LongTensor] = None,
263
+ upp_labels: Optional[torch.LongTensor] = None,
264
+ use_cache: Optional[bool] = None,
265
+ output_attentions: Optional[bool] = None,
266
+ output_hidden_states: Optional[bool] = None,
267
+ return_dict: Optional[bool] = None,
268
+ cache_position: Optional[torch.LongTensor] = None,
269
+ forced_decoder_ids: Optional[torch.LongTensor] = None,
270
+ enrollments=None,
271
+ ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
272
+
273
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
274
+
275
+ if labels is not None:
276
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
277
+ decoder_input_ids = shift_tokens_right(
278
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
279
+ )
280
+
281
+ outputs = self.model(
282
+ input_features,
283
+ attention_mask=attention_mask,
284
+ decoder_input_ids=decoder_input_ids,
285
+ encoder_outputs=encoder_outputs,
286
+ decoder_attention_mask=decoder_attention_mask,
287
+ head_mask=head_mask,
288
+ decoder_head_mask=decoder_head_mask,
289
+ cross_attn_head_mask=cross_attn_head_mask,
290
+ past_key_values=past_key_values,
291
+ decoder_inputs_embeds=decoder_inputs_embeds,
292
+ decoder_position_ids=decoder_position_ids,
293
+ use_cache=use_cache,
294
+ output_attentions=output_attentions,
295
+ output_hidden_states=output_hidden_states,
296
+ return_dict=return_dict,
297
+ cache_position=cache_position,
298
+ stno_mask=stno_mask,
299
+ enrollments=enrollments,
300
+ )
301
+
302
+ dec_lm_logits = self.proj_out(outputs.last_hidden_state)
303
+ loss = None
304
+
305
+ if labels is not None:
306
+ # --- UPDATED LOSS CALCULATION ---
307
+ if self.soft_label_creator is not None:
308
+ # Delegate all soft label creation, flattening, and min-loss logic to the helper
309
+ dec_loss = self.soft_label_creator.compute_loss(dec_lm_logits, labels, upp_labels)
310
+ else:
311
+ # Fallback to original hard label implementation if tokenizer/helper not ready
312
+ loss_fct = CrossEntropyLoss(reduction='none')
313
+ labels = labels.to(dec_lm_logits.device)
314
+
315
+ flat_logits = dec_lm_logits.view(-1, self.config.vocab_size)
316
+ dec_loss1 = loss_fct(flat_logits, labels.reshape(-1))
317
+
318
+ if upp_labels is not None:
319
+ upp_labels = upp_labels.to(dec_lm_logits.device)
320
+ dec_loss2 = loss_fct(flat_logits, upp_labels.reshape(-1))
321
+ dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
322
+ else:
323
+ dec_loss = dec_loss1.mean()
324
+ # --------------------------------
325
+
326
+ if self.config.ctc_weight > 0.0:
327
+ enc_lm_logits = self.get_enc_logits(outputs.encoder_last_hidden_state)
328
+ # Prepare CTC labels
329
+ enc_labels = labels.clone().to(dec_lm_logits.device)
330
+ for token in self.tokenizer.prefix_tokens:
331
+ if (enc_labels[:, 0] == token).all():
332
+ enc_labels = enc_labels[:, 1:]
333
+ enc_labels[enc_labels == self.config.eos_token_id] = -100
334
+
335
+ ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
336
+ loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
337
+ else:
338
+ loss = dec_loss
339
+
340
+ if not return_dict:
341
+ output = (dec_lm_logits,) + outputs[1:]
342
+ return ((loss,) + output) if loss is not None else output
343
+
344
+ return Seq2SeqLMOutput(
345
+ loss=loss,
346
+ logits=dec_lm_logits,
347
+ past_key_values=outputs.past_key_values,
348
+ decoder_hidden_states=outputs.decoder_hidden_states,
349
+ decoder_attentions=outputs.decoder_attentions,
350
+ cross_attentions=outputs.cross_attentions,
351
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
352
+ encoder_hidden_states=outputs.encoder_hidden_states,
353
+ encoder_attentions=outputs.encoder_attentions,
354
+ )
355
+
356
+ def _get_feat_extract_output_lengths(self, attention_mask: torch.LongTensor) -> torch.LongTensor:
357
+ return (self.model.get_encoder()._get_feat_extract_output_lengths(attention_mask) / 4).ceil()
utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperTimeStampLogitsProcessor
3
+
4
+
5
+ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
6
+
7
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
8
+ scores_processed = super().__call__(input_ids, scores)
9
+
10
+ # Enable to early exit from silence via eos token
11
+ if input_ids.shape[1] == self.begin_index:
12
+ scores_processed[:, self.eos_token_id] = scores[:, self.eos_token_id]
13
+
14
+ return scores_processed