|
|
""" |
|
|
MC Dropout 置信度计算器 |
|
|
Monte Carlo Dropout Confidence Calculator |
|
|
|
|
|
通过多次前向传播(开启Dropout)来量化预测不确定性。 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
import logging |
|
|
|
|
|
|
|
|
class MCDropoutConfidence: |
|
|
""" |
|
|
MC Dropout 置信度计算器 |
|
|
|
|
|
使用蒙特卡洛Dropout来量化模型预测的不确定性: |
|
|
- 对同一输入进行多次前向传播(开启Dropout) |
|
|
- 如果多次预测结果接近 → 高置信度 |
|
|
- 如果多次预测结果差异大 → 低置信度 |
|
|
|
|
|
参考: |
|
|
- Gal and Ghahramani (2016) "Dropout as a Bayesian Approximation" |
|
|
""" |
|
|
|
|
|
def __init__(self, model: nn.Module, n_samples: int = 30): |
|
|
""" |
|
|
初始化MC Dropout置信度计算器 |
|
|
|
|
|
Args: |
|
|
model: 包含Dropout层的PyTorch模型 |
|
|
n_samples: MC采样次数(默认30次) |
|
|
""" |
|
|
self.model = model |
|
|
self.n_samples = n_samples |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
self.has_dropout = self._check_dropout_layers() |
|
|
|
|
|
if not self.has_dropout: |
|
|
self.logger.warning("模型未发现Dropout层,MC Dropout将退化为确定性预测") |
|
|
|
|
|
def _check_dropout_layers(self) -> bool: |
|
|
"""检查模型是否包含Dropout层""" |
|
|
has_dropout = False |
|
|
for module in self.model.modules(): |
|
|
if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): |
|
|
has_dropout = True |
|
|
break |
|
|
return has_dropout |
|
|
|
|
|
def compute_confidence( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
enable_dropout: bool = True |
|
|
) -> Tuple[torch.Tensor, Dict[str, Any]]: |
|
|
""" |
|
|
计算预测的MC Dropout置信度 |
|
|
|
|
|
Args: |
|
|
x: 输入张量,形状为 (batch_size, input_dim) |
|
|
enable_dropout: 是否启用Dropout(如果模型没有Dropout层,此参数无效) |
|
|
|
|
|
Returns: |
|
|
(predictions, confidence_dict) |
|
|
- predictions: 平均预测值,形状为 (batch_size, output_dim) |
|
|
- confidence_dict: 包含置信度相关指标的字典 |
|
|
""" |
|
|
original_mode = self.model.training |
|
|
|
|
|
|
|
|
if enable_dropout and self.has_dropout: |
|
|
self.model.train() |
|
|
else: |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
predictions_list = [] |
|
|
with torch.no_grad(): |
|
|
for _ in range(self.n_samples): |
|
|
pred = self.model(x) |
|
|
predictions_list.append(pred) |
|
|
|
|
|
|
|
|
self.model.train(original_mode) |
|
|
|
|
|
|
|
|
all_predictions = torch.stack(predictions_list, dim=0) |
|
|
|
|
|
|
|
|
mean_pred = torch.mean(all_predictions, dim=0) |
|
|
std_pred = torch.std(all_predictions, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
confidence_by_std = torch.exp(-std_pred) |
|
|
|
|
|
|
|
|
overall_confidence = torch.mean(confidence_by_std, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
confidence_std = torch.std(confidence_by_std, dim=1, keepdim=True) |
|
|
|
|
|
confidence_dict = { |
|
|
'mean_prediction': mean_pred, |
|
|
'std_prediction': std_pred, |
|
|
'confidence_by_dimension': confidence_by_std, |
|
|
'overall_confidence': overall_confidence, |
|
|
'confidence_std': confidence_std, |
|
|
'all_predictions': all_predictions |
|
|
} |
|
|
|
|
|
return mean_pred, confidence_dict |
|
|
|
|
|
def predict_with_confidence( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
return_all_samples: bool = False |
|
|
) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: |
|
|
""" |
|
|
预测并返回置信度(返回numpy格式) |
|
|
|
|
|
Args: |
|
|
x: 输入张量 |
|
|
return_all_samples: 是否返回所有采样结果 |
|
|
|
|
|
Returns: |
|
|
(predictions, confidence, info_dict) |
|
|
- predictions: 平均预测值 |
|
|
- confidence: 置信度值 |
|
|
- info_dict: 其他信息字典 |
|
|
""" |
|
|
mean_pred, conf_dict = self.compute_confidence(x) |
|
|
|
|
|
|
|
|
predictions_np = mean_pred.cpu().numpy() |
|
|
confidence_np = conf_dict['overall_confidence'].cpu().numpy() |
|
|
std_np = conf_dict['std_prediction'].cpu().numpy() |
|
|
|
|
|
info_dict = { |
|
|
'prediction_std': std_np, |
|
|
'confidence_by_dimension': conf_dict['confidence_by_dimension'].cpu().numpy(), |
|
|
'confidence_std': conf_dict['confidence_std'].cpu().numpy() |
|
|
} |
|
|
|
|
|
if return_all_samples: |
|
|
info_dict['all_predictions'] = conf_dict['all_predictions'].cpu().numpy() |
|
|
|
|
|
return predictions_np, confidence_np, info_dict |
|
|
|
|
|
|
|
|
def create_mc_dropout_confidence( |
|
|
model: nn.Module, |
|
|
n_samples: int = 30 |
|
|
) -> MCDropoutConfidence: |
|
|
""" |
|
|
创建MC Dropout置信度计算器的工厂函数 |
|
|
|
|
|
Args: |
|
|
model: 包含Dropout层的模型 |
|
|
n_samples: MC采样次数 |
|
|
|
|
|
Returns: |
|
|
MC Dropout置信度计算器实例 |
|
|
""" |
|
|
return MCDropoutConfidence(model, n_samples) |
|
|
|