2264K's picture
Upload train.py with huggingface_hub
db7cbef verified
"""
Phase 2-A Toy PoC: 3-way Training (Vision + Audio + Text)
Loss = L_vision(MSE) + L_audio(MSE) + L_text(CE)
"""
import os
import json
import time
import argparse
import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast
from model import TriModalModel, CONFIG
from data import (
VisionDataset, AudioDataset, TextDataset, TriModalDataLoader,
)
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data
if args.synthetic:
from torch.utils.data import Dataset
class FakeVision(Dataset):
def __len__(self): return 500
def __getitem__(self, i):
p = torch.randn(80, 256)
m = torch.arange(40)
return {"all_patches": p, "target_patches": p[:40], "mask_idx": m}
class FakeAudio(Dataset):
def __len__(self): return 500
def __getitem__(self, i):
f = torch.randn(200, 768)
m = torch.arange(100)
return {"all_features": f, "target_features": f[:100], "mask_idx": m}
class FakeText(Dataset):
def __len__(self): return 500
def __getitem__(self, i):
return torch.randint(0, 10000, (128,)), torch.randint(0, 10000, (128,))
v_ds = FakeVision()
a_ds = FakeAudio()
t_ds = FakeText()
else:
v_ds = VisionDataset(path=args.vision_path, n_frames=5,
patch_size=CONFIG["patch_size"], mask_ratio=0.5)
a_ds = AudioDataset(path=args.audio_path, mask_ratio=0.5)
t_ds = TextDataset(path=args.text_path, seq_len=128,
vocab_size=CONFIG["vocab_size"])
loader = TriModalDataLoader(v_ds, a_ds, t_ds, batch_size=args.batch_size)
print(f"Batches/epoch: {len(loader)}")
# Model
model = TriModalModel(CONFIG).to(device)
params = model.count_params()
print(f"Parameters: {params['total']:,} ({params['total']/1e6:.1f}M)")
# Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs * len(loader), eta_min=1e-6
)
scaler = GradScaler("cuda", enabled=(device.type == "cuda"))
# Loss
criterion_vision = nn.MSELoss()
criterion_audio = nn.MSELoss()
criterion_text = nn.CrossEntropyLoss()
# Checkpoint dir
os.makedirs(args.ckpt_dir, exist_ok=True)
log = []
best_total = float("inf")
for epoch in range(args.epochs):
model.train()
ep_v, ep_a, ep_t, ep_total = 0.0, 0.0, 0.0, 0.0
n_steps = 0
t0 = time.time()
for v_batch, a_batch, t_batch in loader:
# Vision
all_patches = v_batch["all_patches"].to(device)
target_patches = v_batch["target_patches"].to(device)
v_mask_idx = v_batch["mask_idx"]
# Audio
all_features = a_batch["all_features"].to(device)
target_features = a_batch["target_features"].to(device)
a_mask_idx = a_batch["mask_idx"]
# Text
t_x, t_y = t_batch[0].to(device), t_batch[1].to(device)
with autocast(device_type="cuda", dtype=torch.bfloat16,
enabled=(device.type == "cuda")):
vision_out, audio_out, text_out = model(
all_patches, all_features, t_x
)
# Vision JEPA loss (masked patches only)
v_midx = v_mask_idx[0].to(device)
loss_v = criterion_vision(vision_out[:, v_midx, :], target_patches)
# Audio JEPA loss (masked tokens only)
a_midx = a_mask_idx[0].to(device)
loss_a = criterion_audio(audio_out[:, a_midx, :], target_features)
# Text LLM loss
loss_t = criterion_text(
text_out.reshape(-1, CONFIG["vocab_size"]),
t_y.reshape(-1)
)
loss = loss_v + loss_a + loss_t
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
ep_v += loss_v.item()
ep_a += loss_a.item()
ep_t += loss_t.item()
ep_total += loss.item()
n_steps += 1
# Epoch stats
avg_v = ep_v / n_steps
avg_a = ep_a / n_steps
avg_t = ep_t / n_steps
avg_total = ep_total / n_steps
elapsed = time.time() - t0
lr_now = scheduler.get_last_lr()[0]
entry = {
"epoch": epoch,
"loss_vision": avg_v,
"loss_audio": avg_a,
"loss_text": avg_t,
"loss_total": avg_total,
"lr": lr_now,
"time_s": elapsed,
}
log.append(entry)
print(f"Epoch {epoch:3d}/{args.epochs} | "
f"V={avg_v:.4f} A={avg_a:.4f} T={avg_t:.4f} Total={avg_total:.4f} | "
f"lr={lr_now:.2e} | {elapsed:.1f}s")
# Save best
if avg_total < best_total:
best_total = avg_total
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"config": CONFIG,
"log": log,
}, os.path.join(args.ckpt_dir, "best.pt"))
# Save periodic
if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"config": CONFIG,
"log": log,
}, os.path.join(args.ckpt_dir, f"ckpt_epoch{epoch}.pt"))
# Save log
with open(os.path.join(args.ckpt_dir, "training_log.json"), "w") as f:
json.dump(log, f, indent=2)
print(f"\nTraining done. Best total loss: {best_total:.4f}")
print(f"Final: V={log[-1]['loss_vision']:.4f} A={log[-1]['loss_audio']:.4f} T={log[-1]['loss_text']:.4f}")
return log
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--save_every", type=int, default=10)
p.add_argument("--ckpt_dir", type=str, default="checkpoints")
p.add_argument("--synthetic", action="store_true")
p.add_argument("--vision_path", type=str,
default="../phase1/mnist_test_seq.npy")
p.add_argument("--audio_path", type=str,
default="../phase1b/audio_features/audio_features.npy")
p.add_argument("--text_path", type=str,
default="../phase1/TinyStoriesV2-GPT4-train.txt")
args = p.parse_args()
train(args)