File size: 15,126 Bytes
96b9702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
from typing import Optional, Union
import re
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers import Cache
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.models.whisper.modeling_whisper import (
    WhisperForConditionalGeneration,
    shift_tokens_right,
    WhisperModel
)
from transformers.utils import logging
from .config import DiCoWConfig
from .encoder import DiCoWEncoder
from .generation import DiCoWGenerationMixin

logging.set_verbosity_debug()
logger = logging.get_logger("transformers")


class SoftLabelCreator(torch.nn.Module):
    """
    Handles label smoothing for timestamps and the dual-loss logic (Upper vs Lower case).
    """

    def __init__(self, tokenizer, timestamp_sigma=0.08):
        super().__init__()
        self.tokenizer = tokenizer
        self.timestamp_sigma = timestamp_sigma
        # Pre-compute the Gaussian smoothing matrix
        self.register_buffer('ts_smoothing_matrix', self._build_smoothing_matrix())

    def _build_smoothing_matrix(self):
        # FIX: Use get_vocab() instead of .decoder.items()
        vocab = self.tokenizer.get_vocab()
        vocab_size = len(vocab)

        timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>')

        # 1. Map Token IDs to Time Values
        id_to_time = {}
        for token_str, token_id in vocab.items():
            match = timestamp_pattern.match(token_str)
            if match:
                id_to_time[token_id] = float(match.group(1))

        if not id_to_time:
            return None

        # Sorted list for fast lookups
        sorted_ids = sorted(id_to_time.keys())
        self.sorted_ts_ids = torch.tensor(sorted_ids)
        times = torch.tensor([id_to_time[i] for i in sorted_ids])

        # 2. Create the Smoothing Matrix (Num_Timestamps x Vocab_Size)
        num_ts = len(sorted_ids)
        smoothing_matrix = torch.zeros(num_ts, vocab_size)

        # Vectorized Gaussian computation
        diff_sq = (times.unsqueeze(1) - times.unsqueeze(0)) ** 2
        weights = torch.exp(-diff_sq / (2 * self.timestamp_sigma ** 2))

        # Normalize
        weights = weights / weights.sum(dim=1, keepdim=True)

        # Scatter rows back to vocab size
        for i, ts_id in enumerate(sorted_ids):
            smoothing_matrix[i, self.sorted_ts_ids] = weights[i]

        return smoothing_matrix

    def _get_soft_distribution(self, labels, vocab_size):
        """Internal helper to convert hard labels -> soft timestamp labels"""
        device = labels.device

        # Start with One-Hot (clamp -100 to 0 temporarily)
        labels_clamped = labels.clamp(min=0)
        soft_labels = F.one_hot(labels_clamped, num_classes=vocab_size).float()

        # Apply Timestamp Smoothing if matrix exists
        if hasattr(self, 'ts_smoothing_matrix') and self.ts_smoothing_matrix is not None:
            sorted_ts_ids = self.sorted_ts_ids.to(device)
            smoothing_matrix = self.ts_smoothing_matrix.to(device)

            is_timestamp = torch.isin(labels, sorted_ts_ids)

            if is_timestamp.any():
                ts_indices = torch.searchsorted(sorted_ts_ids, labels[is_timestamp])
                soft_labels[is_timestamp] = smoothing_matrix[ts_indices]

        return soft_labels

    def compute_loss(self, logits, labels, upp_labels):
        """
        Computes the enhanced SOT loss:
        1. Generates soft labels (timestamp smoothed) for both 'labels' and 'upp_labels'.
        2. Computes KL Divergence (via CrossEntropy) for both.
        3. Takes the minimum loss per token (case invariance).
        4. Applies padding mask.
        """
        vocab_size = logits.size(-1)
        device = logits.device

        # Ensure labels are on correct device
        labels = labels.to(device)
        if upp_labels is not None:
            upp_labels = upp_labels.to(device)

        # Flatten inputs
        flat_logits = logits.view(-1, vocab_size)
        flat_labels = labels.reshape(-1)

        # 1. Generate Soft Targets for Lowercase
        soft_lower = self._get_soft_distribution(flat_labels, vocab_size)

        # 2. Generate Soft Targets for Uppercase (if provided)
        if upp_labels is not None:
            flat_upp = upp_labels.reshape(-1)
            soft_upper = self._get_soft_distribution(flat_upp, vocab_size)
        else:
            # Fallback if no upper labels provided (shouldn't happen in this pipeline)
            soft_upper = soft_lower

        # 3. Compute Cross Entropy (Soft Target Mode)
        # Note: CE with soft targets = -sum(target * log_prob)
        loss_fct = CrossEntropyLoss(reduction='none')

        loss_lower = loss_fct(flat_logits, soft_lower)
        loss_upper = loss_fct(flat_logits, soft_upper)

        # 4. Mask Padding (ignore_index = -100)
        # Soft-target CE doesn't support ignore_index automatically
        mask = (flat_labels != -100).float()

        loss_lower = loss_lower * mask
        loss_upper = loss_upper * mask

        # 5. Take Min (Case Invariance) and Normalize
        combined_min = torch.min(loss_lower, loss_upper)

        # Sum and divide by number of non-padding tokens
        return combined_min.sum() / mask.sum().clamp(min=1)

class DiCoW(WhisperModel):
    def __init__(self, config: DiCoWConfig):
        super().__init__(config)
        self.encoder = DiCoWEncoder(config)
        self.post_init()

    def forward(
            self,
            input_features: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            stno_mask: Optional[torch.FloatTensor] = None,
            decoder_input_ids: Optional[torch.LongTensor] = None,
            decoder_attention_mask: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            decoder_head_mask: Optional[torch.Tensor] = None,
            cross_attn_head_mask: Optional[torch.Tensor] = None,
            encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
            past_key_values: Optional[Cache] = None,
            decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
            decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            cache_position: Optional[torch.LongTensor] = None,
            enrollments=None
    ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            input_features = self._mask_input_features(input_features, attention_mask=attention_mask)

            encoder_outputs = self.encoder(
                input_features,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                head_mask=head_mask,
                return_dict=return_dict,
                stno_mask=stno_mask,
                enrollments=enrollments
            )

        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            position_ids=decoder_position_ids,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
    config_class = DiCoWConfig

    def __init__(self, config: DiCoWConfig):
        super().__init__(config)
        self.model = DiCoW(config)
        self.encoder_logits = None
        self.tokenizer = None
        self.stno_mask = None
        self.stno_mask_seek = None
        self.soft_label_creator = None
        self.post_init()

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
        # Initialize the helper class
        self.soft_label_creator = SoftLabelCreator(tokenizer)

    def get_enc_logits(self, hidden_states):
        encoder = self.model.get_encoder()
        hidden_states = encoder.possibly_update_last_hidden_states(hidden_states)
        logits = encoder.lm_head(hidden_states)
        return logits

    def forward(
            self,
            input_features: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            stno_mask: Optional[torch.FloatTensor] = None,
            decoder_input_ids: Optional[torch.LongTensor] = None,
            decoder_attention_mask: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            decoder_head_mask: Optional[torch.Tensor] = None,
            cross_attn_head_mask: Optional[torch.Tensor] = None,
            encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
            past_key_values: Optional[Cache] = None,
            decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
            decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
            labels: Optional[torch.LongTensor] = None,
            upp_labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            cache_position: Optional[torch.LongTensor] = None,
            forced_decoder_ids: Optional[torch.LongTensor] = None,
            enrollments=None,
    ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        outputs = self.model(
            input_features,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_position_ids=decoder_position_ids,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            stno_mask=stno_mask,
            enrollments=enrollments,
        )

        dec_lm_logits = self.proj_out(outputs.last_hidden_state)
        loss = None

        if labels is not None:
            # --- UPDATED LOSS CALCULATION ---
            if self.soft_label_creator is not None:
                # Delegate all soft label creation, flattening, and min-loss logic to the helper
                dec_loss = self.soft_label_creator.compute_loss(dec_lm_logits, labels, upp_labels)
            else:
                # Fallback to original hard label implementation if tokenizer/helper not ready
                loss_fct = CrossEntropyLoss(reduction='none')
                labels = labels.to(dec_lm_logits.device)

                flat_logits = dec_lm_logits.view(-1, self.config.vocab_size)
                dec_loss1 = loss_fct(flat_logits, labels.reshape(-1))

                if upp_labels is not None:
                    upp_labels = upp_labels.to(dec_lm_logits.device)
                    dec_loss2 = loss_fct(flat_logits, upp_labels.reshape(-1))
                    dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
                else:
                    dec_loss = dec_loss1.mean()
            # --------------------------------

            if self.config.ctc_weight > 0.0:
                enc_lm_logits = self.get_enc_logits(outputs.encoder_last_hidden_state)
                # Prepare CTC labels
                enc_labels = labels.clone().to(dec_lm_logits.device)
                for token in self.tokenizer.prefix_tokens:
                    if (enc_labels[:, 0] == token).all():
                        enc_labels = enc_labels[:, 1:]
                enc_labels[enc_labels == self.config.eos_token_id] = -100

                ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
                loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
            else:
                loss = dec_loss

        if not return_dict:
            output = (dec_lm_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=dec_lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def _get_feat_extract_output_lengths(self, attention_mask: torch.LongTensor) -> torch.LongTensor:
        return (self.model.get_encoder()._get_feat_extract_output_lengths(attention_mask) / 4).ceil()