trace / evaluate_reconstruction.py
bing-yan's picture
Initial TRACE deployment: rebrand from ECFlow + noise-aug headline checkpoints
d1cefb1 verified
Raw
History Blame Contribute Delete
43.3 kB
"""
Signal reconstruction evaluation for Multi-Mechanism Normalizing Flow.
For each test sample:
1. Infer mechanism + parameter posterior
2. Reconstruct signals from posterior mean, MAP, and random samples
3. Compare reconstructed signals with observed signals
This validates whether the inferred posteriors produce physically consistent
predictions, even when individual parameters have poor R² (due to compensation).
Usage:
python evaluate_reconstruction.py --checkpoint outputs/multi_mechanism_multiscan/.../best.pt
python evaluate_reconstruction.py --checkpoint outputs/tpd_multiheat/.../best.pt --domain tpd
"""
import os
import sys
import json
import glob
import signal
import argparse
import time as time_module
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class _Timeout:
"""Context manager that raises TimeoutError after `seconds` seconds.
Uses POSIX signal.alarm in the main thread; falls back to a no-op in
worker threads (e.g. Gradio request handlers) where signals are unavailable.
"""
def __init__(self, seconds):
self.seconds = seconds
self._use_signal = False
def _handler(self, signum, frame):
raise TimeoutError(f"Reconstruction timed out after {self.seconds}s")
def __enter__(self):
import threading
if threading.current_thread() is threading.main_thread():
self._use_signal = True
self._old = signal.signal(signal.SIGALRM, self._handler)
signal.alarm(self.seconds)
return self
def __exit__(self, *args):
if self._use_signal:
signal.alarm(0)
signal.signal(signal.SIGALRM, self._old)
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate signal reconstruction")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--domain", type=str, default="ec", choices=["ec", "tpd"])
parser.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
parser.add_argument("--data_dir", type=str, default=None,
help="Override data directory (e.g. for clean test set)")
parser.add_argument("--max_samples", type=int, default=200)
parser.add_argument("--n_posterior_samples", type=int, default=100,
help="Posterior samples for reconstruction")
parser.add_argument("--n_recon_samples", type=int, default=10,
help="Number of random posterior samples to reconstruct per test sample")
parser.add_argument("--n_visualize", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--temperature", type=float, default=1.0,
help="Base distribution temperature for sampling (>1 broadens posteriors)")
parser.add_argument("--noise_augmentation", action="store_true",
help="Apply domain-appropriate noise augmentation to "
"the eval signals (matches train-time noise "
"distribution). Used for noise-robustness "
"ablations.")
parser.add_argument("--mechanisms", type=str, nargs="*", default=None,
help="Restrict reconstruction to this subset of "
"mechanism names (default: all in checkpoint).")
parser.add_argument("--output_suffix", type=str, default="",
help="Suffix for the eval output directory (e.g. "
"'_noisy'). Lets clean and noisy recon evals "
"coexist.")
return parser.parse_args()
# =============================================================================
# EC reconstruction
# =============================================================================
def _safe_pow10(val):
"""Compute 10**val, clamping to avoid OverflowError for extreme values."""
val = np.clip(val, -300, 300)
return 10.0 ** val
def _configure_tpd_mechanisms(mechanism_list):
"""Restore the active TPD mechanism ordering used by a checkpoint."""
import generate_tpd_data as _tpd_gen
import tpd_model as _tpd_mod
import dataset_tpd as _ds_tpd
active_mechs = (
list(mechanism_list)
if mechanism_list is not None
else list(_tpd_gen.TPD_MECHANISM_LIST)
)
_tpd_gen.TPD_MECHANISM_LIST = active_mechs
_tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(active_mechs)}
_tpd_mod.TPD_MECHANISM_LIST = active_mechs
_ds_tpd.TPD_MECHANISM_LIST = active_mechs
return active_mechs
def theta_to_ec_params(theta, mechanism, base_params):
"""
Convert model output (log-space) back to physical simulator parameters.
Args:
theta: [D] numpy array of inferred parameters
mechanism: str mechanism name
base_params: dict of fixed parameters from the original sample
"""
from flow_model import MECHANISM_PARAMS
names = MECHANISM_PARAMS[mechanism]['names']
p = dict(base_params)
p['kinetics'] = mechanism
for i, name in enumerate(names):
val = float(theta[i])
if name.startswith('log10(') and name.endswith(')'):
phys_name = name[6:-1]
p[phys_name] = _safe_pow10(val)
elif name in ('alpha', 'alpha_1', 'alpha_2'):
p[name] = float(np.clip(val, 0.01, 0.99))
elif name in ('E0_offset', 'E0_2_offset'):
p[name] = val
return p
def reconstruct_ec_signal(theta, mechanism, base_params, sigmas, n_spatial=64):
"""
Reconstruct CV signal(s) from inferred parameters.
Args:
theta: [D] inferred parameters (model output space)
mechanism: str
base_params: dict with fixed params (theta_i, theta_v, dA, etc.)
sigmas: list of scan rates
n_spatial: spatial grid points
Returns:
list of dicts with 'potential', 'flux', 'time',
'c_ox_surface', 'c_red_surface' per scan rate
"""
import warnings, inspect
from generate_dataset_diffec import _run_single_cv
_EXT_SIMULATORS = None
if mechanism in ('EE', 'EC_prime', 'CE', 'ECE', 'EC_LH', 'MHC_EC', 'MHC_LH'):
from generate_extended_mechanisms import SIMULATORS as _EXT_SIMULATORS
phys = theta_to_ec_params(theta, mechanism, base_params)
K0_at_1 = phys.get('K0', 1.0)
K0_1_at_1 = phys.get('K0_1', 1.0)
K0_2_at_1 = phys.get('K0_2', 1.0)
kc_at_1 = phys.get('kc', 1.0)
kf_at_1 = phys.get('kf', 1.0)
results = []
for sigma in sigmas:
p = dict(phys)
p['sigma'] = float(sigma)
if mechanism in ('BV', 'MHC', 'EC', 'LH', 'EC_prime', 'CE'):
p['K0'] = K0_at_1 / np.sqrt(sigma)
elif mechanism == 'Ads':
p['K0'] = K0_at_1 / sigma
elif mechanism == 'EE':
p['K0_1'] = K0_1_at_1 / np.sqrt(sigma)
p['K0_2'] = K0_2_at_1 / np.sqrt(sigma)
p.setdefault('dA', 1.0)
p.setdefault('dB', 1.0)
if mechanism in ('EC', 'EC_prime'):
p['kc'] = kc_at_1 / sigma
if mechanism == 'CE':
p['kf'] = kf_at_1 / sigma
if mechanism in ('EC_prime', 'CE'):
p.setdefault('dA', 1.0)
p.setdefault('dB', 1.0)
try:
with _Timeout(30), warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
if _EXT_SIMULATORS is not None and mechanism in _EXT_SIMULATORS:
sim_fn = _EXT_SIMULATORS[mechanism]
sig = inspect.signature(sim_fn)
kwargs = {k: v for k, v in p.items() if k in sig.parameters}
kwargs['n_spatial_out'] = n_spatial
result = sim_fn(**kwargs)
else:
result = _run_single_cv(p, n_spatial)
entry = {
'potential': result['potential'],
'flux': result['flux'],
'time': result['time'],
'success': True,
}
if 'c_ox' in result and 'c_red' in result:
entry['c_ox_surface'] = result['c_ox'][:, -1]
entry['c_red_surface'] = result['c_red'][:, 0]
results.append(entry)
except Exception as e:
results.append({'success': False, 'error': str(e)})
return results
# =============================================================================
# TPD reconstruction
# =============================================================================
def theta_to_tpd_params(theta, mechanism, base_params):
"""Convert model output back to physical TPD simulator parameters."""
from generate_tpd_data import TPD_MECHANISM_PARAMS
names = TPD_MECHANISM_PARAMS[mechanism]['names']
p = dict(base_params)
p['mechanism'] = mechanism
for i, name in enumerate(names):
val = float(theta[i])
if name.startswith('log10(') and name.endswith(')'):
phys_name = name[6:-1]
p[phys_name] = _safe_pow10(val)
else:
p[name] = val
return p
def reconstruct_tpd_signal(theta, mechanism, base_params, betas):
"""
Reconstruct TPD signal(s) from inferred parameters.
Args:
theta: [D] inferred parameters
mechanism: str
base_params: dict with T_start, T_end, etc.
betas: list of heating rates
Returns:
list of dicts with 'temperature', 'rate', 'time' per heating rate
"""
import warnings
from generate_tpd_data import _run_single_tpd
phys = theta_to_tpd_params(theta, mechanism, base_params)
results = []
for beta in betas:
p = dict(phys)
p['beta'] = float(beta)
try:
with _Timeout(30), warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
result = _run_single_tpd(p)
results.append({
'temperature': result['temperature'],
'rate': result['rate'],
'time': result['time'],
'success': True,
})
except Exception as e:
results.append({'success': False, 'error': str(e)})
return results
# =============================================================================
# Metrics
# =============================================================================
def _valid_signal(arr):
"""Check that a signal array contains no NaN/Inf/extreme values."""
if not np.all(np.isfinite(arr)):
return False
if np.max(np.abs(arr)) > 1e30:
return False
return True
def signal_rmse(observed, reconstructed, length=None):
"""RMSE between two signals, optionally truncated to valid length."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
return float(np.sqrt(np.mean((o - r) ** 2)))
def signal_nrmse(observed, reconstructed, length=None):
"""Normalized RMSE (by peak-to-peak range of observed signal)."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
ptp = np.ptp(o)
if ptp < 1e-20:
return float('inf')
return float(np.sqrt(np.mean((o - r) ** 2)) / ptp)
def signal_r2(observed, reconstructed, length=None):
"""R² between observed and reconstructed signals."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
ss_res = np.sum((o - r) ** 2)
ss_tot = np.sum((o - np.mean(o)) ** 2)
if ss_tot < 1e-20:
return 0.0
return float(1 - ss_res / ss_tot)
# =============================================================================
# EC evaluation
# =============================================================================
def evaluate_ec(args):
import multi_mechanism_model as _mm_module
import flow_model as _fm_module
from multi_mechanism_model import MultiMechanismFlow
from flow_model import MECHANISM_PARAMS
from dataset import DiffECDataset, collate_fn
from torch.utils.data import DataLoader
ckpt_path = os.path.expanduser(args.checkpoint)
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
ckpt_args = checkpoint['args']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
use_summary = ckpt_args.get('use_summary_features', False)
# If the checkpoint was trained on a custom mechanism subset (e.g. v14_9mech),
# patch the global MECHANISM_LIST in the model + flow_model modules BEFORE
# instantiating MultiMechanismFlow, so the classifier and flow_heads sizes
# match the checkpoint state_dict.
if ckpt_args.get('mechanism_list') is not None:
new_list = list(ckpt_args['mechanism_list'])
_fm_module.MECHANISM_LIST = new_list
_mm_module.MECHANISM_LIST = new_list
print(f"Mechanism list overridden from checkpoint: {new_list} "
f"({len(new_list)} mechanisms)")
MECHANISM_LIST = _fm_module.MECHANISM_LIST
model = MultiMechanismFlow(
d_context=ckpt_args.get('d_context', 128),
d_model=ckpt_args.get('d_model', 128),
n_coupling_layers=ckpt_args.get('n_coupling_layers', 6),
hidden_dim=ckpt_args.get('hidden_dim', 96),
coupling_type=ckpt_args.get('coupling_type', 'spline'),
n_bins=ckpt_args.get('n_bins', 8),
tail_bound=ckpt_args.get('tail_bound', 5.0),
aggregation=ckpt_args.get('aggregation', 'set_transformer'),
use_summary_features=use_summary,
)
ckpt_dir = Path(ckpt_path).parent.parent
theta_stats_path = ckpt_dir / "theta_stats.json"
with open(theta_stats_path) as f:
theta_stats = json.load(f)
for mech in MECHANISM_LIST:
if mech in theta_stats:
model.set_theta_stats(
mech,
torch.tensor(theta_stats[mech]['mean']),
torch.tensor(theta_stats[mech]['std']),
)
norm_stats_path = ckpt_dir / "norm_stats.json"
with open(norm_stats_path) as f:
norm_stats = json.load(f)
_ckpt_sd = checkpoint['model_state_dict']
_model_sd = model.state_dict()
_filtered = {k: v for k, v in _ckpt_sd.items() if k not in _model_sd or v.shape == _model_sd[k].shape}
model.load_state_dict(_filtered, strict=False)
for m in model.modules():
if hasattr(m, '_initialized') and not m.initialized:
m.initialized = True
model = model.to(device)
model.eval()
if args.data_dir:
data_dir = os.path.expanduser(args.data_dir)
else:
data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/DiffEC/data'))
# Support raw per-mechanism directory structure:
# {data_dir}/{Mechanism}/{split}/sample_*.npz
# as well as the assembled flat structure:
# {data_dir}/{split}/sample_*.npz
split_dir = os.path.join(data_dir, args.split)
raw_per_mechanism = False
if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")):
# Try raw per-mechanism structure
mech_dirs = [d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d, args.split))]
if mech_dirs:
raw_per_mechanism = True
print(f"Detected raw per-mechanism directory structure in {data_dir}")
print(f" Mechanisms found: {sorted(mech_dirs)}")
# Create a temporary flat directory with symlinks
import tempfile
tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_")
flat_dir = os.path.join(tmp_dir, args.split)
os.makedirs(flat_dir, exist_ok=True)
file_idx = 0
for mech_name in sorted(mech_dirs):
mech_split = os.path.join(data_dir, mech_name, args.split)
for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))):
dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz")
os.symlink(os.path.abspath(f), dst)
file_idx += 1
split_dir = flat_dir
print(f" Linked {file_idx} samples into temporary flat directory")
print(f"Loading data from: {split_dir}")
noise_aug = None
if getattr(args, 'noise_augmentation', False):
from noise_augmentation import ECNoiseAugmentation
noise_aug = ECNoiseAugmentation()
print("Noise augmentation ENABLED for recon eval (matches train-time noise distribution)")
dataset = DiffECDataset(
split_dir,
max_samples=args.max_samples,
normalize_input=True,
compute_summary=use_summary,
noise_augmentation=noise_aug,
)
dataset.potential_mean = norm_stats['potential'][0]
dataset.potential_std = norm_stats['potential'][1]
dataset.flux_mean = norm_stats['flux'][0]
dataset.flux_std = norm_stats['flux'][1]
dataset.time_mean = norm_stats['time'][0]
dataset.time_std = norm_stats['time'][1]
raw_dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
suffix = getattr(args, 'output_suffix', '')
eval_dir = ckpt_dir / f"eval_recon_{args.split}{suffix}"
if raw_per_mechanism:
eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}{suffix}"
eval_dir.mkdir(exist_ok=True)
per_mech_nrmse_mean = defaultdict(list)
per_mech_nrmse_map = defaultdict(list)
per_mech_nrmse_samples = defaultdict(list)
per_mech_r2_mean = defaultdict(list)
per_mech_r2_map = defaultdict(list)
n_failed = 0
vis_count = defaultdict(int)
print(f"Evaluating signal reconstruction on {len(dataset)} samples...")
for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")):
x = batch['input'].to(device)
scan_mask = batch['scan_mask'].to(device)
sigmas_tensor = batch['sigmas'].to(device)
flux_scales = batch['flux_scales'].to(device)
summary = batch['summary'].to(device) if 'summary' in batch else None
mech_id = batch['mechanism_id'].item()
if mech_id < 0 or mech_id >= len(MECHANISM_LIST):
continue
mech = MECHANISM_LIST[mech_id]
if args.mechanisms is not None and mech not in args.mechanisms:
continue
raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True)
raw_params = raw_data['params'].item()
raw_flux = raw_data['flux'].astype(np.float32)
raw_potential = raw_data['potential'].astype(np.float32)
if 'sigmas' in raw_data:
scan_rates = raw_data['sigmas'].astype(np.float64)
lengths = raw_data['lengths'].astype(int)
else:
scan_rates = np.array([raw_params.get('sigma', 1.0)])
lengths = np.array([len(raw_potential)])
raw_flux = raw_flux[np.newaxis, :]
raw_potential = raw_potential[np.newaxis, :]
with torch.no_grad():
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor,
flux_scales=flux_scales,
n_samples=args.n_posterior_samples,
temperature=args.temperature,
summary=summary)
if pred['stats'][mech] is None:
continue
theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy()
samples = pred['samples'][mech][0].cpu().numpy()
# MAP estimate via 1D KDE
from scipy.stats import gaussian_kde
theta_map = np.zeros_like(theta_mean)
for d in range(len(theta_mean)):
s = samples[:, d]
if np.std(s) < 1e-10:
theta_map[d] = np.mean(s)
else:
try:
kde = gaussian_kde(s)
grid = np.linspace(s.min(), s.max(), 200)
theta_map[d] = grid[np.argmax(kde(grid))]
except Exception:
theta_map[d] = np.median(s)
base_params = dict(raw_params)
# Reconstruct from mean
recon_mean = reconstruct_ec_signal(theta_mean, mech, base_params, scan_rates)
# Reconstruct from MAP
recon_map = reconstruct_ec_signal(theta_map, mech, base_params, scan_rates)
# Compute metrics per scan rate
nrmse_mean_list = []
nrmse_map_list = []
r2_mean_list = []
r2_map_list = []
for s_idx in range(len(scan_rates)):
obs_flux = raw_flux[s_idx]
length = lengths[s_idx]
if recon_mean[s_idx]['success']:
v = signal_nrmse(obs_flux, recon_mean[s_idx]['flux'], length)
r = signal_r2(obs_flux, recon_mean[s_idx]['flux'], length)
if np.isfinite(v):
nrmse_mean_list.append(v)
if np.isfinite(r):
r2_mean_list.append(r)
if recon_map[s_idx]['success']:
v = signal_nrmse(obs_flux, recon_map[s_idx]['flux'], length)
r = signal_r2(obs_flux, recon_map[s_idx]['flux'], length)
if np.isfinite(v):
nrmse_map_list.append(v)
if np.isfinite(r):
r2_map_list.append(r)
if nrmse_mean_list:
per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list))
if r2_mean_list:
per_mech_r2_mean[mech].append(np.mean(r2_mean_list))
if nrmse_map_list:
per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list))
if r2_map_list:
per_mech_r2_map[mech].append(np.mean(r2_map_list))
# Reconstruct from random posterior samples
sample_nrmses = []
n_recon = min(args.n_recon_samples, samples.shape[0])
sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False)
for si in sample_indices:
recon_s = reconstruct_ec_signal(samples[si], mech, base_params, scan_rates)
nrmses = []
for s_idx in range(len(scan_rates)):
if recon_s[s_idx]['success']:
v = signal_nrmse(raw_flux[s_idx], recon_s[s_idx]['flux'], lengths[s_idx])
if np.isfinite(v):
nrmses.append(v)
if nrmses:
sample_nrmses.append(np.mean(nrmses))
if sample_nrmses:
per_mech_nrmse_samples[mech].append(np.median(sample_nrmses))
# Visualization
if vis_count[mech] < args.n_visualize and recon_mean[0]['success']:
fig, axes = plt.subplots(1, len(scan_rates), figsize=(5 * len(scan_rates), 4))
if len(scan_rates) == 1:
axes = [axes]
for s_idx, ax in enumerate(axes):
length = lengths[s_idx]
obs_pot = raw_potential[s_idx, :length]
obs_flux_s = raw_flux[s_idx, :length]
ax.plot(obs_pot, obs_flux_s, 'k-', lw=1.5, label='Observed', alpha=0.8)
if recon_mean[s_idx]['success']:
r_pot = recon_mean[s_idx]['potential']
r_flux = recon_mean[s_idx]['flux']
min_len = min(length, len(r_pot))
nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length)
lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)'
ax.plot(r_pot[:min_len], r_flux[:min_len], 'r--', lw=1.2, label=lbl)
if recon_map[s_idx]['success']:
r_pot = recon_map[s_idx]['potential']
r_flux = recon_map[s_idx]['flux']
min_len = min(length, len(r_pot))
nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length)
lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)'
ax.plot(r_pot[:min_len], r_flux[:min_len], 'b:', lw=1.2, label=lbl)
# Plot a few posterior samples
for si in sample_indices[:3]:
recon_s = reconstruct_ec_signal(samples[si], mech, base_params, [scan_rates[s_idx]])
if recon_s[0]['success']:
r_pot = recon_s[0]['potential']
r_flux = recon_s[0]['flux']
min_len = min(length, len(r_pot))
ax.plot(r_pot[:min_len], r_flux[:min_len], '-', lw=0.5,
alpha=0.3, color='gray')
ax.set_xlabel('Potential (θ)')
ax.set_ylabel('Flux')
ax.set_title(f'σ={scan_rates[s_idx]:.2f}')
ax.legend(fontsize=7)
fig.suptitle(f'{mech} sample {idx}', fontsize=12)
plt.tight_layout()
plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150)
plt.close(fig)
vis_count[mech] += 1
# Print and save results
print("\n" + "=" * 60)
print("SIGNAL RECONSTRUCTION METRICS")
print("=" * 60)
results = {}
for mech in MECHANISM_LIST:
if not per_mech_nrmse_mean[mech]:
continue
nrmse_mean = np.array(per_mech_nrmse_mean[mech])
nrmse_map = np.array(per_mech_nrmse_map[mech])
nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([])
r2_mean = np.array(per_mech_r2_mean[mech])
r2_map = np.array(per_mech_r2_map[mech])
results[mech] = {
'n_samples': len(nrmse_mean),
'nrmse_mean': {'median': float(np.median(nrmse_mean)),
'mean': float(np.mean(nrmse_mean)),
'std': float(np.std(nrmse_mean)),
'q25': float(np.percentile(nrmse_mean, 25)),
'q75': float(np.percentile(nrmse_mean, 75))},
'nrmse_map': {'median': float(np.median(nrmse_map)),
'mean': float(np.mean(nrmse_map)),
'std': float(np.std(nrmse_map))},
'r2_signal_mean': {'median': float(np.median(r2_mean)),
'mean': float(np.mean(r2_mean))},
'r2_signal_map': {'median': float(np.median(r2_map)),
'mean': float(np.mean(r2_map))},
}
if len(nrmse_samp) > 0:
results[mech]['nrmse_posterior_median'] = {
'median': float(np.median(nrmse_samp)),
'mean': float(np.mean(nrmse_samp)),
}
print(f"\n{mech} ({len(nrmse_mean)} samples):")
print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} "
f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}")
print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} "
f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}")
if len(nrmse_samp) > 0:
print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} "
f"mean={np.mean(nrmse_samp):.4f}")
print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}")
print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}")
with open(eval_dir / "reconstruction_results.json", "w") as f:
json.dump(results, f, indent=2)
if raw_per_mechanism:
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
print(f"\nResults saved to {eval_dir}")
print(f"Visualizations: {sum(vis_count.values())} plots saved")
# =============================================================================
# TPD evaluation
# =============================================================================
def evaluate_tpd(args):
from tpd_model import MultiMechanismFlowTPD
from generate_tpd_data import TPD_MECHANISM_PARAMS
from dataset_tpd import TPDDataset, collate_fn
from torch.utils.data import DataLoader
ckpt_path = os.path.expanduser(args.checkpoint)
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
ckpt_args = checkpoint['args']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
use_summary = ckpt_args.get('use_summary_features', False)
active_mechs = _configure_tpd_mechanisms(ckpt_args.get('mechanism_list'))
model = MultiMechanismFlowTPD(
d_context=ckpt_args.get('d_context', 128),
d_model=ckpt_args.get('d_model', 128),
n_coupling_layers=ckpt_args.get('n_coupling_layers', 6),
hidden_dim=ckpt_args.get('hidden_dim', 96),
coupling_type=ckpt_args.get('coupling_type', 'spline'),
n_bins=ckpt_args.get('n_bins', 8),
tail_bound=ckpt_args.get('tail_bound', 5.0),
use_summary_features=use_summary,
mechanism_list=active_mechs,
use_bounded_flow=ckpt_args.get('use_bounded_flow', False),
)
ckpt_dir = Path(ckpt_path).parent.parent
theta_stats_path = ckpt_dir / "theta_stats.json"
with open(theta_stats_path) as f:
theta_stats = json.load(f)
for mech in active_mechs:
if mech in theta_stats:
model.set_theta_stats(
mech,
torch.tensor(theta_stats[mech]['mean']),
torch.tensor(theta_stats[mech]['std']),
)
norm_stats_path = ckpt_dir / "norm_stats.json"
with open(norm_stats_path) as f:
norm_stats = json.load(f)
_ckpt_sd = checkpoint['model_state_dict']
_model_sd = model.state_dict()
_filtered = {k: v for k, v in _ckpt_sd.items() if k not in _model_sd or v.shape == _model_sd[k].shape}
model.load_state_dict(_filtered, strict=False)
for m in model.modules():
if hasattr(m, '_initialized') and not m.initialized:
m.initialized = True
model = model.to(device)
model.eval()
if args.data_dir:
data_dir = os.path.expanduser(args.data_dir)
else:
data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/ECFlow/data_tpd_multiheat'))
split_dir = os.path.join(data_dir, args.split)
raw_per_mechanism = False
if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")):
mech_dirs = [d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d, args.split))]
if mech_dirs:
raw_per_mechanism = True
print(f"Detected raw per-mechanism directory structure in {data_dir}")
print(f" Mechanisms found: {sorted(mech_dirs)}")
import tempfile
tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_tpd_")
flat_dir = os.path.join(tmp_dir, args.split)
os.makedirs(flat_dir, exist_ok=True)
file_idx = 0
for mech_name in sorted(mech_dirs):
mech_split = os.path.join(data_dir, mech_name, args.split)
for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))):
dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz")
os.symlink(os.path.abspath(f), dst)
file_idx += 1
split_dir = flat_dir
print(f" Linked {file_idx} samples into temporary flat directory")
print(f"Loading data from: {split_dir}")
noise_aug = None
if getattr(args, 'noise_augmentation', False):
from noise_augmentation import TPDNoiseAugmentation
noise_aug = TPDNoiseAugmentation()
print("Noise augmentation ENABLED for recon eval (matches train-time noise distribution)")
dataset = TPDDataset(
split_dir,
max_samples=args.max_samples,
normalize_input=True,
compute_summary=use_summary,
noise_augmentation=noise_aug,
)
dataset.temperature_mean = norm_stats['temperature'][0]
dataset.temperature_std = norm_stats['temperature'][1]
dataset.rate_mean = norm_stats['rate'][0]
dataset.rate_std = norm_stats['rate'][1]
raw_dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
suffix = getattr(args, 'output_suffix', '')
eval_dir = ckpt_dir / f"eval_recon_{args.split}{suffix}"
if raw_per_mechanism:
eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}{suffix}"
eval_dir.mkdir(exist_ok=True)
per_mech_nrmse_mean = defaultdict(list)
per_mech_nrmse_map = defaultdict(list)
per_mech_nrmse_samples = defaultdict(list)
per_mech_r2_mean = defaultdict(list)
per_mech_r2_map = defaultdict(list)
vis_count = defaultdict(int)
print(f"Evaluating TPD signal reconstruction on {len(dataset)} samples...")
for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")):
x = batch['input'].to(device)
scan_mask = batch['scan_mask'].to(device)
sigmas_tensor = batch['sigmas'].to(device)
flux_scales = batch['flux_scales'].to(device)
summary = batch['summary'].to(device) if 'summary' in batch else None
mech_id = batch['mechanism_id'].item()
if mech_id < 0 or mech_id >= len(active_mechs):
continue
mech = active_mechs[mech_id]
if args.mechanisms is not None and mech not in args.mechanisms:
continue
raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True)
raw_params = raw_data['params'].item()
raw_rate = raw_data['rate'].astype(np.float32)
raw_temp = raw_data['temperature'].astype(np.float32)
if 'heating_rates' in raw_data:
betas = raw_data['heating_rates'].astype(np.float64)
lengths = raw_data['lengths'].astype(int)
else:
betas = np.array([raw_params.get('beta', 1.0)])
lengths = np.array([len(raw_rate)])
raw_rate = raw_rate[np.newaxis, :]
raw_temp = raw_temp[np.newaxis, :]
with torch.no_grad():
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor,
flux_scales=flux_scales,
n_samples=args.n_posterior_samples,
temperature=args.temperature,
summary=summary)
if pred['stats'][mech] is None:
continue
theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy()
samples = pred['samples'][mech][0].cpu().numpy()
from scipy.stats import gaussian_kde
theta_map = np.zeros_like(theta_mean)
for d in range(len(theta_mean)):
s = samples[:, d]
if np.std(s) < 1e-10:
theta_map[d] = np.mean(s)
else:
try:
kde = gaussian_kde(s)
grid = np.linspace(s.min(), s.max(), 200)
theta_map[d] = grid[np.argmax(kde(grid))]
except Exception:
theta_map[d] = np.median(s)
base_params = dict(raw_params)
recon_mean = reconstruct_tpd_signal(theta_mean, mech, base_params, betas)
recon_map = reconstruct_tpd_signal(theta_map, mech, base_params, betas)
nrmse_mean_list = []
nrmse_map_list = []
r2_mean_list = []
r2_map_list = []
for s_idx in range(len(betas)):
obs_rate = raw_rate[s_idx]
length = lengths[s_idx]
if recon_mean[s_idx]['success']:
v = signal_nrmse(obs_rate, recon_mean[s_idx]['rate'], length)
r = signal_r2(obs_rate, recon_mean[s_idx]['rate'], length)
if np.isfinite(v):
nrmse_mean_list.append(v)
if np.isfinite(r):
r2_mean_list.append(r)
if recon_map[s_idx]['success']:
v = signal_nrmse(obs_rate, recon_map[s_idx]['rate'], length)
r = signal_r2(obs_rate, recon_map[s_idx]['rate'], length)
if np.isfinite(v):
nrmse_map_list.append(v)
if np.isfinite(r):
r2_map_list.append(r)
if nrmse_mean_list:
per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list))
if r2_mean_list:
per_mech_r2_mean[mech].append(np.mean(r2_mean_list))
if nrmse_map_list:
per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list))
if r2_map_list:
per_mech_r2_map[mech].append(np.mean(r2_map_list))
sample_nrmses = []
n_recon = min(args.n_recon_samples, samples.shape[0])
sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False)
for si in sample_indices:
recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, betas)
nrmses = []
for s_idx in range(len(betas)):
if recon_s[s_idx]['success']:
v = signal_nrmse(raw_rate[s_idx], recon_s[s_idx]['rate'], lengths[s_idx])
if np.isfinite(v):
nrmses.append(v)
if nrmses:
sample_nrmses.append(np.mean(nrmses))
if sample_nrmses:
per_mech_nrmse_samples[mech].append(np.median(sample_nrmses))
# Visualization
if vis_count[mech] < args.n_visualize and recon_mean[0]['success']:
fig, axes = plt.subplots(1, len(betas), figsize=(5 * len(betas), 4))
if len(betas) == 1:
axes = [axes]
for s_idx, ax in enumerate(axes):
length = lengths[s_idx]
obs_temp = raw_temp[s_idx, :length]
obs_rate_s = raw_rate[s_idx, :length]
ax.plot(obs_temp, obs_rate_s, 'k-', lw=1.5, label='Observed', alpha=0.8)
if recon_mean[s_idx]['success']:
r_temp = recon_mean[s_idx]['temperature']
r_rate = recon_mean[s_idx]['rate']
min_len = min(length, len(r_temp))
nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length)
lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)'
ax.plot(r_temp[:min_len], r_rate[:min_len], 'r--', lw=1.2, label=lbl)
if recon_map[s_idx]['success']:
r_temp = recon_map[s_idx]['temperature']
r_rate = recon_map[s_idx]['rate']
min_len = min(length, len(r_temp))
nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length)
lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)'
ax.plot(r_temp[:min_len], r_rate[:min_len], 'b:', lw=1.2, label=lbl)
for si in sample_indices[:3]:
recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, [betas[s_idx]])
if recon_s[0]['success']:
r_temp = recon_s[0]['temperature']
r_rate = recon_s[0]['rate']
min_len = min(length, len(r_temp))
ax.plot(r_temp[:min_len], r_rate[:min_len], '-', lw=0.5,
alpha=0.3, color='gray')
ax.set_xlabel('Temperature (K)')
ax.set_ylabel('Rate')
ax.set_title(f'β={betas[s_idx]:.2f} K/s')
ax.legend(fontsize=7)
fig.suptitle(f'{mech} sample {idx}', fontsize=12)
plt.tight_layout()
plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150)
plt.close(fig)
vis_count[mech] += 1
print("\n" + "=" * 60)
print("SIGNAL RECONSTRUCTION METRICS (TPD)")
print("=" * 60)
results = {}
for mech in active_mechs:
if not per_mech_nrmse_mean[mech]:
continue
nrmse_mean = np.array(per_mech_nrmse_mean[mech])
nrmse_map = np.array(per_mech_nrmse_map[mech])
nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([])
r2_mean = np.array(per_mech_r2_mean[mech])
r2_map = np.array(per_mech_r2_map[mech])
results[mech] = {
'n_samples': len(nrmse_mean),
'nrmse_mean': {'median': float(np.median(nrmse_mean)),
'mean': float(np.mean(nrmse_mean)),
'std': float(np.std(nrmse_mean))},
'nrmse_map': {'median': float(np.median(nrmse_map)),
'mean': float(np.mean(nrmse_map)),
'std': float(np.std(nrmse_map))},
'r2_signal_mean': {'median': float(np.median(r2_mean)),
'mean': float(np.mean(r2_mean))},
'r2_signal_map': {'median': float(np.median(r2_map)),
'mean': float(np.mean(r2_map))},
}
if len(nrmse_samp) > 0:
results[mech]['nrmse_posterior_median'] = {
'median': float(np.median(nrmse_samp)),
'mean': float(np.mean(nrmse_samp)),
}
print(f"\n{mech} ({len(nrmse_mean)} samples):")
print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} "
f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}")
print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} "
f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}")
if len(nrmse_samp) > 0:
print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} "
f"mean={np.mean(nrmse_samp):.4f}")
print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}")
print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}")
with open(eval_dir / "reconstruction_results.json", "w") as f:
json.dump(results, f, indent=2)
if raw_per_mechanism:
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
print(f"\nResults saved to {eval_dir}")
print(f"Visualizations: {sum(vis_count.values())} plots saved")
if __name__ == "__main__":
args = parse_args()
if args.domain == "ec":
evaluate_ec(args)
else:
evaluate_tpd(args)