{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/javascript": "\nvar cell = this.closest('.cell');\nif (cell) {\n cell.classList.remove('output_scroll');\n}\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "Bad key paths in file /export/home/daifang/.config/matplotlib/matplotlibrc, line 3 ('paths: /export/home/daifang/fonts/arial/')\n", "You probably need to get an updated matplotlibrc file from\n", "https://github.com/matplotlib/matplotlib/blob/v3.3.4/matplotlibrc.template\n", "or from the matplotlib source distribution\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda:1\n", "\n", "[Epoch 001] Train Loss=4.903 | Val Loss=4.777\n", " All | Subtype AUC=0.610 | TNM AUC=0.571 | DFS C-index=0.631 | OS C-index=0.454\n", " Immune | Subtype AUC=0.571 | TNM AUC=0.585 | DFS C-index=0.627 | OS C-index=0.478\n", " Chemo | Subtype AUC=0.675 | TNM AUC=0.559 | DFS C-index=0.588 | OS C-index=0.476\n", " ✓ Best model updated\n", "\n", "[Epoch 002] Train Loss=4.770 | Val Loss=4.596\n", " All | Subtype AUC=0.678 | TNM AUC=0.683 | DFS C-index=0.606 | OS C-index=0.536\n", " Immune | Subtype AUC=0.691 | TNM AUC=0.654 | DFS C-index=0.585 | OS C-index=0.591\n", " Chemo | Subtype AUC=0.634 | TNM AUC=0.729 | DFS C-index=0.639 | OS C-index=0.456\n", " ✓ Best model updated\n", "\n", "[Epoch 003] Train Loss=4.551 | Val Loss=4.404\n", " All | Subtype AUC=0.698 | TNM AUC=0.678 | DFS C-index=0.592 | OS C-index=0.450\n", " Immune | Subtype AUC=0.667 | TNM AUC=0.686 | DFS C-index=0.635 | OS C-index=0.480\n", " Chemo | Subtype AUC=0.762 | TNM AUC=0.687 | DFS C-index=0.541 | OS C-index=0.379\n", " ✓ Best model updated\n", "\n", "[Epoch 004] Train Loss=4.299 | Val Loss=4.076\n", " All | Subtype AUC=0.763 | TNM AUC=0.753 | DFS C-index=0.550 | OS C-index=0.597\n", " Immune | Subtype AUC=0.720 | TNM AUC=0.745 | DFS C-index=0.554 | OS C-index=0.606\n", " Chemo | Subtype AUC=0.824 | TNM AUC=0.735 | DFS C-index=0.510 | OS C-index=0.589\n", " ✓ Best model updated\n", "\n", "[Epoch 005] Train Loss=3.962 | Val Loss=3.941\n", " All | Subtype AUC=0.693 | TNM AUC=0.703 | DFS C-index=0.640 | OS C-index=0.658\n", " Immune | Subtype AUC=0.804 | TNM AUC=0.714 | DFS C-index=0.607 | OS C-index=0.675\n", " Chemo | Subtype AUC=0.539 | TNM AUC=0.694 | DFS C-index=0.694 | OS C-index=0.621\n", " ✓ Best model updated\n", "\n", "[Epoch 006] Train Loss=3.911 | Val Loss=3.840\n", " All | Subtype AUC=0.810 | TNM AUC=0.636 | DFS C-index=0.591 | OS C-index=0.549\n", " Immune | Subtype AUC=0.840 | TNM AUC=0.657 | DFS C-index=0.552 | OS C-index=0.551\n", " Chemo | Subtype AUC=0.771 | TNM AUC=0.618 | DFS C-index=0.580 | OS C-index=0.577\n", " ✓ Best model updated\n", "\n", "[Epoch 007] Train Loss=3.823 | Val Loss=3.876\n", " All | Subtype AUC=0.800 | TNM AUC=0.722 | DFS C-index=0.645 | OS C-index=0.666\n", " Immune | Subtype AUC=0.833 | TNM AUC=0.764 | DFS C-index=0.604 | OS C-index=0.654\n", " Chemo | Subtype AUC=0.750 | TNM AUC=0.668 | DFS C-index=0.671 | OS C-index=0.605\n", "\n", "[Epoch 008] Train Loss=3.779 | Val Loss=3.794\n", " All | Subtype AUC=0.689 | TNM AUC=0.711 | DFS C-index=0.654 | OS C-index=0.616\n", " Immune | Subtype AUC=0.729 | TNM AUC=0.729 | DFS C-index=0.680 | OS C-index=0.682\n", " Chemo | Subtype AUC=0.637 | TNM AUC=0.725 | DFS C-index=0.624 | OS C-index=0.565\n", " ✓ Best model updated\n", "\n", "[Epoch 009] Train Loss=3.725 | Val Loss=3.630\n", " All | Subtype AUC=0.856 | TNM AUC=0.718 | DFS C-index=0.692 | OS C-index=0.677\n", " Immune | Subtype AUC=0.897 | TNM AUC=0.791 | DFS C-index=0.713 | OS C-index=0.675\n", " Chemo | Subtype AUC=0.785 | TNM AUC=0.652 | DFS C-index=0.678 | OS C-index=0.677\n", " ✓ Best model updated\n", "\n", "[Epoch 010] Train Loss=3.706 | Val Loss=3.646\n", " All | Subtype AUC=0.749 | TNM AUC=0.773 | DFS C-index=0.659 | OS C-index=0.585\n", " Immune | Subtype AUC=0.817 | TNM AUC=0.723 | DFS C-index=0.638 | OS C-index=0.596\n", " Chemo | Subtype AUC=0.649 | TNM AUC=0.855 | DFS C-index=0.710 | OS C-index=0.540\n", "\n", "[Epoch 011] Train Loss=3.694 | Val Loss=3.690\n", " All | Subtype AUC=0.791 | TNM AUC=0.769 | DFS C-index=0.637 | OS C-index=0.542\n", " Immune | Subtype AUC=0.812 | TNM AUC=0.731 | DFS C-index=0.677 | OS C-index=0.593\n", " Chemo | Subtype AUC=0.787 | TNM AUC=0.812 | DFS C-index=0.604 | OS C-index=0.488\n", "\n", "[Epoch 012] Train Loss=3.600 | Val Loss=3.469\n", " All | Subtype AUC=0.875 | TNM AUC=0.744 | DFS C-index=0.663 | OS C-index=0.662\n", " Immune | Subtype AUC=0.896 | TNM AUC=0.754 | DFS C-index=0.671 | OS C-index=0.753\n", " Chemo | Subtype AUC=0.887 | TNM AUC=0.749 | DFS C-index=0.635 | OS C-index=0.573\n", " ✓ Best model updated\n", "\n", "[Epoch 013] Train Loss=3.641 | Val Loss=3.486\n", " All | Subtype AUC=0.870 | TNM AUC=0.701 | DFS C-index=0.659 | OS C-index=0.649\n", " Immune | Subtype AUC=0.912 | TNM AUC=0.701 | DFS C-index=0.708 | OS C-index=0.722\n", " Chemo | Subtype AUC=0.811 | TNM AUC=0.724 | DFS C-index=0.600 | OS C-index=0.512\n", "\n", "[Epoch 014] Train Loss=3.554 | Val Loss=3.560\n", " All | Subtype AUC=0.796 | TNM AUC=0.672 | DFS C-index=0.670 | OS C-index=0.656\n", " Immune | Subtype AUC=0.773 | TNM AUC=0.739 | DFS C-index=0.685 | OS C-index=0.753\n", " Chemo | Subtype AUC=0.818 | TNM AUC=0.600 | DFS C-index=0.659 | OS C-index=0.556\n", "\n", "[Epoch 015] Train Loss=3.490 | Val Loss=3.608\n", " All | Subtype AUC=0.815 | TNM AUC=0.730 | DFS C-index=0.620 | OS C-index=0.622\n", " Immune | Subtype AUC=0.850 | TNM AUC=0.846 | DFS C-index=0.657 | OS C-index=0.635\n", " Chemo | Subtype AUC=0.752 | TNM AUC=0.590 | DFS C-index=0.565 | OS C-index=0.629\n", "\n", "[Epoch 016] Train Loss=3.559 | Val Loss=3.347\n", " All | Subtype AUC=0.793 | TNM AUC=0.782 | DFS C-index=0.655 | OS C-index=0.636\n", " Immune | Subtype AUC=0.889 | TNM AUC=0.781 | DFS C-index=0.643 | OS C-index=0.659\n", " Chemo | Subtype AUC=0.673 | TNM AUC=0.809 | DFS C-index=0.663 | OS C-index=0.605\n", " ✓ Best model updated\n", "\n", "[Epoch 017] Train Loss=3.530 | Val Loss=3.405\n", " All | Subtype AUC=0.840 | TNM AUC=0.731 | DFS C-index=0.645 | OS C-index=0.611\n", " Immune | Subtype AUC=0.832 | TNM AUC=0.738 | DFS C-index=0.688 | OS C-index=0.654\n", " Chemo | Subtype AUC=0.847 | TNM AUC=0.728 | DFS C-index=0.643 | OS C-index=0.597\n", "\n", "[Epoch 018] Train Loss=3.545 | Val Loss=3.434\n", " All | Subtype AUC=0.912 | TNM AUC=0.729 | DFS C-index=0.660 | OS C-index=0.595\n", " Immune | Subtype AUC=0.979 | TNM AUC=0.677 | DFS C-index=0.702 | OS C-index=0.585\n", " Chemo | Subtype AUC=0.804 | TNM AUC=0.767 | DFS C-index=0.671 | OS C-index=0.605\n", "\n", "[Epoch 019] Train Loss=3.605 | Val Loss=3.394\n", " All | Subtype AUC=0.807 | TNM AUC=0.722 | DFS C-index=0.655 | OS C-index=0.664\n", " Immune | Subtype AUC=0.858 | TNM AUC=0.736 | DFS C-index=0.674 | OS C-index=0.701\n", " Chemo | Subtype AUC=0.713 | TNM AUC=0.768 | DFS C-index=0.675 | OS C-index=0.645\n", "\n", "[Epoch 020] Train Loss=3.386 | Val Loss=3.381\n", " All | Subtype AUC=0.784 | TNM AUC=0.757 | DFS C-index=0.674 | OS C-index=0.646\n", " Immune | Subtype AUC=0.796 | TNM AUC=0.732 | DFS C-index=0.696 | OS C-index=0.727\n", " Chemo | Subtype AUC=0.758 | TNM AUC=0.818 | DFS C-index=0.675 | OS C-index=0.565\n", "\n", "[Epoch 021] Train Loss=3.505 | Val Loss=3.388\n", " All | Subtype AUC=0.767 | TNM AUC=0.836 | DFS C-index=0.679 | OS C-index=0.632\n", " Immune | Subtype AUC=0.800 | TNM AUC=0.877 | DFS C-index=0.641 | OS C-index=0.696\n", " Chemo | Subtype AUC=0.663 | TNM AUC=0.772 | DFS C-index=0.714 | OS C-index=0.544\n", "\n", "[Epoch 022] Train Loss=3.577 | Val Loss=3.618\n", " All | Subtype AUC=0.723 | TNM AUC=0.683 | DFS C-index=0.641 | OS C-index=0.569\n", " Immune | Subtype AUC=0.711 | TNM AUC=0.592 | DFS C-index=0.699 | OS C-index=0.619\n", " Chemo | Subtype AUC=0.745 | TNM AUC=0.765 | DFS C-index=0.635 | OS C-index=0.524\n", "\n", "[Epoch 023] Train Loss=3.420 | Val Loss=3.244\n", " All | Subtype AUC=0.874 | TNM AUC=0.752 | DFS C-index=0.654 | OS C-index=0.653\n", " Immune | Subtype AUC=0.912 | TNM AUC=0.839 | DFS C-index=0.652 | OS C-index=0.701\n", " Chemo | Subtype AUC=0.833 | TNM AUC=0.629 | DFS C-index=0.639 | OS C-index=0.625\n", " ✓ Best model updated\n", "\n", "[Epoch 024] Train Loss=3.525 | Val Loss=3.595\n", " All | Subtype AUC=0.775 | TNM AUC=0.754 | DFS C-index=0.600 | OS C-index=0.610\n", " Immune | Subtype AUC=0.733 | TNM AUC=0.758 | DFS C-index=0.641 | OS C-index=0.690\n", " Chemo | Subtype AUC=0.800 | TNM AUC=0.773 | DFS C-index=0.631 | OS C-index=0.589\n", "\n", "[Epoch 025] Train Loss=3.571 | Val Loss=3.480\n", " All | Subtype AUC=0.820 | TNM AUC=0.763 | DFS C-index=0.599 | OS C-index=0.627\n", " Immune | Subtype AUC=0.825 | TNM AUC=0.742 | DFS C-index=0.557 | OS C-index=0.627\n", " Chemo | Subtype AUC=0.837 | TNM AUC=0.796 | DFS C-index=0.616 | OS C-index=0.613\n", "\n", "[Epoch 026] Train Loss=3.508 | Val Loss=3.209\n", " All | Subtype AUC=0.874 | TNM AUC=0.765 | DFS C-index=0.724 | OS C-index=0.633\n", " Immune | Subtype AUC=0.853 | TNM AUC=0.804 | DFS C-index=0.713 | OS C-index=0.688\n", " Chemo | Subtype AUC=0.903 | TNM AUC=0.719 | DFS C-index=0.788 | OS C-index=0.573\n", " ✓ Best model updated\n", "\n", "[Epoch 027] Train Loss=3.417 | Val Loss=3.468\n", " All | Subtype AUC=0.836 | TNM AUC=0.674 | DFS C-index=0.685 | OS C-index=0.672\n", " Immune | Subtype AUC=0.834 | TNM AUC=0.707 | DFS C-index=0.713 | OS C-index=0.701\n", " Chemo | Subtype AUC=0.873 | TNM AUC=0.604 | DFS C-index=0.608 | OS C-index=0.589\n", "\n", "[Epoch 028] Train Loss=3.443 | Val Loss=3.314\n", " All | Subtype AUC=0.817 | TNM AUC=0.734 | DFS C-index=0.653 | OS C-index=0.670\n", " Immune | Subtype AUC=0.875 | TNM AUC=0.738 | DFS C-index=0.604 | OS C-index=0.719\n", " Chemo | Subtype AUC=0.744 | TNM AUC=0.745 | DFS C-index=0.694 | OS C-index=0.593\n", "\n", "[Epoch 029] Train Loss=3.451 | Val Loss=3.259\n", " All | Subtype AUC=0.782 | TNM AUC=0.819 | DFS C-index=0.708 | OS C-index=0.637\n", " Immune | Subtype AUC=0.814 | TNM AUC=0.811 | DFS C-index=0.638 | OS C-index=0.688\n", " Chemo | Subtype AUC=0.750 | TNM AUC=0.833 | DFS C-index=0.776 | OS C-index=0.552\n", "\n", "[Epoch 030] Train Loss=3.556 | Val Loss=3.322\n", " All | Subtype AUC=0.880 | TNM AUC=0.762 | DFS C-index=0.650 | OS C-index=0.662\n", " Immune | Subtype AUC=0.884 | TNM AUC=0.756 | DFS C-index=0.660 | OS C-index=0.727\n", " Chemo | Subtype AUC=0.863 | TNM AUC=0.773 | DFS C-index=0.584 | OS C-index=0.548\n", "\n", "[Epoch 031] Train Loss=3.493 | Val Loss=3.233\n", " All | Subtype AUC=0.857 | TNM AUC=0.764 | DFS C-index=0.705 | OS C-index=0.638\n", " Immune | Subtype AUC=0.894 | TNM AUC=0.772 | DFS C-index=0.663 | OS C-index=0.688\n", " Chemo | Subtype AUC=0.800 | TNM AUC=0.755 | DFS C-index=0.757 | OS C-index=0.585\n", "\n", "[Epoch 032] Train Loss=3.483 | Val Loss=3.330\n", " All | Subtype AUC=0.894 | TNM AUC=0.743 | DFS C-index=0.616 | OS C-index=0.595\n", " Immune | Subtype AUC=0.917 | TNM AUC=0.792 | DFS C-index=0.613 | OS C-index=0.567\n", " Chemo | Subtype AUC=0.869 | TNM AUC=0.644 | DFS C-index=0.561 | OS C-index=0.597\n", "\n", "[Epoch 033] Train Loss=3.447 | Val Loss=3.307\n", " All | Subtype AUC=0.824 | TNM AUC=0.846 | DFS C-index=0.657 | OS C-index=0.660\n", " Immune | Subtype AUC=0.936 | TNM AUC=0.821 | DFS C-index=0.635 | OS C-index=0.740\n", " Chemo | Subtype AUC=0.655 | TNM AUC=0.898 | DFS C-index=0.671 | OS C-index=0.577\n", "\n", "[Epoch 034] Train Loss=3.396 | Val Loss=3.207\n", " All | Subtype AUC=0.779 | TNM AUC=0.746 | DFS C-index=0.679 | OS C-index=0.693\n", " Immune | Subtype AUC=0.747 | TNM AUC=0.774 | DFS C-index=0.641 | OS C-index=0.719\n", " Chemo | Subtype AUC=0.855 | TNM AUC=0.703 | DFS C-index=0.584 | OS C-index=0.605\n", " ✓ Best model updated\n", "\n", "[Epoch 035] Train Loss=3.362 | Val Loss=3.279\n", " All | Subtype AUC=0.802 | TNM AUC=0.756 | DFS C-index=0.666 | OS C-index=0.667\n", " Immune | Subtype AUC=0.861 | TNM AUC=0.826 | DFS C-index=0.749 | OS C-index=0.709\n", " Chemo | Subtype AUC=0.733 | TNM AUC=0.658 | DFS C-index=0.580 | OS C-index=0.609\n", "\n", "[Epoch 036] Train Loss=3.468 | Val Loss=3.541\n", " All | Subtype AUC=0.737 | TNM AUC=0.736 | DFS C-index=0.614 | OS C-index=0.601\n", " Immune | Subtype AUC=0.764 | TNM AUC=0.743 | DFS C-index=0.579 | OS C-index=0.635\n", " Chemo | Subtype AUC=0.725 | TNM AUC=0.736 | DFS C-index=0.616 | OS C-index=0.565\n", "\n", "[Epoch 037] Train Loss=3.259 | Val Loss=3.274\n", " All | Subtype AUC=0.883 | TNM AUC=0.793 | DFS C-index=0.654 | OS C-index=0.630\n", " Immune | Subtype AUC=0.857 | TNM AUC=0.751 | DFS C-index=0.660 | OS C-index=0.675\n", " Chemo | Subtype AUC=0.905 | TNM AUC=0.837 | DFS C-index=0.655 | OS C-index=0.569\n", "\n", "[Epoch 038] Train Loss=3.328 | Val Loss=3.333\n", " All | Subtype AUC=0.809 | TNM AUC=0.772 | DFS C-index=0.655 | OS C-index=0.630\n", " Immune | Subtype AUC=0.861 | TNM AUC=0.813 | DFS C-index=0.652 | OS C-index=0.646\n", " Chemo | Subtype AUC=0.758 | TNM AUC=0.719 | DFS C-index=0.647 | OS C-index=0.589\n", "\n", "[Epoch 039] Train Loss=3.409 | Val Loss=3.216\n", " All | Subtype AUC=0.871 | TNM AUC=0.768 | DFS C-index=0.679 | OS C-index=0.625\n", " Immune | Subtype AUC=0.882 | TNM AUC=0.744 | DFS C-index=0.783 | OS C-index=0.703\n", " Chemo | Subtype AUC=0.865 | TNM AUC=0.823 | DFS C-index=0.624 | OS C-index=0.544\n", "\n", "[Epoch 040] Train Loss=3.421 | Val Loss=3.211\n", " All | Subtype AUC=0.883 | TNM AUC=0.694 | DFS C-index=0.737 | OS C-index=0.680\n", " Immune | Subtype AUC=0.901 | TNM AUC=0.700 | DFS C-index=0.674 | OS C-index=0.701\n", " Chemo | Subtype AUC=0.870 | TNM AUC=0.688 | DFS C-index=0.741 | OS C-index=0.573\n", "\n", "[Epoch 041] Train Loss=3.318 | Val Loss=3.326\n", " All | Subtype AUC=0.821 | TNM AUC=0.778 | DFS C-index=0.680 | OS C-index=0.678\n", " Immune | Subtype AUC=0.855 | TNM AUC=0.817 | DFS C-index=0.613 | OS C-index=0.722\n", " Chemo | Subtype AUC=0.781 | TNM AUC=0.724 | DFS C-index=0.769 | OS C-index=0.661\n", "\n", "[Epoch 042] Train Loss=3.203 | Val Loss=3.307\n", " All | Subtype AUC=0.825 | TNM AUC=0.778 | DFS C-index=0.677 | OS C-index=0.603\n", " Immune | Subtype AUC=0.851 | TNM AUC=0.770 | DFS C-index=0.708 | OS C-index=0.609\n", " Chemo | Subtype AUC=0.850 | TNM AUC=0.815 | DFS C-index=0.675 | OS C-index=0.597\n", "\n", "⏹ Early stopping triggered\n", "\n", "Running inference with best model...\n", "train | loss=3.330 | Immune=105 Chemo=75\n", "val | loss=3.458 | Immune=34 Chemo=26\n", "test | loss=3.877 | Immune=32 Chemo=28\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n", "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n", "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.\n" ] } ], "source": [ "import sys\n", "sys.path.insert(0, \"/export/home/daifang/lunghospital/MM-DLS-master/MM-DLS-master\")\n", "# main.py\n", "import os\n", "import sys\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader, random_split\n", "\n", "from sklearn.metrics import roc_auc_score, accuracy_score\n", "from sklearn.preprocessing import label_binarize\n", "\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from lifelines import KaplanMeierFitter, CoxPHFitter\n", "from lifelines.statistics import multivariate_logrank_test\n", "from lifelines.utils import concordance_index\n", "from sklearn.metrics import brier_score_loss\n", "from scipy.stats import norm\n", "\n", "\n", "\n", "# =========================================================\n", "# Project path (IMPORTANT for Jupyter / HPC)\n", "# =========================================================\n", "PROJECT_ROOT = os.path.abspath(\".\")\n", "if PROJECT_ROOT not in sys.path:\n", " sys.path.insert(0, PROJECT_ROOT)\n", "\n", "# =========================================================\n", "# imports: mm_dls/ \n", "# =========================================================\n", "def _import_modules():\n", "\n", " from mm_dls.HierMM_DLS import HierMM_DLS\n", " from mm_dls.FakePatientDataset import FakePatientDataset\n", " from mm_dls.CoxphLoss import CoxPHLoss\n", " return HierMM_DLS, FakePatientDataset, CoxPHLoss\n", "\n", "\n", "HierMM_DLS, FakePatientDataset, CoxPHLoss = _import_modules()\n", "\n", "\n", "# =========================\n", "# Training configuration\n", "# =========================\n", "EPOCHS = 300\n", "PATIENCE = 8\n", "BATCH_SIZE = 4\n", "LR = 1e-4\n", "WEIGHT_DECAY = 1e-5\n", "\n", "# =========================\n", "# Task definition\n", "# =========================\n", "NUM_SUBTYPES = 2 # e.g., LUAD vs LUSC\n", "NUM_TNM = 3 # Stage I–II / III / IV\n", "\n", "# =========================\n", "# Image settings\n", "# =========================\n", "N_SLICES = 30 # max slices per patient\n", "IMG_SIZE = 224\n", "\n", "\n", "SAVE_DIR = \"./results\"\n", "FIG_DIR = \"./figures\"\n", "os.makedirs(SAVE_DIR, exist_ok=True)\n", "os.makedirs(FIG_DIR, exist_ok=True)\n", "\n", "# -------------------------\n", "# GPU (force cuda:1)\n", "# -------------------------\n", "assert torch.cuda.is_available(), \"CUDA not available\"\n", "DEVICE = torch.device(\"cuda:1\")\n", "torch.cuda.set_device(DEVICE)\n", "print(\"Using device:\", DEVICE)\n", "\n", "\n", "# =========================================================\n", "# Core utils\n", "# =========================================================\n", "def _sigmoid(x):\n", " return 1 / (1 + np.exp(-x))\n", "\n", "def _ensure_numpy(x):\n", " if isinstance(x, torch.Tensor):\n", " return x.detach().cpu().numpy()\n", " return x\n", "\n", "def _risk_to_groups(risk, q=(1/3, 2/3), labels=(\"Low\", \"Mediate\", \"High\")):\n", " \"\"\"\n", " Convert continuous risk into 3 groups by tertiles.\n", " \"\"\"\n", " r = np.asarray(risk).reshape(-1)\n", " t1, t2 = np.quantile(r, q[0]), np.quantile(r, q[1])\n", " out = np.full(len(r), labels[1], dtype=object)\n", " out[r <= t1] = labels[0]\n", " out[r >= t2] = labels[2]\n", " return out\n", "\n", "def _evaluate_survival_metrics(time, event, risk, time_point=30):\n", " \"\"\"\n", " C-index + Brier at a fixed time point.\n", " risk: higher => earlier event, so use -risk in concordance_index.\n", " \"\"\"\n", " time = np.asarray(time).reshape(-1)\n", " event = np.asarray(event).reshape(-1).astype(int)\n", " risk = np.asarray(risk).reshape(-1)\n", "\n", " c_index = concordance_index(time, -risk, event)\n", "\n", " # Brier: predict survival at time_point using a monotonic transform of risk (proxy)\n", " # This is a \"proxy\" survival probability for demo/debug; replace with proper survival model if needed.\n", " y_true = (time > time_point).astype(int) # 1 means survived beyond time_point\n", " # map risk into [0,1] survival prob proxy: higher risk => lower survival prob\n", " y_prob = 1 - (risk - risk.min()) / (risk.max() - risk.min() + 1e-8)\n", " brier = brier_score_loss(y_true, y_prob)\n", "\n", " return float(c_index), float(brier)\n", "\n", "\n", "# =========================================================\n", "# One epoch (train / eval)\n", "# =========================================================\n", "def run_epoch_verbose(model, loader, optimizer, device, train=True):\n", " ce = nn.CrossEntropyLoss()\n", " bce = nn.BCEWithLogitsLoss(reduction=\"none\")\n", " cox = CoxPHLoss()\n", "\n", " model.train() if train else model.eval()\n", "\n", " losses = []\n", "\n", " # classification\n", " sub_y_all, sub_s_all = [], []\n", " tnm_y_all, tnm_s_all = [], []\n", " treat_all = []\n", "\n", " # survival (cox risk + time/event)\n", " dfs_r_all, dfs_t_all, dfs_e_all = [], [], []\n", " os_r_all, os_t_all, os_e_all = [], [], []\n", "\n", " # survival 1y/3y/5y logits (optional save)\n", " dfs_log_all, os_log_all = [], []\n", "\n", " for batch in loader:\n", " # NOTE: dataset must return 19 items including treatment\n", " if len(batch) != 19:\n", " raise ValueError(f\"Batch length mismatch: expected 19, got {len(batch)}. \"\n", " f\"Please ensure Dataset __getitem__ returns treatment as the 19th item.\")\n", "\n", " (\n", " pid, lesion, space, rad, pet, cli,\n", " y_sub, y_tnm,\n", " dfs_t, dfs_e,\n", " os_t, os_e,\n", " dfs1, dfs3, dfs5,\n", " os1, os3, os5,\n", " treatment\n", " ) = batch\n", "\n", " lesion, space = lesion.to(device), space.to(device)\n", " rad, pet, cli = rad.to(device), pet.to(device), cli.to(device)\n", " y_sub, y_tnm = y_sub.to(device), y_tnm.to(device)\n", " dfs_t, dfs_e = dfs_t.to(device), dfs_e.to(device)\n", " os_t, os_e = os_t.to(device), os_e.to(device)\n", " treatment = treatment.to(device)\n", "\n", " dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(device)\n", " os_y = torch.stack([os1, os3, os5 ], dim=1).to(device)\n", "\n", " with torch.set_grad_enabled(train):\n", " sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(\n", " lesion, space, rad, pet, cli\n", " )\n", "\n", " loss = (\n", " ce(sub_l, y_sub) +\n", " ce(tnm_l, y_tnm) +\n", " cox(dfs_r, dfs_t, dfs_e) +\n", " cox(os_r, os_t, os_e) +\n", " bce(dfs_log, dfs_y).mean() +\n", " bce(os_log, os_y ).mean()\n", " )\n", "\n", " if train:\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " losses.append(loss.item())\n", "\n", " # ----- Collect predictions -----\n", " sub_prob = torch.softmax(sub_l, dim=1)[:, 1] # subtype prob\n", " tnm_prob = torch.softmax(tnm_l, dim=1) # [B,3]\n", "\n", " sub_s_all.append(_ensure_numpy(sub_prob))\n", " sub_y_all.append(_ensure_numpy(y_sub))\n", "\n", " tnm_s_all.append(_ensure_numpy(tnm_prob))\n", " tnm_y_all.append(_ensure_numpy(y_tnm))\n", "\n", " treat_all.append(_ensure_numpy(treatment))\n", "\n", " # survival\n", " dfs_r_all.append(_ensure_numpy(dfs_r))\n", " dfs_t_all.append(_ensure_numpy(dfs_t))\n", " dfs_e_all.append(_ensure_numpy(dfs_e))\n", "\n", " os_r_all.append(_ensure_numpy(os_r))\n", " os_t_all.append(_ensure_numpy(os_t))\n", " os_e_all.append(_ensure_numpy(os_e))\n", "\n", " dfs_log_all.append(_ensure_numpy(dfs_log))\n", " os_log_all.append(_ensure_numpy(os_log))\n", "\n", " return (\n", " float(np.mean(losses)),\n", "\n", " np.concatenate(sub_y_all),\n", " np.concatenate(sub_s_all),\n", "\n", " np.concatenate(tnm_y_all),\n", " np.concatenate(tnm_s_all),\n", "\n", " np.concatenate(treat_all),\n", "\n", " np.concatenate(dfs_r_all),\n", " np.concatenate(dfs_t_all),\n", " np.concatenate(dfs_e_all),\n", "\n", " np.concatenate(os_r_all),\n", " np.concatenate(os_t_all),\n", " np.concatenate(os_e_all),\n", "\n", " np.concatenate(dfs_log_all, axis=0), # [N,3]\n", " np.concatenate(os_log_all, axis=0), # [N,3]\n", " )\n", "\n", "\n", "# =========================================================\n", "# Evaluation by cohort (classification + survival)\n", "# =========================================================\n", "def evaluate_by_treatment(sub_y, sub_s, tnm_y, tnm_s, treat,\n", " dfs_r, dfs_t, dfs_e, os_r, os_t, os_e):\n", " results = {}\n", "\n", " cohorts = {\n", " \"All\": np.ones_like(treat, dtype=bool),\n", " \"Immune\": treat == 0,\n", " \"Chemo\": treat == 1,\n", " }\n", "\n", " for name, mask in cohorts.items():\n", " if mask.sum() < 10:\n", " continue\n", "\n", " res = {}\n", "\n", " # Subtype (binary)\n", " res[\"Subtype_AUC\"] = roc_auc_score(sub_y[mask], sub_s[mask])\n", " res[\"Subtype_ACC\"] = accuracy_score(sub_y[mask], (sub_s[mask] > 0.5).astype(int))\n", "\n", " # TNM (multiclass macro AUC + ACC)\n", " tnm_bin = label_binarize(tnm_y[mask], classes=[0, 1, 2])\n", " res[\"TNM_AUC_macro\"] = roc_auc_score(\n", " tnm_bin, tnm_s[mask], average=\"macro\", multi_class=\"ovr\"\n", " )\n", " res[\"TNM_ACC\"] = accuracy_score(\n", " tnm_y[mask], np.argmax(tnm_s[mask], axis=1)\n", " )\n", "\n", " # Survival\n", " dfs_c, dfs_b = _evaluate_survival_metrics(dfs_t[mask], dfs_e[mask], dfs_r[mask], time_point=30)\n", " os_c, os_b = _evaluate_survival_metrics(os_t[mask], os_e[mask], os_r[mask], time_point=30)\n", "\n", " res[\"DFS_C_index\"] = dfs_c\n", " res[\"DFS_Brier_30m\"] = dfs_b\n", " res[\"OS_C_index\"] = os_c\n", " res[\"OS_Brier_30m\"] = os_b\n", "\n", " results[name] = res\n", "\n", " return results\n", "\n", "\n", "# =========================================================\n", "# Figure 7: KM + HR (per cohort, per endpoint)\n", "# =========================================================\n", "def plot_km_curve_with_hr(df, title, save_prefix):\n", " \"\"\"\n", " df must contain columns: time, event, group (Low/Mediate/High)\n", " \"\"\"\n", " kmf = KaplanMeierFitter()\n", " fig, ax = plt.subplots(figsize=(8, 6), facecolor=\"white\")\n", " ax.set_facecolor(\"white\")\n", "\n", " colors = {\"Low\": \"#91c7ae\", \"Mediate\": \"#f7b977\", \"High\": \"#d87c7c\"}\n", " groups = [\"Low\", \"Mediate\", \"High\"]\n", "\n", " # plot KM\n", " lines = {}\n", " at_risk_table = []\n", " times = np.arange(0, 70, 10)\n", "\n", " for g in groups:\n", " m = df[\"group\"] == g\n", " if m.sum() == 0:\n", " continue\n", "\n", " kmf.fit(df.loc[m, \"time\"], event_observed=df.loc[m, \"event\"], label=g)\n", " kmf.plot_survival_function(\n", " ax=ax, ci_show=True, linewidth=2, color=colors[g], marker=\"+\"\n", " )\n", " lines[g] = ax.get_lines()[-1]\n", "\n", " at_risk_table.append([np.sum(df.loc[m, \"time\"] >= t) for t in times])\n", "\n", " # legend\n", " handles = [lines[g] for g in groups if g in lines]\n", " labels = [\"Low\", \"Medium\", \"High\"][:len(handles)]\n", " ax.legend(handles, labels, title=\"Groups\", loc=\"upper right\",\n", " frameon=True, framealpha=0.5, fontsize=12, title_fontsize=12)\n", "\n", " # at risk numbers (optional, matches your style)\n", " if len(at_risk_table) == 3:\n", " low, mid, high = at_risk_table\n", " for i, t in enumerate(times):\n", " ax.text(t, -0.38, str(low[i]), color=\"#207f4c\", fontsize=14, ha=\"center\")\n", " ax.text(t, -0.48, str(mid[i]), color=\"#fca106\", fontsize=14, ha=\"center\")\n", " ax.text(t, -0.58, str(high[i]), color=\"#cc163a\", fontsize=14, ha=\"center\")\n", "\n", " ax.text(-1, -0.28, \"Number at risk\", color=\"black\", ha=\"center\", fontsize=14)\n", " ax.text(-10, -0.38, \"Low\", color=\"#207f4c\", fontsize=14)\n", " ax.text(-10, -0.48, \"Medium\", color=\"#fca106\", fontsize=14)\n", " ax.text(-10, -0.58, \"High\", color=\"#cc163a\", fontsize=14)\n", "\n", " # Cox HR + p-values\n", " df2 = df.copy()\n", " df2[\"group_code\"] = df2[\"group\"].map({\"Low\": 0, \"Mediate\": 1, \"High\": 2})\n", " cph = CoxPHFitter()\n", " cph.fit(df2[[\"time\", \"event\", \"group_code\"]], duration_col=\"time\", event_col=\"event\")\n", "\n", " coef = float(cph.params_[\"group_code\"])\n", " se = float(cph.standard_errors_[\"group_code\"])\n", "\n", " hr_med_vs_low = np.exp(coef * 1)\n", " hr_high_vs_low = np.exp(coef * 2)\n", "\n", " z_med = (coef * 1) / se\n", " p_med = 2 * (1 - norm.cdf(abs(z_med)))\n", "\n", " z_high = (coef * 2) / se\n", " p_high = 2 * (1 - norm.cdf(abs(z_high)))\n", "\n", " # logrank\n", " res_lr = multivariate_logrank_test(df2[\"time\"], df2[\"group\"], df2[\"event\"])\n", "\n", " # C-index + brier (proxy)\n", " c_index, brier = _evaluate_survival_metrics(df2[\"time\"].values, df2[\"event\"].values,\n", " df2[\"group_code\"].values, time_point=30)\n", "\n", " ax.text(25, 0.46, f\"P(log-rank)={res_lr.p_value:.3f}\", fontsize=12)\n", " ax.text(25, 0.36, f\"C-index={c_index:.3f}\", fontsize=12)\n", " ax.text(25, 0.26, f\"Brier(30m)={brier:.3f}\", fontsize=12)\n", " ax.text(25, 0.16, f\"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}\", fontsize=12)\n", " ax.text(25, 0.06, f\"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}\", fontsize=12)\n", "\n", " # cosmetics\n", " ax.spines[\"top\"].set_visible(False)\n", " ax.spines[\"right\"].set_visible(False)\n", " ax.set_title(title, fontsize=14)\n", " ax.set_xlabel(\"Time since treatment start (months)\", fontsize=14)\n", " ax.set_ylabel(\"Survival probability\", fontsize=14)\n", " ax.set_ylim(0, 1.05)\n", " ax.grid(alpha=0.3)\n", "\n", " plt.tight_layout()\n", " plt.savefig(save_prefix + \".png\", dpi=600, bbox_inches=\"tight\")\n", " plt.savefig(save_prefix + \".pdf\", dpi=600, bbox_inches=\"tight\")\n", " plt.close()\n", " return save_prefix\n", "\n", "\n", "def generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=(\"val\", \"test\")):\n", " \"\"\"\n", " Load saved dfs/os arrays and generate KM+HR for Immune/Chemo separately.\n", " \"\"\"\n", " os.makedirs(fig_dir, exist_ok=True)\n", "\n", " for split in which_split:\n", " # load arrays\n", " trt = np.load(os.path.join(result_dir, f\"treatment_{split}.npy\"))\n", "\n", " dfs_r = np.load(os.path.join(result_dir, f\"dfs_{split}_risk.npy\"))\n", " dfs_t = np.load(os.path.join(result_dir, f\"dfs_{split}_time.npy\"))\n", " dfs_e = np.load(os.path.join(result_dir, f\"dfs_{split}_event.npy\"))\n", "\n", " os_r = np.load(os.path.join(result_dir, f\"os_{split}_risk.npy\"))\n", " os_t = np.load(os.path.join(result_dir, f\"os_{split}_time.npy\"))\n", " os_e = np.load(os.path.join(result_dir, f\"os_{split}_event.npy\"))\n", "\n", " for cohort_name, mask in {\n", " \"Immune\": trt == 0,\n", " \"Chemo\": trt == 1\n", " }.items():\n", " if mask.sum() < 20:\n", " print(f\"[Figure7] Skip {split}-{cohort_name}: too few samples ({mask.sum()})\")\n", " continue\n", "\n", " # DFS groups\n", " dfs_group = _risk_to_groups(dfs_r[mask])\n", " df_dfs = pd.DataFrame({\n", " \"time\": dfs_t[mask],\n", " \"event\": dfs_e[mask].astype(int),\n", " \"group\": dfs_group\n", " })\n", "\n", " # OS groups\n", " os_group = _risk_to_groups(os_r[mask])\n", " df_os = pd.DataFrame({\n", " \"time\": os_t[mask],\n", " \"event\": os_e[mask].astype(int),\n", " \"group\": os_group\n", " })\n", "\n", " # save CSV (optional, for reproducibility)\n", " df_dfs.to_csv(os.path.join(result_dir, f\"dfs_{split}_{cohort_name}.csv\"), index=False)\n", " df_os.to_csv(os.path.join(result_dir, f\"os_{split}_{cohort_name}.csv\"), index=False)\n", "\n", " # plot\n", " plot_km_curve_with_hr(\n", " df_dfs,\n", " title=f\"Disease-Free Survival (DFS) — Kaplan-Meier Curves\\n{cohort_name} {split} set (n={mask.sum()})\",\n", " save_prefix=os.path.join(fig_dir, f\"Figure7_DFS_{cohort_name}_{split}\")\n", " )\n", " plot_km_curve_with_hr(\n", " df_os,\n", " title=f\"Overall Survival (OS) — Kaplan-Meier Curves\\n{cohort_name} {split} set (n={mask.sum()})\",\n", " save_prefix=os.path.join(fig_dir, f\"Figure7_OS_{cohort_name}_{split}\")\n", " )\n", "\n", " print(\"✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.\")\n", "\n", "\n", "# =========================================================\n", "# Main\n", "# =========================================================\n", "def main():\n", " # -------------------------\n", " # Dataset (must return treatment as 19th item)\n", " # -------------------------\n", " from mm_dls.PatientDataset import PatientDataset\n", "\n", " dataset = PatientDataset(\n", " data_root=\"/path/to/DATA_ROOT\",\n", " clinical_csv=\"/path/to/clinical.csv\",\n", " radiomics_npy=\"/path/to/radiomics.npy\",\n", " pet_npy=\"/path/to/pet.npy\",\n", " n_slices=N_SLICES,\n", " img_size=IMG_SIZE,\n", " )\n", "\n", "\n", " n_train = int(0.6 * len(dataset))\n", " n_val = int(0.2 * len(dataset))\n", " n_test = len(dataset) - n_train - n_val\n", "\n", " train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])\n", "\n", " loaders = {\n", " \"train\": DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4),\n", " \"val\": DataLoader(val_set, BATCH_SIZE, shuffle=False, num_workers=4),\n", " \"test\": DataLoader(test_set, BATCH_SIZE, shuffle=False, num_workers=4),\n", " }\n", "\n", " # -------------------------\n", " # Model\n", " # -------------------------\n", " model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)\n", " optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", " best_val_loss = 1e9\n", " wait = 0\n", "\n", " # -------------------------\n", " # Training\n", " # -------------------------\n", " for epoch in range(1, EPOCHS + 1):\n", " tr = run_epoch_verbose(model, loaders[\"train\"], optimizer, DEVICE, train=True)\n", " va = run_epoch_verbose(model, loaders[\"val\"], optimizer, DEVICE, train=False)\n", "\n", " tr_loss = tr[0]\n", " va_loss = va[0]\n", "\n", " # unpack val for metrics\n", " _, sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e, _, _ = va\n", " metrics = evaluate_by_treatment(sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e)\n", "\n", " print(f\"\\n[Epoch {epoch:03d}] Train Loss={tr_loss:.3f} | Val Loss={va_loss:.3f}\")\n", " for k, v in metrics.items():\n", " print(\n", " f\" {k:7s} | \"\n", " f\"Subtype AUC={v['Subtype_AUC']:.3f} | \"\n", " f\"TNM AUC={v['TNM_AUC_macro']:.3f} | \"\n", " f\"DFS C-index={v['DFS_C_index']:.3f} | \"\n", " f\"OS C-index={v['OS_C_index']:.3f}\"\n", " )\n", "\n", " # early stopping\n", " if va_loss < best_val_loss:\n", " best_val_loss = va_loss\n", " wait = 0\n", " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"best_model.pt\"))\n", " print(\" ✓ Best model updated\")\n", " else:\n", " wait += 1\n", " if wait >= PATIENCE:\n", " print(\"\\n⏹ Early stopping triggered\")\n", " break\n", "\n", " # -------------------------\n", " # Inference (best model)\n", " # -------------------------\n", " print(\"\\nRunning inference with best model...\")\n", " model.load_state_dict(torch.load(os.path.join(SAVE_DIR, \"best_model.pt\"), map_location=DEVICE))\n", "\n", " for split in [\"train\", \"val\", \"test\"]:\n", " out = run_epoch_verbose(model, loaders[split], optimizer, DEVICE, train=False)\n", " (\n", " loss,\n", " sy, ss,\n", " ty, ts,\n", " trt,\n", " dfs_r, dfs_t, dfs_e,\n", " os_r, os_t, os_e,\n", " dfs_log, os_log\n", " ) = out\n", "\n", " # classification\n", " np.save(os.path.join(SAVE_DIR, f\"subtype_{split}_labels.npy\"), sy)\n", " np.save(os.path.join(SAVE_DIR, f\"subtype_{split}_scores.npy\"), ss)\n", " np.save(os.path.join(SAVE_DIR, f\"tnm_{split}_labels.npy\"), ty)\n", " np.save(os.path.join(SAVE_DIR, f\"tnm_{split}_scores.npy\"), ts)\n", " np.save(os.path.join(SAVE_DIR, f\"treatment_{split}.npy\"), trt)\n", "\n", " # survival (cox risk + time/event)\n", " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_risk.npy\"), dfs_r)\n", " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_time.npy\"), dfs_t)\n", " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_event.npy\"), dfs_e)\n", "\n", " np.save(os.path.join(SAVE_DIR, f\"os_{split}_risk.npy\"), os_r)\n", " np.save(os.path.join(SAVE_DIR, f\"os_{split}_time.npy\"), os_t)\n", " np.save(os.path.join(SAVE_DIR, f\"os_{split}_event.npy\"), os_e)\n", "\n", " # 1y/3y/5y logits (optional, for AUC at specific horizons)\n", " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_logits_1y3y5y.npy\"), dfs_log)\n", " np.save(os.path.join(SAVE_DIR, f\"os_{split}_logits_1y3y5y.npy\"), os_log)\n", "\n", " print(f\"{split:5s} | loss={loss:.3f} | Immune={np.sum(trt==0)} Chemo={np.sum(trt==1)}\")\n", "\n", " print(\"\\n✓ Inference completed. Results saved.\")\n", "\n", " # -------------------------\n", " # Figure: Immune/Chemo KM + HR\n", " # -------------------------\n", " print(\"\\nGenerating Figure (KM + HR) ...\")\n", " generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=(\"val\", \"test\"))\n", " print(\"✓ Figure done. Files saved under ./figures\")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()\n" ] } ], "metadata": { "kernelspec": { "display_name": "VGG", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }