""" 损失函数模块 Loss Functions for PAD Predictor 该模块包含了PAD预测器的各种损失函数,包括: - 加权均方误差损失(WMSE) - 置信度损失函数 - 组合损失函数 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Any, Optional, Tuple import logging class WeightedMSELoss(nn.Module): """ 加权均方误差损失函数 支持对不同输出组件(ΔPAD、ΔPressure、Confidence)设置不同的权重 """ def __init__(self, delta_pad_weight: float = 1.0, delta_pressure_weight: float = 1.0, confidence_weight: float = 0.5, reduction: str = 'mean'): """ 初始化加权MSE损失 Args: delta_pad_weight: ΔPAD损失的权重 delta_pressure_weight: ΔPressure损失的权重 confidence_weight: Confidence损失的权重 reduction: 损失聚合方式 ('mean', 'sum', 'none') """ super(WeightedMSELoss, self).__init__() self.delta_pad_weight = delta_pad_weight self.delta_pressure_weight = delta_pressure_weight self.confidence_weight = confidence_weight self.reduction = reduction self.logger = logging.getLogger(__name__) def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ 计算加权MSE损失 Args: predictions: 预测值,形状为 (batch_size, 5) targets: 真实值,形状为 (batch_size, 5) Returns: 加权MSE损失 """ # 输入验证 if predictions.shape != targets.shape: raise ValueError(f"预测值和真实值形状不匹配: {predictions.shape} vs {targets.shape}") if predictions.size(1) != 5: raise ValueError(f"输出维度应该是5,但得到的是 {predictions.size(1)}") # 分解输出组件 pred_delta_pad = predictions[:, :3] # ΔPAD (3维) pred_delta_pressure = predictions[:, 3:4] # ΔPressure (1维) pred_confidence = predictions[:, 4:5] # Confidence (1维) target_delta_pad = targets[:, :3] target_delta_pressure = targets[:, 3:4] target_confidence = targets[:, 4:5] # 计算各组件的MSE损失 mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction) mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction) mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction) # 加权求和 total_loss = (self.delta_pad_weight * mse_delta_pad + self.delta_pressure_weight * mse_delta_pressure + self.confidence_weight * mse_confidence) return total_loss def get_component_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, torch.Tensor]: """ 获取各组件的损失值 Args: predictions: 预测值 targets: 真实值 Returns: 包含各组件损失的字典 """ # 分解输出组件 pred_delta_pad = predictions[:, :3] pred_delta_pressure = predictions[:, 3:4] pred_confidence = predictions[:, 4:5] target_delta_pad = targets[:, :3] target_delta_pressure = targets[:, 3:4] target_confidence = targets[:, 4:5] # 计算各组件的MSE损失 losses = { 'delta_pad_mse': F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction), 'delta_pressure_mse': F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction), 'confidence_mse': F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction) } # 计算加权损失 losses['weighted_total'] = (self.delta_pad_weight * losses['delta_pad_mse'] + self.delta_pressure_weight * losses['delta_pressure_mse'] + self.confidence_weight * losses['confidence_mse']) return losses class ConfidenceLoss(nn.Module): """ 置信度损失函数 该损失函数旨在校准预测的置信度,使其能够反映实际的预测准确性 """ def __init__(self, base_loss_weight: float = 1.0, confidence_weight: float = 0.1, temperature: float = 1.0, reduction: str = 'mean'): """ 初始化置信度损失 Args: base_loss_weight: 基础损失(如MSE)的权重 confidence_weight: 置信度校准损失的权重 temperature: 温度参数,用于调节置信度的敏感度 reduction: 损失聚合方式 """ super(ConfidenceLoss, self).__init__() self.base_loss_weight = base_loss_weight self.confidence_weight = confidence_weight self.temperature = temperature self.reduction = reduction def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ 计算置信度损失 Args: predictions: 预测值,形状为 (batch_size, 5) targets: 真实值,形状为 (batch_size, 5) Returns: 置信度损失 """ # 分离预测和置信度 pred_components = predictions[:, :4] # ΔPAD (3维) + ΔPressure (1维) pred_confidence = predictions[:, 4:5] # Confidence (1维) target_components = targets[:, :4] # 计算基础损失(MSE) base_loss = F.mse_loss(pred_components, target_components, reduction=self.reduction) # 计算每个样本的预测误差 if self.reduction == 'none': sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True) else: # 如果使用mean或sum,需要计算每个样本的误差 sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True) # 将置信度映射到[0, 1]范围 confidence = torch.sigmoid(pred_confidence / self.temperature) # 置信度校准损失:希望高置信度对应低误差,低置信度对应高误差 # 使用负对数似然损失 confidence_loss = -torch.log(confidence + 1e-8) * sample_errors if self.reduction == 'mean': confidence_loss = torch.mean(confidence_loss) elif self.reduction == 'sum': confidence_loss = torch.sum(confidence_loss) # 组合损失 total_loss = self.base_loss_weight * base_loss + self.confidence_weight * confidence_loss return total_loss class AdaptiveWeightedLoss(nn.Module): """ 自适应加权损失函数 根据训练过程动态调整各组件的权重 """ def __init__(self, initial_weights: Dict[str, float] = None, adaptation_rate: float = 0.01, min_weight: float = 0.1, max_weight: float = 2.0): """ 初始化自适应加权损失 Args: initial_weights: 初始权重字典 adaptation_rate: 权重调整率 min_weight: 最小权重值 max_weight: 最大权重值 """ super(AdaptiveWeightedLoss, self).__init__() if initial_weights is None: initial_weights = { 'delta_pad': 1.0, 'delta_pressure': 1.0, 'confidence': 0.5 } self.weights = nn.ParameterDict({ key: nn.Parameter(torch.tensor(value, dtype=torch.float32)) for key, value in initial_weights.items() }) self.adaptation_rate = adaptation_rate self.min_weight = min_weight self.max_weight = max_weight # 冻结权重参数,不让优化器更新 for param in self.weights.parameters(): param.requires_grad = False def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ 计算自适应加权损失 Args: predictions: 预测值 targets: 真实值 Returns: 自适应加权损失 """ # 分解输出组件 pred_delta_pad = predictions[:, :3] pred_delta_pressure = predictions[:, 3:4] pred_confidence = predictions[:, 4:5] target_delta_pad = targets[:, :3] target_delta_pressure = targets[:, 3:4] target_confidence = targets[:, 4:5] # 计算各组件的MSE损失 mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction='mean') mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction='mean') mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction='mean') # 加权求和 total_loss = (self.weights['delta_pad'] * mse_delta_pad + self.weights['delta_pressure'] * mse_delta_pressure + self.weights['confidence'] * mse_confidence) return total_loss def update_weights(self, component_losses: Dict[str, float]): """ 根据组件损失更新权重 Args: component_losses: 各组件的损失值 """ # 计算总损失 total_loss = sum(component_losses.values()) # 更新权重:损失越大的组件,权重越高 for component, loss in component_losses.items(): if component in self.weights: # 计算新的权重 new_weight = self.weights[component].item() * (1 + self.adaptation_rate * (loss / total_loss - 1/len(component_losses))) # 限制权重范围 new_weight = max(self.min_weight, min(self.max_weight, new_weight)) # 更新权重 self.weights[component].data.fill_(new_weight) def get_current_weights(self) -> Dict[str, float]: """ 获取当前权重 Returns: 当前权重字典 """ return {key: param.item() for key, param in self.weights.items()} class FocalLoss(nn.Module): """ Focal Loss 变体,用于回归任务 专注于难预测的样本 """ def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'): """ 初始化Focal Loss Args: alpha: 平衡因子 gamma: 聚焦参数 reduction: 损失聚合方式 """ super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ 计算Focal Loss Args: predictions: 预测值 targets: 真实值 Returns: Focal Loss """ mse = F.mse_loss(predictions, targets, reduction='none') # 计算每个样本的误差 abs_error = torch.abs(predictions - targets) # 计算Focal权重 focal_weight = self.alpha * torch.pow(1 - torch.exp(-abs_error), self.gamma) focal_loss = focal_weight * mse if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss class MultiTaskLoss(nn.Module): """ 多任务损失函数 用于处理多个相关任务的联合训练,支持任务权重分配和任务不确定性加权 """ def __init__(self, num_tasks: int = 3, task_weights: Optional[list] = None, use_uncertainty_weighting: bool = False, log_variance_init: float = 0.0): """ 初始化多任务损失 Args: num_tasks: 任务数量 task_weights: 各任务的固定权重 use_uncertainty_weighting: 是否使用任务不确定性加权 log_variance_init: 任务方差的对数初始化值 """ super(MultiTaskLoss, self).__init__() self.num_tasks = num_tasks self.use_uncertainty_weighting = use_uncertainty_weighting if task_weights is None: task_weights = [1.0] * num_tasks self.task_weights = task_weights if use_uncertainty_weighting: # 可学习的任务方差参数(log方差) self.log_vars = nn.Parameter(torch.ones(num_tasks) * log_variance_init) def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ 计算多任务损失 Args: predictions: 预测值,形状为 (batch_size, output_dim) targets: 真实值,形状为 (batch_size, output_dim) Returns: 多任务损失 """ # 分解任务 task_losses = [] for i in range(self.num_tasks): task_pred = predictions[:, i:i+1] task_target = targets[:, i:i+1] task_loss = F.mse_loss(task_pred, task_target, reduction='mean') task_losses.append(task_loss) if self.use_uncertainty_weighting: # 使用任务不确定性加权 # Loss = 1/(2*sigma^2) * MSE + log(sigma) weighted_losses = [ torch.exp(-self.log_vars[i]) * task_losses[i] + self.log_vars[i] for i in range(self.num_tasks) ] total_loss = torch.stack(weighted_losses).sum() else: # 使用固定权重 weighted_losses = [ self.task_weights[i] * task_losses[i] for i in range(self.num_tasks) ] total_loss = torch.stack(weighted_losses).sum() return total_loss def get_task_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> list: """ 获取各任务的损失值 Args: predictions: 预测值 targets: 真实值 Returns: 各任务损失的列表 """ task_losses = [] for i in range(self.num_tasks): task_pred = predictions[:, i:i+1] task_target = targets[:, i:i+1] task_loss = F.mse_loss(task_pred, task_target, reduction='mean') task_losses.append(task_loss.item()) return task_losses def get_uncertainties(self) -> torch.Tensor: """ 获取任务不确定性(标准差) Returns: 各任务的标准差 """ if self.use_uncertainty_weighting: return torch.exp(self.log_vars) else: return torch.tensor(self.task_weights) def create_loss_function(loss_type: str, **kwargs) -> nn.Module: """ 创建损失函数的工厂函数 Args: loss_type: 损失函数类型 **kwargs: 损失函数参数 Returns: 损失函数实例 """ loss_functions = { 'wmse': WeightedMSELoss, 'confidence': ConfidenceLoss, 'adaptive': AdaptiveWeightedLoss, 'focal': FocalLoss, 'mse': lambda **kw: nn.MSELoss(**kw), 'l1': lambda **kw: nn.L1Loss(**kw) } if loss_type not in loss_functions: raise ValueError(f"不支持的损失函数类型: {loss_type}. 支持的类型: {list(loss_functions.keys())}") return loss_functions[loss_type](**kwargs) if __name__ == "__main__": # 测试代码 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 创建测试数据 batch_size = 4 predictions = torch.randn(batch_size, 5).to(device) targets = torch.randn(batch_size, 5).to(device) print("测试损失函数:") print(f"输入形状: {predictions.shape}") # 测试加权MSE损失 wmse_loss = WeightedMSELoss( delta_pad_weight=1.0, delta_pressure_weight=1.0, confidence_weight=0.5 ).to(device) wmse = wmse_loss(predictions, targets) component_losses = wmse_loss.get_component_losses(predictions, targets) print(f"\n加权MSE损失: {wmse.item():.6f}") print("组件损失:") for key, value in component_losses.items(): print(f" {key}: {value.item():.6f}") # 测试置信度损失 conf_loss = ConfidenceLoss().to(device) conf = conf_loss(predictions, targets) print(f"\n置信度损失: {conf.item():.6f}") # 测试自适应加权损失 adaptive_loss = AdaptiveWeightedLoss().to(device) adaptive = adaptive_loss(predictions, targets) print(f"\n自适应加权损失: {adaptive.item():.6f}") # 测试Focal Loss focal_loss = FocalLoss().to(device) focal = focal_loss(predictions, targets) print(f"\nFocal Loss: {focal.item():.6f}") print("\n损失函数测试完成!")