cxr-vlm-code / utils /checkpoint.py
convitom
f
6a13626
"""
checkpoint.py
-------------
Utilities for saving and loading model checkpoints.
Saves projection layer + LoRA adapter weights separately
(not the full frozen LLM), keeping checkpoint size small.
"""
import torch
from pathlib import Path
from typing import Optional
def save_checkpoint(model, output_dir: str, name: str = "checkpoint"):
"""
Save trainable model weights:
- projection layer weights
- LoRA adapter weights (via PEFT save)
Does NOT save frozen encoder or full LLM weights.
Args:
model: CXRVisionLanguageModel
output_dir: directory to save into
name: checkpoint name
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save projection layer
proj_path = output_dir / f"{name}_projection.pt"
torch.save(model.projection.state_dict(), proj_path)
print(f"[Checkpoint] Projection saved → {proj_path}")
# Save ITC head if present (Stage-1 ITC mode). Tiny; only exists when
# the model was built with build_itc_head=True.
if getattr(model, "itc_head", None) is not None:
itc_path = output_dir / f"{name}_itc_head.pt"
torch.save(model.itc_head.state_dict(), itc_path)
print(f"[Checkpoint] ITC head saved → {itc_path}")
# Save LoRA adapters via PEFT — skipped in ITC Stage-1 (llm not loaded).
if getattr(model, "llm", None) is not None:
lora_dir = output_dir / f"{name}_lora"
model.llm.save_pretrained(str(lora_dir))
print(f"[Checkpoint] LoRA adapters saved → {lora_dir}")
else:
print("[Checkpoint] llm is None (ITC Stage-1) — LoRA save skipped.")
# Save CheXpert classifier if it exists and was trained
if model.chexpert_classifier is not None:
clf_path = output_dir / f"{name}_chexpert_classifier.pt"
torch.save(model.chexpert_classifier.state_dict(), clf_path)
print(f"[Checkpoint] CheXpert classifier saved → {clf_path}")
def load_checkpoint(
model,
checkpoint_path: str,
load_lora: bool = True,
strict: bool = False,
):
"""
Load trainable weights from checkpoint.
Args:
model: CXRVisionLanguageModel
checkpoint_path: path to checkpoint dir OR .pt file
load_lora: whether to load LoRA adapters
strict: strict state dict loading
"""
checkpoint_path = Path(checkpoint_path)
# If given a .pt file, treat its parent as the checkpoint dir
if checkpoint_path.suffix == ".pt":
ckpt_dir = checkpoint_path.parent
ckpt_name = checkpoint_path.stem
else:
ckpt_dir = checkpoint_path
ckpt_name = "checkpoint"
# Load projection
proj_path = ckpt_dir / f"{ckpt_name}_projection.pt"
if proj_path.exists():
state = torch.load(proj_path, map_location="cpu")
model.projection.load_state_dict(state, strict=strict)
print(f"[Checkpoint] Projection loaded ← {proj_path}")
else:
print(f"[Checkpoint] No projection file found at {proj_path}, skipping.")
# Load ITC head (Stage-1 ITC mode only; present iff model.itc_head built)
if getattr(model, "itc_head", None) is not None:
itc_path = ckpt_dir / f"{ckpt_name}_itc_head.pt"
if itc_path.exists():
state = torch.load(itc_path, map_location="cpu")
model.itc_head.load_state_dict(state, strict=strict)
print(f"[Checkpoint] ITC head loaded ← {itc_path}")
# Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present.
if load_lora and getattr(model, "llm", None) is not None:
lora_dir = ckpt_dir / f"{ckpt_name}_lora"
# Defensive: PEFT raises an opaque HFValidationError when the dir
# exists but `adapter_config.json` is missing (a partially-written
# or partially-downloaded checkpoint). Surface a clearer message so
# the user knows the fix: delete the dir and resume from HF Hub.
if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file():
raise FileNotFoundError(
f"[load_checkpoint] {lora_dir} exists but adapter_config.json "
f"is missing — checkpoint is partially-written/downloaded. "
f"Fix: delete the parent checkpoint folder "
f"({lora_dir.parent}) and rerun with --mode resume so it "
f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct "
f"folder to train Stage 2 fresh from stage1_final."
)
if lora_dir.exists():
from peft import PeftModel
# is_trainable=True is REQUIRED on resume: PEFT defaults to
# inference mode (requires_grad=False on all LoRA params), which
# would shrink the trainable set to projection-only (~5 tensors)
# and break optimizer state loading with a param-group size
# mismatch against the saved 261-tensor stage-2 optimizer.
model.llm = PeftModel.from_pretrained(
model.llm.base_model.model,
str(lora_dir),
is_trainable=True,
)
print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
else:
print(f"[Checkpoint] No LoRA dir found at {lora_dir}, skipping.")
# Load CheXpert classifier
if model.chexpert_classifier is not None:
clf_path = ckpt_dir / f"{ckpt_name}_chexpert_classifier.pt"
if clf_path.exists():
state = torch.load(clf_path, map_location="cpu")
model.chexpert_classifier.load_state_dict(state, strict=strict)
print(f"[Checkpoint] CheXpert classifier loaded ← {clf_path}")