| """ |
| train.py |
| -------- |
| Main training entry point. |
| Implements 2-stage training following RaDialog: |
| Stage 1 — Train projection layer only (2 epochs, LR=1e-3) |
| Stage 2 — Train projection + LLM LoRA (10 epochs, LR=2e-4) |
| |
| Supports two datasets, selected via `train_cfg.data.dataset_name`: |
| - "MIMIC-CXR" → all 3 tasks (findings, impression, VQA) |
| - "IU-Xray" → findings + impression only (no VQA) |
| |
| Checkpoints and results are written under: |
| {training.output_root}/{dataset_name}_run_{N}/stageX_* |
| |
| Usage: |
| python -m training.train --config configs/train_config.yaml |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
| import utils._quiet |
|
|
| import argparse |
| import torch |
| from omegaconf import OmegaConf |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| import transformers |
| from transformers import TrainingArguments, Trainer, TrainerCallback, PrinterCallback |
| from transformers.trainer_callback import ProgressCallback |
|
|
|
|
| class _NoEvalTqdmCallback(ProgressCallback): |
| """Same as HF's ProgressCallback but with the per-batch eval bar disabled. |
| |
| In a Colab `!python -m ...` subprocess HF Trainer's `is_in_notebook()` |
| returns False (no IPython kernel in the child) so it falls back to plain |
| tqdm. Colab's text renderer mishandles `\\r` for fast updates, so the |
| eval bar (~1 batch/sec × 1250 batches) prints a fresh line every step |
| and lags the browser tab. Training tqdm updates slowly enough (one bar |
| line per ~9s at 24M params + LoRA + bf16) that it stays clean, so we |
| only kill the prediction bar. eval_loss is still logged at the end of |
| each eval pass via the standard log_history mechanism.""" |
|
|
| def on_prediction_step(self, args, state, control, **kwargs): |
| return |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from model import CXRVisionLanguageModel |
| from model.rad_dino import BioViLTEncoder |
| from data import CXRInstructDataset, CXRDataCollator, ITCDataCollator |
| from utils.logger import setup_logger |
| from utils.checkpoint import save_checkpoint, load_checkpoint |
| from utils.hf_uploader import build_tracker_from_cfg, pull_last_for_resume, hydrate_run_dir_from_hf |
| from utils.dataset_resolver import ( |
| resolve_dataset_spec, |
| resolve_run_id, |
| save_run_config, |
| run_dir, |
| stage_dir, |
| DatasetSpec, |
| ) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train CXR VLM") |
| parser.add_argument( |
| "--model_config", type=str, |
| default="configs/model_config.yaml", |
| help="Path to model config" |
| ) |
| parser.add_argument( |
| "--train_config", type=str, |
| default="configs/train_config.yaml", |
| help="Path to training config" |
| ) |
| parser.add_argument( |
| "--stage", type=int, default=None, |
| help="Run only stage 1 or stage 2 (default: run both). With --mode resume, " |
| "the stage is auto-detected and this flag should be left unset." |
| ) |
| parser.add_argument( |
| "--mode", type=str, default=None, choices=["fresh", "resume"], |
| help="Unified resume controller. 'fresh' → new run_N folder. " |
| "'resume' → reuse latest matching run_id (or --run_id), auto-detect " |
| "which stage to continue from based on checkpoints on disk. " |
| "If unset, behaviour is inferred from --resume_from / --run_id (legacy)." |
| ) |
| parser.add_argument( |
| "--resume_from", type=str, default=None, |
| help="Path to checkpoint to resume from (legacy; prefer --mode resume)" |
| ) |
| parser.add_argument( |
| "--run_id", type=str, default=None, |
| help="Explicit run id (e.g. 'IU-Xray_run_3'). If unset, auto-resolve." |
| ) |
| parser.add_argument( |
| "--resume_from_hf", action="store_true", |
| help="Pull <run_id>/<stage>/last/ from the HF Hub and resume from it. " |
| "Use when training on a fresh VM after the previous one was killed." |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
|
|
| def _list_checkpoints(stage_dir): |
| """Return [Path, …] of `checkpoint-NNN` folders sorted ascending by step.""" |
| if not stage_dir.is_dir(): |
| return [] |
| out = [] |
| for p in stage_dir.iterdir(): |
| if not p.is_dir() or not p.name.startswith("checkpoint-"): |
| continue |
| suffix = p.name.split("-", 1)[1] |
| if suffix.isdigit(): |
| out.append((int(suffix), p)) |
| return [p for _, p in sorted(out)] |
|
|
|
|
| def detect_resume_point(run_dir_path, stage1_subdir, stage2_subdir): |
| """ |
| Inspect the run dir on disk and decide where to pick up training. |
| |
| Returns a tuple `(target_stage, ckpt_path)` where: |
| target_stage : "stage1" | "stage2" | "done" |
| ckpt_path : Path to the checkpoint folder to pass to HF Trainer, |
| or None if the stage should start from scratch. |
| |
| Priority: |
| 1. stage 2 final saved → ("done", None) everything finished |
| 2. stage 2 has ckpts → ("stage2", latest) resume mid-stage2 |
| 3. stage 1 final saved → ("stage2", None) stage 1 done; start stage 2 |
| 4. stage 1 has ckpts → ("stage1", latest) resume mid-stage1 |
| 5. otherwise → ("stage1", None) brand-new run |
| """ |
| from pathlib import Path as _P |
| run_dir_path = _P(run_dir_path) |
| s1d = run_dir_path / stage1_subdir |
| s2d = run_dir_path / stage2_subdir |
|
|
| if (s2d / "stage2_final_projection.pt").exists(): |
| return ("done", None) |
|
|
| s2_ckpts = _list_checkpoints(s2d) |
| if s2_ckpts: |
| return ("stage2", s2_ckpts[-1]) |
|
|
| if (s1d / "stage1_final_projection.pt").exists(): |
| return ("stage2", None) |
|
|
| s1_ckpts = _list_checkpoints(s1d) |
| if s1_ckpts: |
| return ("stage1", s1_ckpts[-1]) |
|
|
| return ("stage1", None) |
|
|
|
|
| def compute_training_plan(train_cfg, instruct_json_path): |
| """ |
| Compute a coarse plan of total optimizer steps across stage 1 + stage 2, |
| derived from the train_config + the train-split sample count in the |
| instruct JSON. Used to print a human-readable summary at startup. |
| |
| Returns a dict (all ints) — gracefully handles missing fields. |
| """ |
| import json as _json |
| tr = train_cfg.training |
| try: |
| with open(instruct_json_path, "r", encoding="utf-8") as f: |
| all_samples = _json.load(f) |
| train_count = sum(1 for s in all_samples if s.get("split") == "train") |
| except Exception: |
| train_count = 0 |
|
|
| def _eff_batch(stage_cfg, itc_override=None): |
| |
| if itc_override is not None: |
| bs = int(itc_override.get("per_device_train_batch_size", |
| _cfg(stage_cfg, tr, "per_device_train_batch_size", 1))) |
| ga = int(itc_override.get("gradient_accumulation_steps", |
| _cfg(stage_cfg, tr, "gradient_accumulation_steps", 1))) |
| else: |
| bs = int(_cfg(stage_cfg, tr, "per_device_train_batch_size", 1)) |
| ga = int(_cfg(stage_cfg, tr, "gradient_accumulation_steps", 1)) |
| return max(1, bs * ga) |
|
|
| s1_cfg = train_cfg.stage1 |
| s2_cfg = train_cfg.stage2 |
| s1_enabled = bool(getattr(s1_cfg, "enabled", True)) |
| s2_enabled = bool(getattr(s2_cfg, "enabled", True)) |
| s1_epochs = int(getattr(s1_cfg, "num_epochs", 0)) if s1_enabled else 0 |
| s2_epochs = int(getattr(s2_cfg, "num_epochs", 0)) if s2_enabled else 0 |
|
|
| itc_cfg = s1_cfg.get("itc", None) |
| itc_on = bool(itc_cfg and itc_cfg.get("enabled", False)) |
| s1_eff = _eff_batch(s1_cfg, itc_cfg if itc_on else None) |
| s2_eff = _eff_batch(s2_cfg) |
|
|
| s1_spe = max(1, (train_count + s1_eff - 1) // s1_eff) |
| s2_spe = max(1, (train_count + s2_eff - 1) // s2_eff) |
| s1_steps = s1_spe * s1_epochs |
| s2_steps = s2_spe * s2_epochs |
| return { |
| "train_samples": train_count, |
| "effective_batch": s2_eff, |
| "stage1_eff_batch": s1_eff, |
| "stage2_eff_batch": s2_eff, |
| "steps_per_epoch": s2_spe, |
| "stage1_steps": s1_steps, |
| "stage2_steps": s2_steps, |
| "total_steps": s1_steps + s2_steps, |
| "stage1_epochs": s1_epochs, |
| "stage2_epochs": s2_epochs, |
| "itc_stage1": itc_on, |
| } |
|
|
|
|
| def _fmt_plan_banner(plan, run_id, target_stage, resume_ckpt): |
| s1, s2, tot = plan["stage1_steps"], plan["stage2_steps"], plan["total_steps"] |
| head = f"TRAINING PLAN — {run_id}" |
| sep = "=" * max(len(head) + 4, 60) |
| cur = "" |
|
|
| |
| |
| |
| def _ckpt_step(ckpt): |
| if not ckpt: |
| return None |
| try: |
| import json as _json |
| from pathlib import Path as _P |
| ts = _P(str(ckpt)) / "trainer_state.json" |
| if ts.is_file(): |
| return int(_json.load(open(ts))["global_step"]) |
| except Exception: |
| pass |
| |
| suf = str(ckpt).split("-")[-1] |
| return int(suf) if suf.isdigit() else None |
|
|
| real_step = _ckpt_step(resume_ckpt) |
| if target_stage == "stage1": |
| offset = real_step if real_step is not None else 0 |
| cur = f"Resuming at step {offset} / {tot} (inside stage 1)" |
| elif target_stage == "stage2": |
| offset = (s1 + real_step) if real_step is not None else s1 |
| cur = f"Resuming at step {offset} / {tot} (inside stage 2)" |
| elif target_stage == "done": |
| cur = f"All {tot} steps already complete — nothing to do" |
|
|
| s1_eff = plan.get("stage1_eff_batch", plan["effective_batch"]) |
| s2_eff = plan.get("stage2_eff_batch", plan["effective_batch"]) |
| itc_tag = " [ITC]" if plan.get("itc_stage1") else "" |
| lines = [ |
| sep, f" {head}", sep, |
| f" Train samples : {plan['train_samples']:,}", |
| f" Stage 1{itc_tag:<6} : {plan['stage1_epochs']} epochs → {s1} steps " |
| f"(eff.batch {s1_eff}; global steps 1–{s1})", |
| f" Stage 2 : {plan['stage2_epochs']} epochs → {s2} steps " |
| f"(eff.batch {s2_eff}; global steps {s1+1}–{tot})", |
| f" TOTAL : {tot} optimizer steps", |
| ] |
| if cur: |
| lines += [" " + "─" * (len(sep) - 4), f" {cur}"] |
| lines.append(sep) |
| return "\n".join(lines) |
|
|
|
|
| def get_trainer( |
| model, |
| train_dataset, |
| val_dataset, |
| collator, |
| training_args: TrainingArguments, |
| itc_mode: bool = False, |
| itc_temperature: float = 0.07, |
| ) -> Trainer: |
| """Build a HuggingFace Trainer. |
| |
| When `itc_mode=True` the loss is a symmetric image-text InfoNCE |
| (Stage-1 contrastive alignment); otherwise it's the causal-LM loss. |
| """ |
|
|
| class CXRTrainer(Trainer): |
| """Custom Trainer that passes images to model.forward(), and saves |
| only the trainable artifacts (projection MLP + LoRA adapters) instead |
| of the full ~5 GB Vicuna state dict that base Trainer would dump. |
| Resume re-loads the same artifacts.""" |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| if itc_mode: |
| return self._itc_loss(model, inputs, return_outputs) |
| outputs = model( |
| images = inputs["images"], |
| input_ids = inputs["input_ids"], |
| attention_mask = inputs["attention_mask"], |
| labels = inputs["labels"], |
| ) |
| loss = outputs["loss"] |
| return (loss, outputs) if return_outputs else loss |
|
|
| def _itc_loss(self, model, inputs, return_outputs=False): |
| """Symmetric image-text InfoNCE (CLIP/BLIP-2 ITC style). |
| |
| Image embeds come from model.forward_itc; text embeds are the |
| precomputed (already L2-normed) CXR-BERT vectors. The diagonal of |
| the (B, B) similarity matrix is the positive pair (dataset is |
| deduped to one image per study, so no in-batch collisions).""" |
| import torch.nn.functional as F |
| mdl = model.module if hasattr(model, "module") else model |
| img = mdl.forward_itc(inputs["images"]) |
| txt = F.normalize(inputs["text_embeds"].to(img.dtype), dim=-1) |
| logit_scale = 1.0 / max(itc_temperature, 1e-4) |
| logits = logit_scale * img @ txt.t() |
| tgt = torch.arange(logits.size(0), device=logits.device) |
| loss = 0.5 * (F.cross_entropy(logits, tgt) + |
| F.cross_entropy(logits.t(), tgt)) |
| return (loss, {"loss": loss, "logits": logits}) if return_outputs else loss |
|
|
| def prediction_step(self, model, inputs, prediction_loss_only, |
| ignore_keys=None): |
| |
| |
| if itc_mode: |
| with torch.no_grad(): |
| loss = self.compute_loss(model, inputs) |
| return (loss.detach(), None, None) |
| return super().prediction_step(model, inputs, prediction_loss_only, |
| ignore_keys=ignore_keys) |
|
|
| def _save(self, output_dir=None, state_dict=None): |
| output_dir = output_dir if output_dir is not None else self.args.output_dir |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| |
| save_checkpoint(self.model, output_dir, name="checkpoint") |
| |
| torch.save(self.args, Path(output_dir) / "training_args.bin") |
|
|
| def _load_from_checkpoint(self, resume_from_checkpoint, model=None): |
| |
| |
| if model is None: |
| model = self.model |
| load_checkpoint(model, resume_from_checkpoint) |
|
|
| def _get_train_sampler(self, *args, **kwargs): |
| """ |
| Use `WeightedRandomSampler` when the train dataset is mixed-task |
| and exposes per-sample weights — this is what makes the configured |
| `tasks.*.weight` ratios actually control batch composition. |
| Falls back to HF's default (RandomSampler / DistributedSampler) |
| for single-task or eval-time datasets. |
| |
| Notes: |
| * Eval is unaffected — HF's `_get_eval_sampler` returns a |
| `SequentialSampler` by default, so weighted reweighting only |
| applies to training. |
| * `replacement=True` is required for true oversampling — without |
| it you can't draw more samples of a rare-but-upweighted task |
| than physically exist. Tradeoff: a small fraction of samples |
| in a numerous-but-downweighted task may never appear in a |
| given epoch. Acceptable across multiple epochs. |
| """ |
| ds = self.train_dataset |
| getter = getattr(ds, "get_per_sample_weights", None) |
| if getter is not None: |
| weights = getter() |
| if weights is not None: |
| from torch.utils.data import WeightedRandomSampler |
| return WeightedRandomSampler( |
| weights = weights, |
| num_samples = len(ds), |
| replacement = True, |
| ) |
| return super()._get_train_sampler(*args, **kwargs) |
|
|
| trainer = CXRTrainer( |
| model = model, |
| args = training_args, |
| train_dataset = train_dataset, |
| eval_dataset = val_dataset, |
| data_collator = collator, |
| ) |
| trainer.remove_callback(PrinterCallback) |
| |
| |
| trainer.remove_callback(ProgressCallback) |
| trainer.add_callback(_NoEvalTqdmCallback()) |
| return trainer |
|
|
|
|
| def _cfg(stage_cfg, tr, key, default=None): |
| """ |
| Resolve a training hyperparameter with per-stage override semantics: |
| use the value from `stage_cfg` (stage1/stage2) if present, else fall back |
| to the global `training:` block, else `default`. This lets stages share |
| machine-level settings (fp16, dataloader_*) while overriding stage-specific |
| ones (batch size, warmup, optim) — see train_config.yaml docs. |
| """ |
| if stage_cfg is not None: |
| v = stage_cfg.get(key, None) |
| if v is not None: |
| return v |
| return tr.get(key, default) |
|
|
|
|
| def _build_training_args(train_cfg, stage_cfg, out_dir, run_name, *, enable_best, |
| overrides=None): |
| """ |
| Build TrainingArguments from config. Save / eval cadence, total limit, |
| best-model logic come from `train_cfg.training`; per-stage values |
| (batch, lr, warmup, optim, ...) resolve via `_cfg` (stage override → |
| global). `overrides` (dict) wins over both — used for the ITC Stage-1 |
| batch bump. `enable_best` toggles load_best_model_at_end. |
| """ |
| tr = train_cfg.training |
| overrides = overrides or {} |
|
|
| def pick(key, default=None): |
| if key in overrides: |
| return overrides[key] |
| return _cfg(stage_cfg, tr, key, default) |
|
|
| save_strategy = getattr(tr, "save_strategy", "steps") |
| eval_strategy = getattr(tr, "evaluation_strategy", save_strategy) |
| save_steps = getattr(tr, "save_steps", 200) |
| eval_steps = getattr(tr, "eval_steps", save_steps) |
|
|
| kwargs = dict( |
| output_dir = out_dir, |
| num_train_epochs = stage_cfg.num_epochs, |
| per_device_train_batch_size = pick("per_device_train_batch_size", 1), |
| per_device_eval_batch_size = pick("per_device_eval_batch_size", 1), |
| gradient_accumulation_steps = pick("gradient_accumulation_steps", 1), |
| learning_rate = stage_cfg.learning_rate, |
| lr_scheduler_type = pick("lr_scheduler_type", "cosine"), |
| warmup_ratio = pick("warmup_ratio", 0.0), |
| weight_decay = pick("weight_decay", 0.0), |
| fp16 = tr.fp16, |
| bf16 = getattr(tr, "bf16", False), |
| save_strategy = save_strategy, |
| eval_strategy = eval_strategy, |
| logging_steps = tr.logging_steps, |
| save_total_limit = getattr(tr, "save_total_limit", 1), |
| report_to = "wandb" if train_cfg.wandb.enabled else "none", |
| run_name = run_name, |
| dataloader_num_workers = getattr(tr, "dataloader_num_workers", 4), |
| dataloader_pin_memory = getattr(tr, "dataloader_pin_memory", True), |
| dataloader_persistent_workers = ( |
| getattr(tr, "dataloader_persistent_workers", True) |
| and getattr(tr, "dataloader_num_workers", 4) > 0 |
| ), |
| |
| |
| |
| |
| optim = pick("optim", "adamw_torch"), |
| remove_unused_columns = False, |
| ) |
| if save_strategy == "steps": |
| kwargs["save_steps"] = save_steps |
| if eval_strategy == "steps": |
| kwargs["eval_steps"] = eval_steps |
| if enable_best: |
| kwargs["load_best_model_at_end"] = getattr(tr, "load_best_model_at_end", True) |
| kwargs["metric_for_best_model"] = getattr(tr, "metric_for_best_model", "eval_loss") |
| kwargs["greater_is_better"] = getattr(tr, "greater_is_better", False) |
| return TrainingArguments(**kwargs) |
|
|
|
|
| class HFBestLastCallback(TrainerCallback): |
| """ |
| Maintain exactly two checkpoint folders + a training log on HF Hub for |
| each stage: |
| |
| <run_id>/<stage>/last/ ← latest checkpoint (for resume) |
| <run_id>/<stage>/best/ ← lowest eval_loss so far (for inference) |
| <run_id>/<stage>/best/best_meta.json |
| <run_id>/<stage>/training_log.jsonl |
| |
| `last/` is overwritten on every save. `best/` is overwritten only when |
| `metric_for_best` improves. No checkpoint-<step>/ folders accumulate. |
| |
| Each upload first deletes the remote folder so orphan files (e.g. a |
| `optimizer.pt` from a previous step that no longer exists locally) are |
| purged. |
| """ |
|
|
| def __init__(self, tracker, stage_subdir: str, logger, |
| metric_for_best: str = "eval_loss", |
| greater_is_better: bool = False): |
| self.tracker = tracker |
| self.stage_subdir = stage_subdir |
| self.logger = logger |
| self.metric_for_best = metric_for_best |
| self.greater_is_better = greater_is_better |
| self.best_metric = None |
| self.best_step = None |
|
|
| def _is_better(self, m: float) -> bool: |
| if self.best_metric is None: |
| return True |
| return (m > self.best_metric) if self.greater_is_better else (m < self.best_metric) |
|
|
| def on_evaluate(self, args, state, control, metrics=None, **kw): |
| if self.tracker is None or not metrics: |
| return |
| m = metrics.get(self.metric_for_best) |
| if m is None: |
| return |
| if self._is_better(m): |
| self.best_metric = float(m) |
| self.best_step = state.global_step |
|
|
| def on_save(self, args, state, control, **kw): |
| if self.tracker is None: |
| return |
| ckpt_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}" |
| if not ckpt_dir.exists(): |
| return |
|
|
| |
| try: |
| self.tracker.delete_remote(f"{self.stage_subdir}/last") |
| self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/last") |
| except Exception as e: |
| self.logger.warning(f"[HF upload] last @ step {state.global_step} failed: {e}") |
|
|
| |
| if self.best_step == state.global_step: |
| try: |
| self.tracker.delete_remote(f"{self.stage_subdir}/best") |
| self.tracker.upload_folder(str(ckpt_dir), f"{self.stage_subdir}/best") |
| self.tracker.upload_json( |
| { |
| "step": state.global_step, |
| "epoch": state.epoch, |
| self.metric_for_best: self.best_metric, |
| }, |
| f"{self.stage_subdir}/best/best_meta.json", |
| ) |
| except Exception as e: |
| self.logger.warning(f"[HF upload] best @ step {state.global_step} failed: {e}") |
|
|
| |
| try: |
| self.tracker.upload_jsonl( |
| state.log_history, |
| f"{self.stage_subdir}/training_log.jsonl", |
| ) |
| except Exception as e: |
| self.logger.warning(f"[HF upload] training_log failed: {e}") |
|
|
|
|
| def _build_datasets(spec: DatasetSpec, train_cfg, model, transform_train, transform_val, |
| itc_mode: bool = False, itc_text_cache=None): |
| """Construct train + val CXRInstructDataset instances from a DatasetSpec. |
| |
| When `itc_mode=True`, datasets yield (image, precomputed text_embed) pairs |
| for Stage-1 contrastive alignment instead of tokenized prompt/target.""" |
| feature_cache_dir = getattr(train_cfg.data, "feature_cache_dir", None) or None |
| if feature_cache_dir: |
| print(f"[_build_datasets] feature_cache_dir = {feature_cache_dir} " |
| f"(encoder bypass on cache hit)") |
| common = dict( |
| data_path = spec.instruct_json, |
| image_root = spec.image_root, |
| tokenizer = model.tokenizer, |
| task = "mixed", |
| cutoff_len = train_cfg.training.cutoff_len, |
| task_weights = spec.task_weights, |
| max_images = spec.max_images, |
| feature_cache_dir = feature_cache_dir, |
| itc_mode = itc_mode, |
| itc_text_cache = itc_text_cache, |
| ) |
| train_ds = CXRInstructDataset(transform=transform_train, |
| split=train_cfg.data.train_split, **common) |
| val_ds = CXRInstructDataset(transform=transform_val, |
| split=train_cfg.data.val_split, **common) |
| return train_ds, val_ds |
|
|
|
|
| def _stage1_itc_cfg(train_cfg): |
| """Return the stage1.itc DictConfig if ITC is enabled, else None.""" |
| itc = train_cfg.stage1.get("itc", None) |
| if itc and itc.get("enabled", False): |
| return itc |
| return None |
|
|
|
|
| def run_stage1(model, train_cfg, model_cfg, spec, out_dir, logger, tracker=None, resume_from=None): |
| """ |
| Stage 1: align the projection. |
| • default → train projection with the causal-LM loss (Vicuna fwd). |
| • stage1.itc on → train projection + ITC head with InfoNCE against |
| precomputed CXR-BERT text embeddings (no Vicuna). |
| Vision encoder is frozen in both. |
| """ |
| itc = _stage1_itc_cfg(train_cfg) |
| mode_tag = "ITC contrastive" if itc else "projection only (LM loss)" |
| logger.info("=" * 60) |
| logger.info(f"STAGE 1: {mode_tag} [{spec.dataset_name}]") |
| logger.info(f" output_dir = {out_dir}") |
| logger.info("=" * 60) |
|
|
| overrides = None |
| if itc: |
| model.set_stage1_itc_mode() |
| text_cache = itc.get("text_embed_cache", None) |
| if not text_cache: |
| raise ValueError("stage1.itc.enabled=true requires stage1.itc.text_embed_cache") |
| train_ds, val_ds = _build_datasets( |
| spec, train_cfg, model, |
| transform_train = BioViLTEncoder.get_transform("train"), |
| transform_val = BioViLTEncoder.get_transform("val"), |
| itc_mode = True, |
| itc_text_cache = text_cache, |
| ) |
| collator = ITCDataCollator() |
| |
| overrides = { |
| k: itc.get(k) for k in ( |
| "per_device_train_batch_size", "per_device_eval_batch_size", |
| "gradient_accumulation_steps") |
| if itc.get(k, None) is not None |
| } |
| else: |
| model.set_stage1_mode() |
| train_ds, val_ds = _build_datasets( |
| spec, train_cfg, model, |
| transform_train = BioViLTEncoder.get_transform("train"), |
| transform_val = BioViLTEncoder.get_transform("val"), |
| ) |
| collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id) |
|
|
| model.print_trainable_params() |
|
|
| training_args = _build_training_args( |
| train_cfg, train_cfg.stage1, out_dir, |
| run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage1", |
| enable_best = True, |
| overrides = overrides, |
| ) |
|
|
| trainer = get_trainer( |
| model, train_ds, val_ds, collator, training_args, |
| itc_mode = bool(itc), |
| itc_temperature = float(itc.get("temperature", 0.07)) if itc else 0.07, |
| ) |
| if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None: |
| tr = train_cfg.training |
| trainer.add_callback(HFBestLastCallback( |
| tracker, |
| stage_subdir = "stage1", |
| logger = logger, |
| metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"), |
| greater_is_better = getattr(tr, "greater_is_better", False), |
| )) |
| if resume_from: |
| logger.info(f"Resuming stage1 from checkpoint: {resume_from}") |
| trainer.train(resume_from_checkpoint=resume_from) |
| else: |
| trainer.train() |
|
|
| save_checkpoint(model, out_dir, "stage1_final") |
| logger.info(f"Stage 1 complete. Checkpoint saved to {out_dir}") |
|
|
| |
| |
| |
| |
| |
| if tracker is not None: |
| s1 = Path(out_dir) |
| tracker.delete_remote("stage1/best") |
| tracker.upload_file( |
| str(s1 / "stage1_final_projection.pt"), |
| "stage1/best/checkpoint_projection.pt", |
| ) |
| |
| lora_dir = s1 / "stage1_final_lora" |
| if lora_dir.is_dir(): |
| tracker.upload_folder(str(lora_dir), "stage1/best/checkpoint_lora") |
| itc_head_pt = s1 / "stage1_final_itc_head.pt" |
| if itc_head_pt.exists(): |
| tracker.upload_file(str(itc_head_pt), "stage1/best/checkpoint_itc_head.pt") |
| chexpert_pt = s1 / "stage1_final_chexpert_classifier.pt" |
| if chexpert_pt.exists(): |
| tracker.upload_file( |
| str(chexpert_pt), |
| "stage1/best/checkpoint_chexpert_classifier.pt", |
| ) |
| tracker.write_meta({ |
| "dataset_name": spec.dataset_name, |
| "stage1_done": True, |
| "stage1_mode": "itc" if itc else "lm", |
| "stage1_output_dir": out_dir, |
| "stage1_epochs": train_cfg.stage1.num_epochs, |
| "stage1_lr": train_cfg.stage1.learning_rate, |
| }) |
| return model |
|
|
|
|
| def run_stage2(model, train_cfg, model_cfg, spec, out_dir, logger, |
| resume_from=None, tracker=None): |
| """ |
| Stage 2: Train projection + LLM LoRA (instruction tuning). |
| Vision encoder frozen. LLM trained via LoRA adapters. |
| """ |
| logger.info("=" * 60) |
| logger.info(f"STAGE 2: Instruction tuning (projection + LoRA) [{spec.dataset_name}]") |
| logger.info(f" output_dir = {out_dir}") |
| logger.info("=" * 60) |
|
|
| model.set_stage2_mode() |
| model.print_trainable_params() |
|
|
| train_ds, val_ds = _build_datasets( |
| spec, train_cfg, model, |
| transform_train = BioViLTEncoder.get_transform("train"), |
| transform_val = BioViLTEncoder.get_transform("val"), |
| ) |
|
|
| training_args = _build_training_args( |
| train_cfg, train_cfg.stage2, out_dir, |
| run_name = f"{spec.dataset_name}-{train_cfg.wandb.run_name}-stage2", |
| enable_best = True, |
| ) |
|
|
| collator = CXRDataCollator(pad_token_id=model.tokenizer.pad_token_id) |
| trainer = get_trainer(model, train_ds, val_ds, collator, training_args) |
| if getattr(train_cfg.training, "upload_intermediate_to_hf", False) and tracker is not None: |
| tr = train_cfg.training |
| trainer.add_callback(HFBestLastCallback( |
| tracker, |
| stage_subdir = "stage2", |
| logger = logger, |
| metric_for_best = getattr(tr, "metric_for_best_model", "eval_loss"), |
| greater_is_better = getattr(tr, "greater_is_better", False), |
| )) |
|
|
| if resume_from: |
| trainer.train(resume_from_checkpoint=resume_from) |
| else: |
| trainer.train() |
|
|
| save_checkpoint(model, out_dir, "stage2_final") |
| logger.info(f"Stage 2 complete. Checkpoint saved to {out_dir}") |
|
|
| |
| if tracker is not None: |
| s2 = Path(out_dir) |
| tracker.delete_remote("stage2/best") |
| tracker.upload_file( |
| str(s2 / "stage2_final_projection.pt"), |
| "stage2/best/checkpoint_projection.pt", |
| ) |
| tracker.upload_folder( |
| str(s2 / "stage2_final_lora"), |
| "stage2/best/checkpoint_lora", |
| ) |
| chexpert_pt = s2 / "stage2_final_chexpert_classifier.pt" |
| if chexpert_pt.exists(): |
| tracker.upload_file( |
| str(chexpert_pt), |
| "stage2/best/checkpoint_chexpert_classifier.pt", |
| ) |
| tracker.write_meta({ |
| "dataset_name": spec.dataset_name, |
| "stage2_done": True, |
| "stage2_output_dir": out_dir, |
| "stage2_epochs": train_cfg.stage2.num_epochs, |
| "stage2_lr": train_cfg.stage2.learning_rate, |
| }) |
| return model |
|
|
|
|
| def main(): |
| args = parse_args() |
| logger = setup_logger("cxr_vlm_train") |
|
|
| |
| model_cfg = OmegaConf.load(args.model_config) |
| train_cfg = OmegaConf.load(args.train_config) |
|
|
| logger.info(f"Model config: {args.model_config}") |
| logger.info(f"Train config: {args.train_config}") |
|
|
| |
| spec = resolve_dataset_spec(train_cfg) |
| logger.info(f"Dataset: {spec.dataset_name}") |
| logger.info(f" image_root = {spec.image_root}") |
| logger.info(f" instruct_json = {spec.instruct_json}") |
| logger.info(f" tasks = {spec.tasks}") |
| logger.info(f" task_weights = {spec.task_weights}") |
|
|
| |
| output_root = str(train_cfg.training.get("output_root", "checkpoints")) |
| state_file = str(train_cfg.hf_hub.run_state_file) |
| hf_token = os.environ.get( |
| train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN") |
| ) if train_cfg.hf_hub.enabled else None |
| hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None |
|
|
| |
| |
| if args.mode == "resume": |
| resuming = True |
| elif args.mode == "fresh": |
| resuming = False |
| else: |
| resuming = bool(args.resume_from) or args.resume_from_hf |
|
|
| run_id = resolve_run_id( |
| dataset_name = spec.dataset_name, |
| output_root = output_root, |
| state_file = state_file, |
| resuming = resuming, |
| explicit = args.run_id, |
| hf_repo_id = hf_repo_id, |
| hf_token = hf_token, |
| ) |
| logger.info(f"run_id = {run_id}") |
|
|
| |
| if args.resume_from_hf: |
| if not train_cfg.hf_hub.enabled or not train_cfg.hf_hub.repo_id: |
| sys.exit("--resume_from_hf requires hf_hub.enabled=true and a repo_id") |
| target_stage_subdir = ( |
| str(train_cfg.stage1.get("subdir", "stage1_projection")) |
| if args.stage == 1 |
| else str(train_cfg.stage2.get("subdir", "stage2_instruct")) |
| ) |
| |
| hub_stage_dir = "stage1" if args.stage == 1 else "stage2" |
| local_resume = pull_last_for_resume( |
| repo_id = train_cfg.hf_hub.repo_id, |
| token = os.environ.get(train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")), |
| run_id = run_id, |
| stage_subdir = hub_stage_dir, |
| local_root = str(Path(output_root) / "_resume_from_hf"), |
| ) |
| if local_resume is None: |
| sys.exit(f"Could not pull {run_id}/{hub_stage_dir}/last/ from HF — abort.") |
| args.resume_from = local_resume |
| logger.info(f"Will resume from pulled checkpoint: {local_resume} (stage{args.stage})") |
|
|
| |
| |
| |
| |
| |
| |
| if args.mode == "resume" and hf_repo_id and hf_token: |
| try: |
| hydrate_run_dir_from_hf( |
| repo_id = hf_repo_id, |
| token = hf_token, |
| run_id = run_id, |
| output_root = output_root, |
| stage1_subdir = str(train_cfg.stage1.get("subdir", "stage1_projection")), |
| stage2_subdir = str(train_cfg.stage2.get("subdir", "stage2_instruct")), |
| ) |
| except Exception as e: |
| logger.warning(f"[resume hydrate] {type(e).__name__}: {e}") |
|
|
| |
| stage1_out = stage_dir(output_root, run_id, |
| str(train_cfg.stage1.get("subdir", "stage1_projection"))) |
| stage2_out = stage_dir(output_root, run_id, |
| str(train_cfg.stage2.get("subdir", "stage2_instruct"))) |
|
|
| |
| |
| |
| |
| |
| |
| auto_target_stage = None |
| auto_resume_ckpt = None |
| if args.mode == "resume" and args.stage is None: |
| auto_target_stage, auto_resume_ckpt = detect_resume_point( |
| run_dir(output_root, run_id), |
| str(train_cfg.stage1.get("subdir", "stage1_projection")), |
| str(train_cfg.stage2.get("subdir", "stage2_instruct")), |
| ) |
| logger.info( |
| f"[resume autodetect] target={auto_target_stage} " |
| f"ckpt={auto_resume_ckpt}" |
| ) |
|
|
| |
| plan = compute_training_plan(train_cfg, spec.instruct_json) |
| logger.info("\n" + _fmt_plan_banner(plan, run_id, |
| auto_target_stage or "stage1", |
| auto_resume_ckpt)) |
|
|
| if auto_target_stage == "done": |
| logger.info("Both stages already complete for this run. Exiting cleanly.") |
| return |
|
|
| |
| |
| |
| |
| save_run_config( |
| run_dir_path = run_dir(output_root, run_id), |
| spec = spec, |
| model_cfg = model_cfg, |
| train_cfg = train_cfg, |
| extra = { |
| "stage_arg": args.stage, |
| "resumed": bool(args.resume_from) or args.resume_from_hf, |
| "resume_from": args.resume_from, |
| "resume_from_hf": args.resume_from_hf, |
| "model_config_path": args.model_config, |
| "train_config_path": args.train_config, |
| }, |
| ) |
|
|
| |
| if train_cfg.wandb.enabled: |
| os.environ["WANDB_PROJECT"] = train_cfg.wandb.project |
|
|
| |
| |
| |
| tracker = build_tracker_from_cfg( |
| train_cfg, |
| resuming = bool(args.resume_from), |
| explicit_run_id = run_id, |
| ) |
| if tracker is not None: |
| tracker.write_meta({ |
| "dataset_name": spec.dataset_name, |
| "run_id": run_id, |
| "config_model": args.model_config, |
| "config_train": args.train_config, |
| "resumed": bool(args.resume_from), |
| "resume_from": args.resume_from, |
| }) |
| |
| |
| |
| |
| |
| rd = run_dir(output_root, run_id) |
| if (rd / "configs").is_dir(): |
| tracker.upload_folder(str(rd / "configs"), "configs") |
| if (rd / "run_meta.json").is_file(): |
| tracker.upload_file(str(rd / "run_meta.json"), "run_meta.json") |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| if args.stage is not None: |
| run_s1 = (args.stage == 1) and train_cfg.stage1.enabled |
| run_s2 = (args.stage == 2) and train_cfg.stage2.enabled |
| elif auto_target_stage == "stage2": |
| |
| run_s1 = False |
| run_s2 = train_cfg.stage2.enabled |
| else: |
| run_s1 = train_cfg.stage1.enabled |
| run_s2 = train_cfg.stage2.enabled |
|
|
| |
| |
| s1_resume_path = None |
| s2_resume_path = None |
| if args.stage == 1: |
| s1_resume_path = args.resume_from |
| elif args.stage == 2: |
| s2_resume_path = args.resume_from |
| elif auto_target_stage == "stage1": |
| s1_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None |
| elif auto_target_stage == "stage2": |
| s2_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None |
|
|
| |
| |
| |
| |
| itc_on = _stage1_itc_cfg(train_cfg) is not None |
| if itc_on and run_s1: |
| logger.info("Building CXR VLM (light: ITC Stage-1, Vicuna NOT loaded)...") |
| model = CXRVisionLanguageModel(model_cfg, load_llm=False, build_itc_head=True) |
| else: |
| logger.info("Building CXR VLM...") |
| model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False) |
| if args.resume_from: |
| logger.info(f"Resuming from: {args.resume_from}") |
| load_checkpoint(model, args.resume_from) |
|
|
| if run_s1: |
| model = run_stage1( |
| model, train_cfg, model_cfg, spec, stage1_out, logger, |
| tracker = tracker, |
| resume_from = s1_resume_path, |
| ) |
|
|
| if run_s2: |
| |
| |
| |
| if getattr(model, "llm", None) is None: |
| logger.info("Rebuilding full CXR VLM (with Vicuna) for Stage 2...") |
| import gc |
| del model |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| model = CXRVisionLanguageModel(model_cfg, load_llm=True, build_itc_head=False) |
| if args.resume_from and not s2_resume_path: |
| load_checkpoint(model, args.resume_from) |
| |
| |
| |
| |
| |
| |
| |
| stage1_ckpt = Path(stage1_out) / "stage1_final.pt" |
| if run_s1: |
| load_checkpoint(model, str(stage1_ckpt)) |
| logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}") |
| elif stage1_ckpt.exists() and not s2_resume_path: |
| load_checkpoint(model, str(stage1_ckpt)) |
| logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}") |
| elif not s2_resume_path: |
| logger.warning( |
| "⚠ No stage1 weights found and not resuming. Projection layer " |
| "will start RANDOMLY for stage2. Expect degraded convergence. " |
| f"Looked at: {stage1_ckpt}" |
| ) |
|
|
| model = run_stage2( |
| model, train_cfg, model_cfg, spec, stage2_out, logger, |
| resume_from = s2_resume_path, |
| tracker = tracker, |
| ) |
|
|
| |
| if tracker is not None: |
| results_dir = Path("results") / run_id |
| if results_dir.exists(): |
| tracker.upload_folder(str(results_dir), "results") |
| tracker.write_meta({"training_complete": True}) |
|
|
| logger.info("Training complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|