cxr-vlm-code / training /train.py
convitom
p
477f80e
"""
train.py
--------
Main training entry point.
Implements 2-stage training following RaDialog:
Stage 1 — Train projection layer only (2 epochs, LR=1e-3)
Stage 2 — Train projection + LLM LoRA (10 epochs, LR=2e-4)
Supports two datasets, selected via `train_cfg.data.dataset_name`:
- "MIMIC-CXR" → all 3 tasks (findings, impression, VQA)
- "IU-Xray" → findings + impression only (no VQA)
Checkpoints and results are written under:
{training.output_root}/{dataset_name}_run_{N}/stageX_*
Usage:
python -m training.train --config configs/train_config.yaml
"""
import os
import sys
from pathlib import Path
# Silence HF per-shard download tqdm spam — MUST be before transformers/peft/hf_hub imports
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import utils._quiet # noqa: F401
import argparse
import torch
from omegaconf import OmegaConf
# Free perf win on A100/H100 (no-op on T4): allow TF32 for fp32 matmul / cuDNN.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import transformers
from transformers import TrainingArguments, Trainer, TrainerCallback, PrinterCallback
from transformers.trainer_callback import ProgressCallback
class _NoEvalTqdmCallback(ProgressCallback):
"""Same as HF's ProgressCallback but with the per-batch eval bar disabled.
In a Colab `!python -m ...` subprocess HF Trainer's `is_in_notebook()`
returns False (no IPython kernel in the child) so it falls back to plain
tqdm. Colab's text renderer mishandles `\\r` for fast updates, so the
eval bar (~1 batch/sec × 1250 batches) prints a fresh line every step
and lags the browser tab. Training tqdm updates slowly enough (one bar
line per ~9s at 24M params + LoRA + bf16) that it stays clean, so we
only kill the prediction bar. eval_loss is still logged at the end of
each eval pass via the standard log_history mechanism."""
def on_prediction_step(self, args, state, control, **kwargs): # noqa: D401
return
# Add project root to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from model import CXRVisionLanguageModel
from model.rad_dino import BioViLTEncoder
from data import CXRInstructDataset, CXRDataCollator, ITCDataCollator
from utils.logger import setup_logger
from utils.checkpoint import save_checkpoint, load_checkpoint
from utils.hf_uploader import build_tracker_from_cfg, pull_last_for_resume, hydrate_run_dir_from_hf
from utils.dataset_resolver import (
resolve_dataset_spec,
resolve_run_id,
save_run_config,
run_dir,
stage_dir,
DatasetSpec,
)
def parse_args():
parser = argparse.ArgumentParser(description="Train CXR VLM")
parser.add_argument(
"--model_config", type=str,
default="configs/model_config.yaml",
help="Path to model config"
)
parser.add_argument(
"--train_config", type=str,
default="configs/train_config.yaml",
help="Path to training config"
)
parser.add_argument(
"--stage", type=int, default=None,
help="Run only stage 1 or stage 2 (default: run both). With --mode resume, "
"the stage is auto-detected and this flag should be left unset."
)
parser.add_argument(
"--mode", type=str, default=None, choices=["fresh", "resume"],
help="Unified resume controller. 'fresh' → new run_N folder. "
"'resume' → reuse latest matching run_id (or --run_id), auto-detect "
"which stage to continue from based on checkpoints on disk. "
"If unset, behaviour is inferred from --resume_from / --run_id (legacy)."
)
parser.add_argument(
"--resume_from", type=str, default=None,
help="Path to checkpoint to resume from (legacy; prefer --mode resume)"
)
parser.add_argument(
"--run_id", type=str, default=None,
help="Explicit run id (e.g. 'IU-Xray_run_3'). If unset, auto-resolve."
)
parser.add_argument(
"--resume_from_hf", action="store_true",
help="Pull <run_id>/<stage>/last/ from the HF Hub and resume from it. "
"Use when training on a fresh VM after the previous one was killed."
)
return parser.parse_args()
# ─── Resume-point auto-detection ────────────────────────────────────────────
def _list_checkpoints(stage_dir):
"""Return [Path, …] of `checkpoint-NNN` folders sorted ascending by step."""
if not stage_dir.is_dir():
return []
out = []
for p in stage_dir.iterdir():
if not p.is_dir() or not p.name.startswith("checkpoint-"):
continue
suffix = p.name.split("-", 1)[1]
if suffix.isdigit():
out.append((int(suffix), p))
return [p for _, p in sorted(out)]
def detect_resume_point(run_dir_path, stage1_subdir, stage2_subdir):
"""
Inspect the run dir on disk and decide where to pick up training.
Returns a tuple `(target_stage, ckpt_path)` where:
target_stage : "stage1" | "stage2" | "done"
ckpt_path : Path to the checkpoint folder to pass to HF Trainer,
or None if the stage should start from scratch.
Priority:
1. stage 2 final saved → ("done", None) everything finished
2. stage 2 has ckpts → ("stage2", latest) resume mid-stage2
3. stage 1 final saved → ("stage2", None) stage 1 done; start stage 2
4. stage 1 has ckpts → ("stage1", latest) resume mid-stage1
5. otherwise → ("stage1", None) brand-new run
"""
from pathlib import Path as _P
run_dir_path = _P(run_dir_path)
s1d = run_dir_path / stage1_subdir
s2d = run_dir_path / stage2_subdir
if (s2d / "stage2_final_projection.pt").exists():
return ("done", None)
s2_ckpts = _list_checkpoints(s2d)
if s2_ckpts:
return ("stage2", s2_ckpts[-1])
if (s1d / "stage1_final_projection.pt").exists():
return ("stage2", None)
s1_ckpts = _list_checkpoints(s1d)
if s1_ckpts:
return ("stage1", s1_ckpts[-1])
return ("stage1", None)
def compute_training_plan(train_cfg, instruct_json_path):
"""
Compute a coarse plan of total optimizer steps across stage 1 + stage 2,
derived from the train_config + the train-split sample count in the
instruct JSON. Used to print a human-readable summary at startup.
Returns a dict (all ints) — gracefully handles missing fields.
"""
import json as _json
tr = train_cfg.training
try:
with open(instruct_json_path, "r", encoding="utf-8") as f:
all_samples = _json.load(f)
train_count = sum(1 for s in all_samples if s.get("split") == "train")
except Exception:
train_count = 0
def _eff_batch(stage_cfg, itc_override=None):
# ITC override (stage1.itc.*) wins; else per-stage override; else global.
if itc_override is not None:
bs = int(itc_override.get("per_device_train_batch_size",
_cfg(stage_cfg, tr, "per_device_train_batch_size", 1)))
ga = int(itc_override.get("gradient_accumulation_steps",
_cfg(stage_cfg, tr, "gradient_accumulation_steps", 1)))
else:
bs = int(_cfg(stage_cfg, tr, "per_device_train_batch_size", 1))
ga = int(_cfg(stage_cfg, tr, "gradient_accumulation_steps", 1))
return max(1, bs * ga)
s1_cfg = train_cfg.stage1
s2_cfg = train_cfg.stage2
s1_enabled = bool(getattr(s1_cfg, "enabled", True))
s2_enabled = bool(getattr(s2_cfg, "enabled", True))
s1_epochs = int(getattr(s1_cfg, "num_epochs", 0)) if s1_enabled else 0
s2_epochs = int(getattr(s2_cfg, "num_epochs", 0)) if s2_enabled else 0
itc_cfg = s1_cfg.get("itc", None)
itc_on = bool(itc_cfg and itc_cfg.get("enabled", False))
s1_eff = _eff_batch(s1_cfg, itc_cfg if itc_on else None)
s2_eff = _eff_batch(s2_cfg)
s1_spe = max(1, (train_count + s1_eff - 1) // s1_eff)
s2_spe = max(1, (train_count + s2_eff - 1) // s2_eff)
s1_steps = s1_spe * s1_epochs
s2_steps = s2_spe * s2_epochs
return {
"train_samples": train_count,
"effective_batch": s2_eff, # representative (Stage-2 / LM path)
"stage1_eff_batch": s1_eff,
"stage2_eff_batch": s2_eff,
"steps_per_epoch": s2_spe,
"stage1_steps": s1_steps,
"stage2_steps": s2_steps,
"total_steps": s1_steps + s2_steps,
"stage1_epochs": s1_epochs,
"stage2_epochs": s2_epochs,
"itc_stage1": itc_on,
}
def _fmt_plan_banner(plan, run_id, target_stage, resume_ckpt):
s1, s2, tot = plan["stage1_steps"], plan["stage2_steps"], plan["total_steps"]
head = f"TRAINING PLAN — {run_id}"
sep = "=" * max(len(head) + 4, 60)
cur = ""
# Prefer real global_step from trainer_state.json over parsing the folder
# name — hydrate_run_dir_from_hf uses "checkpoint-1" as a placeholder
# regardless of actual step, so the folder digit is meaningless.
def _ckpt_step(ckpt):
if not ckpt:
return None
try:
import json as _json
from pathlib import Path as _P
ts = _P(str(ckpt)) / "trainer_state.json"
if ts.is_file():
return int(_json.load(open(ts))["global_step"])
except Exception:
pass
# Fallback: parse folder suffix
suf = str(ckpt).split("-")[-1]
return int(suf) if suf.isdigit() else None
real_step = _ckpt_step(resume_ckpt)
if target_stage == "stage1":
offset = real_step if real_step is not None else 0
cur = f"Resuming at step {offset} / {tot} (inside stage 1)"
elif target_stage == "stage2":
offset = (s1 + real_step) if real_step is not None else s1
cur = f"Resuming at step {offset} / {tot} (inside stage 2)"
elif target_stage == "done":
cur = f"All {tot} steps already complete — nothing to do"
s1_eff = plan.get("stage1_eff_batch", plan["effective_batch"])
s2_eff = plan.get("stage2_eff_batch", plan["effective_batch"])
itc_tag = " [ITC]" if plan.get("itc_stage1") else ""
lines = [
sep, f" {head}", sep,
f" Train samples : {plan['train_samples']:,}",
f" Stage 1{itc_tag:<6} : {plan['stage1_epochs']} epochs → {s1} steps "
f"(eff.batch {s1_eff}; global steps 1–{s1})",
f" Stage 2 : {plan['stage2_epochs']} epochs → {s2} steps "
f"(eff.batch {s2_eff}; global steps {s1+1}{tot})",
f" TOTAL : {tot} optimizer steps",
]
if cur:
lines += [" " + "─" * (len(sep) - 4), f" {cur}"]
lines.append(sep)
return "\n".join(lines)
def get_trainer(
model,
train_dataset,
val_dataset,
collator,
training_args: TrainingArguments,
itc_mode: bool = False,
itc_temperature: float = 0.07,
) -> Trainer:
"""Build a HuggingFace Trainer.
When `itc_mode=True` the loss is a symmetric image-text InfoNCE
(Stage-1 contrastive alignment); otherwise it's the causal-LM loss.
"""
class CXRTrainer(Trainer):
"""Custom Trainer that passes images to model.forward(), and saves
only the trainable artifacts (projection MLP + LoRA adapters) instead
of the full ~5 GB Vicuna state dict that base Trainer would dump.
Resume re-loads the same artifacts."""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
if itc_mode:
return self._itc_loss(model, inputs, return_outputs)
outputs = model(
images = inputs["images"],
input_ids = inputs["input_ids"],
attention_mask = inputs["attention_mask"],
labels = inputs["labels"],
)
loss = outputs["loss"]
return (loss, outputs) if return_outputs else loss
def _itc_loss(self, model, inputs, return_outputs=False):
"""Symmetric image-text InfoNCE (CLIP/BLIP-2 ITC style).
Image embeds come from model.forward_itc; text embeds are the
precomputed (already L2-normed) CXR-BERT vectors. The diagonal of
the (B, B) similarity matrix is the positive pair (dataset is
deduped to one image per study, so no in-batch collisions)."""
import torch.nn.functional as F
mdl = model.module if hasattr(model, "module") else model
img = mdl.forward_itc(inputs["images"]) # (B, d) normed
txt = F.normalize(inputs["text_embeds"].to(img.dtype), dim=-1) # (B, d)
logit_scale = 1.0 / max(itc_temperature, 1e-4)
logits = logit_scale * img @ txt.t() # (B, B)
tgt = torch.arange(logits.size(0), device=logits.device)
loss = 0.5 * (F.cross_entropy(logits, tgt) +
F.cross_entropy(logits.t(), tgt))
return (loss, {"loss": loss, "logits": logits}) if return_outputs else loss
def prediction_step(self, model, inputs, prediction_loss_only,
ignore_keys=None):
# ITC eval has no token "labels"; force loss-only eval so
# metric_for_best_model="eval_loss" works.
if itc_mode:
with torch.no_grad():
loss = self.compute_loss(model, inputs)
return (loss.detach(), None, None)
return super().prediction_step(model, inputs, prediction_loss_only,
ignore_keys=ignore_keys)
def _save(self, output_dir=None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
# projection (.pt) + LoRA folder + optional CheXpert classifier
save_checkpoint(self.model, output_dir, name="checkpoint")
# TrainingArguments dump — needed for resume sanity check
torch.save(self.args, Path(output_dir) / "training_args.bin")
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# Bypass upstream's WEIGHTS_NAME / SAFE_WEIGHTS_NAME existence
# check entirely; we never write those files.
if model is None:
model = self.model
load_checkpoint(model, resume_from_checkpoint)
def _get_train_sampler(self, *args, **kwargs):
"""
Use `WeightedRandomSampler` when the train dataset is mixed-task
and exposes per-sample weights — this is what makes the configured
`tasks.*.weight` ratios actually control batch composition.
Falls back to HF's default (RandomSampler / DistributedSampler)
for single-task or eval-time datasets.
Notes:
* Eval is unaffected — HF's `_get_eval_sampler` returns a
`SequentialSampler` by default, so weighted reweighting only
applies to training.
* `replacement=True` is required for true oversampling — without
it you can't draw more samples of a rare-but-upweighted task
than physically exist. Tradeoff: a small fraction of samples
in a numerous-but-downweighted task may never appear in a
given epoch. Acceptable across multiple epochs.
"""
ds = self.train_dataset
getter = getattr(ds, "get_per_sample_weights", None)
if getter is not None:
weights = getter()
if weights is not None:
from torch.utils.data import WeightedRandomSampler
return WeightedRandomSampler(
weights = weights,
num_samples = len(ds),
replacement = True,
)
return super()._get_train_sampler(*args, **kwargs)
trainer = CXRTrainer(
model = model,
args = training_args,
train_dataset = train_dataset,
eval_dataset = val_dataset,
data_collator = collator,
)
trainer.remove_callback(PrinterCallback)
# Replace default ProgressCallback with one that skips the eval per-batch
# bar — see _NoEvalTqdmCallback docstring for the Colab-subprocess rationale.
trainer.remove_callback(ProgressCallback)
trainer.add_callback(_NoEvalTqdmCallback())
return trainer
def _cfg(stage_cfg, tr, key, default=None):
"""
Resolve a training hyperparameter with per-stage override semantics:
use the value from `stage_cfg` (stage1/stage2) if present, else fall back
to the global `training:` block, else `default`. This lets stages share
machine-level settings (fp16, dataloader_*) while overriding stage-specific
ones (batch size, warmup, optim) — see train_config.yaml docs.
"""
if stage_cfg is not None:
v = stage_cfg.get(key, None)
if v is not None:
return v
return tr.get(key, default)
def _build_training_args(train_cfg, stage_cfg, out_dir, run_name, *, enable_best,
overrides=None):
"""
Build TrainingArguments from config. Save / eval cadence, total limit,
best-model logic come from `train_cfg.training`; per-stage values
(batch, lr, warmup, optim, ...) resolve via `_cfg` (stage override →
global). `overrides` (dict) wins over both — used for the ITC Stage-1
batch bump. `enable_best` toggles load_best_model_at_end.
"""
tr = train_cfg.training
overrides = overrides or {}
def pick(key, default=None):
if key in overrides:
return overrides[key]
return _cfg(stage_cfg, tr, key, default)
save_strategy = getattr(tr, "save_strategy", "steps")
eval_strategy = getattr(tr, "evaluation_strategy", save_strategy)
save_steps = getattr(tr, "save_steps", 200)
eval_steps = getattr(tr, "eval_steps", save_steps)
kwargs = dict(
output_dir = out_dir,
num_train_epochs = stage_cfg.num_epochs,
per_device_train_batch_size = pick("per_device_train_batch_size", 1),
per_device_eval_batch_size = pick("per_device_eval_batch_size", 1),
gradient_accumulation_steps = pick("gradient_accumulation_steps", 1),
learning_rate = stage_cfg.learning_rate,
lr_scheduler_type = pick("lr_scheduler_type", "cosine"),
warmup_ratio = pick("warmup_ratio", 0.0),
weight_decay = pick("weight_decay", 0.0),
fp16 = tr.fp16,
bf16 = getattr(tr, "bf16", False),
save_strategy = save_strategy,
eval_strategy = eval_strategy,
logging_steps = tr.logging_steps,
save_total_limit = getattr(tr, "save_total_limit", 1),
report_to = "wandb" if train_cfg.wandb.enabled else "none",
run_name = run_name,
dataloader_num_workers = getattr(tr, "dataloader_num_workers", 4),
dataloader_pin_memory = getattr(tr, "dataloader_pin_memory", True),
dataloader_persistent_workers = (
getattr(tr, "dataloader_persistent_workers", True)
and getattr(tr, "dataloader_num_workers", 4) > 0
),
# `paged_adamw_8bit` (bnb) cuts optimizer-state VRAM ~4× with no
# measurable quality loss (Dettmers ICLR'22). Default keeps the
# legacy fp32 AdamW for backward-compat; the auto-detect cell in
# the Colab notebook switches it on for Ampere+ GPUs.
optim = pick("optim", "adamw_torch"),
remove_unused_columns = False,
)
if save_strategy == "steps":
kwargs["save_steps"] = save_steps
if eval_strategy == "steps":
kwargs["eval_steps"] = eval_steps
if enable_best:
kwargs["load_best_model_at_end"] = getattr(tr, "load_best_model_at_end", True)
kwargs["metric_for_best_model"] = getattr(tr, "metric_for_best_model", "eval_loss")
kwargs["greater_is_better"] = getattr(tr, "greater_is_better", False)
return TrainingArguments(**kwargs)
class HFBestLastCallback(TrainerCallback):
"""
Maintain exactly two checkpoint folders + a training log on HF Hub for
each stage:
<run_id>/<stage>/last/ ← latest checkpoint (for resume)
<run_id>/<stage>/best/ ← lowest eval_loss so far (for inference)
<run_id>/<stage>/best/best_meta.json
<run_id>/<stage>/training_log.jsonl
`last/` is overwritten on every save. `best/` is overwritten only when
`metric_for_best` improves. No checkpoint-<step>/ folders accumulate.
Each upload first deletes the remote folder so orphan files (e.g. a
`optimizer.pt` from a previous step that no longer exists locally) are
purged.
"""
def __init__(self, tracker, stage_subdir: str, logger,
metric_for_best: str = "eval_loss",
greater_is_better: bool = False):
self.tracker = tracker
self.stage_subdir = stage_subdir
self.logger = logger
self.metric_for_best = metric_for_best
self.greater_is_better = greater_is_better
self.best_metric = None
self.best_step = None
def _is_better(self, m: float) -> bool:
if self.best_metric is None:
return True
return (m > self.best_metric) if self.greater_is_better else (m < self.best_metric)
def on_evaluate(self, args, state, control, metrics=None, **kw):
if self.tracker is None or not metrics:
return
m = metrics.get(self.metric_for_best)
if m is None:
return
if self._is_better(m):
self.best_metric = float(m)
self.best_step = state.global_step
def on_save(self, args, state, control, **kw):
if self.tracker is None:
return
ckpt_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
if not ckpt_dir.exists():
return
# ── last/ — every save ────────────────────────────────────────
try:
self.tracker.delete_remote(f"{self.stage_subdir}/last")
self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/last")
except Exception as e:
self.logger.warning(f"[HF upload] last @ step {state.global_step} failed: {e}")
# ── best/ — only if this step is the new best ────────────────
if self.best_step == state.global_step:
try:
self.tracker.delete_remote(f"{self.stage_subdir}/best")
self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/best")
self.tracker.upload_json(
{
"step": state.global_step,
"epoch": state.epoch,
self.metric_for_best: self.best_metric,
},
f"{self.stage_subdir}/best/best_meta.json",
)
except Exception as e:
self.logger.warning(f"[HF upload] best @ step {state.global_step} failed: {e}")
# ── training log — full log_history each save ────────────────
try:
self.tracker.upload_jsonl(
state.log_history,
f"{self.stage_subdir}/training_log.jsonl",
)
except Exception as e:
self.logger.warning(f"[HF upload] training_log failed: {e}")
def _build_datasets(spec: DatasetSpec, train_cfg, model, transform_train, transform_val,
itc_mode: bool = False, itc_text_cache=None):
"""Construct train + val CXRInstructDataset instances from a DatasetSpec.
When `itc_mode=True`, datasets yield (image, precomputed text_embed) pairs
for Stage-1 contrastive alignment instead of tokenized prompt/target."""
feature_cache_dir = getattr(train_cfg.data, "feature_cache_dir", None) or None
if feature_cache_dir:
print(f"[_build_datasets] feature_cache_dir = {feature_cache_dir} "
f"(encoder bypass on cache hit)")
common = dict(
data_path = spec.instruct_json,
image_root = spec.image_root,
tokenizer = model.tokenizer,
task = "mixed",
cutoff_len = train_cfg.training.cutoff_len,
task_weights = spec.task_weights,
max_images = spec.max_images,
feature_cache_dir = feature_cache_dir,
itc_mode = itc_mode,
itc_text_cache = itc_text_cache,
)
train_ds = CXRInstructDataset(transform=transform_train,
split=train_cfg.data.train_split, **common)
val_ds = CXRInstructDataset(transform=transform_val,
split=train_cfg.data.val_split, **common)
return train_ds, val_ds
def _stage1_itc_cfg(train_cfg):
"""Return the stage1.itc DictConfig if ITC is enabled, else None."""
itc = train_cfg.stage1.get("itc", None)
if itc and itc.get("enabled", False):
return itc
return None
def run_stage1(model, train_cfg, model_cfg, spec, out_dir, logger, tracker=None, resume_from=None):
"""
Stage 1: align the projection.
• default → train projection with the causal-LM loss (Vicuna fwd).
• stage1.itc on → train projection + ITC head with InfoNCE against
precomputed CXR-BERT text embeddings (no Vicuna).
Vision encoder is frozen in both.
"""
itc = _stage1_itc_cfg(train_cfg)
mode_tag = "ITC contrastive" if itc else "projection only (LM loss)"
logger.info("=" * 60)
logger.info(f"STAGE 1: {mode_tag} [{spec.dataset_name}]")
logger.info(f" output_dir = {out_dir}")
logger.info("=" * 60)
overrides = None
if itc:
model.set_stage1_itc_mode()
text_cache = itc.get("text_embed_cache", None)
if not text_cache:
raise ValueError("stage1.itc.enabled=true requires stage1.itc.text_embed_cache")
train_ds, val_ds = _build_datasets(
spec, train_cfg, model,
transform_train = BioViLTEncoder.get_transform("train"),
transform_val = BioViLTEncoder.get_transform("val"),
itc_mode = True,
itc_text_cache = text_cache,
)
collator = ITCDataCollator()
# ITC batch overrides (no Vicuna → big batch). Fall through to stage/global.
overrides = {
k: itc.get(k) for k in (
"per_device_train_batch_size", "per_device_eval_batch_size",
"gradient_accumulation_steps")
if itc.get(k, None) is not None
}
else:
model.set_stage1_mode()
train_ds, val_ds = _build_datasets(
spec, train_cfg, model,
transform_train = BioViLTEncoder.get_transform("train"),
transform_val = BioViLTEncoder.get_transform("val"),
)
collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id)
model.print_trainable_params()
training_args = _build_training_args(
train_cfg, train_cfg.stage1, out_dir,
run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage1",
enable_best = True,
overrides = overrides,
)
trainer = get_trainer(
model, train_ds, val_ds, collator, training_args,
itc_mode = bool(itc),
itc_temperature = float(itc.get("temperature", 0.07)) if itc else 0.07,
)
if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None:
tr = train_cfg.training
trainer.add_callback(HFBestLastCallback(
tracker,
stage_subdir = "stage1",
logger = logger,
metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"),
greater_is_better = getattr(tr, "greater_is_better", False),
))
if resume_from:
logger.info(f"Resuming stage1 from checkpoint: {resume_from}")
trainer.train(resume_from_checkpoint=resume_from)
else:
trainer.train()
save_checkpoint(model, out_dir, "stage1_final")
logger.info(f"Stage 1 complete. Checkpoint saved to {out_dir}")
# ── HF Hub upload: stage1 final → overwrite best/ ────────────────
# With load_best_model_at_end=True, the in-memory model after train()
# is the best one; save_checkpoint just dumped it as stage1_final_*.
# Upload those files into stage1/best/ under the canonical artifact
# names that load_checkpoint(name="checkpoint") expects.
if tracker is not None:
s1 = Path(out_dir)
tracker.delete_remote("stage1/best")
tracker.upload_file(
str(s1 / "stage1_final_projection.pt"),
"stage1/best/checkpoint_projection.pt",
)
# ITC mode → no LoRA folder; upload the ITC head instead.
lora_dir = s1 / "stage1_final_lora"
if lora_dir.is_dir():
tracker.upload_folder(str(lora_dir), "stage1/best/checkpoint_lora")
itc_head_pt = s1 / "stage1_final_itc_head.pt"
if itc_head_pt.exists():
tracker.upload_file(str(itc_head_pt), "stage1/best/checkpoint_itc_head.pt")
chexpert_pt = s1 / "stage1_final_chexpert_classifier.pt"
if chexpert_pt.exists():
tracker.upload_file(
str(chexpert_pt),
"stage1/best/checkpoint_chexpert_classifier.pt",
)
tracker.write_meta({
"dataset_name": spec.dataset_name,
"stage1_done": True,
"stage1_mode": "itc" if itc else "lm",
"stage1_output_dir": out_dir,
"stage1_epochs": train_cfg.stage1.num_epochs,
"stage1_lr": train_cfg.stage1.learning_rate,
})
return model
def run_stage2(model, train_cfg, model_cfg, spec, out_dir, logger,
resume_from=None, tracker=None):
"""
Stage 2: Train projection + LLM LoRA (instruction tuning).
Vision encoder frozen. LLM trained via LoRA adapters.
"""
logger.info("=" * 60)
logger.info(f"STAGE 2: Instruction tuning (projection + LoRA) [{spec.dataset_name}]")
logger.info(f" output_dir = {out_dir}")
logger.info("=" * 60)
model.set_stage2_mode()
model.print_trainable_params()
train_ds, val_ds = _build_datasets(
spec, train_cfg, model,
transform_train = BioViLTEncoder.get_transform("train"),
transform_val = BioViLTEncoder.get_transform("val"),
)
training_args = _build_training_args(
train_cfg, train_cfg.stage2, out_dir,
run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage2",
enable_best = True,
)
collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id)
trainer = get_trainer(model, train_ds, val_ds, collator, training_args)
if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None:
tr = train_cfg.training
trainer.add_callback(HFBestLastCallback(
tracker,
stage_subdir = "stage2",
logger = logger,
metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"),
greater_is_better = getattr(tr, "greater_is_better", False),
))
if resume_from:
trainer.train(resume_from_checkpoint=resume_from)
else:
trainer.train()
save_checkpoint(model, out_dir, "stage2_final")
logger.info(f"Stage 2 complete. Checkpoint saved to {out_dir}")
# ── HF Hub upload: stage2 final → overwrite best/ ────────────────
if tracker is not None:
s2 = Path(out_dir)
tracker.delete_remote("stage2/best")
tracker.upload_file(
str(s2 / "stage2_final_projection.pt"),
"stage2/best/checkpoint_projection.pt",
)
tracker.upload_folder(
str(s2 / "stage2_final_lora"),
"stage2/best/checkpoint_lora",
)
chexpert_pt = s2 / "stage2_final_chexpert_classifier.pt"
if chexpert_pt.exists():
tracker.upload_file(
str(chexpert_pt),
"stage2/best/checkpoint_chexpert_classifier.pt",
)
tracker.write_meta({
"dataset_name": spec.dataset_name,
"stage2_done": True,
"stage2_output_dir": out_dir,
"stage2_epochs": train_cfg.stage2.num_epochs,
"stage2_lr": train_cfg.stage2.learning_rate,
})
return model
def main():
args = parse_args()
logger = setup_logger("cxr_vlm_train")
# Load configs
model_cfg = OmegaConf.load(args.model_config)
train_cfg = OmegaConf.load(args.train_config)
logger.info(f"Model config: {args.model_config}")
logger.info(f"Train config: {args.train_config}")
# ── Resolve dataset spec (paths, tasks, weights) ─────────────────
spec = resolve_dataset_spec(train_cfg)
logger.info(f"Dataset: {spec.dataset_name}")
logger.info(f" image_root = {spec.image_root}")
logger.info(f" instruct_json = {spec.instruct_json}")
logger.info(f" tasks = {spec.tasks}")
logger.info(f" task_weights = {spec.task_weights}")
# ── Resolve per-dataset run_id (e.g. IU-Xray_run_3) ──────────────
output_root = str(train_cfg.training.get("output_root", "checkpoints"))
state_file = str(train_cfg.hf_hub.run_state_file)
hf_token = os.environ.get(
train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
) if train_cfg.hf_hub.enabled else None
hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
# Unified --mode controller. Falls back to the legacy inference (any of
# --resume_from / --resume_from_hf set ⇒ resuming) when --mode is unset.
if args.mode == "resume":
resuming = True
elif args.mode == "fresh":
resuming = False
else:
resuming = bool(args.resume_from) or args.resume_from_hf
run_id = resolve_run_id(
dataset_name = spec.dataset_name,
output_root = output_root,
state_file = state_file,
resuming = resuming,
explicit = args.run_id,
hf_repo_id = hf_repo_id,
hf_token = hf_token,
)
logger.info(f"run_id = {run_id}")
# ── Optional: pull last/ from HF Hub for resume on a fresh VM ────
if args.resume_from_hf:
if not train_cfg.hf_hub.enabled or not train_cfg.hf_hub.repo_id:
sys.exit("--resume_from_hf requires hf_hub.enabled=true and a repo_id")
target_stage_subdir = (
str(train_cfg.stage1.get("subdir", "stage1_projection"))
if args.stage == 1
else str(train_cfg.stage2.get("subdir", "stage2_instruct"))
)
# On the hub the stage subfolder is just "stage1" / "stage2"
hub_stage_dir = "stage1" if args.stage == 1 else "stage2"
local_resume = pull_last_for_resume(
repo_id = train_cfg.hf_hub.repo_id,
token = os.environ.get(train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")),
run_id = run_id,
stage_subdir = hub_stage_dir,
local_root = str(Path(output_root) / "_resume_from_hf"),
)
if local_resume is None:
sys.exit(f"Could not pull {run_id}/{hub_stage_dir}/last/ from HF — abort.")
args.resume_from = local_resume
logger.info(f"Will resume from pulled checkpoint: {local_resume} (stage{args.stage})")
# ── Fresh-VM resume: hydrate from HF before detect_resume_point ──
# When `--mode resume` is set but the local run dir is empty (Colab
# persistence lost, switching machines), pull configs + last/best
# checkpoints from HF Hub into the canonical local layout so the
# detector finds them. No-op if local already has artifacts or HF
# tracking is disabled.
if args.mode == "resume" and hf_repo_id and hf_token:
try:
hydrate_run_dir_from_hf(
repo_id = hf_repo_id,
token = hf_token,
run_id = run_id,
output_root = output_root,
stage1_subdir = str(train_cfg.stage1.get("subdir", "stage1_projection")),
stage2_subdir = str(train_cfg.stage2.get("subdir", "stage2_instruct")),
)
except Exception as e:
logger.warning(f"[resume hydrate] {type(e).__name__}: {e}")
# ── Compute per-stage output dirs under {output_root}/{run_id}/ ──
stage1_out = stage_dir(output_root, run_id,
str(train_cfg.stage1.get("subdir", "stage1_projection")))
stage2_out = stage_dir(output_root, run_id,
str(train_cfg.stage2.get("subdir", "stage2_instruct")))
# ── Auto-detect where to resume from (when --mode resume) ─────────
# Examines disk state inside {output_root}/{run_id}/ and chooses:
# • stage1 from scratch / stage1 mid-checkpoint
# • stage2 from scratch (stage1 done) / stage2 mid-checkpoint
# • done (both stages finished — skip everything)
# If the user passed --stage explicitly, that wins over auto-detect.
auto_target_stage = None
auto_resume_ckpt = None
if args.mode == "resume" and args.stage is None:
auto_target_stage, auto_resume_ckpt = detect_resume_point(
run_dir(output_root, run_id),
str(train_cfg.stage1.get("subdir", "stage1_projection")),
str(train_cfg.stage2.get("subdir", "stage2_instruct")),
)
logger.info(
f"[resume autodetect] target={auto_target_stage} "
f"ckpt={auto_resume_ckpt}"
)
# ── Pretty plan banner (total steps across both stages) ───────────
plan = compute_training_plan(train_cfg, spec.instruct_json)
logger.info("\n" + _fmt_plan_banner(plan, run_id,
auto_target_stage or "stage1",
auto_resume_ckpt))
if auto_target_stage == "done":
logger.info("Both stages already complete for this run. Exiting cleanly.")
return
# ── Snapshot resolved config into the run dir ────────────────────
# Every run gets its own self-describing folder so we never have to ask
# "what config did IU-Xray_run_3 actually use?" — open run_meta.json.
# Written AFTER stage dirs are created so the run dir definitely exists.
save_run_config(
run_dir_path = run_dir(output_root, run_id),
spec = spec,
model_cfg = model_cfg,
train_cfg = train_cfg,
extra = {
"stage_arg": args.stage,
"resumed": bool(args.resume_from) or args.resume_from_hf,
"resume_from": args.resume_from,
"resume_from_hf": args.resume_from_hf,
"model_config_path": args.model_config,
"train_config_path": args.train_config,
},
)
# Setup WandB
if train_cfg.wandb.enabled:
os.environ["WANDB_PROJECT"] = train_cfg.wandb.project
# ── HuggingFace Hub tracker (optional) ───────────────────────────
# Pass our resolved run_id as `explicit_run_id` so the tracker uses
# the same dataset-prefixed folder on the hub.
tracker = build_tracker_from_cfg(
train_cfg,
resuming = bool(args.resume_from),
explicit_run_id = run_id,
)
if tracker is not None:
tracker.write_meta({
"dataset_name": spec.dataset_name,
"run_id": run_id,
"config_model": args.model_config,
"config_train": args.train_config,
"resumed": bool(args.resume_from),
"resume_from": args.resume_from,
})
# Snapshot the resolved config + run_meta.json to HF so the run is
# self-describing on the hub (you can answer "what config did
# {run_id} actually use?" without pulling the whole checkpoint).
# `save_run_config` writes these into {run_dir}/configs/ +
# {run_dir}/run_meta.json a few lines above.
rd = run_dir(output_root, run_id)
if (rd / "configs").is_dir():
tracker.upload_folder(str(rd / "configs"), "configs")
if (rd / "run_meta.json").is_file():
tracker.upload_file(str(rd / "run_meta.json"), "run_meta.json")
# NOTE: model is built BELOW, after the stage-selection logic, so we know
# whether Stage 1 runs in ITC mode (→ build a light model without Vicuna).
# Run training stages
#
# Stage selection priority:
# 1. Explicit --stage from CLI wins.
# 2. --mode resume + auto-detect: skip stage1 when its final ckpt exists,
# resume stage1/stage2 from `auto_resume_ckpt` as detected above.
# 3. Otherwise: enabled flags from train_cfg drive it (legacy: run both).
if args.stage is not None:
run_s1 = (args.stage == 1) and train_cfg.stage1.enabled
run_s2 = (args.stage == 2) and train_cfg.stage2.enabled
elif auto_target_stage == "stage2":
# Stage 1 finished previously — skip it entirely.
run_s1 = False
run_s2 = train_cfg.stage2.enabled
else:
run_s1 = train_cfg.stage1.enabled
run_s2 = train_cfg.stage2.enabled
# Decide the resume checkpoint each stage should use.
# Manual --resume_from still wins when --stage is given explicitly.
s1_resume_path = None
s2_resume_path = None
if args.stage == 1:
s1_resume_path = args.resume_from
elif args.stage == 2:
s2_resume_path = args.resume_from
elif auto_target_stage == "stage1":
s1_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
elif auto_target_stage == "stage2":
s2_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
# ── Build model (per-stage) ──────────────────────────────────────
# ITC Stage-1 needs no Vicuna → build a light model (saves ~13GB VRAM,
# enables a big contrastive batch). Every other case builds the full
# model with Vicuna + LoRA, byte-identical to the original flow.
itc_on = _stage1_itc_cfg(train_cfg) is not None
if itc_on and run_s1:
logger.info("Building CXR VLM (light: ITC Stage-1, Vicuna NOT loaded)...")
model = CXRVisionLanguageModel(model_cfg, load_llm=False, build_itc_head=True)
else:
logger.info("Building CXR VLM...")
model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False)
if args.resume_from:
logger.info(f"Resuming from: {args.resume_from}")
load_checkpoint(model, args.resume_from)
if run_s1:
model = run_stage1(
model, train_cfg, model_cfg, spec, stage1_out, logger,
tracker = tracker,
resume_from = s1_resume_path,
)
if run_s2:
# Stage 2 always needs the full model. If Stage 1 built a light
# (no-Vicuna) model, free it and rebuild the full one; the stage1
# projection is then seeded from the checkpoint just below.
if getattr(model, "llm", None) is None:
logger.info("Rebuilding full CXR VLM (with Vicuna) for Stage 2...")
import gc
del model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False)
if args.resume_from and not s2_resume_path:
load_checkpoint(model, args.resume_from)
# Load stage1 projection weights before stage2 if available.
# Priority:
# 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt
# 2. Not running stage1 but stage1_final.pt exists on disk → load it
# 3. s2_resume_path set (we're mid-stage2) → Trainer will reload from
# the checkpoint itself; no need to seed stage1 weights here.
# 4. Nothing → warn loudly; stage2 starts with random projection.
stage1_ckpt = Path(stage1_out) / "stage1_final.pt"
if run_s1:
load_checkpoint(model, str(stage1_ckpt))
logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}")
elif stage1_ckpt.exists() and not s2_resume_path:
load_checkpoint(model, str(stage1_ckpt))
logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}")
elif not s2_resume_path:
logger.warning(
"⚠ No stage1 weights found and not resuming. Projection layer "
"will start RANDOMLY for stage2. Expect degraded convergence. "
f"Looked at: {stage1_ckpt}"
)
model = run_stage2(
model, train_cfg, model_cfg, spec, stage2_out, logger,
resume_from = s2_resume_path,
tracker = tracker,
)
# Upload final results folder (predictions + metrics) if evaluate.py has run
if tracker is not None:
results_dir = Path("results") / run_id
if results_dir.exists():
tracker.upload_folder(str(results_dir), "results")
tracker.write_meta({"training_complete": True})
logger.info("Training complete!")
if __name__ == "__main__":
main()