cxr-vlm-code / utils /hf_uploader.py
convitom
f
b961b41
"""
hf_uploader.py
---------------
HuggingFace Hub uploader for training runs.
Conventions:
- Each fresh training launch gets a monotonically increasing folder on the hub:
<repo>/run_1/, run_2/, ..., run_N/
- Resuming a run (train.py --resume_from ...) re-uses the previous run_id
(read from a local state file) so all subsequent uploads land in the same folder.
- Run contents (per run_N):
run_N/
stage1/stage1_final.pt
stage2/stage2_final.pt
stage2/checkpoint-<step>/... (optional, if upload_intermediate=True)
results/predictions_*.json
results/metrics_summary.json
meta.json (start time, resume count, config snapshot)
"""
import json
import os
import re
import time
from pathlib import Path
from typing import Optional, List
try:
from huggingface_hub import HfApi, create_repo
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
class HFRunTracker:
"""Determine current run_id and upload artifacts for it."""
def __init__(
self,
repo_id: str,
token: Optional[str] = None,
state_file: str = "checkpoints/run_id.txt",
resuming: bool = False,
explicit_run_id: Optional[str] = None,
private: bool = True,
):
if not HF_AVAILABLE:
raise ImportError("huggingface_hub not installed. pip install huggingface_hub")
if not repo_id:
raise ValueError("repo_id is required (e.g. 'username/cxr-vlm-runs')")
self.repo_id = repo_id
self.token = token or os.environ.get("HF_TOKEN")
self.state_file = Path(state_file)
self.api = HfApi(token=self.token)
# Make sure repo exists
try:
create_repo(
repo_id = self.repo_id,
token = self.token,
repo_type = "model",
private = private,
exist_ok = True,
)
except Exception as e:
print(f"[HFRunTracker] warn: create_repo: {e}")
self.run_id = self._resolve_run_id(resuming, explicit_run_id)
print(f"[HFRunTracker] using run_id = {self.run_id}")
# ── run_id resolution ──────────────────────────────────────────────────
def _resolve_run_id(self, resuming: bool, explicit: Optional[str]) -> str:
if explicit:
run_id = explicit
self._write_state(run_id)
return run_id
if resuming:
if self.state_file.exists():
return self.state_file.read_text().strip()
# If resuming but no local state → try last run on hub
runs = self._list_remote_runs()
if runs:
run_id = f"run_{max(runs)}"
self._write_state(run_id)
return run_id
raise RuntimeError(
"Resuming but no run_id.txt locally and no runs on HF hub. "
"Pass --run_id explicitly."
)
# Fresh session: honor local state if present (user may have reset kernel
# but wants to continue the same run)
if self.state_file.exists():
run_id = self.state_file.read_text().strip()
print(f"[HFRunTracker] resuming via local state file: {run_id}")
return run_id
# Pick next run number from hub
runs = self._list_remote_runs()
next_n = max(runs) + 1 if runs else 1
run_id = f"run_{next_n}"
self._write_state(run_id)
return run_id
def _list_remote_runs(self) -> List[int]:
try:
files = self.api.list_repo_files(self.repo_id, token=self.token)
except Exception as e:
print(f"[HFRunTracker] list_repo_files: {e} → assuming empty repo")
return []
nums = set()
rx = re.compile(r"^run_(\d+)(?:/|$)")
for f in files:
m = rx.match(f)
if m:
nums.add(int(m.group(1)))
return sorted(nums)
def _write_state(self, run_id: str):
self.state_file.parent.mkdir(parents=True, exist_ok=True)
self.state_file.write_text(run_id)
# ── upload helpers ─────────────────────────────────────────────────────
# All upload methods swallow exceptions and print a warning — training
# must never crash because the hub is unreachable / token is read-only.
def upload_file(self, local_path: str, remote_subpath: str):
local_path = Path(local_path)
if not local_path.exists():
print(f"[HFRunTracker] skip upload (missing): {local_path}")
return
print(f"[HFRunTracker] ↑ {local_path}{self.run_id}/{remote_subpath}")
try:
self.api.upload_file(
path_or_fileobj = str(local_path),
path_in_repo = f"{self.run_id}/{remote_subpath}",
repo_id = self.repo_id,
token = self.token,
)
except Exception as e:
print(f"[HFRunTracker] WARN upload_file failed ({type(e).__name__}): {e}")
def upload_folder(self, local_folder: str, remote_subpath: str, allow_patterns=None, ignore_patterns=None):
local_folder = Path(local_folder)
if not local_folder.exists():
print(f"[HFRunTracker] skip upload_folder (missing): {local_folder}")
return
print(f"[HFRunTracker] ↑ folder {local_folder}{self.run_id}/{remote_subpath}")
try:
self.api.upload_folder(
folder_path = str(local_folder),
path_in_repo = f"{self.run_id}/{remote_subpath}",
repo_id = self.repo_id,
token = self.token,
allow_patterns = allow_patterns,
ignore_patterns= ignore_patterns,
)
except Exception as e:
print(f"[HFRunTracker] WARN upload_folder failed ({type(e).__name__}): {e}")
def delete_remote(self, remote_subpath: str):
"""Best-effort delete of a remote folder (e.g. run_N/stage1/last).
Used to clear stale files before re-uploading 'last' / 'best' so no
orphan files from a previous step linger."""
path_in_repo = f"{self.run_id}/{remote_subpath}"
try:
self.api.delete_folder(
path_in_repo = path_in_repo,
repo_id = self.repo_id,
token = self.token,
)
except Exception as e:
# Folder may not exist yet, or older hub client lacks delete_folder.
# Fall back to per-file delete if we can list.
try:
files = self.api.list_repo_files(self.repo_id, token=self.token)
prefix = path_in_repo.rstrip("/") + "/"
for f in files:
if f.startswith(prefix):
try:
self.api.delete_file(
path_in_repo = f,
repo_id = self.repo_id,
token = self.token,
)
except Exception:
pass
except Exception:
pass
def upload_jsonl(self, lines, remote_subpath: str):
"""Replace a remote .jsonl file with the given list of dict entries."""
import tempfile
try:
with tempfile.NamedTemporaryFile(
"w", suffix=".jsonl", delete=False, encoding="utf-8"
) as tmp:
for entry in lines:
tmp.write(json.dumps(entry, default=str) + "\n")
tmp_path = tmp.name
self.upload_file(tmp_path, remote_subpath)
os.unlink(tmp_path)
except Exception as e:
print(f"[HFRunTracker] WARN upload_jsonl failed ({type(e).__name__}): {e}")
def upload_json(self, obj: dict, remote_subpath: str):
"""Upload a small dict as a JSON file (used for best_meta)."""
import tempfile
try:
with tempfile.NamedTemporaryFile(
"w", suffix=".json", delete=False, encoding="utf-8"
) as tmp:
json.dump(obj, tmp, indent=2, default=str)
tmp_path = tmp.name
self.upload_file(tmp_path, remote_subpath)
os.unlink(tmp_path)
except Exception as e:
print(f"[HFRunTracker] WARN upload_json failed ({type(e).__name__}): {e}")
def write_meta(self, meta: dict, remote_subpath: str = "meta.json"):
"""Merge+upload a meta.json for the run. Reads existing if present.
Any failure (network, permission) is logged but does not raise — so
training never crashes because of a hub glitch."""
import tempfile
existing = {}
try:
# Try download current meta.json if exists
path = self.api.hf_hub_download(
repo_id = self.repo_id,
filename = f"{self.run_id}/{remote_subpath}",
token = self.token,
)
existing = json.loads(Path(path).read_text())
except Exception:
pass
merged = {**existing, **meta}
merged.setdefault("created_at", time.time())
merged["last_updated"] = time.time()
merged["resume_count"] = existing.get("resume_count", 0) + (1 if existing else 0)
try:
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as tmp:
json.dump(merged, tmp, indent=2)
tmp_path = tmp.name
self.upload_file(tmp_path, remote_subpath) # already try/except inside
os.unlink(tmp_path)
except Exception as e:
print(f"[HFRunTracker] WARN write_meta failed ({type(e).__name__}): {e}")
def pull_last_for_resume(
repo_id: str,
token: Optional[str],
run_id: str,
stage_subdir: str,
local_root: str = "checkpoints/_resume_from_hf",
) -> Optional[str]:
"""
Download <run_id>/<stage_subdir>/last/ from the hub into a local folder
that can be passed straight to `trainer.train(resume_from_checkpoint=...)`.
Returns the local path or None on failure.
"""
if not HF_AVAILABLE:
print("[hf_uploader] huggingface_hub not installed — cannot pull resume state")
return None
from huggingface_hub import snapshot_download
token = token or os.environ.get("HF_TOKEN")
target_root = Path(local_root) / run_id / stage_subdir
target_root.mkdir(parents=True, exist_ok=True)
try:
snapshot_download(
repo_id = repo_id,
token = token,
allow_patterns = [f"{run_id}/{stage_subdir}/last/**"],
local_dir = str(target_root.parent.parent), # repo root mirror
)
except Exception as e:
print(f"[hf_uploader] WARN pull_last_for_resume: {e}")
return None
last_dir = Path(local_root) / run_id / stage_subdir / "last"
if not last_dir.exists() or not any(last_dir.iterdir()):
print(f"[hf_uploader] no files pulled into {last_dir}")
return None
print(f"[hf_uploader] pulled resume state → {last_dir}")
return str(last_dir)
def hydrate_run_dir_from_hf(
repo_id: str,
token: Optional[str],
run_id: str,
output_root: str,
stage1_subdir: str = "stage1_projection",
stage2_subdir: str = "stage2_instruct",
) -> bool:
"""
Repopulate a local run dir from HF artifacts so `detect_resume_point`
can find checkpoints after a fresh-VM resume (persistence lost / new host).
HF layout (uploaded by HFBestLastCallback + end-of-stage saves):
{run_id}/configs/ (YAML snapshots)
{run_id}/run_meta.json
{run_id}/timing.json
{run_id}/stage1/last/ + stage1/best/ (best/ = stage1 final, renamed `checkpoint_*`)
{run_id}/stage2/last/ + stage2/best/
Local layout `detect_resume_point` expects:
{output_root}/{run_id}/stage1_projection/stage1_final_* ← stage1 done
{output_root}/{run_id}/stage1_projection/checkpoint-N/... ← stage1 mid
{output_root}/{run_id}/stage2_instruct/stage2_final_* ← stage2 done
{output_root}/{run_id}/stage2_instruct/checkpoint-N/... ← stage2 mid
Mapping rules:
* `stage2/last/` → `stage2_instruct/checkpoint-1/` (placeholder N=1;
Trainer reads the real global_step from trainer_state.json inside).
* `stage1/best/` → `stage1_projection/stage1_final_*` (rename files
from `checkpoint_*` to `stage1_final_*` so save_checkpoint conventions
line up with what the rest of the pipeline expects).
* `stage1/last/` → `stage1_projection/checkpoint-1/` (only if no
stage1_final placed — i.e. stage 1 hadn't finished yet on HF).
Returns True if at least one artifact was placed, False otherwise.
"""
if not HF_AVAILABLE:
print("[hydrate_run_dir_from_hf] huggingface_hub not installed — skip")
return False
from huggingface_hub import snapshot_download
import shutil
token = token or os.environ.get("HF_TOKEN")
output_root = Path(output_root)
staging = output_root / "_hf_pull"
dst_root = output_root / run_id
# Skip if local already has any final/checkpoint — we're not on a fresh VM.
s1_local = dst_root / stage1_subdir
s2_local = dst_root / stage2_subdir
def _has_ckpt(d: Path) -> bool:
return d.is_dir() and any(d.glob("checkpoint-*"))
if (
(s1_local / "stage1_final_projection.pt").exists()
or (s2_local / "stage2_final_projection.pt").exists()
or _has_ckpt(s1_local)
or _has_ckpt(s2_local)
):
print(f"[hydrate_run_dir_from_hf] local {dst_root} already populated — skip pull")
return False
# Pull the run's relevant files (configs + meta + last/best, skip
# training_log.jsonl which can be large).
staging.mkdir(parents=True, exist_ok=True)
try:
snapshot_download(
repo_id = repo_id,
repo_type = "model",
token = token,
allow_patterns = [
f"{run_id}/configs/**",
f"{run_id}/run_meta.json",
f"{run_id}/timing.json",
f"{run_id}/meta.json",
f"{run_id}/stage1/last/**",
f"{run_id}/stage1/best/**",
f"{run_id}/stage2/last/**",
f"{run_id}/stage2/best/**",
],
local_dir = str(staging),
)
except Exception as e:
print(f"[hydrate_run_dir_from_hf] snapshot_download failed: {e}")
return False
src_root = staging / run_id
if not src_root.is_dir():
print(f"[hydrate_run_dir_from_hf] HF has no '{run_id}/' folder")
shutil.rmtree(staging, ignore_errors=True)
return False
dst_root.mkdir(parents=True, exist_ok=True)
placed_any = False
# configs/, run_meta.json, timing.json, meta.json: straight copy
for sub in ("configs",):
s = src_root / sub
if s.is_dir():
shutil.copytree(s, dst_root / sub, dirs_exist_ok=True)
placed_any = True
for f in ("run_meta.json", "timing.json", "meta.json"):
s = src_root / f
if s.is_file():
shutil.copy2(s, dst_root / f)
placed_any = True
# Stage 2 last → checkpoint-1
s2_last_src = src_root / "stage2" / "last"
if s2_last_src.is_dir() and any(s2_last_src.iterdir()):
dst = dst_root / stage2_subdir / "checkpoint-1"
dst.mkdir(parents=True, exist_ok=True)
shutil.copytree(s2_last_src, dst, dirs_exist_ok=True)
placed_any = True
print(f"[hydrate_run_dir_from_hf] stage2 mid-resume placed at {dst}")
# Stage 1 best (= final) → stage1_final_*
s1_best_src = src_root / "stage1" / "best"
if s1_best_src.is_dir() and (s1_best_src / "checkpoint_projection.pt").exists():
dst_s1 = dst_root / stage1_subdir
dst_s1.mkdir(parents=True, exist_ok=True)
for entry in s1_best_src.iterdir():
# Rename "checkpoint_*" → "stage1_final_*"
new_name = entry.name.replace("checkpoint_", "stage1_final_", 1) \
if entry.name.startswith("checkpoint_") else entry.name
if entry.is_file():
shutil.copy2(entry, dst_s1 / new_name)
elif entry.is_dir():
shutil.copytree(entry, dst_s1 / new_name, dirs_exist_ok=True)
placed_any = True
print(f"[hydrate_run_dir_from_hf] stage1 final placed at {dst_s1}")
# Stage 1 last → checkpoint-1 (ONLY if stage1 didn't finish yet)
if not (dst_root / stage1_subdir / "stage1_final_projection.pt").exists():
s1_last_src = src_root / "stage1" / "last"
if s1_last_src.is_dir() and any(s1_last_src.iterdir()):
dst = dst_root / stage1_subdir / "checkpoint-1"
dst.mkdir(parents=True, exist_ok=True)
shutil.copytree(s1_last_src, dst, dirs_exist_ok=True)
placed_any = True
print(f"[hydrate_run_dir_from_hf] stage1 mid-resume placed at {dst}")
# Cleanup staging
shutil.rmtree(staging, ignore_errors=True)
if placed_any:
print(f"[hydrate_run_dir_from_hf] hydrated {dst_root} from HF")
else:
print(f"[hydrate_run_dir_from_hf] nothing usable on HF for {run_id}")
return placed_any
def build_tracker_from_cfg(train_cfg, resuming: bool = False, explicit_run_id: Optional[str] = None):
"""Convenience factory from OmegaConf DictConfig."""
hf = getattr(train_cfg, "hf_hub", None)
if hf is None or not getattr(hf, "enabled", False):
return None
token = os.environ.get(hf.token_env, os.environ.get("HF_TOKEN"))
if not token:
print(f"[hf_uploader] no {hf.token_env} / HF_TOKEN in env — hub upload disabled")
return None
if not hf.repo_id:
print("[hf_uploader] hf_hub.repo_id not set — hub upload disabled")
return None
return HFRunTracker(
repo_id = hf.repo_id,
token = token,
state_file = hf.run_state_file,
resuming = resuming,
explicit_run_id = explicit_run_id,
private = getattr(hf, "private", True),
)