JorgeAV commited on
Commit
da959e9
·
verified ·
1 Parent(s): 3f403f6

feat: add persistent Trackio Space for image logging (space_id + sync)

Browse files
Files changed (1) hide show
  1. train_phase2.py +118 -98
train_phase2.py CHANGED
@@ -19,6 +19,8 @@ Visual diagnostics logged to Trackio (state-of-the-art for JEPA):
19
  9. Cross-Attention Weights in Perceiver — which evidence each query attends to
20
  10. Eigenspectrum Plot — singular value distribution of latent space
21
 
 
 
22
  Usage:
23
  python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt
24
  python train_phase2.py --epochs 10 --backbone_lr 1e-5
@@ -712,6 +714,8 @@ def main():
712
  parser.add_argument("--max_eval_samples", type=int, default=500)
713
  parser.add_argument("--vis_interval", type=int, default=100)
714
  parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2")
 
 
715
  args = parser.parse_args()
716
 
717
  log.info("Downloading Phase 1 training script...")
@@ -753,9 +757,12 @@ def main():
753
  log.info(f"Device: {device}")
754
  os.makedirs(cfg.output_dir, exist_ok=True)
755
 
 
756
  import trackio
757
  trackio.init(
758
- name=args.run_name, project="MR-JEPA",
 
 
759
  config={
760
  "phase": 2, "epochs": args.epochs,
761
  "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr,
@@ -767,7 +774,7 @@ def main():
767
  "backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn,
768
  }
769
  )
770
- log.info("Trackio initialized with visual diagnostics")
771
 
772
  log.info("Building model...")
773
  model = p1.MRJEPAModel(cfg)
@@ -839,102 +846,115 @@ def main():
839
  amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
840
  trainable = [p for p in model.parameters() if p.requires_grad]
841
 
842
- for epoch in range(cfg.epochs):
843
- model.train()
844
- epoch_losses = defaultdict(list)
845
- epoch_correct = 0
846
- epoch_total = 0
847
- optimizer.zero_grad()
848
-
849
- for batch_idx, batch in enumerate(train_dl):
850
- batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
851
- vis_tok = model.vis(batch["pixel_values"]).float()
852
- txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float()
853
-
854
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"):
855
- evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"])
856
- if model._use_rollout:
857
- traj, z_final, z_proj = model.rollout(evidence)
858
- else:
859
- B = batch["batch_size"]
860
- z0 = model.rollout.init_tokens.expand(B, -1, -1) + \
861
- model.rollout.z0_proj(F.adaptive_avg_pool1d(
862
- evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
863
- z_final = z0
864
- z_proj = model.rollout.out_proj(z0).unsqueeze(1)
865
-
866
- if model._use_jepa:
867
- target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach())
868
- else:
869
- target_proj = None
870
-
871
- opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"])
872
- opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1)
873
- logits = model.disc(z_final, opt_emb, batch["opt_mask"])
874
- task_loss = F.cross_entropy(logits, batch["labels"])
875
-
876
- if model._use_jepa and target_proj is not None:
877
- losses = model.jepa_loss(z_proj, target_proj, task_loss)
878
- else:
879
- losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)}
880
- loss = losses["total"] / cfg.grad_accum
881
-
882
- loss.backward()
883
-
884
- if (batch_idx + 1) % cfg.grad_accum == 0:
885
- nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
886
- optimizer.step(); scheduler.step(); optimizer.zero_grad()
887
- model.update_target(global_step, total_steps)
888
- global_step += 1
889
- if global_step % args.vis_interval == 0 and global_step > 0:
890
- log.info(f"Generating visual diagnostics at step {global_step}...")
891
- log_visual_diagnostics(model, batch, device, cfg, global_step, epoch,
892
- diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
893
-
894
- preds = logits.argmax(dim=-1)
895
- for k, v in losses.items():
896
- if isinstance(v, torch.Tensor):
897
- epoch_losses[k].append(v.item())
898
- epoch_correct += (preds == batch["labels"]).sum().item()
899
- epoch_total += batch["batch_size"]
900
-
901
- if batch_idx % 50 == 0:
902
- avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()}
903
- acc = epoch_correct / max(epoch_total, 1) * 100
904
- lrs = scheduler.get_last_lr()
905
- log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | "
906
- f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} "
907
- f"task={avg.get('task',0):.4f} | acc={acc:.1f}%")
908
- trackio.log({
909
- "train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0),
910
- "train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0),
911
- "train/accuracy": acc, "train/lr": lrs[0] if lrs else 0,
912
- "train/backbone_lr": lrs[1] if len(lrs) > 1 else 0,
913
- "train/text_lr": lrs[2] if len(lrs) > 2 else 0,
914
- "train/ema_momentum": model.target.mom,
915
- "train/epoch": epoch, "train/step": global_step,
916
- })
917
-
918
- eval_acc = p1.evaluate(model, eval_dl, device, cfg)
919
- train_acc = epoch_correct / max(epoch_total, 1) * 100
920
- log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===")
921
- trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch,
922
- "eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)})
923
-
924
- log.info(f"Generating epoch-end visual diagnostics...")
925
- diag_batch = next(iter(eval_dl))
926
- diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()}
927
- log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch,
928
- diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
929
-
930
- if eval_acc > best_acc:
931
- best_acc = eval_acc
932
- p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True)
933
- log.info(f"New best accuracy: {best_acc:.1f}%")
934
-
935
- log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%")
936
- diag_collector.detach()
937
- trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step})
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  if cfg.push_to_hub:
939
  p1.push_results(cfg, best_acc)
940
 
 
19
  9. Cross-Attention Weights in Perceiver — which evidence each query attends to
20
  10. Eigenspectrum Plot — singular value distribution of latent space
21
 
22
+ All images are persisted to HF Space JorgeAV/MR-JEPA-Trackio via space_id parameter.
23
+
24
  Usage:
25
  python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt
26
  python train_phase2.py --epochs 10 --backbone_lr 1e-5
 
714
  parser.add_argument("--max_eval_samples", type=int, default=500)
715
  parser.add_argument("--vis_interval", type=int, default=100)
716
  parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2")
717
+ parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio",
718
+ help="HF Space ID for persistent Trackio dashboard")
719
  args = parser.parse_args()
720
 
721
  log.info("Downloading Phase 1 training script...")
 
757
  log.info(f"Device: {device}")
758
  os.makedirs(cfg.output_dir, exist_ok=True)
759
 
760
+ # ── Initialize Trackio with persistent HF Space ──
761
  import trackio
762
  trackio.init(
763
+ name=args.run_name,
764
+ project="MR-JEPA",
765
+ space_id=args.trackio_space,
766
  config={
767
  "phase": 2, "epochs": args.epochs,
768
  "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr,
 
774
  "backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn,
775
  }
776
  )
777
+ log.info(f"Trackio initialized Space: https://huggingface.co/spaces/{args.trackio_space}")
778
 
779
  log.info("Building model...")
780
  model = p1.MRJEPAModel(cfg)
 
846
  amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
847
  trainable = [p for p in model.parameters() if p.requires_grad]
848
 
849
+ try:
850
+ for epoch in range(cfg.epochs):
851
+ model.train()
852
+ epoch_losses = defaultdict(list)
853
+ epoch_correct = 0
854
+ epoch_total = 0
855
+ optimizer.zero_grad()
856
+
857
+ for batch_idx, batch in enumerate(train_dl):
858
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
859
+ vis_tok = model.vis(batch["pixel_values"]).float()
860
+ txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float()
861
+
862
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"):
863
+ evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"])
864
+ if model._use_rollout:
865
+ traj, z_final, z_proj = model.rollout(evidence)
866
+ else:
867
+ B = batch["batch_size"]
868
+ z0 = model.rollout.init_tokens.expand(B, -1, -1) + \
869
+ model.rollout.z0_proj(F.adaptive_avg_pool1d(
870
+ evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
871
+ z_final = z0
872
+ z_proj = model.rollout.out_proj(z0).unsqueeze(1)
873
+
874
+ if model._use_jepa:
875
+ target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach())
876
+ else:
877
+ target_proj = None
878
+
879
+ opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"])
880
+ opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1)
881
+ logits = model.disc(z_final, opt_emb, batch["opt_mask"])
882
+ task_loss = F.cross_entropy(logits, batch["labels"])
883
+
884
+ if model._use_jepa and target_proj is not None:
885
+ losses = model.jepa_loss(z_proj, target_proj, task_loss)
886
+ else:
887
+ losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)}
888
+ loss = losses["total"] / cfg.grad_accum
889
+
890
+ loss.backward()
891
+
892
+ if (batch_idx + 1) % cfg.grad_accum == 0:
893
+ nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
894
+ optimizer.step(); scheduler.step(); optimizer.zero_grad()
895
+ model.update_target(global_step, total_steps)
896
+ global_step += 1
897
+ if global_step % args.vis_interval == 0 and global_step > 0:
898
+ log.info(f"Generating visual diagnostics at step {global_step}...")
899
+ log_visual_diagnostics(model, batch, device, cfg, global_step, epoch,
900
+ diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
901
+
902
+ preds = logits.argmax(dim=-1)
903
+ for k, v in losses.items():
904
+ if isinstance(v, torch.Tensor):
905
+ epoch_losses[k].append(v.item())
906
+ epoch_correct += (preds == batch["labels"]).sum().item()
907
+ epoch_total += batch["batch_size"]
908
+
909
+ if batch_idx % 50 == 0:
910
+ avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()}
911
+ acc = epoch_correct / max(epoch_total, 1) * 100
912
+ lrs = scheduler.get_last_lr()
913
+ log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | "
914
+ f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} "
915
+ f"task={avg.get('task',0):.4f} | acc={acc:.1f}%")
916
+ trackio.log({
917
+ "train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0),
918
+ "train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0),
919
+ "train/accuracy": acc, "train/lr": lrs[0] if lrs else 0,
920
+ "train/backbone_lr": lrs[1] if len(lrs) > 1 else 0,
921
+ "train/text_lr": lrs[2] if len(lrs) > 2 else 0,
922
+ "train/ema_momentum": model.target.mom,
923
+ "train/epoch": epoch, "train/step": global_step,
924
+ })
925
+
926
+ eval_acc = p1.evaluate(model, eval_dl, device, cfg)
927
+ train_acc = epoch_correct / max(epoch_total, 1) * 100
928
+ log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===")
929
+ trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch,
930
+ "eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)})
931
+
932
+ log.info(f"Generating epoch-end visual diagnostics...")
933
+ diag_batch = next(iter(eval_dl))
934
+ diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()}
935
+ log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch,
936
+ diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
937
+
938
+ if eval_acc > best_acc:
939
+ best_acc = eval_acc
940
+ p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True)
941
+ log.info(f"New best accuracy: {best_acc:.1f}%")
942
+
943
+ log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%")
944
+
945
+ finally:
946
+ # ── Ensure Trackio data is persisted even if training crashes ──
947
+ diag_collector.detach()
948
+ trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step})
949
+ log.info("Finishing Trackio and syncing to Space...")
950
+ trackio.finish()
951
+ # Belt-and-suspenders: explicit sync to ensure all images are uploaded
952
+ try:
953
+ trackio.sync(project="MR-JEPA", space_id=args.trackio_space)
954
+ log.info(f"Trackio synced to https://huggingface.co/spaces/{args.trackio_space}")
955
+ except Exception as e:
956
+ log.warning(f"Trackio sync failed (data may still be available via finish): {e}")
957
+
958
  if cfg.push_to_hub:
959
  p1.push_results(cfg, best_acc)
960