Lakoc commited on
Commit
c00ff2c
·
verified ·
1 Parent(s): c4fb00d

Upload DiCoWForConditionalGeneration

Browse files
Files changed (15) hide show
  1. FDDT.py +81 -0
  2. README.md +199 -0
  3. SCBs.py +245 -0
  4. coattention.py +120 -0
  5. config.json +80 -0
  6. config.py +90 -0
  7. contrastive_loss.py +140 -0
  8. decoding.py +397 -0
  9. encoder.py +268 -0
  10. generation.py +1768 -0
  11. generation_config.json +12 -0
  12. layers.py +38 -0
  13. model.safetensors +3 -0
  14. modeling_dicow.py +387 -0
  15. utils.py +96 -0
FDDT.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .layers import CustomDiagonalLinear, CustomLinear
7
+
8
+
9
+ class FDDT(nn.Module):
10
+ def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True,
11
+ use_target=True, use_overlap=True, use_non_target=True, use_interaction=False,
12
+ scb_module: Optional[nn.Module] = None, ):
13
+ super().__init__()
14
+ if use_target:
15
+ self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
16
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
17
+ d_model,
18
+ bias=True,
19
+ init_eye_val=1.0))
20
+ if use_non_target:
21
+ self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
22
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
23
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
24
+ if use_overlap:
25
+ self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
26
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
27
+ d_model,
28
+ bias=True,
29
+ init_eye_val=1.0))
30
+ if use_silence:
31
+ self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
32
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
33
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
34
+
35
+ if use_interaction:
36
+ self.scb = scb_module
37
+
38
+ self.use_silence = use_silence
39
+ self.use_target = use_target
40
+ self.use_overlap = use_overlap
41
+ self.use_non_target = use_non_target
42
+ self.use_interaction = use_interaction
43
+ self.bias_only = bias_only
44
+
45
+ @staticmethod
46
+ def mask_out_non_interaction_signal(hidden_states, mask):
47
+ mask = torch.round(mask).bool()
48
+ masked_hidden_states = hidden_states * mask
49
+ return masked_hidden_states
50
+
51
+ def forward(self, hidden_states, stno_mask):
52
+ stno_mask = stno_mask.to(hidden_states.device)[..., None]
53
+ if self.bias_only:
54
+ if self.use_silence:
55
+ hidden_states += stno_mask[:, 0, ...] * self.silence_linear
56
+ if self.use_target:
57
+ hidden_states += stno_mask[:, 1, ...] * self.target_linear
58
+ if self.use_non_target:
59
+ hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
60
+ if self.use_overlap:
61
+ hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
62
+ # if self.use_interaction:
63
+ # hidden_states += stno_mask[:, 4, ...] * self.scb
64
+ else:
65
+ orig_hidden_states = hidden_states
66
+ hidden_states = (self.silence_linear(
67
+ orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
68
+ (self.target_linear(
69
+ orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
70
+ (self.non_target_linear(
71
+ orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
72
+ :] + \
73
+ (self.overlap_linear(
74
+ orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :]
75
+ # (self.scb(orig_hidden_states) * stno_mask[:, 4,:] if self.use_interaction else (
76
+ # 0 if stno_mask.size(
77
+ # 1) == 4 else orig_hidden_states * stno_mask[:, 4,
78
+ # :]))
79
+ if self.use_interaction:
80
+ hidden_states = self.scb(hidden_states)
81
+ return hidden_states
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
SCBs.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import WhisperConfig
4
+ from transformers.activations import ACT2FN
5
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
6
+ import torch.nn.functional as F
7
+ from .coattention import CoAttention
8
+ from .layers import CustomLinear, CustomDiagonalLinear, Gate
9
+
10
+ class LowRankApproxSelectFirst(nn.Module):
11
+ def __init__(self, d_in, d_out, rank):
12
+ super().__init__()
13
+ self.d_in = d_in
14
+ self.d_out = d_out
15
+ self.rank = rank
16
+ self.proj_in = nn.Linear(d_in, rank)
17
+ self.proj_out = nn.Linear(rank, d_out)
18
+
19
+ def forward(self, x):
20
+ return self.proj_out(self.proj_in(x))
21
+
22
+ def _init_weights(self):
23
+ # Create low-rank approximation of the identity projection from first d_out of input
24
+ eye = torch.eye(self.d_out, self.d_in) # (d_out x d_in)
25
+
26
+ # Low-rank SVD of eye matrix
27
+ U, S, Vh = torch.linalg.svd(eye, full_matrices=False) # U: (d_out x d_out), Vh: (d_in x d_in)
28
+
29
+ U_k = U[:, :self.rank] # (d_out x rank)
30
+ S_k = S[:self.rank] # (rank,)
31
+ V_k = Vh[:self.rank, :] # (rank x d_in)
32
+
33
+ A = V_k # (rank x d_in)
34
+ B = U_k @ torch.diag(S_k) # (d_out x rank)
35
+
36
+ # Set weights
37
+ self.proj_in.weight.data.copy_(A)
38
+ self.proj_in.bias.data.zero_()
39
+ self.proj_out.weight.data.copy_(B)
40
+ self.proj_out.bias.data.zero_()
41
+
42
+
43
+
44
+ class TACBlock(nn.Module):
45
+ def __init__(self, config: WhisperConfig, d_int_factor: float = 1, num_speakers=2):
46
+ super().__init__()
47
+ d = config.d_model
48
+ d_prime = int(d * d_int_factor)
49
+ self.num_speakers = num_speakers
50
+ self.proj_in_1 = nn.Linear(d, d_prime, bias=True)
51
+ self.proj_in_2 = nn.Linear(d, d_prime, bias=True)
52
+ self.proj_int = nn.Linear(d_prime, d_prime,bias=True)
53
+ self.proj_out_1 = nn.Linear(d+d_prime, d,bias=True)
54
+ self.proj_out_2 = nn.Linear(d+d_prime, d,bias=True)
55
+ self.activation_fn = ACT2FN[config.activation_function]
56
+ self.norms = nn.ModuleList([nn.LayerNorm(d) for _ in range(self.num_speakers)])
57
+ self.gate = Gate(self.num_speakers, 0.01)
58
+
59
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
60
+ # hidden_states: (B, self.num_speakers, T, F)
61
+
62
+ x_proj = torch.stack([self.activation_fn(self.proj_in_1(hidden_states[:,0])), self.activation_fn(self.proj_in_2(hidden_states[:, 1]))], dim=1) # (B, 2, T, d')
63
+ x_mean = x_proj.mean(dim=1, keepdim=True) # (B, 1, T, d')
64
+ z = self.activation_fn(self.proj_int(x_mean)) # (B, 1, T, d')
65
+
66
+ z_expand = z.expand(-1, self.num_speakers, -1, -1) # (B, self.num_speakers, T, d')
67
+ x_cat = torch.cat([hidden_states, z_expand], dim=-1) # (B, self.num_speakers, T, d + d')
68
+ x_out = torch.stack([self.norms[0](self.proj_out_1(x_cat[:, 0])), self.norms[1](self.proj_out_2(x_cat[:, 1]))], dim=1) # (B, self.num_speakers, T, d)
69
+ return hidden_states + self.gate(x_out, dim=1)
70
+
71
+
72
+ class CrossAttentionBlock(nn.Module):
73
+ def __init__(self, config: WhisperConfig):
74
+ super().__init__()
75
+ self.embed_dim = config.d_model
76
+
77
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
78
+ if self.num_speakers != 2:
79
+ raise ValueError("CrossAttentionBlock supports only 2 speakers.")
80
+
81
+ # Separate attention block per speaker
82
+ self.attn_blocks = nn.ModuleList([
83
+ WHISPER_ATTENTION_CLASSES[config._attn_implementation](
84
+ embed_dim=self.embed_dim,
85
+ num_heads=config.encoder_attention_heads,
86
+ dropout=config.attention_dropout,
87
+ config=config,
88
+ )
89
+ for _ in range(self.num_speakers)
90
+ ])
91
+
92
+ self.norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(self.num_speakers)])
93
+ self.gate = Gate(self.num_speakers, 0.01)
94
+
95
+ def forward(self, hidden_states):
96
+ # hidden_states: (B, 2, T, F)
97
+ outputs = []
98
+ for s in range(self.num_speakers):
99
+ q = hidden_states[:, s] # (B, T, F)
100
+ other_s = 1 - s
101
+ kv = hidden_states[:, other_s] # (B, T, F)
102
+
103
+ attn_out, _, _ = self.attn_blocks[s](hidden_states=q, key_value_states=kv) # (B, T, F)
104
+ outputs.append(self.norms[s](attn_out[:, None, :, :]))
105
+ outputs = torch.concat(outputs, dim=1)
106
+ outputs_modulated = self.gate(outputs, dim=1) + hidden_states
107
+ return outputs_modulated
108
+
109
+
110
+ class CompetitiveCrossAttentionBlock(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.embed_dim = config.d_model
114
+ self.num_heads = config.encoder_attention_heads
115
+ self.head_dim = self.embed_dim // self.num_heads
116
+ assert (
117
+ self.head_dim * self.num_heads == self.embed_dim
118
+ ), "embed_dim must be divisible by num_heads"
119
+
120
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
121
+ if self.num_speakers != 2:
122
+ raise ValueError("CompetitiveCrossAttentionBlock supports only 2 speakers.")
123
+
124
+ # Separate projections for Q, K, V
125
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
126
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
127
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
128
+
129
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
130
+
131
+ self.norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(self.num_speakers)])
132
+ self.eps = 1e-6
133
+ self.gate = Gate(self.num_speakers, 0.01)
134
+
135
+ def _shape(self, tensor, seq_len, batch_size):
136
+ # reshape into (B, num_heads, T, head_dim)
137
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
138
+
139
+ def forward(self, hidden_states):
140
+ # hidden_states: (B, 2, T, F)
141
+ B, _, T, _ = hidden_states.shape
142
+
143
+ h1, h2 = hidden_states[:, 0], hidden_states[:, 1] # (B, T, F)
144
+
145
+ # Project Q,K,V
146
+ Q1 = self.q_proj(h1) # (B, T, F)
147
+ K2 = self.k_proj(h2)
148
+ V2 = self.v_proj(h2)
149
+
150
+ Q2 = self.q_proj(h2)
151
+ K1 = self.k_proj(h1)
152
+ V1 = self.v_proj(h1)
153
+
154
+ # Reshape for multi-head attention
155
+ Q1 = self._shape(Q1, T, B) # (B, heads, T, head_dim)
156
+ K2 = self._shape(K2, T, B)
157
+ V2 = self._shape(V2, T, B)
158
+
159
+ Q2 = self._shape(Q2, T, B)
160
+ K1 = self._shape(K1, T, B)
161
+ V1 = self._shape(V1, T, B)
162
+
163
+ # Scaled dot-product attention logits
164
+ scale = 1 / (self.head_dim ** 0.5)
165
+ L_1to2 = torch.matmul(Q1, K2.transpose(-1, -2)) * scale # (B, heads, T, T)
166
+ L_2to1 = torch.matmul(Q2, K1.transpose(-1, -2)) * scale # (B, heads, T, T)
167
+
168
+ # Softmax over last dim (keys)
169
+ S_1to2 = F.softmax(L_1to2, dim=-1)
170
+ S_2to1 = F.softmax(L_2to1, dim=-1)
171
+
172
+ # Competitive normalization (soft exclusivity)
173
+ M_joint = S_1to2 + S_2to1 + self.eps
174
+ A_1to2 = S_1to2 / M_joint
175
+ A_2to1 = S_2to1 / M_joint
176
+
177
+ # Weighted sum of values
178
+ H1_attn = torch.matmul(A_1to2, V2) # (B, heads, T, head_dim)
179
+ H2_attn = torch.matmul(A_2to1, V1)
180
+
181
+ # Concatenate heads back
182
+ H1_attn = H1_attn.transpose(1, 2).contiguous().view(B, T, self.embed_dim) # (B, T, F)
183
+ H2_attn = H2_attn.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
184
+
185
+ # Output projection
186
+ H1_attn = self.norms[0](self.out_proj(H1_attn))
187
+ H2_attn = self.norms[1](self.out_proj(H2_attn))
188
+
189
+ # Residuals
190
+ out = hidden_states + self.gate(torch.concat([H1_attn[:, None, :, :], H2_attn[:, None, :, :]], dim=1), dim=1)
191
+
192
+ return out # (B, 2, T, F)
193
+
194
+
195
+ class CoAttentionWrapper(nn.Module):
196
+ def __init__(self, config, num_speakers=2):
197
+ super().__init__()
198
+ self.coa = CoAttention(embed_dim=config.d_model, single_dim=config.d_model//2, multi_dim=config.d_model // 4, n_heads=config.encoder_attention_heads, attn_dropout=config.attention_dropout)
199
+ self.gate = Gate(num_speakers, 0.01)
200
+
201
+ def forward(self, coa_input: torch.Tensor) -> torch.Tensor:
202
+ # hidden_states: (B, 2, T, F)
203
+ hidden_states = coa_input.permute(-2, 0, 1, -1)
204
+ hidden_states = self.coa(hidden_states)
205
+ out = coa_input + self.gate(hidden_states.permute(1, 2, 0, -1), dim=1)
206
+ return out
207
+
208
+
209
+ class SpeakerCommunicationBlock(nn.Module):
210
+ def __init__(self, config, scb_method):
211
+ super().__init__()
212
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
213
+ self.embed_dim = config.d_model
214
+ self.scb_method = scb_method
215
+ self.config = config
216
+
217
+ if self.scb_method == "tac":
218
+ self.method = TACBlock(config)
219
+ elif self.scb_method == "cross_attention":
220
+ self.method = CrossAttentionBlock(config)
221
+ elif self.scb_method == "competitive_cross_attention":
222
+ self.method = CompetitiveCrossAttentionBlock(config)
223
+ elif self.scb_method == "co_attention":
224
+ self.method = CoAttentionWrapper(config)
225
+ elif self.scb_method == "identity":
226
+ self.method = (nn.Parameter(torch.zeros(self.embed_dim)) if config.fddt_bias_only else (
227
+ CustomDiagonalLinear(self.embed_dim, bias=True, init_eye_val=1.0) if config.fddt_is_diagonal else CustomLinear(
228
+ self.embed_dim, self.embed_dim, bias=True, init_eye_val=1.0)))
229
+ else:
230
+ raise ValueError(f"Unsupported scb_method: {self.scb_method}")
231
+
232
+ def forward(self, x):
233
+ # x: (B, T, F)
234
+ B, T, F = x.shape
235
+ S = self.num_speakers
236
+
237
+ # Reshape to (B//S, S, T, F)
238
+ x_reshaped = x.view(B//S, S, T, F)
239
+
240
+ # Call the selected method
241
+ out = self.method(x_reshaped)
242
+
243
+ # Reshape back (B, T, F)
244
+ out_merged = out.view(B, T, F)
245
+ return out_merged
coattention.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class MultiHeadCoAttention(nn.Module):
5
+ def __init__(self, multi_dim, single_dim, num_heads):
6
+ assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads'
7
+ assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads'
8
+ super().__init__()
9
+ self.q_proj = nn.Linear(single_dim, single_dim)
10
+ self.k_proj = nn.Linear(single_dim, single_dim)
11
+ self.multi_v_proj = nn.Linear(multi_dim, multi_dim) # D'
12
+ self.single_v_proj = nn.Linear(single_dim, single_dim) # D
13
+
14
+ self.multi_out_proj = nn.Linear(multi_dim, multi_dim) # D'
15
+ self.single_out_proj = nn.Linear(single_dim, single_dim) # D
16
+
17
+ self.multi_dim = multi_dim
18
+ self.single_dim = single_dim
19
+ self.num_heads = num_heads
20
+
21
+ def forward(self, query, key, multi_value, single_value):
22
+ # q, k, multi_v: (T,B,ch,D')
23
+ # single_v: (T,B,1,D)
24
+ query = torch.transpose(query, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
25
+ key = torch.transpose(key, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
26
+ multi_value = torch.permute(multi_value, (1, 2, 0, 3)) # (B,ch,T,D')...[32, 4, 150, 64]
27
+ single_value = torch.permute(single_value, (1, 2, 0, 3)) # (B,1,T,D)...[32, 1, 150, 256]
28
+ ###########
29
+
30
+ q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
31
+ q = torch.stack(q, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
32
+
33
+ k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
34
+ k = torch.stack(k, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
35
+
36
+ multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads,
37
+ dim=-1) # seq: (B,ch,T,D'/h)
38
+ multi_v = torch.stack(multi_v, dim=1) # (B, h, ch, T, D'/h)...[32, 8, 4, 150, 8]
39
+
40
+ single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads,
41
+ dim=-1) # seq: (B,1,T,D/h)
42
+ single_v = torch.stack(single_v, dim=1) # seq: (B,h,1,T,D/h)...[32, 32, 1, 150, 8]
43
+
44
+ q = q.view(*q.shape[:-2], -1) # (B, h, T, ch*D/h)
45
+ k = k.view(*k.shape[:-2], -1) # (B, h, T, ch*D/h)
46
+ normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device))
47
+
48
+ sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer # (B, h, T, T)
49
+ att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2) # (B, h, 1, T, T)
50
+
51
+ # co-attention
52
+ multi_result = torch.matmul(att_mat, multi_v) # (B, h, ch, T, D'/h)
53
+ single_result = torch.matmul(att_mat, single_v) # (B, h, 1, T, D/h)
54
+
55
+ multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4)) # (T, B, ch, h, D'/h)
56
+ single_result = torch.permute(single_result, (3, 0, 2, 1, 4)) # (T, B, 1, h, D/h)
57
+ multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,)) # (T, B, ch, D')
58
+ single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,)) # (T, B, 1, D)
59
+
60
+ multi_result = self.multi_out_proj(multi_result)
61
+ single_result = self.single_out_proj(single_result)
62
+ return multi_result, single_result
63
+
64
+
65
+ class CoAttention(nn.Module):
66
+ def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0.,
67
+ init_mult=1e-2): # , pre_norm=True):
68
+ super().__init__()
69
+ self.init_mult = init_mult
70
+
71
+ self.in_single_proj = nn.Linear(embed_dim, single_dim) # single_dim == D
72
+ self.in_single_ln = nn.LayerNorm(single_dim)
73
+
74
+ self.in_multi_proj = nn.Linear(embed_dim, multi_dim) # multi_dim == D'
75
+ self.in_multi_ln = nn.LayerNorm(multi_dim)
76
+
77
+ self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads)
78
+ self.mca_multi_out_ln = nn.LayerNorm(multi_dim)
79
+ self.mca_single_out_ln = nn.LayerNorm(single_dim)
80
+
81
+ # default MHA input: (seq, batch, feature)
82
+ self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None,
83
+ vdim=None)
84
+ self.mha_ln = nn.LayerNorm(single_dim)
85
+
86
+ self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim)
87
+
88
+ self.miso = False
89
+
90
+ def scale_weights(self):
91
+ self.cat_proj.bias.data *= 0.
92
+ self.cat_proj.weight.data *= self.init_mult
93
+
94
+ def forward(self, x):
95
+ # x: (T,B,ch,F); (150, 32, 4, 768)
96
+ frames, B, chans, feat_dim = x.shape
97
+
98
+ single_x = torch.mean(x,dim=2) # (T,B,F)
99
+ single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2) # (T,B,1,D)
100
+
101
+ multi_x = self.in_multi_ln(self.in_multi_proj(x)) # (T,B,ch,D')
102
+
103
+ # MCA
104
+ multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x) # (T,B,ch,D'), (T,B,ch,D)
105
+ single_x = single_x + single_mca
106
+ multi_x = multi_x + multi_mca
107
+ multi_x = self.mca_multi_out_ln(multi_x) # (T,B,ch,D')
108
+ single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2) # (T,B,D)
109
+
110
+ # MHA
111
+ single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False) # (T, B, D)
112
+ single_x = self.mha_ln(single_mha + single_x)
113
+
114
+ # join representations
115
+ single_x = single_x.unsqueeze(-2) # (T,B,1,D)
116
+ single_x_tile = torch.tile(single_x, (1, 1, chans, 1)) # (T,B,ch,D)
117
+ cat_x = torch.cat([single_x_tile, multi_x], dim=-1) # (T,B,ch,D+D')
118
+ out = self.cat_proj(cat_x) # (T,B,ch,F)
119
+
120
+ return out
config.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/matylda5/ipoloka/challenges/NOTSOFAR1-Challenge/exp/train/asru/table2/full_augment/checkpoint-25500",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "additional_layer": false,
6
+ "additional_self_attention_layer": true,
7
+ "apply_fddt_to_n_layers": -1,
8
+ "apply_spec_augment": false,
9
+ "architectures": [
10
+ "DiCoWForConditionalGeneration"
11
+ ],
12
+ "attention_dropout": 0.0,
13
+ "auto_map": {
14
+ "AutoConfig": "config.DiCoWConfig",
15
+ "AutoModelForSpeechSeq2Seq": "modeling_dicow.DiCoWForConditionalGeneration"
16
+ },
17
+ "begin_suppress_tokens": [
18
+ 220,
19
+ 50256
20
+ ],
21
+ "blank_token_id": null,
22
+ "bos_token_id": 50257,
23
+ "classifier_proj_size": 256,
24
+ "contrastive_loss_weight": 0,
25
+ "ctc_loss_reduction": "mean",
26
+ "ctc_weight": 0.3,
27
+ "ctc_zero_infinity": false,
28
+ "d_model": 1280,
29
+ "decoder_attention_heads": 20,
30
+ "decoder_ffn_dim": 5120,
31
+ "decoder_layerdrop": 0.0,
32
+ "decoder_layers": 4,
33
+ "decoder_start_token_id": 50258,
34
+ "dropout": 0.0,
35
+ "encoder_attention_heads": 20,
36
+ "encoder_ffn_dim": 5120,
37
+ "encoder_layerdrop": 0.0,
38
+ "encoder_layers": 32,
39
+ "eos_token_id": 50257,
40
+ "fddt_bias_only": false,
41
+ "fddt_init": "disparagement",
42
+ "fddt_is_diagonal": true,
43
+ "fddt_use_non_target": true,
44
+ "fddt_use_overlap": true,
45
+ "fddt_use_silence": true,
46
+ "fddt_use_target": true,
47
+ "final_dropout": 0.0,
48
+ "forced_decoder_ids": null,
49
+ "init_std": 0.02,
50
+ "is_encoder_decoder": true,
51
+ "is_mt": false,
52
+ "mask_feature_length": 10,
53
+ "mask_feature_min_masks": 0,
54
+ "mask_feature_prob": 0.0,
55
+ "mask_time_length": 10,
56
+ "mask_time_min_masks": 2,
57
+ "mask_time_prob": 0.05,
58
+ "max_source_positions": 1500,
59
+ "max_target_positions": 448,
60
+ "median_filter_width": 7,
61
+ "model_type": "DiCoW",
62
+ "mt_num_speakers": 1,
63
+ "n_soft_prompts": 16,
64
+ "non_target_fddt_value": 0.5,
65
+ "num_hidden_layers": 32,
66
+ "num_mel_bins": 128,
67
+ "pad_token_id": 50257,
68
+ "remove_timestamps_from_ctc": true,
69
+ "scale_embedding": false,
70
+ "scb_layers": -1,
71
+ "scb_method": null,
72
+ "sub_sample": true,
73
+ "torch_dtype": "float32",
74
+ "transformers_version": "4.42.0",
75
+ "use_cache": true,
76
+ "use_fddt": true,
77
+ "use_initial_fddt": true,
78
+ "use_weighted_layer_sum": false,
79
+ "vocab_size": 51866
80
+ }
config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import WhisperConfig
6
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput
7
+
8
+
9
+ @dataclass
10
+ class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
11
+ enc_loss: Optional[torch.FloatTensor] = None
12
+ dec_loss: Optional[torch.FloatTensor] = None
13
+ encoder_logits: Optional[torch.FloatTensor] = None
14
+
15
+
16
+ @dataclass
17
+ class BaseModelOutputLogit(BaseModelOutput):
18
+ logits: Optional[torch.FloatTensor] = None
19
+
20
+
21
+ @dataclass
22
+ class Seq2SeqModelOutputLogit(Seq2SeqModelOutput):
23
+ encoder_logits: Optional[torch.FloatTensor] = None
24
+
25
+
26
+ class DiCoWConfig(WhisperConfig):
27
+ """This is a modified version of the `WhisperEncoder` model from the `transformers` library.
28
+ The model has been modified to support CTC loss computation in the forward pass."""
29
+ model_type = "DiCoW"
30
+
31
+ def __init__(
32
+ self,
33
+ ctc_loss_reduction: str = "mean",
34
+ final_dropout: float = 0.0,
35
+ ctc_zero_infinity: bool = False,
36
+ ctc_weight: float = 0.0,
37
+ blank_token_id: Optional[int] = None,
38
+ additional_layer: bool = False,
39
+ additional_self_attention_layer: bool = False,
40
+ sub_sample: bool = False,
41
+ use_fddt: bool = True,
42
+ fddt_is_diagonal: bool = True,
43
+ fddt_bias_only: bool = False,
44
+ fddt_use_silence: bool = True,
45
+ fddt_use_target: bool = True,
46
+ fddt_use_overlap: bool = True,
47
+ fddt_use_non_target: bool = True,
48
+ remove_timestamps_from_ctc: bool = False,
49
+ apply_fddt_to_n_layers: int = -1,
50
+ fddt_init: str = 'non-disturbing', # random, non-disturbing, dispargement
51
+ n_soft_prompts: int = 16,
52
+ mt_num_speakers: int = 1,
53
+ is_mt: bool = False,
54
+ non_target_fddt_value: float = 0.0,
55
+ use_initial_fddt: bool = False,
56
+ scb_method: str = None,
57
+ scb_layers: int = -1,
58
+ contrastive_loss_weight: float = 0.0,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+ self.ctc_loss_reduction = ctc_loss_reduction
63
+ self.final_dropout = final_dropout
64
+ self.ctc_zero_infinity = ctc_zero_infinity
65
+ self.ctc_weight = ctc_weight
66
+ self.blank_token_id = blank_token_id
67
+ self.additional_layer = additional_layer
68
+ self.additional_self_attention_layer = additional_self_attention_layer
69
+ self.sub_sample = sub_sample
70
+ self.use_fddt = use_fddt
71
+ self.fddt_is_diagonal = fddt_is_diagonal
72
+ self.fddt_bias_only = fddt_bias_only
73
+ self.fddt_use_silence = fddt_use_silence
74
+ self.fddt_use_target = fddt_use_target
75
+ self.fddt_use_overlap = fddt_use_overlap
76
+ self.fddt_use_non_target = fddt_use_non_target
77
+ self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
78
+ self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
79
+ self.fddt_init = fddt_init
80
+ self.n_soft_prompts = n_soft_prompts
81
+ self.mt_num_speakers = mt_num_speakers
82
+ self.non_target_fddt_value = non_target_fddt_value
83
+ self.use_initial_fddt = use_initial_fddt
84
+ self.scb_method = scb_method
85
+ self.scb_layers = scb_layers
86
+ self.contrastive_loss_weight = contrastive_loss_weight
87
+ self.is_mt = is_mt
88
+
89
+
90
+ _HIDDEN_STATES_START_POSITION = 2
contrastive_loss.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ContrastiveLoss(nn.Module):
7
+ def __init__(self, temperature=.25, distance_metric='cosine'):
8
+ super(ContrastiveLoss, self).__init__()
9
+ self.temperature = temperature
10
+ self.distance_metric = distance_metric
11
+
12
+ def compute_similarity(self, embeddings):
13
+ if self.distance_metric == 'cosine':
14
+ embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D]
15
+ sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T]
16
+ else:
17
+ raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
18
+ return sim / self.temperature
19
+
20
+ def pairwise_and_no_diag(self, m):
21
+ m_i = m.unsqueeze(2) # [B, T, 1]
22
+ m_j = m.unsqueeze(1) # [B, 1, T]
23
+ out = m_i & m_j # [B, T, T]
24
+ diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0)
25
+ return out & ~diag
26
+
27
+ def forward(self, embeddings, pos_indicator_mask):
28
+ """
29
+ Args:
30
+ embeddings: [B, 2T, D]
31
+ pos_indicator_mask: [B, 2T] - boolean, positions that belong to each speaker group
32
+ Returns:
33
+ Scalar contrastive loss
34
+ """
35
+ B, two_T, D = embeddings.shape
36
+ T = two_T // 2
37
+ sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
38
+
39
+ # Split input mask
40
+ m1 = pos_indicator_mask[:, :T] # [B, T]
41
+ m2 = pos_indicator_mask[:, T:] # [B, T]
42
+
43
+ # Positive mask (same speaker pairs, diagonal excluded)
44
+ pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T]
45
+ pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T]
46
+ pos_mask = torch.cat([
47
+ torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T]
48
+ torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T]
49
+ ], dim=1) # [B, 2T, 2T]
50
+
51
+ # Negative mask (cross-speaker pairs where both are active)
52
+ cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
53
+ neg_mask = torch.cat([
54
+ torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
55
+ torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
56
+ ], dim=1) # [B, 2T, 2T]
57
+
58
+ # Identity mask (exclude [i, i] self-pairs)
59
+ identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
60
+ pos_mask &= ~identity_mask
61
+ neg_mask &= ~identity_mask
62
+
63
+ # Fully vectorized InfoNCE computation
64
+ if pos_mask.any():
65
+ # Compute exp(similarities) for numerical stability
66
+ exp_sim = torch.exp(sim) # [B, 2T, 2T]
67
+
68
+ # Get positive similarities
69
+ pos_sim = sim[pos_mask] # [num_pos_pairs]
70
+ pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
71
+
72
+ # For each position, sum the exponentials of its negatives
73
+ neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T]
74
+
75
+ # Get the negative sums corresponding to each positive pair
76
+ pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
77
+ batch_idx = pos_indices[:, 0] # [num_pos_pairs]
78
+ row_idx = pos_indices[:, 1] # [num_pos_pairs]
79
+
80
+ # Get negative sums for each positive pair's anchor
81
+ neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs]
82
+
83
+ # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
84
+ denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs]
85
+
86
+ # InfoNCE loss: -log(exp(pos) / denominator)
87
+ loss = -torch.log(pos_exp / denominators)
88
+ total_loss = loss.mean()
89
+ # logits = sim
90
+ # logits = logits.masked_fill(~(pos_mask | neg_mask), float('-inf')) # Mask out irrelevant pairs
91
+ #
92
+ # log_probs = F.log_softmax(logits, dim=-1) # [B, 2T, 2T]
93
+ # pos_log_probs = log_probs[pos_mask]
94
+ # total_loss = -pos_log_probs.mean()
95
+ else:
96
+ # No positive pairs found, return zero loss
97
+ total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
98
+ return total_loss
99
+
100
+
101
+
102
+ # Example usage and testing
103
+ def create_example_data():
104
+ """Create example data for testing."""
105
+ B, T, D = 2, 3, 64
106
+
107
+ # Create random embeddings
108
+ embeddings = torch.randn(B, T, D)
109
+
110
+ # Create example positive and negative masks
111
+ pos_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
112
+ neg_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
113
+
114
+ # Example: make adjacent time steps positive pairs
115
+ for b in range(B):
116
+ for t in range(T - 1):
117
+ pos_mask[b, t, b, t + 1] = True
118
+ pos_mask[b, t + 1, b, t] = True
119
+
120
+ # Example: make cross-batch pairs negative
121
+ for b1 in range(B):
122
+ for b2 in range(B):
123
+ if b1 != b2:
124
+ neg_mask[b1, :, b2, :] = True
125
+
126
+ pair_masks = torch.stack([pos_mask, neg_mask])
127
+
128
+ return embeddings, pair_masks
129
+
130
+
131
+ if __name__ == "__main__":
132
+ # Test the implementation
133
+ embeddings, pair_masks = create_example_data()
134
+
135
+ # Initialize loss function
136
+ contrastive_loss = ContrastiveLoss(temperature=0.07, distance_metric='cosine')
137
+
138
+ # Compute loss
139
+ loss = contrastive_loss(embeddings, pair_masks)
140
+ print(f"Contrastive Loss: {loss.item():.4f}")
decoding.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import itertools as it
4
+ from typing import List
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import LogitsProcessor, PreTrainedTokenizer
9
+
10
+
11
+ class CTCPrefixScore(object):
12
+ """Compute CTC label sequence scores
13
+
14
+ which is based on Algorithm 2 in WATANABE et al.
15
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
16
+ but extended to efficiently compute the label probabilities for multiple
17
+ hypotheses simultaneously
18
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
19
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
20
+ """
21
+
22
+ def __init__(self, x, blank, eos):
23
+ self.logzero = -1e10
24
+ self.blank = blank
25
+ self.eos = eos
26
+ self.input_length = x.shape[1]
27
+ self.batch_size = x.shape[0]
28
+ self.x = x
29
+ self.device = x.device
30
+
31
+ # Preallocate `r` and `xs` tensors
32
+ # `num_labels` will be set dynamically in __call__ but preallocated with maximum capacity
33
+ self.max_num_labels = x.shape[2] # Set to a max value that can be dynamically resized
34
+ self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero,
35
+ device=self.device)
36
+ self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero,
37
+ device=self.device)
38
+
39
+ def initial_state(self):
40
+ """Obtain an initial CTC state."""
41
+ # Create initial CTC state tensor and use in-place operations to fill
42
+ r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device)
43
+ r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1)
44
+ s = torch.zeros((self.batch_size, 1), device=self.device)
45
+
46
+ return r, s
47
+
48
+ def _resize_tensors(self, number_of_current_samples, num_labels):
49
+ if self.r.shape[0] != number_of_current_samples:
50
+ self.r = self.r[:number_of_current_samples, ...]
51
+ self.xs = self.xs[:number_of_current_samples, ...]
52
+
53
+ if self.r.shape[3] != num_labels:
54
+ self.r = self.r[:, :, :, :num_labels].fill_(self.logzero)
55
+ self.xs = self.xs[:, :, :num_labels].fill_(self.logzero)
56
+ else:
57
+ self.r.fill_(self.logzero)
58
+ self.xs.fill_(self.logzero)
59
+
60
+ def _initialize_r(self, decoded_len):
61
+ mask = (decoded_len == 0)
62
+ self.r[mask, 0, 0, :] = self.xs[mask, 0]
63
+
64
+ def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev):
65
+ # Expand r_sum for num_labels and initialize log_phi
66
+ log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1])
67
+
68
+ # Create mask for cases where `decoded_len > 0` and to identify where `c == last[i]` for all `i`
69
+ non_zero_mask = (decoded_len > 0)
70
+ label_match_mask = (cs == last.unsqueeze(1))
71
+
72
+ # Update log_phi where both `decoded_len > 0` and `c == last[i]`
73
+ log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi)
74
+ return log_phi
75
+
76
+ def _compute_log_psi(self, decoded_len, log_phi, x_current):
77
+ """This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)),
78
+ and log prefix probabilities log(psi) for all labels in the batch.
79
+
80
+ :param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence
81
+ :param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities
82
+ :param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame
83
+
84
+ :return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities
85
+ """
86
+ B, T, V = log_phi.shape
87
+ start = torch.clamp(decoded_len, min=1) # Ensure start is at least 1 to avoid out-of-bounds
88
+
89
+ # Initialize log_psi with the start position of r[:, start - 1, 0, :]
90
+ log_psi = self.r[torch.arange(B), start - 1, 0, :]
91
+
92
+ # Mask for handling sequence lengths based on decoded_len
93
+ mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1)
94
+
95
+ # Accumulate log_psi only up to the last valid time step for each sequence
96
+ log_psi = torch.logaddexp(log_psi, torch.logsumexp(
97
+ torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1))
98
+
99
+ start = torch.clamp(decoded_len, 1)
100
+
101
+ # TODO: Vectorize this loop by compute suffix xs and multiplying with log_phi
102
+ # xs = self.xs[:,1:,:].clone()
103
+ # xs_cum = torch.cumsum(xs, dim=1)
104
+ # xs_cum_expanded = xs_cum.unsqueeze(1).repeat(1, T-1, 1, 1)
105
+ # xs_u = (xs_cum_expanded - torch.nn.functional.pad(xs_cum[:,:-1,:], (0,0,1,0), value=0).unsqueeze(2).repeat(1, 1,T-1,1)).permute(0,2,1,3)
106
+ #
107
+ # phis_new = log_phi[:,:-1].clone()
108
+ # phis_new[:, 0] = torch.logaddexp(phis_new[:, 0], self.r[:, 0, 0, :])
109
+ # phis_new = phis_new.unsqueeze(1).repeat(1, T-1, 1, 1)
110
+ # causal_mask = torch.ones((T-1,T-1), dtype=torch.bool, device=self.device).tril().unsqueeze(0).unsqueeze(-1).repeat(B,1,1,1)
111
+ # mask = causal_mask & mask_t.unsqueeze(2).unsqueeze(-1)
112
+ # r_zero = torch.logsumexp(torch.where(mask, xs_u + phis_new, self.logzero), dim=2)
113
+ # self.r[:,1:,0] = r_zero
114
+
115
+ for t in range(start.min(), self.input_length):
116
+ should_decode = decoded_len <= t
117
+ self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0],
118
+ log_phi[:, t - 1]) + self.xs[:, t]
119
+ self.r[:, t, 1] = (
120
+ torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None]
121
+ )
122
+ if ~should_decode.any():
123
+ self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero)
124
+
125
+ return log_psi
126
+
127
+ def _update_log_psi_with_eos(self, log_psi, cs, r_sum):
128
+ # Update log_psi for eos positions
129
+ eos_mask = (cs == self.eos)
130
+ log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask]
131
+
132
+ # Exclude blank probabilities if eos is not the blank
133
+ if self.eos != self.blank:
134
+ blank_mask = (cs == self.blank)
135
+ log_psi[blank_mask] = self.logzero
136
+ return log_psi
137
+
138
+ def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev):
139
+ """Compute CTC prefix scores for next labels
140
+
141
+ :param y : prefix label sequence
142
+ :param cs : array of next labels
143
+ :param r_prev: previous CTC state
144
+ :return ctc_scores, ctc_states
145
+ """
146
+ # initialize CTC states
147
+ # output_length = y.shape[1] - 1 # ignore sos
148
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
149
+ # that corresponds to r_t^n(h) and r_t^b(h).
150
+
151
+ # Dynamically resize r and xs to match num_labels if necessary
152
+ num_labels = cs.shape[1]
153
+ number_of_current_samples = cs.shape[0]
154
+ self._resize_tensors(number_of_current_samples, num_labels)
155
+
156
+ # Create a view of the current input frame
157
+ x_current = self.x[samples_to_be_decoded]
158
+ self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1))
159
+
160
+ # Initialize r for the first frame
161
+ self._initialize_r(decoded_len)
162
+
163
+ # prepare forward probabilities for the last label
164
+ r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) # log(r_t^n(g) + r_t^b(g))
165
+ last = y[:, -1]
166
+
167
+ # precompute log_phi
168
+ log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev)
169
+
170
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
171
+ # and log prefix probabilities log(psi)
172
+ log_psi = self._compute_log_psi(decoded_len, log_phi, x_current)
173
+
174
+ # get P(...eos|X) that ends with the prefix itself
175
+ log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum)
176
+
177
+ # return the log prefix probability and CTC states, where the label axis
178
+ # of the CTC states is moved to the first axis to slice it easily
179
+ return log_psi, self.r
180
+
181
+
182
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
183
+ def __init__(
184
+ self,
185
+ encoder_logits: torch.FloatTensor,
186
+ encoder_output_lens: torch.Tensor,
187
+ blank_token_id: int,
188
+ pad_token_id: int,
189
+ eos_token_id: int,
190
+ bos_token_id: int,
191
+ tokenizer: PreTrainedTokenizer,
192
+ ctc_margin: int,
193
+ ctc_weight: float,
194
+ num_beams: int,
195
+ debug: bool = False,
196
+ ctc_tokens_to_score: int = 500
197
+ ):
198
+ super().__init__()
199
+ same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items())))
200
+
201
+ logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1)
202
+ logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]]
203
+
204
+ self.logits = logits
205
+
206
+ self.ctc_prefix_scorer = CTCPrefixScore(
207
+ self.logits,
208
+ blank_token_id,
209
+ eos_token_id,
210
+ )
211
+ self.batch_size = logits.shape[0]
212
+ self.input_length = logits.shape[1]
213
+ self.num_tokens = logits.shape[2]
214
+ self.device = logits.device
215
+ self.ctc_weight = ctc_weight
216
+ self.num_beams = num_beams
217
+ self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state()
218
+ self.eos_token_id = eos_token_id
219
+ self.bos_token_id = bos_token_id
220
+ self.tokenizer = tokenizer
221
+ self.pad_token_id = pad_token_id
222
+ self.blank_token_id = blank_token_id
223
+ self.debug = False
224
+ self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"]
225
+ self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device)
226
+ self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2),
227
+ device=self.device)
228
+ self.ctc_tokens_to_score = ctc_tokens_to_score
229
+
230
+ def analyze_predictions(self,
231
+ scores, ctc_scores, next_token_scores, input_ids, k=10):
232
+ print("\n" + "#" * 100)
233
+
234
+ batch_size = input_ids.shape[0]
235
+
236
+ best_att_ids = scores.topk(k=k, dim=1)
237
+ ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero
238
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
239
+ best_ids = next_token_scores.topk(k=k, dim=1)
240
+
241
+ decoded_prefixes = self.tokenizer.batch_decode(
242
+ input_ids, decode_with_timestamps=True, skip_special_tokens=False
243
+ )
244
+
245
+ def prepare_and_decode(best_ids_tensor):
246
+ new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long)
247
+ new_tensor[:, 0::2] = best_ids_tensor.indices
248
+ new_tensor[:, 1::2] = self.tokenizer.vocab['#']
249
+
250
+ # Flatten to (batch_size * k, 2)
251
+ flat_tensor = new_tensor.view(-1, 2)
252
+ decoded = self.tokenizer.batch_decode(
253
+ flat_tensor, decode_with_timestamps=True, skip_special_tokens=False
254
+ )
255
+ # Reshape back to (batch_size, k)
256
+ decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)]
257
+ return decoded
258
+
259
+ decoded_att = prepare_and_decode(best_att_ids)
260
+ decoded_ctc = prepare_and_decode(best_ctc_ids)
261
+ decoded_next = prepare_and_decode(best_ids)
262
+
263
+ for idx in range(batch_size):
264
+ print("-" * 80)
265
+ print(f"HYPOTHESIS {idx}")
266
+ print("\nPREFIX:")
267
+ print(decoded_prefixes[idx])
268
+
269
+ def print_with_pandas(tokens, scores, title):
270
+ df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]])
271
+ df.index = [f"{title}", "Score"]
272
+ print(f"\n{title}:")
273
+ print(df.to_string(index=True, header=False))
274
+
275
+ print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS")
276
+ print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS")
277
+ print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS")
278
+
279
+ print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}")
280
+ print()
281
+
282
+ print("#" * 100)
283
+
284
+ def update_state(self, best_ids, beam_idx):
285
+ mask = best_ids < self.first_timestamp_token_id
286
+ self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1),
287
+ self.tmp_ctc_states[beam_idx, best_ids],
288
+ self.ctc_state_prev[beam_idx])
289
+ self.ctc_score_prev = torch.where(mask.unsqueeze(-1),
290
+ self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1),
291
+ self.ctc_score_prev[beam_idx])
292
+
293
+ def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
294
+ input_ids = input_ids_orig.clone()
295
+
296
+ # Remove prefix from CTC scoring
297
+ if (input_ids[:, 0] != self.bos_token_id).any():
298
+ input_ids = torch.stack(
299
+ [row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids])
300
+
301
+ # Remove task/lang/timestamp tokens from input_ids
302
+ input_prefix_len = len(self.tokenizer.prefix_tokens)
303
+ if input_prefix_len > 1:
304
+ input_ids = input_ids[:, input_prefix_len - 1:]
305
+
306
+ # Setup the first token to be the blank token(sos)
307
+ input_ids[:, 0] = self.blank_token_id
308
+
309
+ # If there is last token in input_ids timestamp replicate last non-timestamp token which could be potentially even the first token
310
+ decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id,
311
+ input_ids != self.blank_token_id).sum(dim=1)
312
+ mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id,
313
+ input_ids[:, -1] != self.blank_token_id)
314
+ last_non_timestamp_token = torch.gather(input_ids, 1,
315
+ torch.logical_or(input_ids < self.first_timestamp_token_id,
316
+ input_ids == self.blank_token_id).sum(dim=1,
317
+ keepdim=True) - 1)
318
+ input_ids[mask, -1] = last_non_timestamp_token[mask, 0]
319
+
320
+ # If there is no eos token in the last position, we need to continue decoding
321
+ to_be_decoded = input_ids[:, -1] != self.eos_token_id
322
+ self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero
323
+
324
+ input_ids_local = input_ids[to_be_decoded]
325
+ ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices
326
+
327
+ # always score EOS token if not present put on position of last id
328
+ is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1)
329
+ ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id
330
+
331
+ decoded_len_local = decoded_len[to_be_decoded]
332
+
333
+ ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded],
334
+ decoded_len_local, to_be_decoded,
335
+ self.ctc_state_prev[to_be_decoded])
336
+
337
+ # As the CTC scorer might run on subset of samples, we need to scatter the results back to the original batch
338
+ self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded]
339
+ .scatter(1, ids_to_score[to_be_decoded], ctc_scores_local))
340
+ self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1)
341
+ .scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1)
342
+ .repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local)
343
+ .permute(0, 3, 1, 2))
344
+
345
+ # Set the CTC score for the timestamp tokens to the maximum to prefer them over the rest
346
+ self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None]
347
+ ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev
348
+
349
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
350
+
351
+ if self.debug:
352
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig)
353
+
354
+ return next_token_scores
355
+
356
+
357
+ class LogSoftmaxProcessor(LogitsProcessor):
358
+ def __init__(
359
+ self,
360
+ ):
361
+ super().__init__()
362
+
363
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
364
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
365
+ return scores
366
+
367
+
368
+ class GreedyCTCDecoder(torch.nn.Module):
369
+ def __init__(self, tokenizer, blank=0):
370
+ super().__init__()
371
+ self.blank = blank
372
+ self.tokenizer = tokenizer
373
+
374
+ def forward(self, emission: torch.Tensor) -> List[str]:
375
+ """Given a sequence emission over labels, get the best path
376
+ Args:
377
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
378
+
379
+ Returns:
380
+ List[str]: The resulting transcript
381
+ """
382
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
383
+ indices = [torch.unique_consecutive(index, dim=-1) for index in indices]
384
+ indices = [index[index != self.blank] for index in indices]
385
+ indices = torch.nn.utils.rnn.pad_sequence(indices, batch_first=True,
386
+ padding_value=self.tokenizer.pad_token_id)
387
+ indices[indices >= len(self.tokenizer)] = self.tokenizer.unk_token_id
388
+ return indices
389
+
390
+
391
+ def ctc_greedy_decode(logits: torch.Tensor, blank, pad_token_id) -> torch.Tensor:
392
+ idxs = torch.argmax(logits, dim=-1)
393
+ for i, prediction in enumerate(idxs):
394
+ deduplicated = [k for k, g in it.groupby(prediction) if k != blank]
395
+ idxs[i, : len(deduplicated)] = torch.tensor(deduplicated)
396
+ idxs[i, len(deduplicated):] = pad_token_id
397
+ return idxs
encoder.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
5
+
6
+ from .FDDT import FDDT
7
+ from .config import DiCoWConfig
8
+ from .SCBs import SpeakerCommunicationBlock
9
+
10
+
11
+ class DiCoWEncoder(WhisperEncoder):
12
+ config_class = DiCoWConfig
13
+
14
+ def __init__(self, config: DiCoWConfig):
15
+ super().__init__(config)
16
+ self.ctc_weight = config.ctc_weight
17
+ if config.additional_layer and self.ctc_weight > 0.0:
18
+ self.additional_layer = WhisperEncoderLayer(config)
19
+ if config.additional_self_attention_layer and self.ctc_weight > 0.0:
20
+ self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
21
+ embed_dim=config.d_model,
22
+ num_heads=config.encoder_attention_heads,
23
+ dropout=config.attention_dropout,
24
+ config=config,
25
+ )
26
+ if config.sub_sample and self.ctc_weight > 0.0:
27
+ self.subsample_conv1 = nn.Conv1d(
28
+ in_channels=config.d_model,
29
+ out_channels=config.d_model,
30
+ kernel_size=3,
31
+ stride=2,
32
+ padding=1,
33
+ bias=False,
34
+ )
35
+ self.subsample_conv2 = nn.Conv1d(
36
+ in_channels=config.d_model,
37
+ out_channels=config.d_model,
38
+ kernel_size=3,
39
+ stride=2,
40
+ padding=1,
41
+ bias=False,
42
+ )
43
+ if self.ctc_weight > 0.0:
44
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False)
45
+ self.final_dropout = nn.Dropout(config.final_dropout)
46
+ if config.use_fddt:
47
+ num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len(
48
+ self.layers)
49
+ self.initial_fddt = FDDT(config.d_model,
50
+ non_target_rate=config.non_target_fddt_value,
51
+ is_diagonal=config.fddt_is_diagonal,
52
+ bias_only=config.fddt_bias_only,
53
+ use_silence=config.fddt_use_silence,
54
+ use_target=config.fddt_use_target,
55
+ use_overlap=config.fddt_use_overlap,
56
+ use_non_target=config.fddt_use_non_target,
57
+ use_interaction=False,
58
+ scb_module=None
59
+ # in initial layers we dont want communication
60
+ )
61
+ num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len(
62
+ self.layers)) if self.config.is_mt else 0
63
+ self.scbs_identity_layers = config.encoder_layers - num_scbs
64
+ self.fddts = nn.ModuleList([
65
+ FDDT(config.d_model,
66
+ non_target_rate=1.0,
67
+ is_diagonal=config.fddt_is_diagonal,
68
+ bias_only=config.fddt_bias_only,
69
+ use_silence=config.fddt_use_silence,
70
+ use_target=config.fddt_use_target,
71
+ use_overlap=config.fddt_use_overlap,
72
+ use_non_target=config.fddt_use_non_target,
73
+ use_interaction=i >= self.scbs_identity_layers,
74
+ scb_module=SpeakerCommunicationBlock(config,
75
+ scb_method=config.scb_method) if i >= self.scbs_identity_layers else None,
76
+ )
77
+ for i in range(num_fddts)
78
+ ])
79
+ self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks
80
+ self.post_init()
81
+
82
+ @classmethod
83
+ def _load_pretrained_model(
84
+ cls,
85
+ model,
86
+ state_dict,
87
+ loaded_keys,
88
+ resolved_archive_file,
89
+ pretrained_model_name_or_path,
90
+ **kwargs
91
+ ):
92
+ for key in list(state_dict.keys()):
93
+ if key.startswith("encoder."):
94
+ state_dict[key[8:]] = state_dict.pop(key)
95
+ loaded_keys.remove(key)
96
+ loaded_keys.append(key[8:])
97
+ output = super()._load_pretrained_model(
98
+ model,
99
+ state_dict,
100
+ loaded_keys,
101
+ resolved_archive_file,
102
+ pretrained_model_name_or_path,
103
+ **kwargs
104
+ )
105
+ return output
106
+
107
+ def get_loss(self, logits, labels):
108
+ if labels.max() >= self.config.vocab_size:
109
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
110
+ if self.config.remove_timestamps_from_ctc:
111
+ labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels],
112
+ padding_value=-100).T
113
+ input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1],
114
+ device=logits.device)
115
+
116
+ # assuming that padded tokens are filled with -100
117
+ # when not being attended to
118
+ labels_mask = labels >= 0
119
+ target_lengths = labels_mask.sum(-1)
120
+ # flattened_targets = labels_enc.masked_select(labels_mask)
121
+
122
+ # ctc_loss doesn't support fp16
123
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
124
+
125
+ with torch.backends.cudnn.flags(enabled=True):
126
+ ctc_loss = nn.functional.ctc_loss(
127
+ log_probs,
128
+ labels,
129
+ input_lengths,
130
+ target_lengths,
131
+ blank=logits.shape[-1] - 1,
132
+ reduction=self.config.ctc_loss_reduction,
133
+ zero_infinity=True,
134
+ )
135
+ return ctc_loss
136
+
137
+ def forward(
138
+ self,
139
+ input_features,
140
+ attention_mask=None,
141
+ head_mask=None,
142
+ output_attentions=None,
143
+ output_hidden_states=None,
144
+ return_dict=None,
145
+ stno_mask=None,
146
+ per_group_sizes=None
147
+ ):
148
+ # For MT-ASR the input has shape (B X S) x F x T
149
+ # we can use torch.view(B, S, F, -1) to obtain
150
+ # new tensor with speaker dim
151
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
152
+ if input_features.shape[-1] != expected_seq_length:
153
+ if input_features.shape[-1] > expected_seq_length:
154
+ return CausalLMOutput(
155
+ logits=None,
156
+ hidden_states=None,
157
+ attentions=None,
158
+ )
159
+ else:
160
+ raise ValueError(
161
+ f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
162
+ )
163
+
164
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
165
+ output_hidden_states = (
166
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
167
+ )
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
170
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
171
+
172
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
173
+ embed_pos = self.embed_positions.weight
174
+
175
+ if self.config.use_fddt:
176
+ inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask)
177
+
178
+ hidden_states = inputs_embeds + embed_pos
179
+
180
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
181
+
182
+ encoder_states = () if output_hidden_states else None
183
+ all_attentions = () if output_attentions else None
184
+
185
+ # check if head_mask has a correct number of layers specified if desired
186
+ if head_mask is not None:
187
+ assert head_mask.size()[0] == (
188
+ len(self.layers)
189
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
190
+
191
+ for idx, encoder_layer in enumerate(self.layers):
192
+ if output_hidden_states:
193
+ encoder_states = encoder_states + (hidden_states,)
194
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
195
+ to_drop = False
196
+ if self.training:
197
+ dropout_probability = torch.rand([])
198
+ if dropout_probability < self.layerdrop: # skip the layer
199
+ to_drop = True
200
+
201
+ if self.config.use_fddt and idx < len(self.fddts):
202
+ hidden_states = self.fddts[idx](hidden_states, stno_mask)
203
+
204
+ if to_drop:
205
+ layer_outputs = (None, None)
206
+ else:
207
+ if self.gradient_checkpointing and self.training:
208
+ layer_outputs = self._gradient_checkpointing_func(
209
+ encoder_layer.__call__,
210
+ hidden_states,
211
+ None,
212
+ (head_mask[idx] if head_mask is not None else None),
213
+ output_attentions,
214
+ )
215
+ else:
216
+ layer_outputs = encoder_layer(
217
+ hidden_states,
218
+ None,
219
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
220
+ output_attentions=output_attentions,
221
+ )
222
+
223
+ hidden_states = layer_outputs[0]
224
+
225
+ if output_attentions:
226
+ all_attentions = all_attentions + (layer_outputs[1],)
227
+
228
+ hidden_states = self.layer_norm(hidden_states)
229
+ if output_hidden_states:
230
+ encoder_states = encoder_states + (hidden_states,)
231
+
232
+ if not return_dict:
233
+ outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
234
+ else:
235
+ outputs = BaseModelOutput(
236
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
237
+ )
238
+
239
+ if hasattr(self, "additional_layer"):
240
+ inter_output, = self.additional_layer(
241
+ outputs.last_hidden_state,
242
+ attention_mask=None,
243
+ output_attentions=output_attentions,
244
+ layer_head_mask=None,
245
+ )
246
+ elif hasattr(self, "additional_self_attention_layer"):
247
+ inter_output, _, __ = self.additional_self_attention_layer(
248
+ outputs.last_hidden_state,
249
+ attention_mask=None,
250
+ output_attentions=output_attentions,
251
+ layer_head_mask=None,
252
+ )
253
+ else:
254
+ inter_output = outputs.last_hidden_state
255
+
256
+ inter_output = self.final_dropout(inter_output)
257
+ if hasattr(self, "subsample_conv2"):
258
+ inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2)
259
+ if self.ctc_weight > 0.0:
260
+ logits = self.lm_head(inter_output)
261
+ else:
262
+ logits = None
263
+
264
+ return CausalLMOutput(
265
+ logits=logits,
266
+ hidden_states=outputs.hidden_states,
267
+ attentions=outputs.attentions,
268
+ )
generation.py ADDED
@@ -0,0 +1,1768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ from typing import Iterator
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from decimal import Decimal, ROUND_HALF_UP
14
+
15
+ from transformers import LogitsProcessorList, SuppressTokensLogitsProcessor, \
16
+ SuppressTokensAtBeginLogitsProcessor
17
+ from transformers.generation.configuration_utils import GenerationConfig
18
+ from transformers.generation.configuration_utils import GenerationMode
19
+ from transformers.generation.logits_process import (
20
+ LogitsProcessorList,
21
+ SuppressTokensAtBeginLogitsProcessor,
22
+ SuppressTokensLogitsProcessor, )
23
+ from transformers.generation.logits_process import WhisperNoSpeechDetection
24
+ from transformers.generation.stopping_criteria import (
25
+ StoppingCriteriaList,
26
+ )
27
+ from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
28
+ stack_model_outputs, GenerateBeamEncoderDecoderOutput, _split_model_inputs, GenerateNonBeamOutput, \
29
+ GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
30
+ from transformers.modeling_outputs import BaseModelOutput
31
+ from transformers.models.whisper.modeling_whisper import (
32
+ WhisperForConditionalGeneration,
33
+ )
34
+ from transformers.models.whisper.generation_whisper import _get_attr_from_logit_processors, _pad_to_max_length
35
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
36
+ from transformers.utils import logging
37
+
38
+ from .utils import WhisperTimeStampLogitsProcessorCustom
39
+ from .decoding import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
40
+
41
+ logging.set_verbosity_debug()
42
+ logger = logging.get_logger("transformers")
43
+
44
+
45
+ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
46
+ def _prepare_encoder_decoder_kwargs_for_generation(
47
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
48
+ ) -> Dict[str, Any]:
49
+ # self.encoder_output_lens = self._get_feat_extract_output_lengths(
50
+ # model_kwargs['attention_mask_enc'].sum(dim=1)
51
+ # ).int()
52
+ generation_config.output_hidden_states = True
53
+
54
+ # pylint: disable=no-memberva
55
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
56
+ inputs_tensor, model_kwargs, model_input_name, generation_config
57
+ )
58
+ self.encoder_logits = model_kwargs["encoder_outputs"].logits
59
+
60
+ return model_kwargs
61
+
62
+ @staticmethod
63
+ def _expand_inputs_for_generation(
64
+ expand_size: int = 1,
65
+ is_encoder_decoder: bool = False,
66
+ input_ids: Optional[torch.LongTensor] = None,
67
+ **model_kwargs,
68
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
69
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
70
+
71
+ def _expand_dict_for_generation(dict_to_expand):
72
+ for key in dict_to_expand:
73
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key != "loss":
74
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
75
+ return dict_to_expand
76
+
77
+ if input_ids is not None:
78
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
79
+
80
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
81
+
82
+ if is_encoder_decoder:
83
+ if model_kwargs.get("encoder_outputs") is None:
84
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
85
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
86
+ if "hidden_states" in model_kwargs["encoder_outputs"]:
87
+ model_kwargs["encoder_outputs"]["hidden_states"] = tuple(
88
+ hidden_state.repeat_interleave(expand_size, dim=0) for hidden_state in
89
+ model_kwargs["encoder_outputs"]["hidden_states"]
90
+ )
91
+
92
+ return input_ids, model_kwargs
93
+
94
+ def generate(
95
+ self,
96
+ input_features: Optional[torch.Tensor] = None,
97
+ generation_config: Optional[GenerationConfig] = None,
98
+ logits_processor: Optional[LogitsProcessorList] = None,
99
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
100
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
101
+ synced_gpus: bool = False,
102
+ return_timestamps: Optional[bool] = None,
103
+ task: Optional[str] = None,
104
+ language: Optional[str] = None,
105
+ is_multilingual: Optional[bool] = None,
106
+ prompt_ids: Optional[torch.Tensor] = None,
107
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
108
+ condition_on_prev_tokens: Optional[bool] = None,
109
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
110
+ compression_ratio_threshold: Optional[float] = None,
111
+ logprob_threshold: Optional[float] = None,
112
+ no_speech_threshold: Optional[float] = None,
113
+ num_segment_frames: Optional[int] = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ time_precision: float = 0.02,
116
+ return_token_timestamps: Optional[bool] = None,
117
+ return_segments: bool = False,
118
+ return_dict_in_generate: Optional[bool] = None,
119
+ assistant_model: Optional["PreTrainedModel"] = None,
120
+ **kwargs,
121
+ ):
122
+ if condition_on_prev_tokens:
123
+ raise NotImplementedError("Current version does not support conditioning")
124
+
125
+ gen_c, _ = self._prepare_generation_config(generation_config, **kwargs)
126
+ gen_mode = gen_c.get_generation_mode(assistant_model)
127
+
128
+ if gen_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.BEAM_SEARCH]:
129
+ raise ValueError(
130
+ f"Provided generation mode {gen_mode} is not supported"
131
+ f" for WhisperForConditionalGeneration with joint CTC decoding")
132
+
133
+ if "stno_mask" in kwargs:
134
+ self.stno_mask = kwargs["stno_mask"]
135
+ if "encoder_outputs" in kwargs:
136
+ self.encoder_logits = kwargs["encoder_outputs"].logits
137
+ # pylint: disable=no-member
138
+ # 0. deprecate old inputs
139
+ if "inputs" in kwargs:
140
+ input_features = kwargs.pop("inputs")
141
+ warnings.warn(
142
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
143
+ FutureWarning,
144
+ )
145
+
146
+ # 1. prepare generation config
147
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
148
+
149
+ # 2. set global generate variables
150
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
151
+ num_segment_frames = input_stride * self.config.max_source_positions
152
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
153
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
154
+ )
155
+ is_shortform = total_input_frames <= num_segment_frames
156
+
157
+ if is_shortform:
158
+ # warn user of ignored inputs
159
+ self._maybe_warn_unused_inputs(
160
+ condition_on_prev_tokens=condition_on_prev_tokens,
161
+ temperature=temperature,
162
+ compression_ratio_threshold=compression_ratio_threshold,
163
+ logprob_threshold=logprob_threshold,
164
+ no_speech_threshold=no_speech_threshold,
165
+ total_input_frames=total_input_frames,
166
+ )
167
+
168
+ # 3. Make sure generation config is correctly set
169
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
170
+ self._set_return_outputs(
171
+ return_dict_in_generate=return_dict_in_generate,
172
+ return_token_timestamps=return_token_timestamps,
173
+ is_shortform=is_shortform,
174
+ logprob_threshold=logprob_threshold,
175
+ generation_config=generation_config,
176
+ )
177
+ self._set_return_timestamps(
178
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
179
+ )
180
+ self._set_language_and_task(
181
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
182
+ )
183
+ self._set_num_frames(
184
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
185
+ )
186
+ self._set_thresholds_and_condition(
187
+ generation_config=generation_config,
188
+ logprob_threshold=logprob_threshold,
189
+ compression_ratio_threshold=compression_ratio_threshold,
190
+ no_speech_threshold=no_speech_threshold,
191
+ condition_on_prev_tokens=condition_on_prev_tokens,
192
+ )
193
+ self._set_prompt_condition_type(
194
+ generation_config=generation_config,
195
+ prompt_condition_type=prompt_condition_type,
196
+ )
197
+
198
+ # pass self.config for backward compatibility
199
+ init_tokens = self._retrieve_init_tokens(
200
+ input_features,
201
+ batch_size=batch_size,
202
+ generation_config=generation_config,
203
+ config=self.config,
204
+ num_segment_frames=num_segment_frames,
205
+ kwargs=kwargs,
206
+ )
207
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
208
+ # where the input ids are handled explicitly by the generate method
209
+ self._check_decoder_input_ids(kwargs=kwargs)
210
+
211
+ # 3. Retrieve logits processors
212
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
213
+ begin_index = init_tokens.shape[1]
214
+ logits_processor = self._retrieve_logit_processors(
215
+ generation_config=generation_config,
216
+ logits_processor=logits_processor,
217
+ begin_index=begin_index, # begin index is index of first generated decoder token
218
+ is_shortform=is_shortform,
219
+ num_beams=kwargs.get("num_beams", 1),
220
+ device=device,
221
+ )
222
+
223
+ # 5. If we're in shortform mode, simple generate the whole input at once and return the output
224
+ if is_shortform:
225
+ if temperature is not None:
226
+ generation_config.temperature = temperature
227
+
228
+ decoder_input_ids = kwargs.pop("decoder_input_ids", None)
229
+ if decoder_input_ids is None:
230
+ decoder_input_ids = init_tokens
231
+
232
+ if prompt_ids is not None:
233
+ decoder_input_ids = torch.cat(
234
+ [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
235
+ )
236
+
237
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
238
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
239
+ raise ValueError(
240
+ f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
241
+ f"is {max_new_tokens}. Thus, the combined length of "
242
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
243
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
244
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
245
+ f"so that their combined length is less than {self.config.max_target_positions}."
246
+ )
247
+
248
+ outputs = super().generate(
249
+ input_features,
250
+ generation_config=generation_config,
251
+ logits_processor=logits_processor,
252
+ stopping_criteria=stopping_criteria,
253
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
254
+ synced_gpus=synced_gpus,
255
+ decoder_input_ids=decoder_input_ids,
256
+ **kwargs,
257
+ )
258
+
259
+ if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"):
260
+ outputs["token_timestamps"] = self._extract_token_timestamps(
261
+ outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames
262
+ )
263
+
264
+ # print("\n".join(self.tokenizer.batch_decode(outputs,skip_special_tokens=True, decode_with_timestamps=True)))
265
+ return outputs
266
+
267
+ # 6. Else we're in longform mode which is more complex.
268
+ # We need to chunk the audio input depending on when the model generates timestamp tokens
269
+
270
+ # 6.1 Set and retrieve global longform generation variables
271
+ self._set_condition_on_prev_tokens(
272
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
273
+ )
274
+
275
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
276
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
277
+ temperature = temperatures[0]
278
+ batch_size = input_features.shape[0]
279
+
280
+ max_frames, seek = self._retrieve_max_frames_and_seek(
281
+ batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
282
+ )
283
+
284
+ # 6.2 Preppare running variables, list for generation
285
+ cur_bsz = batch_size
286
+ current_segments = self._prepare_segments(
287
+ prompt_ids=prompt_ids,
288
+ batch_size=batch_size,
289
+ generation_config=generation_config,
290
+ )
291
+
292
+ batch_idx_map = list(range(batch_size))
293
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
294
+
295
+ # 6.2 Transcribe audio until we reach the end of all input audios
296
+ while (seek < max_frames).any():
297
+ # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
298
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
299
+ # to know which original audio is being decoded
300
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
301
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
302
+ input_features=input_features,
303
+ seek=seek,
304
+ max_frames=max_frames,
305
+ cur_bsz=cur_bsz,
306
+ batch_idx_map=batch_idx_map,
307
+ )
308
+ time_offset = seek * time_precision / input_stride
309
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
310
+
311
+ # 6.4 cut out next 30s segment from input features
312
+ segment_input = self._get_input_segment(
313
+ input_features=input_features,
314
+ seek=seek,
315
+ seek_num_frames=seek_num_frames,
316
+ num_segment_frames=num_segment_frames,
317
+ cur_bsz=cur_bsz,
318
+ batch_idx_map=batch_idx_map,
319
+ )
320
+
321
+ # 6.5 prepare decoder input ids
322
+ suppress_tokens = _get_attr_from_logit_processors(
323
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
324
+ )
325
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
326
+ cur_bsz=cur_bsz,
327
+ init_tokens=init_tokens,
328
+ current_segments=current_segments,
329
+ batch_idx_map=batch_idx_map,
330
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
331
+ prompt_ids=prompt_ids,
332
+ generation_config=generation_config,
333
+ config=self.config,
334
+ device=segment_input.device,
335
+ suppress_tokens=suppress_tokens,
336
+ kwargs=kwargs,
337
+ )
338
+
339
+ # 6.6 set max new tokens or max length
340
+ self._set_max_new_tokens_and_length(
341
+ config=self.config,
342
+ decoder_input_ids=decoder_input_ids,
343
+ generation_config=generation_config,
344
+ )
345
+
346
+ # 6.7 Set current `begin_index` for all logit processors
347
+ for proc in logits_processor:
348
+ if hasattr(proc, "set_begin_index"):
349
+ proc.set_begin_index(decoder_input_ids.shape[-1])
350
+
351
+ # 6.8 Run generate with fallback
352
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
353
+ segment_input=segment_input,
354
+ decoder_input_ids=decoder_input_ids,
355
+ cur_bsz=cur_bsz,
356
+ batch_idx_map=batch_idx_map,
357
+ seek=seek,
358
+ num_segment_frames=num_segment_frames,
359
+ max_frames=max_frames,
360
+ temperatures=temperatures,
361
+ generation_config=generation_config,
362
+ logits_processor=logits_processor,
363
+ stopping_criteria=stopping_criteria,
364
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
365
+ synced_gpus=synced_gpus,
366
+ return_token_timestamps=return_token_timestamps,
367
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
368
+ kwargs=kwargs,
369
+ )
370
+
371
+ # 6.9 In every generated sequence, split by timestamp tokens and extract segments
372
+ if self.config.mt_num_speakers == 1:
373
+ for i, seek_sequence in enumerate(seek_sequences):
374
+ prev_i = batch_idx_map[i]
375
+
376
+ if should_skip[i]:
377
+ seek[prev_i] += seek_num_frames[prev_i]
378
+ continue
379
+
380
+ segments, segment_offset = self._retrieve_segment(
381
+ seek_sequence=seek_sequence,
382
+ seek_outputs=seek_outputs,
383
+ time_offset=time_offset,
384
+ timestamp_begin=timestamp_begin,
385
+ seek_num_frames=seek_num_frames,
386
+ time_precision=time_precision,
387
+ input_stride=input_stride,
388
+ prev_idx=prev_i,
389
+ idx=i,
390
+ return_token_timestamps=return_token_timestamps,
391
+ )
392
+
393
+ current_segments[prev_i] += segments
394
+ seek[prev_i] += segment_offset
395
+ else:
396
+ # We have to make sure all speakers are synchronized thus we have to find minumum of seeks that each instance like
397
+ for j, seek_seqs in enumerate(
398
+ [seek_sequences[i * self.config.mt_num_speakers:(i + 1) * self.config.mt_num_speakers] for i in
399
+ range(len(seek_sequences) // self.config.mt_num_speakers)]):
400
+ indexes = [j * self.config.mt_num_speakers + i for i in range(self.config.mt_num_speakers)]
401
+ prev_ids = [batch_idx_map[i] for i in indexes]
402
+
403
+ if all([should_skip[i] for i in indexes]):
404
+ for i, prev_i in zip(indexes, prev_ids):
405
+ seek[prev_i] += seek_num_frames[prev_i]
406
+ continue
407
+
408
+ segments, segment_offset = self._retrieve_segment_mt(
409
+ seek_sequences=seek_seqs,
410
+ seek_outputs=seek_outputs,
411
+ time_offset=time_offset,
412
+ timestamp_begin=timestamp_begin,
413
+ seek_num_frames=seek_num_frames,
414
+ time_precision=time_precision,
415
+ input_stride=input_stride,
416
+ prev_ids=prev_ids,
417
+ ids=indexes,
418
+ return_token_timestamps=return_token_timestamps,
419
+ )
420
+
421
+ for prev_i, i in zip(prev_ids, range(self.config.mt_num_speakers)):
422
+ current_segments[prev_i] += segments[i]
423
+ seek[prev_i] += segment_offset[i]
424
+
425
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
426
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
427
+ final_segments = (
428
+ [x[1:] for x in current_segments]
429
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
430
+ else current_segments
431
+ )
432
+ sequences = _pad_to_max_length(
433
+ final_segments, generation_config.pad_token_id, device=self.device, padding="right"
434
+ )
435
+
436
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
437
+ output = {"sequences": sequences, "segments": final_segments}
438
+
439
+ self.encoder_logits = None
440
+
441
+ if isinstance(output, dict):
442
+ output = self._fix_timestamps_from_segmentation(output)
443
+
444
+ return output
445
+
446
+ @staticmethod
447
+ def _find_common_seek(sequences, seeks):
448
+ """
449
+ Finds the minimum seek that does not overlap with other sequences,
450
+ and falls back to (segment.start - 0.2) if needed. Assumes:
451
+ - 'seeks' is a list of (seek_time_int, sequence_index),
452
+ - seek_time_int is in timestamp * 100 format (e.g., 125.5s -> 12550).
453
+ """
454
+
455
+ def is_valid_seek(seek_time, exclude_seq_idx):
456
+ for idx, seq in enumerate(sequences):
457
+ if idx == exclude_seq_idx:
458
+ continue
459
+ for segment in seq:
460
+ start = getattr(segment, 'start', segment['start'])
461
+ end = getattr(segment, 'end', segment['end'])
462
+ if seek_time < start:
463
+ break # Segments are sorted by end
464
+ if start < seek_time < end:
465
+ return False
466
+ return True
467
+
468
+ # Step 1: Find minimum seek
469
+ # if all seek values are the same, return it immediately
470
+ seeks = [s if isinstance(s, int) else s.item() for s in seeks]
471
+ if len(set(seeks)) == 1:
472
+ return seeks[0]
473
+
474
+ min_seek_val = min(seeks)
475
+ min_seek_idx = seeks.index(min_seek_val)
476
+ min_seek_real = min_seek_val / 100
477
+
478
+ if is_valid_seek(min_seek_real, min_seek_idx):
479
+ return min_seek_val
480
+
481
+ # Step 2: Try fallback seeks from all sequences (segment.start - 0.1s)
482
+ fallback_seeks = set()
483
+ for idx, seq in enumerate(sequences):
484
+ for segment in seq:
485
+ start = getattr(segment, 'start', segment['start'])
486
+ if isinstance(start, torch.Tensor):
487
+ start = start.item()
488
+ candidate = round(start, 2)
489
+ fallback_seeks.add((candidate, idx, True))
490
+ end = getattr(segment, 'end', segment['end'])
491
+ if isinstance(end, torch.Tensor):
492
+ end = end.item()
493
+ if end < min_seek_real:
494
+ candidate = round(end, 2)
495
+ fallback_seeks.add((candidate, idx, True))
496
+
497
+ valid_fallbacks = [
498
+ (int(s * 100), idx, is_start) for s, idx, is_start in fallback_seeks
499
+ if is_valid_seek(s, min_seek_idx)
500
+ ]
501
+
502
+ if valid_fallbacks:
503
+ return max(valid_fallbacks)
504
+
505
+ # Step 3: Nothing valid
506
+ return 0
507
+
508
+ @staticmethod
509
+ def remove_segments_after_seek(sequences, seek, eps=100):
510
+ """
511
+ Keep only segments that finish before given timestamp.
512
+
513
+ Args:
514
+ sequences: List of lists, each containing segments (dict or object with 'start' and 'end').
515
+ seek: Integer seek timestamp (e.g., timestamp * 100).
516
+
517
+ Returns:
518
+ None. Modifies the sequences in-place.
519
+ """
520
+ return [[seg for seg in seq if (getattr(seg, 'end', seg['end']) * 100 <= seek + eps)] for seq in sequences]
521
+
522
+ @staticmethod
523
+ def _retrieve_segment_wo_seek(
524
+ seek_sequence,
525
+ seek_outputs,
526
+ time_offset,
527
+ timestamp_begin,
528
+ seek_num_frames,
529
+ time_precision,
530
+ input_stride,
531
+ prev_idx,
532
+ idx,
533
+ return_token_timestamps,
534
+ ):
535
+ # find the predicted "end of segment" predictions of Whisper
536
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
537
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
538
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
539
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
540
+ timestamp_segment_indices.add_(1)
541
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
542
+
543
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
544
+ # "end of segment" prediction and slice the decoding into segments accordingly
545
+ if len(timestamp_segment_indices) > 0:
546
+ # if the output contains two consecutive timestamp tokens
547
+ slices = timestamp_segment_indices.tolist()
548
+ segments = []
549
+ if single_timestamp_ending:
550
+ slices.append(len(seek_sequence))
551
+
552
+ last_slice = 0
553
+ # Add each segment to list of all segments
554
+ for current_slice in slices:
555
+ sliced_tokens = seek_sequence[last_slice:current_slice]
556
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
557
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
558
+ segments.append(
559
+ {
560
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
561
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
562
+ "tokens": sliced_tokens,
563
+ "result": seek_outputs[idx],
564
+ }
565
+ )
566
+ if return_token_timestamps:
567
+ segments[-1]["token_timestamps"] = (
568
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
569
+ )
570
+ last_slice = current_slice
571
+
572
+ if not single_timestamp_ending:
573
+ # generate all predictions after the last predicted "end of segment" and seek by 30s
574
+ sliced_tokens = seek_sequence[last_slice:]
575
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
576
+ end_timestamp_pos = seek_num_frames[prev_idx] // 2
577
+ segments.append(
578
+ {
579
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
580
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
581
+ "tokens": sliced_tokens,
582
+ "result": seek_outputs[idx],
583
+ }
584
+ )
585
+ segment_offset = seek_num_frames[prev_idx]
586
+ else:
587
+ # If whisper does not predict any "end of segment" token, then
588
+ # the whole decoding is considered a segment and we add it to the list of segments
589
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
590
+ start_timestamp_pos = 0.0
591
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
592
+
593
+ if timestamps.numel() > 1:
594
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
595
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
596
+ elif timestamps.numel() == 1:
597
+ # no consecutive timestamps but it has a timestamp; use the last one.
598
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
599
+ segments = [
600
+ {
601
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
602
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
603
+ "tokens": seek_sequence,
604
+ "result": seek_outputs[idx],
605
+ }
606
+ ]
607
+
608
+ segment_offset = seek_num_frames[prev_idx]
609
+
610
+ return segments, segment_offset
611
+
612
+ def _retrieve_segment_mt(
613
+ self,
614
+ seek_sequences,
615
+ seek_outputs,
616
+ time_offset,
617
+ timestamp_begin,
618
+ seek_num_frames,
619
+ time_precision,
620
+ input_stride,
621
+ prev_ids,
622
+ ids,
623
+ return_token_timestamps,
624
+ ):
625
+ sequences, seeks = [], []
626
+ for sequence, prev_id, idx in zip(seek_sequences, prev_ids, ids):
627
+ seq, seek = self._retrieve_segment(
628
+ seek_sequence=sequence,
629
+ seek_outputs=seek_outputs,
630
+ time_offset=time_offset,
631
+ timestamp_begin=timestamp_begin,
632
+ seek_num_frames=seek_num_frames,
633
+ time_precision=time_precision,
634
+ input_stride=input_stride,
635
+ prev_idx=prev_id,
636
+ idx=idx,
637
+ return_token_timestamps=return_token_timestamps,
638
+ )
639
+ sequences.append(seq)
640
+ seeks.append(seek + int(time_offset[prev_id] * 100))
641
+ # best_seek = self._find_common_seek(sequences, seeks)
642
+ best_seek = seeks[0]
643
+ # print(f"Best seek {best_seek}")
644
+ if best_seek - (min(time_offset[prev_ids]) * 100) < 100:
645
+ # we cannot rollback, we have to decode segments as they are
646
+ sequences, seeks = [], []
647
+ for sequence, prev_id, idx in zip(seek_sequences, prev_ids, ids):
648
+ seq, seek = self._retrieve_segment_wo_seek(
649
+ seek_sequence=sequence,
650
+ seek_outputs=seek_outputs,
651
+ time_offset=time_offset,
652
+ timestamp_begin=timestamp_begin,
653
+ seek_num_frames=seek_num_frames,
654
+ time_precision=time_precision,
655
+ input_stride=input_stride,
656
+ prev_idx=prev_id,
657
+ idx=idx,
658
+ return_token_timestamps=return_token_timestamps,
659
+ )
660
+ sequences.append(seq)
661
+ seeks.append(seek)
662
+ return sequences, seeks
663
+
664
+ seqs_new = self.remove_segments_after_seek(sequences, best_seek)
665
+ seeks = [best_seek - int(min(time_offset[prev_ids]) * 100) for _ in seeks]
666
+ return seqs_new, seeks
667
+
668
+ def _beam_search(
669
+ self,
670
+ input_ids: torch.LongTensor,
671
+ beam_scorer: BeamScorer,
672
+ logits_processor: LogitsProcessorList,
673
+ stopping_criteria: StoppingCriteriaList,
674
+ generation_config: GenerationConfig,
675
+ synced_gpus: bool,
676
+ logits_warper: Optional[LogitsProcessorList] = None,
677
+ **model_kwargs,
678
+ ) -> Union[GenerateBeamOutput, torch.LongTensor]:
679
+ r"""
680
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
681
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
682
+
683
+ Parameters:
684
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
685
+ The sequence used as a prompt for the generation.
686
+ beam_scorer (`BeamScorer`):
687
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
688
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
689
+ logits_processor (`LogitsProcessorList`):
690
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
691
+ used to modify the prediction scores of the language modeling head applied at each generation step.
692
+ stopping_criteria (`StoppingCriteriaList`:
693
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
694
+ used to tell if the generation loop should stop.
695
+ generation_config ([`~generation.GenerationConfig`]):
696
+ The generation configuration to be used as parametrization of the decoding method.
697
+ synced_gpus (`bool`):
698
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
699
+ logits_warper (`LogitsProcessorList`, *optional*):
700
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
701
+ to warp the prediction score distribution of the language modeling head applied before multinomial
702
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
703
+ `generation_config`)
704
+ model_kwargs:
705
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
706
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
707
+
708
+ Return:
709
+ [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
710
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
711
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
712
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
713
+ `model.config.is_encoder_decoder=True`.
714
+ """
715
+ # init values
716
+ pad_token_id = generation_config.pad_token_id
717
+ eos_token_id = generation_config.eos_token_id
718
+ output_attentions = generation_config.output_attentions
719
+ output_hidden_states = generation_config.output_hidden_states
720
+ output_scores = generation_config.output_scores
721
+ output_logits = generation_config.output_logits
722
+ return_dict_in_generate = generation_config.return_dict_in_generate
723
+ sequential = generation_config.low_memory
724
+ do_sample = generation_config.do_sample
725
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
726
+ raise ValueError(
727
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
728
+ f"{logits_warper})."
729
+ )
730
+
731
+ batch_size = len(beam_scorer._beam_hyps)
732
+ num_beams = beam_scorer.num_beams
733
+
734
+ batch_beam_size, cur_len = input_ids.shape
735
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
736
+
737
+ if num_beams * batch_size != batch_beam_size:
738
+ raise ValueError(
739
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
740
+ )
741
+
742
+ # init attention / hidden states / scores tuples
743
+ scores = () if (return_dict_in_generate and output_scores) else None
744
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
745
+ beam_indices = (
746
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
747
+ )
748
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
749
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
750
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
751
+
752
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
753
+ if return_dict_in_generate and self.config.is_encoder_decoder:
754
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
755
+ encoder_hidden_states = (
756
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
757
+ )
758
+
759
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
760
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
761
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
762
+ beam_scores[:, 1:] = -1e9
763
+ beam_scores = beam_scores.view((batch_size * num_beams,))
764
+
765
+ this_peer_finished = False
766
+
767
+ decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
768
+
769
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
770
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
771
+
772
+ # if sequential is True, split the input to batches of batch_size and run sequentially
773
+ if sequential:
774
+ if any(
775
+ model_name in self.__class__.__name__.lower()
776
+ for model_name in [
777
+ "fsmt",
778
+ "reformer",
779
+ "bloom",
780
+ "ctrl",
781
+ "gpt_bigcode",
782
+ "transo_xl",
783
+ "xlnet",
784
+ "cpm",
785
+ "jamba",
786
+ ]
787
+ ):
788
+ raise RuntimeError(
789
+ f"Currently generation for {self.__class__.__name__} is not supported "
790
+ f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
791
+ )
792
+
793
+ inputs_per_sub_batches = _split_model_inputs(
794
+ model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
795
+ )
796
+ outputs_per_sub_batch = [
797
+ self(
798
+ **inputs_per_sub_batch,
799
+ return_dict=True,
800
+ output_attentions=output_attentions,
801
+ output_hidden_states=output_hidden_states,
802
+ )
803
+ for inputs_per_sub_batch in inputs_per_sub_batches
804
+ ]
805
+
806
+ outputs = stack_model_outputs(outputs_per_sub_batch)
807
+
808
+ else: # Unchanged original behavior
809
+ outputs = self(
810
+ **model_inputs,
811
+ return_dict=True,
812
+ output_attentions=output_attentions,
813
+ output_hidden_states=output_hidden_states,
814
+ )
815
+
816
+ if synced_gpus and this_peer_finished:
817
+ cur_len = cur_len + 1
818
+ continue # don't waste resources running the code we don't need
819
+
820
+ next_token_logits = outputs.logits[:, -1, :]
821
+ next_token_scores = nn.functional.log_softmax(
822
+ next_token_logits, dim=-1
823
+ ) # (batch_size * num_beams, vocab_size)
824
+
825
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
826
+ if do_sample:
827
+ next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
828
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
829
+ next_token_scores_processed
830
+ )
831
+
832
+ # Store scores, attentions and hidden_states when required
833
+ if return_dict_in_generate:
834
+ if output_scores:
835
+ scores += (next_token_scores_processed,)
836
+ if output_logits:
837
+ raw_logits += (next_token_logits,)
838
+ if output_attentions:
839
+ decoder_attentions += (
840
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
841
+ )
842
+ if self.config.is_encoder_decoder:
843
+ cross_attentions += (outputs.cross_attentions,)
844
+ if output_hidden_states:
845
+ decoder_hidden_states += (
846
+ (outputs.decoder_hidden_states,)
847
+ if self.config.is_encoder_decoder
848
+ else (outputs.hidden_states,)
849
+ )
850
+
851
+ # reshape for beam search
852
+ vocab_size = next_token_scores.shape[-1]
853
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
854
+
855
+ # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
856
+ # non eos token per beam.
857
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
858
+ n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
859
+ if do_sample:
860
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
861
+ next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
862
+ next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
863
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
864
+ next_tokens = torch.gather(next_tokens, -1, _indices)
865
+ else:
866
+ next_token_scores, next_tokens = torch.topk(
867
+ next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
868
+ )
869
+
870
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
871
+ next_tokens = next_tokens % vocab_size
872
+
873
+ # stateless
874
+ beam_outputs = beam_scorer.process(
875
+ input_ids,
876
+ next_token_scores,
877
+ next_tokens,
878
+ next_indices,
879
+ pad_token_id=pad_token_id,
880
+ eos_token_id=eos_token_id,
881
+ beam_indices=beam_indices,
882
+ decoder_prompt_len=decoder_prompt_len,
883
+ )
884
+
885
+ beam_scores = beam_outputs["next_beam_scores"]
886
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
887
+ beam_idx = beam_outputs["next_beam_indices"]
888
+
889
+ # Based on the beam idx and next tokens reshuffle the ctc prev states and scores
890
+ if hasattr(self, "ctc_rescorer"):
891
+ self.ctc_rescorer.update_state(beam_next_tokens, beam_idx)
892
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
893
+
894
+ model_kwargs = self._update_model_kwargs_for_generation(
895
+ outputs,
896
+ model_kwargs,
897
+ is_encoder_decoder=self.config.is_encoder_decoder,
898
+ )
899
+ if model_kwargs.get("past_key_values", None) is not None:
900
+ model_kwargs["past_key_values"] = self._temporary_reorder_cache(
901
+ model_kwargs["past_key_values"], beam_idx
902
+ )
903
+
904
+ if return_dict_in_generate and output_scores:
905
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
906
+
907
+ # increase cur_len
908
+ cur_len = cur_len + 1
909
+
910
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
911
+ this_peer_finished = True
912
+
913
+ sequence_outputs = beam_scorer.finalize(
914
+ input_ids,
915
+ beam_scores,
916
+ next_tokens,
917
+ next_indices,
918
+ pad_token_id=pad_token_id,
919
+ eos_token_id=eos_token_id,
920
+ max_length=stopping_criteria.max_length,
921
+ beam_indices=beam_indices,
922
+ decoder_prompt_len=decoder_prompt_len,
923
+ )
924
+
925
+ if return_dict_in_generate:
926
+ if not output_scores:
927
+ sequence_outputs["sequence_scores"] = None
928
+
929
+ if self.config.is_encoder_decoder:
930
+ return GenerateBeamEncoderDecoderOutput(
931
+ sequences=sequence_outputs["sequences"],
932
+ sequences_scores=sequence_outputs["sequence_scores"],
933
+ scores=scores,
934
+ logits=raw_logits,
935
+ beam_indices=sequence_outputs["beam_indices"],
936
+ encoder_attentions=encoder_attentions,
937
+ encoder_hidden_states=encoder_hidden_states,
938
+ decoder_attentions=decoder_attentions,
939
+ cross_attentions=cross_attentions,
940
+ decoder_hidden_states=decoder_hidden_states,
941
+ past_key_values=model_kwargs.get("past_key_values"),
942
+ )
943
+ else:
944
+ return GenerateBeamDecoderOnlyOutput(
945
+ sequences=sequence_outputs["sequences"],
946
+ sequences_scores=sequence_outputs["sequence_scores"],
947
+ scores=scores,
948
+ logits=raw_logits,
949
+ beam_indices=sequence_outputs["beam_indices"],
950
+ attentions=decoder_attentions,
951
+ hidden_states=decoder_hidden_states,
952
+ past_key_values=model_kwargs.get("past_key_values"),
953
+ )
954
+ else:
955
+ return sequence_outputs["sequences"]
956
+
957
+ def _sample(
958
+ self,
959
+ input_ids: torch.LongTensor,
960
+ logits_processor: LogitsProcessorList,
961
+ stopping_criteria: StoppingCriteriaList,
962
+ generation_config: GenerationConfig,
963
+ synced_gpus: bool,
964
+ streamer: Optional["BaseStreamer"],
965
+ logits_warper: Optional[LogitsProcessorList] = None,
966
+ **model_kwargs,
967
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
968
+ r"""
969
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
970
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
971
+
972
+ Parameters:
973
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
974
+ The sequence used as a prompt for the generation.
975
+ logits_processor (`LogitsProcessorList`):
976
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
977
+ used to modify the prediction scores of the language modeling head applied at each generation step.
978
+ stopping_criteria (`StoppingCriteriaList`):
979
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
980
+ used to tell if the generation loop should stop.
981
+ generation_config ([`~generation.GenerationConfig`]):
982
+ The generation configuration to be used as parametrization of the decoding method.
983
+ synced_gpus (`bool`):
984
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
985
+ streamer (`BaseStreamer`, *optional*):
986
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
987
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
988
+ logits_warper (`LogitsProcessorList`, *optional*):
989
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
990
+ to warp the prediction score distribution of the language modeling head applied before multinomial
991
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
992
+ `generation_config`)
993
+ model_kwargs:
994
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
995
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
996
+
997
+ Return:
998
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
999
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
1000
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
1001
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
1002
+ `model.config.is_encoder_decoder=True`.
1003
+ """
1004
+ # init values
1005
+ pad_token_id = generation_config.pad_token_id
1006
+ output_attentions = generation_config.output_attentions
1007
+ output_hidden_states = generation_config.output_hidden_states
1008
+ output_scores = generation_config.output_scores
1009
+ output_logits = generation_config.output_logits
1010
+ return_dict_in_generate = generation_config.return_dict_in_generate
1011
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
1012
+ do_sample = generation_config.do_sample
1013
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
1014
+ raise ValueError(
1015
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
1016
+ f"{logits_warper})."
1017
+ )
1018
+
1019
+ # init attention / hidden states / scores tuples
1020
+ scores = () if (return_dict_in_generate and output_scores) else None
1021
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
1022
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1023
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1024
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
1025
+
1026
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
1027
+ if return_dict_in_generate and self.config.is_encoder_decoder:
1028
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
1029
+ encoder_hidden_states = (
1030
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
1031
+ )
1032
+
1033
+ # keep track of which sequences are already finished
1034
+ batch_size = input_ids.shape[0]
1035
+ this_peer_finished = False
1036
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1037
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1038
+
1039
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1040
+ # prepare model inputs
1041
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1042
+
1043
+ # forward pass to get next token
1044
+ outputs = self(
1045
+ **model_inputs,
1046
+ return_dict=True,
1047
+ output_attentions=output_attentions,
1048
+ output_hidden_states=output_hidden_states,
1049
+ )
1050
+
1051
+ if synced_gpus and this_peer_finished:
1052
+ continue # don't waste resources running the code we don't need
1053
+
1054
+ next_token_logits = outputs.logits[:, -1, :]
1055
+
1056
+ # pre-process distribution
1057
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1058
+ if do_sample:
1059
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1060
+
1061
+ # Store scores, attentions and hidden_states when required
1062
+ if return_dict_in_generate:
1063
+ if output_scores:
1064
+ scores += (next_token_scores,)
1065
+ if output_logits:
1066
+ raw_logits += (next_token_logits,)
1067
+ if output_attentions:
1068
+ decoder_attentions += (
1069
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
1070
+ )
1071
+ if self.config.is_encoder_decoder:
1072
+ cross_attentions += (outputs.cross_attentions,)
1073
+
1074
+ if output_hidden_states:
1075
+ decoder_hidden_states += (
1076
+ (outputs.decoder_hidden_states,)
1077
+ if self.config.is_encoder_decoder
1078
+ else (outputs.hidden_states,)
1079
+ )
1080
+
1081
+ # token selection
1082
+ if do_sample:
1083
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1084
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1085
+ else:
1086
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
1087
+
1088
+ # finished sentences should have their next token be a padding token
1089
+ if has_eos_stopping_criteria:
1090
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1091
+
1092
+ # Based on the next tokens select the ctc prev states and scores
1093
+ if hasattr(self, "ctc_rescorer"):
1094
+ self.ctc_rescorer.update_state(next_tokens, torch.arange(next_tokens.shape[0]))
1095
+
1096
+ # update generated ids, model inputs, and length for next step
1097
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1098
+ if streamer is not None:
1099
+ streamer.put(next_tokens.cpu())
1100
+ model_kwargs = self._update_model_kwargs_for_generation(
1101
+ outputs,
1102
+ model_kwargs,
1103
+ is_encoder_decoder=self.config.is_encoder_decoder,
1104
+ )
1105
+
1106
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
1107
+ this_peer_finished = unfinished_sequences.max() == 0
1108
+
1109
+ if streamer is not None:
1110
+ streamer.end()
1111
+
1112
+ if return_dict_in_generate:
1113
+ if self.config.is_encoder_decoder:
1114
+ return GenerateEncoderDecoderOutput(
1115
+ sequences=input_ids,
1116
+ scores=scores,
1117
+ logits=raw_logits,
1118
+ encoder_attentions=encoder_attentions,
1119
+ encoder_hidden_states=encoder_hidden_states,
1120
+ decoder_attentions=decoder_attentions,
1121
+ cross_attentions=cross_attentions,
1122
+ decoder_hidden_states=decoder_hidden_states,
1123
+ past_key_values=model_kwargs.get("past_key_values"),
1124
+ )
1125
+ else:
1126
+ return GenerateDecoderOnlyOutput(
1127
+ sequences=input_ids,
1128
+ scores=scores,
1129
+ logits=raw_logits,
1130
+ attentions=decoder_attentions,
1131
+ hidden_states=decoder_hidden_states,
1132
+ past_key_values=model_kwargs.get("past_key_values"),
1133
+ )
1134
+ else:
1135
+ return input_ids
1136
+
1137
+ def prepare_kwargs_for_generate(self,
1138
+ segment_input,
1139
+ cur_bsz,
1140
+ batch_idx_map,
1141
+ seek,
1142
+ num_segment_frames,
1143
+ max_frames,
1144
+ kwargs):
1145
+ kwargs["attention_mask_enc"] = torch.ones(cur_bsz, segment_input.size(-1), device=segment_input.device)
1146
+ seek_vad = seek // 2
1147
+ num_frames_vad = num_segment_frames // 2
1148
+ max_frames_vad = max_frames // 2
1149
+ seek_num_frames = (max_frames_vad - seek_vad).clamp(max=num_frames_vad)
1150
+
1151
+ stno_masks = []
1152
+ for i in range(cur_bsz):
1153
+ prev_i = batch_idx_map[i]
1154
+ segment_input_slice = kwargs["stno_mask"][prev_i: prev_i + 1, :,
1155
+ seek_vad[prev_i]: seek_vad[prev_i] + seek_num_frames[prev_i]]
1156
+
1157
+ if segment_input_slice.shape[-1] < num_frames_vad:
1158
+ orig_len = segment_input_slice.shape[-1]
1159
+ # pad to 3000 if necessary
1160
+ segment_input_slice = torch.nn.functional.pad(
1161
+ segment_input_slice, pad=(0, num_frames_vad - orig_len)
1162
+ )
1163
+ # set corresponding padding tokens to 1 in vad mask representing silence
1164
+ segment_input_slice[0, 0, orig_len:] = 1.0
1165
+
1166
+ stno_masks.append(segment_input_slice)
1167
+ kwargs["stno_mask"] = torch.cat(stno_masks, dim=0)
1168
+ self.stno_mask_seek = kwargs["stno_mask"]
1169
+
1170
+ if "per_group_sizes" in kwargs:
1171
+ group_sizes = kwargs["per_group_sizes"].clone()
1172
+ group_sizes[:] = 0
1173
+ cummulative_group_sizes = (
1174
+ kwargs["per_group_sizes"].max().repeat(kwargs["per_group_sizes"].shape[0])).cumsum(dim=0)
1175
+ for i in batch_idx_map:
1176
+ group_idx = (cummulative_group_sizes > i).nonzero().min()
1177
+ group_sizes[group_idx] += 1
1178
+ kwargs["per_group_sizes"] = group_sizes
1179
+
1180
+ if self.vad_seek_callback is not None:
1181
+ self.vad_seek_callback(kwargs["stno_mask"])
1182
+ return kwargs
1183
+
1184
+ def generate_with_fallback(
1185
+ self,
1186
+ segment_input,
1187
+ decoder_input_ids,
1188
+ cur_bsz,
1189
+ batch_idx_map,
1190
+ seek,
1191
+ num_segment_frames,
1192
+ max_frames,
1193
+ temperatures,
1194
+ generation_config,
1195
+ logits_processor,
1196
+ stopping_criteria,
1197
+ prefix_allowed_tokens_fn,
1198
+ synced_gpus,
1199
+ return_token_timestamps,
1200
+ do_condition_on_prev_tokens,
1201
+ kwargs,
1202
+ ):
1203
+ kwargs = copy.copy(kwargs)
1204
+ kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek, num_segment_frames,
1205
+ max_frames, kwargs)
1206
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = super().generate_with_fallback(
1207
+ segment_input,
1208
+ decoder_input_ids,
1209
+ cur_bsz,
1210
+ batch_idx_map,
1211
+ seek,
1212
+ num_segment_frames,
1213
+ max_frames,
1214
+ temperatures,
1215
+ generation_config,
1216
+ logits_processor,
1217
+ stopping_criteria,
1218
+ prefix_allowed_tokens_fn,
1219
+ synced_gpus,
1220
+ return_token_timestamps,
1221
+ do_condition_on_prev_tokens,
1222
+ kwargs,
1223
+ )
1224
+ self.stno_mask_seek = None
1225
+
1226
+ # for i, seq in enumerate(seek_outputs):
1227
+ # print(f"Sequence {i}: {self.tokenizer.decode(seq, decode_with_timestamps=True)}")
1228
+ # print("-"*50)
1229
+
1230
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
1231
+
1232
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1233
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1234
+ """short function to replace num with a itr in lst"""
1235
+ found = any(i in lst for i in itr)
1236
+ if found:
1237
+ lst = [num if i in itr else i for i in lst]
1238
+ else:
1239
+ lst.append(num)
1240
+ return lst
1241
+
1242
+ def language_to_id(language: str) -> int:
1243
+ language = language.lower()
1244
+ if language in generation_config.lang_to_id.keys():
1245
+ language_token = language
1246
+ elif language in TO_LANGUAGE_CODE.keys():
1247
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1248
+ elif language in TO_LANGUAGE_CODE.values():
1249
+ language_token = f"<|{language}|>"
1250
+ else:
1251
+ is_language_code = len(language) == 2
1252
+ raise ValueError(
1253
+ f"Unsupported language: {language}. Language should be one of:"
1254
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1255
+ )
1256
+ if language_token not in generation_config.lang_to_id:
1257
+ raise ValueError(
1258
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1259
+ "(You should just add it to the generation config)"
1260
+ )
1261
+
1262
+ return generation_config.lang_to_id[language_token]
1263
+
1264
+ task = getattr(generation_config, "task", None)
1265
+ language = getattr(generation_config, "language", None)
1266
+
1267
+ forced_decoder_ids = generation_config.forced_decoder_ids
1268
+ if forced_decoder_ids is not None:
1269
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1270
+ logger.warning_once(
1271
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1272
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1273
+ )
1274
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1275
+ forced_decoder_ids = config.forced_decoder_ids
1276
+
1277
+ elif forced_decoder_ids is not None and language is not None:
1278
+ logger.info(
1279
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1280
+ )
1281
+ forced_decoder_ids = None
1282
+
1283
+ init_tokens = [generation_config.decoder_start_token_id]
1284
+
1285
+ # Update init_tokens with languages
1286
+ lang_ids = None
1287
+
1288
+ if forced_decoder_ids is not None:
1289
+ return forced_decoder_ids
1290
+
1291
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1292
+ generation_config.forced_decoder_ids = None
1293
+
1294
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1295
+
1296
+ # Make sure language is a list of strings of the correct length
1297
+ if isinstance(language, (list, tuple)):
1298
+ if any(l is None for l in language):
1299
+ raise TypeError(
1300
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1301
+ )
1302
+ if len(language) != batch_size:
1303
+ raise ValueError(
1304
+ "When passing a list of languages, the length of the list must match the batch size. "
1305
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1306
+ )
1307
+ languages = language
1308
+ elif language is None:
1309
+ # Language will be detected for each item in batch
1310
+ languages = [None] * batch_size
1311
+ else:
1312
+ languages = [language] # Use a length-1 list now, broadcast later
1313
+
1314
+ # Separate init_tokens for each language
1315
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1316
+
1317
+ if language is not None and lang_ids is not None:
1318
+ lang_ids = [language_to_id(l) for l in languages]
1319
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1320
+ # language is not defined or intentially set to `None` to trigger language detection
1321
+ lang_ids = self.detect_language(
1322
+ input_features=input_features,
1323
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1324
+ generation_config=generation_config,
1325
+ num_segment_frames=num_segment_frames,
1326
+ ).tolist()
1327
+ if lang_ids is not None:
1328
+ # append or replace lang_ids to init_tokens
1329
+ for i in range(len(init_tokens)):
1330
+ if len(init_tokens[i]) > 1:
1331
+ init_tokens[i][1] = lang_ids[i]
1332
+ else:
1333
+ init_tokens[i].append(lang_ids[i])
1334
+ del languages
1335
+
1336
+ # Update init_tokens with task
1337
+ for i in range(len(init_tokens)):
1338
+ if task is not None:
1339
+ if task in TASK_IDS:
1340
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1341
+ task_id = generation_config.task_to_id[generation_config.task]
1342
+
1343
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1344
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1345
+ else:
1346
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1347
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1348
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1349
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1350
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1351
+
1352
+ # let's make sure we don't pass `None` tokens as prompt tokens
1353
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1354
+
1355
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1356
+
1357
+ def detect_language(
1358
+ self,
1359
+ input_features: Optional[torch.FloatTensor] = None,
1360
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1361
+ generation_config: Optional[GenerationConfig] = None,
1362
+ num_segment_frames: int = 3000,
1363
+ ) -> torch.Tensor:
1364
+ """
1365
+ Detects language from log-mel input features or encoder_outputs
1366
+
1367
+ Parameters:
1368
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1369
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1370
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1371
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1372
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1373
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1374
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1375
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1376
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1377
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1378
+ generation_config (`~generation.GenerationConfig`, *optional*):
1379
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1380
+ passed to generate matching the attributes of `generation_config` will override them. If
1381
+ `generation_config` is not provided, the default will be used, which had the following loading
1382
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1383
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1384
+ default values, whose documentation should be checked to parameterize generation.
1385
+ num_segment_frames (`int`, defaults to 3000):
1386
+ The number of log-mel frames the model expects
1387
+
1388
+ Return:
1389
+ A `torch.LongTensor` representing the detected language ids.
1390
+ """
1391
+ if input_features is None and encoder_outputs is None:
1392
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1393
+ elif input_features is not None and encoder_outputs is not None:
1394
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1395
+ elif input_features is not None:
1396
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1397
+ batch_size = input_features.shape[0]
1398
+ elif encoder_outputs is not None:
1399
+ inputs = {"encoder_outputs": encoder_outputs}
1400
+ batch_size = (
1401
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1402
+ )
1403
+
1404
+ generation_config = generation_config or self.generation_config
1405
+ decoder_input_ids = (
1406
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1407
+ * generation_config.decoder_start_token_id
1408
+ )
1409
+
1410
+ with torch.no_grad():
1411
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids,
1412
+ stno_mask=self.stno_mask[:, :, :num_segment_frames // 2]).logits[:, -1]
1413
+
1414
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1415
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1416
+
1417
+ logits[:, non_lang_mask] = -np.inf
1418
+
1419
+ lang_ids = logits.argmax(-1)
1420
+
1421
+ return lang_ids
1422
+
1423
+ def _get_logits_processor(
1424
+ self,
1425
+ generation_config: GenerationConfig,
1426
+ input_ids_seq_length: int,
1427
+ encoder_input_ids: torch.LongTensor,
1428
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
1429
+ logits_processor: Optional[LogitsProcessorList],
1430
+ device: str = None,
1431
+ model_kwargs: Optional[Dict[str, Any]] = None,
1432
+ negative_prompt_ids: Optional[torch.Tensor] = None,
1433
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
1434
+ ) -> LogitsProcessorList:
1435
+ # pylint: disable=no-member
1436
+ gen_config_copy = copy.deepcopy(generation_config)
1437
+ gen_config_copy.forced_decoder_ids = None
1438
+ processors = super()._get_logits_processor(
1439
+ gen_config_copy,
1440
+ input_ids_seq_length,
1441
+ encoder_input_ids,
1442
+ prefix_allowed_tokens_fn,
1443
+ logits_processor,
1444
+ device,
1445
+ model_kwargs,
1446
+ negative_prompt_ids,
1447
+ negative_prompt_attention_mask,
1448
+ )
1449
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1450
+ enc_logits = self.encoder_logits
1451
+ if generation_config.num_beams <= 1:
1452
+ processors.append(LogSoftmaxProcessor())
1453
+ else:
1454
+ enc_logits = enc_logits.repeat_interleave(generation_config.num_beams, dim=0)
1455
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
1456
+ enc_logits,
1457
+ torch.full((enc_logits.shape[0],), fill_value=enc_logits.shape[1],
1458
+ device=enc_logits.device),
1459
+ enc_logits.shape[-1] - 1,
1460
+ generation_config.pad_token_id.item(),
1461
+ generation_config.eos_token_id.item(),
1462
+ generation_config.decoder_start_token_id.item(),
1463
+ self.tokenizer,
1464
+ generation_config.ctc_margin,
1465
+ generation_config.ctc_weight,
1466
+ generation_config.num_beams,
1467
+ False,
1468
+ )
1469
+ processors.append(self.ctc_rescorer)
1470
+ return processors
1471
+
1472
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams,
1473
+ device):
1474
+ if generation_config.return_timestamps is True:
1475
+ timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
1476
+ logits_processor = (
1477
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1478
+ )
1479
+
1480
+ if generation_config.suppress_tokens is not None:
1481
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1482
+ logits_processor = (
1483
+ [suppress_tokens_processor]
1484
+ if logits_processor is None
1485
+ else [suppress_tokens_processor] + logits_processor
1486
+ )
1487
+ generation_config.suppress_tokens = None
1488
+
1489
+ if generation_config.begin_suppress_tokens is not None:
1490
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1491
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1492
+ )
1493
+ logits_processor = (
1494
+ [begin_suppress_processor]
1495
+ if logits_processor is None
1496
+ else [begin_suppress_processor] + logits_processor
1497
+ )
1498
+ generation_config.begin_suppress_tokens = None
1499
+
1500
+ if generation_config.no_speech_threshold is not None and not is_shortform:
1501
+ no_speech_detector = WhisperNoSpeechDetection(
1502
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1503
+ begin_index=begin_index,
1504
+ scores_is_logprobs=num_beams > 1,
1505
+ )
1506
+ logits_processor = (
1507
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1508
+ )
1509
+ no_speech_detector.set_model(self)
1510
+
1511
+ return logits_processor
1512
+
1513
+ @staticmethod
1514
+ def round_to_nearest_0_02(x):
1515
+ d = Decimal(str(x)) # Use str(x) to preserve input precision
1516
+ step = Decimal('0.02')
1517
+ # Divide, round, multiply back
1518
+ rounded = (d / step).to_integral_value(rounding=ROUND_HALF_UP) * step
1519
+ return rounded
1520
+
1521
+ def _fix_timestamps_from_segmentation(self, sequences):
1522
+ """
1523
+ Adjusts token sequences with global timestamps to fit within Whisper's 0–30s timestamp token range.
1524
+
1525
+ This function modifies the input sequences by inserting appropriate timestamp tokens and
1526
+ offset corrections to ensure the decoded token order is correct, without splitting any segment.
1527
+ It aligns all timestamps to 0.02-second precision, inserts placeholder segments to bridge
1528
+ time gaps between 30-second windows, and maintains segment continuity during encoding.
1529
+
1530
+ Args:
1531
+ sequences (dict): A dictionary containing:
1532
+ - 'segments': A list of segment lists, each segment being a dict with 'start', 'end', and 'tokens'.
1533
+ - 'sequences': A tensor used to determine device for padding.
1534
+
1535
+ Returns:
1536
+ torch.Tensor: A batch of padded token sequences with corrected timestamp alignment.
1537
+ """
1538
+ # Get the token ID for the "<|0.00|>" timestamp used to detect dummy segments
1539
+ first_timestamp_token = self.tokenizer.get_vocab()["<|0.00|>"]
1540
+ results = []
1541
+
1542
+ # Filter out segments that are either empty or consist only of the "<|0.00|>" token
1543
+ for idx, sequence_segs in enumerate(sequences['segments']):
1544
+ sequences['segments'][idx] = [
1545
+ seg for seg in sequence_segs
1546
+ if len(seg['tokens']) > 0 and (len(seg['tokens']) != 1 or seg['tokens'][0] != first_timestamp_token)
1547
+ ]
1548
+
1549
+ # Iterate over each group of segments (e.g., one per utterance)
1550
+ for idx, sequence_segs in enumerate(sequences['segments']):
1551
+ result = []
1552
+ prev_segment_end_time = None
1553
+ correction = Decimal(0.0)
1554
+
1555
+ for i, seg in enumerate(sequence_segs):
1556
+ # Round start and end times to nearest 0.02 seconds
1557
+ start_time = self.round_to_nearest_0_02(seg['start'].item())
1558
+ end_time = self.round_to_nearest_0_02(seg['end'].item())
1559
+ tokens = seg['tokens']
1560
+
1561
+ # Determine which 30s window this segment falls into
1562
+ current_block = (start_time + correction) // 30
1563
+
1564
+ if prev_segment_end_time is not None:
1565
+ # If not the first segment, calculate difference in 30s windows
1566
+ prev_block = prev_segment_end_time // 30
1567
+ num_dummies = current_block - prev_block - 1
1568
+
1569
+ # Insert (30, [], 30) marker if we're moving to a new block
1570
+ if current_block > prev_block:
1571
+ result.append((30, [], 30))
1572
+
1573
+ # Insert dummy segments to bridge skipped 30s blocks
1574
+ for _ in range(int(num_dummies)):
1575
+ result.append((0, [], 30))
1576
+ else:
1577
+ # For the first segment, add dummy blocks if it starts after 30s
1578
+ for _ in range(int(start_time // 30)):
1579
+ result.append((0, [], 30))
1580
+
1581
+ # Determine whether segment fits in one block or wraps to the next
1582
+ if (start_time + correction) // 30 == (end_time + correction) // 30:
1583
+ # Segment fits within a single 30s window
1584
+ result.append(((start_time + correction) % 30, tokens, (end_time + correction) % 30))
1585
+ else:
1586
+ # Segment would wrap across a 30s boundary
1587
+ new_seg_start = (correction + start_time) % 30
1588
+ new_seg_end = end_time - start_time
1589
+
1590
+ if new_seg_end >= new_seg_start:
1591
+ # Seek back to the beginning of the segment window
1592
+ result.append((new_seg_start, [], new_seg_start))
1593
+ result.append((0, tokens, new_seg_end))
1594
+ # Apply correction to align future timestamps to new 30s block
1595
+ correction = self.round_to_nearest_0_02(-(start_time % 30))
1596
+ else:
1597
+ # Otherwise, just insert with adjusted times
1598
+ result.append((new_seg_start, tokens, new_seg_end))
1599
+ correction = self.round_to_nearest_0_02(30 - (start_time % 30))
1600
+ # print(f'Processed segment {i}, result: {self.tokenizer.decode(self.tokenizer("".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result]))["input_ids"], decode_with_timestamps=True)[-250:]}')
1601
+ # Update the previous segment's end time for next iteration
1602
+ prev_segment_end_time = end_time + correction
1603
+
1604
+ # Convert result segments into a token sequence with proper timestamp formatting
1605
+ encoded = self.tokenizer(
1606
+ "".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result])
1607
+ )['input_ids']
1608
+ results.append(encoded)
1609
+
1610
+ # Pad all sequences to the same length for batching
1611
+ sequences = pad_sequence(
1612
+ [torch.tensor(res, device=sequences['sequences'].device) for res in results],
1613
+ batch_first=True,
1614
+ padding_value=self.tokenizer.pad_token_id
1615
+ )
1616
+ return sequences
1617
+
1618
+ @staticmethod
1619
+ def _retrieve_segment(
1620
+ seek_sequence,
1621
+ seek_outputs,
1622
+ time_offset,
1623
+ timestamp_begin,
1624
+ seek_num_frames,
1625
+ time_precision,
1626
+ input_stride,
1627
+ prev_idx,
1628
+ idx,
1629
+ return_token_timestamps,
1630
+ ):
1631
+ # find the predicted "end of segment" predictions of Whisper
1632
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1633
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1634
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1635
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1636
+ timestamp_segment_indices.add_(1)
1637
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1638
+
1639
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1640
+ # "end of segment" prediction and slice the decoding into segments accordingly
1641
+ if len(timestamp_segment_indices) > 0:
1642
+ # if the output contains two consecutive timestamp tokens
1643
+ slices = timestamp_segment_indices.tolist()
1644
+ segments = []
1645
+ if single_timestamp_ending:
1646
+ slices.append(len(seek_sequence))
1647
+
1648
+ last_slice = 0
1649
+ # Add each segment to list of all segments
1650
+ for current_slice in slices:
1651
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1652
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
1653
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
1654
+ segments.append(
1655
+ {
1656
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1657
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
1658
+ "tokens": sliced_tokens,
1659
+ "result": seek_outputs[idx],
1660
+ }
1661
+ )
1662
+ if return_token_timestamps:
1663
+ segments[-1]["token_timestamps"] = (
1664
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1665
+ )
1666
+ last_slice = current_slice
1667
+
1668
+ if single_timestamp_ending:
1669
+ # single timestamp at the end means no speech after the last timestamp.
1670
+ segment_offset = seek_num_frames[prev_idx]
1671
+ else:
1672
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1673
+ # here we throw away all predictions after the last predicted "end of segment"
1674
+ # since we are cutting right in the middle of an audio
1675
+ last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
1676
+ segment_offset = last_timestamp_pos * input_stride
1677
+ else:
1678
+ # If whisper does not predict any "end of segment" token, then
1679
+ # the whole decoding is considered a segment and we add it to the list of segments
1680
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1681
+ start_timestamp_pos = 0.0
1682
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
1683
+ skip = False
1684
+ segment_offset = seek_num_frames[prev_idx]
1685
+
1686
+ if timestamps.numel() > 1:
1687
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
1688
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
1689
+ elif timestamps.numel() == 1:
1690
+ # no consecutive timestamps but it has a timestamp; use the last one.
1691
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
1692
+ if start_timestamp_pos > 200:
1693
+ # segment does not fit into decoding window, so we need to rollback
1694
+ segment_offset = start_timestamp_pos * input_stride - 100 # timestamp might be inaccurate
1695
+ skip = True
1696
+ else:
1697
+ # empty sequence, or sequence w/o timestamps
1698
+ skip = True
1699
+
1700
+ if skip:
1701
+ segments = []
1702
+ else:
1703
+ segments = [
1704
+ {
1705
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1706
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1707
+ "tokens": seek_sequence,
1708
+ "result": seek_outputs[idx],
1709
+ }
1710
+ ]
1711
+ if return_token_timestamps:
1712
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1713
+ segment_offset = seek_num_frames[prev_idx]
1714
+
1715
+ if segment_offset <= 0:
1716
+ msg = f"Timestamps: {timestamps}, Segments: {segments}"
1717
+ raise ValueError(f"Segment offset: {segment_offset} <= 0. This should not happen!\n{msg}")
1718
+
1719
+ return segments, segment_offset
1720
+
1721
+ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
1722
+ # remove all previously passed decoder input ids
1723
+ if isinstance(seek_outputs, torch.Tensor):
1724
+ seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
1725
+ seek_outputs = torch.hstack((
1726
+ seek_outputs,
1727
+ torch.full((seek_outputs.shape[0], 1),
1728
+ fill_value=generation_config.pad_token_id,
1729
+ dtype=seek_outputs.dtype,
1730
+ device=seek_outputs.device
1731
+ )
1732
+ ))
1733
+ # first_eos = (seek_outputs == generation_config.eos_token_id).int().argmax(dim=1)
1734
+ # biggest_timestamp = generation_config.no_timestamps_token_id + 1 + 30 * 50
1735
+
1736
+ # empty_transcriptions = first_eos == 0
1737
+ # seek_outputs[empty_transcriptions, 0] = generation_config.no_timestamps_token_id + 1 # 0.00 timestamp
1738
+ # seek_outputs[empty_transcriptions, 1] = biggest_timestamp # 30.00 timestamp
1739
+ # seek_outputs[empty_transcriptions, 2] = generation_config.eos_token_id # 30.00 timestamp
1740
+
1741
+ return seek_outputs, seek_outputs
1742
+
1743
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
1744
+ num_frames = getattr(generation_config, "num_frames", None)
1745
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
1746
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
1747
+ )
1748
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1]:]
1749
+
1750
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1]:]
1751
+
1752
+ def split_by_batch_index(values, key, batch_idx):
1753
+ if key == "scores":
1754
+ return [v[batch_idx].cpu() for v in values]
1755
+ elif key == "past_key_values":
1756
+ # we don't save `past_key_values` as this is too costly
1757
+ return None
1758
+ elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
1759
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
1760
+ return values[batch_idx].cpu()
1761
+
1762
+ sequence_tokens = seek_outputs["sequences"]
1763
+ seek_outputs = [
1764
+ {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
1765
+ for i in range(sequence_tokens.shape[0])
1766
+ ]
1767
+
1768
+ return sequence_tokens, seek_outputs
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "begin_suppress_tokens": [
4
+ 220,
5
+ 50256
6
+ ],
7
+ "bos_token_id": 50257,
8
+ "decoder_start_token_id": 50258,
9
+ "eos_token_id": 50257,
10
+ "pad_token_id": 50257,
11
+ "transformers_version": "4.42.0"
12
+ }
layers.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class CustomLinear(nn.Linear):
6
+ def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.init_eye_val = init_eye_val
9
+
10
+
11
+ class CustomDiagonalLinear(nn.Module):
12
+ def __init__(self, d_model, bias=True, init_eye_val=0.0):
13
+ super().__init__()
14
+ self.init_eye_val = init_eye_val
15
+ self.weight = nn.Parameter(torch.full((d_model,), init_eye_val))
16
+ self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
17
+
18
+ def forward(self, input):
19
+ out = input * self.weight
20
+ if self.bias is not None:
21
+ out += self.bias
22
+ return out
23
+
24
+ class Gate(nn.Module):
25
+ def __init__(self, items, init_val=0.0):
26
+ super().__init__()
27
+ self.init_val = init_val
28
+ self.gate = nn.Parameter(torch.full((items,), init_val))
29
+
30
+ def forward(self, input, dim):
31
+ if input.ndim != 4:
32
+ raise ValueError('input must be a 4D tensor')
33
+ if not (0 <= dim <= 3):
34
+ raise ValueError('dim must be 0, 1, 2, or 3')
35
+
36
+ shape = [1] * 4
37
+ shape[dim] = -1
38
+ return input * self.gate.view(*shape)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f8ccccd743d29563587aef0612e16394916fd034ee2ca62bd134837ca34b4b7
3
+ size 3833628952
modeling_dicow.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import CrossEntropyLoss
6
+ import torch.utils.checkpoint
7
+ import torch.utils.checkpoint
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput
9
+ from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import (
10
+ shift_tokens_right,
11
+ )
12
+ from transformers.models.whisper.modeling_whisper import (
13
+ WhisperEncoder,
14
+ )
15
+ from transformers.models.whisper.modeling_whisper import (
16
+ WhisperForConditionalGeneration,
17
+ shift_tokens_right,
18
+ WhisperModel,
19
+ )
20
+ from transformers.models.whisper.modeling_whisper import sinusoids
21
+ from transformers.utils import logging
22
+
23
+ from .config import Seq2SeqLMOutputLosses, Seq2SeqModelOutputLogit, DiCoWConfig
24
+ from .encoder import DiCoWEncoder
25
+ from .FDDT import FDDT
26
+ from .layers import CustomLinear, CustomDiagonalLinear, Gate
27
+ from .generation import DiCoWGenerationMixin
28
+ from .contrastive_loss import ContrastiveLoss
29
+ import wandb
30
+ logging.set_verbosity_debug()
31
+ logger = logging.get_logger("transformers")
32
+
33
+
34
+ class DiCoW(WhisperModel):
35
+ def __init__(self, config: DiCoWConfig):
36
+ super().__init__(config)
37
+ self.encoder = DiCoWEncoder(config)
38
+
39
+ def forward(
40
+ self,
41
+ input_features: Optional[torch.FloatTensor] = None,
42
+ attention_mask: Optional[torch.LongTensor] = None,
43
+ decoder_input_ids: Optional[torch.LongTensor] = None,
44
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
45
+ head_mask: Optional[torch.Tensor] = None,
46
+ decoder_head_mask: Optional[torch.Tensor] = None,
47
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
48
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
49
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
50
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
51
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
52
+ use_cache: Optional[bool] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ stno_mask: Optional[torch.FloatTensor] = None,
57
+ per_group_sizes: Optional[torch.LongTensor] = None,
58
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutputLosses]:
59
+ r"""
60
+ Returns:
61
+
62
+ Example:
63
+ ```python
64
+ >>> import torch
65
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
66
+ >>> from datasets import load_dataset
67
+
68
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
69
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
70
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
71
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
72
+ >>> input_features = inputs.input_features
73
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
74
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
75
+ >>> list(last_hidden_state.shape)
76
+ [1, 2, 512]
77
+ ```"""
78
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79
+ output_hidden_states = (
80
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
+ )
82
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
83
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
84
+
85
+ if encoder_outputs is None:
86
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
87
+
88
+ encoder_outputs = self.encoder(
89
+ input_features,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=True,
92
+ head_mask=head_mask,
93
+ return_dict=return_dict,
94
+ stno_mask=stno_mask,
95
+ per_group_sizes=per_group_sizes
96
+ )
97
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
98
+ # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
99
+ # raise ValueError("encoder_outputs should be of type BaseModelOutput when return_dict=True.")
100
+
101
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
102
+ decoder_outputs = self.decoder(
103
+ input_ids=decoder_input_ids,
104
+ attention_mask=decoder_attention_mask,
105
+ encoder_hidden_states=encoder_outputs.hidden_states[-1],
106
+ head_mask=decoder_head_mask,
107
+ cross_attn_head_mask=cross_attn_head_mask,
108
+ past_key_values=past_key_values,
109
+ inputs_embeds=decoder_inputs_embeds,
110
+ position_ids=decoder_position_ids,
111
+ use_cache=use_cache,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict,
115
+ )
116
+
117
+ if not return_dict:
118
+ return decoder_outputs + encoder_outputs
119
+
120
+ return Seq2SeqModelOutputLogit(
121
+ last_hidden_state=decoder_outputs.last_hidden_state,
122
+ past_key_values=decoder_outputs.past_key_values,
123
+ decoder_hidden_states=decoder_outputs.hidden_states,
124
+ decoder_attentions=decoder_outputs.attentions,
125
+ cross_attentions=decoder_outputs.cross_attentions,
126
+ encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
127
+ encoder_hidden_states=encoder_outputs.hidden_states,
128
+ encoder_attentions=encoder_outputs.attentions,
129
+ encoder_logits=encoder_outputs.logits,
130
+ )
131
+
132
+
133
+ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
134
+ config_class = DiCoWConfig
135
+
136
+ def __init__(self, config: DiCoWConfig):
137
+ super().__init__(config)
138
+ self.model = DiCoW(config)
139
+ self.encoder_logits = None
140
+ self.tokenizer = None
141
+ self.vad_seek_callback = None
142
+ self.stno_mask = None
143
+ self.stno_mask_seek = None
144
+
145
+ # We need this setter as we can't pass a function/method as a config argument.
146
+ # JSON serialization fails at that point.
147
+ def set_vad_seek_callback(self, vad_seek_callback):
148
+ self.vad_seek_callback = vad_seek_callback
149
+
150
+ def set_tokenizer(self, tokenizer):
151
+ self.tokenizer = tokenizer
152
+
153
+ def _init_weights(self, module):
154
+ std = self.config.init_std
155
+ fddt_init = self.config.fddt_init
156
+ if isinstance(module, CustomLinear):
157
+ with torch.no_grad():
158
+ if fddt_init == 'random':
159
+ module.weight.data.normal_(mean=0.0, std=std)
160
+ if module.bias is not None:
161
+ module.bias.data.normal_(mean=0.0, std=std)
162
+ elif fddt_init == 'non-disturbing':
163
+ module.weight.data = torch.eye(*module.weight.shape).data
164
+ if module.bias is not None:
165
+ module.bias.data.zero_()
166
+ elif fddt_init == 'disparagement':
167
+ eye = torch.eye(*module.weight.shape)
168
+ eye *= module.init_eye_val
169
+ module.weight.data = eye.data
170
+ if module.bias is not None:
171
+ module.bias.data.zero_()
172
+ elif isinstance(module, CustomDiagonalLinear):
173
+ with torch.no_grad():
174
+ if fddt_init == 'random':
175
+ module.weight.data.normal_(mean=0.0, std=std)
176
+ if module.bias is not None:
177
+ module.bias.data.normal_(mean=0.0, std=std)
178
+ elif fddt_init == 'non-disturbing':
179
+ module.weight.data = torch.ones_like(module.weight.data).data
180
+ if module.bias is not None:
181
+ module.bias.data.zero_()
182
+ elif fddt_init == 'disparagement':
183
+ module.weight.data = module.init_eye_val * torch.ones_like(module.weight.data).data
184
+ if module.bias is not None:
185
+ module.bias.data.zero_()
186
+ elif isinstance(module, FDDT):
187
+ if module.bias_only:
188
+ if fddt_init == 'random':
189
+ module.target_linear.data.normal_(mean=0.0, std=std)
190
+ module.non_target_linear.data.normal_(mean=0.0, std=std)
191
+ module.overlap_linear.data.normal_(mean=0.0, std=std)
192
+ module.silence_linear.data.normal_(mean=0.0, std=std)
193
+ module.scb.data.normal_(mean=0.0, std=std)
194
+ else:
195
+ module.target_linear.data.zero_()
196
+ module.non_target_linear.data.zero_()
197
+ module.overlap_linear.data.zero_()
198
+ module.silence_linear.data.zero_()
199
+ module.scb.data.zero_()
200
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
201
+ module.weight.data.normal_(mean=0.0, std=std)
202
+ if module.bias is not None:
203
+ module.bias.data.zero_()
204
+ elif isinstance(module, nn.Embedding):
205
+ module.weight.data.normal_(mean=0.0, std=std)
206
+ if module.padding_idx is not None:
207
+ module.weight.data[module.padding_idx].zero_()
208
+ elif isinstance(module, WhisperEncoder):
209
+ with torch.no_grad():
210
+ embed_positions = module.embed_positions.weight
211
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
212
+ elif isinstance(module, nn.LayerNorm):
213
+ module.reset_parameters()
214
+ elif isinstance(module, nn.MultiheadAttention):
215
+ module._reset_parameters()
216
+ elif isinstance(module, nn.ConvTranspose1d):
217
+ module.reset_parameters()
218
+ elif isinstance(module, Gate):
219
+ module.gate.data = module.init_val * torch.ones_like(module.gate.data).data
220
+
221
+ def forward(
222
+ self,
223
+ input_features: Optional[torch.FloatTensor] = None,
224
+ stno_mask: Optional[torch.FloatTensor] = None,
225
+ per_group_sizes: Optional[torch.LongTensor] = None,
226
+ attention_mask_enc: Optional[torch.LongTensor] = None,
227
+ attention_mask: Optional[torch.LongTensor] = None,
228
+ decoder_input_ids: Optional[torch.LongTensor] = None,
229
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
230
+ head_mask: Optional[torch.Tensor] = None,
231
+ decoder_head_mask: Optional[torch.Tensor] = None,
232
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
233
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
234
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
235
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
236
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
237
+ labels: Optional[torch.LongTensor] = None,
238
+ upp_labels: Optional[torch.LongTensor] = None,
239
+ use_cache: Optional[bool] = None,
240
+ output_attentions: Optional[bool] = None,
241
+ output_hidden_states: Optional[bool] = None,
242
+ return_dict: Optional[bool] = None,
243
+ is_valid: Optional[bool] = None,
244
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
245
+ r"""
246
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
247
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
248
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
249
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
250
+
251
+ Returns:
252
+
253
+ Example:
254
+
255
+ ```python
256
+ >>> import torch
257
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
258
+ >>> from datasets import load_dataset
259
+
260
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
261
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
262
+
263
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
264
+
265
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
266
+ >>> input_features = inputs.input_features
267
+
268
+ >>> generated_ids = model.generate(inputs=input_features)
269
+
270
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
271
+ >>> transcription
272
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
273
+ ```"""
274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
275
+
276
+ if labels is not None:
277
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
278
+ decoder_input_ids = shift_tokens_right(
279
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
280
+ )
281
+
282
+ outputs = self.model(
283
+ input_features,
284
+ attention_mask=attention_mask,
285
+ decoder_input_ids=decoder_input_ids,
286
+ encoder_outputs=encoder_outputs,
287
+ decoder_attention_mask=decoder_attention_mask,
288
+ head_mask=head_mask,
289
+ decoder_head_mask=decoder_head_mask,
290
+ cross_attn_head_mask=cross_attn_head_mask,
291
+ past_key_values=past_key_values,
292
+ decoder_inputs_embeds=decoder_inputs_embeds,
293
+ decoder_position_ids=decoder_position_ids,
294
+ use_cache=use_cache,
295
+ output_attentions=output_attentions,
296
+ output_hidden_states=output_hidden_states,
297
+ return_dict=return_dict,
298
+ stno_mask=stno_mask,
299
+ per_group_sizes=per_group_sizes
300
+ )
301
+
302
+ dec_lm_logits = self.proj_out(outputs.last_hidden_state)
303
+ enc_lm_logits = outputs.encoder_logits
304
+
305
+ loss = None
306
+ ctc_loss = 0
307
+
308
+ # remove fake inputs from labels and logits given per group sizes
309
+ if is_valid is not None:
310
+ if self.config.ctc_weight > 0.0:
311
+ enc_lm_logits = enc_lm_logits[is_valid]
312
+ dec_lm_logits = dec_lm_logits[is_valid]
313
+ labels = labels[is_valid]
314
+ upp_labels = upp_labels[is_valid]
315
+
316
+ if labels is not None and self.config.ctc_weight > 0.0:
317
+ enc_labels = labels.clone()
318
+ for token in self.tokenizer.prefix_tokens:
319
+ if (enc_labels[:, 0] == token).all():
320
+ enc_labels = enc_labels[:, 1:]
321
+ enc_labels[enc_labels == self.config.eos_token_id] = -100
322
+
323
+ ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
324
+
325
+ if labels is not None:
326
+ loss_fct = CrossEntropyLoss(reduction='none')
327
+ # move labels to correct device to enable PP
328
+ labels = labels.to(dec_lm_logits.device)
329
+ dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
330
+ dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1))
331
+ dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
332
+ if wandb.run is not None:
333
+ wandb.log({"dec_loss": dec_loss})
334
+ wandb.log({"ctc_loss": ctc_loss})
335
+ loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
336
+
337
+ if self.config.contrastive_loss_weight > 0.0:
338
+ loss_fct = ContrastiveLoss(distance_metric="cosine")
339
+ stno_per_spk_pair = stno_mask.view(-1, self.config.mt_num_speakers, stno_mask.shape[1], stno_mask.shape[2])
340
+ positive_mask = ((stno_per_spk_pair[:, :, 1, :] + stno_per_spk_pair[:, :, 3, :]) > 0.5).flatten(1)
341
+ intermediate_states = outputs.encoder_hidden_states[8].view(-1, self.config.mt_num_speakers, stno_mask.shape[2],
342
+ outputs.encoder_hidden_states[8].shape[-1]).flatten(1, 2)
343
+ valid_pairs = is_valid.view((-1, self.config.mt_num_speakers)).all(dim=-1)
344
+ contrastive_loss = loss_fct(
345
+ intermediate_states[valid_pairs],
346
+ positive_mask[valid_pairs])
347
+ # print(contrastive_loss)
348
+ if wandb.run is not None:
349
+ wandb.log({"contrastive_loss": contrastive_loss})
350
+ if contrastive_loss != 0.0 and loss < 0.5:
351
+ loss += self.config.contrastive_loss_weight * contrastive_loss
352
+ if not return_dict:
353
+ output = (dec_lm_logits,) + outputs[1:]
354
+ return ((loss,) + output) if loss is not None else output
355
+
356
+ return Seq2SeqLMOutputLosses(
357
+ loss=loss,
358
+ logits=dec_lm_logits,
359
+ past_key_values=outputs.past_key_values,
360
+ decoder_hidden_states=outputs.decoder_hidden_states,
361
+ decoder_attentions=outputs.decoder_attentions,
362
+ cross_attentions=outputs.cross_attentions,
363
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
364
+ encoder_hidden_states=outputs.encoder_hidden_states,
365
+ encoder_attentions=outputs.encoder_attentions,
366
+ encoder_logits=enc_lm_logits,
367
+ )
368
+
369
+ def _get_feat_extract_output_lengths(self, attention_mask: torch.Tensor) -> torch.Tensor:
370
+ return (self.model.encoder._get_feat_extract_output_lengths(attention_mask) / 4).ceil()
371
+
372
+ def freeze_except(self, prefixes_to_preheat):
373
+ for name, param in self.named_parameters():
374
+ param.requires_grad = False
375
+ for prefix in prefixes_to_preheat:
376
+ if name.startswith(prefix):
377
+ param.requires_grad = True
378
+
379
+ def suppress_interactions(self):
380
+ """This method suppress final projection in CoAttention blocks to let the original information flow through"""
381
+ for name, param in self.named_parameters():
382
+ if "interaction" in name and "cat_proj" in name:
383
+ with torch.no_grad():
384
+ if "bias" in name:
385
+ param[:] = 0.
386
+ else:
387
+ param[:] *= 0.001
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from transformers import WhisperTimeStampLogitsProcessor
5
+
6
+
7
+ def remove_fake_elements(inputs, per_group_sizes):
8
+ max_spks = per_group_sizes.max()
9
+ number_of_groups = per_group_sizes.shape[0]
10
+ outputs = []
11
+ inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:])
12
+ for i, group_size in enumerate(per_group_sizes):
13
+ outputs.append(inputs[i, :group_size])
14
+ outputs = torch.cat(outputs, dim=0)
15
+ return outputs
16
+
17
+
18
+ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
19
+ def __init__(
20
+ self, generate_config, begin_index: Optional[int] = None,
21
+ _detect_timestamp_from_logprob: Optional[bool] = None
22
+ ): # support for the kwargs
23
+ self.no_timestamps_token_id = generate_config.no_timestamps_token_id
24
+ self.timestamp_begin = generate_config.no_timestamps_token_id + 1
25
+ self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
26
+
27
+ # this variable is mostly just used for testing
28
+ self._detect_timestamp_from_logprob = (
29
+ _detect_timestamp_from_logprob
30
+ if _detect_timestamp_from_logprob is not None
31
+ else getattr(generate_config, "_detect_timestamp_from_logprob", True)
32
+ )
33
+
34
+ num_forced_ids = (
35
+ len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
36
+ )
37
+ self.begin_index = begin_index or (num_forced_ids + 1)
38
+
39
+ self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
40
+ self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None)
41
+ # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
42
+ # self.max_initial_timestamp_index = 50
43
+
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
45
+ # suppress <|notimestamps|> which is handled by without_timestamps
46
+ scores_processed = scores.clone()
47
+ scores_processed[:, self.no_timestamps_token_id] = -float("inf")
48
+
49
+ # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
50
+ for k in range(input_ids.shape[0]):
51
+ sampled_tokens = input_ids[k, self.begin_index:]
52
+ seq = list(sampled_tokens.tolist())
53
+
54
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
55
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
56
+
57
+ if last_was_timestamp:
58
+ if penultimate_was_timestamp: # has to be non-timestamp
59
+ scores_processed[k, self.timestamp_begin:] = -float("inf")
60
+ else: # cannot be normal text tokens
61
+ scores_processed[k, : self.eos_token_id] = -float("inf")
62
+
63
+ timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
64
+ if timestamps.numel() > 0:
65
+ # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
66
+ # The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
67
+ if last_was_timestamp and not penultimate_was_timestamp:
68
+ timestamp_last = timestamps[-1]
69
+ else:
70
+ # Avoid to emit <|0.00|> again
71
+ timestamp_last = timestamps[-1] + 1
72
+
73
+ scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf")
74
+
75
+ # apply the `max_initial_timestamp` option
76
+ if input_ids.shape[1] == self.begin_index:
77
+ eos_scores = scores_processed[:, self.eos_token_id].clone()
78
+ scores_processed[:, : self.timestamp_begin] = -float("inf")
79
+ scores_processed[:, self.eos_token_id] = eos_scores
80
+
81
+ if self.max_initial_timestamp_index is not None:
82
+ last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
83
+ scores_processed[:, last_allowed + 1:] = -float("inf")
84
+ if self.min_initial_timestamp_index is not None:
85
+ first_allowed = self.timestamp_begin + self.min_initial_timestamp_index
86
+ scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf")
87
+
88
+ # if sum of probability over timestamps is above any other token, sample timestamp
89
+ logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1)
90
+ for k in range(input_ids.shape[0]):
91
+ timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1)
92
+ max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
93
+ if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
94
+ scores_processed[k, : self.timestamp_begin] = -float("inf")
95
+
96
+ return scores_processed