Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from typing import Dict, Tuple | |
| import os | |
| from tqdm import tqdm | |
| import wandb | |
| from ..models.encoder import SpeakerEncoder | |
| from ..configs.config import Config, TrainingConfig | |
| class MetaTrainer: | |
| """元学习训练器:实现少样本语音克隆的训练过程""" | |
| def __init__( | |
| self, | |
| model: SpeakerEncoder, | |
| config: Config, | |
| use_wandb: bool = True | |
| ): | |
| self.model = model | |
| self.config = config | |
| self.use_wandb = use_wandb | |
| self.device = torch.device(config.training.device) | |
| self.model = self.model.to(self.device) | |
| self.optimizer = optim.Adam( | |
| self.model.parameters(), | |
| lr=config.training.learning_rate | |
| ) | |
| self.criterion = nn.CrossEntropyLoss() | |
| if use_wandb: | |
| wandb.init(project="voice-cloning", config=config) | |
| def compute_loss( | |
| self, | |
| support_data: Dict[str, torch.Tensor], | |
| query_data: Dict[str, torch.Tensor] | |
| ) -> Tuple[torch.Tensor, float]: | |
| """ | |
| 计算元学习损失 | |
| Args: | |
| support_data: | |
| - mel_spec: [n_way*k_shot, n_mels, time] | |
| - speaker_ids: [n_way*k_shot] | |
| query_data: | |
| - mel_spec: [n_way*k_query, n_mels, time] | |
| - speaker_ids: [n_way*k_query] | |
| Returns: | |
| loss: 标量损失值 | |
| acc: 准确率 | |
| """ | |
| # 获取支持集和查询集的嵌入向量 | |
| support_mel = support_data['mel_spec'].to(self.device) # [n_way*k_shot, n_mels, time] | |
| query_mel = query_data['mel_spec'].to(self.device) # [n_way*k_query, n_mels, time] | |
| # 获取嵌入向量 | |
| support_embeds = self.model(support_mel) # [n_way*k_shot, embedding_dim] | |
| query_embeds = self.model(query_mel) # [n_way*k_query, embedding_dim] | |
| # 计算支持集的质心 | |
| centroids = [] # 将存储每个说话人的质心 | |
| for speaker_idx in range(self.config.meta_learning.n_way): | |
| speaker_mask = (support_data['speaker_ids'] == speaker_idx).to(self.device) | |
| speaker_embeds = support_embeds[speaker_mask] # [k_shot, embedding_dim] | |
| centroid = speaker_embeds.mean(dim=0) # [embedding_dim] | |
| centroids.append(centroid) | |
| centroids = torch.stack(centroids) # [n_way, embedding_dim] | |
| # 计算查询集样本与各个质心的相似度 | |
| similarities = torch.matmul(query_embeds, centroids.T) # [n_way*k_query, n_way] | |
| # 计算分类损失 | |
| target = query_data['speaker_ids'].to(self.device) # [n_way*k_query] | |
| loss = self.criterion(similarities, target) | |
| # 计算准确率 | |
| pred = similarities.argmax(dim=1) # [n_way*k_query] | |
| acc = (pred == target).float().mean().item() | |
| return loss, acc | |
| def train_epoch(self, dataloader: DataLoader) -> Tuple[float, float]: | |
| """训练一个epoch""" | |
| self.model.train() | |
| total_loss = 0 | |
| total_acc = 0 | |
| with tqdm(dataloader, desc="Training") as pbar: | |
| for batch_idx, (support_batch, query_batch) in enumerate(pbar): | |
| self.optimizer.zero_grad() | |
| loss, acc = self.compute_loss(support_batch, query_batch) | |
| loss.backward() | |
| # 梯度裁剪 | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0) | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| total_acc += acc | |
| pbar.set_postfix({ | |
| 'loss': total_loss / (batch_idx + 1), | |
| 'acc': total_acc / (batch_idx + 1) | |
| }) | |
| if self.use_wandb: | |
| wandb.log({ | |
| 'batch_loss': loss.item(), | |
| 'batch_acc': acc | |
| }) | |
| avg_loss = total_loss / len(dataloader) | |
| avg_acc = total_acc / len(dataloader) | |
| return avg_loss, avg_acc | |
| def validate(self, dataloader: DataLoader) -> Tuple[float, float]: | |
| """验证模型""" | |
| self.model.eval() | |
| total_loss = 0 | |
| total_acc = 0 | |
| with torch.no_grad(): | |
| for support_batch, query_batch in dataloader: | |
| loss, acc = self.compute_loss(support_batch, query_batch) | |
| total_loss += loss.item() | |
| total_acc += acc | |
| avg_loss = total_loss / len(dataloader) | |
| avg_acc = total_acc / len(dataloader) | |
| return avg_loss, avg_acc | |
| def save_checkpoint(self, epoch: int, loss: float, acc: float): | |
| """保存检查点""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'loss': loss, | |
| 'acc': acc | |
| } | |
| checkpoint_path = os.path.join( | |
| self.config.training.checkpoint_dir, | |
| f'checkpoint_epoch_{epoch}.pt' | |
| ) | |
| os.makedirs(self.config.training.checkpoint_dir, exist_ok=True) | |
| torch.save(checkpoint, checkpoint_path) | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """加载检查点""" | |
| checkpoint = torch.load(checkpoint_path) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| return checkpoint['epoch'], checkpoint['loss'], checkpoint['acc'] |