""" checkpoint.py ------------- Utilities for saving and loading model checkpoints. Saves projection layer + LoRA adapter weights separately (not the full frozen LLM), keeping checkpoint size small. """ import torch from pathlib import Path from typing import Optional def save_checkpoint(model, output_dir: str, name: str = "checkpoint"): """ Save trainable model weights: - projection layer weights - LoRA adapter weights (via PEFT save) Does NOT save frozen encoder or full LLM weights. Args: model: CXRVisionLanguageModel output_dir: directory to save into name: checkpoint name """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Save projection layer proj_path = output_dir / f"{name}_projection.pt" torch.save(model.projection.state_dict(), proj_path) print(f"[Checkpoint] Projection saved → {proj_path}") # Save ITC head if present (Stage-1 ITC mode). Tiny; only exists when # the model was built with build_itc_head=True. if getattr(model, "itc_head", None) is not None: itc_path = output_dir / f"{name}_itc_head.pt" torch.save(model.itc_head.state_dict(), itc_path) print(f"[Checkpoint] ITC head saved → {itc_path}") # Save LoRA adapters via PEFT — skipped in ITC Stage-1 (llm not loaded). if getattr(model, "llm", None) is not None: lora_dir = output_dir / f"{name}_lora" model.llm.save_pretrained(str(lora_dir)) print(f"[Checkpoint] LoRA adapters saved → {lora_dir}") else: print("[Checkpoint] llm is None (ITC Stage-1) — LoRA save skipped.") # Save CheXpert classifier if it exists and was trained if model.chexpert_classifier is not None: clf_path = output_dir / f"{name}_chexpert_classifier.pt" torch.save(model.chexpert_classifier.state_dict(), clf_path) print(f"[Checkpoint] CheXpert classifier saved → {clf_path}") def load_checkpoint( model, checkpoint_path: str, load_lora: bool = True, strict: bool = False, ): """ Load trainable weights from checkpoint. Args: model: CXRVisionLanguageModel checkpoint_path: path to checkpoint dir OR .pt file load_lora: whether to load LoRA adapters strict: strict state dict loading """ checkpoint_path = Path(checkpoint_path) # If given a .pt file, treat its parent as the checkpoint dir if checkpoint_path.suffix == ".pt": ckpt_dir = checkpoint_path.parent ckpt_name = checkpoint_path.stem else: ckpt_dir = checkpoint_path ckpt_name = "checkpoint" # Load projection proj_path = ckpt_dir / f"{ckpt_name}_projection.pt" if proj_path.exists(): state = torch.load(proj_path, map_location="cpu") model.projection.load_state_dict(state, strict=strict) print(f"[Checkpoint] Projection loaded ← {proj_path}") else: print(f"[Checkpoint] No projection file found at {proj_path}, skipping.") # Load ITC head (Stage-1 ITC mode only; present iff model.itc_head built) if getattr(model, "itc_head", None) is not None: itc_path = ckpt_dir / f"{ckpt_name}_itc_head.pt" if itc_path.exists(): state = torch.load(itc_path, map_location="cpu") model.itc_head.load_state_dict(state, strict=strict) print(f"[Checkpoint] ITC head loaded ← {itc_path}") # Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present. if load_lora and getattr(model, "llm", None) is not None: lora_dir = ckpt_dir / f"{ckpt_name}_lora" # Defensive: PEFT raises an opaque HFValidationError when the dir # exists but `adapter_config.json` is missing (a partially-written # or partially-downloaded checkpoint). Surface a clearer message so # the user knows the fix: delete the dir and resume from HF Hub. if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file(): raise FileNotFoundError( f"[load_checkpoint] {lora_dir} exists but adapter_config.json " f"is missing — checkpoint is partially-written/downloaded. " f"Fix: delete the parent checkpoint folder " f"({lora_dir.parent}) and rerun with --mode resume so it " f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct " f"folder to train Stage 2 fresh from stage1_final." ) if lora_dir.exists(): from peft import PeftModel # is_trainable=True is REQUIRED on resume: PEFT defaults to # inference mode (requires_grad=False on all LoRA params), which # would shrink the trainable set to projection-only (~5 tensors) # and break optimizer state loading with a param-group size # mismatch against the saved 261-tensor stage-2 optimizer. model.llm = PeftModel.from_pretrained( model.llm.base_model.model, str(lora_dir), is_trainable=True, ) print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}") else: print(f"[Checkpoint] No LoRA dir found at {lora_dir}, skipping.") # Load CheXpert classifier if model.chexpert_classifier is not None: clf_path = ckpt_dir / f"{ckpt_name}_chexpert_classifier.pt" if clf_path.exists(): state = torch.load(clf_path, map_location="cpu") model.chexpert_classifier.load_state_dict(state, strict=strict) print(f"[Checkpoint] CheXpert classifier loaded ← {clf_path}")