Spaces:
Running
Running
| """Metrics for S2F training and evaluation. | |
| Includes: MSE, MS-SSIM, Pixel Correlation (Pearson), Relative Magnitude Error (WFM), | |
| and evaluation helpers for notebooks and scripts. | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from skimage.metrics import structural_similarity as ssim | |
| from scipy.stats import pearsonr | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| try: | |
| from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure | |
| from torchmetrics import MeanSquaredError | |
| HAS_TORCHMETRICS = True | |
| except ImportError: | |
| HAS_TORCHMETRICS = False | |
| def calculate_mse(y_true, y_pred): | |
| if isinstance(y_true, torch.Tensor): | |
| return F.mse_loss(y_pred, y_true).item() | |
| return float(np.mean((np.asarray(y_true) - np.asarray(y_pred)) ** 2)) | |
| def calculate_psnr(y_true, y_pred, max_pixel_value=1.0): | |
| mse = calculate_mse(y_true, y_pred) | |
| if mse == 0: | |
| return float('inf') | |
| return 20 * np.log10(max_pixel_value / np.sqrt(mse)) | |
| def calculate_ssim_tensor(y_true, y_pred, data_range=1.0): | |
| if isinstance(y_true, torch.Tensor): | |
| y_true = y_true.detach().cpu().numpy() | |
| if isinstance(y_pred, torch.Tensor): | |
| y_pred = y_pred.detach().cpu().numpy() | |
| ssim_values = [] | |
| batch_size = y_true.shape[0] | |
| for i in range(batch_size): | |
| if len(y_true.shape) == 4: | |
| true_img = y_true[i, 0] if y_true.shape[1] == 1 else y_true[i, 0] | |
| pred_img = y_pred[i, 0] if y_pred.shape[1] == 1 else y_pred[i, 0] | |
| else: | |
| true_img, pred_img = y_true[i], y_pred[i] | |
| ssim_values.append(ssim(true_img, pred_img, data_range=data_range)) | |
| return np.mean(ssim_values) | |
| def calculate_pearson_correlation(y_true, y_pred): | |
| if isinstance(y_true, torch.Tensor): | |
| y_true = y_true.cpu().numpy() | |
| if isinstance(y_pred, torch.Tensor): | |
| y_pred = y_pred.cpu().numpy() | |
| correlation, _ = pearsonr(y_true.flatten(), y_pred.flatten()) | |
| return correlation | |
| def calculate_individual_pixel_correlation(y_true, y_pred): | |
| """Pixel-wise Pearson correlation per sample in batch.""" | |
| if isinstance(y_true, torch.Tensor): | |
| y_true = y_true.cpu().numpy() | |
| if isinstance(y_pred, torch.Tensor): | |
| y_pred = y_pred.cpu().numpy() | |
| correlations = [] | |
| batch_size = y_true.shape[0] | |
| for i in range(batch_size): | |
| true_flat = y_true[i].flatten() | |
| pred_flat = y_pred[i].flatten() | |
| r, _ = pearsonr(true_flat, pred_flat) | |
| correlations.append(r) | |
| return correlations | |
| # --- WFM (Wrinkle Force Microscopy) metrics for heatmap as magnitude --- | |
| def _to_numpy_wfm(x): | |
| if isinstance(x, torch.Tensor): | |
| return x.detach().cpu().numpy() | |
| return np.asarray(x) | |
| def _ensure_shape_wfm(f): | |
| """Ensure (N, 2, H, W). Heatmap -> fx=magnitude, fy=0.""" | |
| if f.ndim == 3: | |
| if f.shape[-1] == 2: | |
| f = np.transpose(f, (2, 0, 1))[None, ...] | |
| elif f.shape[0] == 2: | |
| f = f[None, ...] | |
| else: | |
| raise ValueError(f"Unsupported 3D shape {f.shape}") | |
| elif f.ndim == 4: | |
| if f.shape[-1] == 2: | |
| f = np.transpose(f, (0, 3, 1, 2)) | |
| else: | |
| raise ValueError(f"Unsupported ndim={f.ndim}") | |
| return f | |
| def _force_mag_wfm(f): | |
| fx, fy = f[:, 0], f[:, 1] | |
| return np.sqrt(fx**2 + fy**2) | |
| def wfm_correlation(y_true, y_pred, mode="magnitude"): | |
| """Pearson correlation between prediction and ground truth (magnitude mode for heatmaps).""" | |
| t = _ensure_shape_wfm(_to_numpy_wfm(y_true)) | |
| p = _ensure_shape_wfm(_to_numpy_wfm(y_pred)) | |
| if t.shape != p.shape: | |
| raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}") | |
| if mode == "magnitude": | |
| tv = _force_mag_wfm(t).ravel() | |
| pv = _force_mag_wfm(p).ravel() | |
| else: | |
| raise ValueError(f"Unknown mode '{mode}'") | |
| tv, pv = tv.astype(np.float64), pv.astype(np.float64) | |
| if np.allclose(tv.std(), 0) or np.allclose(pv.std(), 0): | |
| return 0.0 | |
| return float(np.corrcoef(tv, pv)[0, 1]) | |
| def wfm_relative_magnitude_error(y_true, y_pred, eps=1e-8): | |
| """Relative magnitude error for heatmap-as-magnitude.""" | |
| t = _ensure_shape_wfm(_to_numpy_wfm(y_true)) | |
| p = _ensure_shape_wfm(_to_numpy_wfm(y_pred)) | |
| if t.shape != p.shape: | |
| raise ValueError(f"Shape mismatch: true {t.shape} vs pred {p.shape}") | |
| mag_t = _force_mag_wfm(t) | |
| mag_p = _force_mag_wfm(p) | |
| fbar = np.mean(mag_t) | |
| if np.isclose(fbar, 0): | |
| return 0.0 | |
| rel = np.abs(mag_p - mag_t) / (mag_t + eps) | |
| w = mag_t / fbar | |
| return float(np.mean(rel * w)) | |
| def apply_threshold_mask(tensor, threshold=0.0): | |
| return tensor * (tensor >= threshold).float() | |
| def detect_tanh_output_model(model): | |
| """Detect if model outputs [-1, 1] (Tanh).""" | |
| if hasattr(model, 'use_sigmoid') and not model.use_sigmoid: | |
| return True | |
| if hasattr(model, 'use_tanh_output') and model.use_tanh_output: | |
| return True | |
| if hasattr(model, 'final_conv'): | |
| fc = model.final_conv | |
| if isinstance(fc, nn.Sequential): | |
| if isinstance(fc[-1], nn.Tanh): | |
| return True | |
| elif isinstance(fc, nn.Tanh): | |
| return True | |
| return False | |
| def convert_tanh_to_sigmoid_range(tensor): | |
| return (tensor + 1.0) / 2.0 | |
| # --- TorchMetrics wrapper for MS-SSIM --- | |
| class TorchMetricsWrapper: | |
| def __init__(self, device='cpu'): | |
| self.device = device | |
| self.reset_metrics() | |
| def reset_metrics(self): | |
| if HAS_TORCHMETRICS: | |
| self.ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) | |
| self.mse = MeanSquaredError().to(self.device) | |
| else: | |
| self.ms_ssim = None | |
| self.mse = None | |
| def compute_ms_ssim(self, y_true, y_pred): | |
| if not HAS_TORCHMETRICS: | |
| return float(calculate_ssim_tensor(y_true, y_pred)) # fallback to SSIM | |
| y_true = y_true.to(self.device) | |
| y_pred = y_pred.to(self.device) | |
| if y_true.shape[1] == 1: | |
| pass | |
| else: | |
| y_true, y_pred = y_true[:, 0:1], y_pred[:, 0:1] | |
| return self.ms_ssim(y_pred, y_true).item() | |
| def compute_mse(self, y_true, y_pred): | |
| if not HAS_TORCHMETRICS: | |
| return calculate_mse(y_true, y_pred) | |
| y_true = y_true.to(self.device) | |
| y_pred = y_pred.to(self.device) | |
| return self.mse(y_pred, y_true).item() | |
| # --- Full evaluation on dataset --- | |
| def evaluate_metrics_on_dataset(generator, data_loader, device=None, description="Evaluating", | |
| save_predictions=False, threshold=0.0, use_settings=False, | |
| normalization_params=None, config_path=None, substrate_override=None): | |
| """ | |
| Evaluate S2F generator on a dataset. Returns MSE, MS-SSIM, Pixel Correlation, | |
| Relative Magnitude Error, and force sum/mean correlations. | |
| """ | |
| if device is None: | |
| device = torch.device('mps' if torch.backends.mps.is_available() else | |
| 'cuda' if torch.cuda.is_available() else 'cpu') | |
| generator = generator.to(device) | |
| generator.eval() | |
| metrics_wrapper = TorchMetricsWrapper(device=device) | |
| heatmap_mse = [] | |
| heatmap_ms_ssim = [] | |
| heatmap_pixel_corr = [] | |
| wfm_corr_mag = [] | |
| wfm_rel_mag_err = [] | |
| force_sum_gt, force_sum_pred = [], [] | |
| force_mean_gt, force_mean_pred = [], [] | |
| individual_predictions = [] if save_predictions else None | |
| with torch.no_grad(): | |
| for batch_idx, batch_data in enumerate(tqdm(data_loader, desc=description)): | |
| if len(batch_data) == 5: | |
| images, heatmaps, _, _, metadata = batch_data | |
| has_metadata = True | |
| else: | |
| images, heatmaps, _, _ = batch_data | |
| has_metadata = False | |
| images = images.to(device, dtype=torch.float32) | |
| heatmaps = heatmaps.to(device, dtype=torch.float32) | |
| if use_settings and normalization_params is not None: | |
| from models.s2f_model import create_settings_channels | |
| meta = metadata if has_metadata else {'substrate': [substrate_override or 'fibroblasts_PDMS'] * images.size(0)} | |
| settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path) | |
| images = torch.cat([images, settings_ch], dim=1) | |
| pred = generator(images) | |
| if detect_tanh_output_model(generator): | |
| pred = convert_tanh_to_sigmoid_range(pred) | |
| gt_thresh = apply_threshold_mask(heatmaps, threshold) | |
| pred_thresh = pred # no threshold on pred for metrics | |
| heatmap_mse.append(metrics_wrapper.compute_mse(gt_thresh, pred_thresh)) | |
| heatmap_ms_ssim.append(metrics_wrapper.compute_ms_ssim(gt_thresh, pred_thresh)) | |
| heatmap_pixel_corr.extend(calculate_individual_pixel_correlation(gt_thresh, pred_thresh)) | |
| # WFM: heatmap as magnitude (fx=magnitude, fy=0) | |
| B, _, H, W = gt_thresh.shape | |
| gt_ff = torch.zeros(B, 2, H, W, device=device) | |
| pred_ff = torch.zeros(B, 2, H, W, device=device) | |
| gt_ff[:, 0], pred_ff[:, 0] = gt_thresh[:, 0], pred_thresh[:, 0] | |
| try: | |
| wfm_corr_mag.append(wfm_correlation(gt_ff, pred_ff, mode="magnitude")) | |
| wfm_rel_mag_err.append(wfm_relative_magnitude_error(gt_ff, pred_ff)) | |
| except Exception: | |
| wfm_corr_mag.append(float('nan')) | |
| wfm_rel_mag_err.append(float('nan')) | |
| force_sum_gt.extend(torch.sum(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) | |
| force_sum_pred.extend(torch.sum(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) | |
| force_mean_gt.extend(torch.mean(gt_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) | |
| force_mean_pred.extend(torch.mean(pred_thresh, dim=[1, 2, 3]).cpu().numpy().tolist()) | |
| if save_predictions: | |
| for i in range(images.size(0)): | |
| p, t = pred_thresh[i:i+1], gt_thresh[i:i+1] | |
| gt_ff_i = torch.zeros(1, 2, H, W, device=device) | |
| pred_ff_i = torch.zeros(1, 2, H, W, device=device) | |
| gt_ff_i[0, 0], pred_ff_i[0, 0] = t[0, 0], p[0, 0] | |
| try: | |
| rme = wfm_relative_magnitude_error(gt_ff_i, pred_ff_i) | |
| except Exception: | |
| rme = float('nan') | |
| individual_predictions.append({ | |
| 'batch_idx': batch_idx, | |
| 'sample_idx': i, | |
| 'original_image': images[i].cpu().numpy(), | |
| 'ground_truth': heatmaps[i].cpu().numpy(), | |
| 'ground_truth_thresholded': gt_thresh[i].cpu().numpy(), | |
| 'prediction': pred[i].cpu().numpy(), | |
| 'prediction_thresholded': pred_thresh[i].cpu().numpy(), | |
| 'mse': metrics_wrapper.compute_mse(t, p), | |
| 'ms_ssim': metrics_wrapper.compute_ms_ssim(t, p), | |
| 'pixel_correlation': calculate_pearson_correlation(t, p), | |
| 'wfm_relative_magnitude_error': rme, | |
| 'force_sum_gt': torch.sum(gt_thresh[i]).item(), | |
| 'force_sum_pred': torch.sum(pred_thresh[i]).item(), | |
| 'force_mean_gt': torch.mean(gt_thresh[i]).item(), | |
| 'force_mean_pred': torch.mean(pred_thresh[i]).item(), | |
| }) | |
| valid_wfm_corr = [x for x in wfm_corr_mag if not np.isnan(x)] | |
| valid_wfm_rme = [x for x in wfm_rel_mag_err if not np.isnan(x)] | |
| try: | |
| force_sum_corr, _ = pearsonr(force_sum_gt, force_sum_pred) | |
| force_mean_corr, _ = pearsonr(force_mean_gt, force_mean_pred) | |
| except Exception: | |
| force_sum_corr = force_mean_corr = 0.0 | |
| if force_sum_corr is None or (isinstance(force_sum_corr, float) and np.isnan(force_sum_corr)): | |
| force_sum_corr = 0.0 | |
| if force_mean_corr is None or (isinstance(force_mean_corr, float) and np.isnan(force_mean_corr)): | |
| force_mean_corr = 0.0 | |
| results = { | |
| 'heatmap': { | |
| 'mse': np.mean(heatmap_mse), | |
| 'mse_std': np.std(heatmap_mse), | |
| 'ms_ssim': np.mean(heatmap_ms_ssim), | |
| 'ms_ssim_std': np.std(heatmap_ms_ssim), | |
| 'pixel_correlation': np.mean(heatmap_pixel_corr), | |
| 'pixel_correlation_std': np.std(heatmap_pixel_corr), | |
| }, | |
| 'wfm': { | |
| 'correlation_magnitude': np.mean(valid_wfm_corr) if valid_wfm_corr else float('nan'), | |
| 'correlation_magnitude_std': np.std(valid_wfm_corr) if valid_wfm_corr else float('nan'), | |
| 'relative_magnitude_error': np.mean(valid_wfm_rme) if valid_wfm_rme else float('nan'), | |
| 'relative_magnitude_error_std': np.std(valid_wfm_rme) if valid_wfm_rme else float('nan'), | |
| }, | |
| 'force_sum': { | |
| 'correlation': float(force_sum_corr), | |
| 'gt_mean': np.mean(force_sum_gt), | |
| 'pred_mean': np.mean(force_sum_pred), | |
| 'gt_std': np.std(force_sum_gt), | |
| 'pred_std': np.std(force_sum_pred), | |
| }, | |
| 'force_mean': { | |
| 'correlation': float(force_mean_corr), | |
| 'gt_mean': np.mean(force_mean_gt), | |
| 'pred_mean': np.mean(force_mean_pred), | |
| }, | |
| } | |
| if save_predictions: | |
| results['individual_predictions'] = individual_predictions | |
| return results | |
| def print_metrics_report(report, threshold=0.0, uses_tanh=False): | |
| """Print formatted metrics report.""" | |
| for name, metrics in report.items(): | |
| print(f"\n🔸 {name.upper()} SET METRICS" + (f" (threshold={threshold})" if threshold > 0 else "")) | |
| print("-" * 60) | |
| print("HEATMAP METRICS:") | |
| print(f" MSE: {metrics['heatmap']['mse']:.6f} ± {metrics['heatmap']['mse_std']:.6f}") | |
| print(f" MS-SSIM: {metrics['heatmap']['ms_ssim']:.4f} ± {metrics['heatmap']['ms_ssim_std']:.4f}") | |
| print(f" Pixel Corr: {metrics['heatmap']['pixel_correlation']:.4f} ± {metrics['heatmap']['pixel_correlation_std']:.4f}") | |
| print("WFM METRICS (heatmap as magnitude):") | |
| print(f" Correlation (Magnitude): {metrics['wfm']['correlation_magnitude']:.4f} ± {metrics['wfm']['correlation_magnitude_std']:.4f}") | |
| print(f" Relative Magnitude Error: {metrics['wfm']['relative_magnitude_error']:.4f} ± {metrics['wfm']['relative_magnitude_error_std']:.4f}") | |
| print("FORCE SUM CORRELATION:") | |
| print(f" Correlation: {metrics['force_sum']['correlation']:.4f}") | |
| print(f" GT Mean: {metrics['force_sum']['gt_mean']:.2f} ± {metrics['force_sum']['gt_std']:.2f}") | |
| print(f" Pred Mean: {metrics['force_sum']['pred_mean']:.2f} ± {metrics['force_sum']['pred_std']:.2f}") | |
| if uses_tanh: | |
| print(" Note: Model outputs [-1,1], converted to [0,1] for evaluation") | |
| print("=" * 60) | |
| def gen_prediction_plots(individual_predictions, save_dir, sort_by='ms_ssim', sort_order='desc', threshold=0.0): | |
| """Generate prediction plots (BF | GT | Pred) sorted by metric.""" | |
| os.makedirs(save_dir, exist_ok=True) | |
| reverse = (sort_order.lower() == 'desc') if sort_by.lower() not in ['mse', 'wfm_relative_magnitude_error'] else (sort_order.lower() == 'desc') | |
| valid = [p for p in individual_predictions if not np.isnan(p.get(sort_by.lower(), 0))] | |
| sorted_preds = sorted(valid, key=lambda x: x[sort_by.lower()], reverse=reverse) | |
| print(f"Sorting {len(sorted_preds)} predictions by {sort_by} ({sort_order})") | |
| for rank, p in enumerate(tqdm(sorted_preds, desc="Saving plots"), 1): | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| img = p['original_image'] | |
| axes[0].imshow(img[0] if img.ndim == 3 else img, cmap='gray') | |
| axes[0].set_title('Bright Field') | |
| axes[0].axis('off') | |
| gt = p['ground_truth'] | |
| axes[1].imshow(gt[0] if gt.ndim == 3 else gt, cmap='jet', vmin=0, vmax=1) | |
| axes[1].set_title('Ground Truth') | |
| axes[1].axis('off') | |
| pr = p['prediction'] | |
| axes[2].imshow(pr[0] if pr.ndim == 3 else pr, cmap='jet', vmin=0, vmax=1) | |
| axes[2].set_title('Prediction') | |
| axes[2].axis('off') | |
| m = (f"MSE: {p['mse']:.4f} | MS-SSIM: {p['ms_ssim']:.4f} | " | |
| f"Pixel Corr: {p['pixel_correlation']:.4f} | Rel Mag Err: {p.get('wfm_relative_magnitude_error', 'N/A')}") | |
| fig.suptitle(f"Rank {rank} (by {sort_by})\n{m}", fontsize=10, y=0.02) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(save_dir, f"rank{rank:03d}_batch{p['batch_idx']:03d}_sample{p['sample_idx']:02d}.png"), dpi=150, bbox_inches='tight') | |
| plt.close() | |
| def plot_predictions(loader, generator, n_samples, device, threshold=0.0, | |
| use_settings=False, normalization_params=None, config_path=None, substrate_override=None): | |
| """Plot BF | GT | Pred for first n_samples from loader.""" | |
| generator = generator.to(device) | |
| generator.eval() | |
| bf_list, gt_list, meta_list = [], [], [] | |
| it = iter(loader) | |
| while len(bf_list) < n_samples: | |
| try: | |
| batch = next(it) | |
| except StopIteration: | |
| break | |
| if len(batch) == 5: | |
| images, heatmaps, _, _, meta = batch | |
| else: | |
| images, heatmaps = batch[0], batch[1] | |
| meta = None | |
| for i in range(images.shape[0]): | |
| if len(bf_list) >= n_samples: | |
| break | |
| bf_list.append(images[i]) | |
| gt_list.append(heatmaps[i]) | |
| meta_list.append(meta) | |
| n = min(n_samples, len(bf_list)) | |
| bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32) | |
| if use_settings and normalization_params: | |
| from models.s2f_model import create_settings_channels | |
| sub = substrate_override or 'fibroblasts_PDMS' | |
| meta_dict = {'substrate': [sub] * n} | |
| settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path) | |
| bf_batch = torch.cat([bf_batch, settings_ch], dim=1) | |
| with torch.no_grad(): | |
| pred = generator(bf_batch) | |
| if detect_tanh_output_model(generator): | |
| pred = convert_tanh_to_sigmoid_range(pred) | |
| if threshold > 0: | |
| pred = pred * (pred >= threshold).float() | |
| fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n)) | |
| if n == 1: | |
| axes = axes.reshape(1, -1) | |
| for i in range(n): | |
| axes[i, 0].imshow(bf_list[i].squeeze().cpu().numpy(), cmap='gray') | |
| axes[i, 0].set_title('Bright Field') | |
| axes[i, 0].axis('off') | |
| axes[i, 1].imshow(gt_list[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1) | |
| axes[i, 1].set_title('Ground Truth') | |
| axes[i, 1].axis('off') | |
| axes[i, 2].imshow(pred[i].squeeze().cpu().numpy(), cmap='jet', vmin=0, vmax=1) | |
| axes[i, 2].set_title('Prediction') | |
| axes[i, 2].axis('off') | |
| plt.tight_layout() | |
| plt.show() | |