| import os, json, csv, tempfile |
| import numpy as np |
| import pandas as pd |
| import SimpleITK as sitk |
| import gradio as gr |
|
|
| |
| |
| |
| |
| DEMO_SUV = "SUVPET_demo.nii.gz" |
| DEMO_SEG = "segment_demo.nii.gz" |
|
|
| IMD_SLOPE_CSV = "imd_slope.csv" |
| IMD_INTERCEPT_CSV = "imd_intercept.csv" |
| IMD_SIGMA_CSV = "imd_sigma.csv" |
| IMD_REGIONS_JSON = "imd_regions.json" |
|
|
| def load_imd_params(): |
| """读取 slope/intercept/sigma 以及 ROI 顺序。全部存在才返回 (slope, interc, sigma, regions_order),否则返回 None。""" |
| if not (os.path.exists(IMD_SLOPE_CSV) and os.path.exists(IMD_INTERCEPT_CSV) and os.path.exists(IMD_SIGMA_CSV) and os.path.exists(IMD_REGIONS_JSON)): |
| return None |
| slope = pd.read_csv(IMD_SLOPE_CSV, index_col=0) |
| interc = pd.read_csv(IMD_INTERCEPT_CSV, index_col=0) |
| sigma = pd.read_csv(IMD_SIGMA_CSV, index_col=0) |
| with open(IMD_REGIONS_JSON, "r", encoding="utf-8") as f: |
| regions_order = json.load(f) |
| |
| assert slope.shape == interc.shape == sigma.shape, "IMD matrices shape mismatch" |
| assert list(slope.index) == list(slope.columns), "slope must be square with labels" |
| return slope, interc, sigma, regions_order |
|
|
| def compute_imd_matrix_for_subject(resid_series: pd.Series, slope: pd.DataFrame, interc: pd.DataFrame, sigma: pd.DataFrame, regions_order: list[str]): |
| """ |
| resid_series: 单个样本的残差,index 为 ROI 名(与 df_resid.columns 一致) |
| 返回:DataFrame (N×N),索引与列均为 regions_order |
| """ |
| |
| x = resid_series.reindex(regions_order).fillna(0.0).astype(float).values |
| N = len(regions_order) |
| sr_mat = np.zeros((N, N), dtype=np.float32) |
|
|
| |
| b1 = slope.reindex(index=regions_order, columns=regions_order).values |
| b0 = interc.reindex(index=regions_order, columns=regions_order).values |
| sg = sigma.reindex(index=regions_order, columns=regions_order).values |
|
|
| |
| |
| for i in range(N): |
| yi = x[i] |
| |
| yhat_row = b0[i, :] + b1[i, :] * x |
| denom = sg[i, :] |
| |
| safe = denom != 0 |
| row = np.zeros(N, dtype=np.float32) |
| row[safe] = np.abs(yi - yhat_row[safe]) / denom[safe] |
| sr_mat[i, :] = row |
| np.clip(sr_mat, 0.0, 4, out=sr_mat) |
| |
| np.fill_diagonal(sr_mat, 0.0) |
|
|
| return pd.DataFrame(sr_mat, index=regions_order, columns=regions_order) |
|
|
| |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| def save_imd_heatmap_and_csv(imd_df: pd.DataFrame, title="IMD (Studentized Residuals)", vmax=None): |
| fd_png, png_path = tempfile.mkstemp(suffix=".png", prefix="imd_", dir="/tmp"); os.close(fd_png) |
| fd_csv, csv_path = tempfile.mkstemp(suffix=".csv", prefix="imd_", dir="/tmp"); os.close(fd_csv) |
|
|
| |
| imd_df.to_csv(csv_path) |
|
|
| |
| plt.figure(figsize=(10, 9)) |
| vmin = np.nanmin(imd_df.values) |
| if vmax is None: |
| vmax = np.nanpercentile(imd_df.values, 99) if np.isfinite(vmin) else None |
| im = plt.imshow(imd_df.values, aspect="auto", vmin=vmin if np.isfinite(vmin) else None, vmax=vmax) |
| plt.xticks(range(len(imd_df.columns)), imd_df.columns, rotation=90, fontsize=6) |
| plt.yticks(range(len(imd_df.index)), imd_df.index, fontsize=6) |
| cbar = plt.colorbar(im) |
| cbar.ax.tick_params(labelsize=8) |
| plt.title(title, fontsize=12, pad=10) |
| plt.tight_layout() |
| plt.savefig(png_path, dpi=300) |
| plt.close() |
| return png_path, csv_path |
|
|
|
|
| ATLAS_LABELS_JSON = "atlas_labels.json" |
| BUILTIN_TRAIN_MEANS_JSON = "builtin_train_means.json" |
| BUILTIN_TRAIN_COEFS_JSON = "train_coefs.json" |
|
|
| BRAIN_REF = ['Brainstem'] |
| BODY_REF = 'liver' |
|
|
| |
| RMC_R2_CSV = "rmc_r2.csv" |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| def load_rmc_r2_df(): |
| if os.path.exists(RMC_R2_CSV): |
| df = pd.read_csv(RMC_R2_CSV, index_col=0) |
| |
| if df.shape[0] == df.shape[1]: |
| return df |
| return None |
|
|
| def plot_r2_heatmap_to_tmp(df: pd.DataFrame, title="Reference R² Heatmap", vmax=1.0): |
| """把 R² 矩阵画成热图,保存到 /tmp,返回图片文件路径 & 同时返回 CSV 路径用于下载""" |
| fd, png_path = tempfile.mkstemp(suffix=".png", prefix="rmc_r2_", dir="/tmp"); os.close(fd) |
| plt.figure(figsize=(10, 9)) |
| im = plt.imshow(df.values, aspect="auto", vmin=np.nanmin(df.values), vmax=vmax) |
| plt.xticks(range(len(df.columns)), df.columns, rotation=90, fontsize=6) |
| plt.yticks(range(len(df.index)), df.index, fontsize=6) |
| cbar = plt.colorbar(im) |
| cbar.ax.tick_params(labelsize=8) |
| plt.title(title, fontsize=12, pad=10) |
| plt.tight_layout() |
| plt.savefig(png_path, dpi=300) |
| plt.close() |
| |
| if os.path.exists(RMC_R2_CSV): |
| csv_out_path = RMC_R2_CSV |
| else: |
| fd2, csv_out_path = tempfile.mkstemp(suffix=".csv", prefix="rmc_r2_", dir="/tmp"); os.close(fd2) |
| df.to_csv(csv_out_path) |
| return png_path, csv_out_path |
|
|
|
|
| |
| USE_LABEL_RANGE = None |
|
|
| |
| def load_json_if_exists(path): |
| if path and os.path.exists(path): |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
| return None |
|
|
| ATLAS_LABELS = load_json_if_exists(ATLAS_LABELS_JSON) or {} |
| TRAIN_MEANS = load_json_if_exists(BUILTIN_TRAIN_MEANS_JSON) or {} |
| TRAIN_COEFS = load_json_if_exists(BUILTIN_TRAIN_COEFS_JSON) or {} |
|
|
|
|
| def to_uint8(arr2d, vmin=None, vmax=None): |
| arr = np.asarray(arr2d, dtype=np.float32) |
| arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0) |
| if vmin is None: vmin = np.percentile(arr, 2) |
| if vmax is None: vmax = np.percentile(arr, 98) |
| if not np.isfinite(vmin): vmin = float(np.min(arr)) |
| if not np.isfinite(vmax): vmax = float(np.max(arr)) |
| if vmax <= vmin: vmax = vmin + 1e-6 |
| arr = np.clip(arr, vmin, vmax) |
| arr = (arr - vmin) / (vmax - vmin + 1e-12) |
| return (arr * 255.0).astype(np.uint8) |
|
|
| def roi_name(label_int): |
| return ATLAS_LABELS.get(str(label_int), f"ROI_{label_int}") |
|
|
| def read_nii(path): |
| """读取 NIfTI 为 SimpleITK Image。""" |
| return sitk.ReadImage(path) |
|
|
| def img_to_array(img, dtype=np.float32, pick_first_if_4d=True): |
| """ |
| SimpleITK 图像 → numpy 数组。 |
| SimpleITK.GetArrayFromImage 返回 (z, y, x) 顺序。 |
| """ |
| arr = sitk.GetArrayFromImage(img).astype(dtype, copy=False) |
| |
| if arr.ndim == 4 and pick_first_if_4d: |
| |
| if arr.shape[0] <= 16: |
| arr = arr[0] |
| else: |
| arr = arr[..., 0] |
| return arr |
|
|
| def ensure_same_grid(seg_img, ref_img): |
| """ |
| 确保 segmentation 与 SUV 在相同网格上: |
| - 若尺寸/间距/方向不同,则用最近邻重采样 seg 到 SUV 栅格。 |
| 返回与 ref_img 对齐的 seg_img(SimpleITK Image)。 |
| """ |
| same_size = list(seg_img.GetSize()) == list(ref_img.GetSize()) |
| same_spacing = np.allclose(seg_img.GetSpacing(), ref_img.GetSpacing()) |
| same_origin = np.allclose(seg_img.GetOrigin(), ref_img.GetOrigin()) |
| same_direction = np.allclose(seg_img.GetDirection(), ref_img.GetDirection()) |
| if same_size and same_spacing and same_origin and same_direction: |
| return seg_img |
|
|
| resampler = sitk.ResampleImageFilter() |
| resampler.SetReferenceImage(ref_img) |
| resampler.SetInterpolator(sitk.sitkNearestNeighbor) |
| resampler.SetOutputPixelType(seg_img.GetPixelID()) |
| seg_img_resampled = resampler.Execute(seg_img) |
| return seg_img_resampled |
|
|
| def mid_slices_for_preview_sitk(suv_img): |
| """ |
| 仅用于预览三向切片(来自 SUV)。 |
| SimpleITK 数组是 (z,y,x),我们做适当转置/翻转以便更直观。 |
| """ |
| data = img_to_array(suv_img, dtype=np.float32) |
| z, y, x = data.shape |
| z_mid, y_mid, x_mid = z//2, y//2, x//2 |
|
|
| |
| axial = np.flipud(data[z_mid, :, :]) |
| |
| coronal = np.flipud(data[:, y_mid, :].T) |
| |
| sagittal = np.flipud(data[:, :, x_mid].T) |
|
|
| return to_uint8(axial), to_uint8(coronal), to_uint8(sagittal) |
|
|
| def extract_roi_means_sitk(suv_img, seg_img, pbar: gr.Progress): |
| """ |
| 用 SimpleITK 提取每个 ROI 的 mean/max SUV 和直方图。 |
| 返回 dict[roi_name] = {"meansuv":..., "maxsuv":..., "hist":[...]} |
| """ |
| |
| seg_img = ensure_same_grid(seg_img, suv_img) |
|
|
| suv = img_to_array(suv_img, dtype=np.float32) |
| seg = img_to_array(seg_img, dtype=np.float32) |
| suv = np.nan_to_num(suv, nan=0.0, posinf=0.0, neginf=0.0) |
| seg = np.rint(seg).astype(np.int32, copy=False) |
|
|
| if suv.shape != seg.shape: |
| raise ValueError(f"Shape mismatch after resample: SUV {suv.shape} vs SEG {seg.shape}") |
|
|
| labels = sorted(np.unique(seg).astype(int).tolist()) |
| if USE_LABEL_RANGE is not None: |
| labels = [l for l in labels if l in USE_LABEL_RANGE] |
| labels = [l for l in labels if l != 0] |
|
|
| out = {} |
| total = len(labels) |
| bins = np.append(np.arange(0, 30.1, 0.1), np.inf) |
|
|
| for i, lab in enumerate(labels, start=1): |
| pbar((i, total), desc=f"Extract ROI {lab}") |
| mask = (seg == lab) |
| if not np.any(mask): |
| continue |
| vals = suv[mask] |
| if vals.size == 0: |
| continue |
| meansuv = float(np.mean(vals)) |
| maxsuv = float(np.max(vals)) |
| hist, _ = np.histogram(vals, bins=bins) |
| out[roi_name(lab)] = { |
| "meansuv": meansuv, |
| "maxsuv": maxsuv, |
| "hist": [int(h) for h in hist] |
| } |
| return out |
|
|
| def fill_missing_with_train_mean(target_dict, region_list): |
| """target_dict: 单个受试者的 {roi_name: {meansuv,..}}""" |
| filled = {} |
| for r in region_list: |
| if r in target_dict and "meansuv" in target_dict[r]: |
| filled[r] = {"meansuv": float(target_dict[r]["meansuv"])} |
| else: |
| filled[r] = {"meansuv": float(TRAIN_MEANS.get(r, 0.0))} |
| return filled |
|
|
| def build_subject_row(filled_roi_dict, age, sex): |
| row = { |
| "age": float(age), |
| "sex": 1.0 if str(sex).upper().startswith("M") else 0.0 |
| } |
| for r, d in filled_roi_dict.items(): |
| row[r] = float(d["meansuv"]) |
| return pd.Series(row) |
|
|
| def split_brain_body_regions(REGIONS): |
| brain_regions = [r for r in REGIONS if any(x in r for x in ['L-', 'R-', 'Third ventricle', 'Corpus callosum'])] |
| body_regions = [r for r in REGIONS if (r not in brain_regions and r != BODY_REF and r not in BRAIN_REF)] |
| return brain_regions, body_regions |
|
|
| def add_rmeansuv(df, REGIONS): |
| brain_regions, body_regions = split_brain_body_regions(REGIONS) |
| |
| brain_ref_val = df[BRAIN_REF].mean(axis=1) |
| brain_rel = df[brain_regions].div(brain_ref_val, axis=0) |
| brain_rel.columns = [f"{c}_rel" for c in brain_regions] |
| |
| body_rel = df[body_regions].div(df[BODY_REF], axis=0) |
| body_rel.columns = [f"{c}_rel" for c in body_regions] |
| return pd.concat([df[['age','sex']], brain_rel, body_rel], axis=1) |
|
|
| def apply_level1(test_rel_df): |
| """使用内置 TRAIN_COEFS 对每个 *_rel 列做 age/sex 回归残差""" |
| raw = {} |
| for col in [c for c in test_rel_df.columns if c.endswith("_rel")]: |
| region = col[:-4] |
| coef_entry = TRAIN_COEFS.get(region) |
| if not coef_entry: |
| raw[region] = test_rel_df[col].values |
| continue |
| beta_age, beta_sex = coef_entry["coef"] |
| b0 = coef_entry["intercept"] |
| y = test_rel_df[col].values |
| yhat = beta_age*test_rel_df["age"].values + beta_sex*test_rel_df["sex"].values + b0 |
| raw[region] = y - yhat |
| return pd.DataFrame(raw, index=test_rel_df.index) |
|
|
| |
| def pipeline_demo(age, sex, progress=gr.Progress(track_tqdm=False)): |
| return pipeline(DEMO_SUV, DEMO_SEG, age, sex, progress) |
|
|
| def pipeline(suv_path, seg_path, age, sex, progress=gr.Progress(track_tqdm=False)): |
| if not suv_path or not seg_path: |
| return None, None, None, "Please upload both SUV and segmentation.", None, None |
|
|
| progress(0, desc="Loading NIfTI (SimpleITK)") |
| suv_img = read_nii(suv_path) |
| seg_img = read_nii(seg_path) |
|
|
| |
| progress(0.05, desc="Preparing preview") |
| try: |
| axial_u8, cor_u8, sag_u8 = mid_slices_for_preview_sitk(suv_img) |
| except Exception: |
| axial_u8 = cor_u8 = sag_u8 = None |
|
|
| |
| progress(0.15, desc="Extracting ROI metrics") |
| roi_metrics = extract_roi_means_sitk(suv_img, seg_img, progress) |
|
|
| |
| REGIONS = sorted(set(list(TRAIN_MEANS.keys()) + list(roi_metrics.keys()))) |
|
|
| |
| missing_refs = [r for r in (BRAIN_REF + [BODY_REF]) if r not in REGIONS] |
| if missing_refs: |
| info = f"Missing reference regions in TRAIN_MEANS/seg: {missing_refs}" |
| return axial_u8, cor_u8, sag_u8, info, None, None |
|
|
| |
| progress(0.40, desc="Filling missing ROIs with train means") |
| filled = fill_missing_with_train_mean(roi_metrics, REGIONS) |
|
|
| |
| subj_row = build_subject_row(filled, age, sex) |
| df = pd.DataFrame([subj_row]) |
|
|
| |
| progress(0.50, desc="Computing rmeanSUV") |
| try: |
| df_rel = add_rmeansuv(df, REGIONS) |
| except Exception as e: |
| return axial_u8, cor_u8, sag_u8, f"[rmeanSUV error] {e}", None, None |
|
|
| |
| progress(0.60, desc="Applying age/sex residualization") |
| if TRAIN_COEFS: |
| df_resid = apply_level1(df_rel) |
| result_kind = "residuals" |
| feature_df = df_resid |
| else: |
| result_kind = "rel" |
| feature_df = df_rel[[c for c in df_rel.columns if c.endswith("_rel")]] |
|
|
| |
| progress(0.70, desc="Reference Connectome Visualization") |
| r2_df = load_rmc_r2_df() |
| if r2_df is not None: |
| r2_png_path, r2_csv_path = plot_r2_heatmap_to_tmp(r2_df, title="Reference R² Heatmap", vmax=1.0) |
| else: |
| r2_png_path, r2_csv_path = None, None |
|
|
| |
| progress(0.80, desc="Indivisual Metabolic Visualization") |
| imd_png_path, imd_csv_path = None, None |
| imd_params = load_imd_params() |
| if (result_kind == "residuals") and (imd_params is not None): |
| slope, interc, sigma, regions_order = imd_params |
| |
| resid_series = df_resid.iloc[0] |
| imd_df = compute_imd_matrix_for_subject(resid_series, slope, interc, sigma, regions_order) |
| imd_png_path, imd_csv_path = save_imd_heatmap_and_csv(imd_df, title="IMD Heatmap (Studentized Residuals)") |
| elif imd_params is None: |
| |
| pass |
| else: |
| |
| pass |
|
|
| progress(1.0, desc="Done") |
|
|
| |
| out_df = pd.concat([df[['age','sex']].reset_index(drop=True), feature_df.reset_index(drop=True)], axis=1) |
| fd, csv_path = tempfile.mkstemp(suffix=".csv", prefix=f"features_{result_kind}_", dir="/tmp"); os.close(fd) |
| out_df.to_csv(csv_path, index=False) |
|
|
| info = (f"ROIs extracted: {len(roi_metrics)} / {len(REGIONS)}\n" |
| f"Output columns: {out_df.shape[1]} (features kind: {result_kind})\n" |
| f"Preview from SUV (mid slices).") |
|
|
| progress(1.0, desc="Done") |
| return axial_u8, cor_u8, sag_u8, info, csv_path, out_df.head(20), r2_png_path, r2_csv_path, imd_png_path, imd_csv_path |
|
|
| |
| with gr.Blocks(title="Total-Body 18F-FDG PET — IMD Features (SimpleITK)", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# Total-Body 18F-FDG PET — Indivisual Metabolic Deviation Network\n" |
| "Upload **SUVPET** and **segmentation**") |
| |
| with gr.Row(): |
| run_demo = gr.Button("Run demo (one click) or Upload SUV NIFTI & Segment NIFTI") |
| |
| with gr.Row(): |
| suv_in = gr.File(file_types=[".nii",".nii.gz"], type="filepath", label="Upload SUV NIfTI") |
| seg_in = gr.File(file_types=[".nii",".nii.gz"], type="filepath", label="Upload Segmentation NIfTI") |
|
|
| with gr.Row(): |
| age_in = gr.Number(value=65, label="Age") |
| sex_in = gr.Dropdown(choices=["M","F"], value="M", label="Sex") |
|
|
| run = gr.Button("Analyze", variant="primary") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Axial"): axial_im = gr.Image(type="numpy") |
| with gr.Tab("Coronal"): cor_im = gr.Image(type="numpy") |
| with gr.Tab("Sagittal"): sag_im = gr.Image(type="numpy") |
|
|
| info_box = gr.Textbox(label="Process Log", lines=5) |
| csv_out = gr.File(label="Download Features CSV") |
| preview = gr.Dataframe(label="Feature preview (top rows)", interactive=False) |
|
|
| with gr.Tab("Reference Metabolic Connectome (R²)"): |
| r2_img = gr.Image(type="filepath", label="R² Heatmap (Reference)") |
| r2_csv = gr.File(label="Download R² CSV") |
|
|
| with gr.Tab("Individual Metabolic Deviation (IMD)"): |
| imd_img = gr.Image(type="filepath", label="IMD Heatmap") |
| imd_csv = gr.File(label="Download IMD CSV") |
|
|
| run.click( |
| fn=pipeline, |
| inputs=[suv_in, seg_in, age_in, sex_in], |
| outputs=[axial_im, cor_im, sag_im, info_box, csv_out, preview, r2_img, r2_csv, imd_img, imd_csv] |
| ) |
| run_demo.click( |
| fn=pipeline_demo, |
| inputs=[age_in, sex_in], |
| outputs=[axial_im, cor_im, sag_im, info_box, csv_out, preview, r2_img, r2_csv, imd_img, imd_csv] |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|