""" Signal reconstruction evaluation for Multi-Mechanism Normalizing Flow. For each test sample: 1. Infer mechanism + parameter posterior 2. Reconstruct signals from posterior mean, MAP, and random samples 3. Compare reconstructed signals with observed signals This validates whether the inferred posteriors produce physically consistent predictions, even when individual parameters have poor R² (due to compensation). Usage: python evaluate_reconstruction.py --checkpoint outputs/multi_mechanism_multiscan/.../best.pt python evaluate_reconstruction.py --checkpoint outputs/tpd_multiheat/.../best.pt --domain tpd """ import os import sys import json import glob import signal import argparse import time as time_module from pathlib import Path from collections import defaultdict import numpy as np import torch from tqdm import tqdm import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt class _Timeout: """Context manager that raises TimeoutError after `seconds` seconds. Uses POSIX signal.alarm in the main thread; falls back to a no-op in worker threads (e.g. Gradio request handlers) where signals are unavailable. """ def __init__(self, seconds): self.seconds = seconds self._use_signal = False def _handler(self, signum, frame): raise TimeoutError(f"Reconstruction timed out after {self.seconds}s") def __enter__(self): import threading if threading.current_thread() is threading.main_thread(): self._use_signal = True self._old = signal.signal(signal.SIGALRM, self._handler) signal.alarm(self.seconds) return self def __exit__(self, *args): if self._use_signal: signal.alarm(0) signal.signal(signal.SIGALRM, self._old) def parse_args(): parser = argparse.ArgumentParser(description="Evaluate signal reconstruction") parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--domain", type=str, default="ec", choices=["ec", "tpd"]) parser.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) parser.add_argument("--data_dir", type=str, default=None, help="Override data directory (e.g. for clean test set)") parser.add_argument("--max_samples", type=int, default=200) parser.add_argument("--n_posterior_samples", type=int, default=100, help="Posterior samples for reconstruction") parser.add_argument("--n_recon_samples", type=int, default=10, help="Number of random posterior samples to reconstruct per test sample") parser.add_argument("--n_visualize", type=int, default=20) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--temperature", type=float, default=1.0, help="Base distribution temperature for sampling (>1 broadens posteriors)") parser.add_argument("--noise_augmentation", action="store_true", help="Apply domain-appropriate noise augmentation to " "the eval signals (matches train-time noise " "distribution). Used for noise-robustness " "ablations.") parser.add_argument("--mechanisms", type=str, nargs="*", default=None, help="Restrict reconstruction to this subset of " "mechanism names (default: all in checkpoint).") parser.add_argument("--output_suffix", type=str, default="", help="Suffix for the eval output directory (e.g. " "'_noisy'). Lets clean and noisy recon evals " "coexist.") return parser.parse_args() # ============================================================================= # EC reconstruction # ============================================================================= def _safe_pow10(val): """Compute 10**val, clamping to avoid OverflowError for extreme values.""" val = np.clip(val, -300, 300) return 10.0 ** val def _configure_tpd_mechanisms(mechanism_list): """Restore the active TPD mechanism ordering used by a checkpoint.""" import generate_tpd_data as _tpd_gen import tpd_model as _tpd_mod import dataset_tpd as _ds_tpd active_mechs = ( list(mechanism_list) if mechanism_list is not None else list(_tpd_gen.TPD_MECHANISM_LIST) ) _tpd_gen.TPD_MECHANISM_LIST = active_mechs _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(active_mechs)} _tpd_mod.TPD_MECHANISM_LIST = active_mechs _ds_tpd.TPD_MECHANISM_LIST = active_mechs return active_mechs def theta_to_ec_params(theta, mechanism, base_params): """ Convert model output (log-space) back to physical simulator parameters. Args: theta: [D] numpy array of inferred parameters mechanism: str mechanism name base_params: dict of fixed parameters from the original sample """ from flow_model import MECHANISM_PARAMS names = MECHANISM_PARAMS[mechanism]['names'] p = dict(base_params) p['kinetics'] = mechanism for i, name in enumerate(names): val = float(theta[i]) if name.startswith('log10(') and name.endswith(')'): phys_name = name[6:-1] p[phys_name] = _safe_pow10(val) elif name in ('alpha', 'alpha_1', 'alpha_2'): p[name] = float(np.clip(val, 0.01, 0.99)) elif name in ('E0_offset', 'E0_2_offset'): p[name] = val return p def reconstruct_ec_signal(theta, mechanism, base_params, sigmas, n_spatial=64): """ Reconstruct CV signal(s) from inferred parameters. Args: theta: [D] inferred parameters (model output space) mechanism: str base_params: dict with fixed params (theta_i, theta_v, dA, etc.) sigmas: list of scan rates n_spatial: spatial grid points Returns: list of dicts with 'potential', 'flux', 'time', 'c_ox_surface', 'c_red_surface' per scan rate """ import warnings, inspect from generate_dataset_diffec import _run_single_cv _EXT_SIMULATORS = None if mechanism in ('EE', 'EC_prime', 'CE', 'ECE', 'EC_LH', 'MHC_EC', 'MHC_LH'): from generate_extended_mechanisms import SIMULATORS as _EXT_SIMULATORS phys = theta_to_ec_params(theta, mechanism, base_params) K0_at_1 = phys.get('K0', 1.0) K0_1_at_1 = phys.get('K0_1', 1.0) K0_2_at_1 = phys.get('K0_2', 1.0) kc_at_1 = phys.get('kc', 1.0) kf_at_1 = phys.get('kf', 1.0) results = [] for sigma in sigmas: p = dict(phys) p['sigma'] = float(sigma) if mechanism in ('BV', 'MHC', 'EC', 'LH', 'EC_prime', 'CE'): p['K0'] = K0_at_1 / np.sqrt(sigma) elif mechanism == 'Ads': p['K0'] = K0_at_1 / sigma elif mechanism == 'EE': p['K0_1'] = K0_1_at_1 / np.sqrt(sigma) p['K0_2'] = K0_2_at_1 / np.sqrt(sigma) p.setdefault('dA', 1.0) p.setdefault('dB', 1.0) if mechanism in ('EC', 'EC_prime'): p['kc'] = kc_at_1 / sigma if mechanism == 'CE': p['kf'] = kf_at_1 / sigma if mechanism in ('EC_prime', 'CE'): p.setdefault('dA', 1.0) p.setdefault('dB', 1.0) try: with _Timeout(30), warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) if _EXT_SIMULATORS is not None and mechanism in _EXT_SIMULATORS: sim_fn = _EXT_SIMULATORS[mechanism] sig = inspect.signature(sim_fn) kwargs = {k: v for k, v in p.items() if k in sig.parameters} kwargs['n_spatial_out'] = n_spatial result = sim_fn(**kwargs) else: result = _run_single_cv(p, n_spatial) entry = { 'potential': result['potential'], 'flux': result['flux'], 'time': result['time'], 'success': True, } if 'c_ox' in result and 'c_red' in result: entry['c_ox_surface'] = result['c_ox'][:, -1] entry['c_red_surface'] = result['c_red'][:, 0] results.append(entry) except Exception as e: results.append({'success': False, 'error': str(e)}) return results # ============================================================================= # TPD reconstruction # ============================================================================= def theta_to_tpd_params(theta, mechanism, base_params): """Convert model output back to physical TPD simulator parameters.""" from generate_tpd_data import TPD_MECHANISM_PARAMS names = TPD_MECHANISM_PARAMS[mechanism]['names'] p = dict(base_params) p['mechanism'] = mechanism for i, name in enumerate(names): val = float(theta[i]) if name.startswith('log10(') and name.endswith(')'): phys_name = name[6:-1] p[phys_name] = _safe_pow10(val) else: p[name] = val return p def reconstruct_tpd_signal(theta, mechanism, base_params, betas): """ Reconstruct TPD signal(s) from inferred parameters. Args: theta: [D] inferred parameters mechanism: str base_params: dict with T_start, T_end, etc. betas: list of heating rates Returns: list of dicts with 'temperature', 'rate', 'time' per heating rate """ import warnings from generate_tpd_data import _run_single_tpd phys = theta_to_tpd_params(theta, mechanism, base_params) results = [] for beta in betas: p = dict(phys) p['beta'] = float(beta) try: with _Timeout(30), warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) result = _run_single_tpd(p) results.append({ 'temperature': result['temperature'], 'rate': result['rate'], 'time': result['time'], 'success': True, }) except Exception as e: results.append({'success': False, 'error': str(e)}) return results # ============================================================================= # Metrics # ============================================================================= def _valid_signal(arr): """Check that a signal array contains no NaN/Inf/extreme values.""" if not np.all(np.isfinite(arr)): return False if np.max(np.abs(arr)) > 1e30: return False return True def signal_rmse(observed, reconstructed, length=None): """RMSE between two signals, optionally truncated to valid length.""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') return float(np.sqrt(np.mean((o - r) ** 2))) def signal_nrmse(observed, reconstructed, length=None): """Normalized RMSE (by peak-to-peak range of observed signal).""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') ptp = np.ptp(o) if ptp < 1e-20: return float('inf') return float(np.sqrt(np.mean((o - r) ** 2)) / ptp) def signal_r2(observed, reconstructed, length=None): """R² between observed and reconstructed signals.""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') ss_res = np.sum((o - r) ** 2) ss_tot = np.sum((o - np.mean(o)) ** 2) if ss_tot < 1e-20: return 0.0 return float(1 - ss_res / ss_tot) # ============================================================================= # EC evaluation # ============================================================================= def evaluate_ec(args): import multi_mechanism_model as _mm_module import flow_model as _fm_module from multi_mechanism_model import MultiMechanismFlow from flow_model import MECHANISM_PARAMS from dataset import DiffECDataset, collate_fn from torch.utils.data import DataLoader ckpt_path = os.path.expanduser(args.checkpoint) checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) ckpt_args = checkpoint['args'] device = 'cuda' if torch.cuda.is_available() else 'cpu' use_summary = ckpt_args.get('use_summary_features', False) # If the checkpoint was trained on a custom mechanism subset (e.g. v14_9mech), # patch the global MECHANISM_LIST in the model + flow_model modules BEFORE # instantiating MultiMechanismFlow, so the classifier and flow_heads sizes # match the checkpoint state_dict. if ckpt_args.get('mechanism_list') is not None: new_list = list(ckpt_args['mechanism_list']) _fm_module.MECHANISM_LIST = new_list _mm_module.MECHANISM_LIST = new_list print(f"Mechanism list overridden from checkpoint: {new_list} " f"({len(new_list)} mechanisms)") MECHANISM_LIST = _fm_module.MECHANISM_LIST model = MultiMechanismFlow( d_context=ckpt_args.get('d_context', 128), d_model=ckpt_args.get('d_model', 128), n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), hidden_dim=ckpt_args.get('hidden_dim', 96), coupling_type=ckpt_args.get('coupling_type', 'spline'), n_bins=ckpt_args.get('n_bins', 8), tail_bound=ckpt_args.get('tail_bound', 5.0), aggregation=ckpt_args.get('aggregation', 'set_transformer'), use_summary_features=use_summary, ) ckpt_dir = Path(ckpt_path).parent.parent theta_stats_path = ckpt_dir / "theta_stats.json" with open(theta_stats_path) as f: theta_stats = json.load(f) for mech in MECHANISM_LIST: if mech in theta_stats: model.set_theta_stats( mech, torch.tensor(theta_stats[mech]['mean']), torch.tensor(theta_stats[mech]['std']), ) norm_stats_path = ckpt_dir / "norm_stats.json" with open(norm_stats_path) as f: norm_stats = json.load(f) _ckpt_sd = checkpoint['model_state_dict'] _model_sd = model.state_dict() _filtered = {k: v for k, v in _ckpt_sd.items() if k not in _model_sd or v.shape == _model_sd[k].shape} model.load_state_dict(_filtered, strict=False) for m in model.modules(): if hasattr(m, '_initialized') and not m.initialized: m.initialized = True model = model.to(device) model.eval() if args.data_dir: data_dir = os.path.expanduser(args.data_dir) else: data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/DiffEC/data')) # Support raw per-mechanism directory structure: # {data_dir}/{Mechanism}/{split}/sample_*.npz # as well as the assembled flat structure: # {data_dir}/{split}/sample_*.npz split_dir = os.path.join(data_dir, args.split) raw_per_mechanism = False if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): # Try raw per-mechanism structure mech_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d, args.split))] if mech_dirs: raw_per_mechanism = True print(f"Detected raw per-mechanism directory structure in {data_dir}") print(f" Mechanisms found: {sorted(mech_dirs)}") # Create a temporary flat directory with symlinks import tempfile tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_") flat_dir = os.path.join(tmp_dir, args.split) os.makedirs(flat_dir, exist_ok=True) file_idx = 0 for mech_name in sorted(mech_dirs): mech_split = os.path.join(data_dir, mech_name, args.split) for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") os.symlink(os.path.abspath(f), dst) file_idx += 1 split_dir = flat_dir print(f" Linked {file_idx} samples into temporary flat directory") print(f"Loading data from: {split_dir}") noise_aug = None if getattr(args, 'noise_augmentation', False): from noise_augmentation import ECNoiseAugmentation noise_aug = ECNoiseAugmentation() print("Noise augmentation ENABLED for recon eval (matches train-time noise distribution)") dataset = DiffECDataset( split_dir, max_samples=args.max_samples, normalize_input=True, compute_summary=use_summary, noise_augmentation=noise_aug, ) dataset.potential_mean = norm_stats['potential'][0] dataset.potential_std = norm_stats['potential'][1] dataset.flux_mean = norm_stats['flux'][0] dataset.flux_std = norm_stats['flux'][1] dataset.time_mean = norm_stats['time'][0] dataset.time_std = norm_stats['time'][1] raw_dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=False) loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) suffix = getattr(args, 'output_suffix', '') eval_dir = ckpt_dir / f"eval_recon_{args.split}{suffix}" if raw_per_mechanism: eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}{suffix}" eval_dir.mkdir(exist_ok=True) per_mech_nrmse_mean = defaultdict(list) per_mech_nrmse_map = defaultdict(list) per_mech_nrmse_samples = defaultdict(list) per_mech_r2_mean = defaultdict(list) per_mech_r2_map = defaultdict(list) n_failed = 0 vis_count = defaultdict(int) print(f"Evaluating signal reconstruction on {len(dataset)} samples...") for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): x = batch['input'].to(device) scan_mask = batch['scan_mask'].to(device) sigmas_tensor = batch['sigmas'].to(device) flux_scales = batch['flux_scales'].to(device) summary = batch['summary'].to(device) if 'summary' in batch else None mech_id = batch['mechanism_id'].item() if mech_id < 0 or mech_id >= len(MECHANISM_LIST): continue mech = MECHANISM_LIST[mech_id] if args.mechanisms is not None and mech not in args.mechanisms: continue raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) raw_params = raw_data['params'].item() raw_flux = raw_data['flux'].astype(np.float32) raw_potential = raw_data['potential'].astype(np.float32) if 'sigmas' in raw_data: scan_rates = raw_data['sigmas'].astype(np.float64) lengths = raw_data['lengths'].astype(int) else: scan_rates = np.array([raw_params.get('sigma', 1.0)]) lengths = np.array([len(raw_potential)]) raw_flux = raw_flux[np.newaxis, :] raw_potential = raw_potential[np.newaxis, :] with torch.no_grad(): pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, flux_scales=flux_scales, n_samples=args.n_posterior_samples, temperature=args.temperature, summary=summary) if pred['stats'][mech] is None: continue theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() samples = pred['samples'][mech][0].cpu().numpy() # MAP estimate via 1D KDE from scipy.stats import gaussian_kde theta_map = np.zeros_like(theta_mean) for d in range(len(theta_mean)): s = samples[:, d] if np.std(s) < 1e-10: theta_map[d] = np.mean(s) else: try: kde = gaussian_kde(s) grid = np.linspace(s.min(), s.max(), 200) theta_map[d] = grid[np.argmax(kde(grid))] except Exception: theta_map[d] = np.median(s) base_params = dict(raw_params) # Reconstruct from mean recon_mean = reconstruct_ec_signal(theta_mean, mech, base_params, scan_rates) # Reconstruct from MAP recon_map = reconstruct_ec_signal(theta_map, mech, base_params, scan_rates) # Compute metrics per scan rate nrmse_mean_list = [] nrmse_map_list = [] r2_mean_list = [] r2_map_list = [] for s_idx in range(len(scan_rates)): obs_flux = raw_flux[s_idx] length = lengths[s_idx] if recon_mean[s_idx]['success']: v = signal_nrmse(obs_flux, recon_mean[s_idx]['flux'], length) r = signal_r2(obs_flux, recon_mean[s_idx]['flux'], length) if np.isfinite(v): nrmse_mean_list.append(v) if np.isfinite(r): r2_mean_list.append(r) if recon_map[s_idx]['success']: v = signal_nrmse(obs_flux, recon_map[s_idx]['flux'], length) r = signal_r2(obs_flux, recon_map[s_idx]['flux'], length) if np.isfinite(v): nrmse_map_list.append(v) if np.isfinite(r): r2_map_list.append(r) if nrmse_mean_list: per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) if r2_mean_list: per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) if nrmse_map_list: per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) if r2_map_list: per_mech_r2_map[mech].append(np.mean(r2_map_list)) # Reconstruct from random posterior samples sample_nrmses = [] n_recon = min(args.n_recon_samples, samples.shape[0]) sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) for si in sample_indices: recon_s = reconstruct_ec_signal(samples[si], mech, base_params, scan_rates) nrmses = [] for s_idx in range(len(scan_rates)): if recon_s[s_idx]['success']: v = signal_nrmse(raw_flux[s_idx], recon_s[s_idx]['flux'], lengths[s_idx]) if np.isfinite(v): nrmses.append(v) if nrmses: sample_nrmses.append(np.mean(nrmses)) if sample_nrmses: per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) # Visualization if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: fig, axes = plt.subplots(1, len(scan_rates), figsize=(5 * len(scan_rates), 4)) if len(scan_rates) == 1: axes = [axes] for s_idx, ax in enumerate(axes): length = lengths[s_idx] obs_pot = raw_potential[s_idx, :length] obs_flux_s = raw_flux[s_idx, :length] ax.plot(obs_pot, obs_flux_s, 'k-', lw=1.5, label='Observed', alpha=0.8) if recon_mean[s_idx]['success']: r_pot = recon_mean[s_idx]['potential'] r_flux = recon_mean[s_idx]['flux'] min_len = min(length, len(r_pot)) nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' ax.plot(r_pot[:min_len], r_flux[:min_len], 'r--', lw=1.2, label=lbl) if recon_map[s_idx]['success']: r_pot = recon_map[s_idx]['potential'] r_flux = recon_map[s_idx]['flux'] min_len = min(length, len(r_pot)) nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' ax.plot(r_pot[:min_len], r_flux[:min_len], 'b:', lw=1.2, label=lbl) # Plot a few posterior samples for si in sample_indices[:3]: recon_s = reconstruct_ec_signal(samples[si], mech, base_params, [scan_rates[s_idx]]) if recon_s[0]['success']: r_pot = recon_s[0]['potential'] r_flux = recon_s[0]['flux'] min_len = min(length, len(r_pot)) ax.plot(r_pot[:min_len], r_flux[:min_len], '-', lw=0.5, alpha=0.3, color='gray') ax.set_xlabel('Potential (θ)') ax.set_ylabel('Flux') ax.set_title(f'σ={scan_rates[s_idx]:.2f}') ax.legend(fontsize=7) fig.suptitle(f'{mech} sample {idx}', fontsize=12) plt.tight_layout() plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) plt.close(fig) vis_count[mech] += 1 # Print and save results print("\n" + "=" * 60) print("SIGNAL RECONSTRUCTION METRICS") print("=" * 60) results = {} for mech in MECHANISM_LIST: if not per_mech_nrmse_mean[mech]: continue nrmse_mean = np.array(per_mech_nrmse_mean[mech]) nrmse_map = np.array(per_mech_nrmse_map[mech]) nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) r2_mean = np.array(per_mech_r2_mean[mech]) r2_map = np.array(per_mech_r2_map[mech]) results[mech] = { 'n_samples': len(nrmse_mean), 'nrmse_mean': {'median': float(np.median(nrmse_mean)), 'mean': float(np.mean(nrmse_mean)), 'std': float(np.std(nrmse_mean)), 'q25': float(np.percentile(nrmse_mean, 25)), 'q75': float(np.percentile(nrmse_mean, 75))}, 'nrmse_map': {'median': float(np.median(nrmse_map)), 'mean': float(np.mean(nrmse_map)), 'std': float(np.std(nrmse_map))}, 'r2_signal_mean': {'median': float(np.median(r2_mean)), 'mean': float(np.mean(r2_mean))}, 'r2_signal_map': {'median': float(np.median(r2_map)), 'mean': float(np.mean(r2_map))}, } if len(nrmse_samp) > 0: results[mech]['nrmse_posterior_median'] = { 'median': float(np.median(nrmse_samp)), 'mean': float(np.mean(nrmse_samp)), } print(f"\n{mech} ({len(nrmse_mean)} samples):") print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") if len(nrmse_samp) > 0: print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " f"mean={np.mean(nrmse_samp):.4f}") print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") with open(eval_dir / "reconstruction_results.json", "w") as f: json.dump(results, f, indent=2) if raw_per_mechanism: import shutil shutil.rmtree(tmp_dir, ignore_errors=True) print(f"\nResults saved to {eval_dir}") print(f"Visualizations: {sum(vis_count.values())} plots saved") # ============================================================================= # TPD evaluation # ============================================================================= def evaluate_tpd(args): from tpd_model import MultiMechanismFlowTPD from generate_tpd_data import TPD_MECHANISM_PARAMS from dataset_tpd import TPDDataset, collate_fn from torch.utils.data import DataLoader ckpt_path = os.path.expanduser(args.checkpoint) checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) ckpt_args = checkpoint['args'] device = 'cuda' if torch.cuda.is_available() else 'cpu' use_summary = ckpt_args.get('use_summary_features', False) active_mechs = _configure_tpd_mechanisms(ckpt_args.get('mechanism_list')) model = MultiMechanismFlowTPD( d_context=ckpt_args.get('d_context', 128), d_model=ckpt_args.get('d_model', 128), n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), hidden_dim=ckpt_args.get('hidden_dim', 96), coupling_type=ckpt_args.get('coupling_type', 'spline'), n_bins=ckpt_args.get('n_bins', 8), tail_bound=ckpt_args.get('tail_bound', 5.0), use_summary_features=use_summary, mechanism_list=active_mechs, use_bounded_flow=ckpt_args.get('use_bounded_flow', False), ) ckpt_dir = Path(ckpt_path).parent.parent theta_stats_path = ckpt_dir / "theta_stats.json" with open(theta_stats_path) as f: theta_stats = json.load(f) for mech in active_mechs: if mech in theta_stats: model.set_theta_stats( mech, torch.tensor(theta_stats[mech]['mean']), torch.tensor(theta_stats[mech]['std']), ) norm_stats_path = ckpt_dir / "norm_stats.json" with open(norm_stats_path) as f: norm_stats = json.load(f) _ckpt_sd = checkpoint['model_state_dict'] _model_sd = model.state_dict() _filtered = {k: v for k, v in _ckpt_sd.items() if k not in _model_sd or v.shape == _model_sd[k].shape} model.load_state_dict(_filtered, strict=False) for m in model.modules(): if hasattr(m, '_initialized') and not m.initialized: m.initialized = True model = model.to(device) model.eval() if args.data_dir: data_dir = os.path.expanduser(args.data_dir) else: data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/ECFlow/data_tpd_multiheat')) split_dir = os.path.join(data_dir, args.split) raw_per_mechanism = False if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): mech_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d, args.split))] if mech_dirs: raw_per_mechanism = True print(f"Detected raw per-mechanism directory structure in {data_dir}") print(f" Mechanisms found: {sorted(mech_dirs)}") import tempfile tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_tpd_") flat_dir = os.path.join(tmp_dir, args.split) os.makedirs(flat_dir, exist_ok=True) file_idx = 0 for mech_name in sorted(mech_dirs): mech_split = os.path.join(data_dir, mech_name, args.split) for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") os.symlink(os.path.abspath(f), dst) file_idx += 1 split_dir = flat_dir print(f" Linked {file_idx} samples into temporary flat directory") print(f"Loading data from: {split_dir}") noise_aug = None if getattr(args, 'noise_augmentation', False): from noise_augmentation import TPDNoiseAugmentation noise_aug = TPDNoiseAugmentation() print("Noise augmentation ENABLED for recon eval (matches train-time noise distribution)") dataset = TPDDataset( split_dir, max_samples=args.max_samples, normalize_input=True, compute_summary=use_summary, noise_augmentation=noise_aug, ) dataset.temperature_mean = norm_stats['temperature'][0] dataset.temperature_std = norm_stats['temperature'][1] dataset.rate_mean = norm_stats['rate'][0] dataset.rate_std = norm_stats['rate'][1] raw_dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=False) loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) suffix = getattr(args, 'output_suffix', '') eval_dir = ckpt_dir / f"eval_recon_{args.split}{suffix}" if raw_per_mechanism: eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}{suffix}" eval_dir.mkdir(exist_ok=True) per_mech_nrmse_mean = defaultdict(list) per_mech_nrmse_map = defaultdict(list) per_mech_nrmse_samples = defaultdict(list) per_mech_r2_mean = defaultdict(list) per_mech_r2_map = defaultdict(list) vis_count = defaultdict(int) print(f"Evaluating TPD signal reconstruction on {len(dataset)} samples...") for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): x = batch['input'].to(device) scan_mask = batch['scan_mask'].to(device) sigmas_tensor = batch['sigmas'].to(device) flux_scales = batch['flux_scales'].to(device) summary = batch['summary'].to(device) if 'summary' in batch else None mech_id = batch['mechanism_id'].item() if mech_id < 0 or mech_id >= len(active_mechs): continue mech = active_mechs[mech_id] if args.mechanisms is not None and mech not in args.mechanisms: continue raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) raw_params = raw_data['params'].item() raw_rate = raw_data['rate'].astype(np.float32) raw_temp = raw_data['temperature'].astype(np.float32) if 'heating_rates' in raw_data: betas = raw_data['heating_rates'].astype(np.float64) lengths = raw_data['lengths'].astype(int) else: betas = np.array([raw_params.get('beta', 1.0)]) lengths = np.array([len(raw_rate)]) raw_rate = raw_rate[np.newaxis, :] raw_temp = raw_temp[np.newaxis, :] with torch.no_grad(): pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, flux_scales=flux_scales, n_samples=args.n_posterior_samples, temperature=args.temperature, summary=summary) if pred['stats'][mech] is None: continue theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() samples = pred['samples'][mech][0].cpu().numpy() from scipy.stats import gaussian_kde theta_map = np.zeros_like(theta_mean) for d in range(len(theta_mean)): s = samples[:, d] if np.std(s) < 1e-10: theta_map[d] = np.mean(s) else: try: kde = gaussian_kde(s) grid = np.linspace(s.min(), s.max(), 200) theta_map[d] = grid[np.argmax(kde(grid))] except Exception: theta_map[d] = np.median(s) base_params = dict(raw_params) recon_mean = reconstruct_tpd_signal(theta_mean, mech, base_params, betas) recon_map = reconstruct_tpd_signal(theta_map, mech, base_params, betas) nrmse_mean_list = [] nrmse_map_list = [] r2_mean_list = [] r2_map_list = [] for s_idx in range(len(betas)): obs_rate = raw_rate[s_idx] length = lengths[s_idx] if recon_mean[s_idx]['success']: v = signal_nrmse(obs_rate, recon_mean[s_idx]['rate'], length) r = signal_r2(obs_rate, recon_mean[s_idx]['rate'], length) if np.isfinite(v): nrmse_mean_list.append(v) if np.isfinite(r): r2_mean_list.append(r) if recon_map[s_idx]['success']: v = signal_nrmse(obs_rate, recon_map[s_idx]['rate'], length) r = signal_r2(obs_rate, recon_map[s_idx]['rate'], length) if np.isfinite(v): nrmse_map_list.append(v) if np.isfinite(r): r2_map_list.append(r) if nrmse_mean_list: per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) if r2_mean_list: per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) if nrmse_map_list: per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) if r2_map_list: per_mech_r2_map[mech].append(np.mean(r2_map_list)) sample_nrmses = [] n_recon = min(args.n_recon_samples, samples.shape[0]) sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) for si in sample_indices: recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, betas) nrmses = [] for s_idx in range(len(betas)): if recon_s[s_idx]['success']: v = signal_nrmse(raw_rate[s_idx], recon_s[s_idx]['rate'], lengths[s_idx]) if np.isfinite(v): nrmses.append(v) if nrmses: sample_nrmses.append(np.mean(nrmses)) if sample_nrmses: per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) # Visualization if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: fig, axes = plt.subplots(1, len(betas), figsize=(5 * len(betas), 4)) if len(betas) == 1: axes = [axes] for s_idx, ax in enumerate(axes): length = lengths[s_idx] obs_temp = raw_temp[s_idx, :length] obs_rate_s = raw_rate[s_idx, :length] ax.plot(obs_temp, obs_rate_s, 'k-', lw=1.5, label='Observed', alpha=0.8) if recon_mean[s_idx]['success']: r_temp = recon_mean[s_idx]['temperature'] r_rate = recon_mean[s_idx]['rate'] min_len = min(length, len(r_temp)) nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' ax.plot(r_temp[:min_len], r_rate[:min_len], 'r--', lw=1.2, label=lbl) if recon_map[s_idx]['success']: r_temp = recon_map[s_idx]['temperature'] r_rate = recon_map[s_idx]['rate'] min_len = min(length, len(r_temp)) nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' ax.plot(r_temp[:min_len], r_rate[:min_len], 'b:', lw=1.2, label=lbl) for si in sample_indices[:3]: recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, [betas[s_idx]]) if recon_s[0]['success']: r_temp = recon_s[0]['temperature'] r_rate = recon_s[0]['rate'] min_len = min(length, len(r_temp)) ax.plot(r_temp[:min_len], r_rate[:min_len], '-', lw=0.5, alpha=0.3, color='gray') ax.set_xlabel('Temperature (K)') ax.set_ylabel('Rate') ax.set_title(f'β={betas[s_idx]:.2f} K/s') ax.legend(fontsize=7) fig.suptitle(f'{mech} sample {idx}', fontsize=12) plt.tight_layout() plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) plt.close(fig) vis_count[mech] += 1 print("\n" + "=" * 60) print("SIGNAL RECONSTRUCTION METRICS (TPD)") print("=" * 60) results = {} for mech in active_mechs: if not per_mech_nrmse_mean[mech]: continue nrmse_mean = np.array(per_mech_nrmse_mean[mech]) nrmse_map = np.array(per_mech_nrmse_map[mech]) nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) r2_mean = np.array(per_mech_r2_mean[mech]) r2_map = np.array(per_mech_r2_map[mech]) results[mech] = { 'n_samples': len(nrmse_mean), 'nrmse_mean': {'median': float(np.median(nrmse_mean)), 'mean': float(np.mean(nrmse_mean)), 'std': float(np.std(nrmse_mean))}, 'nrmse_map': {'median': float(np.median(nrmse_map)), 'mean': float(np.mean(nrmse_map)), 'std': float(np.std(nrmse_map))}, 'r2_signal_mean': {'median': float(np.median(r2_mean)), 'mean': float(np.mean(r2_mean))}, 'r2_signal_map': {'median': float(np.median(r2_map)), 'mean': float(np.mean(r2_map))}, } if len(nrmse_samp) > 0: results[mech]['nrmse_posterior_median'] = { 'median': float(np.median(nrmse_samp)), 'mean': float(np.mean(nrmse_samp)), } print(f"\n{mech} ({len(nrmse_mean)} samples):") print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") if len(nrmse_samp) > 0: print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " f"mean={np.mean(nrmse_samp):.4f}") print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") with open(eval_dir / "reconstruction_results.json", "w") as f: json.dump(results, f, indent=2) if raw_per_mechanism: import shutil shutil.rmtree(tmp_dir, ignore_errors=True) print(f"\nResults saved to {eval_dir}") print(f"Visualizations: {sum(vis_count.values())} plots saved") if __name__ == "__main__": args = parse_args() if args.domain == "ec": evaluate_ec(args) else: evaluate_tpd(args)