File size: 2,093 Bytes
5f2f308 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | """
LoRA (Low-Rank Adaptation) implementation.
"""
import torch
import torch.nn as nn
from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from typing import Optional
import itertools
def build_lora_student_model(
config,
device: torch.device
) -> tuple[PeftModel, Optional[nn.Linear]]:
"""构建 LoRA 学生模型"""
base = WhisperForConditionalGeneration.from_pretrained(
config.model.student_model,
torch_dtype=torch.float16,
use_cache=False
)
base.gradient_checkpointing_enable()
for p in itertools.chain(base.model.encoder.parameters(),
base.model.decoder.parameters()):
p.requires_grad = False
# LoRA 配置
lcfg = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=config.lora.r,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
target_modules=config.lora.target_modules,
bias="none"
)
student = get_peft_model(base, lcfg).to(device)
proj = nn.Linear(
config.model.student_hidden_dim,
config.model.teacher_hidden_dim
).to(device) if config.distillation.hidden_beta > 0 else None
trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
if proj:
trainable += sum(p.numel() for p in proj.parameters())
total = sum(p.numel() for p in student.parameters()) + \
(sum(p.numel() for p in proj.parameters()) if proj else 0)
print(f"Trainable parameters: {trainable:,} ({trainable/total*100:.2f}%)")
return student, proj
def build_teacher_model(
config,
device: torch.device
) -> WhisperForConditionalGeneration:
"""构建教师模型"""
teacher = WhisperForConditionalGeneration.from_pretrained(
config.model.teacher_model,
torch_dtype=torch.float16,
use_cache=False
).eval()
# 冻结所有参数
for p in teacher.parameters():
p.requires_grad = False
return teacher |