Spaces:
Build error
Build error
| import argparse | |
| import torch | |
| from pathlib import Path | |
| from torch.utils.data import DataLoader | |
| from src.configs.config import Config | |
| from src.models.encoder import SpeakerEncoder | |
| from src.data.dataset import create_meta_learning_dataloader | |
| from src.trainers.meta_trainer import MetaTrainer | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='训练说话人编码器') | |
| parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') | |
| parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径') | |
| parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # 加载配置 | |
| config = Config() | |
| config.data.root_dir = args.data_dir | |
| # 创建数据加载器 | |
| train_loader = create_meta_learning_dataloader( | |
| root_dir=config.data.root_dir, | |
| n_way=config.meta_learning.n_way, | |
| k_shot=config.meta_learning.k_shot, | |
| k_query=config.meta_learning.k_query, | |
| n_tasks=config.meta_learning.n_tasks, | |
| batch_size=config.meta_learning.batch_size, | |
| num_workers=config.meta_learning.num_workers | |
| ) | |
| # 创建验证集数据加载器 | |
| val_loader = create_meta_learning_dataloader( | |
| root_dir=config.data.root_dir, | |
| n_way=config.meta_learning.n_way, | |
| k_shot=config.meta_learning.k_shot, | |
| k_query=config.meta_learning.k_query, | |
| n_tasks=config.meta_learning.n_tasks // 10, # 验证集任务数较少 | |
| batch_size=config.meta_learning.batch_size, | |
| num_workers=config.meta_learning.num_workers | |
| ) | |
| # 创建模型 | |
| model = SpeakerEncoder( | |
| input_dim=config.audio.n_mels, | |
| hidden_dim=256, | |
| embedding_dim=512 | |
| ) | |
| # 创建训练器 | |
| trainer = MetaTrainer( | |
| model=model, | |
| config=config, | |
| use_wandb=not args.no_wandb | |
| ) | |
| # 如果指定了检查点,则加载 | |
| start_epoch = 0 | |
| if args.checkpoint: | |
| print(f"Loading checkpoint from {args.checkpoint}") | |
| start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint) | |
| start_epoch += 1 | |
| # 开始训练 | |
| print("Starting training...") | |
| best_val_acc = 0 | |
| for epoch in range(start_epoch, config.training.num_epochs): | |
| print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}") | |
| # 训练一个epoch | |
| train_loss, train_acc = trainer.train_epoch(train_loader) | |
| print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}") | |
| # 验证 | |
| val_loss, val_acc = trainer.validate(val_loader) | |
| print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}") | |
| # 保存最佳模型 | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| trainer.save_checkpoint( | |
| epoch=epoch, | |
| loss=val_loss, | |
| acc=val_acc | |
| ) | |
| print(f"Saved new best model with validation accuracy: {val_acc:.4f}") | |
| if __name__ == '__main__': | |
| main() |