{ "cells": [ { "cell_type": "markdown", "id": "pc-intro", "metadata": {}, "source": [ "# Precompute CXR-BERT text embeddings (SELF-CONTAINED)\n", "\n", "One-shot, runnable on its own (Colab / Kaggle / Lightning). It:\n", "1. pulls the project code + the chosen dataset from HF (same as the train notebook),\n", "2. locates the data and **builds the instruct JSON for the chosen mode via the same\n", " resolver the trainer uses** (so the text + ids are byte-identical to training),\n", "3. runs **microsoft/BiomedVLP-CXR-BERT-specialized** over the per-study report text,\n", "4. saves one **128-d L2-normalized** embedding per study and uploads it to\n", " `hieu3636/cxr-vlm-data/cxr_bert_text_embeddings/`.\n", "\n", "The cache feeds the Stage-1 ITC contrastive loss (`stage1.itc.enabled=true`). CXR-BERT\n", "never has to live in the training env → no version conflicts with Vicuna/PEFT/bnb.\n", "\n", "**Set `DATASET_NAME`, `REPORT_MODE`, `IMAGE_MODE` below to MATCH your training run.**" ] }, { "cell_type": "markdown", "id": "pc-sel-md", "metadata": {}, "source": [ "## 1. Selectors" ] }, { "cell_type": "code", "id": "pc-sel", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# ── Selectors — MUST match the training run ──\n", "PLATFORM = 'colab' # 'kaggle' | 'colab' | 'lightning' | 'gcp' | 'local'\n", "DATASET_NAME = 'IU-Xray' # 'MIMIC-CXR' | 'MIMIC-CXR_resized' | 'IU-Xray'\n", "REPORT_MODE = 'split_cascade' # 'split' | 'merged' | 'split_cascade'\n", "IMAGE_MODE = 'frontal_only_split' # 'all_views_split' | 'frontal_only_split' | 'multi_image_merged'\n", "\n", "assert PLATFORM in ('kaggle', 'colab', 'lightning', 'gcp', 'local')\n", "assert DATASET_NAME in ('MIMIC-CXR', 'MIMIC-CXR_resized', 'IU-Xray')\n", "assert REPORT_MODE in ('split', 'merged', 'split_cascade')\n", "assert IMAGE_MODE in ('all_views_split', 'frontal_only_split', 'multi_image_merged')\n", "\n", "# ── CXR-BERT / ITC settings ──\n", "MODEL_NAME = 'microsoft/BiomedVLP-CXR-BERT-specialized' # projected dim = 128\n", "MAX_LEN = 256\n", "BATCH_SIZE = 128\n", "FALLBACK_TO_IMPRESSION = True\n", "\n", "# ── HF upload target (dataset-based filename; mode-independent) ──\n", "HF_REPO_ID = 'hieu3636/cxr-vlm-data'\n", "HF_REPO_TYPE = 'dataset'\n", "HF_SUBDIR = 'cxr_bert_text_embeddings'\n", "DO_UPLOAD = True\n", "_CACHE_NAME = {\n", " 'IU-Xray': 'cxrbert_text_embeds_iu_xray.pt',\n", " 'MIMIC-CXR_resized': 'cxrbert_text_embeds_mimic_resized.pt',\n", " 'MIMIC-CXR': 'cxrbert_text_embeds_mimic.pt',\n", "}\n", "print(f'{DATASET_NAME} | {REPORT_MODE} | {IMAGE_MODE}')" ] }, { "cell_type": "markdown", "id": "pc-inst-md", "metadata": {}, "source": [ "## 2. Install deps" ] }, { "cell_type": "code", "id": "pc-inst", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Precompute env: needs transformers (for the frozen-encoder *class* import in\n", "# the builder chain) + huggingface_hub (modern, has .errors submodule).\n", "# - Uninstall peft: Colab preinstalls a newer peft that imports\n", "# huggingface_hub.errors. With the old hub on Colab base image, peft import\n", "# chains into ModuleNotFoundError. We don't use peft here (no Vicuna/LoRA in\n", "# precompute), so just remove it.\n", "# - Pin huggingface_hub>=0.24 (has .errors); transformers==4.35 (matches repo).\n", "!pip uninstall -y -q peft\n", "!pip install -q \"transformers==4.35.0\" \"huggingface_hub>=0.24.0\" omegaconf pillow tqdm einops sentencepiece" ] }, { "cell_type": "markdown", "id": "pc-env-md", "metadata": {}, "source": [ "## 3. Environment + pull code & data (from the train notebook)" ] }, { "cell_type": "code", "id": "cell-env", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0' # single-GPU\n", "os.environ['TOKENIZERS_PARALLELISM'] = 'false' # silence HF tokenizers fork warning\n", "os.environ['BITSANDBYTES_NOWELCOME'] = '1'\n", "os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # kill per-shard download bars\n", "os.environ['TRANSFORMERS_VERBOSITY'] = 'warning'\n", "os.environ['PYTHONUNBUFFERED'] = '1'\n", "\n", "import sys, shutil, subprocess\n", "from pathlib import Path\n" ] }, { "cell_type": "code", "id": "cell-paths", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# ── Per-platform storage + source-of-truth ─────────────────────────\n", "# All platforms (kaggle / colab / lightning / gcp / local) pull code +\n", "# data from HF Hub. The only platform-specific bit is:\n", "# * WORK : where to land outputs (persisted dirs differ per host)\n", "# * TOKEN : how HF_TOKEN reaches os.environ (secrets API differs)\n", "#\n", "# Required HF repos:\n", "# /cxr-vlm-code — project source (flat folder)\n", "# /cxr-vlm-data — per-dataset payloads:\n", "# MIMIC-CXR_resized/ (tar shards + manifests + vqa)\n", "# MIMIC-CXR.zip (single zip)\n", "# IU-Xray.zip (single zip)\n", "\n", "HF_USER = 'hieu3636' # <<< EDIT ME\n", "\n", "# ── 1) WORK dir + HF_TOKEN bootstrap (platform-specific) ───────────\n", "if PLATFORM == 'kaggle':\n", " from kaggle_secrets import UserSecretsClient\n", " os.environ['HF_TOKEN'] = UserSecretsClient().get_secret('HF_TOKEN')\n", " WORK = Path('/kaggle/working')\n", "elif PLATFORM == 'colab':\n", " from google.colab import userdata\n", " os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n", " WORK = Path('/content')\n", "elif PLATFORM == 'lightning':\n", " WORK = Path('/teamspace/studios/this_studio')\n", "elif PLATFORM == 'gcp':\n", " WORK = Path('/workspace')\n", "else: # 'local'\n", " WORK = Path.home() / 'cxr-vlm-work'\n", "WORK.mkdir(parents=True, exist_ok=True)\n", "\n", "assert os.environ.get('HF_TOKEN'), \\\n", " 'HF_TOKEN missing — set it via the platform secrets UI before re-running.'\n", "\n", "try:\n", " from huggingface_hub import snapshot_download, hf_hub_download, HfApi\n", "except ImportError:\n", " !pip install -q huggingface_hub\n", " from huggingface_hub import snapshot_download, hf_hub_download, HfApi\n", "\n", "# ── 2) Code: flat folder, few hundred files → snapshot_download ──\n", "print(f'Pulling code from HF (user: {HF_USER}) …')\n", "CODE_SRC = Path(snapshot_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-code',\n", " repo_type = 'model',\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(WORK / 'cxr-vlm-code'),\n", "))\n", "\n", "# ── 3) Data: layout depends on DATASET_NAME ──\n", "DATA_SRC = WORK / 'data'\n", "DATA_SRC.mkdir(parents=True, exist_ok=True)\n", "\n", "if DATASET_NAME == 'MIMIC-CXR_resized':\n", " # Tar-sharded payload. Reports + images live INSIDE the tars under\n", " # `files/pXX/pXXXX/{sYYYY/*.jpg, sYYYY.txt}` so extracting all shards\n", " # gives one unified tree. We download manifests + vqa + SHARDS.txt\n", " # first (small, ~tens of MB), then each *.tar one at a time →\n", " # extract → delete (saves disk).\n", " # Final on-disk layout:\n", " # DATA_SRC/MIMIC-CXR_resized/\n", " # ├── manifest_{train,val,test}.csv\n", " # ├── vqa/ {vqa.json, vqa_val.json, vqa_test.json}\n", " # ├── SHARDS.txt + _manifest.json\n", " # └── files/pXX/pXXXX/ ← from tars\n", " # ├── sYYYY.txt (report)\n", " # └── sYYYY/.jpg (images)\n", " import tarfile\n", " mr_dir = DATA_SRC / 'MIMIC-CXR_resized'\n", " mr_dir.mkdir(parents=True, exist_ok=True)\n", " files_dir = mr_dir / 'files'\n", "\n", " # Marker: if files/ already has shards extracted AND manifests exist,\n", " # skip everything. Lets the cell be re-run safely.\n", " manifests_present = all(\n", " (mr_dir / f).is_file() for f in ('manifest_train.csv', 'manifest_val.csv', 'manifest_test.csv')\n", " )\n", " if manifests_present and files_dir.is_dir() and any(files_dir.glob('p*')):\n", " print(f'{mr_dir} already populated — skipping download.')\n", " else:\n", " api = HfApi(token=os.environ['HF_TOKEN'])\n", " all_files = api.list_repo_files(\n", " repo_id=f'{HF_USER}/cxr-vlm-data', repo_type='dataset')\n", " mr_files = [f for f in all_files if f.startswith('MIMIC-CXR_resized/')]\n", " tar_files = sorted(f for f in mr_files if f.endswith('.tar'))\n", " meta_files = [f for f in mr_files if not f.endswith('.tar')]\n", " print(f'MIMIC-CXR_resized on HF: {len(tar_files)} tar shards + {len(meta_files)} metadata files')\n", "\n", " # 3a) Pull metadata (manifests, vqa, SHARDS.txt, _manifest.json)\n", " # in one snapshot (small; few MB).\n", " print(f' downloading manifests + vqa + SHARDS.txt …')\n", " snapshot_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-data',\n", " repo_type = 'dataset',\n", " allow_patterns = ['MIMIC-CXR_resized/*.csv',\n", " 'MIMIC-CXR_resized/*.json',\n", " 'MIMIC-CXR_resized/*.txt',\n", " 'MIMIC-CXR_resized/vqa/**'],\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(DATA_SRC),\n", " )\n", "\n", " # 3b) Sequentially fetch + extract + delete each image tar to\n", " # minimise peak disk usage (each shard ~2 GB). Reports come\n", " # out alongside images — both land under mr_dir/files/.\n", " print(f' downloading + extracting {len(tar_files)} tar shards …')\n", " for i, tf in enumerate(tar_files, 1):\n", " print(f' [{i}/{len(tar_files)}] {tf}')\n", " tar_path = Path(hf_hub_download(\n", " repo_id=f'{HF_USER}/cxr-vlm-data', repo_type='dataset',\n", " filename=tf, token=os.environ['HF_TOKEN'],\n", " local_dir=str(DATA_SRC),\n", " ))\n", " with tarfile.open(tar_path) as t:\n", " # Extract into mr_dir so member paths like\n", " # \"files/p10/.../*.jpg\" + \"files/p10/.../*.txt\" land at\n", " # mr_dir/files/p10/…\n", " t.extractall(mr_dir)\n", " tar_path.unlink(missing_ok=True)\n", " print(f' done. {mr_dir} ready.')\n", "\n", "else:\n", " # MIMIC-CXR / IU-Xray: single zip per dataset (legacy path)\n", " import zipfile\n", " zip_name = f'{DATASET_NAME}.zip' # 'IU-Xray.zip' | 'MIMIC-CXR.zip'\n", " marker = DATA_SRC / DATASET_NAME # DATA_SRC/IU-Xray after unzip\n", "\n", " if not marker.exists():\n", " print(f'Pulling {zip_name} from HF …')\n", " zpath = hf_hub_download(\n", " repo_id = f'{HF_USER}/cxr-vlm-data',\n", " filename = zip_name,\n", " repo_type = 'dataset',\n", " token = os.environ['HF_TOKEN'],\n", " local_dir = str(DATA_SRC),\n", " )\n", " print(f' unzipping → {DATA_SRC}')\n", " with zipfile.ZipFile(zpath) as zf:\n", " zf.extractall(DATA_SRC)\n", " try:\n", " os.remove(zpath) # free disk\n", " except OSError:\n", " pass\n", " else:\n", " print(f'{marker} already present — skipping download.')\n", "\n", "print(f'Contents of {DATA_SRC}: {sorted(os.listdir(DATA_SRC))}')\n", "\n", "# ── Common: copy code into writable PROJECT dir ────────────────────\n", "PROJECT = WORK / 'cxr_vlm'\n", "if CODE_SRC.resolve() != PROJECT.resolve() and not PROJECT.exists():\n", " shutil.copytree(CODE_SRC, PROJECT)\n", "\n", "os.chdir(PROJECT)\n", "sys.path.insert(0, str(PROJECT))\n", "print('PLATFORM :', PLATFORM)\n", "print('CODE_SRC :', CODE_SRC)\n", "print('DATA_SRC :', DATA_SRC)\n", "print('PROJECT :', PROJECT)\n", "print('WORK :', WORK)" ] }, { "cell_type": "markdown", "id": "pc-loc-md", "metadata": {}, "source": [ "## 4. Locate data" ] }, { "cell_type": "code", "id": "cell-find-data-mimic", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "def find_split_parent(root: Path) -> Path:\n", " for cand in [root, root / 'MIMIC-CXR', root / 'data' / 'MIMIC-CXR']:\n", " if (cand / 'train').exists() and (cand / 'valid').exists() and (cand / 'test').exists():\n", " return cand\n", " for p in root.rglob('train'):\n", " if p.is_dir() and (p.parent / 'valid').exists() and (p.parent / 'test').exists():\n", " return p.parent\n", " raise FileNotFoundError('Could not find train/ valid/ test/ under ' + str(root))\n", "\n", "\n", "def find_mimic_resized_root(root: Path) -> Path:\n", " \"\"\"Find the MIMIC-CXR_resized payload — folder with manifest_*.csv + files/.\"\"\"\n", " for cand in [root / 'MIMIC-CXR_resized', root, *root.rglob('MIMIC-CXR_resized')]:\n", " if (cand / 'manifest_train.csv').is_file():\n", " return cand\n", " raise FileNotFoundError(\n", " f'Could not find MIMIC-CXR_resized payload under {root}. '\n", " f'Expected manifest_train.csv (alongside manifest_val.csv / manifest_test.csv).'\n", " )\n", "\n", "\n", "def find_iu_dirs(root: Path):\n", " \"\"\"Locate IU-Xray `images/` and `labels/` (flat XMLs) under `root`.\n", "\n", " Resolution order:\n", " 1. `{root}/IU-Xray/{images,labels}` — canonical layout.\n", " 2. Any nested `IU-Xray` folder that contains both.\n", " 3. Fallback: any folder containing CXR*.png (images) and\n", " any folder containing *.xml — whichever comes first.\n", "\n", " The labels subfolder is treated as a flat directory of XMLs (we no\n", " longer require the legacy `ecgen-radiology/` subfolder).\n", " \"\"\"\n", " # Canonical + nested\n", " for cand in [root / 'IU-Xray', *root.rglob('IU-Xray')]:\n", " if not cand.is_dir():\n", " continue\n", " imgs = cand / 'images'\n", " lbls = cand / 'labels'\n", " if imgs.is_dir() and lbls.is_dir() and any(lbls.glob('*.xml')):\n", " return imgs, lbls\n", " # Legacy: labels/ecgen-radiology/*.xml\n", " legacy = lbls / 'ecgen-radiology'\n", " if imgs.is_dir() and legacy.is_dir() and any(legacy.glob('*.xml')):\n", " return imgs, legacy\n", "\n", " # Fallback: any images/ with CXR*.png + any folder with XML\n", " img_dir = lbl_dir = None\n", " for cand in [root / 'images', *root.rglob('images')]:\n", " if cand.is_dir() and any(cand.glob('CXR*.png')):\n", " img_dir = cand; break\n", " for cand in [root / 'labels', *root.rglob('labels')]:\n", " if cand.is_dir() and any(cand.glob('*.xml')):\n", " lbl_dir = cand; break\n", " if lbl_dir is None:\n", " # very last resort — any ecgen-radiology folder with XMLs\n", " for cand in root.rglob('ecgen-radiology'):\n", " if cand.is_dir() and any(cand.glob('*.xml')):\n", " lbl_dir = cand; break\n", " return img_dir, lbl_dir\n", "\n", "\n", "# Filled in below depending on DATASET_NAME\n", "CXR_ROOT = None # MIMIC-CXR root (with train/valid/test subdirs)\n", "SPLIT_DIRS = None # MIMIC only\n", "VQA_ROOT = None # MIMIC only\n", "MR_ROOT = None # MIMIC-CXR_resized root (manifests + files/ + vqa/)\n", "IU_IMAGES_DIR = None # IU-Xray only\n", "IU_LABELS_DIR = None # IU-Xray only\n", "\n", "if DATASET_NAME == 'MIMIC-CXR':\n", " CXR_ROOT = find_split_parent(DATA_SRC)\n", " print('MIMIC-CXR root:', CXR_ROOT)\n", "\n", " SPLIT_DIRS = {\n", " 'train' : ('train', CXR_ROOT / 'train'),\n", " 'validate': ('valid', CXR_ROOT / 'valid'),\n", " 'test' : ('test', CXR_ROOT / 'test'),\n", " }\n", " for s, (sub, d) in SPLIT_DIRS.items():\n", " assert d.exists(), f'Missing split dir: {d}'\n", " print(f' {s:<9s} → {d}')\n", "\n", " for p in DATA_SRC.rglob('MIMIC-Ext-MIMIC-CXR-VQA'):\n", " cand = p / 'dataset'\n", " if cand.exists() and (cand / 'train.json').exists():\n", " VQA_ROOT = cand\n", " break\n", " assert VQA_ROOT is not None, 'VQA dataset folder not found under ' + str(DATA_SRC)\n", " print('VQA root:', VQA_ROOT)\n", "\n", "elif DATASET_NAME == 'MIMIC-CXR_resized':\n", " MR_ROOT = find_mimic_resized_root(DATA_SRC)\n", " print('MIMIC-CXR_resized root:', MR_ROOT)\n", " # Sanity: 3 manifest CSVs, files/ (images+reports), vqa/\n", " for cf in ('manifest_train.csv', 'manifest_val.csv', 'manifest_test.csv'):\n", " f = MR_ROOT / cf\n", " print(f' {cf}: {\"OK\" if f.is_file() else \"MISSING\"}')\n", " for sub in ('files', 'vqa'):\n", " d = MR_ROOT / sub\n", " print(f' {sub:<5s}: {\"OK\" if d.is_dir() else \"MISSING\"} ({d})')\n", " # Spot-check one report (.txt) sits at patient-dir level inside files/\n", " txt_hits = list((MR_ROOT / 'files').glob('p*/p*/s*.txt')) if (MR_ROOT / 'files').is_dir() else []\n", " print(f' reports inside files/ : {len(txt_hits):,} found (sample: {txt_hits[0] if txt_hits else \"—\"})')\n", "\n", "else: # IU-Xray\n", " IU_IMAGES_DIR, IU_LABELS_DIR = find_iu_dirs(DATA_SRC)\n", " assert IU_IMAGES_DIR is not None, f'IU images/ not found under {DATA_SRC}'\n", " assert IU_LABELS_DIR is not None, f'IU labels/ (with *.xml) not found under {DATA_SRC}'\n", " print('IU images dir:', IU_IMAGES_DIR, '→', len(list(IU_IMAGES_DIR.glob('*.png'))), 'PNGs')\n", " print('IU labels dir:', IU_LABELS_DIR, '→', len(list(IU_LABELS_DIR.glob('*.xml'))), 'XMLs')" ] }, { "cell_type": "markdown", "id": "pc-build-md", "metadata": {}, "source": [ "## 5. Build the instruct JSON (same resolver as training)" ] }, { "cell_type": "code", "id": "pc-build", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# ── Build the instruct JSON for the chosen mode (same resolver as the trainer) ──\n", "from omegaconf import OmegaConf\n", "from utils.dataset_resolver import resolve_dataset_spec\n", "\n", "cfg = OmegaConf.load(PROJECT / 'configs' / 'train_config.yaml')\n", "cfg.data.dataset_name = DATASET_NAME\n", "cfg.data.report_mode = REPORT_MODE\n", "cfg.data.image_mode = IMAGE_MODE\n", "\n", "if DATASET_NAME == 'IU-Xray':\n", " cfg.data.iu_xray.images_dir = str(IU_IMAGES_DIR)\n", " cfg.data.iu_xray.labels_dir = str(IU_LABELS_DIR)\n", " cfg.data.iu_xray.instruct_json = str(PROJECT / 'data/data_files/iu_xray_instruct.json')\n", " cfg.data.iu_xray.auto_build = True\n", "elif DATASET_NAME == 'MIMIC-CXR_resized':\n", " cfg.data.mimic_cxr_resized.root = str(MR_ROOT)\n", " cfg.data.mimic_cxr_resized.manifest_dir = None\n", " cfg.data.mimic_cxr_resized.vqa_dir = None\n", " cfg.data.mimic_cxr_resized.reports_root = None\n", " cfg.data.mimic_cxr_resized.instruct_json = str(PROJECT / 'data/data_files/mimic_cxr_resized_instruct.json')\n", " cfg.data.mimic_cxr_resized.auto_build = True\n", "else: # MIMIC-CXR\n", " cfg.data.mimic_cxr_root = str(CXR_ROOT)\n", " cfg.data.instruct_json = str(PROJECT / 'data/data_files/mimic_cxr_instruct_unified.json')\n", " cfg.data.mimic_auto_build = True\n", " _cx = (sorted(DATA_SRC.rglob('*chexpert*.csv')) or sorted(DATA_SRC.rglob('*chexbert*.csv')))\n", " cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None\n", " cfg.data.mimic_vqa_root = str(VQA_ROOT) if VQA_ROOT is not None else None\n", "\n", "spec = resolve_dataset_spec(cfg) # builds the suffixed JSON if missing\n", "INSTRUCT_JSON = spec.instruct_json\n", "OUT_PT = str(PROJECT / 'data/data_files' / _CACHE_NAME[DATASET_NAME])\n", "import os\n", "print('INSTRUCT_JSON ->', INSTRUCT_JSON, ' exists:', os.path.exists(INSTRUCT_JSON))\n", "print('OUT_PT ->', OUT_PT)" ] }, { "cell_type": "markdown", "id": "pc-collect-md", "metadata": {}, "source": [ "## 6. Collect per-study canonical text" ] }, { "cell_type": "code", "id": "pc-collect", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "import json\n", "\n", "with open(INSTRUCT_JSON, 'r', encoding='utf-8') as f:\n", " samples = json.load(f)\n", "print(f'loaded {len(samples):,} instruct samples')\n", "\n", "# key -> {findings, impression, report}; key = study_id (MIMIC) else image_path\n", "# (IU-Xray). MUST match the ITC lookup key in data/dataset.py.\n", "per_study = {}\n", "for s in samples:\n", " sid = s.get('study_id') or s.get('image_path')\n", " if not sid:\n", " continue\n", " tgt = (s.get('target') or '').strip()\n", " if not tgt:\n", " continue\n", " d = per_study.setdefault(sid, {})\n", " task = s.get('task')\n", " if task in ('findings', 'impression', 'report'):\n", " d.setdefault(task, tgt)\n", "\n", "def _canonical(d):\n", " if d.get('findings'):\n", " return d['findings']\n", " if FALLBACK_TO_IMPRESSION and d.get('impression'):\n", " return d['impression']\n", " return d.get('report')\n", "\n", "study_text = {k: _canonical(v) for k, v in per_study.items()}\n", "study_text = {k: v for k, v in study_text.items() if v}\n", "print(f'keys with usable text: {len(study_text):,} / {len(per_study):,}')\n", "next(iter(study_text.items()))" ] }, { "cell_type": "markdown", "id": "pc-enc-md", "metadata": {}, "source": [ "## 7. Encode with CXR-BERT (projected 128-d, L2-norm)" ] }, { "cell_type": "code", "id": "pc-enc", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from transformers import AutoModel, AutoTokenizer\n", "from tqdm.auto import tqdm\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "tok = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "mdl = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True).eval().to(device)\n", "\n", "sids = list(study_text.keys())\n", "texts = [study_text[s] for s in sids]\n", "\n", "@torch.no_grad()\n", "def encode(batch_texts):\n", " enc = tok(batch_texts, padding=True, truncation=True,\n", " max_length=MAX_LEN, return_tensors='pt').to(device)\n", " emb = mdl.get_projected_text_embeddings(input_ids=enc.input_ids,\n", " attention_mask=enc.attention_mask)\n", " return F.normalize(emb, dim=-1).cpu()\n", "\n", "embeds = {}\n", "for i in tqdm(range(0, len(texts), BATCH_SIZE)):\n", " chunk = sids[i:i + BATCH_SIZE]\n", " out = encode(texts[i:i + BATCH_SIZE])\n", " for sid, v in zip(chunk, out):\n", " embeds[sid] = v.clone()\n", "\n", "proj_dim = next(iter(embeds.values())).shape[0]\n", "print(f'encoded {len(embeds):,} keys; proj_dim = {proj_dim}')\n", "assert proj_dim == 128, f'expected 128-d, got {proj_dim}'" ] }, { "cell_type": "markdown", "id": "pc-save-md", "metadata": {}, "source": [ "## 8. Save cache" ] }, { "cell_type": "code", "id": "pc-save", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "from pathlib import Path\n", "Path(OUT_PT).parent.mkdir(parents=True, exist_ok=True)\n", "torch.save({\n", " 'embeds': embeds,\n", " 'meta': {\n", " 'dataset': DATASET_NAME, 'report_mode': REPORT_MODE, 'image_mode': IMAGE_MODE,\n", " 'model': MODEL_NAME, 'proj_dim': proj_dim, 'source_json': INSTRUCT_JSON,\n", " 'max_len': MAX_LEN, 'fallback_to_impression': FALLBACK_TO_IMPRESSION,\n", " 'n_studies': len(embeds),\n", " },\n", "}, OUT_PT)\n", "print(f'saved -> {OUT_PT} ({Path(OUT_PT).stat().st_size/1e6:.1f} MB)')" ] }, { "cell_type": "markdown", "id": "pc-up-md", "metadata": {}, "source": [ "## 9. Upload to hieu3636/cxr-vlm-data" ] }, { "cell_type": "code", "id": "pc-up", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "if DO_UPLOAD:\n", " from huggingface_hub import HfApi\n", " api = HfApi(token=os.environ.get('HF_TOKEN'))\n", " path_in_repo = f'{HF_SUBDIR}/{Path(OUT_PT).name}'\n", " api.upload_file(path_or_fileobj=OUT_PT, path_in_repo=path_in_repo,\n", " repo_id=HF_REPO_ID, repo_type=HF_REPO_TYPE)\n", " print(f'uploaded -> {HF_REPO_ID}/{path_in_repo}')\n", "else:\n", " print('DO_UPLOAD=False — skipped')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" } }, "nbformat": 4, "nbformat_minor": 5 }