import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import os import numpy as np from datetime import datetime import argparse from celldreamer.models.class_celldreamer import ClassCellDreamer from celldreamer.models import load_config def train(args): device = torch.device(args.device) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}") print(f"Loading datasets from {args.data_path}") train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False) val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) print(f"Train Size: {len(train_ds)} samples") print(f"Val Size: {len(val_ds)} samples") print(f"Model: {args.model_type}") if args.model_type.lower() == "celldreamer": model_wrapper = ClassCellDreamer(args) else: raise ValueError(f"Unknown model type: {args.model_type}") global_step = 0 best_val_loss = float('inf') best_val_mse = float('inf') # Track best validation MSE separately for epoch in range(1, args.epochs + 1): # --- TRAIN --- model_wrapper.model.train() train_mse = [] train_kl = [] train_posterior_kl = [] train_total = [] loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]") for batch in loop: x_t = batch['x_t'].to(device) x_next = batch['x_next'].to(device) logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs) train_total.append(logs['loss']) train_mse.append(logs['recon_loss']) train_kl.append(logs['dynamics_loss']) train_posterior_kl.append(logs.get('posterior_kl', 0)) global_step += 1 if global_step % args.log_interval == 0: writer.add_scalar("Step/Total_Loss", logs['loss'], global_step) writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step) writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step) writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step) loop.set_postfix(loss=logs['loss']) # --- VALIDATION --- model_wrapper.model.eval() val_mse = [] val_kl = [] val_posterior_kl = [] val_total = [] with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "): x_t = batch['x_t'].to(device) x_next = batch['x_next'].to(device) outputs = model_wrapper.model(x_t) target_mean, target_std = model_wrapper.model.encoder(x_next) recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t) dyn_loss = model_wrapper.get_kl_loss( target_mean, target_std, outputs["prior_next_mean"], outputs["prior_next_std"] ) # Add posterior KL for consistency with training zeros = torch.zeros_like(outputs["post_mean"]) ones = torch.ones_like(outputs["post_std"]) post_kl = model_wrapper.get_kl_loss( outputs["post_mean"], outputs["post_std"], zeros, ones ) # Apply same free bits constraint as training free_bits_per_dim = 0.1 min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] post_kl = torch.clamp(post_kl, min=min_kl) dyn_loss = torch.clamp(dyn_loss, min=min_kl) # Compute KL weight same as training warmup_period = args.epochs // 2 kl_weight = min(1.0, (epoch / warmup_period)) effective_kl = model_wrapper.kl_scale * kl_weight total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl) val_total.append(total_val_loss.item()) val_mse.append(recon_loss.item()) val_kl.append(dyn_loss.item()) val_posterior_kl.append(post_kl.item()) # --- STATS --- avg_train_loss = np.mean(train_total) avg_val_loss = np.mean(val_total) writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch) writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch) writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch) # Calculate KL contribution to understand why validation loss isn't dropping warmup_period = args.epochs // 2 kl_weight = min(1.0, (epoch / warmup_period)) effective_kl = model_wrapper.kl_scale * kl_weight val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl)) train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl)) print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}") print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}") if epoch % args.save_freq == 0: model_wrapper.save(f"{args.save_dir}/last.pth") avg_val_mse = np.mean(val_mse) if avg_val_loss < best_val_loss: print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})") best_val_loss = avg_val_loss # Also track best validation MSE (more meaningful metric) if avg_val_mse < best_val_mse: print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model") best_val_mse = avg_val_mse model_wrapper.save(f"{args.save_dir}/best.pth") writer.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="trainig script for celldreamer") parser.add_argument( "--config", type=str, default="celldreamer/config/train_config.yml", help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)" ) args = parser.parse_args() config = load_config(args.config) train(config)