Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| import os | |
| import numpy as np | |
| from datetime import datetime | |
| import argparse | |
| from celldreamer.models.class_celldreamer import ClassCellDreamer | |
| from celldreamer.models import load_config | |
| def train(args): | |
| device = torch.device(args.device) | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| os.makedirs(args.log_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}") | |
| print(f"Loading datasets from {args.data_path}") | |
| train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False) | |
| val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False) | |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| print(f"Train Size: {len(train_ds)} samples") | |
| print(f"Val Size: {len(val_ds)} samples") | |
| print(f"Model: {args.model_type}") | |
| if args.model_type.lower() == "celldreamer": | |
| model_wrapper = ClassCellDreamer(args) | |
| else: | |
| raise ValueError(f"Unknown model type: {args.model_type}") | |
| global_step = 0 | |
| best_val_loss = float('inf') | |
| best_val_mse = float('inf') # Track best validation MSE separately | |
| for epoch in range(1, args.epochs + 1): | |
| # --- TRAIN --- | |
| model_wrapper.model.train() | |
| train_mse = [] | |
| train_kl = [] | |
| train_posterior_kl = [] | |
| train_total = [] | |
| loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]") | |
| for batch in loop: | |
| x_t = batch['x_t'].to(device) | |
| x_next = batch['x_next'].to(device) | |
| logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs) | |
| train_total.append(logs['loss']) | |
| train_mse.append(logs['recon_loss']) | |
| train_kl.append(logs['dynamics_loss']) | |
| train_posterior_kl.append(logs.get('posterior_kl', 0)) | |
| global_step += 1 | |
| if global_step % args.log_interval == 0: | |
| writer.add_scalar("Step/Total_Loss", logs['loss'], global_step) | |
| writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step) | |
| writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step) | |
| writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step) | |
| loop.set_postfix(loss=logs['loss']) | |
| # --- VALIDATION --- | |
| model_wrapper.model.eval() | |
| val_mse = [] | |
| val_kl = [] | |
| val_posterior_kl = [] | |
| val_total = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "): | |
| x_t = batch['x_t'].to(device) | |
| x_next = batch['x_next'].to(device) | |
| outputs = model_wrapper.model(x_t) | |
| target_mean, target_std = model_wrapper.model.encoder(x_next) | |
| recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t) | |
| dyn_loss = model_wrapper.get_kl_loss( | |
| target_mean, target_std, | |
| outputs["prior_next_mean"], outputs["prior_next_std"] | |
| ) | |
| # Add posterior KL for consistency with training | |
| zeros = torch.zeros_like(outputs["post_mean"]) | |
| ones = torch.ones_like(outputs["post_std"]) | |
| post_kl = model_wrapper.get_kl_loss( | |
| outputs["post_mean"], outputs["post_std"], | |
| zeros, ones | |
| ) | |
| # Apply same free bits constraint as training | |
| free_bits_per_dim = 0.1 | |
| min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] | |
| post_kl = torch.clamp(post_kl, min=min_kl) | |
| dyn_loss = torch.clamp(dyn_loss, min=min_kl) | |
| # Compute KL weight same as training | |
| warmup_period = args.epochs // 2 | |
| kl_weight = min(1.0, (epoch / warmup_period)) | |
| effective_kl = model_wrapper.kl_scale * kl_weight | |
| total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl) | |
| val_total.append(total_val_loss.item()) | |
| val_mse.append(recon_loss.item()) | |
| val_kl.append(dyn_loss.item()) | |
| val_posterior_kl.append(post_kl.item()) | |
| # --- STATS --- | |
| avg_train_loss = np.mean(train_total) | |
| avg_val_loss = np.mean(val_total) | |
| writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch) | |
| writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch) | |
| writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch) | |
| # Calculate KL contribution to understand why validation loss isn't dropping | |
| warmup_period = args.epochs // 2 | |
| kl_weight = min(1.0, (epoch / warmup_period)) | |
| effective_kl = model_wrapper.kl_scale * kl_weight | |
| val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl)) | |
| train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl)) | |
| print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}") | |
| print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}") | |
| if epoch % args.save_freq == 0: | |
| model_wrapper.save(f"{args.save_dir}/last.pth") | |
| avg_val_mse = np.mean(val_mse) | |
| if avg_val_loss < best_val_loss: | |
| print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})") | |
| best_val_loss = avg_val_loss | |
| # Also track best validation MSE (more meaningful metric) | |
| if avg_val_mse < best_val_mse: | |
| print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model") | |
| best_val_mse = avg_val_mse | |
| model_wrapper.save(f"{args.save_dir}/best.pth") | |
| writer.close() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="trainig script for celldreamer") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="celldreamer/config/train_config.yml", | |
| help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)" | |
| ) | |
| args = parser.parse_args() | |
| config = load_config(args.config) | |
| train(config) | |