Chenyixin's picture
Update app.py
fb1b122 verified
import os, json, csv, tempfile
import numpy as np
import pandas as pd
import SimpleITK as sitk
import gradio as gr
# =========================
# 0) —— 可配置 / 内置资源 ——
# =========================
# ===== IMD(学生化成对残差)所需的训练期矩阵 =====
DEMO_SUV = "SUVPET_demo.nii.gz"
DEMO_SEG = "segment_demo.nii.gz"
IMD_SLOPE_CSV = "imd_slope.csv" # 形状 NxN,行列名为 ROI
IMD_INTERCEPT_CSV = "imd_intercept.csv" # 形状 NxN,行列名为 ROI
IMD_SIGMA_CSV = "imd_sigma.csv" # 形状 NxN,行列名为 ROI
IMD_REGIONS_JSON = "imd_regions.json" # 训练期 ROI 顺序(list[str])
def load_imd_params():
"""读取 slope/intercept/sigma 以及 ROI 顺序。全部存在才返回 (slope, interc, sigma, regions_order),否则返回 None。"""
if not (os.path.exists(IMD_SLOPE_CSV) and os.path.exists(IMD_INTERCEPT_CSV) and os.path.exists(IMD_SIGMA_CSV) and os.path.exists(IMD_REGIONS_JSON)):
return None
slope = pd.read_csv(IMD_SLOPE_CSV, index_col=0)
interc = pd.read_csv(IMD_INTERCEPT_CSV, index_col=0)
sigma = pd.read_csv(IMD_SIGMA_CSV, index_col=0)
with open(IMD_REGIONS_JSON, "r", encoding="utf-8") as f:
regions_order = json.load(f)
# 基本健壮性:保证是方阵且顺序匹配
assert slope.shape == interc.shape == sigma.shape, "IMD matrices shape mismatch"
assert list(slope.index) == list(slope.columns), "slope must be square with labels"
return slope, interc, sigma, regions_order
def compute_imd_matrix_for_subject(resid_series: pd.Series, slope: pd.DataFrame, interc: pd.DataFrame, sigma: pd.DataFrame, regions_order: list[str]):
"""
resid_series: 单个样本的残差,index 为 ROI 名(与 df_resid.columns 一致)
返回:DataFrame (N×N),索引与列均为 regions_order
"""
# 按训练顺序对齐残差,缺失置 0(也可置 np.nan,再掩码;这里用 0 更稳)
x = resid_series.reindex(regions_order).fillna(0.0).astype(float).values # shape (N,)
N = len(regions_order)
sr_mat = np.zeros((N, N), dtype=np.float32)
# 取 numpy 数组提升速度
b1 = slope.reindex(index=regions_order, columns=regions_order).values
b0 = interc.reindex(index=regions_order, columns=regions_order).values
sg = sigma.reindex(index=regions_order, columns=regions_order).values
# 计算 sr_ij = | y_i - (b0_ij + b1_ij * x_j) | / sigma_ij
# y_i 是 x[i];x_j 是 x[j]
for i in range(N):
yi = x[i]
# 向量化一行:对所有 j
yhat_row = b0[i, :] + b1[i, :] * x
denom = sg[i, :]
# 避免除 0
safe = denom != 0
row = np.zeros(N, dtype=np.float32)
row[safe] = np.abs(yi - yhat_row[safe]) / denom[safe]
sr_mat[i, :] = row
np.clip(sr_mat, 0.0, 4, out=sr_mat)
# 对角线设 0
np.fill_diagonal(sr_mat, 0.0)
return pd.DataFrame(sr_mat, index=regions_order, columns=regions_order)
# 画 IMD 热图并导出
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def save_imd_heatmap_and_csv(imd_df: pd.DataFrame, title="IMD (Studentized Residuals)", vmax=None):
fd_png, png_path = tempfile.mkstemp(suffix=".png", prefix="imd_", dir="/tmp"); os.close(fd_png)
fd_csv, csv_path = tempfile.mkstemp(suffix=".csv", prefix="imd_", dir="/tmp"); os.close(fd_csv)
# 保存 CSV
imd_df.to_csv(csv_path)
# 画图
plt.figure(figsize=(10, 9))
vmin = np.nanmin(imd_df.values)
if vmax is None:
vmax = np.nanpercentile(imd_df.values, 99) if np.isfinite(vmin) else None
im = plt.imshow(imd_df.values, aspect="auto", vmin=vmin if np.isfinite(vmin) else None, vmax=vmax)
plt.xticks(range(len(imd_df.columns)), imd_df.columns, rotation=90, fontsize=6)
plt.yticks(range(len(imd_df.index)), imd_df.index, fontsize=6)
cbar = plt.colorbar(im)
cbar.ax.tick_params(labelsize=8)
plt.title(title, fontsize=12, pad=10)
plt.tight_layout()
plt.savefig(png_path, dpi=300)
plt.close()
return png_path, csv_path
ATLAS_LABELS_JSON = "atlas_labels.json" # 可留空 "" 表示没有
BUILTIN_TRAIN_MEANS_JSON = "builtin_train_means.json"
BUILTIN_TRAIN_COEFS_JSON = "train_coefs.json"
BRAIN_REF = ['Brainstem'] # 脑参考
BODY_REF = 'liver' # 体参考(肝)
# —— 内置 RMC(R² 矩阵)的 CSV 路径(请把你的 r2_df 导出成这个 CSV 放到仓库根目录)——
RMC_R2_CSV = "rmc_r2.csv" # 形如:第一行列名是 ROI,第一列也是 ROI,内部为 R² 浮点
import matplotlib
matplotlib.use("Agg") # 后端设为无界面
import matplotlib.pyplot as plt
def load_rmc_r2_df():
if os.path.exists(RMC_R2_CSV):
df = pd.read_csv(RMC_R2_CSV, index_col=0)
# 基本健壮性:确保是方阵
if df.shape[0] == df.shape[1]:
return df
return None
def plot_r2_heatmap_to_tmp(df: pd.DataFrame, title="Reference R² Heatmap", vmax=1.0):
"""把 R² 矩阵画成热图,保存到 /tmp,返回图片文件路径 & 同时返回 CSV 路径用于下载"""
fd, png_path = tempfile.mkstemp(suffix=".png", prefix="rmc_r2_", dir="/tmp"); os.close(fd)
plt.figure(figsize=(10, 9))
im = plt.imshow(df.values, aspect="auto", vmin=np.nanmin(df.values), vmax=vmax)
plt.xticks(range(len(df.columns)), df.columns, rotation=90, fontsize=6)
plt.yticks(range(len(df.index)), df.index, fontsize=6)
cbar = plt.colorbar(im)
cbar.ax.tick_params(labelsize=8)
plt.title(title, fontsize=12, pad=10)
plt.tight_layout()
plt.savefig(png_path, dpi=300)
plt.close()
# 为了便于用户下载,我们也把(可能已存在的)CSV复制一份到 /tmp,或者新存一份
if os.path.exists(RMC_R2_CSV):
csv_out_path = RMC_R2_CSV # 直接用仓库文件路径(gr.File 可以下载仓库内文件)
else:
fd2, csv_out_path = tempfile.mkstemp(suffix=".csv", prefix="rmc_r2_", dir="/tmp"); os.close(fd2)
df.to_csv(csv_out_path)
return png_path, csv_out_path
# 默认使用 segmentation 中出现的所有非零标签;如需固定范围可设为 range(3, 215)
USE_LABEL_RANGE = None
# ============ 工具函数 ============
def load_json_if_exists(path):
if path and os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return None
ATLAS_LABELS = load_json_if_exists(ATLAS_LABELS_JSON) or {}
TRAIN_MEANS = load_json_if_exists(BUILTIN_TRAIN_MEANS_JSON) or {}
TRAIN_COEFS = load_json_if_exists(BUILTIN_TRAIN_COEFS_JSON) or {}
def to_uint8(arr2d, vmin=None, vmax=None):
arr = np.asarray(arr2d, dtype=np.float32)
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
if vmin is None: vmin = np.percentile(arr, 2)
if vmax is None: vmax = np.percentile(arr, 98)
if not np.isfinite(vmin): vmin = float(np.min(arr))
if not np.isfinite(vmax): vmax = float(np.max(arr))
if vmax <= vmin: vmax = vmin + 1e-6
arr = np.clip(arr, vmin, vmax)
arr = (arr - vmin) / (vmax - vmin + 1e-12)
return (arr * 255.0).astype(np.uint8)
def roi_name(label_int):
return ATLAS_LABELS.get(str(label_int), f"ROI_{label_int}")
def read_nii(path):
"""读取 NIfTI 为 SimpleITK Image。"""
return sitk.ReadImage(path)
def img_to_array(img, dtype=np.float32, pick_first_if_4d=True):
"""
SimpleITK 图像 → numpy 数组。
SimpleITK.GetArrayFromImage 返回 (z, y, x) 顺序。
"""
arr = sitk.GetArrayFromImage(img).astype(dtype, copy=False)
# 处理 4D(t,z,y,x)或 (z,y,x,t) 这类:简单取第 0 帧
if arr.ndim == 4 and pick_first_if_4d:
# 兼容两种常见顺序
if arr.shape[0] <= 16: # 多数情况下 t 在最前
arr = arr[0]
else:
arr = arr[..., 0]
return arr # (z,y,x)
def ensure_same_grid(seg_img, ref_img):
"""
确保 segmentation 与 SUV 在相同网格上:
- 若尺寸/间距/方向不同,则用最近邻重采样 seg 到 SUV 栅格。
返回与 ref_img 对齐的 seg_img(SimpleITK Image)。
"""
same_size = list(seg_img.GetSize()) == list(ref_img.GetSize())
same_spacing = np.allclose(seg_img.GetSpacing(), ref_img.GetSpacing())
same_origin = np.allclose(seg_img.GetOrigin(), ref_img.GetOrigin())
same_direction = np.allclose(seg_img.GetDirection(), ref_img.GetDirection())
if same_size and same_spacing and same_origin and same_direction:
return seg_img
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(ref_img)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetOutputPixelType(seg_img.GetPixelID())
seg_img_resampled = resampler.Execute(seg_img)
return seg_img_resampled
def mid_slices_for_preview_sitk(suv_img):
"""
仅用于预览三向切片(来自 SUV)。
SimpleITK 数组是 (z,y,x),我们做适当转置/翻转以便更直观。
"""
data = img_to_array(suv_img, dtype=np.float32)
z, y, x = data.shape
z_mid, y_mid, x_mid = z//2, y//2, x//2
# axial: 固定 z
axial = np.flipud(data[z_mid, :, :]) # (y, x)
# coronal: 固定 y
coronal = np.flipud(data[:, y_mid, :].T) # (x, z) -> (z, x) 转置为 (x,z),再 flipud 成 (H,W)
# sagittal: 固定 x
sagittal = np.flipud(data[:, :, x_mid].T) # (y, z) -> (z, y) 转置为 (y,z),再 flipud
return to_uint8(axial), to_uint8(coronal), to_uint8(sagittal)
def extract_roi_means_sitk(suv_img, seg_img, pbar: gr.Progress):
"""
用 SimpleITK 提取每个 ROI 的 mean/max SUV 和直方图。
返回 dict[roi_name] = {"meansuv":..., "maxsuv":..., "hist":[...]}
"""
# 先对齐分割到 SUV 栅格
seg_img = ensure_same_grid(seg_img, suv_img)
suv = img_to_array(suv_img, dtype=np.float32) # (z,y,x)
seg = img_to_array(seg_img, dtype=np.float32) # 可能是 float,下面将其转 int 标签
suv = np.nan_to_num(suv, nan=0.0, posinf=0.0, neginf=0.0)
seg = np.rint(seg).astype(np.int32, copy=False)
if suv.shape != seg.shape:
raise ValueError(f"Shape mismatch after resample: SUV {suv.shape} vs SEG {seg.shape}")
labels = sorted(np.unique(seg).astype(int).tolist())
if USE_LABEL_RANGE is not None:
labels = [l for l in labels if l in USE_LABEL_RANGE]
labels = [l for l in labels if l != 0]
out = {}
total = len(labels)
bins = np.append(np.arange(0, 30.1, 0.1), np.inf)
for i, lab in enumerate(labels, start=1):
pbar((i, total), desc=f"Extract ROI {lab}")
mask = (seg == lab)
if not np.any(mask):
continue
vals = suv[mask]
if vals.size == 0:
continue
meansuv = float(np.mean(vals))
maxsuv = float(np.max(vals))
hist, _ = np.histogram(vals, bins=bins)
out[roi_name(lab)] = {
"meansuv": meansuv,
"maxsuv": maxsuv,
"hist": [int(h) for h in hist]
}
return out
def fill_missing_with_train_mean(target_dict, region_list):
"""target_dict: 单个受试者的 {roi_name: {meansuv,..}}"""
filled = {}
for r in region_list:
if r in target_dict and "meansuv" in target_dict[r]:
filled[r] = {"meansuv": float(target_dict[r]["meansuv"])}
else:
filled[r] = {"meansuv": float(TRAIN_MEANS.get(r, 0.0))}
return filled
def build_subject_row(filled_roi_dict, age, sex):
row = {
"age": float(age),
"sex": 1.0 if str(sex).upper().startswith("M") else 0.0
}
for r, d in filled_roi_dict.items():
row[r] = float(d["meansuv"])
return pd.Series(row)
def split_brain_body_regions(REGIONS):
brain_regions = [r for r in REGIONS if any(x in r for x in ['L-', 'R-', 'Third ventricle', 'Corpus callosum'])]
body_regions = [r for r in REGIONS if (r not in brain_regions and r != BODY_REF and r not in BRAIN_REF)]
return brain_regions, body_regions
def add_rmeansuv(df, REGIONS):
brain_regions, body_regions = split_brain_body_regions(REGIONS)
# 脑:相对 Brainstem(可多脑参考取均值)
brain_ref_val = df[BRAIN_REF].mean(axis=1)
brain_rel = df[brain_regions].div(brain_ref_val, axis=0)
brain_rel.columns = [f"{c}_rel" for c in brain_regions]
# 体:相对 liver
body_rel = df[body_regions].div(df[BODY_REF], axis=0)
body_rel.columns = [f"{c}_rel" for c in body_regions]
return pd.concat([df[['age','sex']], brain_rel, body_rel], axis=1)
def apply_level1(test_rel_df):
"""使用内置 TRAIN_COEFS 对每个 *_rel 列做 age/sex 回归残差"""
raw = {}
for col in [c for c in test_rel_df.columns if c.endswith("_rel")]:
region = col[:-4]
coef_entry = TRAIN_COEFS.get(region)
if not coef_entry:
raw[region] = test_rel_df[col].values
continue
beta_age, beta_sex = coef_entry["coef"]
b0 = coef_entry["intercept"]
y = test_rel_df[col].values
yhat = beta_age*test_rel_df["age"].values + beta_sex*test_rel_df["sex"].values + b0
raw[region] = y - yhat
return pd.DataFrame(raw, index=test_rel_df.index)
# ============== 主流程 ==============
def pipeline_demo(age, sex, progress=gr.Progress(track_tqdm=False)):
return pipeline(DEMO_SUV, DEMO_SEG, age, sex, progress)
def pipeline(suv_path, seg_path, age, sex, progress=gr.Progress(track_tqdm=False)):
if not suv_path or not seg_path:
return None, None, None, "Please upload both SUV and segmentation.", None, None
progress(0, desc="Loading NIfTI (SimpleITK)")
suv_img = read_nii(suv_path)
seg_img = read_nii(seg_path)
# 预览三向切片(来自 SUV)
progress(0.05, desc="Preparing preview")
try:
axial_u8, cor_u8, sag_u8 = mid_slices_for_preview_sitk(suv_img)
except Exception:
axial_u8 = cor_u8 = sag_u8 = None
# 1) 提取 ROI meansuv/maxsuv/hist
progress(0.15, desc="Extracting ROI metrics")
roi_metrics = extract_roi_means_sitk(suv_img, seg_img, progress)
# 2) REGIONS:以“内置均值键 + 本次出现的 ROI”合并
REGIONS = sorted(set(list(TRAIN_MEANS.keys()) + list(roi_metrics.keys())))
# 可选:检查参考区是否存在
missing_refs = [r for r in (BRAIN_REF + [BODY_REF]) if r not in REGIONS]
if missing_refs:
info = f"Missing reference regions in TRAIN_MEANS/seg: {missing_refs}"
return axial_u8, cor_u8, sag_u8, info, None, None
# 3) 缺失用训练均值填充
progress(0.40, desc="Filling missing ROIs with train means")
filled = fill_missing_with_train_mean(roi_metrics, REGIONS)
# 4) 组装为 DataFrame(单行)
subj_row = build_subject_row(filled, age, sex)
df = pd.DataFrame([subj_row])
# 5) 计算 rmeanSUV
progress(0.50, desc="Computing rmeanSUV")
try:
df_rel = add_rmeansuv(df, REGIONS)
except Exception as e:
return axial_u8, cor_u8, sag_u8, f"[rmeanSUV error] {e}", None, None
# 6) 应用内置 age/sex 去偏系数(如无则直接返回 _rel)
progress(0.60, desc="Applying age/sex residualization")
if TRAIN_COEFS:
df_resid = apply_level1(df_rel)
result_kind = "residuals"
feature_df = df_resid
else:
result_kind = "rel"
feature_df = df_rel[[c for c in df_rel.columns if c.endswith("_rel")]]
# 7) Reference Connectome (R²) 可视化与导出
progress(0.70, desc="Reference Connectome Visualization")
r2_df = load_rmc_r2_df()
if r2_df is not None:
r2_png_path, r2_csv_path = plot_r2_heatmap_to_tmp(r2_df, title="Reference R² Heatmap", vmax=1.0)
else:
r2_png_path, r2_csv_path = None, None
# 8) 计算并可视化单例 IMD(仅当有 df_resid 且 IMD 参数存在时)
progress(0.80, desc="Indivisual Metabolic Visualization")
imd_png_path, imd_csv_path = None, None
imd_params = load_imd_params()
if (result_kind == "residuals") and (imd_params is not None):
slope, interc, sigma, regions_order = imd_params
# 取该受试者的残差 Series(单行)
resid_series = df_resid.iloc[0] # index: ROI
imd_df = compute_imd_matrix_for_subject(resid_series, slope, interc, sigma, regions_order)
imd_png_path, imd_csv_path = save_imd_heatmap_and_csv(imd_df, title="IMD Heatmap (Studentized Residuals)")
elif imd_params is None:
# 没有训练期的 IMD 参数文件
pass
else:
# 没有做 residuals(比如无 TRAIN_COEFS)
pass
progress(1.0, desc="Done")
# 输出 CSV(含 age/sex + 特征)
out_df = pd.concat([df[['age','sex']].reset_index(drop=True), feature_df.reset_index(drop=True)], axis=1)
fd, csv_path = tempfile.mkstemp(suffix=".csv", prefix=f"features_{result_kind}_", dir="/tmp"); os.close(fd)
out_df.to_csv(csv_path, index=False)
info = (f"ROIs extracted: {len(roi_metrics)} / {len(REGIONS)}\n"
f"Output columns: {out_df.shape[1]} (features kind: {result_kind})\n"
f"Preview from SUV (mid slices).")
progress(1.0, desc="Done")
return axial_u8, cor_u8, sag_u8, info, csv_path, out_df.head(20), r2_png_path, r2_csv_path, imd_png_path, imd_csv_path
# ============== Gradio UI ==============
with gr.Blocks(title="Total-Body 18F-FDG PET — IMD Features (SimpleITK)", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Total-Body 18F-FDG PET — Indivisual Metabolic Deviation Network\n"
"Upload **SUVPET** and **segmentation**")
with gr.Row():
run_demo = gr.Button("Run demo (one click) or Upload SUV NIFTI & Segment NIFTI") # 新增
with gr.Row():
suv_in = gr.File(file_types=[".nii",".nii.gz"], type="filepath", label="Upload SUV NIfTI")
seg_in = gr.File(file_types=[".nii",".nii.gz"], type="filepath", label="Upload Segmentation NIfTI")
with gr.Row():
age_in = gr.Number(value=65, label="Age")
sex_in = gr.Dropdown(choices=["M","F"], value="M", label="Sex")
run = gr.Button("Analyze", variant="primary")
with gr.Tabs():
with gr.Tab("Axial"): axial_im = gr.Image(type="numpy")
with gr.Tab("Coronal"): cor_im = gr.Image(type="numpy")
with gr.Tab("Sagittal"): sag_im = gr.Image(type="numpy")
info_box = gr.Textbox(label="Process Log", lines=5)
csv_out = gr.File(label="Download Features CSV")
preview = gr.Dataframe(label="Feature preview (top rows)", interactive=False)
with gr.Tab("Reference Metabolic Connectome (R²)"):
r2_img = gr.Image(type="filepath", label="R² Heatmap (Reference)")
r2_csv = gr.File(label="Download R² CSV")
with gr.Tab("Individual Metabolic Deviation (IMD)"):
imd_img = gr.Image(type="filepath", label="IMD Heatmap")
imd_csv = gr.File(label="Download IMD CSV")
run.click(
fn=pipeline,
inputs=[suv_in, seg_in, age_in, sex_in],
outputs=[axial_im, cor_im, sag_im, info_box, csv_out, preview, r2_img, r2_csv, imd_img, imd_csv]
)
run_demo.click(
fn=pipeline_demo,
inputs=[age_in, sex_in],
outputs=[axial_im, cor_im, sag_im, info_box, csv_out, preview, r2_img, r2_csv, imd_img, imd_csv]
)
if __name__ == "__main__":
demo.launch()