the-well-diffusion / train_jepa.py
AlexWortega's picture
Upload train_jepa.py with huggingface_hub
47f8396 verified
#!/usr/bin/env python3
"""
Training script for Spatial JEPA on The Well datasets.
Usage:
python train_jepa.py --dataset turbulent_radiative_layer_2D --batch_size 16
python train_jepa.py --dataset active_matter --streaming --epochs 50
"""
import argparse
import logging
import math
import os
import time
import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from data_pipeline import create_dataloader, prepare_batch, get_channel_info
from jepa import JEPA
logging.basicConfig(level=logging.WARNING) # suppress noisy library logs
logger = logging.getLogger("train_jepa")
logger.setLevel(logging.INFO)
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"))
logger.addHandler(_handler)
logger.propagate = False
def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6):
if step < warmup:
return base_lr * step / max(warmup, 1)
progress = (step - warmup) / max(total - warmup, 1)
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi))
def cosine_ema(step, total, start=0.996, end=1.0):
"""EMA decay schedule: ramps from start to end over training."""
progress = step / max(total, 1)
return end - (end - start) * (1 + math.cos(progress * math.pi)) / 2
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
# ---- Data ----
logger.info(f"Loading dataset: {args.dataset} (streaming={args.streaming})")
train_loader, train_dataset = create_dataloader(
dataset_name=args.dataset,
split="train",
batch_size=args.batch_size,
n_steps_input=args.n_input,
n_steps_output=args.n_output,
num_workers=args.workers,
streaming=args.streaming,
local_path=args.local_path,
)
ch_info = get_channel_info(train_dataset)
logger.info(f"Channel info: {ch_info}")
c_in = ch_info["input_channels"]
c_out = ch_info["output_channels"]
# JEPA uses same channel count for input and target
# If they differ, we use max and pad in forward
assert c_in == c_out, (
f"JEPA expects same input/output channels, got {c_in} vs {c_out}. "
"Set n_input == n_output or use different architecture."
)
# ---- Model ----
model = JEPA(
in_channels=c_in,
latent_channels=args.latent_ch,
base_ch=args.base_ch,
pred_hidden=args.pred_hidden,
ema_decay=args.ema_start,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Trainable parameters: {n_params:,}")
# ---- Optimizer ----
# Only optimize online encoder + predictor (target is EMA)
trainable = list(model.online_encoder.parameters()) + list(model.predictor.parameters())
optimizer = torch.optim.AdamW(trainable, lr=args.lr, weight_decay=args.wd)
scaler = GradScaler("cuda", enabled=args.amp)
# ---- Resume ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scaler.load_state_dict(ckpt["scaler"])
start_epoch = ckpt["epoch"] + 1
global_step = ckpt["global_step"]
logger.info(f"Resumed from epoch {start_epoch}, step {global_step}")
# ---- Training ----
os.makedirs(args.ckpt_dir, exist_ok=True)
total_steps = args.epochs * len(train_loader)
try:
import wandb
if args.wandb:
wandb.init(project="the-well-jepa", config=vars(args))
except ImportError:
args.wandb = False
logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps")
for epoch in range(start_epoch, args.epochs):
model.train()
epoch_loss = 0.0
epoch_metrics = {"sim": 0, "var": 0, "cov": 0}
n_batches = 0
t0 = time.time()
pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
for batch in pbar:
try:
x_input, x_target = prepare_batch(batch, device)
except Exception as e:
logger.warning(f"Batch error: {e}, skipping")
continue
# LR schedule
lr = cosine_lr(global_step, args.warmup, total_steps, args.lr)
for pg in optimizer.param_groups:
pg["lr"] = lr
# EMA schedule
ema = cosine_ema(global_step, total_steps, args.ema_start, args.ema_end)
model.set_ema_decay(ema)
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp):
loss, metrics = model.compute_loss(x_input, x_target)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(trainable, args.grad_clip)
scaler.step(optimizer)
scaler.update()
# EMA update
model.update_target()
epoch_loss += loss.item()
for k in epoch_metrics:
epoch_metrics[k] += metrics[k]
n_batches += 1
global_step += 1
pbar.set_postfix(
loss=f"{loss.item():.4f}",
sim=f"{metrics['sim']:.4f}",
ema=f"{ema:.4f}",
)
if args.wandb:
wandb.log(
{"train/loss": loss.item(), "train/lr": lr, "train/ema": ema, **{f"train/{k}": v for k, v in metrics.items()}},
step=global_step,
)
avg_loss = epoch_loss / max(n_batches, 1)
avg_m = {k: v / max(n_batches, 1) for k, v in epoch_metrics.items()}
elapsed = time.time() - t0
logger.info(
f"Epoch {epoch}: loss={avg_loss:.4f}, sim={avg_m['sim']:.4f}, "
f"var={avg_m['var']:.4f}, cov={avg_m['cov']:.4f}, "
f"time={elapsed:.1f}s"
)
# Checkpoint
if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
ckpt_path = os.path.join(args.ckpt_dir, f"jepa_ep{epoch:04d}.pt")
torch.save(
{
"epoch": epoch,
"global_step": global_step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"args": vars(args),
"ch_info": ch_info,
},
ckpt_path,
)
logger.info(f"Saved {ckpt_path}")
logger.info("Training complete.")
def main():
p = argparse.ArgumentParser(description="Train Spatial JEPA on The Well")
# Data
p.add_argument("--dataset", default="turbulent_radiative_layer_2D")
p.add_argument("--streaming", action="store_true", default=True)
p.add_argument("--no-streaming", dest="streaming", action="store_false")
p.add_argument("--local_path", default=None)
p.add_argument("--batch_size", type=int, default=16)
p.add_argument("--workers", type=int, default=0)
p.add_argument("--n_input", type=int, default=1)
p.add_argument("--n_output", type=int, default=1)
# Model
p.add_argument("--latent_ch", type=int, default=128)
p.add_argument("--base_ch", type=int, default=32)
p.add_argument("--pred_hidden", type=int, default=256)
# Optimization
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--wd", type=float, default=0.05)
p.add_argument("--warmup", type=int, default=500)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--amp", action="store_true", default=True)
p.add_argument("--no-amp", dest="amp", action="store_false")
p.add_argument("--epochs", type=int, default=100)
p.add_argument("--ema_start", type=float, default=0.996)
p.add_argument("--ema_end", type=float, default=1.0)
# Checkpointing
p.add_argument("--ckpt_dir", default="checkpoints/jepa")
p.add_argument("--save_every", type=int, default=5)
p.add_argument("--resume", default=None)
# Logging
p.add_argument("--wandb", action="store_true", default=False)
args = p.parse_args()
train(args)
if __name__ == "__main__":
main()