import torch from torch.utils.data import DataLoader from tqdm import tqdm import os import numpy as np import json import argparse import sys import umap import matplotlib.pyplot as plt from celldreamer.models.class_celldreamer import ClassCellDreamer from celldreamer.models import load_config def evaluate(args): device = torch.device(args.device) os.makedirs(args.output_dir, exist_ok=True) test_path = f"{args.data_path}/test.pt" print(f"Loading test dataset from {test_path}...") if not os.path.exists(test_path): raise FileNotFoundError(f"Test dataset not found at {test_path}") test_ds = torch.load(test_path, weights_only=False) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2) print(f"Test Size: {len(test_ds)} samples") print(f"Initializing Model: {args.model_type}") if args.model_type.lower() == "celldreamer": model_wrapper = ClassCellDreamer(args) else: raise ValueError(f"Unknown model type: {args.model_type}") model_wrapper.load(args.checkpoint_path) model_wrapper.model.eval() test_recon_losses = [] test_dynamics_losses = [] test_posterior_kl_losses = [] test_total_losses = [] all_latents = [] print("Running inference...") with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating"): x_t = batch['x_t'].to(device) x_next = batch['x_next'].to(device) outputs = model_wrapper.model(x_t) target_mean, target_std = model_wrapper.model.encoder(x_next) recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t) dyn_loss = model_wrapper.get_kl_loss( target_mean, target_std, outputs["prior_next_mean"], outputs["prior_next_std"] ) # Add posterior KL for consistency with training zeros = torch.zeros_like(outputs["post_mean"]) ones = torch.ones_like(outputs["post_std"]) post_kl = model_wrapper.get_kl_loss( outputs["post_mean"], outputs["post_std"], zeros, ones ) # Apply same free bits constraint as training free_bits_per_dim = 0.1 min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] post_kl = torch.clamp(post_kl, min=min_kl) dyn_loss = torch.clamp(dyn_loss, min=min_kl) # Use same loss computation as training total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl) test_recon_losses.append(recon_loss.item()) test_dynamics_losses.append(dyn_loss.item()) test_posterior_kl_losses.append(post_kl.item()) test_total_losses.append(total_loss.item()) all_latents.append(outputs["post_mean"].cpu()) metrics = { "model": args.model_type, "checkpoint": args.checkpoint_path, "test_samples": len(test_ds), "metrics": { "avg_total_loss": float(np.mean(test_total_losses)), "avg_recon_loss_mse": float(np.mean(test_recon_losses)), "avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)), "avg_posterior_kl": float(np.mean(test_posterior_kl_losses)), "std_total_loss": float(np.std(test_total_losses)) } } print("Results:") print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}") print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}") print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}") print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}") output_file_path = os.path.join(args.output_dir, args.output_filename) with open(output_file_path, 'w') as f: json.dump(metrics, f, indent=4) print(f"\nResults saved to: {output_file_path}") print("Generating UMAP visualization...") latents_tensor = torch.cat(all_latents) reducer = umap.UMAP(n_components=2) coords = reducer.fit_transform(latents_tensor.numpy()) plt.figure(figsize=(10, 8)) plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5) plt.title("Latent Space Visualization") umap_path = os.path.join(args.output_dir, "latent_umap.png") plt.savefig(umap_path) plt.close() print(f"UMAP plot saved to {umap_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluation script for celldreamer") parser.add_argument( "--config", type=str, default="celldreamer/config/eval_config.yml", help="Path to the YAML configuration file" ) args = parser.parse_args() config = load_config(args.config) evaluate(config)