| """
|
| Phase 2-A Toy PoC: 3-way Training (Vision + Audio + Text)
|
| Loss = L_vision(MSE) + L_audio(MSE) + L_text(CE)
|
| """
|
|
|
| import os
|
| import json
|
| import time
|
| import argparse
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.amp import GradScaler, autocast
|
|
|
| from model import TriModalModel, CONFIG
|
| from data import (
|
| VisionDataset, AudioDataset, TextDataset, TriModalDataLoader,
|
| )
|
|
|
|
|
| def train(args):
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Device: {device}")
|
|
|
|
|
| if args.synthetic:
|
| from torch.utils.data import Dataset
|
|
|
| class FakeVision(Dataset):
|
| def __len__(self): return 500
|
| def __getitem__(self, i):
|
| p = torch.randn(80, 256)
|
| m = torch.arange(40)
|
| return {"all_patches": p, "target_patches": p[:40], "mask_idx": m}
|
|
|
| class FakeAudio(Dataset):
|
| def __len__(self): return 500
|
| def __getitem__(self, i):
|
| f = torch.randn(200, 768)
|
| m = torch.arange(100)
|
| return {"all_features": f, "target_features": f[:100], "mask_idx": m}
|
|
|
| class FakeText(Dataset):
|
| def __len__(self): return 500
|
| def __getitem__(self, i):
|
| return torch.randint(0, 10000, (128,)), torch.randint(0, 10000, (128,))
|
|
|
| v_ds = FakeVision()
|
| a_ds = FakeAudio()
|
| t_ds = FakeText()
|
| else:
|
| v_ds = VisionDataset(path=args.vision_path, n_frames=5,
|
| patch_size=CONFIG["patch_size"], mask_ratio=0.5)
|
| a_ds = AudioDataset(path=args.audio_path, mask_ratio=0.5)
|
| t_ds = TextDataset(path=args.text_path, seq_len=128,
|
| vocab_size=CONFIG["vocab_size"])
|
|
|
| loader = TriModalDataLoader(v_ds, a_ds, t_ds, batch_size=args.batch_size)
|
| print(f"Batches/epoch: {len(loader)}")
|
|
|
|
|
| model = TriModalModel(CONFIG).to(device)
|
| params = model.count_params()
|
| print(f"Parameters: {params['total']:,} ({params['total']/1e6:.1f}M)")
|
|
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
|
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| optimizer, T_max=args.epochs * len(loader), eta_min=1e-6
|
| )
|
| scaler = GradScaler("cuda", enabled=(device.type == "cuda"))
|
|
|
|
|
| criterion_vision = nn.MSELoss()
|
| criterion_audio = nn.MSELoss()
|
| criterion_text = nn.CrossEntropyLoss()
|
|
|
|
|
| os.makedirs(args.ckpt_dir, exist_ok=True)
|
| log = []
|
| best_total = float("inf")
|
|
|
| for epoch in range(args.epochs):
|
| model.train()
|
| ep_v, ep_a, ep_t, ep_total = 0.0, 0.0, 0.0, 0.0
|
| n_steps = 0
|
| t0 = time.time()
|
|
|
| for v_batch, a_batch, t_batch in loader:
|
|
|
| all_patches = v_batch["all_patches"].to(device)
|
| target_patches = v_batch["target_patches"].to(device)
|
| v_mask_idx = v_batch["mask_idx"]
|
|
|
|
|
| all_features = a_batch["all_features"].to(device)
|
| target_features = a_batch["target_features"].to(device)
|
| a_mask_idx = a_batch["mask_idx"]
|
|
|
|
|
| t_x, t_y = t_batch[0].to(device), t_batch[1].to(device)
|
|
|
| with autocast(device_type="cuda", dtype=torch.bfloat16,
|
| enabled=(device.type == "cuda")):
|
| vision_out, audio_out, text_out = model(
|
| all_patches, all_features, t_x
|
| )
|
|
|
|
|
| v_midx = v_mask_idx[0].to(device)
|
| loss_v = criterion_vision(vision_out[:, v_midx, :], target_patches)
|
|
|
|
|
| a_midx = a_mask_idx[0].to(device)
|
| loss_a = criterion_audio(audio_out[:, a_midx, :], target_features)
|
|
|
|
|
| loss_t = criterion_text(
|
| text_out.reshape(-1, CONFIG["vocab_size"]),
|
| t_y.reshape(-1)
|
| )
|
|
|
| loss = loss_v + loss_a + loss_t
|
|
|
| optimizer.zero_grad()
|
| scaler.scale(loss).backward()
|
| scaler.unscale_(optimizer)
|
| nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| scaler.step(optimizer)
|
| scaler.update()
|
| scheduler.step()
|
|
|
| ep_v += loss_v.item()
|
| ep_a += loss_a.item()
|
| ep_t += loss_t.item()
|
| ep_total += loss.item()
|
| n_steps += 1
|
|
|
|
|
| avg_v = ep_v / n_steps
|
| avg_a = ep_a / n_steps
|
| avg_t = ep_t / n_steps
|
| avg_total = ep_total / n_steps
|
| elapsed = time.time() - t0
|
| lr_now = scheduler.get_last_lr()[0]
|
|
|
| entry = {
|
| "epoch": epoch,
|
| "loss_vision": avg_v,
|
| "loss_audio": avg_a,
|
| "loss_text": avg_t,
|
| "loss_total": avg_total,
|
| "lr": lr_now,
|
| "time_s": elapsed,
|
| }
|
| log.append(entry)
|
|
|
| print(f"Epoch {epoch:3d}/{args.epochs} | "
|
| f"V={avg_v:.4f} A={avg_a:.4f} T={avg_t:.4f} Total={avg_total:.4f} | "
|
| f"lr={lr_now:.2e} | {elapsed:.1f}s")
|
|
|
|
|
| if avg_total < best_total:
|
| best_total = avg_total
|
| torch.save({
|
| "epoch": epoch,
|
| "model_state": model.state_dict(),
|
| "optimizer_state": optimizer.state_dict(),
|
| "config": CONFIG,
|
| "log": log,
|
| }, os.path.join(args.ckpt_dir, "best.pt"))
|
|
|
|
|
| if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
|
| torch.save({
|
| "epoch": epoch,
|
| "model_state": model.state_dict(),
|
| "optimizer_state": optimizer.state_dict(),
|
| "config": CONFIG,
|
| "log": log,
|
| }, os.path.join(args.ckpt_dir, f"ckpt_epoch{epoch}.pt"))
|
|
|
|
|
| with open(os.path.join(args.ckpt_dir, "training_log.json"), "w") as f:
|
| json.dump(log, f, indent=2)
|
|
|
| print(f"\nTraining done. Best total loss: {best_total:.4f}")
|
| print(f"Final: V={log[-1]['loss_vision']:.4f} A={log[-1]['loss_audio']:.4f} T={log[-1]['loss_text']:.4f}")
|
|
|
| return log
|
|
|
|
|
| if __name__ == "__main__":
|
| p = argparse.ArgumentParser()
|
| p.add_argument("--epochs", type=int, default=50)
|
| p.add_argument("--batch_size", type=int, default=32)
|
| p.add_argument("--lr", type=float, default=3e-4)
|
| p.add_argument("--save_every", type=int, default=10)
|
| p.add_argument("--ckpt_dir", type=str, default="checkpoints")
|
| p.add_argument("--synthetic", action="store_true")
|
| p.add_argument("--vision_path", type=str,
|
| default="../phase1/mnist_test_seq.npy")
|
| p.add_argument("--audio_path", type=str,
|
| default="../phase1b/audio_features/audio_features.npy")
|
| p.add_argument("--text_path", type=str,
|
| default="../phase1/TinyStoriesV2-GPT4-train.txt")
|
| args = p.parse_args()
|
| train(args)
|
|
|