|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
import numpy as np |
|
|
from sklearn.metrics import roc_auc_score |
|
|
from sklearn.metrics import roc_curve |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning import LightningModule, Trainer |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
from transformers import Wav2Vec2FeatureExtractor, HubertModel |
|
|
from model_pinyin import MMKWS2 |
|
|
from torch.utils.data import Dataset, ConcatDataset |
|
|
|
|
|
class MMKWS2_Wrapper(LightningModule): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.model = MMKWS2() |
|
|
self.criterion = nn.BCEWithLogitsLoss() |
|
|
|
|
|
hubert_model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large").half().eval() |
|
|
self._hubert_model = hubert_model |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
anchor_wave, anchor_text_embedding, compare_wave, compare_lengths, label, seq_label = \ |
|
|
batch['anchor_wave'], batch['anchor_embedding'], batch['compare_wave'], batch['compare_lengths'], batch['label'], batch['seq_label'] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self._hubert_model(anchor_wave.half()) |
|
|
anchor_wave_embedding = outputs.last_hidden_state |
|
|
|
|
|
anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype) |
|
|
|
|
|
logits, seq_logits = self.model( |
|
|
anchor_wave_embedding, |
|
|
anchor_text_embedding, |
|
|
compare_wave, |
|
|
compare_lengths |
|
|
) |
|
|
|
|
|
|
|
|
utt_loss = self.criterion(logits, label.float()) |
|
|
|
|
|
|
|
|
mask = (seq_label != -1).float() |
|
|
seq_label_valid = seq_label.clone() |
|
|
seq_label_valid[seq_label == -1] = 0 |
|
|
|
|
|
seq_loss = F.binary_cross_entropy_with_logits( |
|
|
seq_logits, seq_label_valid.float(), weight=mask, reduction='sum' |
|
|
) / (mask.sum() + 1e-6) |
|
|
|
|
|
loss = utt_loss + seq_loss |
|
|
|
|
|
|
|
|
self.log('train/utt_loss', utt_loss, on_step=True, on_epoch=False, prog_bar=True) |
|
|
self.log('train/seq_loss', seq_loss, on_step=True, on_epoch=False, prog_bar=True) |
|
|
self.log('train/loss', loss, on_step=True, on_epoch=False, prog_bar=True) |
|
|
return loss |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) |
|
|
|
|
|
def lr_lambda(step): |
|
|
return 0.95 ** (step // 1000) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR( |
|
|
optimizer, |
|
|
lr_lambda=lr_lambda |
|
|
) |
|
|
|
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"interval": "step", |
|
|
"frequency": 1 |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pl.seed_everything(2024) |
|
|
import os |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '7' |
|
|
from dataset_pinyin import PairDataset |
|
|
from dataloader_pinyin import get_dataloader |
|
|
|
|
|
|
|
|
dataset1 = PairDataset( |
|
|
'/nvme01/aizq/kws-agent/data/anchor_pairs.parquet', |
|
|
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S', |
|
|
augment=True |
|
|
) |
|
|
dataset2 = PairDataset( |
|
|
'/nvme01/aizq/kws-agent/data/synthetic_pairs.parquet', |
|
|
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S', |
|
|
augment=True |
|
|
) |
|
|
dataset = ConcatDataset([dataset1, dataset2]) |
|
|
|
|
|
dataloader = get_dataloader(dataset, batch_size=1024) |
|
|
|
|
|
model = MMKWS2_Wrapper() |
|
|
model_checkpoint = ModelCheckpoint( |
|
|
dirpath="/nvme01/aizq/kws-agent/ckpts", |
|
|
filename="step{step:06d}", |
|
|
save_top_k=-1, |
|
|
save_on_train_epoch_end=False, |
|
|
every_n_train_steps=500 |
|
|
) |
|
|
logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/kws-agent/logs/', name='MMKWS+') |
|
|
trainer = Trainer( |
|
|
devices=1, |
|
|
accelerator='gpu', |
|
|
logger=logger, |
|
|
max_epochs=4, |
|
|
callbacks=[model_checkpoint], |
|
|
accumulate_grad_batches=2, |
|
|
) |
|
|
trainer.fit(model, train_dataloaders=dataloader) |