{ "cells": [ { "cell_type": "markdown", "id": "5a6c76f2", "metadata": {}, "source": [ "## Chest2VEC" ] }, { "cell_type": "code", "execution_count": null, "id": "26215417", "metadata": {}, "outputs": [], "source": [ "from chest2vec import Chest2Vec\n", "import os\n", "os.environ[\"HF_HOME\"] = \"/model/huggingface\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n", "\n", "\n", "\n", "m = Chest2Vec.from_pretrained(\"chest2vec/chest2vec_0.6b_cxr\", device=\"cuda:0\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "624ad061", "metadata": {}, "outputs": [], "source": [ "instructions = [\"Find findings about the lungs.\"]\n", "queries = [\"Consolidation in the right lower lobe.\"]\n", "\n", "out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)\n", "\n", "# Global embedding (derived): mean of 9 section vectors then L2-normalized\n", "g = out.global_embedding # [N, H]\n", "\n", "# Per-section embeddings (by full name)\n", "lung = out.by_section_name[\"Lungs and Airways\"] # [N, H]\n", "imp = out.by_section_name[\"impression\"] # [N, H]\n", "\n", "# Or use aliases (case-insensitive)\n", "lung = out.by_alias[\"lungs\"] # [N, H]\n", "cardio = out.by_alias[\"cardio\"] # [N, H]\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "b083b9a8", "metadata": {}, "outputs": [], "source": [ "candidates = [\n", " \"Lungs are clear. No focal consolidation.\",\n", " \"Pleural effusion on the left.\",\n", " \"Right lower lobe consolidation.\",\n", " \"Cardiomediastinal silhouette is normal.\"\n", "]\n", "\n", "cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)\n", "\n", "cand_global = cand_out.global_embedding # [N, H]\n", "cand_lung = cand_out.by_alias[\"lungs\"] # [N, H]\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "98ebf6d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top-5 scores: tensor([ 0.3646, -0.0407, -0.0810, -0.1504])\n", "Top-5 indices: tensor([2, 1, 3, 0])\n" ] } ], "source": [ "# Query embeddings for \"Lungs and Airways\" section\n", "q = out.global_embedding # [Nq, H]\n", "\n", "# Document embeddings for \"Lungs and Airways\" section\n", "d = cand_out.global_embedding # [Nd, H]\n", "\n", "# Compute top-k cosine similarities\n", "scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device=\"cuda\")\n", "# scores: [Nq, k] - similarity scores\n", "# idx: [Nq, k] - indices of top-k candidates\n", "\n", "print(f\"Top-5 scores: {scores[0]}\")\n", "print(f\"Top-5 indices: {idx[0]}\")\n" ] }, { "cell_type": "markdown", "id": "906d89b8", "metadata": {}, "source": [ "## CT" ] }, { "cell_type": "code", "execution_count": null, "id": "347dd738", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "EmbedOutput(section_matrix=tensor([[[ 6.9175e-02, -1.6013e-04, -1.1102e-02, ..., -3.0460e-02,\n", " -5.8357e-02, 3.6722e-02],\n", " [ 7.2193e-02, -6.4974e-04, -1.3356e-02, ..., -2.9022e-02,\n", " -6.0931e-02, 3.6963e-02],\n", " [ 7.2526e-02, -2.2293e-03, -1.8355e-02, ..., -2.7643e-02,\n", " -6.0637e-02, 4.1613e-02],\n", " ...,\n", " [ 7.4716e-02, -3.1762e-03, -2.3746e-02, ..., -1.9208e-02,\n", " -5.9592e-02, 4.8399e-02],\n", " [ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n", " -5.8251e-02, 3.8930e-02],\n", " [ 7.5282e-02, -3.2632e-03, -2.3526e-02, ..., -1.8972e-02,\n", " -6.0407e-02, 4.7962e-02]]]), global_embedding=tensor([[ 0.0731, -0.0017, -0.0180, ..., -0.0245, -0.0603, 0.0423]]), by_section_name={'Lungs and Airways': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'Pleura': tensor([[ 0.0722, -0.0006, -0.0134, ..., -0.0290, -0.0609, 0.0370]]), 'Cardiovascular': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'Hila and Mediastinum': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'Tubes & Devices': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'Musculoskeletal and Chest Wall': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'Abdominal': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'impression': tensor([[ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n", " -5.8251e-02, 3.8930e-02]]), 'Other': tensor([[ 0.0753, -0.0033, -0.0235, ..., -0.0190, -0.0604, 0.0480]])}, by_alias={'global': tensor([[ 0.0731, -0.0017, -0.0180, ..., -0.0245, -0.0603, 0.0423]]), 'lungs': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'lung': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'pleura': tensor([[ 0.0722, -0.0006, -0.0134, ..., -0.0290, -0.0609, 0.0370]]), 'cardio': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'cardiovascular': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'hila': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'mediastinum': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'tubes': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'devices': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'msk': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'musculoskeletal': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'abd': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'abdominal': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'impression': tensor([[ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n", " -5.8251e-02, 3.8930e-02]]), 'other': tensor([[ 0.0753, -0.0033, -0.0235, ..., -0.0190, -0.0604, 0.0480]])})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# !pip install nibabel, monai\n", "import numpy as np\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "\n", "# Optional (for nicer overlays). If scipy isn't installed, code will fall back.\n", "try:\n", " from scipy.ndimage import binary_erosion\n", " _HAS_SCIPY = True\n", "except Exception:\n", " _HAS_SCIPY = False\n", "\n", "# Point this to your preprocessed folder\n", "NPZ_ROOT = Path(\"./data/preprocessed\") # <-- EDIT\n" ] }, { "cell_type": "code", "execution_count": null, "id": "edbddaf7", "metadata": {}, "outputs": [], "source": [ "def load_npz_case(npz_path: Path):\n", " npz_path = Path(npz_path)\n", " with np.load(npz_path, allow_pickle=False) as z:\n", " keys = list(z.keys())\n", "\n", " # Support a few common key names\n", " if \"ct\" in keys:\n", " ct = z[\"ct\"]\n", " elif \"image\" in keys:\n", " ct = z[\"image\"]\n", " else:\n", " raise KeyError(f\"No CT key found. Available keys: {keys}\")\n", "\n", " rex = z[\"rex\"] if \"rex\" in keys else None\n", " tot = z[\"totalseg\"] if \"totalseg\" in keys else (z[\"label\"] if \"label\" in keys else None)\n", "\n", " # Basic sanity checks\n", " assert ct.ndim == 4 and ct.shape[0] == 1, f\"Expected ct shape (1,D,H,W), got {ct.shape}\"\n", " D, H, W = ct.shape[1], ct.shape[2], ct.shape[3]\n", "\n", " if tot is not None:\n", " assert tot.ndim == 4 and tot.shape[0] == 1, f\"Expected totalseg shape (1,D,H,W), got {tot.shape}\"\n", " assert tot.shape[1:] == (D, H, W), f\"totalseg spatial mismatch: {tot.shape} vs ct {ct.shape}\"\n", "\n", " if rex is not None:\n", " assert rex.ndim == 4, f\"Expected rex shape (F,D,H,W), got {rex.shape}\"\n", " assert rex.shape[1:] == (D, H, W), f\"rex spatial mismatch: {rex.shape} vs ct {ct.shape}\"\n", "\n", " return ct, rex, tot, keys\n", "\n", "# List files\n", "npz_files = sorted(NPZ_ROOT.rglob(\"*.npz\"))\n", "print(\"Found npz files:\", len(npz_files))\n", "print(\"Example:\", npz_files[0] if npz_files else \"NONE\")\n", "\n", "# Pick one (edit index or set by name)\n", "case_path = npz_files[0] # <-- change to inspect a specific file\n", "ct, rex, tot, keys = load_npz_case(case_path)\n", "\n", "print(\"Loaded:\", case_path.name)\n", "print(\"Keys:\", keys)\n", "print(\"CT:\", ct.shape, ct.dtype, f\"min={ct.min():.3f}, max={ct.max():.3f}\")\n", "print(\"Rex:\", None if rex is None else (rex.shape, rex.dtype, f\"channels={rex.shape[0]}\"))\n", "print(\"Tot:\", None if tot is None else (tot.shape, tot.dtype))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c1d10b70", "metadata": {}, "outputs": [], "source": [ "def choose_rex_channel(rex_arr: np.ndarray):\n", " \"\"\"\n", " Returns (best_channel_index, counts_per_channel)\n", " counts = number of voxels > 0 in each channel\n", " \"\"\"\n", " if rex_arr is None:\n", " return None, None\n", " counts = (rex_arr > 0).reshape(rex_arr.shape[0], -1).sum(axis=1)\n", " best = int(np.argmax(counts))\n", " return best, counts\n", "\n", "rex_ch, rex_counts = choose_rex_channel(rex)\n", "if rex is not None:\n", " print(\"Top 10 ReX channels by voxel count:\")\n", " top = np.argsort(-rex_counts)[:10]\n", " for i in top:\n", " print(f\" ch={int(i):4d} voxels={int(rex_counts[i])}\")\n", " print(\"Auto-selected channel:\", rex_ch)\n", "else:\n", " print(\"No ReX mask in this NPZ.\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3c75853f", "metadata": {}, "outputs": [], "source": [ "def mask_edges_2d(m2d: np.ndarray) -> np.ndarray:\n", " \"\"\"Thin-ish edge for 2D mask.\"\"\"\n", " m2d = (m2d > 0)\n", " if not _HAS_SCIPY:\n", " return m2d.astype(np.uint8) # fallback: filled mask\n", " er = binary_erosion(m2d)\n", " return (m2d ^ er).astype(np.uint8)\n", "\n", "def top_slices_by_area(mask_3d: np.ndarray, topk: int = 8):\n", " \"\"\"\n", " mask_3d: (D,H,W) boolean/int\n", " returns list of axial slice indices with largest mask area\n", " \"\"\"\n", " areas = (mask_3d > 0).sum(axis=(1,2))\n", " idx = np.argsort(-areas)[:topk]\n", " return [int(i) for i in idx if areas[i] > 0], areas\n", "\n", "# Build binary masks for display\n", "ct_vol = ct[0] # (D,H,W)\n", "rex_mask = None\n", "if rex is not None:\n", " rex_mask = (rex[rex_ch] > 0) # (D,H,W)\n", "\n", "tot_mask = None\n", "if tot is not None:\n", " tot_mask = (tot[0] > 0) # (D,H,W)\n", "\n", "if rex_mask is not None:\n", " idxs, areas = top_slices_by_area(rex_mask, topk=10)\n", " print(\"Top axial slices by ReX area:\", idxs[:10])\n", "else:\n", " print(\"No ReX mask to suggest slices.\")\n", "\n", "if tot_mask is not None:\n", " idxs2, areas2 = top_slices_by_area(tot_mask, topk=10)\n", " print(\"Top axial slices by TotalSeg area:\", idxs2[:10])\n", "else:\n", " print(\"No TotalSeg mask to suggest slices.\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "683579a8", "metadata": {}, "outputs": [], "source": [ "def show_axial_grid(ct_vol, rex_mask=None, tot_mask=None, slice_indices=None, rex_title=\"ReX\", tot_title=\"TotalSeg\"):\n", " \"\"\"\n", " ct_vol: (D,H,W) float\n", " rex_mask / tot_mask: (D,H,W) bool\n", " slice_indices: list[int]\n", " \"\"\"\n", " if slice_indices is None or len(slice_indices) == 0:\n", " slice_indices = [ct_vol.shape[0] // 2]\n", "\n", " n = len(slice_indices)\n", " fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(12, 4*n))\n", " if n == 1:\n", " axes = np.array([axes])\n", "\n", " for r, d in enumerate(slice_indices):\n", " ct2d = ct_vol[d]\n", "\n", " # Panel 1: CT\n", " ax = axes[r, 0]\n", " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n", " ax.set_title(f\"CT (axial d={d})\")\n", " ax.axis(\"off\")\n", "\n", " # Panel 2: CT + ReX\n", " ax = axes[r, 1]\n", " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n", " if rex_mask is not None:\n", " e = mask_edges_2d(rex_mask[d])\n", " ax.imshow(e, cmap=\"Reds\", alpha=0.7, origin=\"lower\")\n", " ax.set_title(f\"CT + {rex_title}\")\n", " ax.axis(\"off\")\n", "\n", " # Panel 3: CT + TotalSeg\n", " ax = axes[r, 2]\n", " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n", " if tot_mask is not None:\n", " e = mask_edges_2d(tot_mask[d])\n", " ax.imshow(e, cmap=\"Blues\", alpha=0.6, origin=\"lower\")\n", " ax.set_title(f\"CT + {tot_title}\")\n", " ax.axis(\"off\")\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Choose slices to visualize (prefer slices with ReX content if present)\n", "if rex_mask is not None:\n", " slices, _ = top_slices_by_area(rex_mask, topk=3)\n", " if len(slices) == 0:\n", " slices = [ct_vol.shape[0]//2]\n", "else:\n", " slices = [ct_vol.shape[0]//2]\n", "\n", "show_axial_grid(ct_vol, rex_mask=rex_mask, tot_mask=tot_mask, slice_indices=slices[:3],\n", " rex_title=f\"ReX(ch={rex_ch})\", tot_title=\"TotalSeg\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }