Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import math | |
| import bisect | |
| import re | |
| import shutil | |
| import random | |
| import numpy as np | |
| import nibabel as nib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, RobertaModel | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib | |
| matplotlib.use("Agg") # headless | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import io | |
| import pandas as pd | |
| # --------------------------- | |
| # HF token (for private repos) | |
| # --------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| # --------------------------- | |
| # Deterministic setup | |
| # --------------------------- | |
| SEED = int(os.getenv("APP_SEED", "2026")) | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # ============================== | |
| # Device & (disable AMP for determinism) | |
| # ============================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| USE_AMP = False # 为保证可重复性,关闭 AMP | |
| # ======================================================== | |
| # 5-year survival mapping rules | |
| # ======================================================== | |
| FIVE_YEAR_SURVIVAL_RANGE_STR = { | |
| "Q1 (Low risk)": "90% – 97%", | |
| "Q2–Q3 (Medium risk)": "72% – 88%", | |
| "Q4 (High risk)": "50% – 68%", | |
| } | |
| CLINICAL_SUGGESTION = { | |
| "Q1 (Low risk)": "Routine annual follow-up according to standard guidelines.", | |
| "Q2–Q3 (Medium risk)": "Regular follow-up is recommended. Consider supplemental evaluation if clinically indicated.", | |
| "Q4 (High risk)": "Close clinical monitoring is advised. Consider multidisciplinary evaluation and potential treatment optimization." | |
| } | |
| # ======================================================== | |
| # Hard-coded Hub defaults | |
| # ======================================================== | |
| DEFAULT_MODEL_REPO = "zhang0319/Multimodel_Surv" | |
| DEFAULT_SPACE_REPO = "zhang0319/Multimodal_Surv" | |
| DEFAULT_TEXT_SUBFOLDER = "models/radiobert_BigDataset_epoch10" | |
| DEFAULT_CKPT_FILENAME = "weights/20251003_dropout0.3_best_image_report_clin_model20251003_8__.pth" | |
| HIDDEN_DIM = 768 | |
| DROPOUT_PROB = 0.3 | |
| # ======================================================== | |
| # Neoadjuvant options (UI only) | |
| # ======================================================== | |
| NEO_OPTIONS = [ | |
| "Neoadjuvant radiotherapy", | |
| "Neoadjuvant chemotherapy", | |
| "Neoadjuvant hormonal", | |
| "Neoadjuvant targeted", | |
| ] | |
| # ======================================================== | |
| # Load SCORES from Excel (from space repo) | |
| # ======================================================== | |
| def load_scores_from_excel(): | |
| excel_path = hf_hub_download( | |
| repo_id=DEFAULT_SPACE_REPO, | |
| filename="scores.xlsx", | |
| repo_type="space", | |
| token=HF_TOKEN | |
| ) | |
| df = pd.read_excel(excel_path) | |
| for col in df.columns: | |
| if "score" in col.lower(): | |
| vals = df[col].dropna().astype(float).tolist() | |
| return sorted(vals) | |
| vals = df.iloc[:, 0].dropna().astype(float).tolist() | |
| return sorted(vals) | |
| SCORES = load_scores_from_excel() | |
| # ======================================================== | |
| # Percentile → Risk → 5-year Survival (Interpolation) | |
| # ======================================================== | |
| def percentile_from_scores(pred_val: float, scores: list[float]): | |
| if not scores: | |
| return float("nan"), 0, 0 | |
| n = len(scores) | |
| idx = bisect.bisect_right(scores, float(pred_val)) # 1..n | |
| pct = (idx - 1) / (n - 1) * 100 if n > 1 else 100.0 | |
| return max(0.0, min(100.0, pct)), idx, n | |
| def quartile_from_percentile(pct: float): | |
| if pct < 25: return "Q1 (Low risk)" | |
| if pct < 75: return "Q2–Q3 (Medium risk)" | |
| return "Q4 (High risk)" | |
| def interpolate_s5_from_percentile(pct: float): | |
| risk = quartile_from_percentile(pct) | |
| if risk == "Q1 (Low risk)": | |
| high, low = 97.0, 90.0 | |
| t = (pct - 0) / 25 | |
| s5 = high + t * (low - high) | |
| elif risk == "Q2–Q3 (Medium risk)": | |
| high, low = 88.0, 72.0 | |
| t = (pct - 25) / 50 | |
| s5 = high + t * (low - high) | |
| else: | |
| high, low = 68.0, 50.0 | |
| t = (pct - 75) / 25 | |
| s5 = high + t * (low - high) | |
| return float(max(0, min(100, s5))), risk | |
| # ======================================================== | |
| # Smooth KM-like 0–10y Curve | |
| # ======================================================== | |
| def _make_baseline_curve_0_10y( | |
| s5_percent: float, | |
| alpha: float = 2.0, | |
| late_slowdown: float = 0.35, | |
| ): | |
| s5 = float(np.clip(s5_percent, 1.0, 99.9)) | |
| years = np.linspace(0.0, 10.0, 241) | |
| t1 = np.clip(years, 0.0, 5.0) | |
| a_eff = max(alpha, 1.0001) | |
| surv_0_5 = 100.0 - (100.0 - s5) * (t1 / 5.0) ** a_eff | |
| m5 = - (100.0 - s5) * a_eff / 5.0 | |
| delta10 = 6.0 * (1.0 - np.clip(late_slowdown, 0.0, 1.0)) | |
| S10 = max(s5 - delta10, 40.0) | |
| u = np.clip((years - 5.0) / 5.0, 0.0, 1.0) | |
| Delta = max(s5 - S10, 1e-6) | |
| a_raw = (-5.0 * m5) / Delta | |
| a = float(np.clip(a_raw, 0.0, 3.0)) | |
| r = a * u + (3.0 - 2.0 * a) * (u ** 2) + (a - 2.0) * (u ** 3) | |
| surv_5_10 = s5 - Delta * r | |
| surv = np.where(years <= 5.0, surv_0_5, surv_5_10) | |
| surv = np.clip(surv, 0.0, 100.0) | |
| return years, surv | |
| def make_dynamic_curve_image( | |
| s5_percent: float, | |
| alpha: float = 2.0, | |
| late_slowdown: float = 0.35 | |
| ): | |
| years, surv = _make_baseline_curve_0_10y( | |
| s5_percent=s5_percent, | |
| alpha=alpha, | |
| late_slowdown=late_slowdown | |
| ) | |
| fig, ax = plt.subplots(figsize=(6, 4), dpi=160) | |
| ax.plot(years, surv, linewidth=2.5, linestyle="--", color="orange") | |
| ax.axvline(5.0, linestyle="--", linewidth=1.5, color="gray", alpha=0.8) | |
| ax.fill_between( | |
| years, 0, surv, | |
| where=years <= 5.0, | |
| color="#272A75", alpha=1, interpolate=True | |
| ) | |
| ax.fill_between( | |
| years, 0, surv, | |
| where=years >= 5.0, | |
| color="blue", alpha=1, interpolate=True | |
| ) | |
| ax.set_xlim(0, 10); ax.set_ylim(0, 100) | |
| ax.set_xlabel("Years after surgery") | |
| ax.set_ylabel("Percentage of women surviving") | |
| ax.grid(alpha=0.3, linestyle="--") | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # ============================== | |
| # Utils | |
| # ============================== | |
| def clip_and_norm(img: np.ndarray, vmin: float, vmax: float) -> np.ndarray: | |
| x = np.clip(img, vmin, vmax) | |
| x = (x - vmin) / max(vmax - vmin, 1e-6) | |
| return x.astype(np.float32) | |
| def normalize_dce_pair_all(dce1: np.ndarray, dce2: np.ndarray): | |
| max_val = np.percentile(dce2, 99) | |
| if max_val < 1e-6: | |
| max_val = 1e-6 | |
| dce1_clipped = np.clip(dce1, 0, max_val) | |
| dce2_clipped = np.clip(dce2, 0, max_val) | |
| dce1_norm = dce1_clipped / max_val | |
| dce2_norm = dce2_clipped / max_val | |
| return dce1_norm.astype(np.float32), dce2_norm.astype(np.float32) | |
| def load_nii(file) -> np.ndarray: | |
| if file is None: | |
| return None | |
| img = nib.load(file.name) | |
| return img.get_fdata().astype(np.float32) | |
| def clean_report_text(raw: str) -> str: | |
| if raw is None: | |
| return "" | |
| report2 = raw.replace("\n", "") | |
| idx_k = report2.find("Klinische") | |
| if idx_k != -1: | |
| idx_v = report2.find("Verslag", idx_k) | |
| return report2[idx_v:] if idx_v != -1 else report2 | |
| return report2 | |
| def clamp_age_0_100(age_val): | |
| try: | |
| v = int(float(age_val)) | |
| except: | |
| return "-1" | |
| v = max(0, min(100, v)) | |
| return str(v) | |
| def _ensure_3d(vol: np.ndarray) -> np.ndarray: | |
| if vol is None: | |
| return None | |
| if vol.ndim == 4 and vol.shape[-1] == 1: | |
| vol = vol[..., 0] | |
| return vol | |
| def get_largest_mask_slice_idx(mask_3d: np.ndarray) -> int: | |
| mask_3d = _ensure_3d(mask_3d) | |
| if mask_3d is None or mask_3d.ndim != 3: | |
| raise ValueError(f"Expected 3D mask, got shape {None if mask_3d is None else mask_3d.shape}") | |
| sums = mask_3d.sum(axis=(0, 1)) | |
| if sums.max() <= 0: | |
| return mask_3d.shape[2] // 2 | |
| return int(sums.argmax()) | |
| def volume_to_slice_pil(vol: np.ndarray, slice_idx: int, apply_norm: bool = True) -> Image.Image: | |
| vol = _ensure_3d(vol) | |
| if vol is None or vol.ndim != 3: | |
| raise ValueError(f"Expected 3D volume, got shape {None if vol is None else vol.shape}") | |
| h, w, d = vol.shape | |
| slice_idx = max(0, min(d - 1, int(slice_idx))) | |
| slice_2d = vol[:, :, slice_idx].astype(np.float32) | |
| if apply_norm: | |
| vmin = np.percentile(slice_2d, 1) | |
| vmax = np.percentile(slice_2d, 99) | |
| if vmax <= vmin: | |
| vmax = vmin + 1.0 | |
| slice_2d = np.clip((slice_2d - vmin) / (vmax - vmin), 0, 1) | |
| else: | |
| slice_2d = np.clip(slice_2d, 0, 1) | |
| slice_uint8 = (slice_2d * 255).astype(np.uint8) | |
| return Image.fromarray(slice_uint8, mode="L") | |
| # ============================== | |
| # 3D ResNet-18 (feature -> 512) | |
| # ============================== | |
| class BasicBlock3D(nn.Module): | |
| expansion = 1 | |
| def __init__(self, in_planes, planes, stride=1, downsample=None): | |
| super().__init__() | |
| self.conv1 = nn.Conv3d(in_planes, planes, 3, stride, 1, bias=False) | |
| self.bn1 = nn.BatchNorm3d(planes) | |
| self.conv2 = nn.Conv3d(planes, planes, 3, 1, 1, bias=False) | |
| self.bn2 = nn.BatchNorm3d(planes) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| identity = x | |
| out = F.relu(self.bn1(self.conv1(x)), inplace=True) | |
| out = self.bn2(self.conv2(out)) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| return F.relu(out + identity, inplace=True) | |
| class ResNet3D(nn.Module): | |
| def __init__(self, block, layers): | |
| super().__init__() | |
| self.in_planes = 64 | |
| self.conv1 = nn.Conv3d(1, 64, 7, 2, 3, bias=False) | |
| self.bn1 = nn.BatchNorm3d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool3d(3, 2, 1) | |
| self.layer1 = self._make_layer(block, 64, layers[0]) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
| self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
| def _make_layer(self, block, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or self.in_planes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv3d(self.in_planes, planes * block.expansion, 1, stride, bias=False), | |
| nn.BatchNorm3d(planes * block.expansion), | |
| ) | |
| layers = [block(self.in_planes, planes, stride, downsample)] | |
| self.in_planes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append(block(self.in_planes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.relu(self.bn1(self.conv1(x))) | |
| x = self.maxpool(x) | |
| x = self.layer4(self.layer3(self.layer2(self.layer1(x)))) | |
| x = self.avgpool(x) | |
| return torch.flatten(x, 1) | |
| def resnet18_3d_feat(): | |
| return ResNet3D(BasicBlock3D, [2, 2, 2, 2]) | |
| # ============================== | |
| # Text encoder | |
| # ============================== | |
| class RadioLOGIC_666(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = RobertaModel.from_pretrained( | |
| DEFAULT_MODEL_REPO, | |
| subfolder=DEFAULT_TEXT_SUBFOLDER, | |
| add_pooling_layer=False, | |
| token=HF_TOKEN | |
| ) | |
| def forward(self, input_id=None, attention_mask=None): | |
| outputs = self.bert(input_ids=input_id, attention_mask=attention_mask, return_dict=False) | |
| sequence_output = outputs[0] | |
| pooler = sequence_output[:, 0] | |
| return pooler | |
| class Report_Net(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.RadioLOGIC = RadioLOGIC_666() | |
| for p in self.RadioLOGIC.parameters(): | |
| p.requires_grad = False | |
| def forward(self, input_id=None, attention_mask=None): | |
| return self.RadioLOGIC(input_id, attention_mask) | |
| # ============================== | |
| # Attention | |
| # ============================== | |
| class Attention(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, input_dim) | |
| ) | |
| def forward(self, x): | |
| attn_weights = self.attention(x) | |
| attn_weights = torch.softmax(attn_weights, dim=1) | |
| weighted_output = x * attn_weights | |
| output = weighted_output + x | |
| return output | |
| # ============================== | |
| # MultiModalModel | |
| # ============================== | |
| class MultiModalModel(nn.Module): | |
| def __init__(self, image_model, report_net, dropout_prob=DROPOUT_PROB, hidden_dim=HIDDEN_DIM): | |
| super().__init__() | |
| self.image_branch = image_model | |
| self.vector_branch = report_net | |
| self.attention = Attention(input_dim=512+512+768+512, hidden_dim=hidden_dim) | |
| self.fc_clin = nn.Linear(25*12, 512) | |
| self.prompt_embedding = nn.Embedding(2, 16) | |
| self.fc = nn.Linear(2320, 512) | |
| self.relu = nn.ReLU() | |
| self.bn = nn.BatchNorm1d(512) | |
| self.drop = nn.Dropout(p=dropout_prob) | |
| self.classifier = nn.Linear(512, 1) | |
| def forward(self, image1, image2, input_id, attention_mask, clinical_feature, prompt_type): | |
| if clinical_feature.dtype != torch.float32: | |
| clinical_feature = clinical_feature.float() | |
| clinical_feature = clinical_feature.view(clinical_feature.size(0), -1) | |
| image_features1 = self.image_branch(image1) | |
| image_features2 = self.image_branch(image2) | |
| vector_features = self.vector_branch(input_id, attention_mask) | |
| clinical_features = self.fc_clin(clinical_feature) | |
| combined_features = torch.cat( | |
| (image_features1, image_features2, vector_features, clinical_features), dim=1 | |
| ) | |
| attended_features = self.attention(combined_features) | |
| prompt_embedding = self.prompt_embedding(prompt_type) | |
| combined_with_prompt = torch.cat((prompt_embedding, attended_features), dim=1) | |
| outputs = self.relu(self.fc(combined_with_prompt)) | |
| outputs = self.drop(self.bn(outputs)) | |
| outputs = self.classifier(outputs) | |
| return outputs | |
| # ============================== | |
| # Clinical one-hot | |
| # ============================== | |
| def generate_one_hot(T_stage_value, N_stage_value, M_stage_value, | |
| T_stage_post_value, N_stage_post_value, M_stage_post_value, | |
| family_history, age, E, P, H, tumor_types): | |
| one_hot_matrix = np.zeros((25, 12), dtype=np.float32) | |
| T_stage = ['0', '1', '1A', '1B', '1C', '1M', '1MI', '2', '3', '4', '4A', '4B', '4D', 'IS', 'X'] | |
| N_stage = ['0', '0IS', '0S', '1', '1MS', '1S', '1BS', '2A', '2B', '3', '3A', '3B', '3BS', '3C', 'X'] | |
| M_stage = ['0', '1', 'X'] | |
| T_stage_post = ['0', '1', '1A', '1B', '1C', '1MI', '2', '3', '4B', 'IS', 'X', 'Y0', 'Y1', 'Y1A', 'Y1B', | |
| 'Y1C', 'Y1MI', 'Y2', 'Y3', 'Y4A', 'Y4B', 'Y4D', 'YIS', 'YX'] | |
| N_stage_post = ['0', '0I', '0IS', '0S', '1', '1A', '1AS', '1B', '1BS', '1B1', '1B2', '1B3', '1B4', | |
| '1M', '1MI', '1MS', '2', '2A', '2B', '3A', '3B', '3C', 'X'] | |
| M_stage_post = ['0', '1', 'X'] | |
| def _set(lst, val, col): | |
| if val not in ["-1", "-", "None", None]: | |
| s = str(val).strip() | |
| if s in lst: | |
| one_hot_matrix[lst.index(s), col] = 1 | |
| _set(T_stage, T_stage_value, 0) | |
| _set(N_stage, N_stage_value, 1) | |
| _set(M_stage, M_stage_value, 2) | |
| _set(T_stage_post, T_stage_post_value, 3) | |
| _set(N_stage_post, N_stage_post_value, 4) | |
| _set(M_stage_post, M_stage_post_value, 5) | |
| if isinstance(family_history, str) and 'kanker' in family_history.lower(): | |
| one_hot_matrix[0, 6] = 1 | |
| try: | |
| if age not in ["-1", None, "None", ""]: | |
| agei = int(float(age)) | |
| agei = max(0, min(100, agei)) | |
| idx = max(0, min(24, agei // 4)) | |
| one_hot_matrix[idx, 7] = 1 | |
| except: | |
| pass | |
| def _ep_handle(v, col): | |
| if v == "-1": | |
| one_hot_matrix[:, col] = 0 | |
| else: | |
| try: | |
| vi = int(v) | |
| if 0 < vi / 4 < 25: | |
| one_hot_matrix[1, col] = 1 | |
| else: | |
| one_hot_matrix[0, col] = 1 | |
| except: | |
| pass | |
| _ep_handle("90" if E == "90" else ("-1" if E == "-1" else "0"), 8) | |
| _ep_handle("90" if P == "90" else ("-1" if P == "-1" else "0"), 9) | |
| if H == "-1": | |
| one_hot_matrix[:, 10] = 0 | |
| else: | |
| try: | |
| hi = int(H) | |
| if 2 < hi < 8: | |
| one_hot_matrix[1, 10] = 1 | |
| else: | |
| one_hot_matrix[0, 10] = 1 | |
| except: | |
| pass | |
| if tumor_types is not None: | |
| tt = tumor_types.lower() | |
| if ('ductaal' in tt) and ('infiltrerend ductaal' not in tt) and \ | |
| ('intraductaal carcinoom' not in tt) and ('ductaal carcinoma in situ' not in tt): | |
| one_hot_matrix[0, 11] = 1 | |
| if ('infiltrerend ductaal' in tt) and ('intraductaal carcinoom' not in tt): | |
| one_hot_matrix[1, 11] = 1 | |
| if ('lobulair' in tt) and ('infiltrerend lobulair' not in tt): | |
| one_hot_matrix[2, 11] = 1 | |
| if 'infiltrerend lobulair' in tt: | |
| one_hot_matrix[3, 11] = 1 | |
| if 'tubular' in tt: | |
| one_hot_matrix[4, 11] = 1 | |
| if 'mucineus' in tt: | |
| one_hot_matrix[5, 11] = 1 | |
| if 'micropapillair' in tt: | |
| one_hot_matrix[6, 11] = 1 | |
| if ('papillair' in tt) and ('micropapillair' not in tt) and \ | |
| ('intraductaal papillair adenocarcinoom' not in tt): | |
| one_hot_matrix[7, 11] = 1 | |
| if ('ductaal carcinoma in situ' in tt) or ('intraductaal carcinoom' in tt) or \ | |
| ('intraductaal papillair adenocarcinoom' in tt): | |
| one_hot_matrix[8, 11] = 1 | |
| if np.sum(one_hot_matrix[:, 11]) == 0: | |
| one_hot_matrix[9, 11] = 1 | |
| return one_hot_matrix | |
| # ============================== | |
| # Global tokenizer / model | |
| # ============================== | |
| print("[INIT] Loading tokenizer...") | |
| GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained( | |
| DEFAULT_MODEL_REPO, | |
| subfolder=DEFAULT_TEXT_SUBFOLDER, | |
| use_fast=True, | |
| token=HF_TOKEN | |
| ) | |
| def _build_model() -> "MultiModalModel": | |
| image_model = resnet18_3d_feat() | |
| report_net = Report_Net() | |
| model = MultiModalModel( | |
| image_model=image_model, | |
| report_net=report_net, | |
| dropout_prob=DROPOUT_PROB, | |
| hidden_dim=HIDDEN_DIM | |
| ) | |
| return model | |
| def _filter_state_dict_for_strict_load(model: nn.Module, state: dict) -> dict: | |
| model_sd = model.state_dict() | |
| filtered = {} | |
| drop_cnt = 0 | |
| for k, v in state.items(): | |
| if k in model_sd and model_sd[k].shape == v.shape: | |
| filtered[k] = v | |
| else: | |
| drop_cnt += 1 | |
| if drop_cnt: | |
| print(f"[INIT] Dropped {drop_cnt} unmatched ckpt keys (e.g., old heads like image_branch.fc.*).") | |
| missing = [k for k in model_sd.keys() if k not in filtered] | |
| if missing: | |
| print("[WARN] Missing keys after filtering (showing first 10):", missing[:10]) | |
| return filtered | |
| print("[INIT] Building model and loading weights...") | |
| GLOBAL_MODEL = _build_model().to(DEVICE).eval() | |
| ckpt_path = hf_hub_download( | |
| repo_id=DEFAULT_MODEL_REPO, | |
| filename=DEFAULT_CKPT_FILENAME, | |
| repo_type="model", | |
| token=HF_TOKEN | |
| ) | |
| raw_state = torch.load(ckpt_path, map_location=DEVICE) | |
| if isinstance(raw_state, dict) and "state_dict" in raw_state: | |
| raw_state = raw_state["state_dict"] | |
| raw_state = {k.replace("module.", ""): v for k, v in raw_state.items()} | |
| state = _filter_state_dict_for_strict_load(GLOBAL_MODEL, raw_state) | |
| GLOBAL_MODEL.load_state_dict(state, strict=True) | |
| print("[INIT] Model weights loaded with strict=True (after key filtering)") | |
| # ============================== | |
| # UI helper | |
| # ============================== | |
| def map_ui_to_internal(family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opts, treatment_opt): | |
| fam = "kanker" if (family_history_opt == "Yes") else "NA" | |
| def _map_er_pr(v): | |
| if v == "Positive": return "90" | |
| if v == "Negative": return "0" | |
| return "-1" | |
| E = _map_er_pr(er_opt) | |
| P = _map_er_pr(pr_opt) | |
| if her2_opt == "Positive": | |
| H = "3" | |
| elif her2_opt == "Negative": | |
| H = "1" | |
| else: | |
| H = "-1" | |
| mapped_tokens = [] | |
| if isinstance(tumor_type_opts, (list, tuple)): | |
| for opt in tumor_type_opts: | |
| if opt == "Invasive ductal carcinoma": | |
| mapped_tokens.append("infiltrerend ductaal") | |
| elif opt == "Invasive lobular carcinoma": | |
| mapped_tokens.append("infiltrerend lobulair") | |
| elif opt == "Lobular carcinoma (IS)": | |
| mapped_tokens.append("lobulair") | |
| elif opt == "DCIS": | |
| mapped_tokens.append("ductaal carcinoma in situ") | |
| else: | |
| mapped_tokens.append("others") | |
| else: | |
| if tumor_type_opts == "Invasive ductal carcinoma": | |
| mapped_tokens.append("infiltrerend ductaal") | |
| elif tumor_type_opts == "Invasive lobular carcinoma": | |
| mapped_tokens.append("infiltrerend lobulair") | |
| elif tumor_type_opts == "Lobular carcinoma (IS)": | |
| mapped_tokens.append("lobulair") | |
| elif tumor_type_opts == "DCIS": | |
| mapped_tokens.append("ductaal carcinoma in situ") | |
| else: | |
| mapped_tokens.append("others") | |
| tumor = "; ".join(mapped_tokens) | |
| primary = "NA" if (treatment_opt == "Neoadjuvant") else "Surgery" | |
| return fam, E, P, H, tumor, primary | |
| # ============================== | |
| # Inference | |
| # ============================== | |
| def run_infer( | |
| has_metastasis, | |
| img1_file, img2_file, mask_file, | |
| T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post, | |
| family_history_opt, age_val, er_opt, pr_opt, her2_opt, tumor_type_opts, treatment_opt, | |
| neo_types, | |
| report_text | |
| ): | |
| def _early_err(msg): | |
| return msg, None, None, None, None, None, None | |
| if has_metastasis == "Yes": | |
| return _early_err( | |
| "This prediction tool should not be used for patients with distant metastases; " | |
| "such patients are already considered a high-risk group." | |
| ) | |
| if img1_file is None or img2_file is None or mask_file is None: | |
| return _early_err("❌ Please upload Pre MRI, Post MRI, and Mask.") | |
| try: | |
| img1 = load_nii(img1_file) | |
| img2 = load_nii(img2_file) | |
| msk = load_nii(mask_file) | |
| except Exception as e: | |
| return _early_err(f"❌ Failed to load NIfTI files: {e}") | |
| if img1 is None or img2 is None or msk is None: | |
| return _early_err("❌ Failed to load NIfTI files.") | |
| mask_3d = _ensure_3d(msk.astype(np.float32)) | |
| largest_idx = get_largest_mask_slice_idx(mask_3d) | |
| img1_norm, img2_norm = normalize_dce_pair_all(img1, img2) | |
| pre_img_pil = volume_to_slice_pil(img1_norm, largest_idx, apply_norm=False) | |
| post_img_pil = volume_to_slice_pil(img2_norm, largest_idx, apply_norm=False) | |
| mask_modified = np.where(mask_3d == 1, 10.0, 1.0).astype(np.float32) | |
| img1_3d = img1_norm * mask_modified | |
| img2_3d = img2_norm * mask_modified | |
| img1_tensor = np.expand_dims(np.expand_dims(img1_3d, axis=0), axis=0) | |
| img2_tensor = np.expand_dims(np.expand_dims(img2_3d, axis=0), axis=0) | |
| image1 = torch.from_numpy(img1_tensor).to(DEVICE) | |
| image2 = torch.from_numpy(img2_tensor).to(DEVICE) | |
| clean_txt = clean_report_text(report_text or "") | |
| code = GLOBAL_TOKENIZER( | |
| clean_txt, | |
| padding='max_length', | |
| max_length=512, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_id = code['input_ids'].to(DEVICE) | |
| attn_mask = code['attention_mask'].to(DEVICE) | |
| fam, E, P, H, tumor, primary = map_ui_to_internal( | |
| family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opts, treatment_opt | |
| ) | |
| age = clamp_age_0_100(age_val) | |
| clin_np = generate_one_hot( | |
| T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post, | |
| fam, age, E, P, H, tumor | |
| ) | |
| clin = torch.from_numpy(clin_np).unsqueeze(0).to(DEVICE) | |
| prompt_type_indices = torch.tensor( | |
| [0 if primary == "NA" else 1], | |
| dtype=torch.long, | |
| device=DEVICE | |
| ) | |
| pred_logit = GLOBAL_MODEL( | |
| image1, image2, input_id, attn_mask, clin, prompt_type_indices | |
| ) | |
| pred = float(pred_logit.squeeze().item()) | |
| pct, rank, total = percentile_from_scores(pred, SCORES) | |
| pct_rounded = round(pct, 1) | |
| s5, risk = interpolate_s5_from_percentile(pct) | |
| status_text = f"• Percentile: {pct_rounded}% → {risk}" | |
| suggest_text = CLINICAL_SUGGESTION.get(risk, "Regular follow-up is recommended.") | |
| years_curve, surv_curve = _make_baseline_curve_0_10y( | |
| s5_percent=s5, | |
| alpha=2.0, | |
| late_slowdown=0.35 | |
| ) | |
| s3_plot = float(np.interp(3.0, years_curve, surv_curve)) | |
| s5_plot = float(np.interp(5.0, years_curve, surv_curve)) | |
| s10_plot = float(np.interp(10.0, years_curve, surv_curve)) | |
| km_img_pil = make_dynamic_curve_image( | |
| s5_percent=s5, | |
| alpha=2.0, | |
| late_slowdown=0.35 | |
| ) | |
| surv_text = ( | |
| f' • 3-year: {s3_plot:.0f}% ' | |
| f' • 5-year: {s5_plot:.0f}% ' | |
| f' • 10-year: {s10_plot:.0f}% ' | |
| ) | |
| return status_text, pred, km_img_pil, surv_text, suggest_text, pre_img_pil, post_img_pil | |
| # ------------------------------------------------- | |
| # Examples handling: WITHOUT metastasis column | |
| # ------------------------------------------------- | |
| def _download_example_to_tmp(fname: str) -> str: | |
| src = hf_hub_download( | |
| repo_id=DEFAULT_MODEL_REPO, | |
| filename=f"examples/{fname}", | |
| repo_type="model", | |
| token=HF_TOKEN | |
| ) | |
| os.makedirs("/tmp/examples", exist_ok=True) | |
| dst = os.path.join("/tmp/examples", fname) | |
| shutil.copy2(src, dst) | |
| return dst | |
| def get_examples_rows(): | |
| ex1_p1 = _download_example_to_tmp("B000013668_20110715_1.nii.gz") | |
| ex1_p2 = _download_example_to_tmp("B000013668_20110715_2.nii.gz") | |
| ex1_msk = _download_example_to_tmp("B000013668_20110715_mask.nii.gz") | |
| ex2_p1 = _download_example_to_tmp("B000019708_20170620_1.nii.gz") | |
| ex2_p2 = _download_example_to_tmp("B000019708_20170620_2.nii.gz") | |
| ex2_msk = _download_example_to_tmp("B000019708_20170620_mask.nii.gz") | |
| ex3_p1 = _download_example_to_tmp("B000026169_20190507_1.nii.gz") | |
| ex3_p2 = _download_example_to_tmp("B000026169_20190507_2.nii.gz") | |
| ex3_msk = _download_example_to_tmp("B000026169_20190507_mask.nii.gz") | |
| row1 = [ | |
| ex1_p1, ex1_p2, ex1_msk, | |
| "4", "3", "0", "4", "3", "0", | |
| "No", 54, "Positive", "Negative", "Negative", | |
| ["Invasive ductal carcinoma"], "Neoadjuvant", | |
| ["Neoadjuvant chemotherapy"], | |
| "Categorie: Poliklinisch order Afdeling: \nOnderzoeksdatum: 15-07-2011 \nOnderzoek(en): MRI mamma neo adj.chemo Typist(e): \n\nKlinische Gegevens:\nMammacarcinoom T4N1-2 li, ulcererend; screening voor neoadjuvante chemotherapie\n\nVerslag.\nRechts lateraalboven vrij scherp omschreven aankleurende structuur van ongeveer 6 mm. BIRADS-3. Advies: gerichte echo.\n\nLinks zeer uitgebreide in de huid infiltrerende tumor met ringvormige aankleuring en uitwas.\nBekende pathologische okselklieren." | |
| ] | |
| row2 = [ | |
| ex2_p1, ex2_p2, ex2_msk, | |
| "X", "0", "0", "1", "0", "0", | |
| "No", 40, "Positive", "Positive", "Negative", | |
| ["Invasive ductal carcinoma"], "Surgery", | |
| [], | |
| "Klinische Gegevens:\nPalpabele tumor, DCIS graad 2 lateraal bovenkwadrant linker mamma, 2-3cm op echografie\n\nVraagstelling:\nGrootte afwijking/DCIS links. Contralateraal? Mammasparende therapie mogelijk? \n\n\nVerslag: MRI mamma:\nEerste onderzoek. Ter correlatie mammografie en echo 19 juni 2017.\nBeiderzijds heterogeen dens klierweefsel waarin verspreid vlekkige parenchym achtergrond aankleuring, met doorkleuring in de late fase. \nDaarnaast in de linker mamma diffuus, aankleurend massagebied van nagenoeg gehele laterale onderkwadrant (tussen 3 en 5 uur) over ruim 11 cm, te vervolgen tot aan de tepelknop.\nDeze toont een wat confluerend nodulair aspect. \nIn de late fase gemengd plateau en uitwas over ruim 11 cm.\nLateraal in het tumorgebied (op 3 - 4 uur) een kleine marker, alwaar eerder echogeleid een histologisch bewezen maligne palpabele massa gevonden werd (ca 3cm, PA:DCIS graad 2, met strikmarker, Zaandam 30 mei 2017).\nMogelijk is deze herkenbaar als een wat massiever aankleurend gebiedje van ca. 2,5 cm rond de marker, binnen het veel grotere pathologisch aankleurende tumorproces.\nVerder beiderzijds geen bijzonderheden.\nIn oksel beiderzijds onverdachte reactieve lymfeklieren\n\nConclusie:\n1.Het histologisch bewezen en gemarkeerde palpabele tumormassa lateraal mamma links (BIRADS 6)\nblijkt onderdeel van een groot diffuus aankleurend tumorproces van het LOQ mamma links van ruim 11 cm (BIRADS 5).\n2.Geen axillaire lymfadenopathie." | |
| ] | |
| row3 = [ | |
| ex3_p1, ex3_p2, ex3_msk, | |
| "1", "0", "0", "1", "0", "0", | |
| "No", 46, "Positive", "Positive", "Negative", | |
| ["Invasive ductal carcinoma", "Lobular carcinoma (IS)"], "Surgery", | |
| [], | |
| "Klinische Gegevens:\nZie digitaal mammapoli formulier in EZIS of BOB-formulier.\n\ncT1cN0 mammacarcinoom rechts LBK\n\nVraagstelling:\nafmeting? \n\n\nVerslag:\nCorrelatie met echografie 06/05/2019.\nMRI mamma: Irregulaire cyclus.\nGedeformeerd aspect van de rechtermamma na recente biopsie.\nCompressie gazen nog in situ. Impressie van de tepelregio.\nZeer veel fibroglandulair weefsel beiderzijds。\nMatige symmetrische achtergrondaankleuring。\nBekende maligniteit rechts centraal craniaal (ca. 14 mm)。\nSnelle contrastopname。In de late fase uitwas - plateau。Geringe diffusie restrictie。\nPeritumoraal geringe non-mass aankleuring waarschijnlijk postbiopsie。Geen andere pathologisch aankleurende afwijkingen rechts。\nRechts axillair geen pathologisch vergrote lymfeklieren。\nIn de linkermamma, linkeraxilla en parasternaal geen pathologie。\nConclusie:\n- suboptimale scan vanwege compressie rechtermamma na recente biopsie。\n- unifocale maligniteit rechts ca. 14 mm。BIRADS 4。" | |
| ] | |
| return [row1, row2, row3] | |
| # ============================== | |
| # Gradio UI | |
| # ============================== | |
| STAGES_T = ['0', '1', '2', '3', '4', 'IS', 'X'] | |
| STAGES_N = ['0', '1', '2', '3', 'X'] | |
| STAGES_M = ['0', '1', 'X'] | |
| STAGES_T_POST = ['0', '1', '2', '3', '4', 'IS', 'X'] | |
| STAGES_N_POST = ['0', '1', '2', '3', 'X'] | |
| STAGES_M_POST = ['0', '1', 'X'] | |
| def _on_metastasis_change(choice): | |
| if choice == "Yes": | |
| msg = ( | |
| "This prediction tool should not be used for patients with distant metastases; " | |
| "such patients are already considered a high-risk group." | |
| ) | |
| return ( | |
| gr.update(value=msg), # Status text | |
| gr.update(visible=False), # main_panel | |
| gr.update(visible=False), # surv5_box | |
| gr.update(visible=False) # suggest_box | |
| ) | |
| else: | |
| return ( | |
| gr.update(value=""), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True) | |
| ) | |
| with gr.Blocks(title="Breast Cancer Prognosis (4-Modal)", theme=gr.themes.Base()) as demo: | |
| gr.HTML( | |
| """ | |
| <style> | |
| .ai-beacon-header { | |
| display:flex; align-items:center; gap:18px; flex-wrap:wrap; | |
| margin:10px 0 8px; | |
| } | |
| .ai-beacon-header img { height:160px; max-width:100%; object-fit:contain; } | |
| .ai-beacon-title { font-weight:700; font-size:clamp(20px, 2.4vw, 28px); line-height:1.25; margin:0; } | |
| .ai-beacon-sub { margin-top:6px; font-size:14px; } | |
| </style> | |
| <div class="ai-beacon-header"> | |
| <img src="https://huggingface.co/spaces/zhang0319/Multimodal_Surv/resolve/main/AI-BEACON_Hugging_face.png" alt="AI-BEACON" /> | |
| <div> | |
| <div class="ai-beacon-title">Breast Cancer Prognosis (MRI + Report + Clinical + Treatment)</div> | |
| <div class="ai-beacon-sub"> | |
| 📧 Contact us: Dr. Tianyu Zhang | |
| <a href="mailto:Tianyu.Zhang@radboudumc.nl">Tianyu.Zhang@radboudumc.nl</a>, | |
| Dr. Ritse Mann <a href="mailto:Ritse.Mann@radboudumc.nl">Ritse.Mann@radboudumc.nl</a> | |
| </div> | |
| </div> | |
| </div> | |
| <hr style="border:none;border-top:1px solid #eee;margin:8px 0 14px;" /> | |
| """ | |
| ) | |
| # NEW: metastasis question | |
| metastasis = gr.Radio( | |
| choices=["No", "Yes"], | |
| value="No", | |
| label="Has the patient already developed distant metastases (M1)?" | |
| ) | |
| with gr.Group(visible=True) as main_panel: | |
| gr.Markdown( | |
| "Upload **Pre/Post MRI and Mask**, paste the **report text**, choose **clinical options**, " | |
| "then get a single prediction score (pred)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img1 = gr.File(label="Pre MRI (NIfTI .nii/.nii.gz)") | |
| img2 = gr.File(label="Post MRI (NIfTI .nii/.nii.gz)") | |
| msk = gr.File(label="Mask (NIfTI .nii/.nii.gz)") | |
| report = gr.Textbox( | |
| label="Report Text", | |
| lines=8, | |
| placeholder="Paste the full report…" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Pre-treatment staging") | |
| T = gr.Dropdown( | |
| choices=STAGES_T, value="0", | |
| label="T_stage", interactive=True | |
| ) | |
| N = gr.Dropdown( | |
| choices=STAGES_N, value="0", | |
| label="N_stage", interactive=True | |
| ) | |
| M = gr.Dropdown( | |
| choices=STAGES_M, value="0", | |
| label="M_stage", | |
| interactive=False | |
| ) | |
| gr.Markdown("### Post-treatment staging") | |
| T_post = gr.Dropdown( | |
| choices=STAGES_T_POST, value="0", | |
| label="T_stage_post", interactive=True | |
| ) | |
| N_post = gr.Dropdown( | |
| choices=STAGES_N_POST, value="0", | |
| label="N_stage_post", interactive=True | |
| ) | |
| M_post = gr.Dropdown( | |
| choices=STAGES_M_POST, value="0", | |
| label="M_stage_post", | |
| interactive=False | |
| ) | |
| gr.Markdown("### Other clinical variables") | |
| fam = gr.Radio( | |
| choices=["No", "Yes"], | |
| value="No", | |
| label="Family history" | |
| ) | |
| age = gr.Number( | |
| label="Age (0–100)", | |
| value=55, | |
| minimum=0, | |
| maximum=100, | |
| precision=0 | |
| ) | |
| er = gr.Radio( | |
| choices=["Negative", "Positive", "Unknown"], | |
| value="Negative", | |
| label="ER" | |
| ) | |
| pr = gr.Radio( | |
| choices=["Negative", "Positive", "Unknown"], | |
| value="Negative", | |
| label="PR" | |
| ) | |
| her2 = gr.Radio( | |
| choices=["Negative", "Positive", "Unknown"], | |
| value="Negative", | |
| label="HER2" | |
| ) | |
| tumor = gr.CheckboxGroup( | |
| choices=[ | |
| "Invasive ductal carcinoma", | |
| "Invasive lobular carcinoma", | |
| "Lobular carcinoma (IS)", | |
| "DCIS", | |
| "Others" | |
| ], | |
| value=["Invasive ductal carcinoma"], | |
| label="Tumor type(s)" | |
| ) | |
| treat = gr.Radio( | |
| choices=["Neoadjuvant", "Surgery"], | |
| value="Surgery", | |
| label="Treatment choice" | |
| ) | |
| neo = gr.CheckboxGroup( | |
| choices=NEO_OPTIONS, | |
| value=["Neoadjuvant chemotherapy"], | |
| label="Neoadjuvant regimen(s)", | |
| visible=False | |
| ) | |
| def _toggle_neo(treatment_choice): | |
| return gr.update(visible=(treatment_choice == "Neoadjuvant")) | |
| treat.change( | |
| fn=_toggle_neo, | |
| inputs=treat, | |
| outputs=neo | |
| ) | |
| examples_rows = get_examples_rows() | |
| gr.Examples( | |
| examples=examples_rows, | |
| inputs=[ | |
| img1, img2, msk, | |
| T, N, M, T_post, N_post, M_post, | |
| fam, age, er, pr, her2, tumor, treat, neo, | |
| report | |
| ], | |
| cache_examples=False, | |
| label="Click an example to fill inputs, then press Run Inference" | |
| ) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| pred_num = gr.Number( | |
| label="Prediction (pred)", | |
| interactive=False, | |
| visible=False, | |
| precision=4 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| km_img = gr.Image( | |
| label="Dynamic survival curve (0–10y)", | |
| interactive=False, | |
| height=400 | |
| ) | |
| with gr.Column(scale=1): | |
| pre_img_view = gr.Image( | |
| label="Pre MRI", | |
| interactive=False, | |
| height=200 | |
| ) | |
| post_img_view = gr.Image( | |
| label="Post MRI", | |
| interactive=False, | |
| height=200 | |
| ) | |
| # —— 结果区域,顺序:图 → Status → Estimated → Suggestion —— | |
| status = gr.Textbox(label="Status / Result", interactive=False) | |
| surv5_box = gr.Textbox( | |
| label="Estimated 3/5/10-year Survival", | |
| interactive=False | |
| ) | |
| suggest_box = gr.Textbox( | |
| label="Clinical Suggestion / Recommendation", | |
| interactive=False | |
| ) | |
| # 绑定 metastasis 切换逻辑 | |
| metastasis.change( | |
| fn=_on_metastasis_change, | |
| inputs=metastasis, | |
| outputs=[status, main_panel, surv5_box, suggest_box] | |
| ) | |
| # 运行按钮 | |
| run_btn.click( | |
| fn=run_infer, | |
| inputs=[ | |
| metastasis, | |
| img1, img2, msk, | |
| T, N, M, T_post, N_post, M_post, | |
| fam, age, er, pr, her2, tumor, treat, neo, | |
| report | |
| ], | |
| outputs=[ | |
| status, pred_num, km_img, surv5_box, suggest_box, | |
| pre_img_view, post_img_view | |
| ] | |
| ) | |
| gr.HTML( | |
| """ | |
| <style> | |
| .logo-container { | |
| margin-top: 30px; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 40px; | |
| flex-wrap: wrap; | |
| } | |
| .logo-container img { | |
| height: 80px; | |
| max-width: 40%; | |
| object-fit: contain; | |
| } | |
| @media (max-width: 600px) { | |
| .logo-container { | |
| gap: 20px; | |
| } | |
| .logo-container img { | |
| height: 60px; | |
| max-width: 45%; | |
| } | |
| } | |
| </style> | |
| <div class="logo-container"> | |
| <img src="https://huggingface.co/spaces/zhang0319/Multimodal_Surv/resolve/main/BIG.png" | |
| alt="BIG logo" /> | |
| <img src="https://huggingface.co/spaces/zhang0319/Multimodal_Surv/resolve/main/RUMC.png" | |
| alt="RUMC logo" /> | |
| <img src="https://huggingface.co/spaces/zhang0319/Multimodal_Surv/resolve/main/NKI-Logo.png" | |
| alt="NKI logo" /> | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |