Yiru Yang
fresh LFS version for Hugging Face push
5f2f308
"""
损失函数实现
"""
import torch
import torch.nn.functional as F
from peft import PeftModel
def compute_loss(student, feats, labels, mask):
"""计算基础损失"""
core = student.model if isinstance(student, PeftModel) else student
out = core(
input_features=feats,
attention_mask=mask,
labels=labels,
output_hidden_states=True
)
return out.loss
def compute_taid_lambda(step: int, config) -> float:
"""计算 TAID lambda 值(分段线性插值)"""
if step <= 0:
return config.distillation.taid.start
if step >= config.training.max_steps:
return config.distillation.taid.end
half = config.training.max_steps / 2
if step <= half:
return config.distillation.taid.start + \
(config.distillation.taid.mid - config.distillation.taid.start) * (step / half)
else:
return config.distillation.taid.mid + \
(config.distillation.taid.end - config.distillation.taid.mid) * ((step - half) / half)
def compute_loss_with_taid(student, teacher, proj, feats, labels, mask, step, config, device):
"""使用 TAID 计算损失"""
core = student.model if isinstance(student, PeftModel) else student
student_outputs = core(input_features=feats, attention_mask=mask,
labels=labels, output_hidden_states=True)
total_loss = 0.8 * student_outputs.loss if config.distillation.kl_weight > 0 else student_outputs.loss
if config.distillation.kl_weight > 0:
student_logits = student_outputs.logits
with torch.no_grad():
teacher_outputs = teacher(input_features=feats.cpu(), attention_mask=mask.cpu(), labels=labels.cpu())
teacher_logits = teacher_outputs.logits.to(device)
# 确保长度和词汇表大小一致
min_len = min(student_logits.size(1), teacher_logits.size(1))
min_vocab = min(student_logits.size(2), teacher_logits.size(2))
student_logits = student_logits[:, :min_len, :min_vocab]
teacher_logits = teacher_logits[:, :min_len, :min_vocab]
# 计算概率分布
student_probs = F.softmax(student_logits / config.distillation.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / config.distillation.temperature, dim=-1)
# TAID 插值
lambda_val = compute_taid_lambda(step, config)
interp = (1 - lambda_val) * student_probs + lambda_val * teacher_probs
# KL 散度损失
kl_loss = F.kl_div(
F.log_softmax(student_logits / config.distillation.temperature, dim=-1),
interp.detach(),
reduction='batchmean'
) * (config.distillation.temperature ** 2)
total_loss += config.distillation.kl_weight * kl_loss
# Hidden states 对齐损失
if config.distillation.hidden_beta > 0 and proj is not None:
student_hidden = student_outputs.encoder_last_hidden_state
with torch.no_grad():
teacher_full = teacher(input_features=feats.cpu(), attention_mask=mask.cpu(), labels=labels.cpu(),
output_hidden_states=True)
teacher_hidden = teacher_full.encoder_last_hidden_state.to(device)
projected = proj(student_hidden)
L = min(projected.size(1), teacher_hidden.size(1))
total_loss += config.distillation.hidden_beta * F.mse_loss(projected[:, :L, :], teacher_hidden[:, :L, :])
return total_loss