SE-DiCoW / modeling_dicow.py
Lakoc's picture
Upload DiCoWForConditionalGeneration
96b9702 verified
from typing import Optional, Union
import re
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers import Cache
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.models.whisper.modeling_whisper import (
WhisperForConditionalGeneration,
shift_tokens_right,
WhisperModel
)
from transformers.utils import logging
from .config import DiCoWConfig
from .encoder import DiCoWEncoder
from .generation import DiCoWGenerationMixin
logging.set_verbosity_debug()
logger = logging.get_logger("transformers")
class SoftLabelCreator(torch.nn.Module):
"""
Handles label smoothing for timestamps and the dual-loss logic (Upper vs Lower case).
"""
def __init__(self, tokenizer, timestamp_sigma=0.08):
super().__init__()
self.tokenizer = tokenizer
self.timestamp_sigma = timestamp_sigma
# Pre-compute the Gaussian smoothing matrix
self.register_buffer('ts_smoothing_matrix', self._build_smoothing_matrix())
def _build_smoothing_matrix(self):
# FIX: Use get_vocab() instead of .decoder.items()
vocab = self.tokenizer.get_vocab()
vocab_size = len(vocab)
timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>')
# 1. Map Token IDs to Time Values
id_to_time = {}
for token_str, token_id in vocab.items():
match = timestamp_pattern.match(token_str)
if match:
id_to_time[token_id] = float(match.group(1))
if not id_to_time:
return None
# Sorted list for fast lookups
sorted_ids = sorted(id_to_time.keys())
self.sorted_ts_ids = torch.tensor(sorted_ids)
times = torch.tensor([id_to_time[i] for i in sorted_ids])
# 2. Create the Smoothing Matrix (Num_Timestamps x Vocab_Size)
num_ts = len(sorted_ids)
smoothing_matrix = torch.zeros(num_ts, vocab_size)
# Vectorized Gaussian computation
diff_sq = (times.unsqueeze(1) - times.unsqueeze(0)) ** 2
weights = torch.exp(-diff_sq / (2 * self.timestamp_sigma ** 2))
# Normalize
weights = weights / weights.sum(dim=1, keepdim=True)
# Scatter rows back to vocab size
for i, ts_id in enumerate(sorted_ids):
smoothing_matrix[i, self.sorted_ts_ids] = weights[i]
return smoothing_matrix
def _get_soft_distribution(self, labels, vocab_size):
"""Internal helper to convert hard labels -> soft timestamp labels"""
device = labels.device
# Start with One-Hot (clamp -100 to 0 temporarily)
labels_clamped = labels.clamp(min=0)
soft_labels = F.one_hot(labels_clamped, num_classes=vocab_size).float()
# Apply Timestamp Smoothing if matrix exists
if hasattr(self, 'ts_smoothing_matrix') and self.ts_smoothing_matrix is not None:
sorted_ts_ids = self.sorted_ts_ids.to(device)
smoothing_matrix = self.ts_smoothing_matrix.to(device)
is_timestamp = torch.isin(labels, sorted_ts_ids)
if is_timestamp.any():
ts_indices = torch.searchsorted(sorted_ts_ids, labels[is_timestamp])
soft_labels[is_timestamp] = smoothing_matrix[ts_indices]
return soft_labels
def compute_loss(self, logits, labels, upp_labels):
"""
Computes the enhanced SOT loss:
1. Generates soft labels (timestamp smoothed) for both 'labels' and 'upp_labels'.
2. Computes KL Divergence (via CrossEntropy) for both.
3. Takes the minimum loss per token (case invariance).
4. Applies padding mask.
"""
vocab_size = logits.size(-1)
device = logits.device
# Ensure labels are on correct device
labels = labels.to(device)
if upp_labels is not None:
upp_labels = upp_labels.to(device)
# Flatten inputs
flat_logits = logits.view(-1, vocab_size)
flat_labels = labels.reshape(-1)
# 1. Generate Soft Targets for Lowercase
soft_lower = self._get_soft_distribution(flat_labels, vocab_size)
# 2. Generate Soft Targets for Uppercase (if provided)
if upp_labels is not None:
flat_upp = upp_labels.reshape(-1)
soft_upper = self._get_soft_distribution(flat_upp, vocab_size)
else:
# Fallback if no upper labels provided (shouldn't happen in this pipeline)
soft_upper = soft_lower
# 3. Compute Cross Entropy (Soft Target Mode)
# Note: CE with soft targets = -sum(target * log_prob)
loss_fct = CrossEntropyLoss(reduction='none')
loss_lower = loss_fct(flat_logits, soft_lower)
loss_upper = loss_fct(flat_logits, soft_upper)
# 4. Mask Padding (ignore_index = -100)
# Soft-target CE doesn't support ignore_index automatically
mask = (flat_labels != -100).float()
loss_lower = loss_lower * mask
loss_upper = loss_upper * mask
# 5. Take Min (Case Invariance) and Normalize
combined_min = torch.min(loss_lower, loss_upper)
# Sum and divide by number of non-padding tokens
return combined_min.sum() / mask.sum().clamp(min=1)
class DiCoW(WhisperModel):
def __init__(self, config: DiCoWConfig):
super().__init__(config)
self.encoder = DiCoWEncoder(config)
self.post_init()
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
stno_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
enrollments=None
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stno_mask=stno_mask,
enrollments=enrollments
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
config_class = DiCoWConfig
def __init__(self, config: DiCoWConfig):
super().__init__(config)
self.model = DiCoW(config)
self.encoder_logits = None
self.tokenizer = None
self.stno_mask = None
self.stno_mask_seek = None
self.soft_label_creator = None
self.post_init()
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
# Initialize the helper class
self.soft_label_creator = SoftLabelCreator(tokenizer)
def get_enc_logits(self, hidden_states):
encoder = self.model.get_encoder()
hidden_states = encoder.possibly_update_last_hidden_states(hidden_states)
logits = encoder.lm_head(hidden_states)
return logits
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
stno_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
upp_labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
forced_decoder_ids: Optional[torch.LongTensor] = None,
enrollments=None,
) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
stno_mask=stno_mask,
enrollments=enrollments,
)
dec_lm_logits = self.proj_out(outputs.last_hidden_state)
loss = None
if labels is not None:
# --- UPDATED LOSS CALCULATION ---
if self.soft_label_creator is not None:
# Delegate all soft label creation, flattening, and min-loss logic to the helper
dec_loss = self.soft_label_creator.compute_loss(dec_lm_logits, labels, upp_labels)
else:
# Fallback to original hard label implementation if tokenizer/helper not ready
loss_fct = CrossEntropyLoss(reduction='none')
labels = labels.to(dec_lm_logits.device)
flat_logits = dec_lm_logits.view(-1, self.config.vocab_size)
dec_loss1 = loss_fct(flat_logits, labels.reshape(-1))
if upp_labels is not None:
upp_labels = upp_labels.to(dec_lm_logits.device)
dec_loss2 = loss_fct(flat_logits, upp_labels.reshape(-1))
dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
else:
dec_loss = dec_loss1.mean()
# --------------------------------
if self.config.ctc_weight > 0.0:
enc_lm_logits = self.get_enc_logits(outputs.encoder_last_hidden_state)
# Prepare CTC labels
enc_labels = labels.clone().to(dec_lm_logits.device)
for token in self.tokenizer.prefix_tokens:
if (enc_labels[:, 0] == token).all():
enc_labels = enc_labels[:, 1:]
enc_labels[enc_labels == self.config.eos_token_id] = -100
ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
else:
loss = dec_loss
if not return_dict:
output = (dec_lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=dec_lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def _get_feat_extract_output_lengths(self, attention_mask: torch.LongTensor) -> torch.LongTensor:
return (self.model.get_encoder()._get_feat_extract_output_lengths(attention_mask) / 4).ceil()