Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import numpy as np | |
| import nibabel as nib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, RobertaModel | |
| from huggingface_hub import hf_hub_download | |
| # ============================== | |
| # Hard-coded Hub defaults (no UI) | |
| # ============================== | |
| DEFAULT_REPO_ID = "zhang0319/Multimodel_Surv" | |
| DEFAULT_TEXT_SUBFOLDER = "models/radiobert_BigDataset_epoch10" | |
| DEFAULT_CKPT_FILENAME = "weights/20251003_dropout0.3_best_image_report_clin_model20251003_8__.pth" | |
| # ============================== | |
| # Device & AMP | |
| # ============================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| USE_AMP = torch.cuda.is_available() | |
| AMP_DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16 | |
| # ============================== | |
| # Utils | |
| # ============================== | |
| def clip_and_norm(img: np.ndarray, vmin: float, vmax: float) -> np.ndarray: | |
| x = np.clip(img, vmin, vmax) | |
| x = (x - vmin) / max(vmax - vmin, 1e-6) | |
| return x.astype(np.float32) | |
| def load_nii(file) -> np.ndarray | None: | |
| """ | |
| 兼容 Gradio 4/5:File 可能是字符串路径、{'path': ...} 字典,或旧式带 .name 的对象。 | |
| """ | |
| if file is None: | |
| return None | |
| if isinstance(file, (str, os.PathLike)): | |
| path = str(file) | |
| elif isinstance(file, dict) and "path" in file: | |
| path = file["path"] | |
| else: | |
| path = getattr(file, "name", None) | |
| if not path or not os.path.exists(path): | |
| return None | |
| img = nib.load(path) | |
| return img.get_fdata().astype(np.float32) | |
| def clean_report_text(raw: str) -> str: | |
| if raw is None: | |
| return "" | |
| report2 = raw.replace("\n", "") | |
| idx_k = report2.find("Klinische") | |
| if idx_k != -1: | |
| idx_v = report2.find("Verslag", idx_k) | |
| return report2[idx_v:] if idx_v != -1 else report2 | |
| return report2 | |
| def clamp_age_0_100(age_val): | |
| try: | |
| v = int(float(age_val)) | |
| except: | |
| return "-1" | |
| v = max(0, min(100, v)) | |
| return str(v) | |
| # ============================== | |
| # Report encoder | |
| # ============================== | |
| class RadioLOGIC_666(torch.nn.Module): | |
| def __init__(self, repo_id: str, subfolder: str = "", add_pooling_layer: bool = False): | |
| super().__init__() | |
| self.bert = RobertaModel.from_pretrained(repo_id, subfolder=subfolder, add_pooling_layer=add_pooling_layer) | |
| def forward(self, input_id=None, attention_mask=None): | |
| outputs = self.bert(input_ids=input_id, attention_mask=attention_mask, return_dict=False) | |
| sequence_output = outputs[0] | |
| return sequence_output[:, 0] # (B, 768) | |
| # ============================== | |
| # 3D ResNet-18 (feature -> 512) | |
| # ============================== | |
| class BasicBlock3D(nn.Module): | |
| expansion = 1 | |
| def __init__(self, in_planes, planes, stride=1, downsample=None): | |
| super().__init__() | |
| self.conv1 = nn.Conv3d(in_planes, planes, 3, stride, 1, bias=False) | |
| self.bn1 = nn.BatchNorm3d(planes) | |
| self.conv2 = nn.Conv3d(planes, planes, 3, 1, 1, bias=False) | |
| self.bn2 = nn.BatchNorm3d(planes) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| identity = x | |
| out = F.relu(self.bn1(self.conv1(x)), inplace=True) | |
| out = self.bn2(self.conv2(out)) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| return F.relu(out + identity, inplace=True) | |
| class ResNet3D(nn.Module): | |
| def __init__(self, block, layers): | |
| super().__init__() | |
| self.in_planes = 64 | |
| self.conv1 = nn.Conv3d(1, 64, 7, 2, 3, bias=False) | |
| self.bn1 = nn.BatchNorm3d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool3d(3, 2, 1) | |
| self.layer1 = self._make_layer(block, 64, layers[0]) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
| self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
| def _make_layer(self, block, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or self.in_planes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv3d(self.in_planes, planes * block.expansion, 1, stride, bias=False), | |
| nn.BatchNorm3d(planes * block.expansion), | |
| ) | |
| layers = [block(self.in_planes, planes, stride, downsample)] | |
| self.in_planes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append(block(self.in_planes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.relu(self.bn1(self.conv1(x))) | |
| x = self.maxpool(x) | |
| x = self.layer4(self.layer3(self.layer2(self.layer1(x)))) | |
| x = self.avgpool(x) | |
| return torch.flatten(x, 1) # (B,512) | |
| def resnet18_3d_feat(): | |
| return ResNet3D(BasicBlock3D, [2, 2, 2, 2]) | |
| # ============================== | |
| # Clinical one-hot (25 x 12) | |
| # ============================== | |
| def generate_one_hot(T_stage_value, N_stage_value, M_stage_value, | |
| T_stage_post_value, N_stage_post_value, M_stage_post_value, | |
| family_history, age, E, P, H, tumor_types): | |
| one_hot_matrix = np.zeros((25, 12), dtype=np.float32) | |
| T_stage = ['0', '1', '1A', '1B', '1C', '1M', '1MI', '2', '3', '4', '4A', '4B', '4D', 'IS', 'X'] | |
| N_stage = ['0', '0IS', '0S', '1', '1MS', '1S', '1BS', '2A', '2B', '3', '3A', '3B', '3BS', '3C', 'X'] | |
| M_stage = ['0', '1', 'X'] | |
| T_stage_post = ['0', '1', '1A', '1B', '1C', '1MI', '2', '3', '4B', 'IS', 'X', 'Y0', 'Y1', 'Y1A', 'Y1B', | |
| 'Y1C', 'Y1MI', 'Y2', 'Y3', 'Y4A', 'Y4B', 'Y4D', 'YIS', 'YX'] | |
| N_stage_post = ['0', '0I', '0IS', '0S', '1', '1A', '1AS', '1B', '1BS', '1B1', '1B2', '1B3', '1B4', | |
| '1M', '1MI', '1MS', '2', '2A', '2B', '3A', '3B', '3C', 'X'] | |
| M_stage_post = ['0', '1', 'X'] | |
| def _set(lst, val, col): | |
| if val not in ["-1", "-", "None", None]: | |
| s = str(val).strip() | |
| if s in lst: | |
| one_hot_matrix[lst.index(s), col] = 1 | |
| _set(T_stage, T_stage_value, 0) | |
| _set(N_stage, N_stage_value, 1) | |
| _set(M_stage, M_stage_value, 2) | |
| _set(T_stage_post, T_stage_post_value, 3) | |
| _set(N_stage_post, N_stage_post_value, 4) | |
| _set(M_stage_post, M_stage_post_value, 5) | |
| # family history | |
| if isinstance(family_history, str) and 'kanker' in family_history.lower(): | |
| one_hot_matrix[0, 6] = 1 | |
| # age bucket per 4 years (0..100) | |
| try: | |
| if age not in ["-1", None, "None", ""]: | |
| agei = int(float(age)) | |
| agei = max(0, min(100, agei)) | |
| idx = max(0, min(24, agei // 4)) | |
| one_hot_matrix[idx, 7] = 1 | |
| except: | |
| pass | |
| # ER / PR | |
| def _ep_handle(v, col): | |
| if v == "-1": | |
| one_hot_matrix[:, col] = 0 | |
| else: | |
| try: | |
| vi = int(v) | |
| if 0 < vi / 4 < 25: | |
| one_hot_matrix[1, col] = 1 | |
| else: | |
| one_hot_matrix[0, col] = 1 | |
| except: | |
| pass | |
| _ep_handle("90" if E == "90" else "0", 8) # ER | |
| _ep_handle("90" if P == "90" else "0", 9) # PR | |
| # HER2 | |
| if H == "-1": | |
| one_hot_matrix[:, 10] = 0 | |
| else: | |
| try: | |
| hi = int(H) | |
| if 2 < hi < 8: | |
| one_hot_matrix[1,10] = 1 | |
| else: | |
| one_hot_matrix[0,10] = 1 | |
| except: | |
| pass | |
| # tumor types | |
| if tumor_types is not None: | |
| tt = tumor_types.lower() | |
| if ('ductaal' in tt) and ('infiltrerend ductaal' not in tt) and ('intraductaal carcinoom' not in tt) and ('ductaal carcinoma in situ' not in tt): | |
| one_hot_matrix[0, 11] = 1 | |
| if ('infiltrerend ductaal' in tt) and ('intraductaal carcinoom' not in tt): | |
| one_hot_matrix[1, 11] = 1 | |
| if ('lobulair' in tt) and ('infiltrerend lobulair' not in tt): | |
| one_hot_matrix[2, 11] = 1 | |
| if 'infiltrerend lobulair' in tt: | |
| one_hot_matrix[3, 11] = 1 | |
| if 'tubular' in tt: | |
| one_hot_matrix[4, 11] = 1 | |
| if 'mucineus' in tt: | |
| one_hot_matrix[5, 11] = 1 | |
| if 'micropapillair' in tt: | |
| one_hot_matrix[6, 11] = 1 | |
| if ('papillair' in tt) and ('micropapillair' not in tt) and ('intraductaal papillair adenocarcinoom' not in tt): | |
| one_hot_matrix[7, 11] = 1 | |
| if ('ductaal carcinoma in situ' in tt) or ('intraductaal carcinoom' in tt) or ('intraductaal papillair adenocarcinoom' in tt): | |
| one_hot_matrix[8, 11] = 1 | |
| if np.sum(one_hot_matrix[:, 11]) == 0: | |
| one_hot_matrix[9, 11] = 1 | |
| return one_hot_matrix # (25,12) | |
| # ============================== | |
| # Multimodal model — outputs a single pred score | |
| # ============================== | |
| class ClinicalProj(nn.Module): | |
| def __init__(self, out_dim=128): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(25*12, 256), nn.BatchNorm1d(256), nn.ReLU(True), | |
| nn.Dropout(0.1), | |
| nn.Linear(256, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(True) | |
| ) | |
| def forward(self, x): # x: (B,25,12) | |
| b = x.size(0) | |
| return self.net(x.view(b, -1)) | |
| class MultiModalModel(nn.Module): | |
| """ | |
| Returns: | |
| {"pred": (B,1)} # single prediction score | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.img1 = resnet18_3d_feat() | |
| self.img2 = resnet18_3d_feat() | |
| self.report_proj = nn.Sequential(nn.Linear(768, 256, bias=False), nn.BatchNorm1d(256), nn.ReLU(True)) | |
| self.clin_proj = ClinicalProj(out_dim=128) | |
| self.treat_emb = nn.Embedding(2, 32) # 0=NA (Neoadjuvant), 1=Surgery | |
| fuse_dim = 512 + 512 + 256 + 128 + 32 | |
| self.fuse_ln = nn.LayerNorm(fuse_dim) | |
| self.head = nn.Sequential( | |
| nn.Linear(fuse_dim, 512), nn.ReLU(True), nn.Dropout(0.2), | |
| nn.Linear(512, 256), nn.ReLU(True), nn.Dropout(0.2), | |
| ) | |
| self.out = nn.Linear(256, 1) # single score | |
| def forward(self, image1, image2, vector, clin_features, prompt_type_indices): | |
| f1 = self.img1(image1) | |
| f2 = self.img2(image2) | |
| ft = self.report_proj(vector) | |
| fc = self.clin_proj(clin_features) | |
| ftreat = self.treat_emb(prompt_type_indices) | |
| fused = torch.cat([f1, f2, ft, fc, ftreat], dim=1) | |
| fused = self.fuse_ln(fused) | |
| z = self.head(fused) | |
| out = self.out(z) # (B,1) | |
| return {"pred": out} | |
| # ============================== | |
| # UI mapping (English → internal) | |
| # ============================== | |
| def map_ui_to_internal(family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt): | |
| fam = "kanker" if (family_history_opt == "Yes") else "NA" | |
| E = "90" if (er_opt == "Positive") else "0" | |
| P = "90" if (pr_opt == "Positive") else "0" | |
| H = "3" if (her2_opt == "Positive") else "1" | |
| if tumor_type_opt == "Invasive ductal carcinoma": | |
| tumor = "infiltrerend ductaal" | |
| elif tumor_type_opt == "Lobular carcinoma": | |
| tumor = "infiltrerend lobulair" | |
| else: | |
| tumor = "others" | |
| primary = "NA" if (treatment_opt == "Neoadjuvant") else "Surgery" | |
| return fam, E, P, H, tumor, primary | |
| # ============================== | |
| # Inference (GPU-decorated) | |
| # ============================== | |
| def run_infer( | |
| img1_file, img2_file, mask_file, | |
| T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post, | |
| family_history_opt, age_val, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt, | |
| report_text | |
| ): | |
| if img1_file is None or img2_file is None or mask_file is None: | |
| return "❌ Please upload Pre MRI, Post MRI, and Mask.", None | |
| # load volumes | |
| img1 = load_nii(img1_file) | |
| img2 = load_nii(img2_file) | |
| msk = load_nii(mask_file) | |
| if img1 is None or img2 is None or msk is None: | |
| return "❌ Failed to load NIfTI files.", None | |
| # preprocess + mask weighting | |
| img1 = clip_and_norm(img1, 0, 3000) | |
| img2 = clip_and_norm(img2, 0, 3000) | |
| mask_modified = np.where(msk == 1, 10.0, 1.0).astype(np.float32) | |
| img1 = np.expand_dims(np.expand_dims(img1 * mask_modified, axis=0), axis=0) # (1,1,D,H,W) | |
| img2 = np.expand_dims(np.expand_dims(img2 * mask_modified, axis=0), axis=0) | |
| image1 = torch.from_numpy(img1).to(DEVICE) | |
| image2 = torch.from_numpy(img2).to(DEVICE) | |
| # tokenizer & report encoder from Hub (hard-coded) | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_REPO_ID, subfolder=DEFAULT_TEXT_SUBFOLDER, use_fast=True) | |
| clean_txt = clean_report_text(report_text or "") | |
| code = tokenizer(clean_txt, padding='max_length', max_length=512, truncation=True, return_tensors="pt") | |
| input_id = code['input_ids'].to(DEVICE) | |
| attn_mask = code['attention_mask'].to(DEVICE) | |
| report_net = RadioLOGIC_666(repo_id=DEFAULT_REPO_ID, subfolder=DEFAULT_TEXT_SUBFOLDER).to(DEVICE).eval() | |
| vector768 = report_net(input_id=input_id, attention_mask=attn_mask) # (1,768) | |
| # clinical | |
| fam, E, P, H, tumor, primary = map_ui_to_internal(family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt) | |
| age = clamp_age_0_100(age_val) # "0..100" or "-1" | |
| clin_np = generate_one_hot(T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post, | |
| fam, age, E, P, H, tumor) # (25,12) | |
| clin = torch.from_numpy(clin_np).unsqueeze(0).to(DEVICE) | |
| prompt_type_indices = torch.tensor([0 if primary == "NA" else 1], dtype=torch.long, device=DEVICE) | |
| # model + weights (always from Hub defaults) | |
| model = MultiModalModel().to(DEVICE).eval() | |
| try: | |
| ckpt_path = hf_hub_download(repo_id=DEFAULT_REPO_ID, filename=DEFAULT_CKPT_FILENAME) | |
| state = torch.load(ckpt_path, map_location=DEVICE) | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| new_state = {k.replace("module.", ""): v for k, v in state.items()} | |
| model.load_state_dict(new_state, strict=False) | |
| except Exception as e: | |
| print(f"[WARN] Could not load default weights from Hub: {e}") | |
| # forward | |
| if USE_AMP and DEVICE.type == "cuda": | |
| with torch.autocast(device_type="cuda", dtype=AMP_DTYPE): | |
| out = model(image1, image2, vector768, clin, prompt_type_indices) | |
| else: | |
| out = model(image1, image2, vector768, clin, prompt_type_indices) | |
| pred = float(out["pred"].squeeze().item()) | |
| return f"✅ Inference done. Prediction (pred): {pred:.4f}", pred | |
| # ============================== | |
| # Gradio UI (English, no Hub settings) | |
| # ============================== | |
| STAGES_T = ['-1','0','1','1A','1B','1C','1M','1MI','2','3','4','4A','4B','4D','IS','X'] | |
| STAGES_N = ['-1','0','0IS','0S','1','1MS','1S','1BS','2A','2B','3','3A','3B','3BS','3C','X'] | |
| STAGES_M = ['-1','0','1','X'] | |
| STAGES_T_POST = ['-1','0','1','1A','1B','1C','1MI','2','3','4B','IS','X','Y0','Y1','Y1A','Y1B','Y1C','Y1MI','Y2','Y3','Y4A','Y4B','Y4D','YIS','YX'] | |
| STAGES_N_POST = ['-1','0','0I','0IS','0S','1','1A','1AS','1B','1BS','1B1','1B2','1B3','1B4','1M','1MI','1MS','2','2A','2B','3A','3B','3C','X'] | |
| STAGES_M_POST = ['-1','0','1','X'] | |
| with gr.Blocks(title="Breast Cancer Prognosis (4-Modal)", theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# 🩺 Breast Cancer Prognosis (MRI + Report + Clinical + Treatment)") | |
| gr.Markdown("Upload **Pre/Post MRI and Mask**, paste the **report text**, choose **clinical options**, then get a single prediction score (pred).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img1 = gr.File(label="Pre MRI (NIfTI .nii/.nii.gz)") | |
| img2 = gr.File(label="Post MRI (NIfTI .nii/.nii.gz)") | |
| msk = gr.File(label="Mask (NIfTI .nii/.nii.gz)") | |
| report = gr.Textbox(label="Report Text", lines=8, placeholder="Paste the full report…") | |
| with gr.Column(): | |
| gr.Markdown("### Pre-treatment staging") | |
| T = gr.Dropdown(choices=STAGES_T, value="-1", label="T_stage", interactive=True) | |
| N = gr.Dropdown(choices=STAGES_N, value="-1", label="N_stage", interactive=True) | |
| M = gr.Dropdown(choices=STAGES_M, value="-1", label="M_stage", interactive=True) | |
| gr.Markdown("### Post-treatment staging") | |
| T_post = gr.Dropdown(choices=STAGES_T_POST, value="-1", label="T_stage_post", interactive=True) | |
| N_post = gr.Dropdown(choices=STAGES_N_POST, value="-1", label="N_stage_post", interactive=True) | |
| M_post = gr.Dropdown(choices=STAGES_M_POST, value="-1", label="M_stage_post", interactive=True) | |
| gr.Markdown("### Other clinical variables") | |
| fam = gr.Radio(choices=["No","Yes"], value="No", label="Family history") | |
| age = gr.Number(label="Age (0–100)", value=55, minimum=0, maximum=100, precision=0) | |
| er = gr.Radio(choices=["Negative","Positive"], value="Negative", label="ER") | |
| pr = gr.Radio(choices=["Negative","Positive"], value="Negative", label="PR") | |
| her2 = gr.Radio(choices=["Negative","Positive"], value="Negative", label="HER2") | |
| tumor = gr.Radio(choices=["Invasive ductal carcinoma","Lobular carcinoma","Others"], | |
| value="Invasive ductal carcinoma", label="Tumor type") | |
| treat = gr.Radio(choices=["Neoadjuvant","Surgery"], value="Surgery", label="Treatment choice") | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| status = gr.Textbox(label="Status / Result", interactive=False) | |
| pred_num = gr.Number(label="Prediction (pred)", interactive=False) | |
| run_btn.click( | |
| fn=run_infer, | |
| inputs=[img1, img2, msk, | |
| T, N, M, T_post, N_post, M_post, fam, age, er, pr, her2, tumor, treat, | |
| report], | |
| outputs=[status, pred_num] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |