""" train.py -------- Main training entry point. Implements 2-stage training following RaDialog: Stage 1 — Train projection layer only (2 epochs, LR=1e-3) Stage 2 — Train projection + LLM LoRA (10 epochs, LR=2e-4) Supports two datasets, selected via `train_cfg.data.dataset_name`: - "MIMIC-CXR" → all 3 tasks (findings, impression, VQA) - "IU-Xray" → findings + impression only (no VQA) Checkpoints and results are written under: {training.output_root}/{dataset_name}_run_{N}/stageX_* Usage: python -m training.train --config configs/train_config.yaml """ import os import sys from pathlib import Path # Silence HF per-shard download tqdm spam — MUST be before transformers/peft/hf_hub imports sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import utils._quiet # noqa: F401 import argparse import torch from omegaconf import OmegaConf # Free perf win on A100/H100 (no-op on T4): allow TF32 for fp32 matmul / cuDNN. torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import transformers from transformers import TrainingArguments, Trainer, TrainerCallback, PrinterCallback from transformers.trainer_callback import ProgressCallback class _NoEvalTqdmCallback(ProgressCallback): """Same as HF's ProgressCallback but with the per-batch eval bar disabled. In a Colab `!python -m ...` subprocess HF Trainer's `is_in_notebook()` returns False (no IPython kernel in the child) so it falls back to plain tqdm. Colab's text renderer mishandles `\\r` for fast updates, so the eval bar (~1 batch/sec × 1250 batches) prints a fresh line every step and lags the browser tab. Training tqdm updates slowly enough (one bar line per ~9s at 24M params + LoRA + bf16) that it stays clean, so we only kill the prediction bar. eval_loss is still logged at the end of each eval pass via the standard log_history mechanism.""" def on_prediction_step(self, args, state, control, **kwargs): # noqa: D401 return # Add project root to path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from model import CXRVisionLanguageModel from model.rad_dino import BioViLTEncoder from data import CXRInstructDataset, CXRDataCollator, ITCDataCollator from utils.logger import setup_logger from utils.checkpoint import save_checkpoint, load_checkpoint from utils.hf_uploader import build_tracker_from_cfg, pull_last_for_resume, hydrate_run_dir_from_hf from utils.dataset_resolver import ( resolve_dataset_spec, resolve_run_id, save_run_config, run_dir, stage_dir, DatasetSpec, ) def parse_args(): parser = argparse.ArgumentParser(description="Train CXR VLM") parser.add_argument( "--model_config", type=str, default="configs/model_config.yaml", help="Path to model config" ) parser.add_argument( "--train_config", type=str, default="configs/train_config.yaml", help="Path to training config" ) parser.add_argument( "--stage", type=int, default=None, help="Run only stage 1 or stage 2 (default: run both). With --mode resume, " "the stage is auto-detected and this flag should be left unset." ) parser.add_argument( "--mode", type=str, default=None, choices=["fresh", "resume"], help="Unified resume controller. 'fresh' → new run_N folder. " "'resume' → reuse latest matching run_id (or --run_id), auto-detect " "which stage to continue from based on checkpoints on disk. " "If unset, behaviour is inferred from --resume_from / --run_id (legacy)." ) parser.add_argument( "--resume_from", type=str, default=None, help="Path to checkpoint to resume from (legacy; prefer --mode resume)" ) parser.add_argument( "--run_id", type=str, default=None, help="Explicit run id (e.g. 'IU-Xray_run_3'). If unset, auto-resolve." ) parser.add_argument( "--resume_from_hf", action="store_true", help="Pull //last/ from the HF Hub and resume from it. " "Use when training on a fresh VM after the previous one was killed." ) return parser.parse_args() # ─── Resume-point auto-detection ──────────────────────────────────────────── def _list_checkpoints(stage_dir): """Return [Path, …] of `checkpoint-NNN` folders sorted ascending by step.""" if not stage_dir.is_dir(): return [] out = [] for p in stage_dir.iterdir(): if not p.is_dir() or not p.name.startswith("checkpoint-"): continue suffix = p.name.split("-", 1)[1] if suffix.isdigit(): out.append((int(suffix), p)) return [p for _, p in sorted(out)] def detect_resume_point(run_dir_path, stage1_subdir, stage2_subdir): """ Inspect the run dir on disk and decide where to pick up training. Returns a tuple `(target_stage, ckpt_path)` where: target_stage : "stage1" | "stage2" | "done" ckpt_path : Path to the checkpoint folder to pass to HF Trainer, or None if the stage should start from scratch. Priority: 1. stage 2 final saved → ("done", None) everything finished 2. stage 2 has ckpts → ("stage2", latest) resume mid-stage2 3. stage 1 final saved → ("stage2", None) stage 1 done; start stage 2 4. stage 1 has ckpts → ("stage1", latest) resume mid-stage1 5. otherwise → ("stage1", None) brand-new run """ from pathlib import Path as _P run_dir_path = _P(run_dir_path) s1d = run_dir_path / stage1_subdir s2d = run_dir_path / stage2_subdir if (s2d / "stage2_final_projection.pt").exists(): return ("done", None) s2_ckpts = _list_checkpoints(s2d) if s2_ckpts: return ("stage2", s2_ckpts[-1]) if (s1d / "stage1_final_projection.pt").exists(): return ("stage2", None) s1_ckpts = _list_checkpoints(s1d) if s1_ckpts: return ("stage1", s1_ckpts[-1]) return ("stage1", None) def compute_training_plan(train_cfg, instruct_json_path): """ Compute a coarse plan of total optimizer steps across stage 1 + stage 2, derived from the train_config + the train-split sample count in the instruct JSON. Used to print a human-readable summary at startup. Returns a dict (all ints) — gracefully handles missing fields. """ import json as _json tr = train_cfg.training try: with open(instruct_json_path, "r", encoding="utf-8") as f: all_samples = _json.load(f) train_count = sum(1 for s in all_samples if s.get("split") == "train") except Exception: train_count = 0 def _eff_batch(stage_cfg, itc_override=None): # ITC override (stage1.itc.*) wins; else per-stage override; else global. if itc_override is not None: bs = int(itc_override.get("per_device_train_batch_size", _cfg(stage_cfg, tr, "per_device_train_batch_size", 1))) ga = int(itc_override.get("gradient_accumulation_steps", _cfg(stage_cfg, tr, "gradient_accumulation_steps", 1))) else: bs = int(_cfg(stage_cfg, tr, "per_device_train_batch_size", 1)) ga = int(_cfg(stage_cfg, tr, "gradient_accumulation_steps", 1)) return max(1, bs * ga) s1_cfg = train_cfg.stage1 s2_cfg = train_cfg.stage2 s1_enabled = bool(getattr(s1_cfg, "enabled", True)) s2_enabled = bool(getattr(s2_cfg, "enabled", True)) s1_epochs = int(getattr(s1_cfg, "num_epochs", 0)) if s1_enabled else 0 s2_epochs = int(getattr(s2_cfg, "num_epochs", 0)) if s2_enabled else 0 itc_cfg = s1_cfg.get("itc", None) itc_on = bool(itc_cfg and itc_cfg.get("enabled", False)) s1_eff = _eff_batch(s1_cfg, itc_cfg if itc_on else None) s2_eff = _eff_batch(s2_cfg) s1_spe = max(1, (train_count + s1_eff - 1) // s1_eff) s2_spe = max(1, (train_count + s2_eff - 1) // s2_eff) s1_steps = s1_spe * s1_epochs s2_steps = s2_spe * s2_epochs return { "train_samples": train_count, "effective_batch": s2_eff, # representative (Stage-2 / LM path) "stage1_eff_batch": s1_eff, "stage2_eff_batch": s2_eff, "steps_per_epoch": s2_spe, "stage1_steps": s1_steps, "stage2_steps": s2_steps, "total_steps": s1_steps + s2_steps, "stage1_epochs": s1_epochs, "stage2_epochs": s2_epochs, "itc_stage1": itc_on, } def _fmt_plan_banner(plan, run_id, target_stage, resume_ckpt): s1, s2, tot = plan["stage1_steps"], plan["stage2_steps"], plan["total_steps"] head = f"TRAINING PLAN — {run_id}" sep = "=" * max(len(head) + 4, 60) cur = "" # Prefer real global_step from trainer_state.json over parsing the folder # name — hydrate_run_dir_from_hf uses "checkpoint-1" as a placeholder # regardless of actual step, so the folder digit is meaningless. def _ckpt_step(ckpt): if not ckpt: return None try: import json as _json from pathlib import Path as _P ts = _P(str(ckpt)) / "trainer_state.json" if ts.is_file(): return int(_json.load(open(ts))["global_step"]) except Exception: pass # Fallback: parse folder suffix suf = str(ckpt).split("-")[-1] return int(suf) if suf.isdigit() else None real_step = _ckpt_step(resume_ckpt) if target_stage == "stage1": offset = real_step if real_step is not None else 0 cur = f"Resuming at step {offset} / {tot} (inside stage 1)" elif target_stage == "stage2": offset = (s1 + real_step) if real_step is not None else s1 cur = f"Resuming at step {offset} / {tot} (inside stage 2)" elif target_stage == "done": cur = f"All {tot} steps already complete — nothing to do" s1_eff = plan.get("stage1_eff_batch", plan["effective_batch"]) s2_eff = plan.get("stage2_eff_batch", plan["effective_batch"]) itc_tag = " [ITC]" if plan.get("itc_stage1") else "" lines = [ sep, f" {head}", sep, f" Train samples : {plan['train_samples']:,}", f" Stage 1{itc_tag:<6} : {plan['stage1_epochs']} epochs → {s1} steps " f"(eff.batch {s1_eff}; global steps 1–{s1})", f" Stage 2 : {plan['stage2_epochs']} epochs → {s2} steps " f"(eff.batch {s2_eff}; global steps {s1+1}–{tot})", f" TOTAL : {tot} optimizer steps", ] if cur: lines += [" " + "─" * (len(sep) - 4), f" {cur}"] lines.append(sep) return "\n".join(lines) def get_trainer( model, train_dataset, val_dataset, collator, training_args: TrainingArguments, itc_mode: bool = False, itc_temperature: float = 0.07, ) -> Trainer: """Build a HuggingFace Trainer. When `itc_mode=True` the loss is a symmetric image-text InfoNCE (Stage-1 contrastive alignment); otherwise it's the causal-LM loss. """ class CXRTrainer(Trainer): """Custom Trainer that passes images to model.forward(), and saves only the trainable artifacts (projection MLP + LoRA adapters) instead of the full ~5 GB Vicuna state dict that base Trainer would dump. Resume re-loads the same artifacts.""" def compute_loss(self, model, inputs, return_outputs=False, **kwargs): if itc_mode: return self._itc_loss(model, inputs, return_outputs) outputs = model( images = inputs["images"], input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"], labels = inputs["labels"], ) loss = outputs["loss"] return (loss, outputs) if return_outputs else loss def _itc_loss(self, model, inputs, return_outputs=False): """Symmetric image-text InfoNCE (CLIP/BLIP-2 ITC style). Image embeds come from model.forward_itc; text embeds are the precomputed (already L2-normed) CXR-BERT vectors. The diagonal of the (B, B) similarity matrix is the positive pair (dataset is deduped to one image per study, so no in-batch collisions).""" import torch.nn.functional as F mdl = model.module if hasattr(model, "module") else model img = mdl.forward_itc(inputs["images"]) # (B, d) normed txt = F.normalize(inputs["text_embeds"].to(img.dtype), dim=-1) # (B, d) logit_scale = 1.0 / max(itc_temperature, 1e-4) logits = logit_scale * img @ txt.t() # (B, B) tgt = torch.arange(logits.size(0), device=logits.device) loss = 0.5 * (F.cross_entropy(logits, tgt) + F.cross_entropy(logits.t(), tgt)) return (loss, {"loss": loss, "logits": logits}) if return_outputs else loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): # ITC eval has no token "labels"; force loss-only eval so # metric_for_best_model="eval_loss" works. if itc_mode: with torch.no_grad(): loss = self.compute_loss(model, inputs) return (loss.detach(), None, None) return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) def _save(self, output_dir=None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir Path(output_dir).mkdir(parents=True, exist_ok=True) # projection (.pt) + LoRA folder + optional CheXpert classifier save_checkpoint(self.model, output_dir, name="checkpoint") # TrainingArguments dump — needed for resume sanity check torch.save(self.args, Path(output_dir) / "training_args.bin") def _load_from_checkpoint(self, resume_from_checkpoint, model=None): # Bypass upstream's WEIGHTS_NAME / SAFE_WEIGHTS_NAME existence # check entirely; we never write those files. if model is None: model = self.model load_checkpoint(model, resume_from_checkpoint) def _get_train_sampler(self, *args, **kwargs): """ Use `WeightedRandomSampler` when the train dataset is mixed-task and exposes per-sample weights — this is what makes the configured `tasks.*.weight` ratios actually control batch composition. Falls back to HF's default (RandomSampler / DistributedSampler) for single-task or eval-time datasets. Notes: * Eval is unaffected — HF's `_get_eval_sampler` returns a `SequentialSampler` by default, so weighted reweighting only applies to training. * `replacement=True` is required for true oversampling — without it you can't draw more samples of a rare-but-upweighted task than physically exist. Tradeoff: a small fraction of samples in a numerous-but-downweighted task may never appear in a given epoch. Acceptable across multiple epochs. """ ds = self.train_dataset getter = getattr(ds, "get_per_sample_weights", None) if getter is not None: weights = getter() if weights is not None: from torch.utils.data import WeightedRandomSampler return WeightedRandomSampler( weights = weights, num_samples = len(ds), replacement = True, ) return super()._get_train_sampler(*args, **kwargs) trainer = CXRTrainer( model = model, args = training_args, train_dataset = train_dataset, eval_dataset = val_dataset, data_collator = collator, ) trainer.remove_callback(PrinterCallback) # Replace default ProgressCallback with one that skips the eval per-batch # bar — see _NoEvalTqdmCallback docstring for the Colab-subprocess rationale. trainer.remove_callback(ProgressCallback) trainer.add_callback(_NoEvalTqdmCallback()) return trainer def _cfg(stage_cfg, tr, key, default=None): """ Resolve a training hyperparameter with per-stage override semantics: use the value from `stage_cfg` (stage1/stage2) if present, else fall back to the global `training:` block, else `default`. This lets stages share machine-level settings (fp16, dataloader_*) while overriding stage-specific ones (batch size, warmup, optim) — see train_config.yaml docs. """ if stage_cfg is not None: v = stage_cfg.get(key, None) if v is not None: return v return tr.get(key, default) def _build_training_args(train_cfg, stage_cfg, out_dir, run_name, *, enable_best, overrides=None): """ Build TrainingArguments from config. Save / eval cadence, total limit, best-model logic come from `train_cfg.training`; per-stage values (batch, lr, warmup, optim, ...) resolve via `_cfg` (stage override → global). `overrides` (dict) wins over both — used for the ITC Stage-1 batch bump. `enable_best` toggles load_best_model_at_end. """ tr = train_cfg.training overrides = overrides or {} def pick(key, default=None): if key in overrides: return overrides[key] return _cfg(stage_cfg, tr, key, default) save_strategy = getattr(tr, "save_strategy", "steps") eval_strategy = getattr(tr, "evaluation_strategy", save_strategy) save_steps = getattr(tr, "save_steps", 200) eval_steps = getattr(tr, "eval_steps", save_steps) kwargs = dict( output_dir = out_dir, num_train_epochs = stage_cfg.num_epochs, per_device_train_batch_size = pick("per_device_train_batch_size", 1), per_device_eval_batch_size = pick("per_device_eval_batch_size", 1), gradient_accumulation_steps = pick("gradient_accumulation_steps", 1), learning_rate = stage_cfg.learning_rate, lr_scheduler_type = pick("lr_scheduler_type", "cosine"), warmup_ratio = pick("warmup_ratio", 0.0), weight_decay = pick("weight_decay", 0.0), fp16 = tr.fp16, bf16 = getattr(tr, "bf16", False), save_strategy = save_strategy, eval_strategy = eval_strategy, logging_steps = tr.logging_steps, save_total_limit = getattr(tr, "save_total_limit", 1), report_to = "wandb" if train_cfg.wandb.enabled else "none", run_name = run_name, dataloader_num_workers = getattr(tr, "dataloader_num_workers", 4), dataloader_pin_memory = getattr(tr, "dataloader_pin_memory", True), dataloader_persistent_workers = ( getattr(tr, "dataloader_persistent_workers", True) and getattr(tr, "dataloader_num_workers", 4) > 0 ), # `paged_adamw_8bit` (bnb) cuts optimizer-state VRAM ~4× with no # measurable quality loss (Dettmers ICLR'22). Default keeps the # legacy fp32 AdamW for backward-compat; the auto-detect cell in # the Colab notebook switches it on for Ampere+ GPUs. optim = pick("optim", "adamw_torch"), remove_unused_columns = False, ) if save_strategy == "steps": kwargs["save_steps"] = save_steps if eval_strategy == "steps": kwargs["eval_steps"] = eval_steps if enable_best: kwargs["load_best_model_at_end"] = getattr(tr, "load_best_model_at_end", True) kwargs["metric_for_best_model"] = getattr(tr, "metric_for_best_model", "eval_loss") kwargs["greater_is_better"] = getattr(tr, "greater_is_better", False) return TrainingArguments(**kwargs) class HFBestLastCallback(TrainerCallback): """ Maintain exactly two checkpoint folders + a training log on HF Hub for each stage: //last/ ← latest checkpoint (for resume) //best/ ← lowest eval_loss so far (for inference) //best/best_meta.json //training_log.jsonl `last/` is overwritten on every save. `best/` is overwritten only when `metric_for_best` improves. No checkpoint-/ folders accumulate. Each upload first deletes the remote folder so orphan files (e.g. a `optimizer.pt` from a previous step that no longer exists locally) are purged. """ def __init__(self, tracker, stage_subdir: str, logger, metric_for_best: str = "eval_loss", greater_is_better: bool = False): self.tracker = tracker self.stage_subdir = stage_subdir self.logger = logger self.metric_for_best = metric_for_best self.greater_is_better = greater_is_better self.best_metric = None self.best_step = None def _is_better(self, m: float) -> bool: if self.best_metric is None: return True return (m > self.best_metric) if self.greater_is_better else (m < self.best_metric) def on_evaluate(self, args, state, control, metrics=None, **kw): if self.tracker is None or not metrics: return m = metrics.get(self.metric_for_best) if m is None: return if self._is_better(m): self.best_metric = float(m) self.best_step = state.global_step def on_save(self, args, state, control, **kw): if self.tracker is None: return ckpt_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}" if not ckpt_dir.exists(): return # ── last/ — every save ──────────────────────────────────────── try: self.tracker.delete_remote(f"{self.stage_subdir}/last") self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/last") except Exception as e: self.logger.warning(f"[HF upload] last @ step {state.global_step} failed: {e}") # ── best/ — only if this step is the new best ──────────────── if self.best_step == state.global_step: try: self.tracker.delete_remote(f"{self.stage_subdir}/best") self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/best") self.tracker.upload_json( { "step": state.global_step, "epoch": state.epoch, self.metric_for_best: self.best_metric, }, f"{self.stage_subdir}/best/best_meta.json", ) except Exception as e: self.logger.warning(f"[HF upload] best @ step {state.global_step} failed: {e}") # ── training log — full log_history each save ──────────────── try: self.tracker.upload_jsonl( state.log_history, f"{self.stage_subdir}/training_log.jsonl", ) except Exception as e: self.logger.warning(f"[HF upload] training_log failed: {e}") def _build_datasets(spec: DatasetSpec, train_cfg, model, transform_train, transform_val, itc_mode: bool = False, itc_text_cache=None): """Construct train + val CXRInstructDataset instances from a DatasetSpec. When `itc_mode=True`, datasets yield (image, precomputed text_embed) pairs for Stage-1 contrastive alignment instead of tokenized prompt/target.""" feature_cache_dir = getattr(train_cfg.data, "feature_cache_dir", None) or None if feature_cache_dir: print(f"[_build_datasets] feature_cache_dir = {feature_cache_dir} " f"(encoder bypass on cache hit)") common = dict( data_path = spec.instruct_json, image_root = spec.image_root, tokenizer = model.tokenizer, task = "mixed", cutoff_len = train_cfg.training.cutoff_len, task_weights = spec.task_weights, max_images = spec.max_images, feature_cache_dir = feature_cache_dir, itc_mode = itc_mode, itc_text_cache = itc_text_cache, ) train_ds = CXRInstructDataset(transform=transform_train, split=train_cfg.data.train_split, **common) val_ds = CXRInstructDataset(transform=transform_val, split=train_cfg.data.val_split, **common) return train_ds, val_ds def _stage1_itc_cfg(train_cfg): """Return the stage1.itc DictConfig if ITC is enabled, else None.""" itc = train_cfg.stage1.get("itc", None) if itc and itc.get("enabled", False): return itc return None def run_stage1(model, train_cfg, model_cfg, spec, out_dir, logger, tracker=None, resume_from=None): """ Stage 1: align the projection. • default → train projection with the causal-LM loss (Vicuna fwd). • stage1.itc on → train projection + ITC head with InfoNCE against precomputed CXR-BERT text embeddings (no Vicuna). Vision encoder is frozen in both. """ itc = _stage1_itc_cfg(train_cfg) mode_tag = "ITC contrastive" if itc else "projection only (LM loss)" logger.info("=" * 60) logger.info(f"STAGE 1: {mode_tag} [{spec.dataset_name}]") logger.info(f" output_dir = {out_dir}") logger.info("=" * 60) overrides = None if itc: model.set_stage1_itc_mode() text_cache = itc.get("text_embed_cache", None) if not text_cache: raise ValueError("stage1.itc.enabled=true requires stage1.itc.text_embed_cache") train_ds, val_ds = _build_datasets( spec, train_cfg, model, transform_train = BioViLTEncoder.get_transform("train"), transform_val = BioViLTEncoder.get_transform("val"), itc_mode = True, itc_text_cache = text_cache, ) collator = ITCDataCollator() # ITC batch overrides (no Vicuna → big batch). Fall through to stage/global. overrides = { k: itc.get(k) for k in ( "per_device_train_batch_size", "per_device_eval_batch_size", "gradient_accumulation_steps") if itc.get(k, None) is not None } else: model.set_stage1_mode() train_ds, val_ds = _build_datasets( spec, train_cfg, model, transform_train = BioViLTEncoder.get_transform("train"), transform_val = BioViLTEncoder.get_transform("val"), ) collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id) model.print_trainable_params() training_args = _build_training_args( train_cfg, train_cfg.stage1, out_dir, run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage1", enable_best = True, overrides = overrides, ) trainer = get_trainer( model, train_ds, val_ds, collator, training_args, itc_mode = bool(itc), itc_temperature = float(itc.get("temperature", 0.07)) if itc else 0.07, ) if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None: tr = train_cfg.training trainer.add_callback(HFBestLastCallback( tracker, stage_subdir = "stage1", logger = logger, metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"), greater_is_better = getattr(tr, "greater_is_better", False), )) if resume_from: logger.info(f"Resuming stage1 from checkpoint: {resume_from}") trainer.train(resume_from_checkpoint=resume_from) else: trainer.train() save_checkpoint(model, out_dir, "stage1_final") logger.info(f"Stage 1 complete. Checkpoint saved to {out_dir}") # ── HF Hub upload: stage1 final → overwrite best/ ──────────────── # With load_best_model_at_end=True, the in-memory model after train() # is the best one; save_checkpoint just dumped it as stage1_final_*. # Upload those files into stage1/best/ under the canonical artifact # names that load_checkpoint(name="checkpoint") expects. if tracker is not None: s1 = Path(out_dir) tracker.delete_remote("stage1/best") tracker.upload_file( str(s1 / "stage1_final_projection.pt"), "stage1/best/checkpoint_projection.pt", ) # ITC mode → no LoRA folder; upload the ITC head instead. lora_dir = s1 / "stage1_final_lora" if lora_dir.is_dir(): tracker.upload_folder(str(lora_dir), "stage1/best/checkpoint_lora") itc_head_pt = s1 / "stage1_final_itc_head.pt" if itc_head_pt.exists(): tracker.upload_file(str(itc_head_pt), "stage1/best/checkpoint_itc_head.pt") chexpert_pt = s1 / "stage1_final_chexpert_classifier.pt" if chexpert_pt.exists(): tracker.upload_file( str(chexpert_pt), "stage1/best/checkpoint_chexpert_classifier.pt", ) tracker.write_meta({ "dataset_name": spec.dataset_name, "stage1_done": True, "stage1_mode": "itc" if itc else "lm", "stage1_output_dir": out_dir, "stage1_epochs": train_cfg.stage1.num_epochs, "stage1_lr": train_cfg.stage1.learning_rate, }) return model def run_stage2(model, train_cfg, model_cfg, spec, out_dir, logger, resume_from=None, tracker=None): """ Stage 2: Train projection + LLM LoRA (instruction tuning). Vision encoder frozen. LLM trained via LoRA adapters. """ logger.info("=" * 60) logger.info(f"STAGE 2: Instruction tuning (projection + LoRA) [{spec.dataset_name}]") logger.info(f" output_dir = {out_dir}") logger.info("=" * 60) model.set_stage2_mode() model.print_trainable_params() train_ds, val_ds = _build_datasets( spec, train_cfg, model, transform_train = BioViLTEncoder.get_transform("train"), transform_val = BioViLTEncoder.get_transform("val"), ) training_args = _build_training_args( train_cfg, train_cfg.stage2, out_dir, run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage2", enable_best = True, ) collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id) trainer = get_trainer(model, train_ds, val_ds, collator, training_args) if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None: tr = train_cfg.training trainer.add_callback(HFBestLastCallback( tracker, stage_subdir = "stage2", logger = logger, metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"), greater_is_better = getattr(tr, "greater_is_better", False), )) if resume_from: trainer.train(resume_from_checkpoint=resume_from) else: trainer.train() save_checkpoint(model, out_dir, "stage2_final") logger.info(f"Stage 2 complete. Checkpoint saved to {out_dir}") # ── HF Hub upload: stage2 final → overwrite best/ ──────────────── if tracker is not None: s2 = Path(out_dir) tracker.delete_remote("stage2/best") tracker.upload_file( str(s2 / "stage2_final_projection.pt"), "stage2/best/checkpoint_projection.pt", ) tracker.upload_folder( str(s2 / "stage2_final_lora"), "stage2/best/checkpoint_lora", ) chexpert_pt = s2 / "stage2_final_chexpert_classifier.pt" if chexpert_pt.exists(): tracker.upload_file( str(chexpert_pt), "stage2/best/checkpoint_chexpert_classifier.pt", ) tracker.write_meta({ "dataset_name": spec.dataset_name, "stage2_done": True, "stage2_output_dir": out_dir, "stage2_epochs": train_cfg.stage2.num_epochs, "stage2_lr": train_cfg.stage2.learning_rate, }) return model def main(): args = parse_args() logger = setup_logger("cxr_vlm_train") # Load configs model_cfg = OmegaConf.load(args.model_config) train_cfg = OmegaConf.load(args.train_config) logger.info(f"Model config: {args.model_config}") logger.info(f"Train config: {args.train_config}") # ── Resolve dataset spec (paths, tasks, weights) ───────────────── spec = resolve_dataset_spec(train_cfg) logger.info(f"Dataset: {spec.dataset_name}") logger.info(f" image_root = {spec.image_root}") logger.info(f" instruct_json = {spec.instruct_json}") logger.info(f" tasks = {spec.tasks}") logger.info(f" task_weights = {spec.task_weights}") # ── Resolve per-dataset run_id (e.g. IU-Xray_run_3) ────────────── output_root = str(train_cfg.training.get("output_root", "checkpoints")) state_file = str(train_cfg.hf_hub.run_state_file) hf_token = os.environ.get( train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN") ) if train_cfg.hf_hub.enabled else None hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None # Unified --mode controller. Falls back to the legacy inference (any of # --resume_from / --resume_from_hf set ⇒ resuming) when --mode is unset. if args.mode == "resume": resuming = True elif args.mode == "fresh": resuming = False else: resuming = bool(args.resume_from) or args.resume_from_hf run_id = resolve_run_id( dataset_name = spec.dataset_name, output_root = output_root, state_file = state_file, resuming = resuming, explicit = args.run_id, hf_repo_id = hf_repo_id, hf_token = hf_token, ) logger.info(f"run_id = {run_id}") # ── Optional: pull last/ from HF Hub for resume on a fresh VM ──── if args.resume_from_hf: if not train_cfg.hf_hub.enabled or not train_cfg.hf_hub.repo_id: sys.exit("--resume_from_hf requires hf_hub.enabled=true and a repo_id") target_stage_subdir = ( str(train_cfg.stage1.get("subdir", "stage1_projection")) if args.stage == 1 else str(train_cfg.stage2.get("subdir", "stage2_instruct")) ) # On the hub the stage subfolder is just "stage1" / "stage2" hub_stage_dir = "stage1" if args.stage == 1 else "stage2" local_resume = pull_last_for_resume( repo_id = train_cfg.hf_hub.repo_id, token = os.environ.get(train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")), run_id = run_id, stage_subdir = hub_stage_dir, local_root = str(Path(output_root) / "_resume_from_hf"), ) if local_resume is None: sys.exit(f"Could not pull {run_id}/{hub_stage_dir}/last/ from HF — abort.") args.resume_from = local_resume logger.info(f"Will resume from pulled checkpoint: {local_resume} (stage{args.stage})") # ── Fresh-VM resume: hydrate from HF before detect_resume_point ── # When `--mode resume` is set but the local run dir is empty (Colab # persistence lost, switching machines), pull configs + last/best # checkpoints from HF Hub into the canonical local layout so the # detector finds them. No-op if local already has artifacts or HF # tracking is disabled. if args.mode == "resume" and hf_repo_id and hf_token: try: hydrate_run_dir_from_hf( repo_id = hf_repo_id, token = hf_token, run_id = run_id, output_root = output_root, stage1_subdir = str(train_cfg.stage1.get("subdir", "stage1_projection")), stage2_subdir = str(train_cfg.stage2.get("subdir", "stage2_instruct")), ) except Exception as e: logger.warning(f"[resume hydrate] {type(e).__name__}: {e}") # ── Compute per-stage output dirs under {output_root}/{run_id}/ ── stage1_out = stage_dir(output_root, run_id, str(train_cfg.stage1.get("subdir", "stage1_projection"))) stage2_out = stage_dir(output_root, run_id, str(train_cfg.stage2.get("subdir", "stage2_instruct"))) # ── Auto-detect where to resume from (when --mode resume) ───────── # Examines disk state inside {output_root}/{run_id}/ and chooses: # • stage1 from scratch / stage1 mid-checkpoint # • stage2 from scratch (stage1 done) / stage2 mid-checkpoint # • done (both stages finished — skip everything) # If the user passed --stage explicitly, that wins over auto-detect. auto_target_stage = None auto_resume_ckpt = None if args.mode == "resume" and args.stage is None: auto_target_stage, auto_resume_ckpt = detect_resume_point( run_dir(output_root, run_id), str(train_cfg.stage1.get("subdir", "stage1_projection")), str(train_cfg.stage2.get("subdir", "stage2_instruct")), ) logger.info( f"[resume autodetect] target={auto_target_stage} " f"ckpt={auto_resume_ckpt}" ) # ── Pretty plan banner (total steps across both stages) ─────────── plan = compute_training_plan(train_cfg, spec.instruct_json) logger.info("\n" + _fmt_plan_banner(plan, run_id, auto_target_stage or "stage1", auto_resume_ckpt)) if auto_target_stage == "done": logger.info("Both stages already complete for this run. Exiting cleanly.") return # ── Snapshot resolved config into the run dir ──────────────────── # Every run gets its own self-describing folder so we never have to ask # "what config did IU-Xray_run_3 actually use?" — open run_meta.json. # Written AFTER stage dirs are created so the run dir definitely exists. save_run_config( run_dir_path = run_dir(output_root, run_id), spec = spec, model_cfg = model_cfg, train_cfg = train_cfg, extra = { "stage_arg": args.stage, "resumed": bool(args.resume_from) or args.resume_from_hf, "resume_from": args.resume_from, "resume_from_hf": args.resume_from_hf, "model_config_path": args.model_config, "train_config_path": args.train_config, }, ) # Setup WandB if train_cfg.wandb.enabled: os.environ["WANDB_PROJECT"] = train_cfg.wandb.project # ── HuggingFace Hub tracker (optional) ─────────────────────────── # Pass our resolved run_id as `explicit_run_id` so the tracker uses # the same dataset-prefixed folder on the hub. tracker = build_tracker_from_cfg( train_cfg, resuming = bool(args.resume_from), explicit_run_id = run_id, ) if tracker is not None: tracker.write_meta({ "dataset_name": spec.dataset_name, "run_id": run_id, "config_model": args.model_config, "config_train": args.train_config, "resumed": bool(args.resume_from), "resume_from": args.resume_from, }) # Snapshot the resolved config + run_meta.json to HF so the run is # self-describing on the hub (you can answer "what config did # {run_id} actually use?" without pulling the whole checkpoint). # `save_run_config` writes these into {run_dir}/configs/ + # {run_dir}/run_meta.json a few lines above. rd = run_dir(output_root, run_id) if (rd / "configs").is_dir(): tracker.upload_folder(str(rd / "configs"), "configs") if (rd / "run_meta.json").is_file(): tracker.upload_file(str(rd / "run_meta.json"), "run_meta.json") # NOTE: model is built BELOW, after the stage-selection logic, so we know # whether Stage 1 runs in ITC mode (→ build a light model without Vicuna). # Run training stages # # Stage selection priority: # 1. Explicit --stage from CLI wins. # 2. --mode resume + auto-detect: skip stage1 when its final ckpt exists, # resume stage1/stage2 from `auto_resume_ckpt` as detected above. # 3. Otherwise: enabled flags from train_cfg drive it (legacy: run both). if args.stage is not None: run_s1 = (args.stage == 1) and train_cfg.stage1.enabled run_s2 = (args.stage == 2) and train_cfg.stage2.enabled elif auto_target_stage == "stage2": # Stage 1 finished previously — skip it entirely. run_s1 = False run_s2 = train_cfg.stage2.enabled else: run_s1 = train_cfg.stage1.enabled run_s2 = train_cfg.stage2.enabled # Decide the resume checkpoint each stage should use. # Manual --resume_from still wins when --stage is given explicitly. s1_resume_path = None s2_resume_path = None if args.stage == 1: s1_resume_path = args.resume_from elif args.stage == 2: s2_resume_path = args.resume_from elif auto_target_stage == "stage1": s1_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None elif auto_target_stage == "stage2": s2_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None # ── Build model (per-stage) ────────────────────────────────────── # ITC Stage-1 needs no Vicuna → build a light model (saves ~13GB VRAM, # enables a big contrastive batch). Every other case builds the full # model with Vicuna + LoRA, byte-identical to the original flow. itc_on = _stage1_itc_cfg(train_cfg) is not None if itc_on and run_s1: logger.info("Building CXR VLM (light: ITC Stage-1, Vicuna NOT loaded)...") model = CXRVisionLanguageModel(model_cfg, load_llm=False, build_itc_head=True) else: logger.info("Building CXR VLM...") model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False) if args.resume_from: logger.info(f"Resuming from: {args.resume_from}") load_checkpoint(model, args.resume_from) if run_s1: model = run_stage1( model, train_cfg, model_cfg, spec, stage1_out, logger, tracker = tracker, resume_from = s1_resume_path, ) if run_s2: # Stage 2 always needs the full model. If Stage 1 built a light # (no-Vicuna) model, free it and rebuild the full one; the stage1 # projection is then seeded from the checkpoint just below. if getattr(model, "llm", None) is None: logger.info("Rebuilding full CXR VLM (with Vicuna) for Stage 2...") import gc del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False) if args.resume_from and not s2_resume_path: load_checkpoint(model, args.resume_from) # Load stage1 projection weights before stage2 if available. # Priority: # 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt # 2. Not running stage1 but stage1_final.pt exists on disk → load it # 3. s2_resume_path set (we're mid-stage2) → Trainer will reload from # the checkpoint itself; no need to seed stage1 weights here. # 4. Nothing → warn loudly; stage2 starts with random projection. stage1_ckpt = Path(stage1_out) / "stage1_final.pt" if run_s1: load_checkpoint(model, str(stage1_ckpt)) logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}") elif stage1_ckpt.exists() and not s2_resume_path: load_checkpoint(model, str(stage1_ckpt)) logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}") elif not s2_resume_path: logger.warning( "⚠ No stage1 weights found and not resuming. Projection layer " "will start RANDOMLY for stage2. Expect degraded convergence. " f"Looked at: {stage1_ckpt}" ) model = run_stage2( model, train_cfg, model_cfg, spec, stage2_out, logger, resume_from = s2_resume_path, tracker = tracker, ) # Upload final results folder (predictions + metrics) if evaluate.py has run if tracker is not None: results_dir = Path("results") / run_id if results_dir.exists(): tracker.upload_folder(str(results_dir), "results") tracker.write_meta({"training_complete": True}) logger.info("Training complete!") if __name__ == "__main__": main()