File size: 2,093 Bytes
5f2f308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
LoRA (Low-Rank Adaptation) implementation.
""" 

import torch
import torch.nn as nn
from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from typing import Optional
import itertools


def build_lora_student_model(
    config,
    device: torch.device
) -> tuple[PeftModel, Optional[nn.Linear]]:
    """构建 LoRA 学生模型"""
    base = WhisperForConditionalGeneration.from_pretrained(
        config.model.student_model, 
        torch_dtype=torch.float16, 
        use_cache=False
    )
    base.gradient_checkpointing_enable()
    
    for p in itertools.chain(base.model.encoder.parameters(), 
                             base.model.decoder.parameters()):
        p.requires_grad = False

    # LoRA 配置
    lcfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora.r,
        lora_alpha=config.lora.alpha,
        lora_dropout=config.lora.dropout,
        target_modules=config.lora.target_modules,
        bias="none"
    )
    
    student = get_peft_model(base, lcfg).to(device)
    
    proj = nn.Linear(
        config.model.student_hidden_dim, 
        config.model.teacher_hidden_dim
    ).to(device) if config.distillation.hidden_beta > 0 else None
    
    trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
    if proj:
        trainable += sum(p.numel() for p in proj.parameters())
    total = sum(p.numel() for p in student.parameters()) + \
            (sum(p.numel() for p in proj.parameters()) if proj else 0)
    print(f"Trainable parameters: {trainable:,} ({trainable/total*100:.2f}%)")
    
    return student, proj


def build_teacher_model(
    config,
    device: torch.device
) -> WhisperForConditionalGeneration:
    """构建教师模型"""
    teacher = WhisperForConditionalGeneration.from_pretrained(
        config.model.teacher_model, 
        torch_dtype=torch.float16, 
        use_cache=False
    ).eval()
    
    # 冻结所有参数
    for p in teacher.parameters():
        p.requires_grad = False
    
    return teacher