File size: 7,408 Bytes
5f2f308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import json
import random
import warnings
import itertools
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from transformers import WhisperForConditionalGeneration, set_seed
from peft import LoraConfig, get_peft_model, TaskType

import optuna
from optuna.samplers import TPESampler

# 添加项目路径
sys.path.append(str(Path(__file__).parent.parent))

from src.utils.config import load_config


def objective(trial, train_ds, safe_collate_fn, device):
    """Optuna 目标函数 - 使用 YAML 配置"""
    # 加载基础配置
    config = load_config("../configs/default_config.yaml")
    
    # Optuna 搜索的超参数
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    lora_dropout = trial.suggest_float("lora_dropout", 0.05, 0.3)
    temperature = trial.suggest_float("temperature", 1.0, 5.0)
    kl_weight = trial.suggest_float("kl_weight", 0.0, 1.0)
    hidden_beta = trial.suggest_float("hidden_beta", 0.0, 3.0)
    grad_accum = trial.suggest_categorical("grad_accum", [1, 2, 4])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    lora_r = trial.suggest_categorical("lora_r", [32, 64, 128])
    lora_alpha = trial.suggest_categorical("lora_alpha", [8, 16, 32])

    # 加载教师和学生模型
    teacher = WhisperForConditionalGeneration.from_pretrained(
        config.model.teacher_model, torch_dtype=torch.float16, use_cache=False
    ).eval()
    for p in teacher.parameters():
        p.requires_grad = False

    base = WhisperForConditionalGeneration.from_pretrained(
        config.model.student_model, torch_dtype=torch.float16, use_cache=False
    )
    for p in itertools.chain(base.model.encoder.parameters(),
                             base.model.decoder.parameters()):
        p.requires_grad = False

    # LoRA 配置
    lcfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=config.lora.target_modules,
        bias="none"
    )
    student = get_peft_model(base, lcfg).to(device)
    proj = nn.Linear(config.model.student_hidden_dim, config.model.teacher_hidden_dim).to(device) \
           if hidden_beta > 0 else None

    opt = torch.optim.AdamW(
        list(student.parameters()) + ([] if proj is None else list(proj.parameters())),
        lr=lr, weight_decay=config.optimizer.weight_decay
    )

    loader = DataLoader(train_ds,
                        batch_size=batch_size,
                        shuffle=True,
                        collate_fn=safe_collate_fn,
                        num_workers=0,
                        drop_last=True)
    
    total_loss, count = 0.0, 0
    for i, batch in enumerate(loader):
        if i >= 5: 
            break
        feats = batch["input_features"].half().to(device)
        labels = batch["labels"].to(device)
        mask = (feats.sum(1) != 0).long()
        
        with autocast():
            out = student.model(input_features=feats,
                                attention_mask=mask,
                                labels=labels,
                                output_hidden_states=True)
            loss = out.loss
        
        loss.backward()
        opt.step()
        opt.zero_grad()
        total_loss += loss.item()
        count += 1

    # 清理内存
    del teacher, student, base
    if proj:
        del proj
    torch.cuda.empty_cache()

    return total_loss / max(1, count)


def run_hyperparameter_search(train_dataset, collate_function):
    """运行超参数搜索"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 设置随机种子
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    set_seed(SEED)
    warnings.filterwarnings("ignore")
    
    print(f"Using device: {device}")

    # 创建 Optuna 研究
    study = optuna.create_study(direction="minimize", sampler=TPESampler(seed=SEED))
    study.optimize(
        lambda trial: objective(trial, train_dataset, collate_function, device),
        n_trials=50
    )

    # 获取基础配置和最佳参数
    config = load_config("../configs/default_config.yaml")
    best_params = study.best_params

    # 更新配置  cfg = CFG(); for k, v in study.best_params.items(): setattr(cfg, k, v)
    import copy
    updated_config = copy.deepcopy(config)
    
    for k, v in best_params.items():
        if k == 'lr':
            updated_config.optimizer.lr = v
        elif k == 'lora_dropout':
            updated_config.lora.dropout = v
        elif k == 'temperature':
            updated_config.distillation.temperature = v
        elif k == 'kl_weight':
            updated_config.distillation.kl_weight = v
        elif k == 'hidden_beta':
            updated_config.distillation.hidden_beta = v
        elif k == 'grad_accum':
            updated_config.training.grad_accum = v
        elif k == 'batch_size':
            updated_config.training.batch_size = v
        elif k == 'lora_r':
            updated_config.lora.r = v
        elif k == 'lora_alpha':
            updated_config.lora.alpha = v

    # 保存配置
    os.makedirs(config.output.dir, exist_ok=True)
    
    # 保存最佳参数到 JSON(用于记录)
    with open(os.path.join(config.output.dir, "training_config.json"), "w") as f:
        json.dump(best_params, f, indent=2)
    
    # 将最佳参数写回到 default_config.yaml
    import yaml
    config_path = "../configs/default_config.yaml"
    
    # 读取原始 YAML 文件
    with open(config_path, 'r', encoding='utf-8') as f:
        yaml_data = yaml.safe_load(f)
    
    # 更新相关参数
    for k, v in best_params.items():
        if k == 'lr':
            yaml_data['optimizer']['lr'] = v
        elif k == 'lora_dropout':
            yaml_data['lora']['dropout'] = v
        elif k == 'temperature':
            yaml_data['distillation']['temperature'] = v
        elif k == 'kl_weight':
            yaml_data['distillation']['kl_weight'] = v
        elif k == 'hidden_beta':
            yaml_data['distillation']['hidden_beta'] = v
        elif k == 'grad_accum':
            yaml_data['training']['grad_accum'] = v
        elif k == 'batch_size':
            yaml_data['training']['batch_size'] = v
        elif k == 'lora_r':
            yaml_data['lora']['r'] = v
        elif k == 'lora_alpha':
            yaml_data['lora']['alpha'] = v
    
    # 写回文件
    with open(config_path, 'w', encoding='utf-8') as f:
        yaml.dump(yaml_data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
    
    print(f"已将最佳参数更新到 {config_path}")

    # 打印结果
    print("=== Hyperparameters from 50-run auto-search ===")
    print(f"  lr: {best_params['lr']}")
    print(f"  lora_dropout: {best_params['lora_dropout']}")
    print(f"  temperature: {best_params['temperature']}")
    print(f"  kl_weight: {best_params['kl_weight']}")
    print(f"  hidden_beta: {best_params['hidden_beta']}")
    print(f"  grad_accum: {best_params['grad_accum']}")
    print(f"  batch_size: {best_params['batch_size']}")
    print(f"  lora_r: {best_params['lora_r']}")
    print(f"  lora_alpha: {best_params['lora_alpha']}")

    return updated_config