|
|
|
|
|
""" |
|
|
评估脚本 |
|
|
Evaluation script for PAD Predictor |
|
|
|
|
|
该脚本实现了完整的模型评估流程,包括: |
|
|
- 加载训练好的模型 |
|
|
- 测试集评估和性能分析 |
|
|
- 生成详细的评估报告和可视化图表 |
|
|
- 支持模型比较和批量评估 |
|
|
- 置信度校准分析 |
|
|
- PAD特定指标分析 |
|
|
|
|
|
使用方法: |
|
|
python evaluate.py --model-path checkpoints/best_model.pth --data-path data/test.csv |
|
|
python evaluate.py --model-path checkpoints/final_model.pth --config configs/training_config.yaml |
|
|
python evaluate.py --compare-models model1.pth model2.pth --data-path data/test.csv |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import yaml |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any, Optional, Union, Tuple |
|
|
import logging |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.models.pad_predictor import PADPredictor, create_pad_predictor |
|
|
from src.data.data_loader import DataLoader, load_data_from_config |
|
|
from src.models.metrics import PADMetrics, RegressionMetrics, CalibrationMetrics |
|
|
from src.utils.logger import TrainingLogger, create_logger |
|
|
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
""" |
|
|
解析命令行参数 |
|
|
|
|
|
Returns: |
|
|
解析后的参数 |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='PAD预测器评估脚本', |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--model-path', '-m', |
|
|
type=str, |
|
|
required=True, |
|
|
help='模型文件路径' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--model-config', '-mc', |
|
|
type=str, |
|
|
default='configs/model_config.yaml', |
|
|
help='模型配置文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--data-path', '-d', |
|
|
type=str, |
|
|
help='测试数据文件路径' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--config', '-c', |
|
|
type=str, |
|
|
default='configs/training_config.yaml', |
|
|
help='训练配置文件路径(用于数据加载配置)' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--output-dir', '-o', |
|
|
type=str, |
|
|
default='evaluation_results', |
|
|
help='输出目录' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--report-name', '-r', |
|
|
type=str, |
|
|
default='evaluation_report', |
|
|
help='评估报告名称' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--batch-size', '-b', |
|
|
type=int, |
|
|
help='批次大小(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda', 'mps'], |
|
|
default='auto', |
|
|
help='评估设备' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--gpu-id', |
|
|
type=int, |
|
|
default=0, |
|
|
help='GPU ID(当使用CUDA时)' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--compare-models', |
|
|
nargs='+', |
|
|
type=str, |
|
|
help='比较多个模型,提供模型路径列表' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--model-names', |
|
|
nargs='+', |
|
|
type=str, |
|
|
help='比较模型时的名称列表' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--detailed-analysis', |
|
|
action='store_true', |
|
|
help='进行详细分析(包括组件级别分析)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--calibration-analysis', |
|
|
action='store_true', |
|
|
help='进行置信度校准分析' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--error-analysis', |
|
|
action='store_true', |
|
|
help='进行误差分析' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--generate-plots', |
|
|
action='store_true', |
|
|
default=True, |
|
|
help='生成可视化图表' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--synthetic-data', |
|
|
action='store_true', |
|
|
help='使用合成数据进行评估' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--num-samples', |
|
|
type=int, |
|
|
default=1000, |
|
|
help='合成数据样本数量' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--save-predictions', |
|
|
action='store_true', |
|
|
help='保存预测结果' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--format', |
|
|
choices=['json', 'csv', 'xlsx'], |
|
|
default='json', |
|
|
help='输出格式' |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_model(model_path: str, |
|
|
model_config: Optional[Dict[str, Any]] = None, |
|
|
device: Union[str, torch.device] = 'cpu') -> nn.Module: |
|
|
""" |
|
|
加载模型 |
|
|
|
|
|
Args: |
|
|
model_path: 模型文件路径 |
|
|
model_config: 模型配置 |
|
|
device: 设备 |
|
|
|
|
|
Returns: |
|
|
加载的模型 |
|
|
""" |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
if model_config is None and 'model_config' in checkpoint: |
|
|
model_config = checkpoint['model_config'] |
|
|
elif model_config is None: |
|
|
|
|
|
model_config = { |
|
|
'dimensions': {'input_dim': 10, 'output_dim': 4}, |
|
|
'architecture': { |
|
|
'hidden_layers': [ |
|
|
{'size': 128, 'activation': 'ReLU', 'dropout': 0.2}, |
|
|
{'size': 64, 'activation': 'ReLU', 'dropout': 0.2}, |
|
|
{'size': 32, 'activation': 'ReLU', 'dropout': 0.1} |
|
|
] |
|
|
}, |
|
|
'initialization': {'weight_init': 'xavier_uniform', 'bias_init': 'zeros'} |
|
|
} |
|
|
|
|
|
|
|
|
model = create_pad_predictor(model_config) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
logging.info(f"模型已加载: {model_path}") |
|
|
logging.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def load_data_for_evaluation(config: Dict[str, Any], |
|
|
data_path: Optional[str] = None, |
|
|
synthetic_data: bool = False, |
|
|
num_samples: int = 1000, |
|
|
batch_size: Optional[int] = None) -> torch.utils.data.DataLoader: |
|
|
""" |
|
|
加载评估数据 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
data_path: 数据文件路径 |
|
|
synthetic_data: 是否使用合成数据 |
|
|
num_samples: 合成数据样本数量 |
|
|
batch_size: 批次大小 |
|
|
|
|
|
Returns: |
|
|
数据加载器 |
|
|
""" |
|
|
if synthetic_data: |
|
|
|
|
|
logging.info(f"生成合成数据,样本数量: {num_samples}") |
|
|
|
|
|
from src.data.synthetic_generator import SyntheticDataGenerator |
|
|
generator = SyntheticDataGenerator(num_samples=num_samples) |
|
|
data, labels = generator.generate_data() |
|
|
|
|
|
|
|
|
data_loader_config = config.get('data', {}).get('dataloader', {}) |
|
|
if batch_size: |
|
|
data_loader_config['batch_size'] = batch_size |
|
|
|
|
|
data_loader = DataLoader(data_loader_config) |
|
|
test_loader = data_loader.get_test_loader(data=np.hstack([data, labels])) |
|
|
|
|
|
else: |
|
|
|
|
|
if data_path: |
|
|
|
|
|
logging.info(f"从文件加载数据: {data_path}") |
|
|
|
|
|
data_loader_config = config.get('data', {}).get('dataloader', {}) |
|
|
if batch_size: |
|
|
data_loader_config['batch_size'] = batch_size |
|
|
|
|
|
data_loader = DataLoader(data_loader_config) |
|
|
test_loader = data_loader.get_test_loader(data_path=data_path) |
|
|
|
|
|
else: |
|
|
|
|
|
logging.info("从配置文件加载数据") |
|
|
_, _, test_loader = load_data_from_config(config.get('data', {}).get('test_data_path', '')) |
|
|
|
|
|
logging.info(f"测试数据批次数: {len(test_loader)}") |
|
|
return test_loader |
|
|
|
|
|
|
|
|
def evaluate_model(model: nn.Module, |
|
|
data_loader: torch.utils.data.DataLoader, |
|
|
device: torch.device, |
|
|
save_predictions: bool = False, |
|
|
output_dir: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
评估单个模型 |
|
|
|
|
|
Args: |
|
|
model: 模型 |
|
|
data_loader: 数据加载器 |
|
|
device: 设备 |
|
|
save_predictions: 是否保存预测结果 |
|
|
output_dir: 输出目录 |
|
|
|
|
|
Returns: |
|
|
评估结果 |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
all_features = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, targets in data_loader: |
|
|
features = features.to(device) |
|
|
targets = targets.to(device) |
|
|
|
|
|
predictions = model(features) |
|
|
|
|
|
all_predictions.append(predictions.cpu()) |
|
|
all_targets.append(targets.cpu()) |
|
|
all_features.append(features.cpu()) |
|
|
|
|
|
|
|
|
all_predictions = torch.cat(all_predictions, dim=0) |
|
|
all_targets = torch.cat(all_targets, dim=0) |
|
|
all_features = torch.cat(all_features, dim=0) |
|
|
|
|
|
|
|
|
metrics = PADMetrics() |
|
|
evaluation_results = metrics.evaluate_predictions(all_predictions, all_targets) |
|
|
|
|
|
|
|
|
evaluation_results['predictions'] = all_predictions |
|
|
evaluation_results['targets'] = all_targets |
|
|
evaluation_results['features'] = all_features |
|
|
|
|
|
|
|
|
if save_predictions and output_dir: |
|
|
predictions_file = Path(output_dir) / 'predictions.csv' |
|
|
|
|
|
|
|
|
pred_df = pd.DataFrame(all_predictions.numpy(), |
|
|
columns=['delta_pad_p', 'delta_pad_a', 'delta_pad_d', 'delta_pressure', 'confidence']) |
|
|
target_df = pd.DataFrame(all_targets.numpy(), |
|
|
columns=['delta_pad_p', 'delta_pad_a', 'delta_pad_d', 'delta_pressure', 'confidence']) |
|
|
feature_df = pd.DataFrame(all_features.numpy(), |
|
|
columns=['user_pad_p', 'user_pad_a', 'user_pad_d', 'vitality', 'current_pad_p', 'current_pad_a', 'current_pad_d']) |
|
|
|
|
|
|
|
|
combined_df = pd.concat([feature_df, target_df, pred_df], axis=1) |
|
|
combined_df.to_csv(predictions_file, index=False) |
|
|
|
|
|
logging.info(f"预测结果已保存: {predictions_file}") |
|
|
|
|
|
return evaluation_results |
|
|
|
|
|
|
|
|
def generate_evaluation_report(results: Dict[str, Any], |
|
|
output_dir: str, |
|
|
report_name: str = 'evaluation_report', |
|
|
detailed_analysis: bool = False, |
|
|
calibration_analysis: bool = False, |
|
|
error_analysis: bool = False, |
|
|
generate_plots: bool = True) -> str: |
|
|
""" |
|
|
生成评估报告 |
|
|
|
|
|
Args: |
|
|
results: 评估结果 |
|
|
output_dir: 输出目录 |
|
|
report_name: 报告名称 |
|
|
detailed_analysis: 是否进行详细分析 |
|
|
calibration_analysis: 是否进行校准分析 |
|
|
error_analysis: 是否进行误差分析 |
|
|
generate_plots: 是否生成图表 |
|
|
|
|
|
Returns: |
|
|
报告文件路径 |
|
|
""" |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
metrics = PADMetrics() |
|
|
report_text = metrics.generate_evaluation_report( |
|
|
results['predictions'], |
|
|
results['targets'], |
|
|
save_path=output_path / f'{report_name}.txt' |
|
|
) |
|
|
|
|
|
|
|
|
json_results = {} |
|
|
for key, value in results.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
json_results[key] = value.tolist() |
|
|
elif isinstance(value, dict): |
|
|
json_results[key] = value |
|
|
else: |
|
|
json_results[key] = value |
|
|
|
|
|
|
|
|
for key in ['predictions', 'targets', 'features']: |
|
|
if key in json_results: |
|
|
del json_results[key] |
|
|
|
|
|
with open(output_path / f'{report_name}.json', 'w', encoding='utf-8') as f: |
|
|
json.dump(json_results, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
if generate_plots: |
|
|
generate_evaluation_plots(results, output_path, detailed_analysis, calibration_analysis, error_analysis) |
|
|
|
|
|
logging.info(f"评估报告已生成: {output_path / f'{report_name}.txt'}") |
|
|
return str(output_path / f'{report_name}.txt') |
|
|
|
|
|
|
|
|
def generate_evaluation_plots(results: Dict[str, Any], |
|
|
output_path: Path, |
|
|
detailed_analysis: bool = False, |
|
|
calibration_analysis: bool = False, |
|
|
error_analysis: bool = False): |
|
|
""" |
|
|
生成评估图表 |
|
|
|
|
|
Args: |
|
|
results: 评估结果 |
|
|
output_path: 输出路径 |
|
|
detailed_analysis: 是否进行详细分析 |
|
|
calibration_analysis: 是否进行校准分析 |
|
|
error_analysis: 是否进行误差分析 |
|
|
""" |
|
|
predictions = results['predictions'] |
|
|
targets = results['targets'] |
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8') |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
|
|
fig.suptitle('预测值 vs 真实值', fontsize=16) |
|
|
|
|
|
component_names = ['ΔPAD_P', 'ΔPAD_A', 'ΔPAD_D', 'ΔPressure'] |
|
|
|
|
|
for i, (ax, name) in enumerate(zip(axes.flat, component_names)): |
|
|
if i < predictions.size(1): |
|
|
pred_vals = predictions[:, i].numpy() |
|
|
true_vals = targets[:, i].numpy() |
|
|
|
|
|
ax.scatter(true_vals, pred_vals, alpha=0.6, s=20) |
|
|
|
|
|
|
|
|
min_val = min(true_vals.min(), pred_vals.min()) |
|
|
max_val = max(true_vals.max(), pred_vals.max()) |
|
|
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2) |
|
|
|
|
|
ax.set_xlabel('真实值') |
|
|
ax.set_ylabel('预测值') |
|
|
ax.set_title(name) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
r2 = np.corrcoef(true_vals, pred_vals)[0, 1] ** 2 |
|
|
ax.text(0.05, 0.95, f'R² = {r2:.3f}', transform=ax.transAxes, |
|
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path / 'prediction_vs_true.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
if detailed_analysis: |
|
|
fig, axes = plt.subplots(2, 3, figsize=(18, 12)) |
|
|
fig.suptitle('误差分布', fontsize=16) |
|
|
|
|
|
for i, (ax, name) in enumerate(zip(axes.flat, component_names)): |
|
|
if i < predictions.size(1): |
|
|
errors = (predictions[:, i] - targets[:, i]).numpy() |
|
|
|
|
|
ax.hist(errors, bins=30, alpha=0.7, density=True) |
|
|
ax.axvline(0, color='r', linestyle='--', linewidth=2) |
|
|
ax.axvline(np.mean(errors), color='g', linestyle='-', linewidth=2, label=f'均值: {np.mean(errors):.4f}') |
|
|
ax.axvline(np.median(errors), color='b', linestyle='-', linewidth=2, label=f'中位数: {np.median(errors):.4f}') |
|
|
|
|
|
ax.set_xlabel('误差') |
|
|
ax.set_ylabel('密度') |
|
|
ax.set_title(name) |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path / 'error_distribution.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if detailed_analysis: |
|
|
|
|
|
delta_pad_pred = predictions[:, :3] |
|
|
delta_pad_true = targets[:, :3] |
|
|
|
|
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(delta_pad_pred, delta_pad_true, dim=1) |
|
|
angle_errors = torch.acos(torch.clamp(cos_sim, -1 + 1e-8, 1 - 1e-8)) * 180 / np.pi |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
|
|
|
|
|
|
axes[0].hist(angle_errors.numpy(), bins=30, alpha=0.7, density=True) |
|
|
axes[0].set_xlabel('角度误差 (度)') |
|
|
axes[0].set_ylabel('密度') |
|
|
axes[0].set_title('PAD向量角度误差分布') |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].hist(cos_sim.numpy(), bins=30, alpha=0.7, density=True) |
|
|
axes[1].set_xlabel('余弦相似度') |
|
|
axes[1].set_ylabel('密度') |
|
|
axes[1].set_title('PAD向量余弦相似度分布') |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path / 'pad_analysis.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logging.info(f"评估图表已保存到: {output_path}") |
|
|
|
|
|
|
|
|
def compare_models(model_paths: List[str], |
|
|
model_names: List[str], |
|
|
data_loader: torch.utils.data.DataLoader, |
|
|
device: torch.device, |
|
|
output_dir: str) -> Dict[str, Any]: |
|
|
""" |
|
|
比较多个模型 |
|
|
|
|
|
Args: |
|
|
model_paths: 模型路径列表 |
|
|
model_names: 模型名称列表 |
|
|
data_loader: 数据加载器 |
|
|
device: 设备 |
|
|
output_dir: 输出目录 |
|
|
|
|
|
Returns: |
|
|
比较结果 |
|
|
""" |
|
|
if len(model_names) != len(model_paths): |
|
|
model_names = [f"Model_{i+1}" for i in range(len(model_paths))] |
|
|
|
|
|
comparison_results = {} |
|
|
|
|
|
logging.info(f"开始比较 {len(model_paths)} 个模型...") |
|
|
|
|
|
for model_path, model_name in zip(model_paths, model_names): |
|
|
logging.info(f"评估模型: {model_name} ({model_path})") |
|
|
|
|
|
try: |
|
|
|
|
|
model = load_model(model_path, device=device) |
|
|
|
|
|
|
|
|
results = evaluate_model(model, data_loader, device) |
|
|
|
|
|
|
|
|
key_metrics = {} |
|
|
if 'regression' in results: |
|
|
regression_metrics = results['regression'] |
|
|
if 'overall' in regression_metrics: |
|
|
for metric, value in regression_metrics['overall'].items(): |
|
|
key_metrics[f'regression_{metric}'] = value |
|
|
|
|
|
if 'calibration' in results: |
|
|
calibration_metrics = results['calibration'] |
|
|
for metric, value in calibration_metrics.items(): |
|
|
if isinstance(value, (int, float)): |
|
|
key_metrics[f'calibration_{metric}'] = value |
|
|
|
|
|
comparison_results[model_name] = { |
|
|
'model_path': model_path, |
|
|
'metrics': key_metrics, |
|
|
'full_results': results |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"评估模型 {model_name} 时发生错误: {e}") |
|
|
comparison_results[model_name] = {'error': str(e)} |
|
|
|
|
|
|
|
|
generate_comparison_report(comparison_results, output_dir) |
|
|
|
|
|
return comparison_results |
|
|
|
|
|
|
|
|
def generate_comparison_report(comparison_results: Dict[str, Any], output_dir: str): |
|
|
""" |
|
|
生成模型比较报告 |
|
|
|
|
|
Args: |
|
|
comparison_results: 比较结果 |
|
|
output_dir: 输出目录 |
|
|
""" |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
comparison_data = [] |
|
|
|
|
|
for model_name, results in comparison_results.items(): |
|
|
if 'error' in results: |
|
|
continue |
|
|
|
|
|
row = {'Model': model_name} |
|
|
row.update(results['metrics']) |
|
|
comparison_data.append(row) |
|
|
|
|
|
if comparison_data: |
|
|
df = pd.DataFrame(comparison_data) |
|
|
|
|
|
|
|
|
df.to_csv(output_path / 'model_comparison.csv', index=False) |
|
|
|
|
|
|
|
|
if len(comparison_data) > 1: |
|
|
|
|
|
key_metrics = ['regression_mae', 'regression_rmse', 'regression_r2', 'calibration_ece'] |
|
|
available_metrics = [m for m in key_metrics if m in df.columns] |
|
|
|
|
|
if available_metrics: |
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
|
|
axes = axes.flatten() |
|
|
|
|
|
for i, metric in enumerate(available_metrics): |
|
|
if i < len(axes): |
|
|
ax = axes[i] |
|
|
|
|
|
|
|
|
sorted_df = df.sort_values(metric, ascending=metric in ['regression_mae', 'regression_rmse', 'calibration_ece']) |
|
|
|
|
|
bars = ax.bar(range(len(sorted_df)), sorted_df[metric]) |
|
|
ax.set_xticks(range(len(sorted_df))) |
|
|
ax.set_xticklabels(sorted_df['Model'], rotation=45, ha='right') |
|
|
ax.set_ylabel(metric.replace('_', ' ').title()) |
|
|
ax.set_title(f'{metric.replace("_", " ").title()} Comparison') |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
for j, bar in enumerate(bars): |
|
|
height = bar.get_height() |
|
|
ax.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height:.4f}', ha='center', va='bottom') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path / 'model_comparison.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
report_lines = [] |
|
|
report_lines.append("=" * 60) |
|
|
report_lines.append("模型比较报告") |
|
|
report_lines.append("=" * 60) |
|
|
report_lines.append(f"比较时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
report_lines.append(f"模型数量: {len(comparison_results)}") |
|
|
report_lines.append("") |
|
|
|
|
|
for model_name, results in comparison_results.items(): |
|
|
report_lines.append(f"模型: {model_name}") |
|
|
if 'error' in results: |
|
|
report_lines.append(f" 错误: {results['error']}") |
|
|
else: |
|
|
report_lines.append(f" 路径: {results['model_path']}") |
|
|
for metric, value in results['metrics'].items(): |
|
|
report_lines.append(f" {metric}: {value:.6f}") |
|
|
report_lines.append("") |
|
|
|
|
|
report_text = "\n".join(report_lines) |
|
|
|
|
|
with open(output_path / 'comparison_report.txt', 'w', encoding='utf-8') as f: |
|
|
f.write(report_text) |
|
|
|
|
|
logging.info(f"模型比较报告已生成: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG if args.verbose else logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.info("开始PAD预测器评估") |
|
|
|
|
|
try: |
|
|
|
|
|
if args.device == 'auto': |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device(f'cuda:{args.gpu_id}') |
|
|
logger.info(f"使用GPU: {torch.cuda.get_device_name(args.gpu_id)}") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
logger.info("使用CPU") |
|
|
else: |
|
|
device = torch.device(args.device) |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.config: |
|
|
config = load_config(args.config) |
|
|
else: |
|
|
config = { |
|
|
'data': { |
|
|
'dataloader': { |
|
|
'batch_size': args.batch_size or 32, |
|
|
'num_workers': 0, |
|
|
'pin_memory': False, |
|
|
'shuffle': False |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if args.batch_size: |
|
|
config['data']['dataloader']['batch_size'] = args.batch_size |
|
|
|
|
|
|
|
|
model_config = None |
|
|
if args.model_config and os.path.exists(args.model_config): |
|
|
model_config = load_config(args.model_config) |
|
|
|
|
|
|
|
|
data_loader = load_data_for_evaluation( |
|
|
config, |
|
|
args.data_path, |
|
|
args.synthetic_data, |
|
|
args.num_samples, |
|
|
args.batch_size |
|
|
) |
|
|
|
|
|
if args.compare_models: |
|
|
|
|
|
logger.info(f"比较 {len(args.compare_models)} 个模型") |
|
|
|
|
|
comparison_results = compare_models( |
|
|
args.compare_models, |
|
|
args.model_names if args.model_names else [], |
|
|
data_loader, |
|
|
device, |
|
|
str(output_dir) |
|
|
) |
|
|
|
|
|
logger.info(f"模型比较完成,结果保存在: {output_dir}") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info(f"评估模型: {args.model_path}") |
|
|
|
|
|
|
|
|
model = load_model(args.model_path, model_config, device) |
|
|
|
|
|
|
|
|
results = evaluate_model( |
|
|
model, |
|
|
data_loader, |
|
|
device, |
|
|
args.save_predictions, |
|
|
str(output_dir) |
|
|
) |
|
|
|
|
|
|
|
|
report_path = generate_evaluation_report( |
|
|
results, |
|
|
str(output_dir), |
|
|
args.report_name, |
|
|
args.detailed_analysis, |
|
|
args.calibration_analysis, |
|
|
args.error_analysis, |
|
|
args.generate_plots |
|
|
) |
|
|
|
|
|
|
|
|
if 'regression' in results and 'overall' in results['regression']: |
|
|
overall_metrics = results['regression']['overall'] |
|
|
logger.info("评估结果:") |
|
|
logger.info(f" MAE: {overall_metrics.get('mae', 0):.6f}") |
|
|
logger.info(f" RMSE: {overall_metrics.get('rmse', 0):.6f}") |
|
|
logger.info(f" R²: {overall_metrics.get('r2', 0):.6f}") |
|
|
logger.info(f" MAPE: {overall_metrics.get('mape', 0):.6f}") |
|
|
|
|
|
if 'calibration' in results: |
|
|
calibration_metrics = results['calibration'] |
|
|
logger.info("校准指标:") |
|
|
logger.info(f" ECE: {calibration_metrics.get('ece', 0):.6f}") |
|
|
logger.info(f" Sharpness: {calibration_metrics.get('sharpness', 0):.6f}") |
|
|
|
|
|
logger.info(f"评估完成,报告保存在: {report_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"评估过程中发生错误: {e}") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |