OpenKWS / train_pinyin.py
ZhiqiAi's picture
Upload 16 files
693e27f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import Wav2Vec2FeatureExtractor, HubertModel
from model_pinyin import MMKWS2
from torch.utils.data import Dataset, ConcatDataset
class MMKWS2_Wrapper(LightningModule):
def __init__(self):
super().__init__()
self.model = MMKWS2()
self.criterion = nn.BCEWithLogitsLoss()
# 将hubert_model设为临时变量而非类属性
hubert_model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large").half().eval()
self._hubert_model = hubert_model # 使用下划线前缀表示内部使用
def training_step(self, batch, batch_idx):
anchor_wave, anchor_text_embedding, compare_wave, compare_lengths, label, seq_label = \
batch['anchor_wave'], batch['anchor_embedding'], batch['compare_wave'], batch['compare_lengths'], batch['label'], batch['seq_label']
with torch.no_grad():
outputs = self._hubert_model(anchor_wave.half())
anchor_wave_embedding = outputs.last_hidden_state
anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
logits, seq_logits = self.model(
anchor_wave_embedding,
anchor_text_embedding,
compare_wave,
compare_lengths
)
# 句级二分类loss
utt_loss = self.criterion(logits, label.float())
# 序列loss(mask掉seq_label为-1的部分)
mask = (seq_label != -1).float()
seq_label_valid = seq_label.clone()
seq_label_valid[seq_label == -1] = 0 # 避免-1影响loss
seq_loss = F.binary_cross_entropy_with_logits(
seq_logits, seq_label_valid.float(), weight=mask, reduction='sum'
) / (mask.sum() + 1e-6)
loss = utt_loss + seq_loss
# 每500步记录日志
self.log('train/utt_loss', utt_loss, on_step=True, on_epoch=False, prog_bar=True)
self.log('train/seq_loss', seq_loss, on_step=True, on_epoch=False, prog_bar=True)
self.log('train/loss', loss, on_step=True, on_epoch=False, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
def lr_lambda(step):
return 0.95 ** (step // 1000)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lr_lambda
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step", # 按步更新
"frequency": 1
},
}
# 3. 设置 Trainer 和训练
if __name__ == "__main__":
pl.seed_everything(2024)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
from dataset_pinyin import PairDataset
from dataloader_pinyin import get_dataloader
# 创建数据集
dataset1 = PairDataset(
'/nvme01/aizq/kws-agent/data/anchor_pairs.parquet',
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
augment=True
)
dataset2 = PairDataset(
'/nvme01/aizq/kws-agent/data/synthetic_pairs.parquet',
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
augment=True
)
dataset = ConcatDataset([dataset1, dataset2])
# 创建dataloader
dataloader = get_dataloader(dataset, batch_size=1024)
model = MMKWS2_Wrapper()
model_checkpoint = ModelCheckpoint(
dirpath="/nvme01/aizq/kws-agent/ckpts",
filename="step{step:06d}",
save_top_k=-1,
save_on_train_epoch_end=False, # 按训练步数保存
every_n_train_steps=500 # 每10k步保存一次
)
logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/kws-agent/logs/', name='MMKWS+')
trainer = Trainer(
devices=1,
accelerator='gpu',
logger=logger,
max_epochs=4, # 训练1轮
callbacks=[model_checkpoint],
accumulate_grad_batches=2, # 2048 batchsize
)
trainer.fit(model, train_dataloaders=dataloader)