Spaces:
Sleeping
Sleeping
File size: 5,138 Bytes
e59f78e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
import json
import argparse
import sys
import umap
import matplotlib.pyplot as plt
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
def evaluate(args):
device = torch.device(args.device)
os.makedirs(args.output_dir, exist_ok=True)
test_path = f"{args.data_path}/test.pt"
print(f"Loading test dataset from {test_path}...")
if not os.path.exists(test_path):
raise FileNotFoundError(f"Test dataset not found at {test_path}")
test_ds = torch.load(test_path, weights_only=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)
print(f"Test Size: {len(test_ds)} samples")
print(f"Initializing Model: {args.model_type}")
if args.model_type.lower() == "celldreamer":
model_wrapper = ClassCellDreamer(args)
else:
raise ValueError(f"Unknown model type: {args.model_type}")
model_wrapper.load(args.checkpoint_path)
model_wrapper.model.eval()
test_recon_losses = []
test_dynamics_losses = []
test_posterior_kl_losses = []
test_total_losses = []
all_latents = []
print("Running inference...")
with torch.no_grad():
for batch in tqdm(test_loader, desc="Evaluating"):
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
outputs = model_wrapper.model(x_t)
target_mean, target_std = model_wrapper.model.encoder(x_next)
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
dyn_loss = model_wrapper.get_kl_loss(
target_mean, target_std,
outputs["prior_next_mean"], outputs["prior_next_std"]
)
# Add posterior KL for consistency with training
zeros = torch.zeros_like(outputs["post_mean"])
ones = torch.ones_like(outputs["post_std"])
post_kl = model_wrapper.get_kl_loss(
outputs["post_mean"], outputs["post_std"],
zeros, ones
)
# Apply same free bits constraint as training
free_bits_per_dim = 0.1
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
post_kl = torch.clamp(post_kl, min=min_kl)
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
# Use same loss computation as training
total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl)
test_recon_losses.append(recon_loss.item())
test_dynamics_losses.append(dyn_loss.item())
test_posterior_kl_losses.append(post_kl.item())
test_total_losses.append(total_loss.item())
all_latents.append(outputs["post_mean"].cpu())
metrics = {
"model": args.model_type,
"checkpoint": args.checkpoint_path,
"test_samples": len(test_ds),
"metrics": {
"avg_total_loss": float(np.mean(test_total_losses)),
"avg_recon_loss_mse": float(np.mean(test_recon_losses)),
"avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)),
"avg_posterior_kl": float(np.mean(test_posterior_kl_losses)),
"std_total_loss": float(np.std(test_total_losses))
}
}
print("Results:")
print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}")
print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}")
print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}")
print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}")
output_file_path = os.path.join(args.output_dir, args.output_filename)
with open(output_file_path, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"\nResults saved to: {output_file_path}")
print("Generating UMAP visualization...")
latents_tensor = torch.cat(all_latents)
reducer = umap.UMAP(n_components=2)
coords = reducer.fit_transform(latents_tensor.numpy())
plt.figure(figsize=(10, 8))
plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5)
plt.title("Latent Space Visualization")
umap_path = os.path.join(args.output_dir, "latent_umap.png")
plt.savefig(umap_path)
plt.close()
print(f"UMAP plot saved to {umap_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluation script for celldreamer")
parser.add_argument(
"--config",
type=str,
default="celldreamer/config/eval_config.yml",
help="Path to the YAML configuration file"
)
args = parser.parse_args()
config = load_config(args.config)
evaluate(config) |