| """ |
| 损失函数实现 |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from peft import PeftModel |
|
|
|
|
| def compute_loss(student, feats, labels, mask): |
| """计算基础损失""" |
| core = student.model if isinstance(student, PeftModel) else student |
| out = core( |
| input_features=feats, |
| attention_mask=mask, |
| labels=labels, |
| output_hidden_states=True |
| ) |
| return out.loss |
|
|
|
|
| def compute_taid_lambda(step: int, config) -> float: |
| """计算 TAID lambda 值(分段线性插值)""" |
| if step <= 0: |
| return config.distillation.taid.start |
| if step >= config.training.max_steps: |
| return config.distillation.taid.end |
| |
| half = config.training.max_steps / 2 |
| if step <= half: |
| return config.distillation.taid.start + \ |
| (config.distillation.taid.mid - config.distillation.taid.start) * (step / half) |
| else: |
| return config.distillation.taid.mid + \ |
| (config.distillation.taid.end - config.distillation.taid.mid) * ((step - half) / half) |
|
|
|
|
| def compute_loss_with_taid(student, teacher, proj, feats, labels, mask, step, config, device): |
| """使用 TAID 计算损失""" |
| core = student.model if isinstance(student, PeftModel) else student |
| student_outputs = core(input_features=feats, attention_mask=mask, |
| labels=labels, output_hidden_states=True) |
| total_loss = 0.8 * student_outputs.loss if config.distillation.kl_weight > 0 else student_outputs.loss |
|
|
| if config.distillation.kl_weight > 0: |
| student_logits = student_outputs.logits |
| with torch.no_grad(): |
| teacher_outputs = teacher(input_features=feats.cpu(), attention_mask=mask.cpu(), labels=labels.cpu()) |
| teacher_logits = teacher_outputs.logits.to(device) |
| |
| |
| min_len = min(student_logits.size(1), teacher_logits.size(1)) |
| min_vocab = min(student_logits.size(2), teacher_logits.size(2)) |
| student_logits = student_logits[:, :min_len, :min_vocab] |
| teacher_logits = teacher_logits[:, :min_len, :min_vocab] |
| |
| |
| student_probs = F.softmax(student_logits / config.distillation.temperature, dim=-1) |
| teacher_probs = F.softmax(teacher_logits / config.distillation.temperature, dim=-1) |
| |
| |
| lambda_val = compute_taid_lambda(step, config) |
| interp = (1 - lambda_val) * student_probs + lambda_val * teacher_probs |
| |
| |
| kl_loss = F.kl_div( |
| F.log_softmax(student_logits / config.distillation.temperature, dim=-1), |
| interp.detach(), |
| reduction='batchmean' |
| ) * (config.distillation.temperature ** 2) |
| total_loss += config.distillation.kl_weight * kl_loss |
|
|
| |
| if config.distillation.hidden_beta > 0 and proj is not None: |
| student_hidden = student_outputs.encoder_last_hidden_state |
| with torch.no_grad(): |
| teacher_full = teacher(input_features=feats.cpu(), attention_mask=mask.cpu(), labels=labels.cpu(), |
| output_hidden_states=True) |
| teacher_hidden = teacher_full.encoder_last_hidden_state.to(device) |
| projected = proj(student_hidden) |
| L = min(projected.size(1), teacher_hidden.size(1)) |
| total_loss += config.distillation.hidden_beta * F.mse_loss(projected[:, :L, :], teacher_hidden[:, :L, :]) |
|
|
| return total_loss |