zhang0319's picture
Update app.py
dd6729b verified
# app.py
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import spaces
from transformers import AutoTokenizer, RobertaModel
from huggingface_hub import hf_hub_download
# ==============================
# Hard-coded Hub defaults (no UI)
# ==============================
DEFAULT_REPO_ID = "zhang0319/Multimodel_Surv"
DEFAULT_TEXT_SUBFOLDER = "models/radiobert_BigDataset_epoch10"
DEFAULT_CKPT_FILENAME = "weights/20251003_dropout0.3_best_image_report_clin_model20251003_8__.pth"
# ==============================
# Device & AMP
# ==============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = torch.cuda.is_available()
AMP_DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
# ==============================
# Utils
# ==============================
def clip_and_norm(img: np.ndarray, vmin: float, vmax: float) -> np.ndarray:
x = np.clip(img, vmin, vmax)
x = (x - vmin) / max(vmax - vmin, 1e-6)
return x.astype(np.float32)
def load_nii(file) -> np.ndarray | None:
"""
兼容 Gradio 4/5:File 可能是字符串路径、{'path': ...} 字典,或旧式带 .name 的对象。
"""
if file is None:
return None
if isinstance(file, (str, os.PathLike)):
path = str(file)
elif isinstance(file, dict) and "path" in file:
path = file["path"]
else:
path = getattr(file, "name", None)
if not path or not os.path.exists(path):
return None
img = nib.load(path)
return img.get_fdata().astype(np.float32)
def clean_report_text(raw: str) -> str:
if raw is None:
return ""
report2 = raw.replace("\n", "")
idx_k = report2.find("Klinische")
if idx_k != -1:
idx_v = report2.find("Verslag", idx_k)
return report2[idx_v:] if idx_v != -1 else report2
return report2
def clamp_age_0_100(age_val):
try:
v = int(float(age_val))
except:
return "-1"
v = max(0, min(100, v))
return str(v)
# ==============================
# Report encoder
# ==============================
class RadioLOGIC_666(torch.nn.Module):
def __init__(self, repo_id: str, subfolder: str = "", add_pooling_layer: bool = False):
super().__init__()
self.bert = RobertaModel.from_pretrained(repo_id, subfolder=subfolder, add_pooling_layer=add_pooling_layer)
def forward(self, input_id=None, attention_mask=None):
outputs = self.bert(input_ids=input_id, attention_mask=attention_mask, return_dict=False)
sequence_output = outputs[0]
return sequence_output[:, 0] # (B, 768)
# ==============================
# 3D ResNet-18 (feature -> 512)
# ==============================
class BasicBlock3D(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv3d(in_planes, planes, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
def forward(self, x):
identity = x
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = self.bn2(self.conv2(out))
if self.downsample is not None:
identity = self.downsample(x)
return F.relu(out + identity, inplace=True)
class ResNet3D(nn.Module):
def __init__(self, block, layers):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv3d(1, 64, 7, 2, 3, bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(3, 2, 1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.in_planes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.in_planes, planes * block.expansion, 1, stride, bias=False),
nn.BatchNorm3d(planes * block.expansion),
)
layers = [block(self.in_planes, planes, stride, downsample)]
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer4(self.layer3(self.layer2(self.layer1(x))))
x = self.avgpool(x)
return torch.flatten(x, 1) # (B,512)
def resnet18_3d_feat():
return ResNet3D(BasicBlock3D, [2, 2, 2, 2])
# ==============================
# Clinical one-hot (25 x 12)
# ==============================
def generate_one_hot(T_stage_value, N_stage_value, M_stage_value,
T_stage_post_value, N_stage_post_value, M_stage_post_value,
family_history, age, E, P, H, tumor_types):
one_hot_matrix = np.zeros((25, 12), dtype=np.float32)
T_stage = ['0', '1', '1A', '1B', '1C', '1M', '1MI', '2', '3', '4', '4A', '4B', '4D', 'IS', 'X']
N_stage = ['0', '0IS', '0S', '1', '1MS', '1S', '1BS', '2A', '2B', '3', '3A', '3B', '3BS', '3C', 'X']
M_stage = ['0', '1', 'X']
T_stage_post = ['0', '1', '1A', '1B', '1C', '1MI', '2', '3', '4B', 'IS', 'X', 'Y0', 'Y1', 'Y1A', 'Y1B',
'Y1C', 'Y1MI', 'Y2', 'Y3', 'Y4A', 'Y4B', 'Y4D', 'YIS', 'YX']
N_stage_post = ['0', '0I', '0IS', '0S', '1', '1A', '1AS', '1B', '1BS', '1B1', '1B2', '1B3', '1B4',
'1M', '1MI', '1MS', '2', '2A', '2B', '3A', '3B', '3C', 'X']
M_stage_post = ['0', '1', 'X']
def _set(lst, val, col):
if val not in ["-1", "-", "None", None]:
s = str(val).strip()
if s in lst:
one_hot_matrix[lst.index(s), col] = 1
_set(T_stage, T_stage_value, 0)
_set(N_stage, N_stage_value, 1)
_set(M_stage, M_stage_value, 2)
_set(T_stage_post, T_stage_post_value, 3)
_set(N_stage_post, N_stage_post_value, 4)
_set(M_stage_post, M_stage_post_value, 5)
# family history
if isinstance(family_history, str) and 'kanker' in family_history.lower():
one_hot_matrix[0, 6] = 1
# age bucket per 4 years (0..100)
try:
if age not in ["-1", None, "None", ""]:
agei = int(float(age))
agei = max(0, min(100, agei))
idx = max(0, min(24, agei // 4))
one_hot_matrix[idx, 7] = 1
except:
pass
# ER / PR
def _ep_handle(v, col):
if v == "-1":
one_hot_matrix[:, col] = 0
else:
try:
vi = int(v)
if 0 < vi / 4 < 25:
one_hot_matrix[1, col] = 1
else:
one_hot_matrix[0, col] = 1
except:
pass
_ep_handle("90" if E == "90" else "0", 8) # ER
_ep_handle("90" if P == "90" else "0", 9) # PR
# HER2
if H == "-1":
one_hot_matrix[:, 10] = 0
else:
try:
hi = int(H)
if 2 < hi < 8:
one_hot_matrix[1,10] = 1
else:
one_hot_matrix[0,10] = 1
except:
pass
# tumor types
if tumor_types is not None:
tt = tumor_types.lower()
if ('ductaal' in tt) and ('infiltrerend ductaal' not in tt) and ('intraductaal carcinoom' not in tt) and ('ductaal carcinoma in situ' not in tt):
one_hot_matrix[0, 11] = 1
if ('infiltrerend ductaal' in tt) and ('intraductaal carcinoom' not in tt):
one_hot_matrix[1, 11] = 1
if ('lobulair' in tt) and ('infiltrerend lobulair' not in tt):
one_hot_matrix[2, 11] = 1
if 'infiltrerend lobulair' in tt:
one_hot_matrix[3, 11] = 1
if 'tubular' in tt:
one_hot_matrix[4, 11] = 1
if 'mucineus' in tt:
one_hot_matrix[5, 11] = 1
if 'micropapillair' in tt:
one_hot_matrix[6, 11] = 1
if ('papillair' in tt) and ('micropapillair' not in tt) and ('intraductaal papillair adenocarcinoom' not in tt):
one_hot_matrix[7, 11] = 1
if ('ductaal carcinoma in situ' in tt) or ('intraductaal carcinoom' in tt) or ('intraductaal papillair adenocarcinoom' in tt):
one_hot_matrix[8, 11] = 1
if np.sum(one_hot_matrix[:, 11]) == 0:
one_hot_matrix[9, 11] = 1
return one_hot_matrix # (25,12)
# ==============================
# Multimodal model — outputs a single pred score
# ==============================
class ClinicalProj(nn.Module):
def __init__(self, out_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(25*12, 256), nn.BatchNorm1d(256), nn.ReLU(True),
nn.Dropout(0.1),
nn.Linear(256, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(True)
)
def forward(self, x): # x: (B,25,12)
b = x.size(0)
return self.net(x.view(b, -1))
class MultiModalModel(nn.Module):
"""
Returns:
{"pred": (B,1)} # single prediction score
"""
def __init__(self):
super().__init__()
self.img1 = resnet18_3d_feat()
self.img2 = resnet18_3d_feat()
self.report_proj = nn.Sequential(nn.Linear(768, 256, bias=False), nn.BatchNorm1d(256), nn.ReLU(True))
self.clin_proj = ClinicalProj(out_dim=128)
self.treat_emb = nn.Embedding(2, 32) # 0=NA (Neoadjuvant), 1=Surgery
fuse_dim = 512 + 512 + 256 + 128 + 32
self.fuse_ln = nn.LayerNorm(fuse_dim)
self.head = nn.Sequential(
nn.Linear(fuse_dim, 512), nn.ReLU(True), nn.Dropout(0.2),
nn.Linear(512, 256), nn.ReLU(True), nn.Dropout(0.2),
)
self.out = nn.Linear(256, 1) # single score
def forward(self, image1, image2, vector, clin_features, prompt_type_indices):
f1 = self.img1(image1)
f2 = self.img2(image2)
ft = self.report_proj(vector)
fc = self.clin_proj(clin_features)
ftreat = self.treat_emb(prompt_type_indices)
fused = torch.cat([f1, f2, ft, fc, ftreat], dim=1)
fused = self.fuse_ln(fused)
z = self.head(fused)
out = self.out(z) # (B,1)
return {"pred": out}
# ==============================
# UI mapping (English → internal)
# ==============================
def map_ui_to_internal(family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt):
fam = "kanker" if (family_history_opt == "Yes") else "NA"
E = "90" if (er_opt == "Positive") else "0"
P = "90" if (pr_opt == "Positive") else "0"
H = "3" if (her2_opt == "Positive") else "1"
if tumor_type_opt == "Invasive ductal carcinoma":
tumor = "infiltrerend ductaal"
elif tumor_type_opt == "Lobular carcinoma":
tumor = "infiltrerend lobulair"
else:
tumor = "others"
primary = "NA" if (treatment_opt == "Neoadjuvant") else "Surgery"
return fam, E, P, H, tumor, primary
# ==============================
# Inference (GPU-decorated)
# ==============================
@spaces.GPU
@torch.no_grad()
def run_infer(
img1_file, img2_file, mask_file,
T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post,
family_history_opt, age_val, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt,
report_text
):
if img1_file is None or img2_file is None or mask_file is None:
return "❌ Please upload Pre MRI, Post MRI, and Mask.", None
# load volumes
img1 = load_nii(img1_file)
img2 = load_nii(img2_file)
msk = load_nii(mask_file)
if img1 is None or img2 is None or msk is None:
return "❌ Failed to load NIfTI files.", None
# preprocess + mask weighting
img1 = clip_and_norm(img1, 0, 3000)
img2 = clip_and_norm(img2, 0, 3000)
mask_modified = np.where(msk == 1, 10.0, 1.0).astype(np.float32)
img1 = np.expand_dims(np.expand_dims(img1 * mask_modified, axis=0), axis=0) # (1,1,D,H,W)
img2 = np.expand_dims(np.expand_dims(img2 * mask_modified, axis=0), axis=0)
image1 = torch.from_numpy(img1).to(DEVICE)
image2 = torch.from_numpy(img2).to(DEVICE)
# tokenizer & report encoder from Hub (hard-coded)
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_REPO_ID, subfolder=DEFAULT_TEXT_SUBFOLDER, use_fast=True)
clean_txt = clean_report_text(report_text or "")
code = tokenizer(clean_txt, padding='max_length', max_length=512, truncation=True, return_tensors="pt")
input_id = code['input_ids'].to(DEVICE)
attn_mask = code['attention_mask'].to(DEVICE)
report_net = RadioLOGIC_666(repo_id=DEFAULT_REPO_ID, subfolder=DEFAULT_TEXT_SUBFOLDER).to(DEVICE).eval()
vector768 = report_net(input_id=input_id, attention_mask=attn_mask) # (1,768)
# clinical
fam, E, P, H, tumor, primary = map_ui_to_internal(family_history_opt, er_opt, pr_opt, her2_opt, tumor_type_opt, treatment_opt)
age = clamp_age_0_100(age_val) # "0..100" or "-1"
clin_np = generate_one_hot(T_stage, N_stage, M_stage, T_stage_post, N_stage_post, M_stage_post,
fam, age, E, P, H, tumor) # (25,12)
clin = torch.from_numpy(clin_np).unsqueeze(0).to(DEVICE)
prompt_type_indices = torch.tensor([0 if primary == "NA" else 1], dtype=torch.long, device=DEVICE)
# model + weights (always from Hub defaults)
model = MultiModalModel().to(DEVICE).eval()
try:
ckpt_path = hf_hub_download(repo_id=DEFAULT_REPO_ID, filename=DEFAULT_CKPT_FILENAME)
state = torch.load(ckpt_path, map_location=DEVICE)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
new_state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(new_state, strict=False)
except Exception as e:
print(f"[WARN] Could not load default weights from Hub: {e}")
# forward
if USE_AMP and DEVICE.type == "cuda":
with torch.autocast(device_type="cuda", dtype=AMP_DTYPE):
out = model(image1, image2, vector768, clin, prompt_type_indices)
else:
out = model(image1, image2, vector768, clin, prompt_type_indices)
pred = float(out["pred"].squeeze().item())
return f"✅ Inference done. Prediction (pred): {pred:.4f}", pred
# ==============================
# Gradio UI (English, no Hub settings)
# ==============================
STAGES_T = ['-1','0','1','1A','1B','1C','1M','1MI','2','3','4','4A','4B','4D','IS','X']
STAGES_N = ['-1','0','0IS','0S','1','1MS','1S','1BS','2A','2B','3','3A','3B','3BS','3C','X']
STAGES_M = ['-1','0','1','X']
STAGES_T_POST = ['-1','0','1','1A','1B','1C','1MI','2','3','4B','IS','X','Y0','Y1','Y1A','Y1B','Y1C','Y1MI','Y2','Y3','Y4A','Y4B','Y4D','YIS','YX']
STAGES_N_POST = ['-1','0','0I','0IS','0S','1','1A','1AS','1B','1BS','1B1','1B2','1B3','1B4','1M','1MI','1MS','2','2A','2B','3A','3B','3C','X']
STAGES_M_POST = ['-1','0','1','X']
with gr.Blocks(title="Breast Cancer Prognosis (4-Modal)", theme=gr.themes.Base()) as demo:
gr.Markdown("# 🩺 Breast Cancer Prognosis (MRI + Report + Clinical + Treatment)")
gr.Markdown("Upload **Pre/Post MRI and Mask**, paste the **report text**, choose **clinical options**, then get a single prediction score (pred).")
with gr.Row():
with gr.Column():
img1 = gr.File(label="Pre MRI (NIfTI .nii/.nii.gz)")
img2 = gr.File(label="Post MRI (NIfTI .nii/.nii.gz)")
msk = gr.File(label="Mask (NIfTI .nii/.nii.gz)")
report = gr.Textbox(label="Report Text", lines=8, placeholder="Paste the full report…")
with gr.Column():
gr.Markdown("### Pre-treatment staging")
T = gr.Dropdown(choices=STAGES_T, value="-1", label="T_stage", interactive=True)
N = gr.Dropdown(choices=STAGES_N, value="-1", label="N_stage", interactive=True)
M = gr.Dropdown(choices=STAGES_M, value="-1", label="M_stage", interactive=True)
gr.Markdown("### Post-treatment staging")
T_post = gr.Dropdown(choices=STAGES_T_POST, value="-1", label="T_stage_post", interactive=True)
N_post = gr.Dropdown(choices=STAGES_N_POST, value="-1", label="N_stage_post", interactive=True)
M_post = gr.Dropdown(choices=STAGES_M_POST, value="-1", label="M_stage_post", interactive=True)
gr.Markdown("### Other clinical variables")
fam = gr.Radio(choices=["No","Yes"], value="No", label="Family history")
age = gr.Number(label="Age (0–100)", value=55, minimum=0, maximum=100, precision=0)
er = gr.Radio(choices=["Negative","Positive"], value="Negative", label="ER")
pr = gr.Radio(choices=["Negative","Positive"], value="Negative", label="PR")
her2 = gr.Radio(choices=["Negative","Positive"], value="Negative", label="HER2")
tumor = gr.Radio(choices=["Invasive ductal carcinoma","Lobular carcinoma","Others"],
value="Invasive ductal carcinoma", label="Tumor type")
treat = gr.Radio(choices=["Neoadjuvant","Surgery"], value="Surgery", label="Treatment choice")
run_btn = gr.Button("Run Inference", variant="primary")
status = gr.Textbox(label="Status / Result", interactive=False)
pred_num = gr.Number(label="Prediction (pred)", interactive=False)
run_btn.click(
fn=run_infer,
inputs=[img1, img2, msk,
T, N, M, T_post, N_post, M_post, fam, age, er, pr, her2, tumor, treat,
report],
outputs=[status, pred_num]
)
if __name__ == "__main__":
demo.launch()