convitom commited on
Commit ·
320063f
1
Parent(s): 0b18c4a
- docker/Dockerfile +44 -0
- docker/requirements_docker.txt +38 -0
- evaluation/evaluate.py +1 -1
- requirements.txt +40 -19
- scripts/cxrvlm_colab_train.ipynb +2 -59
- training/train.py +4 -2
- utils/checkpoint.py +13 -0
docker/Dockerfile
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CXR-VLM training environment.
|
| 2 |
+
# Matches the Colab GPU runtime fingerprint verified on 2026-05:
|
| 3 |
+
# Python 3.12.13
|
| 4 |
+
# torch 2.10.0+cu128
|
| 5 |
+
# CUDA 12.8 (nvcc 12.8.93)
|
| 6 |
+
# cuDNN 9.10.2
|
| 7 |
+
# glibc 2.35 (Ubuntu 22.04 base)
|
| 8 |
+
# bnb 0.49.2 (4-bit quantize verified)
|
| 9 |
+
#
|
| 10 |
+
# ─── Host requirements ────────────────────────────────────────────────────────
|
| 11 |
+
# This image requires NVIDIA driver >= 550.54 on the host (CUDA 12.8 runtime).
|
| 12 |
+
# • Colab — driver 580+, OK
|
| 13 |
+
# • Vast.ai — filter "CUDA Driver >= 550" when picking an instance
|
| 14 |
+
# • Lightning AI — A10G / A100 / H100 OK; check older T4 Studios
|
| 15 |
+
# • RunPod — pick a 12.8-compatible template or BYO image
|
| 16 |
+
#
|
| 17 |
+
# T4 (sm_75) note: torch.cuda.is_bf16_supported() returns True via emulation,
|
| 18 |
+
# but T4 has no hardware BF16. Keep train_cfg.training.fp16=True / bf16=False
|
| 19 |
+
# on T4. On A100/L4/H100 (sm_80+) you can flip to bf16.
|
| 20 |
+
#
|
| 21 |
+
# ─── Build & push ─────────────────────────────────────────────────────────────
|
| 22 |
+
# docker build -t <hub>/cxr-vlm-env:cu128 docker/
|
| 23 |
+
# docker push <hub>/cxr-vlm-env:cu128
|
| 24 |
+
#
|
| 25 |
+
# ─── Base image fallbacks (if 2.10.0-cuda12.8 tag is missing on Docker Hub) ──
|
| 26 |
+
# FROM nvcr.io/nvidia/pytorch:25.04-py3 # NVIDIA NGC, always cu128
|
| 27 |
+
# FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel # stable, slightly older
|
| 28 |
+
FROM pytorch/pytorch:2.10.0-cuda12.8-cudnn9-devel
|
| 29 |
+
|
| 30 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 31 |
+
PYTHONUNBUFFERED=1 \
|
| 32 |
+
BITSANDBYTES_NOWELCOME=1 \
|
| 33 |
+
TOKENIZERS_PARALLELISM=false \
|
| 34 |
+
HF_HUB_DISABLE_PROGRESS_BARS=1 \
|
| 35 |
+
TRANSFORMERS_VERBOSITY=warning
|
| 36 |
+
|
| 37 |
+
RUN apt-get update && \
|
| 38 |
+
apt-get install -y --no-install-recommends git wget curl && \
|
| 39 |
+
rm -rf /var/lib/apt/lists/*
|
| 40 |
+
|
| 41 |
+
COPY requirements_docker.txt /tmp/requirements_docker.txt
|
| 42 |
+
RUN pip install --no-cache-dir -r /tmp/requirements_docker.txt
|
| 43 |
+
|
| 44 |
+
WORKDIR /workspace
|
docker/requirements_docker.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Versions match Colab GPU runtime (verified 2026-05). torch + torchvision
|
| 2 |
+
# are NOT listed here — they come from the pytorch/pytorch base image with
|
| 3 |
+
# the right CUDA build baked in.
|
| 4 |
+
|
| 5 |
+
# ── Core HF stack ─────────────────────────────────────────────────────────────
|
| 6 |
+
transformers==4.49.0
|
| 7 |
+
peft==0.14.0
|
| 8 |
+
accelerate==1.13.0
|
| 9 |
+
bitsandbytes==0.49.2
|
| 10 |
+
huggingface_hub==1.11.0
|
| 11 |
+
httpx==0.28.1 # utils/_httpx_compat.py handles the allow_redirects removal
|
| 12 |
+
|
| 13 |
+
# ── Vision encoder ────────────────────────────────────────────────────────────
|
| 14 |
+
# rad_dino loads via transformers AutoModel — no extra dep needed.
|
| 15 |
+
# hi-ml-multimodal (BioViL-T) intentionally omitted; model/rad_dino.py wraps
|
| 16 |
+
# its import in try/except and falls back to timm/transformers cleanly.
|
| 17 |
+
timm==1.0.26
|
| 18 |
+
Pillow==11.3.0
|
| 19 |
+
|
| 20 |
+
# ── Config / data ─────────────────────────────────────────────────────────────
|
| 21 |
+
omegaconf==2.3.0
|
| 22 |
+
sentencepiece==0.2.1
|
| 23 |
+
protobuf==5.29.6
|
| 24 |
+
numpy==2.0.2
|
| 25 |
+
pandas==2.2.2
|
| 26 |
+
|
| 27 |
+
# ── Eval metrics ──────────────────────────────────────────────────────────────
|
| 28 |
+
nltk==3.9.1
|
| 29 |
+
rouge-score==0.1.2
|
| 30 |
+
bert-score==0.3.12
|
| 31 |
+
scikit-learn==1.6.1
|
| 32 |
+
|
| 33 |
+
# ── Training / experiment tracking ────────────────────────────────────────────
|
| 34 |
+
wandb==0.26.1
|
| 35 |
+
tqdm==4.67.3
|
| 36 |
+
|
| 37 |
+
# ── Optional: LLM-as-judge for VQA ────────────────────────────────────────────
|
| 38 |
+
openai==2.32.0
|
evaluation/evaluate.py
CHANGED
|
@@ -47,7 +47,7 @@ from typing import List, Dict, Optional
|
|
| 47 |
import torch
|
| 48 |
from torch.utils.data import DataLoader
|
| 49 |
from omegaconf import OmegaConf
|
| 50 |
-
from tqdm import tqdm
|
| 51 |
|
| 52 |
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 53 |
|
|
|
|
| 47 |
import torch
|
| 48 |
from torch.utils.data import DataLoader
|
| 49 |
from omegaconf import OmegaConf
|
| 50 |
+
from tqdm.auto import tqdm
|
| 51 |
|
| 52 |
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 53 |
|
requirements.txt
CHANGED
|
@@ -1,21 +1,42 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
omegaconf==2.3.0
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
rouge-score==0.1.2
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 1 |
+
# CXR-VLM dependencies — versions match Colab GPU runtime (verified 2026-05).
|
| 2 |
+
#
|
| 3 |
+
# NOTE: torch / torchvision are deliberately NOT pinned here. Install them
|
| 4 |
+
# separately matching your CUDA driver, e.g. on Colab they come pre-installed,
|
| 5 |
+
# and for local/Vast.ai/Lightning use:
|
| 6 |
+
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
|
| 7 |
+
# (or whichever CUDA the host driver supports).
|
| 8 |
+
|
| 9 |
+
# ── Core HF stack ─────────────────────────────────────────────────────────────
|
| 10 |
+
transformers==4.49.0
|
| 11 |
+
peft==0.14.0
|
| 12 |
+
accelerate==1.13.0
|
| 13 |
+
bitsandbytes==0.49.2
|
| 14 |
+
huggingface_hub==1.11.0
|
| 15 |
+
httpx==0.28.1 # utils/_httpx_compat.py handles the allow_redirects removal
|
| 16 |
+
|
| 17 |
+
# ── Vision encoder ────────────────────────────────────────────────────────────
|
| 18 |
+
# rad_dino (default) loads via transformers AutoModel — no extra dep.
|
| 19 |
+
# biovilt backend (hi-ml-multimodal) intentionally OMITTED — code falls back
|
| 20 |
+
# to timm/transformers automatically when health_multimodal isn't installed.
|
| 21 |
+
timm==1.0.26
|
| 22 |
+
Pillow==11.3.0
|
| 23 |
+
|
| 24 |
+
# ── Config / data ─────────────────────────────────────────────────────────────
|
| 25 |
omegaconf==2.3.0
|
| 26 |
+
sentencepiece==0.2.1
|
| 27 |
+
protobuf==5.29.6
|
| 28 |
+
numpy==2.0.2
|
| 29 |
+
pandas==2.2.2
|
| 30 |
+
|
| 31 |
+
# ── Eval metrics ──────────────────────────────────────────────────────────────
|
| 32 |
+
nltk==3.9.1
|
| 33 |
rouge-score==0.1.2
|
| 34 |
+
bert-score==0.3.12
|
| 35 |
+
scikit-learn==1.6.1
|
| 36 |
+
|
| 37 |
+
# ── Training / experiment tracking ────────────────────────────────────────────
|
| 38 |
+
wandb==0.26.1
|
| 39 |
+
tqdm==4.67.3
|
| 40 |
+
|
| 41 |
+
# ── Optional: LLM-as-judge for VQA (set --judge_model on evaluation/evaluate.py)
|
| 42 |
+
openai==2.32.0
|
scripts/cxrvlm_colab_train.ipynb
CHANGED
|
@@ -299,64 +299,7 @@
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
-
"!pip uninstall -y -q torchao transformers bitsandbytes peft accelerate\n",
|
| 303 |
-
"\n",
|
| 304 |
-
"# Let pip pick latest bnb that matches Colab's CUDA 12.8 + triton 3.x\n",
|
| 305 |
-
"!pip install -q -U bitsandbytes\n",
|
| 306 |
-
"\n",
|
| 307 |
-
"# Install everything. We DON'T pin httpx anymore — Colab's firebase-admin and\n",
|
| 308 |
-
"# google-genai hard-pin httpx==0.28.1, so the resolver always wins. Instead\n",
|
| 309 |
-
"# we monkey-patch httpx 0.28+ below to keep accepting the legacy\n",
|
| 310 |
-
"# `allow_redirects` kwarg that transformers ≤4.50 still passes.\n",
|
| 311 |
-
"!pip install -q \\\n",
|
| 312 |
-
" 'transformers>=4.46,<4.50' \\\n",
|
| 313 |
-
" 'peft>=0.13,<0.15' \\\n",
|
| 314 |
-
" 'accelerate>=1.0' \\\n",
|
| 315 |
-
" 'huggingface_hub>=0.27,<1.0' \\\n",
|
| 316 |
-
" omegaconf sentencepiece 'protobuf>=3.20' \\\n",
|
| 317 |
-
" nltk rouge-score bert-score sacrebleu\n",
|
| 318 |
-
"\n",
|
| 319 |
-
"import torch, transformers, bitsandbytes, peft, accelerate, huggingface_hub, httpx\n",
|
| 320 |
-
"print('torch :', torch.__version__, '| cuda:', torch.cuda.is_available())\n",
|
| 321 |
-
"print('transformers :', transformers.__version__)\n",
|
| 322 |
-
"print('bitsandbytes :', bitsandbytes.__version__)\n",
|
| 323 |
-
"print('peft :', peft.__version__)\n",
|
| 324 |
-
"print('accelerate :', accelerate.__version__)\n",
|
| 325 |
-
"print('huggingface_hub:', huggingface_hub.__version__)\n",
|
| 326 |
-
"print('httpx :', httpx.__version__)\n",
|
| 327 |
-
"\n",
|
| 328 |
-
"# ── httpx 0.28+ compat shim ───────────────────────────────────────────────\n",
|
| 329 |
-
"# transformers ≤4.50 calls httpx.Client.head(..., allow_redirects=True) which\n",
|
| 330 |
-
"# httpx 0.28 removed → \"Client.head() got an unexpected keyword argument\n",
|
| 331 |
-
"# 'allow_redirects'\". Translate the kwarg at the call site so the rest of\n",
|
| 332 |
-
"# the stack keeps working. No-op on httpx <0.28.\n",
|
| 333 |
-
"#\n",
|
| 334 |
-
"# The same patch is auto-applied inside the train.py subprocess via\n",
|
| 335 |
-
"# utils._quiet → utils._httpx_compat. Here we apply it in the NOTEBOOK\n",
|
| 336 |
-
"# kernel too, so the smoke test cell (which runs in-kernel) benefits.\n",
|
| 337 |
-
"def _patch_httpx():\n",
|
| 338 |
-
" if tuple(int(x) for x in httpx.__version__.split('.')[:2]) < (0, 28):\n",
|
| 339 |
-
" return\n",
|
| 340 |
-
" if getattr(httpx.Client, '_cxr_vlm_compat_patched', False):\n",
|
| 341 |
-
" return\n",
|
| 342 |
-
" def _make(orig):\n",
|
| 343 |
-
" def patched(self, *args, **kwargs):\n",
|
| 344 |
-
" if 'allow_redirects' in kwargs:\n",
|
| 345 |
-
" kwargs['follow_redirects'] = kwargs.pop('allow_redirects')\n",
|
| 346 |
-
" # httpx 0.28+ removed per-request `proxies=` too — transformers ≤4.49\n",
|
| 347 |
-
" # still passes it via huggingface_hub.has_file → drop it silently.\n",
|
| 348 |
-
" kwargs.pop('proxies', None)\n",
|
| 349 |
-
" return orig(self, *args, **kwargs)\n",
|
| 350 |
-
" return patched\n",
|
| 351 |
-
" for cls in (httpx.Client, httpx.AsyncClient):\n",
|
| 352 |
-
" for m in ('request', 'get', 'head', 'post', 'put',\n",
|
| 353 |
-
" 'patch', 'delete', 'options'):\n",
|
| 354 |
-
" if hasattr(cls, m):\n",
|
| 355 |
-
" setattr(cls, m, _make(getattr(cls, m)))\n",
|
| 356 |
-
" httpx.Client._cxr_vlm_compat_patched = True\n",
|
| 357 |
-
" print(f'httpx {httpx.__version__}: monkey-patched allow_redirects → follow_redirects')\n",
|
| 358 |
-
"\n",
|
| 359 |
-
"_patch_httpx()\n"
|
| 360 |
]
|
| 361 |
},
|
| 362 |
{
|
|
@@ -1723,4 +1666,4 @@
|
|
| 1723 |
},
|
| 1724 |
"nbformat": 4,
|
| 1725 |
"nbformat_minor": 5
|
| 1726 |
-
}
|
|
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
+
"import os as _os\n_in_docker = _os.path.exists('/.dockerenv')\nif _in_docker:\n print('Running inside Docker image -- skipping pip install (env pre-baked).')\nelse:\n !pip uninstall -y -q torchao transformers bitsandbytes peft accelerate\n \n # Let pip pick latest bnb that matches Colab's CUDA 12.8 + triton 3.x\n !pip install -q -U bitsandbytes\n \n # Install everything. We DON'T pin httpx anymore — Colab's firebase-admin and\n # google-genai hard-pin httpx==0.28.1, so the resolver always wins. Instead\n # we monkey-patch httpx 0.28+ below to keep accepting the legacy\n # `allow_redirects` kwarg that transformers ≤4.50 still passes.\n !pip install -q \\\n 'transformers>=4.46,<4.50' \\\n 'peft>=0.13,<0.15' \\\n 'accelerate>=1.0' \\\n 'huggingface_hub>=0.27,<1.0' \\\n omegaconf sentencepiece 'protobuf>=3.20' \\\n nltk rouge-score bert-score sacrebleu\n \n import torch, transformers, bitsandbytes, peft, accelerate, huggingface_hub, httpx\n print('torch :', torch.__version__, '| cuda:', torch.cuda.is_available())\n print('transformers :', transformers.__version__)\n print('bitsandbytes :', bitsandbytes.__version__)\n print('peft :', peft.__version__)\n print('accelerate :', accelerate.__version__)\n print('huggingface_hub:', huggingface_hub.__version__)\n print('httpx :', httpx.__version__)\n \n # ── httpx 0.28+ compat shim ─────────────────────────────────���─────────────\n # transformers ≤4.50 calls httpx.Client.head(..., allow_redirects=True) which\n # httpx 0.28 removed → \"Client.head() got an unexpected keyword argument\n # 'allow_redirects'\". Translate the kwarg at the call site so the rest of\n # the stack keeps working. No-op on httpx <0.28.\n #\n # The same patch is auto-applied inside the train.py subprocess via\n # utils._quiet → utils._httpx_compat. Here we apply it in the NOTEBOOK\n # kernel too, so the smoke test cell (which runs in-kernel) benefits.\n def _patch_httpx():\n if tuple(int(x) for x in httpx.__version__.split('.')[:2]) < (0, 28):\n return\n if getattr(httpx.Client, '_cxr_vlm_compat_patched', False):\n return\n def _make(orig):\n def patched(self, *args, **kwargs):\n if 'allow_redirects' in kwargs:\n kwargs['follow_redirects'] = kwargs.pop('allow_redirects')\n # httpx 0.28+ removed per-request `proxies=` too — transformers ≤4.49\n # still passes it via huggingface_hub.has_file → drop it silently.\n kwargs.pop('proxies', None)\n return orig(self, *args, **kwargs)\n return patched\n for cls in (httpx.Client, httpx.AsyncClient):\n for m in ('request', 'get', 'head', 'post', 'put',\n 'patch', 'delete', 'options'):\n if hasattr(cls, m):\n setattr(cls, m, _make(getattr(cls, m)))\n httpx.Client._cxr_vlm_compat_patched = True\n print(f'httpx {httpx.__version__}: monkey-patched allow_redirects → follow_redirects')\n \n _patch_httpx()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
]
|
| 304 |
},
|
| 305 |
{
|
|
|
|
| 1666 |
},
|
| 1667 |
"nbformat": 4,
|
| 1668 |
"nbformat_minor": 5
|
| 1669 |
+
}
|
training/train.py
CHANGED
|
@@ -34,7 +34,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
| 34 |
torch.backends.cudnn.allow_tf32 = True
|
| 35 |
|
| 36 |
import transformers
|
| 37 |
-
from transformers import TrainingArguments, Trainer, TrainerCallback
|
| 38 |
|
| 39 |
# Add project root to path
|
| 40 |
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
@@ -353,13 +353,15 @@ def get_trainer(
|
|
| 353 |
)
|
| 354 |
return super()._get_train_sampler(*args, **kwargs)
|
| 355 |
|
| 356 |
-
|
| 357 |
model = model,
|
| 358 |
args = training_args,
|
| 359 |
train_dataset = train_dataset,
|
| 360 |
eval_dataset = val_dataset,
|
| 361 |
data_collator = collator,
|
| 362 |
)
|
|
|
|
|
|
|
| 363 |
|
| 364 |
|
| 365 |
def _cfg(stage_cfg, tr, key, default=None):
|
|
|
|
| 34 |
torch.backends.cudnn.allow_tf32 = True
|
| 35 |
|
| 36 |
import transformers
|
| 37 |
+
from transformers import TrainingArguments, Trainer, TrainerCallback, PrinterCallback
|
| 38 |
|
| 39 |
# Add project root to path
|
| 40 |
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
| 353 |
)
|
| 354 |
return super()._get_train_sampler(*args, **kwargs)
|
| 355 |
|
| 356 |
+
trainer = CXRTrainer(
|
| 357 |
model = model,
|
| 358 |
args = training_args,
|
| 359 |
train_dataset = train_dataset,
|
| 360 |
eval_dataset = val_dataset,
|
| 361 |
data_collator = collator,
|
| 362 |
)
|
| 363 |
+
trainer.remove_callback(PrinterCallback)
|
| 364 |
+
return trainer
|
| 365 |
|
| 366 |
|
| 367 |
def _cfg(stage_cfg, tr, key, default=None):
|
utils/checkpoint.py
CHANGED
|
@@ -99,6 +99,19 @@ def load_checkpoint(
|
|
| 99 |
# Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present.
|
| 100 |
if load_lora and getattr(model, "llm", None) is not None:
|
| 101 |
lora_dir = ckpt_dir / f"{ckpt_name}_lora"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
if lora_dir.exists():
|
| 103 |
from peft import PeftModel
|
| 104 |
model.llm = PeftModel.from_pretrained(
|
|
|
|
| 99 |
# Load LoRA — skipped when llm not loaded (ITC Stage-1) or no dir present.
|
| 100 |
if load_lora and getattr(model, "llm", None) is not None:
|
| 101 |
lora_dir = ckpt_dir / f"{ckpt_name}_lora"
|
| 102 |
+
# Defensive: PEFT raises an opaque HFValidationError when the dir
|
| 103 |
+
# exists but `adapter_config.json` is missing (a partially-written
|
| 104 |
+
# or partially-downloaded checkpoint). Surface a clearer message so
|
| 105 |
+
# the user knows the fix: delete the dir and resume from HF Hub.
|
| 106 |
+
if lora_dir.is_dir() and not (lora_dir / "adapter_config.json").is_file():
|
| 107 |
+
raise FileNotFoundError(
|
| 108 |
+
f"[load_checkpoint] {lora_dir} exists but adapter_config.json "
|
| 109 |
+
f"is missing — checkpoint is partially-written/downloaded. "
|
| 110 |
+
f"Fix: delete the parent checkpoint folder "
|
| 111 |
+
f"({lora_dir.parent}) and rerun with --mode resume so it "
|
| 112 |
+
f"gets re-pulled from HF Hub, OR rm -rf the stage2_instruct "
|
| 113 |
+
f"folder to train Stage 2 fresh from stage1_final."
|
| 114 |
+
)
|
| 115 |
if lora_dir.exists():
|
| 116 |
from peft import PeftModel
|
| 117 |
model.llm = PeftModel.from_pretrained(
|