| """
|
| RadioUNet V3 推理脚本
|
| 使用训练好的模型对SoundMapDiff数据集进行推理
|
| """
|
|
|
| import os
|
| import sys
|
| import argparse
|
| import torch
|
| import numpy as np
|
| from PIL import Image
|
| import matplotlib.pyplot as plt
|
| from pathlib import Path
|
| from skimage.metrics import structural_similarity as ssim
|
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__), 'lib'))
|
|
|
| from lib.modules import RadioWNet
|
| from lib.soundmap_loader import SoundMapDataset
|
| from torch.utils.data import DataLoader
|
|
|
|
|
| def calculate_metrics(pred, target):
|
| """计算评估指标"""
|
| pred_np = pred.cpu().numpy().squeeze()
|
| target_np = target.cpu().numpy().squeeze()
|
|
|
|
|
| mse = np.mean((pred_np - target_np) ** 2)
|
|
|
|
|
| mae = np.mean(np.abs(pred_np - target_np))
|
|
|
|
|
| rmse = np.sqrt(mse)
|
|
|
|
|
| ssim_val = ssim(pred_np, target_np, data_range=1.0)
|
|
|
|
|
| if mse > 0:
|
| psnr = 10 * np.log10(1.0 / mse)
|
| else:
|
| psnr = float('inf')
|
|
|
| return {
|
| 'mse': mse,
|
| 'mae': mae,
|
| 'rmse': rmse,
|
| 'ssim': ssim_val,
|
| 'psnr': psnr
|
| }
|
|
|
|
|
| def visualize_prediction(inputs, target, pred, metrics, save_path):
|
| """可视化预测结果"""
|
| fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
|
|
|
|
| axes[0, 0].imshow(inputs[0].cpu().numpy(), cmap='gray')
|
| axes[0, 0].set_title('Building Layout', fontsize=14)
|
| axes[0, 0].axis('off')
|
|
|
|
|
| axes[0, 1].imshow(inputs[1].cpu().numpy(), cmap='hot')
|
| axes[0, 1].set_title('Sound Source', fontsize=14)
|
| axes[0, 1].axis('off')
|
|
|
|
|
| im1 = axes[1, 0].imshow(target.cpu().numpy().squeeze(), cmap='viridis', vmin=0, vmax=1)
|
| axes[1, 0].set_title('Ground Truth', fontsize=14)
|
| axes[1, 0].axis('off')
|
| plt.colorbar(im1, ax=axes[1, 0], fraction=0.046, pad=0.04)
|
|
|
|
|
| im2 = axes[1, 1].imshow(pred.cpu().numpy().squeeze(), cmap='viridis', vmin=0, vmax=1)
|
| axes[1, 1].set_title(f"Prediction (SSIM: {metrics['ssim']:.4f})", fontsize=14)
|
| axes[1, 1].axis('off')
|
| plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04)
|
|
|
|
|
| metrics_text = f"MSE: {metrics['mse']:.6f} | MAE: {metrics['mae']:.4f} | SSIM: {metrics['ssim']:.4f} | PSNR: {metrics['psnr']:.2f} dB"
|
| fig.suptitle(metrics_text, fontsize=12, y=0.02)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description='RadioUNet V3 推理脚本')
|
| parser.add_argument('--checkpoint', type=str,
|
| default='outputs/radiounet_v3/checkpoints/best_model.pth',
|
| help='模型检查点路径')
|
| parser.add_argument('--dataset_dir', type=str,
|
| default='/home/djk/generate/dataset/SoundMapDiff',
|
| help='数据集目录')
|
| parser.add_argument('--output_dir', type=str,
|
| default='outputs/radiounet_v3/inference',
|
| help='输出目录')
|
| parser.add_argument('--num_samples', type=int, default=20,
|
| help='推理样本数量')
|
| parser.add_argument('--img_size', type=int, default=256,
|
| help='图像尺寸')
|
|
|
| args = parser.parse_args()
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| print(f"使用设备: {device}")
|
|
|
|
|
| output_dir = Path(args.output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| print(f"加载模型: {args.checkpoint}")
|
| model = RadioWNet(inputs=2, phase="firstU").to(device)
|
|
|
| checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| if 'model_state_dict' in checkpoint:
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| print(f"加载Epoch {checkpoint.get('epoch', 'unknown')}的模型")
|
| else:
|
| model.load_state_dict(checkpoint)
|
|
|
| model.eval()
|
|
|
|
|
| print(f"加载数据集: {args.dataset_dir}")
|
| test_dataset = SoundMapDataset(
|
| dataset_dir=args.dataset_dir,
|
| phase="test",
|
| img_size=args.img_size
|
| )
|
|
|
|
|
| total_samples = len(test_dataset)
|
| indices = np.linspace(0, total_samples - 1, args.num_samples, dtype=int)
|
|
|
| print(f"测试样本数: {total_samples}, 采样数: {args.num_samples}")
|
| print(f"\n{'='*60}")
|
| print("开始推理...")
|
| print(f"{'='*60}\n")
|
|
|
| all_metrics = []
|
|
|
| with torch.no_grad():
|
| for i, idx in enumerate(indices):
|
| inputs, target = test_dataset[idx]
|
| inputs = inputs.unsqueeze(0).to(device)
|
| target = target.unsqueeze(0).to(device)
|
|
|
|
|
| outputs = model(inputs)
|
| if isinstance(outputs, list):
|
| outputs = outputs[0]
|
|
|
|
|
| metrics = calculate_metrics(outputs.squeeze(0), target.squeeze(0))
|
| all_metrics.append(metrics)
|
|
|
|
|
| save_path = output_dir / f'prediction_{i+1}_idx{idx}.png'
|
| visualize_prediction(inputs.squeeze(0), target.squeeze(0),
|
| outputs.squeeze(0), metrics, save_path)
|
|
|
| print(f"样本 {i+1}/{args.num_samples} (idx={idx}): "
|
| f"SSIM={metrics['ssim']:.4f}, MSE={metrics['mse']:.6f}, PSNR={metrics['psnr']:.2f}dB")
|
|
|
|
|
| avg_metrics = {
|
| 'mse': np.mean([m['mse'] for m in all_metrics]),
|
| 'mae': np.mean([m['mae'] for m in all_metrics]),
|
| 'rmse': np.mean([m['rmse'] for m in all_metrics]),
|
| 'ssim': np.mean([m['ssim'] for m in all_metrics]),
|
| 'psnr': np.mean([m['psnr'] for m in all_metrics])
|
| }
|
|
|
| print(f"\n{'='*60}")
|
| print("平均评估指标")
|
| print(f"{'='*60}")
|
| print(f" 平均 MSE: {avg_metrics['mse']:.6f}")
|
| print(f" 平均 MAE: {avg_metrics['mae']:.4f}")
|
| print(f" 平均 RMSE: {avg_metrics['rmse']:.4f}")
|
| print(f" 平均 SSIM: {avg_metrics['ssim']:.4f}")
|
| print(f" 平均 PSNR: {avg_metrics['psnr']:.2f} dB")
|
| print(f"{'='*60}")
|
|
|
|
|
| report_path = output_dir / 'evaluation_report.txt'
|
| with open(report_path, 'w', encoding='utf-8') as f:
|
| f.write("RadioUNet V3 评估报告\n")
|
| f.write("=" * 60 + "\n\n")
|
| f.write(f"模型: {args.checkpoint}\n")
|
| f.write(f"测试样本数: {args.num_samples}\n\n")
|
|
|
| for i, (idx, m) in enumerate(zip(indices, all_metrics)):
|
| f.write(f"样本 {i+1} (索引 {idx}):\n")
|
| f.write(f" MSE: {m['mse']:.6f}\n")
|
| f.write(f" MAE: {m['mae']:.4f}\n")
|
| f.write(f" SSIM: {m['ssim']:.4f}\n")
|
| f.write(f" PSNR: {m['psnr']:.2f} dB\n\n")
|
|
|
| f.write("=" * 60 + "\n")
|
| f.write("平均指标:\n")
|
| f.write("=" * 60 + "\n")
|
| f.write(f" 平均 MSE: {avg_metrics['mse']:.6f}\n")
|
| f.write(f" 平均 MAE: {avg_metrics['mae']:.4f}\n")
|
| f.write(f" 平均 RMSE: {avg_metrics['rmse']:.4f}\n")
|
| f.write(f" 平均 SSIM: {avg_metrics['ssim']:.4f}\n")
|
| f.write(f" 平均 PSNR: {avg_metrics['psnr']:.2f} dB\n")
|
|
|
| print(f"\n✅ 推理完成!结果保存在: {output_dir}")
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|