{ "cells": [ { "cell_type": "markdown", "id": "s0-intro", "metadata": {}, "source": [ "# Stage 0 — Train CheXpert Classifier (14 pathologies × 3 classes)\n", "\n", "Trains the **U-MultiClass** CheXpert head (`model/chexpert_classifier.py`) on top of\n", "**frozen RAD-DINO** global (CLS) embeddings. Output: 14 pathology heads, each predicting\n", "`negative / positive / uncertain` — fed downstream as the PNU structured-findings string.\n", "\n", "Designed for **T4 free-tier Colab** because:\n", "- RAD-DINO frozen → we **precompute CLS embeddings once** and cache to disk.\n", "- The classifier itself is a tiny MLP (768 → 256 → 14×3) — training the head on\n", " cached embeddings runs in minutes, no big VRAM needed.\n", "\n", "Data: MIMIC-CXR_resized (manifests + flat `Study_.jpg` files) pulled from\n", "`hieu3636/cxr-vlm-data`. Labels come from manifest's `chex_*` columns\n", "(1→positive, 0→negative, -1→uncertain, blank/NaN→negative — META-CXR convention).\n", "\n", "Output checkpoint uploads to `hieu3636/cxr-vlm-data/chexpert_classifier/`. Point\n", "`model_config.yaml: chexpert_classifier.checkpoint` at the downloaded copy for\n", "Stage 1/2 training." ] }, { "cell_type": "markdown", "id": "s0-sel-md", "metadata": {}, "source": [ "## 1. Selectors + hyperparams" ] }, { "cell_type": "code", "id": "s0-sel", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# ── Selectors ──\n", "PLATFORM = 'colab' # 'colab' | 'kaggle' | 'lightning' | 'gcp' | 'local'\n", "DATASET_NAME = 'MIMIC-CXR_resized' # this notebook is built for the resized subset\n", "\n", "# ── Training hyperparams (small head; defaults are sane for T4) ──\n", "NUM_EPOCHS = 8\n", "BATCH_SIZE = 256 # head only, cached embeddings — go big\n", "LR = 1.0e-3\n", "WEIGHT_DECAY = 1.0e-4\n", "DROPOUT = 0.2\n", "USE_CLS_TOKEN = True # True = RAD-DINO CLS (recommended). False = mean-pool patches.\n", "\n", "# ── Cache + checkpoint paths ──\n", "EMBED_CACHE_PT = 'data/data_files/raddino_global_embeds_mimic_resized.pt'\n", "CKPT_OUT = 'checkpoints/stage0_chexpert/chexpert_mimic_resized.pt'\n", "\n", "# ── HF upload (set HF_TOKEN secret on Colab first) ──\n", "HF_REPO_ID = 'hieu3636/cxr-vlm-data'\n", "HF_REPO_TYPE = 'dataset'\n", "HF_SUBDIR = 'chexpert_classifier'\n", "DO_UPLOAD = True\n", "print(f'{DATASET_NAME} | epochs={NUM_EPOCHS} bs={BATCH_SIZE} lr={LR}')" ] }, { "cell_type": "markdown", "id": "s0-inst-md", "metadata": {}, "source": [ "## 2. Install deps (light)" ] }, { "cell_type": "code", "id": "s0-inst", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Mirrors the train notebook's install dance but trimmed for Stage 0:\n", "# - no bitsandbytes / peft (no Vicuna, no LoRA in this stage)\n", "# - no sacrebleu / rouge / bert-score (no text generation)\n", "# Same transformers + huggingface_hub pins as the train notebook so this\n", "# notebook can import the repo's model/data modules without version skew.\n", "!pip uninstall -y -q torchao transformers peft accelerate\n", "\n", "!pip install -q \\\n", " 'transformers>=4.46,<4.50' \\\n", " 'accelerate>=1.0' \\\n", " 'huggingface_hub>=0.27,<1.0' \\\n", " omegaconf sentencepiece 'protobuf>=3.20' \\\n", " pillow tqdm scikit-learn\n", "\n", "import torch, transformers, huggingface_hub, httpx\n", "print('torch :', torch.__version__, '| cuda:', torch.cuda.is_available())\n", "print('transformers :', transformers.__version__)\n", "print('huggingface_hub:', huggingface_hub.__version__)\n", "print('httpx :', httpx.__version__)\n", "\n", "# ── httpx 0.28+ compat shim (same as train notebook) ──────────────────\n", "# transformers ≤4.49 calls Client.head(..., allow_redirects=True, proxies=...)\n", "# which httpx 0.28 removed → TypeError on the Hub probe path. Patch both\n", "# kwargs at the call site; no-op on httpx <0.28.\n", "def _patch_httpx():\n", " if tuple(int(x) for x in httpx.__version__.split('.')[:2]) < (0, 28):\n", " return\n", " if getattr(httpx.Client, '_cxr_vlm_compat_patched', False):\n", " return\n", " def _make(orig):\n", " def patched(self, *args, **kwargs):\n", " if 'allow_redirects' in kwargs:\n", " kwargs['follow_redirects'] = kwargs.pop('allow_redirects')\n", " kwargs.pop('proxies', None)\n", " return orig(self, *args, **kwargs)\n", " return patched\n", " for cls in (httpx.Client, httpx.AsyncClient):\n", " for m in ('request','get','head','post','put','patch','delete','options'):\n", " if hasattr(cls, m):\n", " setattr(cls, m, _make(getattr(cls, m)))\n", " httpx.Client._cxr_vlm_compat_patched = True\n", " print(f'httpx {httpx.__version__}: monkey-patched allow_redirects + proxies')\n", "\n", "_patch_httpx()" ] }, { "cell_type": "markdown", "id": "s0-env-md", "metadata": {}, "source": [ "## 3. Env + pull code & data (reused from train notebook)" ] }, { "cell_type": "code", "id": "cell-env", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0' # single-GPU\n", "os.environ['TOKENIZERS_PARALLELISM'] = 'false' # silence HF tokenizers fork warning\n", "os.environ['BITSANDBYTES_NOWELCOME'] = '1'\n", "os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # kill per-shard download bars\n", "os.environ['TRANSFORMERS_VERBOSITY'] = 'warning'\n", "os.environ['PYTHONUNBUFFERED'] = '1'\n", "\n", "import sys, shutil, subprocess\n", "from pathlib import Path\n" ] }, { "cell_type": "code", "id": "cell-paths", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# ── Per-platform storage + source-of-truth ─────────────────────────\n", "# All platforms (kaggle / colab / lightning / gcp / local) pull code +\n", "# data from HF Hub. The only platform-specific bit is:\n", "# * WORK : where to land outputs (persisted dirs differ per host)\n", "# * TOKEN : how HF_TOKEN reaches os.environ (secrets API differs)\n", "#\n", "# Required HF repos:\n", "# /cxr-vlm-code — project source (flat folder)\n", "# /cxr-vlm-data — per-dataset payloads:\n", "# MIMIC-CXR_resized/ (tar shards + manifests + vqa)\n", "# MIMIC-CXR.zip (single zip)\n", "# IU-Xray.zip (single zip)\n", "\n", "HF_USER = 'hieu3636' # <<< EDIT ME\n", "\n", "# ── 1) WORK dir + HF_TOKEN bootstrap (platform-specific) ───────────\n", "if PLATFORM == 'kaggle':\n", " from kaggle_secrets import UserSecretsClient\n", " os.environ['HF_TOKEN'] = UserSecretsClient().get_secret('HF_TOKEN')\n", " WORK = Path('/kaggle/working')\n", "elif PLATFORM == 'colab':\n", " from google.colab import userdata\n", " os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n", " WORK = Path('/content')\n", "elif PLATFORM == 'lightning':\n", " WORK = Path('/teamspace/studios/this_studio')\n", "elif PLATFORM == 'gcp':\n", " WORK = Path('/workspace')\n", "else: # 'local'\n", " WORK = Path.home() / 'cxr-vlm-work'\n", "WORK.mkdir(parents=True, exist_ok=True)\n", "\n", "assert os.environ.get('HF_TOKEN'), \\\n", " 'HF_TOKEN missing — set it via the platform secrets UI before re-running.'\n", "\n", "try:\n", " from huggingface_hub import snapshot_download, hf_hub_download, HfApi\n", "except ImportError:\n", " !pip install -q huggingface_hub\n", " from huggingface_hub import snapshot_download, hf_hub_download, HfApi\n", "\n", "# ── 2) Code: flat folder, few hundred files → snapshot_download ──\n", "print(f'Pulling code from HF (user: {HF_USER}) …')\n", "CODE_SRC = Path(snapshot_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-code',\n", " repo_type = 'model',\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(WORK / 'cxr-vlm-code'),\n", "))\n", "\n", "# ── 3) Data: layout depends on DATASET_NAME ──\n", "DATA_SRC = WORK / 'data'\n", "DATA_SRC.mkdir(parents=True, exist_ok=True)\n", "\n", "if DATASET_NAME == 'MIMIC-CXR_resized':\n", " # Tar-sharded payload. Reports + images live INSIDE the tars under\n", " # `files/pXX/pXXXX/{sYYYY/*.jpg, sYYYY.txt}` so extracting all shards\n", " # gives one unified tree. We download manifests + vqa + SHARDS.txt\n", " # first (small, ~tens of MB), then each *.tar one at a time →\n", " # extract → delete (saves disk).\n", " # Final on-disk layout:\n", " # DATA_SRC/MIMIC-CXR_resized/\n", " # ├── manifest_{train,val,test}.csv\n", " # ├── vqa/ {vqa.json, vqa_val.json, vqa_test.json}\n", " # ├── SHARDS.txt + _manifest.json\n", " # └── files/pXX/pXXXX/ ← from tars\n", " # ├── sYYYY.txt (report)\n", " # └── sYYYY/.jpg (images)\n", " import tarfile\n", " mr_dir = DATA_SRC / 'MIMIC-CXR_resized'\n", " mr_dir.mkdir(parents=True, exist_ok=True)\n", " files_dir = mr_dir / 'files'\n", "\n", " # Marker: if files/ already has shards extracted AND manifests exist,\n", " # skip everything. Lets the cell be re-run safely.\n", " manifests_present = all(\n", " (mr_dir / f).is_file() for f in ('manifest_train.csv', 'manifest_val.csv', 'manifest_test.csv')\n", " )\n", " if manifests_present and files_dir.is_dir() and any(files_dir.glob('p*')):\n", " print(f'{mr_dir} already populated — skipping download.')\n", " else:\n", " api = HfApi(token=os.environ['HF_TOKEN'])\n", " all_files = api.list_repo_files(\n", " repo_id=f'{HF_USER}/cxr-vlm-data', repo_type='dataset')\n", " mr_files = [f for f in all_files if f.startswith('MIMIC-CXR_resized/')]\n", " tar_files = sorted(f for f in mr_files if f.endswith('.tar'))\n", " meta_files = [f for f in mr_files if not f.endswith('.tar')]\n", " print(f'MIMIC-CXR_resized on HF: {len(tar_files)} tar shards + {len(meta_files)} metadata files')\n", "\n", " # 3a) Pull metadata (manifests, vqa, SHARDS.txt, _manifest.json)\n", " # in one snapshot (small; few MB).\n", " print(f' downloading manifests + vqa + SHARDS.txt …')\n", " snapshot_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-data',\n", " repo_type = 'dataset',\n", " allow_patterns = ['MIMIC-CXR_resized/*.csv',\n", " 'MIMIC-CXR_resized/*.json',\n", " 'MIMIC-CXR_resized/*.txt',\n", " 'MIMIC-CXR_resized/vqa/**'],\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(DATA_SRC),\n", " )\n", "\n", " # 3b) Sequentially fetch + extract + delete each image tar to\n", " # minimise peak disk usage (each shard ~2 GB). Reports come\n", " # out alongside images — both land under mr_dir/files/.\n", " print(f' downloading + extracting {len(tar_files)} tar shards …')\n", " for i, tf in enumerate(tar_files, 1):\n", " print(f' [{i}/{len(tar_files)}] {tf}')\n", " tar_path = Path(hf_hub_download(\n", " repo_id=f'{HF_USER}/cxr-vlm-data', repo_type='dataset',\n", " filename=tf, token=os.environ['HF_TOKEN'],\n", " local_dir=str(DATA_SRC),\n", " ))\n", " with tarfile.open(tar_path) as t:\n", " # Extract into mr_dir so member paths like\n", " # \"files/p10/.../*.jpg\" + \"files/p10/.../*.txt\" land at\n", " # mr_dir/files/p10/…\n", " t.extractall(mr_dir)\n", " tar_path.unlink(missing_ok=True)\n", " print(f' done. {mr_dir} ready.')\n", "\n", "else:\n", " # MIMIC-CXR / IU-Xray: single zip per dataset (legacy path)\n", " import zipfile\n", " zip_name = f'{DATASET_NAME}.zip' # 'IU-Xray.zip' | 'MIMIC-CXR.zip'\n", " marker = DATA_SRC / DATASET_NAME # DATA_SRC/IU-Xray after unzip\n", "\n", " if not marker.exists():\n", " print(f'Pulling {zip_name} from HF …')\n", " zpath = hf_hub_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-data',\n", " filename = zip_name,\n", " repo_type = 'dataset',\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(DATA_SRC),\n", " )\n", " print(f' unzipping → {DATA_SRC}')\n", " with zipfile.ZipFile(zpath) as zf:\n", " zf.extractall(DATA_SRC)\n", " try:\n", " os.remove(zpath) # free disk\n", " except OSError:\n", " pass\n", " else:\n", " print(f'{marker} already present — skipping download.')\n", "\n", "print(f'Contents of {DATA_SRC}: {sorted(os.listdir(DATA_SRC))}')\n", "\n", "# ── Common: copy code into writable PROJECT dir ────────────────────\n", "PROJECT = WORK / 'cxr_vlm'\n", "if CODE_SRC.resolve() != PROJECT.resolve() and not PROJECT.exists():\n", " shutil.copytree(CODE_SRC, PROJECT)\n", "\n", "os.chdir(PROJECT)\n", "sys.path.insert(0, str(PROJECT))\n", "print('PLATFORM :', PLATFORM)\n", "print('CODE_SRC :', CODE_SRC)\n", "print('DATA_SRC :', DATA_SRC)\n", "print('PROJECT :', PROJECT)\n", "print('WORK :', WORK)" ] }, { "cell_type": "markdown", "id": "s0-loc-md", "metadata": {}, "source": [ "## 4. Locate data" ] }, { "cell_type": "code", "id": "cell-find-data-mimic", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "def find_split_parent(root: Path) -> Path:\n", " for cand in [root, root / 'MIMIC-CXR', root / 'data' / 'MIMIC-CXR']:\n", " if (cand / 'train').exists() and (cand / 'valid').exists() and (cand / 'test').exists():\n", " return cand\n", " for p in root.rglob('train'):\n", " if p.is_dir() and (p.parent / 'valid').exists() and (p.parent / 'test').exists():\n", " return p.parent\n", " raise FileNotFoundError('Could not find train/ valid/ test/ under ' + str(root))\n", "\n", "\n", "def find_mimic_resized_root(root: Path) -> Path:\n", " \"\"\"Find the MIMIC-CXR_resized payload — folder with manifest_*.csv + files/.\"\"\"\n", " for cand in [root / 'MIMIC-CXR_resized', root, *root.rglob('MIMIC-CXR_resized')]:\n", " if (cand / 'manifest_train.csv').is_file():\n", " return cand\n", " raise FileNotFoundError(\n", " f'Could not find MIMIC-CXR_resized payload under {root}. '\n", " f'Expected manifest_train.csv (alongside manifest_val.csv / manifest_test.csv).'\n", " )\n", "\n", "\n", "def find_iu_dirs(root: Path):\n", " \"\"\"Locate IU-Xray `images/` and `labels/` (flat XMLs) under `root`.\n", "\n", " Resolution order:\n", " 1. `{root}/IU-Xray/{images,labels}` — canonical layout.\n", " 2. Any nested `IU-Xray` folder that contains both.\n", " 3. Fallback: any folder containing CXR*.png (images) and\n", " any folder containing *.xml — whichever comes first.\n", "\n", " The labels subfolder is treated as a flat directory of XMLs (we no\n", " longer require the legacy `ecgen-radiology/` subfolder).\n", " \"\"\"\n", " # Canonical + nested\n", " for cand in [root / 'IU-Xray', *root.rglob('IU-Xray')]:\n", " if not cand.is_dir():\n", " continue\n", " imgs = cand / 'images'\n", " lbls = cand / 'labels'\n", " if imgs.is_dir() and lbls.is_dir() and any(lbls.glob('*.xml')):\n", " return imgs, lbls\n", " # Legacy: labels/ecgen-radiology/*.xml\n", " legacy = lbls / 'ecgen-radiology'\n", " if imgs.is_dir() and legacy.is_dir() and any(legacy.glob('*.xml')):\n", " return imgs, legacy\n", "\n", " # Fallback: any images/ with CXR*.png + any folder with XML\n", " img_dir = lbl_dir = None\n", " for cand in [root / 'images', *root.rglob('images')]:\n", " if cand.is_dir() and any(cand.glob('CXR*.png')):\n", " img_dir = cand; break\n", " for cand in [root / 'labels', *root.rglob('labels')]:\n", " if cand.is_dir() and any(cand.glob('*.xml')):\n", " lbl_dir = cand; break\n", " if lbl_dir is None:\n", " # very last resort — any ecgen-radiology folder with XMLs\n", " for cand in root.rglob('ecgen-radiology'):\n", " if cand.is_dir() and any(cand.glob('*.xml')):\n", " lbl_dir = cand; break\n", " return img_dir, lbl_dir\n", "\n", "\n", "# Filled in below depending on DATASET_NAME\n", "CXR_ROOT = None # MIMIC-CXR root (with train/valid/test subdirs)\n", "SPLIT_DIRS = None # MIMIC only\n", "VQA_ROOT = None # MIMIC only\n", "MR_ROOT = None # MIMIC-CXR_resized root (manifests + files/ + vqa/)\n", "IU_IMAGES_DIR = None # IU-Xray only\n", "IU_LABELS_DIR = None # IU-Xray only\n", "\n", "if DATASET_NAME == 'MIMIC-CXR':\n", " CXR_ROOT = find_split_parent(DATA_SRC)\n", " print('MIMIC-CXR root:', CXR_ROOT)\n", "\n", " SPLIT_DIRS = {\n", " 'train' : ('train', CXR_ROOT / 'train'),\n", " 'validate': ('valid', CXR_ROOT / 'valid'),\n", " 'test' : ('test', CXR_ROOT / 'test'),\n", " }\n", " for s, (sub, d) in SPLIT_DIRS.items():\n", " assert d.exists(), f'Missing split dir: {d}'\n", " print(f' {s:<9s} → {d}')\n", "\n", " for p in DATA_SRC.rglob('MIMIC-Ext-MIMIC-CXR-VQA'):\n", " cand = p / 'dataset'\n", " if cand.exists() and (cand / 'train.json').exists():\n", " VQA_ROOT = cand\n", " break\n", " assert VQA_ROOT is not None, 'VQA dataset folder not found under ' + str(DATA_SRC)\n", " print('VQA root:', VQA_ROOT)\n", "\n", "elif DATASET_NAME == 'MIMIC-CXR_resized':\n", " MR_ROOT = find_mimic_resized_root(DATA_SRC)\n", " print('MIMIC-CXR_resized root:', MR_ROOT)\n", " # Sanity: 3 manifest CSVs, files/ (images+reports), vqa/\n", " for cf in ('manifest_train.csv', 'manifest_val.csv', 'manifest_test.csv'):\n", " f = MR_ROOT / cf\n", " print(f' {cf}: {\"OK\" if f.is_file() else \"MISSING\"}')\n", " for sub in ('files', 'vqa'):\n", " d = MR_ROOT / sub\n", " print(f' {sub:<5s}: {\"OK\" if d.is_dir() else \"MISSING\"} ({d})')\n", " # Spot-check one report (.txt) sits at patient-dir level inside files/\n", " txt_hits = list((MR_ROOT / 'files').glob('p*/p*/s*.txt')) if (MR_ROOT / 'files').is_dir() else []\n", " print(f' reports inside files/ : {len(txt_hits):,} found (sample: {txt_hits[0] if txt_hits else \"—\"})')\n", "\n", "else: # IU-Xray\n", " IU_IMAGES_DIR, IU_LABELS_DIR = find_iu_dirs(DATA_SRC)\n", " assert IU_IMAGES_DIR is not None, f'IU images/ not found under {DATA_SRC}'\n", " assert IU_LABELS_DIR is not None, f'IU labels/ (with *.xml) not found under {DATA_SRC}'\n", " print('IU images dir:', IU_IMAGES_DIR, '→', len(list(IU_IMAGES_DIR.glob('*.png'))), 'PNGs')\n", " print('IU labels dir:', IU_LABELS_DIR, '→', len(list(IU_LABELS_DIR.glob('*.xml'))), 'XMLs')" ] }, { "cell_type": "markdown", "id": "s0-lab-md", "metadata": {}, "source": [ "## 5. Read manifests → per-image (image_path, 14-label tensor)" ] }, { "cell_type": "code", "id": "s0-lab", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Read manifests, build per-image label tensors (14 × {0,1,2}).\n", "import csv\n", "import torch\n", "from pathlib import Path\n", "from model.chexpert_classifier import (\n", " PATHOLOGIES, CLASS_NEGATIVE, CLASS_POSITIVE, CLASS_UNCERTAIN, NUM_STATES\n", ")\n", "\n", "VAL_TO_CLS = {\n", " '1': CLASS_POSITIVE, '1.0': CLASS_POSITIVE,\n", " '0': CLASS_NEGATIVE, '0.0': CLASS_NEGATIVE,\n", " '-1': CLASS_UNCERTAIN, '-1.0': CLASS_UNCERTAIN,\n", "}\n", "\n", "def row_to_labels(row):\n", " out = torch.zeros(len(PATHOLOGIES), dtype=torch.long)\n", " for j, name in enumerate(PATHOLOGIES):\n", " v = str(row.get(f'chex_{name}', '')).strip()\n", " out[j] = VAL_TO_CLS.get(v, CLASS_NEGATIVE)\n", " return out\n", "\n", "def locate_image(row, mr_root):\n", " # Layout used in this dataset: flat `Study_.jpg` directly under MR_ROOT\n", " sn = row.get('study_name', '').strip()\n", " if sn:\n", " p = mr_root / f'{sn}.jpg'\n", " if p.is_file(): return p\n", " # Fallback: image_filename (e.g. .jpg) at MR_ROOT\n", " fn = row.get('image_filename', '').strip()\n", " if fn:\n", " p = mr_root / fn\n", " if p.is_file(): return p\n", " # Fallback: canonical image_relpath\n", " rp = row.get('image_relpath', '').strip()\n", " if rp:\n", " p = mr_root / rp\n", " if p.is_file(): return p\n", " return None\n", "\n", "per_split = {'train': [], 'validate': [], 'test': []}\n", "csv_map = {'train': 'manifest_train.csv',\n", " 'validate': 'manifest_val.csv',\n", " 'test': 'manifest_test.csv'}\n", "missing = 0\n", "for split, csv_name in csv_map.items():\n", " csv_path = MR_ROOT / csv_name\n", " if not csv_path.is_file():\n", " print(f' skip {csv_name} (not found)')\n", " continue\n", " with open(csv_path, encoding='utf-8', newline='') as f:\n", " for row in csv.DictReader(f):\n", " img = locate_image(row, MR_ROOT)\n", " if img is None:\n", " missing += 1\n", " continue\n", " per_split[split].append({\n", " 'image_path': str(img),\n", " 'labels': row_to_labels(row),\n", " 'study_name': row.get('study_name', ''),\n", " })\n", "\n", "for s, lst in per_split.items():\n", " print(f' {s:8s}: {len(lst):,}')\n", "print(f' missing image files: {missing:,}')\n", "assert len(per_split['train']) > 0, 'no train samples — check MR_ROOT / manifest layout'" ] }, { "cell_type": "markdown", "id": "s0-pre-md", "metadata": {}, "source": [ "## 6. Precompute RAD-DINO CLS embeddings (one forward per image — cached)" ] }, { "cell_type": "code", "id": "s0-pre", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Precompute RAD-DINO global (CLS) embeddings — ONE forward per unique image.\n", "# Encoder frozen, so the head only ever needs these 768-d vectors.\n", "import torch, gc\n", "from PIL import Image\n", "from tqdm.auto import tqdm\n", "from transformers import AutoModel, AutoImageProcessor\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print('device:', device)\n", "\n", "proc = AutoImageProcessor.from_pretrained('microsoft/rad-dino')\n", "enc = AutoModel.from_pretrained('microsoft/rad-dino').eval().to(device)\n", "\n", "cache_path = Path(EMBED_CACHE_PT)\n", "cache_path.parent.mkdir(parents=True, exist_ok=True)\n", "\n", "# Resume support: if cache already has some entries, skip them.\n", "if cache_path.is_file():\n", " print(f'found existing cache, loading: {cache_path}')\n", " cached = torch.load(cache_path, map_location='cpu')\n", " embeds = cached.get('embeds', cached)\n", "else:\n", " embeds = {}\n", "\n", "# Union of all unique image_paths across splits — dedup so we don't encode twice.\n", "all_imgs = list({s['image_path'] for split in per_split.values() for s in split})\n", "todo = [p for p in all_imgs if p not in embeds]\n", "print(f'images total {len(all_imgs):,} | already cached {len(embeds):,} | todo {len(todo):,}')\n", "\n", "@torch.no_grad()\n", "def encode_batch(paths):\n", " pils = [Image.open(p).convert('RGB') for p in paths]\n", " pix = proc(images=pils, return_tensors='pt')['pixel_values'].to(device)\n", " out = enc(pixel_values=pix)\n", " h = out.last_hidden_state # (B, N+1, 768)\n", " if USE_CLS_TOKEN:\n", " v = h[:, 0, :] # CLS\n", " else:\n", " v = h[:, 1:, :].mean(dim=1) # mean-pool patches\n", " return v.cpu()\n", "\n", "BATCH_ENC = 32\n", "save_every = 2000 # checkpoint cache to disk periodically (T4 can be killed)\n", "last_save = len(embeds)\n", "\n", "for i in tqdm(range(0, len(todo), BATCH_ENC)):\n", " chunk = todo[i:i+BATCH_ENC]\n", " try:\n", " v = encode_batch(chunk)\n", " for p, vec in zip(chunk, v):\n", " embeds[p] = vec.clone()\n", " except Exception as e:\n", " print(f' ! batch failed at {i}: {type(e).__name__}: {e} — skipping')\n", " continue\n", " if len(embeds) - last_save >= save_every:\n", " torch.save({'embeds': embeds, 'meta': {'cls': USE_CLS_TOKEN}}, cache_path)\n", " last_save = len(embeds)\n", "\n", "torch.save({'embeds': embeds, 'meta': {'cls': USE_CLS_TOKEN}}, cache_path)\n", "print(f'cache saved: {cache_path} ({len(embeds):,} embeddings)')\n", "\n", "# Free encoder VRAM — head training does not need it.\n", "del enc, proc\n", "gc.collect()\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "s0-dl-md", "metadata": {}, "source": [ "## 7. Build DataLoaders + class frequency stats\n", "\n", "No class weights here — we use **Asymmetric Focal CE** (next cell) to handle\n", "imbalance, which is the cleaner mechanism for highly skewed multi-label data\n", "(Ben-Baruch et al., ICCV 2021). Stacking inverse-frequency weights on top of\n", "asymmetric focal decay causes over-correction (the recall-bias / precision-collapse\n", "pattern we saw on the first run)." ] }, { "cell_type": "code", "id": "s0-dl", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Build (embedding, labels) datasets from the cache.\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "class CachedChexDataset(Dataset):\n", " def __init__(self, samples, embed_dict):\n", " # Keep only samples whose embedding was successfully encoded.\n", " self.samples = [s for s in samples if s['image_path'] in embed_dict]\n", " self.embeds = embed_dict\n", " def __len__(self): return len(self.samples)\n", " def __getitem__(self, i):\n", " s = self.samples[i]\n", " return self.embeds[s['image_path']].float(), s['labels']\n", "\n", "train_ds = CachedChexDataset(per_split['train'], embeds)\n", "val_ds = CachedChexDataset(per_split['validate'], embeds)\n", "print(f'train: {len(train_ds):,} val: {len(val_ds):,}')\n", "\n", "train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n", " num_workers=2, pin_memory=True)\n", "val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,\n", " num_workers=2, pin_memory=True)\n", "\n", "# Per-(pathology, class) frequency stats — diagnostic only. ASL replaces\n", "# the old inverse-frequency weighting (see next cell).\n", "import torch\n", "counts = torch.zeros(len(PATHOLOGIES), NUM_STATES, dtype=torch.float)\n", "for s in per_split['train']:\n", " for j, c in enumerate(s['labels'].tolist()):\n", " counts[j, c] += 1\n", "\n", "n_train = float(counts.sum(dim=1).max().item())\n", "print(f'per-pathology counts (n_train ≈ {int(n_train):,}):')\n", "print(f' {\"pathology\":<26s} {\"pos\":>7s} {\"neg\":>7s} {\"unc\":>7s} pos%')\n", "for j, name in enumerate(PATHOLOGIES):\n", " pos = counts[j, CLASS_POSITIVE].item()\n", " neg = counts[j, CLASS_NEGATIVE].item()\n", " unc = counts[j, CLASS_UNCERTAIN].item()\n", " pct = 100.0 * pos / max(1.0, pos + neg + unc)\n", " print(f' {name:<26s} {pos:>7.0f} {neg:>7.0f} {unc:>7.0f} {pct:5.2f}%')" ] }, { "cell_type": "markdown", "id": "s0-tr-md", "metadata": {}, "source": [ "## 8. Train the head with Asymmetric Focal CE (ASL-style)\n", "\n", "Per-pathology 3-class softmax (`neg / pos / unc`). Loss per sample:\n", "\n", "```\n", "L = -(1 - p_t)^γ * log(p_t)\n", "γ = γ_pos (= 0) if target == POSITIVE else γ_neg (= 4)\n", "```\n", "\n", "- positive (rare, informative) → **no focal decay** → keep full gradient signal\n", "- negative (dominant ~50–99%) → strong decay → easy negatives stop driving loss\n", "- uncertain (also rare) → kept with positive treatment (also γ=0)\n", "\n", "Best epoch selected by **macro F1 of the positive class** (not recall — that was\n", "what biased the first run toward over-predicting positive)." ] }, { "cell_type": "code", "id": "s0-tr", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Train the head with Asymmetric Focal CE (multi-class adaptation of ASL).\n", "import torch, torch.nn as nn, torch.nn.functional as F\n", "from model.chexpert_classifier import CheXpertClassifier\n", "\n", "# ── ASL hyperparams ──\n", "ASL_GAMMA_POS = 0.0 # focal exponent when target == POSITIVE (rare → no decay)\n", "ASL_GAMMA_NEG = 4.0 # focal exponent when target != POSITIVE (dominant → strong decay)\n", " # Push to 5–6 if rare pathologies (Pleural Other, EC) still\n", " # underperform after this run.\n", "\n", "def asymmetric_focal_ce(logits, target,\n", " gamma_pos=ASL_GAMMA_POS, gamma_neg=ASL_GAMMA_NEG):\n", " \"\"\"\n", " Multi-class asymmetric focal CE on a 3-class softmax head.\n", " logits : (B, 3)\n", " target : (B,) ∈ {0=neg, 1=pos, 2=unc}\n", " POSITIVE samples get γ_pos (no decay by default); NEGATIVE+UNCERTAIN get γ_neg.\n", " Uncertain is rare too, so we keep it under γ_pos as well (see is_rare below).\n", " \"\"\"\n", " log_probs = F.log_softmax(logits, dim=-1) # (B, 3)\n", " probs = log_probs.exp()\n", " pt = probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n", " log_pt = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n", " # Rare classes (positive OR uncertain) → keep full signal\n", " is_rare = (target != CLASS_NEGATIVE).float()\n", " gamma = is_rare * gamma_pos + (1.0 - is_rare) * gamma_neg\n", " return -((1.0 - pt) ** gamma) * log_pt # (B,)\n", "\n", "emb_dim = next(iter(train_ds))[0].numel() # 768 for RAD-DINO CLS\n", "print('embedding dim:', emb_dim)\n", "\n", "clf = CheXpertClassifier(input_dim=emb_dim, num_classes=len(PATHOLOGIES)).to(device)\n", "opt = torch.optim.AdamW(clf.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=NUM_EPOCHS)\n", "\n", "def epoch_pass(dl, train=True):\n", " clf.train(train)\n", " total_loss = 0.0\n", " n_total = 0\n", " # Per-pathology tp/fp/fn for precision-recall-F1 of POSITIVE class\n", " tp = [0]*len(PATHOLOGIES); fp = [0]*len(PATHOLOGIES); fn = [0]*len(PATHOLOGIES)\n", " for x, y in dl:\n", " x = x.to(device); y = y.to(device)\n", " logits = clf(x) # (B, 14, 3)\n", " # Loss: mean asymmetric focal CE over the 14 pathologies\n", " loss = 0.0\n", " for j in range(len(PATHOLOGIES)):\n", " loss = loss + asymmetric_focal_ce(logits[:, j, :], y[:, j]).mean()\n", " loss = loss / len(PATHOLOGIES)\n", " if train:\n", " opt.zero_grad(); loss.backward(); opt.step()\n", " total_loss += loss.item() * x.size(0)\n", " pred = logits.argmax(dim=-1) # (B, 14)\n", " n_total += x.size(0)\n", " for j in range(len(PATHOLOGIES)):\n", " p = pred[:, j]; t = y[:, j]\n", " tp[j] += ((p == CLASS_POSITIVE) & (t == CLASS_POSITIVE)).sum().item()\n", " fp[j] += ((p == CLASS_POSITIVE) & (t != CLASS_POSITIVE)).sum().item()\n", " fn[j] += ((p != CLASS_POSITIVE) & (t == CLASS_POSITIVE)).sum().item()\n", " prec = [tp[j] / max(1, tp[j] + fp[j]) for j in range(len(PATHOLOGIES))]\n", " rec = [tp[j] / max(1, tp[j] + fn[j]) for j in range(len(PATHOLOGIES))]\n", " f1 = [2*p*r/max(1e-9, p+r) for p, r in zip(prec, rec)]\n", " return total_loss / max(1, n_total), prec, rec, f1\n", "\n", "best_macro_f1 = -1.0\n", "best_state = None\n", "for ep in range(1, NUM_EPOCHS + 1):\n", " tl, _, _, _ = epoch_pass(train_dl, train=True)\n", " vl, vprec, vrec, vf1 = epoch_pass(val_dl, train=False)\n", " sched.step()\n", " macro_p = sum(vprec) / len(vprec)\n", " macro_r = sum(vrec) / len(vrec)\n", " macro_f1 = sum(vf1) / len(vf1)\n", " print(f'epoch {ep:2d}/{NUM_EPOCHS} train_loss={tl:.4f} val_loss={vl:.4f} '\n", " f'P={macro_p:.3f} R={macro_r:.3f} F1={macro_f1:.3f} '\n", " f'lr={sched.get_last_lr()[0]:.2e}')\n", " if macro_f1 > best_macro_f1:\n", " best_macro_f1 = macro_f1\n", " best_state = {k: v.detach().cpu().clone() for k, v in clf.state_dict().items()}\n", "\n", "print(f'best val macro-F1 (positive class): {best_macro_f1:.3f}')\n", "clf.load_state_dict(best_state)" ] }, { "cell_type": "markdown", "id": "s0-ev-md", "metadata": {}, "source": [ "## 9. Per-pathology validation metrics" ] }, { "cell_type": "code", "id": "s0-ev", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Per-pathology metrics on val (uses best checkpoint loaded above).\n", "from collections import defaultdict\n", "clf.eval()\n", "pred_log = defaultdict(lambda: {'tp':0,'fp':0,'fn':0,'tn':0,'unc_correct':0,'unc_total':0})\n", "n_total = 0\n", "with torch.no_grad():\n", " for x, y in val_dl:\n", " x = x.to(device); y = y.to(device)\n", " pred = clf(x).argmax(dim=-1)\n", " n_total += x.size(0)\n", " for j, name in enumerate(PATHOLOGIES):\n", " p = pred[:, j]; t = y[:, j]\n", " d = pred_log[name]\n", " d['tp'] += ((p == CLASS_POSITIVE) & (t == CLASS_POSITIVE)).sum().item()\n", " d['fp'] += ((p == CLASS_POSITIVE) & (t != CLASS_POSITIVE)).sum().item()\n", " d['fn'] += ((p != CLASS_POSITIVE) & (t == CLASS_POSITIVE)).sum().item()\n", " d['tn'] += ((p == CLASS_NEGATIVE) & (t == CLASS_NEGATIVE)).sum().item()\n", " d['unc_total'] += (t == CLASS_UNCERTAIN).sum().item()\n", " d['unc_correct'] += ((p == CLASS_UNCERTAIN) & (t == CLASS_UNCERTAIN)).sum().item()\n", "\n", "print(f'{\"pathology\":<26s} prec rec F1 uncert-acc')\n", "print('-' * 70)\n", "f1s = []\n", "for name in PATHOLOGIES:\n", " d = pred_log[name]\n", " prec = d['tp'] / max(1, d['tp'] + d['fp'])\n", " rec = d['tp'] / max(1, d['tp'] + d['fn'])\n", " f1 = 2 * prec * rec / max(1e-9, prec + rec)\n", " uacc = d['unc_correct'] / max(1, d['unc_total'])\n", " f1s.append(f1)\n", " print(f'{name:<26s} {prec:.3f} {rec:.3f} {f1:.3f} {uacc:.3f}')\n", "print('-' * 70)\n", "print(f'macro-F1 (positive class): {sum(f1s)/len(f1s):.3f}')" ] }, { "cell_type": "markdown", "id": "s0-sv-md", "metadata": {}, "source": [ "## 10. Save + upload checkpoint" ] }, { "cell_type": "code", "id": "s0-sv", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Save head weights, upload to HF.\n", "import os\n", "from pathlib import Path\n", "from huggingface_hub import HfApi\n", "\n", "ckpt_path = Path(CKPT_OUT)\n", "ckpt_path.parent.mkdir(parents=True, exist_ok=True)\n", "torch.save(clf.state_dict(), ckpt_path)\n", "print(f'saved -> {ckpt_path} ({ckpt_path.stat().st_size/1e6:.2f} MB)')\n", "\n", "if DO_UPLOAD:\n", " api = HfApi(token=os.environ.get('HF_TOKEN'))\n", " path_in_repo = f'{HF_SUBDIR}/{ckpt_path.name}'\n", " api.upload_file(path_or_fileobj=str(ckpt_path), path_in_repo=path_in_repo,\n", " repo_id=HF_REPO_ID, repo_type=HF_REPO_TYPE)\n", " print(f'uploaded -> {HF_REPO_ID}/{path_in_repo}')\n", "\n", "print()\n", "print('=== Next step ===')\n", "print('In configs/model_config.yaml set:')\n", "print(f' chexpert_classifier.checkpoint: \"/{ckpt_path.name}\"')\n", "print('The downstream Stage 1/2 training will then build the classifier with these weights.')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" } }, "nbformat": 4, "nbformat_minor": 5 }