| import os |
| import json |
| import random |
| import warnings |
| import itertools |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.cuda.amp import autocast |
| from torch.utils.data import DataLoader |
| from transformers import WhisperForConditionalGeneration, set_seed |
| from peft import LoraConfig, get_peft_model, TaskType |
|
|
| import optuna |
| from optuna.samplers import TPESampler |
|
|
| |
| sys.path.append(str(Path(__file__).parent.parent)) |
|
|
| from src.utils.config import load_config |
|
|
|
|
| def objective(trial, train_ds, safe_collate_fn, device): |
| """Optuna 目标函数 - 使用 YAML 配置""" |
| |
| config = load_config("../configs/default_config.yaml") |
| |
| |
| lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True) |
| lora_dropout = trial.suggest_float("lora_dropout", 0.05, 0.3) |
| temperature = trial.suggest_float("temperature", 1.0, 5.0) |
| kl_weight = trial.suggest_float("kl_weight", 0.0, 1.0) |
| hidden_beta = trial.suggest_float("hidden_beta", 0.0, 3.0) |
| grad_accum = trial.suggest_categorical("grad_accum", [1, 2, 4]) |
| batch_size = trial.suggest_categorical("batch_size", [8, 16, 32]) |
| lora_r = trial.suggest_categorical("lora_r", [32, 64, 128]) |
| lora_alpha = trial.suggest_categorical("lora_alpha", [8, 16, 32]) |
|
|
| |
| teacher = WhisperForConditionalGeneration.from_pretrained( |
| config.model.teacher_model, torch_dtype=torch.float16, use_cache=False |
| ).eval() |
| for p in teacher.parameters(): |
| p.requires_grad = False |
|
|
| base = WhisperForConditionalGeneration.from_pretrained( |
| config.model.student_model, torch_dtype=torch.float16, use_cache=False |
| ) |
| for p in itertools.chain(base.model.encoder.parameters(), |
| base.model.decoder.parameters()): |
| p.requires_grad = False |
|
|
| |
| lcfg = LoraConfig( |
| task_type=TaskType.SEQ_2_SEQ_LM, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=config.lora.target_modules, |
| bias="none" |
| ) |
| student = get_peft_model(base, lcfg).to(device) |
| proj = nn.Linear(config.model.student_hidden_dim, config.model.teacher_hidden_dim).to(device) \ |
| if hidden_beta > 0 else None |
|
|
| opt = torch.optim.AdamW( |
| list(student.parameters()) + ([] if proj is None else list(proj.parameters())), |
| lr=lr, weight_decay=config.optimizer.weight_decay |
| ) |
|
|
| loader = DataLoader(train_ds, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=safe_collate_fn, |
| num_workers=0, |
| drop_last=True) |
| |
| total_loss, count = 0.0, 0 |
| for i, batch in enumerate(loader): |
| if i >= 5: |
| break |
| feats = batch["input_features"].half().to(device) |
| labels = batch["labels"].to(device) |
| mask = (feats.sum(1) != 0).long() |
| |
| with autocast(): |
| out = student.model(input_features=feats, |
| attention_mask=mask, |
| labels=labels, |
| output_hidden_states=True) |
| loss = out.loss |
| |
| loss.backward() |
| opt.step() |
| opt.zero_grad() |
| total_loss += loss.item() |
| count += 1 |
|
|
| |
| del teacher, student, base |
| if proj: |
| del proj |
| torch.cuda.empty_cache() |
|
|
| return total_loss / max(1, count) |
|
|
|
|
| def run_hyperparameter_search(train_dataset, collate_function): |
| """运行超参数搜索""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| SEED = 42 |
| random.seed(SEED) |
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
| torch.cuda.manual_seed_all(SEED) |
| set_seed(SEED) |
| warnings.filterwarnings("ignore") |
| |
| print(f"Using device: {device}") |
|
|
| |
| study = optuna.create_study(direction="minimize", sampler=TPESampler(seed=SEED)) |
| study.optimize( |
| lambda trial: objective(trial, train_dataset, collate_function, device), |
| n_trials=50 |
| ) |
|
|
| |
| config = load_config("../configs/default_config.yaml") |
| best_params = study.best_params |
|
|
| |
| import copy |
| updated_config = copy.deepcopy(config) |
| |
| for k, v in best_params.items(): |
| if k == 'lr': |
| updated_config.optimizer.lr = v |
| elif k == 'lora_dropout': |
| updated_config.lora.dropout = v |
| elif k == 'temperature': |
| updated_config.distillation.temperature = v |
| elif k == 'kl_weight': |
| updated_config.distillation.kl_weight = v |
| elif k == 'hidden_beta': |
| updated_config.distillation.hidden_beta = v |
| elif k == 'grad_accum': |
| updated_config.training.grad_accum = v |
| elif k == 'batch_size': |
| updated_config.training.batch_size = v |
| elif k == 'lora_r': |
| updated_config.lora.r = v |
| elif k == 'lora_alpha': |
| updated_config.lora.alpha = v |
|
|
| |
| os.makedirs(config.output.dir, exist_ok=True) |
| |
| |
| with open(os.path.join(config.output.dir, "training_config.json"), "w") as f: |
| json.dump(best_params, f, indent=2) |
| |
| |
| import yaml |
| config_path = "../configs/default_config.yaml" |
| |
| |
| with open(config_path, 'r', encoding='utf-8') as f: |
| yaml_data = yaml.safe_load(f) |
| |
| |
| for k, v in best_params.items(): |
| if k == 'lr': |
| yaml_data['optimizer']['lr'] = v |
| elif k == 'lora_dropout': |
| yaml_data['lora']['dropout'] = v |
| elif k == 'temperature': |
| yaml_data['distillation']['temperature'] = v |
| elif k == 'kl_weight': |
| yaml_data['distillation']['kl_weight'] = v |
| elif k == 'hidden_beta': |
| yaml_data['distillation']['hidden_beta'] = v |
| elif k == 'grad_accum': |
| yaml_data['training']['grad_accum'] = v |
| elif k == 'batch_size': |
| yaml_data['training']['batch_size'] = v |
| elif k == 'lora_r': |
| yaml_data['lora']['r'] = v |
| elif k == 'lora_alpha': |
| yaml_data['lora']['alpha'] = v |
| |
| |
| with open(config_path, 'w', encoding='utf-8') as f: |
| yaml.dump(yaml_data, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
| |
| print(f"已将最佳参数更新到 {config_path}") |
|
|
| |
| print("=== Hyperparameters from 50-run auto-search ===") |
| print(f" lr: {best_params['lr']}") |
| print(f" lora_dropout: {best_params['lora_dropout']}") |
| print(f" temperature: {best_params['temperature']}") |
| print(f" kl_weight: {best_params['kl_weight']}") |
| print(f" hidden_beta: {best_params['hidden_beta']}") |
| print(f" grad_accum: {best_params['grad_accum']}") |
| print(f" batch_size: {best_params['batch_size']}") |
| print(f" lora_r: {best_params['lora_r']}") |
| print(f" lora_alpha: {best_params['lora_alpha']}") |
|
|
| return updated_config |