File size: 5,735 Bytes
28b13fc 8356dae 28b13fc 8356dae 28b13fc 320063f 28b13fc 6a13626 28b13fc 6a13626 28b13fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """
checkpoint.py
-------------
Utilities for saving and loading model checkpoints.
Saves projection layer + LoRA adapter weights separately
(not the full frozen LLM), keeping checkpoint size small.
"""
import torch
from pathlib import Path
from typing import Optional
def save_checkpoint(model, output_dir: str, name: str = "checkpoint"):
"""
Save trainable model weights:
- projection layer weights
- LoRA adapter weights (via PEFT save)
Does NOT save frozen encoder or full LLM weights.
Args:
model: CXRVisionLanguageModel
output_dir: directory to save into
name: checkpoint name
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save projection layer
proj_path = output_dir / f"{name}_projection.pt"
torch.save(model.projection.state_dict(), proj_path)
print(f"[Checkpoint] Projection saved → {proj_path}")
# Save ITC head if present (Stage-1 ITC mode). Tiny; only exists when
# the model was built with build_itc_head=True.
if getattr(model, "itc_head", None) is not None:
itc_path = output_dir / f"{name}_itc_head.pt"
torch.save(model.itc_head.state_dict(), itc_path)
print(f"[Checkpoint] ITC head saved → {itc_path}")
# Save LoRA adapters via PEFT — skipped in ITC Stage-1 (llm not loaded).
if getattr(model, "llm", None) is not None:
lora_dir = output_dir / f"{name}_lora"
model.llm.save_pretrained(str(lora_dir))
print(f"[Checkpoint] LoRA adapters saved → {lora_dir}")
else:
print("[Checkpoint] llm is None (ITC Stage-1) — LoRA save skipped.")
# Save CheXpert classifier if it exists and was trained
if model.chexpert_classifier is not None:
clf_path = output_dir / f"{name}_chexpert_classifier.pt"
torch.save(model.chexpert_classifier.state_dict(), clf_path)
print(f"[Checkpoint] CheXpert classifier saved → {clf_path}")
def load_checkpoint(
model,
checkpoint_path: str,
load_lora: bool = True,
strict: bool = False,
):
"""
Load trainable weights from checkpoint.
Args:
model: CXRVisionLanguageModel
checkpoint_path: path to checkpoint dir OR .pt file
load_lora: whether to load LoRA adapters
strict: strict state dict loading
"""
checkpoint_path = Path(checkpoint_path)
# If given a .pt file, treat its parent as the checkpoint dir
if checkpoint_path.suffix == ".pt":
ckpt_dir = checkpoint_path.parent
ckpt_name = checkpoint_path.stem
else:
ckpt_dir = checkpoint_path
ckpt_name = "checkpoint"
# Load projection
proj_path = ckpt_dir / f"{ckpt_name}_projection.pt"
if proj_path.exists():
state = torch.load(proj_path, map_location="cpu")
model.projection.load_state_dict(state, strict=strict)
print(f"[Checkpoint] Projection loaded ← {proj_path}")
else:
print(f"[Checkpoint] No projection file found at {proj_path}, skipping.")
# Load ITC head (Stage-1 ITC mode only; present iff model.itc_head built)
if getattr(model, "itc_head", None) is not None:
itc_path = ckpt_dir / f"{ckpt_name}_itc_head.pt"
if itc_path.exists():
state = torch.load(itc_path, map_location="cpu")
model.itc_head.load_state_dict(state, strict=strict)
print(f"[Checkpoint] ITC head loaded ← {itc_path}")
# Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present.
if load_lora and getattr(model, "llm", None) is not None:
lora_dir = ckpt_dir / f"{ckpt_name}_lora"
# Defensive: PEFT raises an opaque HFValidationError when the dir
# exists but `adapter_config.json` is missing (a partially-written
# or partially-downloaded checkpoint). Surface a clearer message so
# the user knows the fix: delete the dir and resume from HF Hub.
if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file():
raise FileNotFoundError(
f"[load_checkpoint] {lora_dir} exists but adapter_config.json "
f"is missing — checkpoint is partially-written/downloaded. "
f"Fix: delete the parent checkpoint folder "
f"({lora_dir.parent}) and rerun with --mode resume so it "
f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct "
f"folder to train Stage 2 fresh from stage1_final."
)
if lora_dir.exists():
from peft import PeftModel
# is_trainable=True is REQUIRED on resume: PEFT defaults to
# inference mode (requires_grad=False on all LoRA params), which
# would shrink the trainable set to projection-only (~5 tensors)
# and break optimizer state loading with a param-group size
# mismatch against the saved 261-tensor stage-2 optimizer.
model.llm = PeftModel.from_pretrained(
model.llm.base_model.model,
str(lora_dir),
is_trainable=True,
)
print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
else:
print(f"[Checkpoint] No LoRA dir found at {lora_dir}, skipping.")
# Load CheXpert classifier
if model.chexpert_classifier is not None:
clf_path = ckpt_dir / f"{ckpt_name}_chexpert_classifier.pt"
if clf_path.exists():
state = torch.load(clf_path, map_location="cpu")
model.chexpert_classifier.load_state_dict(state, strict=strict)
print(f"[Checkpoint] CheXpert classifier loaded ← {clf_path}")
|