|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
class PhonemeCorrector(pl.LightningModule): |
|
|
def __init__(self, vocab_size, audio_vocab_size, d_model=256, nhead=4, num_layers=4, dropout=0.1, lr=1e-4, |
|
|
weight_decay=0.01, scheduler_config=None, optimizer_config=None): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.scheduler_config = scheduler_config or {} |
|
|
self.optimizer_config = optimizer_config or {} |
|
|
|
|
|
|
|
|
self.text_embedding = nn.Embedding(vocab_size, d_model) |
|
|
self.audio_embedding = nn.Embedding(audio_vocab_size, d_model) |
|
|
|
|
|
|
|
|
self.pos_encoder = PositionalEncoding(d_model, dropout) |
|
|
|
|
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True) |
|
|
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_ops = vocab_size + 2 |
|
|
self.head_op = nn.Linear(d_model, num_ops) |
|
|
|
|
|
|
|
|
|
|
|
num_inserts = vocab_size |
|
|
self.head_ins = nn.Linear(d_model, num_inserts) |
|
|
|
|
|
def forward(self, text_ids, audio_ids, text_mask=None, audio_mask=None): |
|
|
""" |
|
|
text_ids: (Batch, Text_Len) |
|
|
audio_ids: (Batch, Audio_Len) |
|
|
masks: (Batch, Len) - 1 for valid, 0 for pad. |
|
|
""" |
|
|
text_emb = self.pos_encoder(self.text_embedding(text_ids)) |
|
|
audio_emb = self.pos_encoder(self.audio_embedding(audio_ids)) |
|
|
|
|
|
txt_pad_mask = (text_mask == 0) if text_mask is not None else None |
|
|
aud_pad_mask = (audio_mask == 0) if audio_mask is not None else None |
|
|
|
|
|
encoded_features = self.transformer( |
|
|
tgt=text_emb, |
|
|
memory=audio_emb, |
|
|
tgt_key_padding_mask=txt_pad_mask, |
|
|
memory_key_padding_mask=aud_pad_mask |
|
|
) |
|
|
|
|
|
logits_op = self.head_op(encoded_features) |
|
|
logits_ins = self.head_ins(encoded_features) |
|
|
|
|
|
return logits_op, logits_ins |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
input_ids = batch['input_ids'] |
|
|
audio_tokens = batch['audio_tokens'] |
|
|
lbl_op = batch['labels']['op'] |
|
|
lbl_ins = batch['labels']['ins'] |
|
|
txt_mask = batch['masks']['text'] |
|
|
audio_mask = batch['masks']['audio'] |
|
|
|
|
|
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask) |
|
|
|
|
|
|
|
|
active_loss = txt_mask.view(-1) == 1 |
|
|
|
|
|
|
|
|
num_ops = self.hparams.vocab_size + 2 |
|
|
loss_op = nn.functional.cross_entropy( |
|
|
logits_op.view(-1, num_ops)[active_loss], |
|
|
lbl_op.view(-1)[active_loss] |
|
|
) |
|
|
|
|
|
|
|
|
loss_ins = nn.functional.cross_entropy( |
|
|
logits_ins.view(-1, self.hparams.vocab_size)[active_loss], |
|
|
lbl_ins.view(-1)[active_loss] |
|
|
) |
|
|
|
|
|
loss = loss_op + loss_ins |
|
|
self.log('train_loss', loss, prog_bar=True) |
|
|
self.log('train_loss_op', loss_op) |
|
|
self.log('train_loss_ins', loss_ins) |
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
input_ids = batch['input_ids'] |
|
|
audio_tokens = batch['audio_tokens'] |
|
|
lbl_op = batch['labels']['op'] |
|
|
lbl_ins = batch['labels']['ins'] |
|
|
txt_mask = batch['masks']['text'] |
|
|
audio_mask = batch['masks']['audio'] |
|
|
|
|
|
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask) |
|
|
|
|
|
|
|
|
active_loss = txt_mask.view(-1) == 1 |
|
|
num_ops = self.hparams.vocab_size + 2 |
|
|
|
|
|
loss_op = nn.functional.cross_entropy( |
|
|
logits_op.view(-1, num_ops)[active_loss], |
|
|
lbl_op.view(-1)[active_loss] |
|
|
) |
|
|
|
|
|
loss_ins = nn.functional.cross_entropy( |
|
|
logits_ins.view(-1, self.hparams.vocab_size)[active_loss], |
|
|
lbl_ins.view(-1)[active_loss] |
|
|
) |
|
|
|
|
|
loss = loss_op + loss_ins |
|
|
|
|
|
|
|
|
pred_op = torch.argmax(logits_op, dim=-1) |
|
|
pred_ins = torch.argmax(logits_ins, dim=-1) |
|
|
|
|
|
|
|
|
op_correct = (pred_op == lbl_op) & txt_mask |
|
|
op_acc = op_correct.sum().float() / txt_mask.sum().float() |
|
|
|
|
|
|
|
|
ins_correct = (pred_ins == lbl_ins) & txt_mask |
|
|
ins_acc = ins_correct.sum().float() / txt_mask.sum().float() |
|
|
|
|
|
|
|
|
overall_acc = op_acc |
|
|
|
|
|
|
|
|
keep_mask = (lbl_op == 0) & txt_mask |
|
|
del_mask = (lbl_op == 1) & txt_mask |
|
|
sub_op_mask = (lbl_op >= 2) & txt_mask |
|
|
|
|
|
keep_acc = torch.tensor(0.0, device=loss.device) |
|
|
del_acc = torch.tensor(0.0, device=loss.device) |
|
|
sub_op_acc = torch.tensor(0.0, device=loss.device) |
|
|
|
|
|
if keep_mask.sum() > 0: |
|
|
keep_correct = (pred_op == lbl_op) & keep_mask |
|
|
keep_acc = keep_correct.sum().float() / keep_mask.sum().float() |
|
|
|
|
|
if del_mask.sum() > 0: |
|
|
del_correct = (pred_op == lbl_op) & del_mask |
|
|
del_acc = del_correct.sum().float() / del_mask.sum().float() |
|
|
|
|
|
if sub_op_mask.sum() > 0: |
|
|
sub_op_correct = (pred_op == lbl_op) & sub_op_mask |
|
|
sub_op_acc = sub_op_correct.sum().float() / sub_op_mask.sum().float() |
|
|
|
|
|
|
|
|
self.log('val_loss', loss, prog_bar=True, sync_dist=True) |
|
|
self.log('val_loss_op', loss_op, sync_dist=True) |
|
|
self.log('val_loss_ins', loss_ins, sync_dist=True) |
|
|
self.log('val_acc', overall_acc, prog_bar=True, sync_dist=True) |
|
|
self.log('val_acc_op', op_acc, sync_dist=True) |
|
|
self.log('val_acc_ins', ins_acc, sync_dist=True) |
|
|
self.log('val_acc_keep', keep_acc, sync_dist=True) |
|
|
self.log('val_acc_del', del_acc, sync_dist=True) |
|
|
self.log('val_acc_sub_op', sub_op_acc, sync_dist=True) |
|
|
|
|
|
return { |
|
|
'val_loss': loss, |
|
|
'val_acc': overall_acc, |
|
|
'val_acc_op': op_acc, |
|
|
'val_acc_ins': ins_acc |
|
|
} |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
optimizer_name = self.optimizer_config.get("name", "adamw").lower() |
|
|
lr = self.hparams.lr |
|
|
weight_decay = getattr(self.hparams, 'weight_decay', 0.01) |
|
|
|
|
|
if optimizer_name == "adamw": |
|
|
optimizer = torch.optim.AdamW( |
|
|
self.parameters(), |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
betas=self.optimizer_config.get("betas", [0.9, 0.999]), |
|
|
eps=self.optimizer_config.get("eps", 1.0e-8) |
|
|
) |
|
|
elif optimizer_name == "adam": |
|
|
optimizer = torch.optim.Adam( |
|
|
self.parameters(), |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
betas=self.optimizer_config.get("betas", [0.9, 0.999]), |
|
|
eps=self.optimizer_config.get("eps", 1.0e-8) |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown optimizer: {optimizer_name}") |
|
|
|
|
|
|
|
|
scheduler_type = self.scheduler_config.get("type", "cosine").lower() |
|
|
|
|
|
|
|
|
max_epochs = getattr(self.trainer, 'max_epochs', 50) |
|
|
if self.trainer and hasattr(self.trainer, 'estimated_stepping_batches'): |
|
|
total_steps = self.trainer.estimated_stepping_batches |
|
|
else: |
|
|
|
|
|
estimated_steps_per_epoch = 1000 |
|
|
total_steps = max_epochs * estimated_steps_per_epoch |
|
|
|
|
|
warmup_ratio = self.scheduler_config.get("warmup_ratio", 0.1) |
|
|
warmup_steps = max(1, int(total_steps * warmup_ratio)) |
|
|
|
|
|
if scheduler_type == "cosine": |
|
|
|
|
|
try: |
|
|
from transformers import get_cosine_schedule_with_warmup |
|
|
eta_min = self.scheduler_config.get("eta_min", 1.0e-6) |
|
|
scheduler = get_cosine_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=warmup_steps, |
|
|
num_training_steps=total_steps, |
|
|
num_cycles=0.5, |
|
|
last_epoch=-1 |
|
|
) |
|
|
except ImportError: |
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
import math |
|
|
eta_min = self.scheduler_config.get("eta_min", 1.0e-6) |
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
else: |
|
|
|
|
|
progress = (step - warmup_steps) / (total_steps - warmup_steps) |
|
|
cosine_value = 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
return eta_min / lr + (1 - eta_min / lr) * cosine_value |
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
elif scheduler_type == "linear": |
|
|
|
|
|
try: |
|
|
from transformers import get_linear_schedule_with_warmup |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=warmup_steps, |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
except ImportError: |
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
else: |
|
|
progress = (step - warmup_steps) / (total_steps - warmup_steps) |
|
|
return max(0.0, 1.0 - progress) |
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
elif scheduler_type == "polynomial": |
|
|
|
|
|
try: |
|
|
from transformers import get_polynomial_decay_schedule_with_warmup |
|
|
power = self.scheduler_config.get("power", 1.0) |
|
|
scheduler = get_polynomial_decay_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=warmup_steps, |
|
|
num_training_steps=total_steps, |
|
|
power=power |
|
|
) |
|
|
except ImportError: |
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
else: |
|
|
progress = (step - warmup_steps) / (total_steps - warmup_steps) |
|
|
return max(0.0, (1.0 - progress) ** power) |
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
elif scheduler_type == "reduce_on_plateau": |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
scheduler = ReduceLROnPlateau( |
|
|
optimizer, |
|
|
mode='min', |
|
|
factor=self.scheduler_config.get("factor", 0.5), |
|
|
patience=self.scheduler_config.get("patience", 3), |
|
|
min_lr=self.scheduler_config.get("min_lr", 1.0e-6), |
|
|
verbose=True |
|
|
) |
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"monitor": "train_loss", |
|
|
"interval": "epoch", |
|
|
"frequency": 1, |
|
|
} |
|
|
} |
|
|
else: |
|
|
|
|
|
return optimizer |
|
|
|
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"interval": "step", |
|
|
"frequency": 1, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x + self.pe[:, :x.size(1)] |
|
|
return self.dropout(x) |