| """ |
| checkpoint.py |
| ------------- |
| Utilities for saving and loading model checkpoints. |
| Saves projection layer + LoRA adapter weights separately |
| (not the full frozen LLM), keeping checkpoint size small. |
| """ |
|
|
| import torch |
| from pathlib import Path |
| from typing import Optional |
|
|
|
|
| def save_checkpoint(model, output_dir: str, name: str = "checkpoint"): |
| """ |
| Save trainable model weights: |
| - projection layer weights |
| - LoRA adapter weights (via PEFT save) |
| |
| Does NOT save frozen encoder or full LLM weights. |
| |
| Args: |
| model: CXRVisionLanguageModel |
| output_dir: directory to save into |
| name: checkpoint name |
| """ |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| proj_path = output_dir / f"{name}_projection.pt" |
| torch.save(model.projection.state_dict(), proj_path) |
| print(f"[Checkpoint] Projection saved → {proj_path}") |
|
|
| |
| |
| if getattr(model, "itc_head", None) is not None: |
| itc_path = output_dir / f"{name}_itc_head.pt" |
| torch.save(model.itc_head.state_dict(), itc_path) |
| print(f"[Checkpoint] ITC head saved → {itc_path}") |
|
|
| |
| if getattr(model, "llm", None) is not None: |
| lora_dir = output_dir / f"{name}_lora" |
| model.llm.save_pretrained(str(lora_dir)) |
| print(f"[Checkpoint] LoRA adapters saved → {lora_dir}") |
| else: |
| print("[Checkpoint] llm is None (ITC Stage-1) — LoRA save skipped.") |
|
|
| |
| if model.chexpert_classifier is not None: |
| clf_path = output_dir / f"{name}_chexpert_classifier.pt" |
| torch.save(model.chexpert_classifier.state_dict(), clf_path) |
| print(f"[Checkpoint] CheXpert classifier saved → {clf_path}") |
|
|
|
|
| def load_checkpoint( |
| model, |
| checkpoint_path: str, |
| load_lora: bool = True, |
| strict: bool = False, |
| ): |
| """ |
| Load trainable weights from checkpoint. |
| |
| Args: |
| model: CXRVisionLanguageModel |
| checkpoint_path: path to checkpoint dir OR .pt file |
| load_lora: whether to load LoRA adapters |
| strict: strict state dict loading |
| """ |
| checkpoint_path = Path(checkpoint_path) |
|
|
| |
| if checkpoint_path.suffix == ".pt": |
| ckpt_dir = checkpoint_path.parent |
| ckpt_name = checkpoint_path.stem |
| else: |
| ckpt_dir = checkpoint_path |
| ckpt_name = "checkpoint" |
|
|
| |
| proj_path = ckpt_dir / f"{ckpt_name}_projection.pt" |
| if proj_path.exists(): |
| state = torch.load(proj_path, map_location="cpu") |
| model.projection.load_state_dict(state, strict=strict) |
| print(f"[Checkpoint] Projection loaded ← {proj_path}") |
| else: |
| print(f"[Checkpoint] No projection file found at {proj_path}, skipping.") |
|
|
| |
| if getattr(model, "itc_head", None) is not None: |
| itc_path = ckpt_dir / f"{ckpt_name}_itc_head.pt" |
| if itc_path.exists(): |
| state = torch.load(itc_path, map_location="cpu") |
| model.itc_head.load_state_dict(state, strict=strict) |
| print(f"[Checkpoint] ITC head loaded ← {itc_path}") |
|
|
| |
| if load_lora and getattr(model, "llm", None) is not None: |
| lora_dir = ckpt_dir / f"{ckpt_name}_lora" |
| |
| |
| |
| |
| if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file(): |
| raise FileNotFoundError( |
| f"[load_checkpoint] {lora_dir} exists but adapter_config.json " |
| f"is missing — checkpoint is partially-written/downloaded. " |
| f"Fix: delete the parent checkpoint folder " |
| f"({lora_dir.parent}) and rerun with --mode resume so it " |
| f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct " |
| f"folder to train Stage 2 fresh from stage1_final." |
| ) |
| if lora_dir.exists(): |
| from peft import PeftModel |
| |
| |
| |
| |
| |
| model.llm = PeftModel.from_pretrained( |
| model.llm.base_model.model, |
| str(lora_dir), |
| is_trainable=True, |
| ) |
| print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}") |
| else: |
| print(f"[Checkpoint] No LoRA dir found at {lora_dir}, skipping.") |
|
|
| |
| if model.chexpert_classifier is not None: |
| clf_path = ckpt_dir / f"{ckpt_name}_chexpert_classifier.pt" |
| if clf_path.exists(): |
| state = torch.load(clf_path, map_location="cpu") |
| model.chexpert_classifier.load_state_dict(state, strict=strict) |
| print(f"[Checkpoint] CheXpert classifier loaded ← {clf_path}") |
|
|