Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import os | |
| import numpy as np | |
| import json | |
| import argparse | |
| import sys | |
| import umap | |
| import matplotlib.pyplot as plt | |
| from celldreamer.models.class_celldreamer import ClassCellDreamer | |
| from celldreamer.models import load_config | |
| def evaluate(args): | |
| device = torch.device(args.device) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| test_path = f"{args.data_path}/test.pt" | |
| print(f"Loading test dataset from {test_path}...") | |
| if not os.path.exists(test_path): | |
| raise FileNotFoundError(f"Test dataset not found at {test_path}") | |
| test_ds = torch.load(test_path, weights_only=False) | |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2) | |
| print(f"Test Size: {len(test_ds)} samples") | |
| print(f"Initializing Model: {args.model_type}") | |
| if args.model_type.lower() == "celldreamer": | |
| model_wrapper = ClassCellDreamer(args) | |
| else: | |
| raise ValueError(f"Unknown model type: {args.model_type}") | |
| model_wrapper.load(args.checkpoint_path) | |
| model_wrapper.model.eval() | |
| test_recon_losses = [] | |
| test_dynamics_losses = [] | |
| test_posterior_kl_losses = [] | |
| test_total_losses = [] | |
| all_latents = [] | |
| print("Running inference...") | |
| with torch.no_grad(): | |
| for batch in tqdm(test_loader, desc="Evaluating"): | |
| x_t = batch['x_t'].to(device) | |
| x_next = batch['x_next'].to(device) | |
| outputs = model_wrapper.model(x_t) | |
| target_mean, target_std = model_wrapper.model.encoder(x_next) | |
| recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t) | |
| dyn_loss = model_wrapper.get_kl_loss( | |
| target_mean, target_std, | |
| outputs["prior_next_mean"], outputs["prior_next_std"] | |
| ) | |
| # Add posterior KL for consistency with training | |
| zeros = torch.zeros_like(outputs["post_mean"]) | |
| ones = torch.ones_like(outputs["post_std"]) | |
| post_kl = model_wrapper.get_kl_loss( | |
| outputs["post_mean"], outputs["post_std"], | |
| zeros, ones | |
| ) | |
| # Apply same free bits constraint as training | |
| free_bits_per_dim = 0.1 | |
| min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] | |
| post_kl = torch.clamp(post_kl, min=min_kl) | |
| dyn_loss = torch.clamp(dyn_loss, min=min_kl) | |
| # Use same loss computation as training | |
| total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl) | |
| test_recon_losses.append(recon_loss.item()) | |
| test_dynamics_losses.append(dyn_loss.item()) | |
| test_posterior_kl_losses.append(post_kl.item()) | |
| test_total_losses.append(total_loss.item()) | |
| all_latents.append(outputs["post_mean"].cpu()) | |
| metrics = { | |
| "model": args.model_type, | |
| "checkpoint": args.checkpoint_path, | |
| "test_samples": len(test_ds), | |
| "metrics": { | |
| "avg_total_loss": float(np.mean(test_total_losses)), | |
| "avg_recon_loss_mse": float(np.mean(test_recon_losses)), | |
| "avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)), | |
| "avg_posterior_kl": float(np.mean(test_posterior_kl_losses)), | |
| "std_total_loss": float(np.std(test_total_losses)) | |
| } | |
| } | |
| print("Results:") | |
| print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}") | |
| print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}") | |
| print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}") | |
| print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}") | |
| output_file_path = os.path.join(args.output_dir, args.output_filename) | |
| with open(output_file_path, 'w') as f: | |
| json.dump(metrics, f, indent=4) | |
| print(f"\nResults saved to: {output_file_path}") | |
| print("Generating UMAP visualization...") | |
| latents_tensor = torch.cat(all_latents) | |
| reducer = umap.UMAP(n_components=2) | |
| coords = reducer.fit_transform(latents_tensor.numpy()) | |
| plt.figure(figsize=(10, 8)) | |
| plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5) | |
| plt.title("Latent Space Visualization") | |
| umap_path = os.path.join(args.output_dir, "latent_umap.png") | |
| plt.savefig(umap_path) | |
| plt.close() | |
| print(f"UMAP plot saved to {umap_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Evaluation script for celldreamer") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="celldreamer/config/eval_config.yml", | |
| help="Path to the YAML configuration file" | |
| ) | |
| args = parser.parse_args() | |
| config = load_config(args.config) | |
| evaluate(config) |