|
|
""" |
|
|
损失函数模块 |
|
|
Loss Functions for PAD Predictor |
|
|
|
|
|
该模块包含了PAD预测器的各种损失函数,包括: |
|
|
- 加权均方误差损失(WMSE) |
|
|
- 置信度损失函数 |
|
|
- 组合损失函数 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, Any, Optional, Tuple |
|
|
import logging |
|
|
|
|
|
|
|
|
class WeightedMSELoss(nn.Module): |
|
|
""" |
|
|
加权均方误差损失函数 |
|
|
|
|
|
支持对不同输出组件(ΔPAD、ΔPressure、Confidence)设置不同的权重 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
delta_pad_weight: float = 1.0, |
|
|
delta_pressure_weight: float = 1.0, |
|
|
confidence_weight: float = 0.5, |
|
|
reduction: str = 'mean'): |
|
|
""" |
|
|
初始化加权MSE损失 |
|
|
|
|
|
Args: |
|
|
delta_pad_weight: ΔPAD损失的权重 |
|
|
delta_pressure_weight: ΔPressure损失的权重 |
|
|
confidence_weight: Confidence损失的权重 |
|
|
reduction: 损失聚合方式 ('mean', 'sum', 'none') |
|
|
""" |
|
|
super(WeightedMSELoss, self).__init__() |
|
|
|
|
|
self.delta_pad_weight = delta_pad_weight |
|
|
self.delta_pressure_weight = delta_pressure_weight |
|
|
self.confidence_weight = confidence_weight |
|
|
self.reduction = reduction |
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
计算加权MSE损失 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值,形状为 (batch_size, 5) |
|
|
targets: 真实值,形状为 (batch_size, 5) |
|
|
|
|
|
Returns: |
|
|
加权MSE损失 |
|
|
""" |
|
|
|
|
|
if predictions.shape != targets.shape: |
|
|
raise ValueError(f"预测值和真实值形状不匹配: {predictions.shape} vs {targets.shape}") |
|
|
|
|
|
if predictions.size(1) != 5: |
|
|
raise ValueError(f"输出维度应该是5,但得到的是 {predictions.size(1)}") |
|
|
|
|
|
|
|
|
pred_delta_pad = predictions[:, :3] |
|
|
pred_delta_pressure = predictions[:, 3:4] |
|
|
pred_confidence = predictions[:, 4:5] |
|
|
|
|
|
target_delta_pad = targets[:, :3] |
|
|
target_delta_pressure = targets[:, 3:4] |
|
|
target_confidence = targets[:, 4:5] |
|
|
|
|
|
|
|
|
mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction) |
|
|
mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction) |
|
|
mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction) |
|
|
|
|
|
|
|
|
total_loss = (self.delta_pad_weight * mse_delta_pad + |
|
|
self.delta_pressure_weight * mse_delta_pressure + |
|
|
self.confidence_weight * mse_confidence) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
def get_component_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
获取各组件的损失值 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值 |
|
|
targets: 真实值 |
|
|
|
|
|
Returns: |
|
|
包含各组件损失的字典 |
|
|
""" |
|
|
|
|
|
pred_delta_pad = predictions[:, :3] |
|
|
pred_delta_pressure = predictions[:, 3:4] |
|
|
pred_confidence = predictions[:, 4:5] |
|
|
|
|
|
target_delta_pad = targets[:, :3] |
|
|
target_delta_pressure = targets[:, 3:4] |
|
|
target_confidence = targets[:, 4:5] |
|
|
|
|
|
|
|
|
losses = { |
|
|
'delta_pad_mse': F.mse_loss(pred_delta_pad, target_delta_pad, reduction=self.reduction), |
|
|
'delta_pressure_mse': F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction=self.reduction), |
|
|
'confidence_mse': F.mse_loss(pred_confidence, target_confidence, reduction=self.reduction) |
|
|
} |
|
|
|
|
|
|
|
|
losses['weighted_total'] = (self.delta_pad_weight * losses['delta_pad_mse'] + |
|
|
self.delta_pressure_weight * losses['delta_pressure_mse'] + |
|
|
self.confidence_weight * losses['confidence_mse']) |
|
|
|
|
|
return losses |
|
|
|
|
|
|
|
|
class ConfidenceLoss(nn.Module): |
|
|
""" |
|
|
置信度损失函数 |
|
|
|
|
|
该损失函数旨在校准预测的置信度,使其能够反映实际的预测准确性 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
base_loss_weight: float = 1.0, |
|
|
confidence_weight: float = 0.1, |
|
|
temperature: float = 1.0, |
|
|
reduction: str = 'mean'): |
|
|
""" |
|
|
初始化置信度损失 |
|
|
|
|
|
Args: |
|
|
base_loss_weight: 基础损失(如MSE)的权重 |
|
|
confidence_weight: 置信度校准损失的权重 |
|
|
temperature: 温度参数,用于调节置信度的敏感度 |
|
|
reduction: 损失聚合方式 |
|
|
""" |
|
|
super(ConfidenceLoss, self).__init__() |
|
|
|
|
|
self.base_loss_weight = base_loss_weight |
|
|
self.confidence_weight = confidence_weight |
|
|
self.temperature = temperature |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
计算置信度损失 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值,形状为 (batch_size, 5) |
|
|
targets: 真实值,形状为 (batch_size, 5) |
|
|
|
|
|
Returns: |
|
|
置信度损失 |
|
|
""" |
|
|
|
|
|
pred_components = predictions[:, :4] |
|
|
pred_confidence = predictions[:, 4:5] |
|
|
|
|
|
target_components = targets[:, :4] |
|
|
|
|
|
|
|
|
base_loss = F.mse_loss(pred_components, target_components, reduction=self.reduction) |
|
|
|
|
|
|
|
|
if self.reduction == 'none': |
|
|
sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True) |
|
|
else: |
|
|
|
|
|
sample_errors = torch.mean((pred_components - target_components) ** 2, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
confidence = torch.sigmoid(pred_confidence / self.temperature) |
|
|
|
|
|
|
|
|
|
|
|
confidence_loss = -torch.log(confidence + 1e-8) * sample_errors |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
confidence_loss = torch.mean(confidence_loss) |
|
|
elif self.reduction == 'sum': |
|
|
confidence_loss = torch.sum(confidence_loss) |
|
|
|
|
|
|
|
|
total_loss = self.base_loss_weight * base_loss + self.confidence_weight * confidence_loss |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
class AdaptiveWeightedLoss(nn.Module): |
|
|
""" |
|
|
自适应加权损失函数 |
|
|
|
|
|
根据训练过程动态调整各组件的权重 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
initial_weights: Dict[str, float] = None, |
|
|
adaptation_rate: float = 0.01, |
|
|
min_weight: float = 0.1, |
|
|
max_weight: float = 2.0): |
|
|
""" |
|
|
初始化自适应加权损失 |
|
|
|
|
|
Args: |
|
|
initial_weights: 初始权重字典 |
|
|
adaptation_rate: 权重调整率 |
|
|
min_weight: 最小权重值 |
|
|
max_weight: 最大权重值 |
|
|
""" |
|
|
super(AdaptiveWeightedLoss, self).__init__() |
|
|
|
|
|
if initial_weights is None: |
|
|
initial_weights = { |
|
|
'delta_pad': 1.0, |
|
|
'delta_pressure': 1.0, |
|
|
'confidence': 0.5 |
|
|
} |
|
|
|
|
|
self.weights = nn.ParameterDict({ |
|
|
key: nn.Parameter(torch.tensor(value, dtype=torch.float32)) |
|
|
for key, value in initial_weights.items() |
|
|
}) |
|
|
|
|
|
self.adaptation_rate = adaptation_rate |
|
|
self.min_weight = min_weight |
|
|
self.max_weight = max_weight |
|
|
|
|
|
|
|
|
for param in self.weights.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
计算自适应加权损失 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值 |
|
|
targets: 真实值 |
|
|
|
|
|
Returns: |
|
|
自适应加权损失 |
|
|
""" |
|
|
|
|
|
pred_delta_pad = predictions[:, :3] |
|
|
pred_delta_pressure = predictions[:, 3:4] |
|
|
pred_confidence = predictions[:, 4:5] |
|
|
|
|
|
target_delta_pad = targets[:, :3] |
|
|
target_delta_pressure = targets[:, 3:4] |
|
|
target_confidence = targets[:, 4:5] |
|
|
|
|
|
|
|
|
mse_delta_pad = F.mse_loss(pred_delta_pad, target_delta_pad, reduction='mean') |
|
|
mse_delta_pressure = F.mse_loss(pred_delta_pressure, target_delta_pressure, reduction='mean') |
|
|
mse_confidence = F.mse_loss(pred_confidence, target_confidence, reduction='mean') |
|
|
|
|
|
|
|
|
total_loss = (self.weights['delta_pad'] * mse_delta_pad + |
|
|
self.weights['delta_pressure'] * mse_delta_pressure + |
|
|
self.weights['confidence'] * mse_confidence) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
def update_weights(self, component_losses: Dict[str, float]): |
|
|
""" |
|
|
根据组件损失更新权重 |
|
|
|
|
|
Args: |
|
|
component_losses: 各组件的损失值 |
|
|
""" |
|
|
|
|
|
total_loss = sum(component_losses.values()) |
|
|
|
|
|
|
|
|
for component, loss in component_losses.items(): |
|
|
if component in self.weights: |
|
|
|
|
|
new_weight = self.weights[component].item() * (1 + self.adaptation_rate * (loss / total_loss - 1/len(component_losses))) |
|
|
|
|
|
|
|
|
new_weight = max(self.min_weight, min(self.max_weight, new_weight)) |
|
|
|
|
|
|
|
|
self.weights[component].data.fill_(new_weight) |
|
|
|
|
|
def get_current_weights(self) -> Dict[str, float]: |
|
|
""" |
|
|
获取当前权重 |
|
|
|
|
|
Returns: |
|
|
当前权重字典 |
|
|
""" |
|
|
return {key: param.item() for key, param in self.weights.items()} |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
""" |
|
|
Focal Loss 变体,用于回归任务 |
|
|
|
|
|
专注于难预测的样本 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
alpha: float = 1.0, |
|
|
gamma: float = 2.0, |
|
|
reduction: str = 'mean'): |
|
|
""" |
|
|
初始化Focal Loss |
|
|
|
|
|
Args: |
|
|
alpha: 平衡因子 |
|
|
gamma: 聚焦参数 |
|
|
reduction: 损失聚合方式 |
|
|
""" |
|
|
super(FocalLoss, self).__init__() |
|
|
|
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
计算Focal Loss |
|
|
|
|
|
Args: |
|
|
predictions: 预测值 |
|
|
targets: 真实值 |
|
|
|
|
|
Returns: |
|
|
Focal Loss |
|
|
""" |
|
|
mse = F.mse_loss(predictions, targets, reduction='none') |
|
|
|
|
|
|
|
|
abs_error = torch.abs(predictions - targets) |
|
|
|
|
|
|
|
|
focal_weight = self.alpha * torch.pow(1 - torch.exp(-abs_error), self.gamma) |
|
|
|
|
|
focal_loss = focal_weight * mse |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return torch.mean(focal_loss) |
|
|
elif self.reduction == 'sum': |
|
|
return torch.sum(focal_loss) |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class MultiTaskLoss(nn.Module): |
|
|
""" |
|
|
多任务损失函数 |
|
|
|
|
|
用于处理多个相关任务的联合训练,支持任务权重分配和任务不确定性加权 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
num_tasks: int = 3, |
|
|
task_weights: Optional[list] = None, |
|
|
use_uncertainty_weighting: bool = False, |
|
|
log_variance_init: float = 0.0): |
|
|
""" |
|
|
初始化多任务损失 |
|
|
|
|
|
Args: |
|
|
num_tasks: 任务数量 |
|
|
task_weights: 各任务的固定权重 |
|
|
use_uncertainty_weighting: 是否使用任务不确定性加权 |
|
|
log_variance_init: 任务方差的对数初始化值 |
|
|
""" |
|
|
super(MultiTaskLoss, self).__init__() |
|
|
|
|
|
self.num_tasks = num_tasks |
|
|
self.use_uncertainty_weighting = use_uncertainty_weighting |
|
|
|
|
|
if task_weights is None: |
|
|
task_weights = [1.0] * num_tasks |
|
|
self.task_weights = task_weights |
|
|
|
|
|
if use_uncertainty_weighting: |
|
|
|
|
|
self.log_vars = nn.Parameter(torch.ones(num_tasks) * log_variance_init) |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
计算多任务损失 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值,形状为 (batch_size, output_dim) |
|
|
targets: 真实值,形状为 (batch_size, output_dim) |
|
|
|
|
|
Returns: |
|
|
多任务损失 |
|
|
""" |
|
|
|
|
|
task_losses = [] |
|
|
for i in range(self.num_tasks): |
|
|
task_pred = predictions[:, i:i+1] |
|
|
task_target = targets[:, i:i+1] |
|
|
task_loss = F.mse_loss(task_pred, task_target, reduction='mean') |
|
|
task_losses.append(task_loss) |
|
|
|
|
|
if self.use_uncertainty_weighting: |
|
|
|
|
|
|
|
|
weighted_losses = [ |
|
|
torch.exp(-self.log_vars[i]) * task_losses[i] + self.log_vars[i] |
|
|
for i in range(self.num_tasks) |
|
|
] |
|
|
total_loss = torch.stack(weighted_losses).sum() |
|
|
else: |
|
|
|
|
|
weighted_losses = [ |
|
|
self.task_weights[i] * task_losses[i] |
|
|
for i in range(self.num_tasks) |
|
|
] |
|
|
total_loss = torch.stack(weighted_losses).sum() |
|
|
|
|
|
return total_loss |
|
|
|
|
|
def get_task_losses(self, predictions: torch.Tensor, targets: torch.Tensor) -> list: |
|
|
""" |
|
|
获取各任务的损失值 |
|
|
|
|
|
Args: |
|
|
predictions: 预测值 |
|
|
targets: 真实值 |
|
|
|
|
|
Returns: |
|
|
各任务损失的列表 |
|
|
""" |
|
|
task_losses = [] |
|
|
for i in range(self.num_tasks): |
|
|
task_pred = predictions[:, i:i+1] |
|
|
task_target = targets[:, i:i+1] |
|
|
task_loss = F.mse_loss(task_pred, task_target, reduction='mean') |
|
|
task_losses.append(task_loss.item()) |
|
|
|
|
|
return task_losses |
|
|
|
|
|
def get_uncertainties(self) -> torch.Tensor: |
|
|
""" |
|
|
获取任务不确定性(标准差) |
|
|
|
|
|
Returns: |
|
|
各任务的标准差 |
|
|
""" |
|
|
if self.use_uncertainty_weighting: |
|
|
return torch.exp(self.log_vars) |
|
|
else: |
|
|
return torch.tensor(self.task_weights) |
|
|
|
|
|
|
|
|
def create_loss_function(loss_type: str, **kwargs) -> nn.Module: |
|
|
""" |
|
|
创建损失函数的工厂函数 |
|
|
|
|
|
Args: |
|
|
loss_type: 损失函数类型 |
|
|
**kwargs: 损失函数参数 |
|
|
|
|
|
Returns: |
|
|
损失函数实例 |
|
|
""" |
|
|
loss_functions = { |
|
|
'wmse': WeightedMSELoss, |
|
|
'confidence': ConfidenceLoss, |
|
|
'adaptive': AdaptiveWeightedLoss, |
|
|
'focal': FocalLoss, |
|
|
'mse': lambda **kw: nn.MSELoss(**kw), |
|
|
'l1': lambda **kw: nn.L1Loss(**kw) |
|
|
} |
|
|
|
|
|
if loss_type not in loss_functions: |
|
|
raise ValueError(f"不支持的损失函数类型: {loss_type}. 支持的类型: {list(loss_functions.keys())}") |
|
|
|
|
|
return loss_functions[loss_type](**kwargs) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
batch_size = 4 |
|
|
predictions = torch.randn(batch_size, 5).to(device) |
|
|
targets = torch.randn(batch_size, 5).to(device) |
|
|
|
|
|
print("测试损失函数:") |
|
|
print(f"输入形状: {predictions.shape}") |
|
|
|
|
|
|
|
|
wmse_loss = WeightedMSELoss( |
|
|
delta_pad_weight=1.0, |
|
|
delta_pressure_weight=1.0, |
|
|
confidence_weight=0.5 |
|
|
).to(device) |
|
|
|
|
|
wmse = wmse_loss(predictions, targets) |
|
|
component_losses = wmse_loss.get_component_losses(predictions, targets) |
|
|
|
|
|
print(f"\n加权MSE损失: {wmse.item():.6f}") |
|
|
print("组件损失:") |
|
|
for key, value in component_losses.items(): |
|
|
print(f" {key}: {value.item():.6f}") |
|
|
|
|
|
|
|
|
conf_loss = ConfidenceLoss().to(device) |
|
|
conf = conf_loss(predictions, targets) |
|
|
print(f"\n置信度损失: {conf.item():.6f}") |
|
|
|
|
|
|
|
|
adaptive_loss = AdaptiveWeightedLoss().to(device) |
|
|
adaptive = adaptive_loss(predictions, targets) |
|
|
print(f"\n自适应加权损失: {adaptive.item():.6f}") |
|
|
|
|
|
|
|
|
focal_loss = FocalLoss().to(device) |
|
|
focal = focal_loss(predictions, targets) |
|
|
print(f"\nFocal Loss: {focal.item():.6f}") |
|
|
|
|
|
print("\n损失函数测试完成!") |