feat: add persistent Trackio Space for image logging (space_id + sync)
Browse files- train_phase2.py +118 -98
train_phase2.py
CHANGED
|
@@ -19,6 +19,8 @@ Visual diagnostics logged to Trackio (state-of-the-art for JEPA):
|
|
| 19 |
9. Cross-Attention Weights in Perceiver — which evidence each query attends to
|
| 20 |
10. Eigenspectrum Plot — singular value distribution of latent space
|
| 21 |
|
|
|
|
|
|
|
| 22 |
Usage:
|
| 23 |
python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt
|
| 24 |
python train_phase2.py --epochs 10 --backbone_lr 1e-5
|
|
@@ -712,6 +714,8 @@ def main():
|
|
| 712 |
parser.add_argument("--max_eval_samples", type=int, default=500)
|
| 713 |
parser.add_argument("--vis_interval", type=int, default=100)
|
| 714 |
parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2")
|
|
|
|
|
|
|
| 715 |
args = parser.parse_args()
|
| 716 |
|
| 717 |
log.info("Downloading Phase 1 training script...")
|
|
@@ -753,9 +757,12 @@ def main():
|
|
| 753 |
log.info(f"Device: {device}")
|
| 754 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 755 |
|
|
|
|
| 756 |
import trackio
|
| 757 |
trackio.init(
|
| 758 |
-
name=args.run_name,
|
|
|
|
|
|
|
| 759 |
config={
|
| 760 |
"phase": 2, "epochs": args.epochs,
|
| 761 |
"core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr,
|
|
@@ -767,7 +774,7 @@ def main():
|
|
| 767 |
"backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn,
|
| 768 |
}
|
| 769 |
)
|
| 770 |
-
log.info("Trackio initialized
|
| 771 |
|
| 772 |
log.info("Building model...")
|
| 773 |
model = p1.MRJEPAModel(cfg)
|
|
@@ -839,102 +846,115 @@ def main():
|
|
| 839 |
amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
|
| 840 |
trainable = [p for p in model.parameters() if p.requires_grad]
|
| 841 |
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
if cfg.push_to_hub:
|
| 939 |
p1.push_results(cfg, best_acc)
|
| 940 |
|
|
|
|
| 19 |
9. Cross-Attention Weights in Perceiver — which evidence each query attends to
|
| 20 |
10. Eigenspectrum Plot — singular value distribution of latent space
|
| 21 |
|
| 22 |
+
All images are persisted to HF Space JorgeAV/MR-JEPA-Trackio via space_id parameter.
|
| 23 |
+
|
| 24 |
Usage:
|
| 25 |
python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt
|
| 26 |
python train_phase2.py --epochs 10 --backbone_lr 1e-5
|
|
|
|
| 714 |
parser.add_argument("--max_eval_samples", type=int, default=500)
|
| 715 |
parser.add_argument("--vis_interval", type=int, default=100)
|
| 716 |
parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2")
|
| 717 |
+
parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio",
|
| 718 |
+
help="HF Space ID for persistent Trackio dashboard")
|
| 719 |
args = parser.parse_args()
|
| 720 |
|
| 721 |
log.info("Downloading Phase 1 training script...")
|
|
|
|
| 757 |
log.info(f"Device: {device}")
|
| 758 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 759 |
|
| 760 |
+
# ── Initialize Trackio with persistent HF Space ──
|
| 761 |
import trackio
|
| 762 |
trackio.init(
|
| 763 |
+
name=args.run_name,
|
| 764 |
+
project="MR-JEPA",
|
| 765 |
+
space_id=args.trackio_space,
|
| 766 |
config={
|
| 767 |
"phase": 2, "epochs": args.epochs,
|
| 768 |
"core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr,
|
|
|
|
| 774 |
"backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn,
|
| 775 |
}
|
| 776 |
)
|
| 777 |
+
log.info(f"Trackio initialized → Space: https://huggingface.co/spaces/{args.trackio_space}")
|
| 778 |
|
| 779 |
log.info("Building model...")
|
| 780 |
model = p1.MRJEPAModel(cfg)
|
|
|
|
| 846 |
amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
|
| 847 |
trainable = [p for p in model.parameters() if p.requires_grad]
|
| 848 |
|
| 849 |
+
try:
|
| 850 |
+
for epoch in range(cfg.epochs):
|
| 851 |
+
model.train()
|
| 852 |
+
epoch_losses = defaultdict(list)
|
| 853 |
+
epoch_correct = 0
|
| 854 |
+
epoch_total = 0
|
| 855 |
+
optimizer.zero_grad()
|
| 856 |
+
|
| 857 |
+
for batch_idx, batch in enumerate(train_dl):
|
| 858 |
+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 859 |
+
vis_tok = model.vis(batch["pixel_values"]).float()
|
| 860 |
+
txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float()
|
| 861 |
+
|
| 862 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"):
|
| 863 |
+
evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"])
|
| 864 |
+
if model._use_rollout:
|
| 865 |
+
traj, z_final, z_proj = model.rollout(evidence)
|
| 866 |
+
else:
|
| 867 |
+
B = batch["batch_size"]
|
| 868 |
+
z0 = model.rollout.init_tokens.expand(B, -1, -1) + \
|
| 869 |
+
model.rollout.z0_proj(F.adaptive_avg_pool1d(
|
| 870 |
+
evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
|
| 871 |
+
z_final = z0
|
| 872 |
+
z_proj = model.rollout.out_proj(z0).unsqueeze(1)
|
| 873 |
+
|
| 874 |
+
if model._use_jepa:
|
| 875 |
+
target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach())
|
| 876 |
+
else:
|
| 877 |
+
target_proj = None
|
| 878 |
+
|
| 879 |
+
opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"])
|
| 880 |
+
opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1)
|
| 881 |
+
logits = model.disc(z_final, opt_emb, batch["opt_mask"])
|
| 882 |
+
task_loss = F.cross_entropy(logits, batch["labels"])
|
| 883 |
+
|
| 884 |
+
if model._use_jepa and target_proj is not None:
|
| 885 |
+
losses = model.jepa_loss(z_proj, target_proj, task_loss)
|
| 886 |
+
else:
|
| 887 |
+
losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)}
|
| 888 |
+
loss = losses["total"] / cfg.grad_accum
|
| 889 |
+
|
| 890 |
+
loss.backward()
|
| 891 |
+
|
| 892 |
+
if (batch_idx + 1) % cfg.grad_accum == 0:
|
| 893 |
+
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
|
| 894 |
+
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
| 895 |
+
model.update_target(global_step, total_steps)
|
| 896 |
+
global_step += 1
|
| 897 |
+
if global_step % args.vis_interval == 0 and global_step > 0:
|
| 898 |
+
log.info(f"Generating visual diagnostics at step {global_step}...")
|
| 899 |
+
log_visual_diagnostics(model, batch, device, cfg, global_step, epoch,
|
| 900 |
+
diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
|
| 901 |
+
|
| 902 |
+
preds = logits.argmax(dim=-1)
|
| 903 |
+
for k, v in losses.items():
|
| 904 |
+
if isinstance(v, torch.Tensor):
|
| 905 |
+
epoch_losses[k].append(v.item())
|
| 906 |
+
epoch_correct += (preds == batch["labels"]).sum().item()
|
| 907 |
+
epoch_total += batch["batch_size"]
|
| 908 |
+
|
| 909 |
+
if batch_idx % 50 == 0:
|
| 910 |
+
avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()}
|
| 911 |
+
acc = epoch_correct / max(epoch_total, 1) * 100
|
| 912 |
+
lrs = scheduler.get_last_lr()
|
| 913 |
+
log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | "
|
| 914 |
+
f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} "
|
| 915 |
+
f"task={avg.get('task',0):.4f} | acc={acc:.1f}%")
|
| 916 |
+
trackio.log({
|
| 917 |
+
"train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0),
|
| 918 |
+
"train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0),
|
| 919 |
+
"train/accuracy": acc, "train/lr": lrs[0] if lrs else 0,
|
| 920 |
+
"train/backbone_lr": lrs[1] if len(lrs) > 1 else 0,
|
| 921 |
+
"train/text_lr": lrs[2] if len(lrs) > 2 else 0,
|
| 922 |
+
"train/ema_momentum": model.target.mom,
|
| 923 |
+
"train/epoch": epoch, "train/step": global_step,
|
| 924 |
+
})
|
| 925 |
+
|
| 926 |
+
eval_acc = p1.evaluate(model, eval_dl, device, cfg)
|
| 927 |
+
train_acc = epoch_correct / max(epoch_total, 1) * 100
|
| 928 |
+
log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===")
|
| 929 |
+
trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch,
|
| 930 |
+
"eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)})
|
| 931 |
+
|
| 932 |
+
log.info(f"Generating epoch-end visual diagnostics...")
|
| 933 |
+
diag_batch = next(iter(eval_dl))
|
| 934 |
+
diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()}
|
| 935 |
+
log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch,
|
| 936 |
+
diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
|
| 937 |
+
|
| 938 |
+
if eval_acc > best_acc:
|
| 939 |
+
best_acc = eval_acc
|
| 940 |
+
p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True)
|
| 941 |
+
log.info(f"New best accuracy: {best_acc:.1f}%")
|
| 942 |
+
|
| 943 |
+
log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%")
|
| 944 |
+
|
| 945 |
+
finally:
|
| 946 |
+
# ── Ensure Trackio data is persisted even if training crashes ──
|
| 947 |
+
diag_collector.detach()
|
| 948 |
+
trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step})
|
| 949 |
+
log.info("Finishing Trackio and syncing to Space...")
|
| 950 |
+
trackio.finish()
|
| 951 |
+
# Belt-and-suspenders: explicit sync to ensure all images are uploaded
|
| 952 |
+
try:
|
| 953 |
+
trackio.sync(project="MR-JEPA", space_id=args.trackio_space)
|
| 954 |
+
log.info(f"Trackio synced to https://huggingface.co/spaces/{args.trackio_space}")
|
| 955 |
+
except Exception as e:
|
| 956 |
+
log.warning(f"Trackio sync failed (data may still be available via finish): {e}")
|
| 957 |
+
|
| 958 |
if cfg.push_to_hub:
|
| 959 |
p1.push_results(cfg, best_acc)
|
| 960 |
|