""" hf_uploader.py --------------- HuggingFace Hub uploader for training runs. Conventions: - Each fresh training launch gets a monotonically increasing folder on the hub: /run_1/, run_2/, ..., run_N/ - Resuming a run (train.py --resume_from ...) re-uses the previous run_id (read from a local state file) so all subsequent uploads land in the same folder. - Run contents (per run_N): run_N/ stage1/stage1_final.pt stage2/stage2_final.pt stage2/checkpoint-/... (optional, if upload_intermediate=True) results/predictions_*.json results/metrics_summary.json meta.json (start time, resume count, config snapshot) """ import json import os import re import time from pathlib import Path from typing import Optional, List try: from huggingface_hub import HfApi, create_repo HF_AVAILABLE = True except ImportError: HF_AVAILABLE = False class HFRunTracker: """Determine current run_id and upload artifacts for it.""" def __init__( self, repo_id: str, token: Optional[str] = None, state_file: str = "checkpoints/run_id.txt", resuming: bool = False, explicit_run_id: Optional[str] = None, private: bool = True, ): if not HF_AVAILABLE: raise ImportError("huggingface_hub not installed. pip install huggingface_hub") if not repo_id: raise ValueError("repo_id is required (e.g. 'username/cxr-vlm-runs')") self.repo_id = repo_id self.token = token or os.environ.get("HF_TOKEN") self.state_file = Path(state_file) self.api = HfApi(token=self.token) # Make sure repo exists try: create_repo( repo_id = self.repo_id, token = self.token, repo_type = "model", private = private, exist_ok = True, ) except Exception as e: print(f"[HFRunTracker] warn: create_repo: {e}") self.run_id = self._resolve_run_id(resuming, explicit_run_id) print(f"[HFRunTracker] using run_id = {self.run_id}") # ── run_id resolution ────────────────────────────────────────────────── def _resolve_run_id(self, resuming: bool, explicit: Optional[str]) -> str: if explicit: run_id = explicit self._write_state(run_id) return run_id if resuming: if self.state_file.exists(): return self.state_file.read_text().strip() # If resuming but no local state → try last run on hub runs = self._list_remote_runs() if runs: run_id = f"run_{max(runs)}" self._write_state(run_id) return run_id raise RuntimeError( "Resuming but no run_id.txt locally and no runs on HF hub. " "Pass --run_id explicitly." ) # Fresh session: honor local state if present (user may have reset kernel # but wants to continue the same run) if self.state_file.exists(): run_id = self.state_file.read_text().strip() print(f"[HFRunTracker] resuming via local state file: {run_id}") return run_id # Pick next run number from hub runs = self._list_remote_runs() next_n = max(runs) + 1 if runs else 1 run_id = f"run_{next_n}" self._write_state(run_id) return run_id def _list_remote_runs(self) -> List[int]: try: files = self.api.list_repo_files(self.repo_id, token=self.token) except Exception as e: print(f"[HFRunTracker] list_repo_files: {e} → assuming empty repo") return [] nums = set() rx = re.compile(r"^run_(\d+)(?:/|$)") for f in files: m = rx.match(f) if m: nums.add(int(m.group(1))) return sorted(nums) def _write_state(self, run_id: str): self.state_file.parent.mkdir(parents=True, exist_ok=True) self.state_file.write_text(run_id) # ── upload helpers ───────────────────────────────────────────────────── # All upload methods swallow exceptions and print a warning — training # must never crash because the hub is unreachable / token is read-only. def upload_file(self, local_path: str, remote_subpath: str): local_path = Path(local_path) if not local_path.exists(): print(f"[HFRunTracker] skip upload (missing): {local_path}") return print(f"[HFRunTracker] ↑ {local_path} → {self.run_id}/{remote_subpath}") try: self.api.upload_file( path_or_fileobj = str(local_path), path_in_repo = f"{self.run_id}/{remote_subpath}", repo_id = self.repo_id, token = self.token, ) except Exception as e: print(f"[HFRunTracker] WARN upload_file failed ({type(e).__name__}): {e}") def upload_folder(self, local_folder: str, remote_subpath: str, allow_patterns=None, ignore_patterns=None): local_folder = Path(local_folder) if not local_folder.exists(): print(f"[HFRunTracker] skip upload_folder (missing): {local_folder}") return print(f"[HFRunTracker] ↑ folder {local_folder} → {self.run_id}/{remote_subpath}") try: self.api.upload_folder( folder_path = str(local_folder), path_in_repo = f"{self.run_id}/{remote_subpath}", repo_id = self.repo_id, token = self.token, allow_patterns = allow_patterns, ignore_patterns= ignore_patterns, ) except Exception as e: print(f"[HFRunTracker] WARN upload_folder failed ({type(e).__name__}): {e}") def delete_remote(self, remote_subpath: str): """Best-effort delete of a remote folder (e.g. run_N/stage1/last). Used to clear stale files before re-uploading 'last' / 'best' so no orphan files from a previous step linger.""" path_in_repo = f"{self.run_id}/{remote_subpath}" try: self.api.delete_folder( path_in_repo = path_in_repo, repo_id = self.repo_id, token = self.token, ) except Exception as e: # Folder may not exist yet, or older hub client lacks delete_folder. # Fall back to per-file delete if we can list. try: files = self.api.list_repo_files(self.repo_id, token=self.token) prefix = path_in_repo.rstrip("/") + "/" for f in files: if f.startswith(prefix): try: self.api.delete_file( path_in_repo = f, repo_id = self.repo_id, token = self.token, ) except Exception: pass except Exception: pass def upload_jsonl(self, lines, remote_subpath: str): """Replace a remote .jsonl file with the given list of dict entries.""" import tempfile try: with tempfile.NamedTemporaryFile( "w", suffix=".jsonl", delete=False, encoding="utf-8" ) as tmp: for entry in lines: tmp.write(json.dumps(entry, default=str) + "\n") tmp_path = tmp.name self.upload_file(tmp_path, remote_subpath) os.unlink(tmp_path) except Exception as e: print(f"[HFRunTracker] WARN upload_jsonl failed ({type(e).__name__}): {e}") def upload_json(self, obj: dict, remote_subpath: str): """Upload a small dict as a JSON file (used for best_meta).""" import tempfile try: with tempfile.NamedTemporaryFile( "w", suffix=".json", delete=False, encoding="utf-8" ) as tmp: json.dump(obj, tmp, indent=2, default=str) tmp_path = tmp.name self.upload_file(tmp_path, remote_subpath) os.unlink(tmp_path) except Exception as e: print(f"[HFRunTracker] WARN upload_json failed ({type(e).__name__}): {e}") def write_meta(self, meta: dict, remote_subpath: str = "meta.json"): """Merge+upload a meta.json for the run. Reads existing if present. Any failure (network, permission) is logged but does not raise — so training never crashes because of a hub glitch.""" import tempfile existing = {} try: # Try download current meta.json if exists path = self.api.hf_hub_download( repo_id = self.repo_id, filename = f"{self.run_id}/{remote_subpath}", token = self.token, ) existing = json.loads(Path(path).read_text()) except Exception: pass merged = {**existing, **meta} merged.setdefault("created_at", time.time()) merged["last_updated"] = time.time() merged["resume_count"] = existing.get("resume_count", 0) + (1 if existing else 0) try: with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as tmp: json.dump(merged, tmp, indent=2) tmp_path = tmp.name self.upload_file(tmp_path, remote_subpath) # already try/except inside os.unlink(tmp_path) except Exception as e: print(f"[HFRunTracker] WARN write_meta failed ({type(e).__name__}): {e}") def pull_last_for_resume( repo_id: str, token: Optional[str], run_id: str, stage_subdir: str, local_root: str = "checkpoints/_resume_from_hf", ) -> Optional[str]: """ Download //last/ from the hub into a local folder that can be passed straight to `trainer.train(resume_from_checkpoint=...)`. Returns the local path or None on failure. """ if not HF_AVAILABLE: print("[hf_uploader] huggingface_hub not installed — cannot pull resume state") return None from huggingface_hub import snapshot_download token = token or os.environ.get("HF_TOKEN") target_root = Path(local_root) / run_id / stage_subdir target_root.mkdir(parents=True, exist_ok=True) try: snapshot_download( repo_id = repo_id, token = token, allow_patterns = [f"{run_id}/{stage_subdir}/last/**"], local_dir = str(target_root.parent.parent), # repo root mirror ) except Exception as e: print(f"[hf_uploader] WARN pull_last_for_resume: {e}") return None last_dir = Path(local_root) / run_id / stage_subdir / "last" if not last_dir.exists() or not any(last_dir.iterdir()): print(f"[hf_uploader] no files pulled into {last_dir}") return None print(f"[hf_uploader] pulled resume state → {last_dir}") return str(last_dir) def hydrate_run_dir_from_hf( repo_id: str, token: Optional[str], run_id: str, output_root: str, stage1_subdir: str = "stage1_projection", stage2_subdir: str = "stage2_instruct", ) -> bool: """ Repopulate a local run dir from HF artifacts so `detect_resume_point` can find checkpoints after a fresh-VM resume (persistence lost / new host). HF layout (uploaded by HFBestLastCallback + end-of-stage saves): {run_id}/configs/ (YAML snapshots) {run_id}/run_meta.json {run_id}/timing.json {run_id}/stage1/last/ + stage1/best/ (best/ = stage1 final, renamed `checkpoint_*`) {run_id}/stage2/last/ + stage2/best/ Local layout `detect_resume_point` expects: {output_root}/{run_id}/stage1_projection/stage1_final_* ← stage1 done {output_root}/{run_id}/stage1_projection/checkpoint-N/... ← stage1 mid {output_root}/{run_id}/stage2_instruct/stage2_final_* ← stage2 done {output_root}/{run_id}/stage2_instruct/checkpoint-N/... ← stage2 mid Mapping rules: * `stage2/last/` → `stage2_instruct/checkpoint-1/` (placeholder N=1; Trainer reads the real global_step from trainer_state.json inside). * `stage1/best/` → `stage1_projection/stage1_final_*` (rename files from `checkpoint_*` to `stage1_final_*` so save_checkpoint conventions line up with what the rest of the pipeline expects). * `stage1/last/` → `stage1_projection/checkpoint-1/` (only if no stage1_final placed — i.e. stage 1 hadn't finished yet on HF). Returns True if at least one artifact was placed, False otherwise. """ if not HF_AVAILABLE: print("[hydrate_run_dir_from_hf] huggingface_hub not installed — skip") return False from huggingface_hub import snapshot_download import shutil token = token or os.environ.get("HF_TOKEN") output_root = Path(output_root) staging = output_root / "_hf_pull" dst_root = output_root / run_id # Skip if local already has any final/checkpoint — we're not on a fresh VM. s1_local = dst_root / stage1_subdir s2_local = dst_root / stage2_subdir def _has_ckpt(d: Path) -> bool: return d.is_dir() and any(d.glob("checkpoint-*")) if ( (s1_local / "stage1_final_projection.pt").exists() or (s2_local / "stage2_final_projection.pt").exists() or _has_ckpt(s1_local) or _has_ckpt(s2_local) ): print(f"[hydrate_run_dir_from_hf] local {dst_root} already populated — skip pull") return False # Pull the run's relevant files (configs + meta + last/best, skip # training_log.jsonl which can be large). staging.mkdir(parents=True, exist_ok=True) try: snapshot_download( repo_id = repo_id, repo_type = "model", token = token, allow_patterns = [ f"{run_id}/configs/**", f"{run_id}/run_meta.json", f"{run_id}/timing.json", f"{run_id}/meta.json", f"{run_id}/stage1/last/**", f"{run_id}/stage1/best/**", f"{run_id}/stage2/last/**", f"{run_id}/stage2/best/**", ], local_dir = str(staging), ) except Exception as e: print(f"[hydrate_run_dir_from_hf] snapshot_download failed: {e}") return False src_root = staging / run_id if not src_root.is_dir(): print(f"[hydrate_run_dir_from_hf] HF has no '{run_id}/' folder") shutil.rmtree(staging, ignore_errors=True) return False dst_root.mkdir(parents=True, exist_ok=True) placed_any = False # configs/, run_meta.json, timing.json, meta.json: straight copy for sub in ("configs",): s = src_root / sub if s.is_dir(): shutil.copytree(s, dst_root / sub, dirs_exist_ok=True) placed_any = True for f in ("run_meta.json", "timing.json", "meta.json"): s = src_root / f if s.is_file(): shutil.copy2(s, dst_root / f) placed_any = True # Stage 2 last → checkpoint-1 s2_last_src = src_root / "stage2" / "last" if s2_last_src.is_dir() and any(s2_last_src.iterdir()): dst = dst_root / stage2_subdir / "checkpoint-1" dst.mkdir(parents=True, exist_ok=True) shutil.copytree(s2_last_src, dst, dirs_exist_ok=True) placed_any = True print(f"[hydrate_run_dir_from_hf] stage2 mid-resume placed at {dst}") # Stage 1 best (= final) → stage1_final_* s1_best_src = src_root / "stage1" / "best" if s1_best_src.is_dir() and (s1_best_src / "checkpoint_projection.pt").exists(): dst_s1 = dst_root / stage1_subdir dst_s1.mkdir(parents=True, exist_ok=True) for entry in s1_best_src.iterdir(): # Rename "checkpoint_*" → "stage1_final_*" new_name = entry.name.replace("checkpoint_", "stage1_final_", 1) \ if entry.name.startswith("checkpoint_") else entry.name if entry.is_file(): shutil.copy2(entry, dst_s1 / new_name) elif entry.is_dir(): shutil.copytree(entry, dst_s1 / new_name, dirs_exist_ok=True) placed_any = True print(f"[hydrate_run_dir_from_hf] stage1 final placed at {dst_s1}") # Stage 1 last → checkpoint-1 (ONLY if stage1 didn't finish yet) if not (dst_root / stage1_subdir / "stage1_final_projection.pt").exists(): s1_last_src = src_root / "stage1" / "last" if s1_last_src.is_dir() and any(s1_last_src.iterdir()): dst = dst_root / stage1_subdir / "checkpoint-1" dst.mkdir(parents=True, exist_ok=True) shutil.copytree(s1_last_src, dst, dirs_exist_ok=True) placed_any = True print(f"[hydrate_run_dir_from_hf] stage1 mid-resume placed at {dst}") # Cleanup staging shutil.rmtree(staging, ignore_errors=True) if placed_any: print(f"[hydrate_run_dir_from_hf] hydrated {dst_root} from HF") else: print(f"[hydrate_run_dir_from_hf] nothing usable on HF for {run_id}") return placed_any def build_tracker_from_cfg(train_cfg, resuming: bool = False, explicit_run_id: Optional[str] = None): """Convenience factory from OmegaConf DictConfig.""" hf = getattr(train_cfg, "hf_hub", None) if hf is None or not getattr(hf, "enabled", False): return None token = os.environ.get(hf.token_env, os.environ.get("HF_TOKEN")) if not token: print(f"[hf_uploader] no {hf.token_env} / HF_TOKEN in env — hub upload disabled") return None if not hf.repo_id: print("[hf_uploader] hf_hub.repo_id not set — hub upload disabled") return None return HFRunTracker( repo_id = hf.repo_id, token = token, state_file = hf.run_state_file, resuming = resuming, explicit_run_id = explicit_run_id, private = getattr(hf, "private", True), )