Chordia / src /models /loss_functions.py
Corolin's picture
first commit
0a6452f
"""
损失函数模块
Loss Functions for PAD Predictor
该模块包含了PAD预测器的各种损失函数,包括:
- 加权均方误差损失(WMSE)
- 置信度损失函数
- 组合损失函数
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple
import logging
class WeightedMSELoss(nn.Module):
"""
加权均方误差损失函数
支持对不同输出组件(ΔPAD、ΔPressure、Confidence)设置不同的权重
"""
def __init__(self,
delta_pad_weight: float = 1.0,
delta_pressure_weight: float = 1.0,
confidence_weight: float = 0.5,
reduction: str = 'mean'):
"""
初始化加权MSE损失
Args:
delta_pad_weight: ΔPAD损失的权重
delta_pressure_weight: ΔPressure损失的权重
confidence_weight: Confidence损失的权重
reduction: 损失聚合方式 ('mean', 'sum', 'none')
"""
super(WeightedMSELoss, self).__init__()
self.delta_pad_weight = delta_pad_weight
self.delta_pressure_weight = delta_pressure_weight
self.confidence_weight = confidence_weight
self.reduction = reduction
self.logger = logging.getLogger(__name__)
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算加权MSE损失
Args:
predictions: 预测值,形状为 (batch_size, 5)
targets: 真实值,形状为 (batch_size, 5)
Returns:
加权MSE损失
"""
# 输入验证
if predictions.shape != targets.shape:
raise ValueError(f"预测值和真实值形状不匹配: {predictions.shape} vs {targets.shape}")
if predictions.size(1) != 5:
raise ValueError(f"输出维度应该是5,但得到的是 {predictions.size(1)}")
# 分解输出组件
pred_delta_pad = predictions[:, :3] # ΔPAD (3维)
pred_delta_pressure = predictions[:, 3:4] # ΔPressure (1维)
pred_confidence = predictions[:, 4:5] # Confidence (1维)
target_delta_pad = targets[:, :3]
target_delta_pressure = targets[:, 3:4]
target_confidence = targets[:, 4:5]
# 计算各组件的MSE损失
mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction)
mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction)
mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction)
# 加权求和
total_loss = (self.delta_pad_weight * mse_delta_pad +
self.delta_pressure_weight * mse_delta_pressure +
self.confidence_weight * mse_confidence)
return total_loss
def get_component_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
获取各组件的损失值
Args:
predictions: 预测值
targets: 真实值
Returns:
包含各组件损失的字典
"""
# 分解输出组件
pred_delta_pad = predictions[:, :3]
pred_delta_pressure = predictions[:, 3:4]
pred_confidence = predictions[:, 4:5]
target_delta_pad = targets[:, :3]
target_delta_pressure = targets[:, 3:4]
target_confidence = targets[:, 4:5]
# 计算各组件的MSE损失
losses = {
'delta_pad_mse': F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction),
'delta_pressure_mse': F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction),
'confidence_mse': F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction)
}
# 计算加权损失
losses['weighted_total'] = (self.delta_pad_weight * losses['delta_pad_mse'] +
self.delta_pressure_weight * losses['delta_pressure_mse'] +
self.confidence_weight * losses['confidence_mse'])
return losses
class ConfidenceLoss(nn.Module):
"""
置信度损失函数
该损失函数旨在校准预测的置信度,使其能够反映实际的预测准确性
"""
def __init__(self,
base_loss_weight: float = 1.0,
confidence_weight: float = 0.1,
temperature: float = 1.0,
reduction: str = 'mean'):
"""
初始化置信度损失
Args:
base_loss_weight: 基础损失(如MSE)的权重
confidence_weight: 置信度校准损失的权重
temperature: 温度参数,用于调节置信度的敏感度
reduction: 损失聚合方式
"""
super(ConfidenceLoss, self).__init__()
self.base_loss_weight = base_loss_weight
self.confidence_weight = confidence_weight
self.temperature = temperature
self.reduction = reduction
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算置信度损失
Args:
predictions: 预测值,形状为 (batch_size, 5)
targets: 真实值,形状为 (batch_size, 5)
Returns:
置信度损失
"""
# 分离预测和置信度
pred_components = predictions[:, :4] # ΔPAD (3维) + ΔPressure (1维)
pred_confidence = predictions[:, 4:5] # Confidence (1维)
target_components = targets[:, :4]
# 计算基础损失(MSE)
base_loss = F.mse_loss(pred_components, target_components, reduction=self.reduction)
# 计算每个样本的预测误差
if self.reduction == 'none':
sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True)
else:
# 如果使用mean或sum,需要计算每个样本的误差
sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True)
# 将置信度映射到[0, 1]范围
confidence = torch.sigmoid(pred_confidence / self.temperature)
# 置信度校准损失:希望高置信度对应低误差,低置信度对应高误差
# 使用负对数似然损失
confidence_loss = -torch.log(confidence + 1e-8) * sample_errors
if self.reduction == 'mean':
confidence_loss = torch.mean(confidence_loss)
elif self.reduction == 'sum':
confidence_loss = torch.sum(confidence_loss)
# 组合损失
total_loss = self.base_loss_weight * base_loss + self.confidence_weight * confidence_loss
return total_loss
class AdaptiveWeightedLoss(nn.Module):
"""
自适应加权损失函数
根据训练过程动态调整各组件的权重
"""
def __init__(self,
initial_weights: Dict[str, float] = None,
adaptation_rate: float = 0.01,
min_weight: float = 0.1,
max_weight: float = 2.0):
"""
初始化自适应加权损失
Args:
initial_weights: 初始权重字典
adaptation_rate: 权重调整率
min_weight: 最小权重值
max_weight: 最大权重值
"""
super(AdaptiveWeightedLoss, self).__init__()
if initial_weights is None:
initial_weights = {
'delta_pad': 1.0,
'delta_pressure': 1.0,
'confidence': 0.5
}
self.weights = nn.ParameterDict({
key: nn.Parameter(torch.tensor(value, dtype=torch.float32))
for key, value in initial_weights.items()
})
self.adaptation_rate = adaptation_rate
self.min_weight = min_weight
self.max_weight = max_weight
# 冻结权重参数,不让优化器更新
for param in self.weights.parameters():
param.requires_grad = False
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算自适应加权损失
Args:
predictions: 预测值
targets: 真实值
Returns:
自适应加权损失
"""
# 分解输出组件
pred_delta_pad = predictions[:, :3]
pred_delta_pressure = predictions[:, 3:4]
pred_confidence = predictions[:, 4:5]
target_delta_pad = targets[:, :3]
target_delta_pressure = targets[:, 3:4]
target_confidence = targets[:, 4:5]
# 计算各组件的MSE损失
mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction='mean')
mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction='mean')
mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction='mean')
# 加权求和
total_loss = (self.weights['delta_pad'] * mse_delta_pad +
self.weights['delta_pressure'] * mse_delta_pressure +
self.weights['confidence'] * mse_confidence)
return total_loss
def update_weights(self, component_losses: Dict[str, float]):
"""
根据组件损失更新权重
Args:
component_losses: 各组件的损失值
"""
# 计算总损失
total_loss = sum(component_losses.values())
# 更新权重:损失越大的组件,权重越高
for component, loss in component_losses.items():
if component in self.weights:
# 计算新的权重
new_weight = self.weights[component].item() * (1 + self.adaptation_rate * (loss / total_loss - 1/len(component_losses)))
# 限制权重范围
new_weight = max(self.min_weight, min(self.max_weight, new_weight))
# 更新权重
self.weights[component].data.fill_(new_weight)
def get_current_weights(self) -> Dict[str, float]:
"""
获取当前权重
Returns:
当前权重字典
"""
return {key: param.item() for key, param in self.weights.items()}
class FocalLoss(nn.Module):
"""
Focal Loss 变体,用于回归任务
专注于难预测的样本
"""
def __init__(self,
alpha: float = 1.0,
gamma: float = 2.0,
reduction: str = 'mean'):
"""
初始化Focal Loss
Args:
alpha: 平衡因子
gamma: 聚焦参数
reduction: 损失聚合方式
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算Focal Loss
Args:
predictions: 预测值
targets: 真实值
Returns:
Focal Loss
"""
mse = F.mse_loss(predictions, targets, reduction='none')
# 计算每个样本的误差
abs_error = torch.abs(predictions - targets)
# 计算Focal权重
focal_weight = self.alpha * torch.pow(1 - torch.exp(-abs_error), self.gamma)
focal_loss = focal_weight * mse
if self.reduction == 'mean':
return torch.mean(focal_loss)
elif self.reduction == 'sum':
return torch.sum(focal_loss)
else:
return focal_loss
class MultiTaskLoss(nn.Module):
"""
多任务损失函数
用于处理多个相关任务的联合训练,支持任务权重分配和任务不确定性加权
"""
def __init__(self,
num_tasks: int = 3,
task_weights: Optional[list] = None,
use_uncertainty_weighting: bool = False,
log_variance_init: float = 0.0):
"""
初始化多任务损失
Args:
num_tasks: 任务数量
task_weights: 各任务的固定权重
use_uncertainty_weighting: 是否使用任务不确定性加权
log_variance_init: 任务方差的对数初始化值
"""
super(MultiTaskLoss, self).__init__()
self.num_tasks = num_tasks
self.use_uncertainty_weighting = use_uncertainty_weighting
if task_weights is None:
task_weights = [1.0] * num_tasks
self.task_weights = task_weights
if use_uncertainty_weighting:
# 可学习的任务方差参数(log方差)
self.log_vars = nn.Parameter(torch.ones(num_tasks) * log_variance_init)
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算多任务损失
Args:
predictions: 预测值,形状为 (batch_size, output_dim)
targets: 真实值,形状为 (batch_size, output_dim)
Returns:
多任务损失
"""
# 分解任务
task_losses = []
for i in range(self.num_tasks):
task_pred = predictions[:, i:i+1]
task_target = targets[:, i:i+1]
task_loss = F.mse_loss(task_pred, task_target, reduction='mean')
task_losses.append(task_loss)
if self.use_uncertainty_weighting:
# 使用任务不确定性加权
# Loss = 1/(2*sigma^2) * MSE + log(sigma)
weighted_losses = [
torch.exp(-self.log_vars[i]) * task_losses[i] + self.log_vars[i]
for i in range(self.num_tasks)
]
total_loss = torch.stack(weighted_losses).sum()
else:
# 使用固定权重
weighted_losses = [
self.task_weights[i] * task_losses[i]
for i in range(self.num_tasks)
]
total_loss = torch.stack(weighted_losses).sum()
return total_loss
def get_task_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> list:
"""
获取各任务的损失值
Args:
predictions: 预测值
targets: 真实值
Returns:
各任务损失的列表
"""
task_losses = []
for i in range(self.num_tasks):
task_pred = predictions[:, i:i+1]
task_target = targets[:, i:i+1]
task_loss = F.mse_loss(task_pred, task_target, reduction='mean')
task_losses.append(task_loss.item())
return task_losses
def get_uncertainties(self) -> torch.Tensor:
"""
获取任务不确定性(标准差)
Returns:
各任务的标准差
"""
if self.use_uncertainty_weighting:
return torch.exp(self.log_vars)
else:
return torch.tensor(self.task_weights)
def create_loss_function(loss_type: str, **kwargs) -> nn.Module:
"""
创建损失函数的工厂函数
Args:
loss_type: 损失函数类型
**kwargs: 损失函数参数
Returns:
损失函数实例
"""
loss_functions = {
'wmse': WeightedMSELoss,
'confidence': ConfidenceLoss,
'adaptive': AdaptiveWeightedLoss,
'focal': FocalLoss,
'mse': lambda **kw: nn.MSELoss(**kw),
'l1': lambda **kw: nn.L1Loss(**kw)
}
if loss_type not in loss_functions:
raise ValueError(f"不支持的损失函数类型: {loss_type}. 支持的类型: {list(loss_functions.keys())}")
return loss_functions[loss_type](**kwargs)
if __name__ == "__main__":
# 测试代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建测试数据
batch_size = 4
predictions = torch.randn(batch_size, 5).to(device)
targets = torch.randn(batch_size, 5).to(device)
print("测试损失函数:")
print(f"输入形状: {predictions.shape}")
# 测试加权MSE损失
wmse_loss = WeightedMSELoss(
delta_pad_weight=1.0,
delta_pressure_weight=1.0,
confidence_weight=0.5
).to(device)
wmse = wmse_loss(predictions, targets)
component_losses = wmse_loss.get_component_losses(predictions, targets)
print(f"\n加权MSE损失: {wmse.item():.6f}")
print("组件损失:")
for key, value in component_losses.items():
print(f" {key}: {value.item():.6f}")
# 测试置信度损失
conf_loss = ConfidenceLoss().to(device)
conf = conf_loss(predictions, targets)
print(f"\n置信度损失: {conf.item():.6f}")
# 测试自适应加权损失
adaptive_loss = AdaptiveWeightedLoss().to(device)
adaptive = adaptive_loss(predictions, targets)
print(f"\n自适应加权损失: {adaptive.item():.6f}")
# 测试Focal Loss
focal_loss = FocalLoss().to(device)
focal = focal_loss(predictions, targets)
print(f"\nFocal Loss: {focal.item():.6f}")
print("\n损失函数测试完成!")