| """ |
| 使用 TAID (Temperature-Aware Interpolation Distillation) 进行 LoRA 知识蒸馏训练。 |
| 包含超参数调优和完整训练流程。 |
| """ |
|
|
| import os |
| import gc |
| import json |
| import torch |
| from transformers import WhisperProcessor, set_seed |
|
|
| from src.data.dataset import load_dataset |
| from src.data.dataloader import ASRDataLoader |
| from src.utils.config import load_config |
| from src.trainers.taid_trainer import TAIDDistillationTrainer |
|
|
| |
| from scripts.hyperparameter_search import run_hyperparameter_search |
|
|
| def run_training(config_path: str = "../configs/default_config.yaml") -> float: |
| """运行完整训练流程:超参数优化 + 训练 |
| |
| Args: |
| config_path: 配置文件路径 |
| |
| Returns: |
| float: 验证集上的最佳性能指标 |
| """ |
| print("========================================") |
| print("开始完整训练流程") |
| print("========================================") |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| set_seed(42) |
| |
| |
| cfg = load_config(config_path) |
| |
| |
| processor = WhisperProcessor.from_pretrained(cfg.model.student_model) |
| train_ds, val_ds = load_dataset( |
| root_dir=cfg.data.root_dir, |
| processor=processor, |
| sample_cap=cfg.data.sample_cap, |
| val_ratio=cfg.data.val_ratio |
| ) |
| |
| |
| data_loader = ASRDataLoader( |
| processor=processor, |
| batch_size=cfg.training.batch_size, |
| num_workers=cfg.data.num_workers, |
| max_frames=cfg.data.max_frames, |
| sample_rate=cfg.data.sample_rate, |
| pin_memory=cfg.data.pin_memory, |
| persistent_workers=cfg.data.persistent_workers, |
| prefetch_factor=cfg.data.prefetch_factor |
| ) |
| |
| |
| print("\n第一步: 开始超参数优化...") |
| best_params = run_hyperparameter_search(train_ds, data_loader.safe_collate_fn) |
| print("✓ 超参数优化完成,最佳参数已更新到配置文件") |
| |
| |
| cfg = load_config(config_path) |
| |
| |
| print("\n第二步: 开始正式训练...") |
| print("使用优化后的超参数进行模型训练...") |
| |
| |
| data_loader = ASRDataLoader( |
| processor=processor, |
| batch_size=cfg.training.batch_size, |
| num_workers=cfg.data.num_workers, |
| max_frames=cfg.data.max_frames, |
| sample_rate=cfg.data.sample_rate, |
| pin_memory=cfg.data.pin_memory, |
| persistent_workers=cfg.data.persistent_workers, |
| prefetch_factor=cfg.data.prefetch_factor |
| ) |
| train_loader = data_loader.get_loader(train_ds, shuffle=True, drop_last=True) |
| val_loader = data_loader.get_loader(val_ds, shuffle=False, drop_last=False) |
| |
| |
| trainer = TAIDDistillationTrainer(cfg, processor, train_loader, val_loader, device) |
| best_metric = trainer.train() |
| |
| |
| trainer.cleanup() |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| |
| out_dir = cfg.output.dir |
| print("\n训练完成!") |
| print("输出目录内容:", os.listdir(out_dir)) |
| print("适配器目录内容:", os.listdir(os.path.join(out_dir, "adapter"))) |
| |
| |
| hist = json.load(open(os.path.join(out_dir, "training_history.json"))) |
| if "taid_lambda" in hist: |
| print("\nTAID Lambda 进展:") |
| for i in range(0, len(hist["taid_lambda"]), max(1, len(hist["taid_lambda"])//5)): |
| step = hist["steps"][i] |
| lambda_val = hist["taid_lambda"][i] |
| print(f" Step {step}: λ = {lambda_val:.3f}") |
| |
| print("\n========================================") |
| print("✓ 完整训练流程成功完成!") |
| print(" 1. 超参数优化 ✓") |
| print(" 2. 模型训练 ✓") |
| print("========================================") |
| |
| return best_metric |
|
|
| if __name__ == "__main__": |
| run_training() |
|
|