File size: 5,735 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
28b13fc
320063f
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
6a13626
 
 
 
 
28b13fc
 
 
6a13626
28b13fc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
checkpoint.py
-------------
Utilities for saving and loading model checkpoints.
Saves projection layer + LoRA adapter weights separately
(not the full frozen LLM), keeping checkpoint size small.
"""

import torch
from pathlib import Path
from typing import Optional


def save_checkpoint(model, output_dir: str, name: str = "checkpoint"):
    """
    Save trainable model weights:
      - projection layer weights
      - LoRA adapter weights (via PEFT save)

    Does NOT save frozen encoder or full LLM weights.

    Args:
        model:      CXRVisionLanguageModel
        output_dir: directory to save into
        name:       checkpoint name
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save projection layer
    proj_path = output_dir / f"{name}_projection.pt"
    torch.save(model.projection.state_dict(), proj_path)
    print(f"[Checkpoint] Projection saved → {proj_path}")

    # Save ITC head if present (Stage-1 ITC mode). Tiny; only exists when
    # the model was built with build_itc_head=True.
    if getattr(model, "itc_head", None) is not None:
        itc_path = output_dir / f"{name}_itc_head.pt"
        torch.save(model.itc_head.state_dict(), itc_path)
        print(f"[Checkpoint] ITC head saved → {itc_path}")

    # Save LoRA adapters via PEFT — skipped in ITC Stage-1 (llm not loaded).
    if getattr(model, "llm", None) is not None:
        lora_dir = output_dir / f"{name}_lora"
        model.llm.save_pretrained(str(lora_dir))
        print(f"[Checkpoint] LoRA adapters saved → {lora_dir}")
    else:
        print("[Checkpoint] llm is None (ITC Stage-1) — LoRA save skipped.")

    # Save CheXpert classifier if it exists and was trained
    if model.chexpert_classifier is not None:
        clf_path = output_dir / f"{name}_chexpert_classifier.pt"
        torch.save(model.chexpert_classifier.state_dict(), clf_path)
        print(f"[Checkpoint] CheXpert classifier saved → {clf_path}")


def load_checkpoint(
    model,
    checkpoint_path: str,
    load_lora: bool = True,
    strict: bool = False,
):
    """
    Load trainable weights from checkpoint.

    Args:
        model:            CXRVisionLanguageModel
        checkpoint_path:  path to checkpoint dir OR .pt file
        load_lora:        whether to load LoRA adapters
        strict:           strict state dict loading
    """
    checkpoint_path = Path(checkpoint_path)

    # If given a .pt file, treat its parent as the checkpoint dir
    if checkpoint_path.suffix == ".pt":
        ckpt_dir  = checkpoint_path.parent
        ckpt_name = checkpoint_path.stem
    else:
        ckpt_dir  = checkpoint_path
        ckpt_name = "checkpoint"

    # Load projection
    proj_path = ckpt_dir / f"{ckpt_name}_projection.pt"
    if proj_path.exists():
        state = torch.load(proj_path, map_location="cpu")
        model.projection.load_state_dict(state, strict=strict)
        print(f"[Checkpoint] Projection loaded ← {proj_path}")
    else:
        print(f"[Checkpoint] No projection file found at {proj_path}, skipping.")

    # Load ITC head (Stage-1 ITC mode only; present iff model.itc_head built)
    if getattr(model, "itc_head", None) is not None:
        itc_path = ckpt_dir / f"{ckpt_name}_itc_head.pt"
        if itc_path.exists():
            state = torch.load(itc_path, map_location="cpu")
            model.itc_head.load_state_dict(state, strict=strict)
            print(f"[Checkpoint] ITC head loaded ← {itc_path}")

    # Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present.
    if load_lora and getattr(model, "llm", None) is not None:
        lora_dir = ckpt_dir / f"{ckpt_name}_lora"
        # Defensive: PEFT raises an opaque HFValidationError when the dir
        # exists but `adapter_config.json` is missing (a partially-written
        # or partially-downloaded checkpoint). Surface a clearer message so
        # the user knows the fix: delete the dir and resume from HF Hub.
        if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file():
            raise FileNotFoundError(
                f"[load_checkpoint] {lora_dir} exists but adapter_config.json "
                f"is missing — checkpoint is partially-written/downloaded. "
                f"Fix: delete the parent checkpoint folder "
                f"({lora_dir.parent}) and rerun with --mode resume so it "
                f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct "
                f"folder to train Stage 2 fresh from stage1_final."
            )
        if lora_dir.exists():
            from peft import PeftModel
            # is_trainable=True is REQUIRED on resume: PEFT defaults to
            # inference mode (requires_grad=False on all LoRA params), which
            # would shrink the trainable set to projection-only (~5 tensors)
            # and break optimizer state loading with a param-group size
            # mismatch against the saved 261-tensor stage-2 optimizer.
            model.llm = PeftModel.from_pretrained(
                model.llm.base_model.model,
                str(lora_dir),
                is_trainable=True,
            )
            print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
        else:
            print(f"[Checkpoint] No LoRA dir found at {lora_dir}, skipping.")

    # Load CheXpert classifier
    if model.chexpert_classifier is not None:
        clf_path = ckpt_dir / f"{ckpt_name}_chexpert_classifier.pt"
        if clf_path.exists():
            state = torch.load(clf_path, map_location="cpu")
            model.chexpert_classifier.load_state_dict(state, strict=strict)
            print(f"[Checkpoint] CheXpert classifier loaded ← {clf_path}")