ASR_latent_space / scripts /train_taid.py
Yiru Yang
fresh LFS version for Hugging Face push
5f2f308
"""
使用 TAID (Temperature-Aware Interpolation Distillation) 进行 LoRA 知识蒸馏训练。
包含超参数调优和完整训练流程。
"""
import os
import gc
import json
import torch
from transformers import WhisperProcessor, set_seed
from src.data.dataset import load_dataset
from src.data.dataloader import ASRDataLoader
from src.utils.config import load_config
from src.trainers.taid_trainer import TAIDDistillationTrainer
# 导入超参数搜索功能
from scripts.hyperparameter_search import run_hyperparameter_search
def run_training(config_path: str = "../configs/default_config.yaml") -> float:
"""运行完整训练流程:超参数优化 + 训练
Args:
config_path: 配置文件路径
Returns:
float: 验证集上的最佳性能指标
"""
print("========================================")
print("开始完整训练流程")
print("========================================")
# 1. 设置环境
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
# 2. 加载配置
cfg = load_config(config_path)
# 3. 加载数据集和处理器
processor = WhisperProcessor.from_pretrained(cfg.model.student_model)
train_ds, val_ds = load_dataset(
root_dir=cfg.data.root_dir,
processor=processor,
sample_cap=cfg.data.sample_cap,
val_ratio=cfg.data.val_ratio
)
# 4. 创建数据加载器
data_loader = ASRDataLoader(
processor=processor,
batch_size=cfg.training.batch_size,
num_workers=cfg.data.num_workers,
max_frames=cfg.data.max_frames,
sample_rate=cfg.data.sample_rate,
pin_memory=cfg.data.pin_memory,
persistent_workers=cfg.data.persistent_workers,
prefetch_factor=cfg.data.prefetch_factor
)
# 5. 第一步:超参数优化
print("\n第一步: 开始超参数优化...")
best_params = run_hyperparameter_search(train_ds, data_loader.safe_collate_fn)
print("✓ 超参数优化完成,最佳参数已更新到配置文件")
# 6. 重新加载更新后的配置
cfg = load_config(config_path)
# 7. 第二步:使用最佳参数进行训练
print("\n第二步: 开始正式训练...")
print("使用优化后的超参数进行模型训练...")
# 重新创建数据加载器(使用新的batch_size)
data_loader = ASRDataLoader(
processor=processor,
batch_size=cfg.training.batch_size,
num_workers=cfg.data.num_workers,
max_frames=cfg.data.max_frames,
sample_rate=cfg.data.sample_rate,
pin_memory=cfg.data.pin_memory,
persistent_workers=cfg.data.persistent_workers,
prefetch_factor=cfg.data.prefetch_factor
)
train_loader = data_loader.get_loader(train_ds, shuffle=True, drop_last=True)
val_loader = data_loader.get_loader(val_ds, shuffle=False, drop_last=False)
# 8. 创建训练器并开始训练
trainer = TAIDDistillationTrainer(cfg, processor, train_loader, val_loader, device)
best_metric = trainer.train()
# 9. 清理资源
trainer.cleanup()
torch.cuda.empty_cache()
gc.collect()
# 10. 显示训练结果
out_dir = cfg.output.dir
print("\n训练完成!")
print("输出目录内容:", os.listdir(out_dir))
print("适配器目录内容:", os.listdir(os.path.join(out_dir, "adapter")))
# 11. 显示 TAID lambda 进展
hist = json.load(open(os.path.join(out_dir, "training_history.json")))
if "taid_lambda" in hist:
print("\nTAID Lambda 进展:")
for i in range(0, len(hist["taid_lambda"]), max(1, len(hist["taid_lambda"])//5)):
step = hist["steps"][i]
lambda_val = hist["taid_lambda"][i]
print(f" Step {step}: λ = {lambda_val:.3f}")
print("\n========================================")
print("✓ 完整训练流程成功完成!")
print(" 1. 超参数优化 ✓")
print(" 2. 模型训练 ✓")
print("========================================")
return best_metric
if __name__ == "__main__":
run_training()