pineapple-lover's picture
Release HuPER Corrector weights and inference code
7d0662d
import torch
import torch.nn as nn
import math
import pytorch_lightning as pl
class PhonemeCorrector(pl.LightningModule):
def __init__(self, vocab_size, audio_vocab_size, d_model=256, nhead=4, num_layers=4, dropout=0.1, lr=1e-4,
weight_decay=0.01, scheduler_config=None, optimizer_config=None):
super().__init__()
self.save_hyperparameters()
self.scheduler_config = scheduler_config or {}
self.optimizer_config = optimizer_config or {}
# 1. Embeddings
self.text_embedding = nn.Embedding(vocab_size, d_model)
self.audio_embedding = nn.Embedding(audio_vocab_size, d_model)
# Positional Encoding (Standard Sinusoidal)
self.pos_encoder = PositionalEncoding(d_model, dropout)
# 2. The Core Transformer (Text querying Audio)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
# 3. Prediction Heads - 2-head architecture
# Head 1: Operation (KEEP, DEL, SUB:AA, SUB:AE, ...)
# num_ops = vocab_size + 2 (KEEP=0, DEL=1, SUB:phonemes=2+)
# This matches the precomputed op_ids format
num_ops = vocab_size + 2
self.head_op = nn.Linear(d_model, num_ops)
# Head 2: Insertion (NONE=0, AA, AE, ...)
# num_inserts = vocab_size (NONE=0, then phonemes)
num_inserts = vocab_size
self.head_ins = nn.Linear(d_model, num_inserts)
def forward(self, text_ids, audio_ids, text_mask=None, audio_mask=None):
"""
text_ids: (Batch, Text_Len)
audio_ids: (Batch, Audio_Len)
masks: (Batch, Len) - 1 for valid, 0 for pad.
"""
text_emb = self.pos_encoder(self.text_embedding(text_ids))
audio_emb = self.pos_encoder(self.audio_embedding(audio_ids))
txt_pad_mask = (text_mask == 0) if text_mask is not None else None
aud_pad_mask = (audio_mask == 0) if audio_mask is not None else None
encoded_features = self.transformer(
tgt=text_emb,
memory=audio_emb,
tgt_key_padding_mask=txt_pad_mask,
memory_key_padding_mask=aud_pad_mask
)
logits_op = self.head_op(encoded_features)
logits_ins = self.head_ins(encoded_features)
return logits_op, logits_ins
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
audio_tokens = batch['audio_tokens']
lbl_op = batch['labels']['op']
lbl_ins = batch['labels']['ins']
txt_mask = batch['masks']['text']
audio_mask = batch['masks']['audio']
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
# Active loss mask (only compute loss on valid text tokens)
active_loss = txt_mask.view(-1) == 1
# OP LOSS (includes KEEP, DEL, and all SUB:phoneme operations)
num_ops = self.hparams.vocab_size + 2
loss_op = nn.functional.cross_entropy(
logits_op.view(-1, num_ops)[active_loss],
lbl_op.view(-1)[active_loss]
)
# INS LOSS
loss_ins = nn.functional.cross_entropy(
logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
lbl_ins.view(-1)[active_loss]
)
loss = loss_op + loss_ins
self.log('train_loss', loss, prog_bar=True)
self.log('train_loss_op', loss_op)
self.log('train_loss_ins', loss_ins)
return loss
def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids']
audio_tokens = batch['audio_tokens']
lbl_op = batch['labels']['op']
lbl_ins = batch['labels']['ins']
txt_mask = batch['masks']['text']
audio_mask = batch['masks']['audio']
logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
# Compute losses
active_loss = txt_mask.view(-1) == 1
num_ops = self.hparams.vocab_size + 2
loss_op = nn.functional.cross_entropy(
logits_op.view(-1, num_ops)[active_loss],
lbl_op.view(-1)[active_loss]
)
loss_ins = nn.functional.cross_entropy(
logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
lbl_ins.view(-1)[active_loss]
)
loss = loss_op + loss_ins
# Compute accuracy
pred_op = torch.argmax(logits_op, dim=-1)
pred_ins = torch.argmax(logits_ins, dim=-1)
# OP accuracy
op_correct = (pred_op == lbl_op) & txt_mask
op_acc = op_correct.sum().float() / txt_mask.sum().float()
# INS accuracy
ins_correct = (pred_ins == lbl_ins) & txt_mask
ins_acc = ins_correct.sum().float() / txt_mask.sum().float()
# Overall accuracy: correct OP prediction
overall_acc = op_acc
# Per-operation accuracy (KEEP=0, DEL=1, SUB>=2)
keep_mask = (lbl_op == 0) & txt_mask
del_mask = (lbl_op == 1) & txt_mask
sub_op_mask = (lbl_op >= 2) & txt_mask
keep_acc = torch.tensor(0.0, device=loss.device)
del_acc = torch.tensor(0.0, device=loss.device)
sub_op_acc = torch.tensor(0.0, device=loss.device)
if keep_mask.sum() > 0:
keep_correct = (pred_op == lbl_op) & keep_mask
keep_acc = keep_correct.sum().float() / keep_mask.sum().float()
if del_mask.sum() > 0:
del_correct = (pred_op == lbl_op) & del_mask
del_acc = del_correct.sum().float() / del_mask.sum().float()
if sub_op_mask.sum() > 0:
sub_op_correct = (pred_op == lbl_op) & sub_op_mask
sub_op_acc = sub_op_correct.sum().float() / sub_op_mask.sum().float()
# Log metrics
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
self.log('val_loss_op', loss_op, sync_dist=True)
self.log('val_loss_ins', loss_ins, sync_dist=True)
self.log('val_acc', overall_acc, prog_bar=True, sync_dist=True)
self.log('val_acc_op', op_acc, sync_dist=True)
self.log('val_acc_ins', ins_acc, sync_dist=True)
self.log('val_acc_keep', keep_acc, sync_dist=True)
self.log('val_acc_del', del_acc, sync_dist=True)
self.log('val_acc_sub_op', sub_op_acc, sync_dist=True)
return {
'val_loss': loss,
'val_acc': overall_acc,
'val_acc_op': op_acc,
'val_acc_ins': ins_acc
}
def configure_optimizers(self):
# Get optimizer configuration
optimizer_name = self.optimizer_config.get("name", "adamw").lower()
lr = self.hparams.lr
weight_decay = getattr(self.hparams, 'weight_decay', 0.01)
if optimizer_name == "adamw":
optimizer = torch.optim.AdamW(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=self.optimizer_config.get("betas", [0.9, 0.999]),
eps=self.optimizer_config.get("eps", 1.0e-8)
)
elif optimizer_name == "adam":
optimizer = torch.optim.Adam(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=self.optimizer_config.get("betas", [0.9, 0.999]),
eps=self.optimizer_config.get("eps", 1.0e-8)
)
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
# Configure scheduler
scheduler_type = self.scheduler_config.get("type", "cosine").lower()
# Calculate total training steps
max_epochs = getattr(self.trainer, 'max_epochs', 50)
if self.trainer and hasattr(self.trainer, 'estimated_stepping_batches'):
total_steps = self.trainer.estimated_stepping_batches
else:
# Fallback: estimate steps per epoch
estimated_steps_per_epoch = 1000 # Conservative estimate
total_steps = max_epochs * estimated_steps_per_epoch
warmup_ratio = self.scheduler_config.get("warmup_ratio", 0.1)
warmup_steps = max(1, int(total_steps * warmup_ratio))
if scheduler_type == "cosine":
# Use transformers' cosine scheduler with warmup
try:
from transformers import get_cosine_schedule_with_warmup
eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
num_cycles=0.5, # Default cosine cycles
last_epoch=-1
)
except ImportError:
# Fallback to PyTorch implementation
from torch.optim.lr_scheduler import LambdaLR
import math
eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
# Cosine annealing after warmup
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_value = 0.5 * (1 + math.cos(math.pi * progress))
return eta_min / lr + (1 - eta_min / lr) * cosine_value
scheduler = LambdaLR(optimizer, lr_lambda)
elif scheduler_type == "linear":
# Use transformers' linear scheduler with warmup
try:
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
except ImportError:
# Fallback to PyTorch implementation
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max(0.0, 1.0 - progress)
scheduler = LambdaLR(optimizer, lr_lambda)
elif scheduler_type == "polynomial":
# Use transformers' polynomial scheduler with warmup
try:
from transformers import get_polynomial_decay_schedule_with_warmup
power = self.scheduler_config.get("power", 1.0)
scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
power=power
)
except ImportError:
# Fallback: use linear scheduler
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max(0.0, (1.0 - progress) ** power)
scheduler = LambdaLR(optimizer, lr_lambda)
elif scheduler_type == "reduce_on_plateau":
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=self.scheduler_config.get("factor", 0.5),
patience=self.scheduler_config.get("patience", 3),
min_lr=self.scheduler_config.get("min_lr", 1.0e-6),
verbose=True
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "train_loss",
"interval": "epoch",
"frequency": 1,
}
}
else:
# No scheduler
return optimizer
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
}
}
# Helper for Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
# x: (Batch, Seq, Dim)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)