RobroKools's picture
Upload 44 files
e59f78e verified
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
import numpy as np
from datetime import datetime
import argparse
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
def train(args):
device = torch.device(args.device)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}")
print(f"Loading datasets from {args.data_path}")
train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False)
val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
print(f"Train Size: {len(train_ds)} samples")
print(f"Val Size: {len(val_ds)} samples")
print(f"Model: {args.model_type}")
if args.model_type.lower() == "celldreamer":
model_wrapper = ClassCellDreamer(args)
else:
raise ValueError(f"Unknown model type: {args.model_type}")
global_step = 0
best_val_loss = float('inf')
best_val_mse = float('inf') # Track best validation MSE separately
for epoch in range(1, args.epochs + 1):
# --- TRAIN ---
model_wrapper.model.train()
train_mse = []
train_kl = []
train_posterior_kl = []
train_total = []
loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]")
for batch in loop:
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs)
train_total.append(logs['loss'])
train_mse.append(logs['recon_loss'])
train_kl.append(logs['dynamics_loss'])
train_posterior_kl.append(logs.get('posterior_kl', 0))
global_step += 1
if global_step % args.log_interval == 0:
writer.add_scalar("Step/Total_Loss", logs['loss'], global_step)
writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step)
writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step)
writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step)
loop.set_postfix(loss=logs['loss'])
# --- VALIDATION ---
model_wrapper.model.eval()
val_mse = []
val_kl = []
val_posterior_kl = []
val_total = []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "):
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
outputs = model_wrapper.model(x_t)
target_mean, target_std = model_wrapper.model.encoder(x_next)
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
dyn_loss = model_wrapper.get_kl_loss(
target_mean, target_std,
outputs["prior_next_mean"], outputs["prior_next_std"]
)
# Add posterior KL for consistency with training
zeros = torch.zeros_like(outputs["post_mean"])
ones = torch.ones_like(outputs["post_std"])
post_kl = model_wrapper.get_kl_loss(
outputs["post_mean"], outputs["post_std"],
zeros, ones
)
# Apply same free bits constraint as training
free_bits_per_dim = 0.1
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
post_kl = torch.clamp(post_kl, min=min_kl)
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
# Compute KL weight same as training
warmup_period = args.epochs // 2
kl_weight = min(1.0, (epoch / warmup_period))
effective_kl = model_wrapper.kl_scale * kl_weight
total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl)
val_total.append(total_val_loss.item())
val_mse.append(recon_loss.item())
val_kl.append(dyn_loss.item())
val_posterior_kl.append(post_kl.item())
# --- STATS ---
avg_train_loss = np.mean(train_total)
avg_val_loss = np.mean(val_total)
writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch)
writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch)
writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch)
# Calculate KL contribution to understand why validation loss isn't dropping
warmup_period = args.epochs // 2
kl_weight = min(1.0, (epoch / warmup_period))
effective_kl = model_wrapper.kl_scale * kl_weight
val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl))
train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl))
print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}")
print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}")
if epoch % args.save_freq == 0:
model_wrapper.save(f"{args.save_dir}/last.pth")
avg_val_mse = np.mean(val_mse)
if avg_val_loss < best_val_loss:
print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})")
best_val_loss = avg_val_loss
# Also track best validation MSE (more meaningful metric)
if avg_val_mse < best_val_mse:
print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model")
best_val_mse = avg_val_mse
model_wrapper.save(f"{args.save_dir}/best.pth")
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="trainig script for celldreamer")
parser.add_argument(
"--config",
type=str,
default="celldreamer/config/train_config.yml",
help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)"
)
args = parser.parse_args()
config = load_config(args.config)
train(config)