RobroKools's picture
Upload 44 files
e59f78e verified
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
import json
import argparse
import sys
import umap
import matplotlib.pyplot as plt
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
def evaluate(args):
device = torch.device(args.device)
os.makedirs(args.output_dir, exist_ok=True)
test_path = f"{args.data_path}/test.pt"
print(f"Loading test dataset from {test_path}...")
if not os.path.exists(test_path):
raise FileNotFoundError(f"Test dataset not found at {test_path}")
test_ds = torch.load(test_path, weights_only=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)
print(f"Test Size: {len(test_ds)} samples")
print(f"Initializing Model: {args.model_type}")
if args.model_type.lower() == "celldreamer":
model_wrapper = ClassCellDreamer(args)
else:
raise ValueError(f"Unknown model type: {args.model_type}")
model_wrapper.load(args.checkpoint_path)
model_wrapper.model.eval()
test_recon_losses = []
test_dynamics_losses = []
test_posterior_kl_losses = []
test_total_losses = []
all_latents = []
print("Running inference...")
with torch.no_grad():
for batch in tqdm(test_loader, desc="Evaluating"):
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
outputs = model_wrapper.model(x_t)
target_mean, target_std = model_wrapper.model.encoder(x_next)
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
dyn_loss = model_wrapper.get_kl_loss(
target_mean, target_std,
outputs["prior_next_mean"], outputs["prior_next_std"]
)
# Add posterior KL for consistency with training
zeros = torch.zeros_like(outputs["post_mean"])
ones = torch.ones_like(outputs["post_std"])
post_kl = model_wrapper.get_kl_loss(
outputs["post_mean"], outputs["post_std"],
zeros, ones
)
# Apply same free bits constraint as training
free_bits_per_dim = 0.1
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
post_kl = torch.clamp(post_kl, min=min_kl)
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
# Use same loss computation as training
total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl)
test_recon_losses.append(recon_loss.item())
test_dynamics_losses.append(dyn_loss.item())
test_posterior_kl_losses.append(post_kl.item())
test_total_losses.append(total_loss.item())
all_latents.append(outputs["post_mean"].cpu())
metrics = {
"model": args.model_type,
"checkpoint": args.checkpoint_path,
"test_samples": len(test_ds),
"metrics": {
"avg_total_loss": float(np.mean(test_total_losses)),
"avg_recon_loss_mse": float(np.mean(test_recon_losses)),
"avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)),
"avg_posterior_kl": float(np.mean(test_posterior_kl_losses)),
"std_total_loss": float(np.std(test_total_losses))
}
}
print("Results:")
print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}")
print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}")
print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}")
print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}")
output_file_path = os.path.join(args.output_dir, args.output_filename)
with open(output_file_path, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"\nResults saved to: {output_file_path}")
print("Generating UMAP visualization...")
latents_tensor = torch.cat(all_latents)
reducer = umap.UMAP(n_components=2)
coords = reducer.fit_transform(latents_tensor.numpy())
plt.figure(figsize=(10, 8))
plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5)
plt.title("Latent Space Visualization")
umap_path = os.path.join(args.output_dir, "latent_umap.png")
plt.savefig(umap_path)
plt.close()
print(f"UMAP plot saved to {umap_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluation script for celldreamer")
parser.add_argument(
"--config",
type=str,
default="celldreamer/config/eval_config.yml",
help="Path to the YAML configuration file"
)
args = parser.parse_args()
config = load_config(args.config)
evaluate(config)