DDPM-2param / src /evaluate_conditional.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200)
0d05ab1 verified
"""
Evaluate Conditional Diffusion Model (2-label: Omega_m, sigma_8)
Usage:
python evaluate_conditional.py --checkpoint outputs_conditional_YYYYMMDD_HHMMSS/checkpoints/best_model.pt
Changes from original:
- Loads args.json (saved by improved training script) for robust config parsing
- Falls back to args.txt parsing if JSON not available
- Vectorized power spectrum calculation (~100x speedup)
- Added weights_only parameter to torch.load
"""
import argparse
import ast
import json
import os
from pathlib import Path
from typing import Dict, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
from unet_conditional import ConditionalUNet
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate conditional 2-label diffusion model")
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to trained checkpoint (e.g. outputs_conditional_*/checkpoints/best_model.pt)",
)
parser.add_argument(
"--training_args",
type=str,
default=None,
help="Path to args.json or args.txt from training (auto-detected if not provided)",
)
parser.add_argument(
"--data_dir",
type=str,
default="./data/params_2",
help="Directory containing the CAMELS LH dataset (default matches repo structure)",
)
parser.add_argument(
"--split",
type=str,
default="test",
choices=["train", "val", "test"],
help="Which split to use for real images",
)
parser.add_argument(
"--num_samples",
type=int,
default=8,
help="Number of examples to show in the comparison grid",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility",
)
parser.add_argument(
"--output_dir",
type=str,
default="evaluation_outputs",
help="Where to save plots and results",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM sampling steps",
)
return parser.parse_args()
def load_training_config(path: str) -> Dict:
"""Load training configuration. Prefers JSON, falls back to txt parsing."""
# Try JSON first (written by improved training script)
json_path = path.replace('.txt', '.json') if path.endswith('.txt') else path
if json_path.endswith('.json') and os.path.isfile(json_path):
with open(json_path, 'r') as f:
return json.load(f)
# Fall back to txt parsing
if not os.path.isfile(path):
raise FileNotFoundError(f"Training args file not found: {path}")
config = {}
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
if value.startswith("[") and value.endswith("]"):
try:
config[key] = ast.literal_eval(value)
except (ValueError, SyntaxError):
config[key] = value
elif value.isdigit():
config[key] = int(value)
elif value.replace(".", "", 1).replace("e-", "", 1).replace("e", "", 1).isdigit():
config[key] = float(value)
else:
config[key] = value
return config
def _detect_label_suffix(data_dir: Path) -> str:
"""Detect whether this is a 2-param or 6-param dataset."""
if (data_dir / "train_labels_LH_2.npy").exists():
return "_2"
elif (data_dir / "train_labels_LH.npy").exists():
return ""
else:
raise FileNotFoundError(f"No label files found in {data_dir}")
def _detect_image_suffix(data_dir: Path) -> str:
"""Detect whether images use _6 suffix (6-param) or not."""
if (data_dir / "train_LH.npy").exists():
return ""
elif (data_dir / "train_LH_6.npy").exists():
return "_6"
else:
raise FileNotFoundError(f"No image files found in {data_dir}")
def load_label_stats(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
"""Load mean and std from training labels (used for normalization)."""
suffix = _detect_label_suffix(data_dir)
labels_path = data_dir / f"train_labels_LH{suffix}.npy"
labels = np.load(labels_path)
mean, std = labels.mean(axis=0), labels.std(axis=0)
std = np.where(std == 0, 1.0, std) # guard against zero-variance labels
return mean, std
def load_split(data_dir: Path, split: str) -> Tuple[np.ndarray, np.ndarray]:
"""Load images and labels for a given split."""
img_suffix = _detect_image_suffix(data_dir)
label_suffix = _detect_label_suffix(data_dir)
image_path = data_dir / f"{split}_LH{img_suffix}.npy"
label_path = data_dir / f"{split}_labels_LH{label_suffix}.npy"
if not image_path.exists():
raise FileNotFoundError(f"Image file not found: {image_path}")
if not label_path.exists():
raise FileNotFoundError(f"Label file not found: {label_path}")
images = np.load(image_path).astype(np.float32)
labels = np.load(label_path).astype(np.float32)
return images, labels
def build_model(config: Dict, device: torch.device) -> ConditionalDiffusionModel:
"""Rebuild the exact same model architecture used during training."""
unet = ConditionalUNet(
in_channels=1,
out_channels=1,
label_dim=int(config.get("label_dim", 2)),
base_channels=int(config.get("base_channels", 64)),
channel_multipliers=config.get("channel_multipliers", [1, 2, 4, 8]),
attention_levels=config.get("attention_levels", [2, 3]),
dropout=float(config.get("dropout", 0.1)),
)
diffusion = GaussianDiffusion(
timesteps=int(config.get("timesteps", 1500)),
beta_start=float(config.get("beta_start", 1e-4)),
beta_end=float(config.get("beta_end", 0.02)),
schedule_type=config.get("schedule_type", "linear"),
)
return ConditionalDiffusionModel(unet, diffusion).to(device)
def load_checkpoint(model: ConditionalDiffusionModel, checkpoint_path: str, device: torch.device):
"""Load model weights from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint["model_state_dict"] if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint else checkpoint
# If EMA weights are available, use them (they are the better weights)
if isinstance(checkpoint, dict) and "ema_shadow" in checkpoint:
print("Loading EMA shadow weights from checkpoint")
ema_shadow = checkpoint["ema_shadow"]
current_state = model.state_dict()
for name, param in ema_shadow.items():
if name in current_state:
current_state[name] = param
model.load_state_dict(current_state)
else:
model.load_state_dict(state_dict)
model.eval()
print(f"Loaded checkpoint: {checkpoint_path}")
def PowerSpectrum(box: np.ndarray, N: int, dl: float) -> Tuple[np.ndarray, np.ndarray]:
"""Vectorized 2D power spectrum computation."""
FT_box = np.fft.fftn(box, norm="ortho")
k = 2 * np.pi * np.fft.fftfreq(N, dl)
dk_val = 2 * np.pi / (N * dl)
# Vectorized: compute k magnitudes and bin indices for all pixels at once
ki, kj = np.meshgrid(k, k, indexing='ij')
kbar = np.sqrt(ki**2 + kj**2)
n_bins = N // 2 # only bins up to Nyquist frequency
t_idx = np.round(kbar / dk_val).astype(int)
# Mask out modes beyond Nyquist to avoid bin contamination
valid = t_idx < n_bins
power = (FT_box * np.conj(FT_box)).real
pk = np.zeros(n_bins)
count = np.zeros(n_bins)
np.add.at(pk, t_idx[valid], power[valid])
np.add.at(count, t_idx[valid], 1)
pk /= np.where(count == 0, 1, count)
pk *= dl**2
dk = np.arange(n_bins) * dk_val
return dk, pk
def calculate_pdf_batch(images: np.ndarray, log_nhi_min=14.0, log_nhi_max=22.0, n_bins=100):
images_01 = np.clip(images, 0.0, 1.0)
log_nhi_bins = np.linspace(log_nhi_min, log_nhi_max, n_bins)
bin_centers = 0.5 * (log_nhi_bins[:-1] + log_nhi_bins[1:])
pdfs = []
for img in images_01:
log_nhi_values = log_nhi_min + (log_nhi_max - log_nhi_min) * img.reshape(-1)
hist, _ = np.histogram(log_nhi_values, bins=log_nhi_bins, density=True)
pdfs.append(hist)
pdf_array = np.stack(pdfs)
return bin_centers, pdf_array.mean(axis=0), pdf_array.std(axis=0)
def calculate_power_spectrum_batch(images: np.ndarray, box_size: float = 25.0):
N = images.shape[-1]
dl = box_size / N
# Compute k-values once, then reuse for all images
dk, _ = PowerSpectrum(images[0], N=N, dl=dl)
power_spectra = [PowerSpectrum(img, N=N, dl=dl)[1] for img in images]
power_array = np.stack(power_spectra)
return dk, power_array.mean(axis=0), power_array.std(axis=0)
def prepare_labels_for_model(labels: np.ndarray, mean: np.ndarray, std: np.ndarray) -> torch.Tensor:
normalized = (labels - mean) / std
return torch.from_numpy(normalized).float()
def from_model_output(samples: torch.Tensor) -> np.ndarray:
arrays = samples.cpu().numpy()
return np.clip((arrays + 1.0) / 2.0, 0.0, 1.0)[:, 0, :, :]
def plot_image_grid(generated, real, labels, output_path: Path, num_samples=8):
num = min(num_samples, generated.shape[0])
fig, axes = plt.subplots(num, 2, figsize=(6, 3 * num))
if num == 1:
axes = np.expand_dims(axes, axis=0)
for i in range(num):
label_str = ", ".join(f"{v:.3f}" for v in labels[i])
axes[i, 0].imshow(generated[i], cmap="magma", origin="lower")
axes[i, 0].set_title(f"Generated\n{label_str}")
axes[i, 0].axis("off")
axes[i, 1].imshow(real[i], cmap="magma", origin="lower")
axes[i, 1].set_title("Real")
axes[i, 1].axis("off")
plt.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def plot_mean_std(x, mean_real, std_real, mean_gen, std_gen, xlabel, ylabel, title, output_path: Path, yscale="linear"):
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, mean_real, label="Real mean", color="tab:blue", linewidth=2)
ax.plot(x, mean_gen, label="Generated mean", color="tab:orange", linewidth=2)
ax.fill_between(x, mean_real - std_real, mean_real + std_real, color="tab:blue", alpha=0.15, label="Real +/-1s")
ax.fill_between(x, mean_real - 3*std_real, mean_real + 3*std_real, color="tab:blue", alpha=0.05)
ax.fill_between(x, mean_gen - std_gen, mean_gen + std_gen, color="tab:orange", alpha=0.15, label="Generated +/-1s")
ax.fill_between(x, mean_gen - 3*std_gen, mean_gen + 3*std_gen, color="tab:orange", alpha=0.05)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.set_yscale(yscale)
ax.legend()
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def main():
args = parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load training config
if args.training_args is None:
# Try JSON first, then txt
possible_json = list(Path(".").glob("outputs_conditional_*/args.json"))
possible_txt = list(Path(".").glob("outputs_conditional_*/args.txt"))
possible = possible_json + possible_txt
if possible:
args.training_args = str(max(possible, key=os.path.getctime))
print(f"Auto-detected training args: {args.training_args}")
else:
raise FileNotFoundError("Please provide --training_args path to your training args.json or args.txt")
config = load_training_config(args.training_args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(config, device)
load_checkpoint(model, args.checkpoint, device)
# Load data
data_dir = Path(args.data_dir)
images_split, labels_split = load_split(data_dir, args.split)
label_mean, label_std = load_label_stats(data_dir)
# Select random samples
num_select = min(100, len(images_split))
indices = np.random.choice(len(images_split), num_select, replace=False)
real_images = images_split[indices]
original_labels = labels_split[indices]
# Generate samples in batches
batch_size = min(8, num_select)
generated_list = []
print(f"Generating {num_select} samples (batch size = {batch_size})...")
for i in range(0, num_select, batch_size):
batch_labels = original_labels[i:i+batch_size]
batch_labels_tensor = prepare_labels_for_model(batch_labels, label_mean, label_std).to(device)
with torch.no_grad():
batch_gen = model.sample(
labels=batch_labels_tensor,
channels=1,
height=real_images.shape[-2],
width=real_images.shape[-1],
device=device,
progress=False,
use_ddim=True,
ddim_steps=args.ddim_steps,
)
generated_list.append(from_model_output(batch_gen))
print(f" Batch {i//batch_size + 1}/{(num_select+batch_size-1)//batch_size} done")
generated_images = np.concatenate(generated_list, axis=0)
# Plots
plot_image_grid(generated_images, real_images, original_labels,
output_dir / "real_vs_generated.png", num_samples=args.num_samples)
# PDF
bin_centers, mean_pdf_real, std_pdf_real = calculate_pdf_batch(real_images)
_, mean_pdf_gen, std_pdf_gen = calculate_pdf_batch(generated_images)
plot_mean_std(bin_centers, mean_pdf_real, std_pdf_real, mean_pdf_gen, std_pdf_gen,
"log N_HI [cm^-2]", "PDF", "Column Density PDF", output_dir / "pdf_mean_std.png")
# Power Spectrum (skip k=0 DC component for log-scale plotting)
dk, mean_pk_real, std_pk_real = calculate_power_spectrum_batch(real_images)
_, mean_pk_gen, std_pk_gen = calculate_power_spectrum_batch(generated_images)
plot_mean_std(dk[1:], mean_pk_real[1:], std_pk_real[1:], mean_pk_gen[1:], std_pk_gen[1:],
"k [h/Mpc]", "P(k)", "Power Spectrum", output_dir / "power_spectrum_mean_std.png", yscale="log")
# Save numerical results
np.savez(
output_dir / "evaluation_data.npz",
indices=indices,
labels_original=original_labels,
bin_centers=bin_centers,
mean_pdf_real=mean_pdf_real, std_pdf_real=std_pdf_real,
mean_pdf_gen=mean_pdf_gen, std_pdf_gen=std_pdf_gen,
dk=dk,
mean_pk_real=mean_pk_real, std_pk_real=std_pk_real,
mean_pk_gen=mean_pk_gen, std_pk_gen=std_pk_gen,
)
print(f"\nEvaluation complete!")
print(f" Plots saved to: {output_dir}")
print(f" Numerical data saved to: {output_dir}/evaluation_data.npz")
if __name__ == "__main__":
main()