Spaces:
Sleeping
Sleeping
File size: 7,659 Bytes
e59f78e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
import numpy as np
from datetime import datetime
import argparse
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
def train(args):
device = torch.device(args.device)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}")
print(f"Loading datasets from {args.data_path}")
train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False)
val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
print(f"Train Size: {len(train_ds)} samples")
print(f"Val Size: {len(val_ds)} samples")
print(f"Model: {args.model_type}")
if args.model_type.lower() == "celldreamer":
model_wrapper = ClassCellDreamer(args)
else:
raise ValueError(f"Unknown model type: {args.model_type}")
global_step = 0
best_val_loss = float('inf')
best_val_mse = float('inf') # Track best validation MSE separately
for epoch in range(1, args.epochs + 1):
# --- TRAIN ---
model_wrapper.model.train()
train_mse = []
train_kl = []
train_posterior_kl = []
train_total = []
loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]")
for batch in loop:
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs)
train_total.append(logs['loss'])
train_mse.append(logs['recon_loss'])
train_kl.append(logs['dynamics_loss'])
train_posterior_kl.append(logs.get('posterior_kl', 0))
global_step += 1
if global_step % args.log_interval == 0:
writer.add_scalar("Step/Total_Loss", logs['loss'], global_step)
writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step)
writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step)
writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step)
loop.set_postfix(loss=logs['loss'])
# --- VALIDATION ---
model_wrapper.model.eval()
val_mse = []
val_kl = []
val_posterior_kl = []
val_total = []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "):
x_t = batch['x_t'].to(device)
x_next = batch['x_next'].to(device)
outputs = model_wrapper.model(x_t)
target_mean, target_std = model_wrapper.model.encoder(x_next)
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
dyn_loss = model_wrapper.get_kl_loss(
target_mean, target_std,
outputs["prior_next_mean"], outputs["prior_next_std"]
)
# Add posterior KL for consistency with training
zeros = torch.zeros_like(outputs["post_mean"])
ones = torch.ones_like(outputs["post_std"])
post_kl = model_wrapper.get_kl_loss(
outputs["post_mean"], outputs["post_std"],
zeros, ones
)
# Apply same free bits constraint as training
free_bits_per_dim = 0.1
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
post_kl = torch.clamp(post_kl, min=min_kl)
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
# Compute KL weight same as training
warmup_period = args.epochs // 2
kl_weight = min(1.0, (epoch / warmup_period))
effective_kl = model_wrapper.kl_scale * kl_weight
total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl)
val_total.append(total_val_loss.item())
val_mse.append(recon_loss.item())
val_kl.append(dyn_loss.item())
val_posterior_kl.append(post_kl.item())
# --- STATS ---
avg_train_loss = np.mean(train_total)
avg_val_loss = np.mean(val_total)
writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch)
writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch)
writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch)
# Calculate KL contribution to understand why validation loss isn't dropping
warmup_period = args.epochs // 2
kl_weight = min(1.0, (epoch / warmup_period))
effective_kl = model_wrapper.kl_scale * kl_weight
val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl))
train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl))
print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}")
print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}")
if epoch % args.save_freq == 0:
model_wrapper.save(f"{args.save_dir}/last.pth")
avg_val_mse = np.mean(val_mse)
if avg_val_loss < best_val_loss:
print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})")
best_val_loss = avg_val_loss
# Also track best validation MSE (more meaningful metric)
if avg_val_mse < best_val_mse:
print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model")
best_val_mse = avg_val_mse
model_wrapper.save(f"{args.save_dir}/best.pth")
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="trainig script for celldreamer")
parser.add_argument(
"--config",
type=str,
default="celldreamer/config/train_config.yml",
help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)"
)
args = parser.parse_args()
config = load_config(args.config)
train(config)
|