File size: 4,189 Bytes
5f2f308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
使用 TAID (Temperature-Aware Interpolation Distillation) 进行 LoRA 知识蒸馏训练。
包含超参数调优和完整训练流程。
"""

import os
import gc
import json
import torch
from transformers import WhisperProcessor, set_seed

from src.data.dataset import load_dataset
from src.data.dataloader import ASRDataLoader
from src.utils.config import load_config
from src.trainers.taid_trainer import TAIDDistillationTrainer

# 导入超参数搜索功能
from scripts.hyperparameter_search import run_hyperparameter_search

def run_training(config_path: str = "../configs/default_config.yaml") -> float:
    """运行完整训练流程:超参数优化 + 训练
    
    Args:
        config_path: 配置文件路径
        
    Returns:
        float: 验证集上的最佳性能指标
    """
    print("========================================")
    print("开始完整训练流程")
    print("========================================")
    
    # 1. 设置环境
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(42)
    
    # 2. 加载配置
    cfg = load_config(config_path)
    
    # 3. 加载数据集和处理器
    processor = WhisperProcessor.from_pretrained(cfg.model.student_model)
    train_ds, val_ds = load_dataset(
        root_dir=cfg.data.root_dir,
        processor=processor,
        sample_cap=cfg.data.sample_cap,
        val_ratio=cfg.data.val_ratio
    )
    
    # 4. 创建数据加载器
    data_loader = ASRDataLoader(
        processor=processor,
        batch_size=cfg.training.batch_size,
        num_workers=cfg.data.num_workers,
        max_frames=cfg.data.max_frames,
        sample_rate=cfg.data.sample_rate,
        pin_memory=cfg.data.pin_memory,
        persistent_workers=cfg.data.persistent_workers,
        prefetch_factor=cfg.data.prefetch_factor
    )
    
    # 5. 第一步:超参数优化
    print("\n第一步: 开始超参数优化...")
    best_params = run_hyperparameter_search(train_ds, data_loader.safe_collate_fn)
    print("✓ 超参数优化完成,最佳参数已更新到配置文件")
    
    # 6. 重新加载更新后的配置
    cfg = load_config(config_path)
    
    # 7. 第二步:使用最佳参数进行训练
    print("\n第二步: 开始正式训练...")
    print("使用优化后的超参数进行模型训练...")
    
    # 重新创建数据加载器(使用新的batch_size)
    data_loader = ASRDataLoader(
        processor=processor,
        batch_size=cfg.training.batch_size,
        num_workers=cfg.data.num_workers,
        max_frames=cfg.data.max_frames,
        sample_rate=cfg.data.sample_rate,
        pin_memory=cfg.data.pin_memory,
        persistent_workers=cfg.data.persistent_workers,
        prefetch_factor=cfg.data.prefetch_factor
    )
    train_loader = data_loader.get_loader(train_ds, shuffle=True, drop_last=True)
    val_loader = data_loader.get_loader(val_ds, shuffle=False, drop_last=False)
    
    # 8. 创建训练器并开始训练
    trainer = TAIDDistillationTrainer(cfg, processor, train_loader, val_loader, device)
    best_metric = trainer.train()
    
    # 9. 清理资源
    trainer.cleanup()
    torch.cuda.empty_cache()
    gc.collect()
    
    # 10. 显示训练结果
    out_dir = cfg.output.dir
    print("\n训练完成!")
    print("输出目录内容:", os.listdir(out_dir))
    print("适配器目录内容:", os.listdir(os.path.join(out_dir, "adapter")))
    
    # 11. 显示 TAID lambda 进展
    hist = json.load(open(os.path.join(out_dir, "training_history.json")))
    if "taid_lambda" in hist:
        print("\nTAID Lambda 进展:")
        for i in range(0, len(hist["taid_lambda"]), max(1, len(hist["taid_lambda"])//5)):
            step = hist["steps"][i]
            lambda_val = hist["taid_lambda"][i]
            print(f"  Step {step}: λ = {lambda_val:.3f}")
    
    print("\n========================================")
    print("✓ 完整训练流程成功完成!")
    print("  1. 超参数优化 ✓")
    print("  2. 模型训练 ✓")
    print("========================================")
    
    return best_metric

if __name__ == "__main__":
    run_training()