import gradio as gr import torch import numpy as np import nibabel as nib import torch.nn.functional as F import matplotlib.pyplot as plt from generator import UnetGenerator from huggingface_hub import hf_hub_download from collections import OrderedDict import tempfile import os import time import spaces # for @spaces.GPU # Device setup # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model_path = hf_hub_download(repo_id="zhang0319/synthetic_t2", filename="generator_t2_88.pth") state_dict = torch.load(model_path, map_location="cuda") new_state_dict = OrderedDict() for k, v in state_dict.items(): new_key = k.replace("module.", "") if k.startswith("module.") else k new_state_dict[new_key] = v model = UnetGenerator() model.load_state_dict(new_state_dict) model.eval().to("cuda") # Target size (H, W) for each slice target_h, target_w = 384, 192 def normalize(img): min_val, max_val = img.min(), img.max() return 2 * (img - min_val) / (max_val - min_val) - 1 if max_val > min_val else img def center_crop_pad(img, target_h=384, target_w=192): h, w = img.shape center_y, center_x = h // 2, w // 2 crop_top = max(center_y - target_h // 2, 0) crop_bottom = min(crop_top + target_h, h) crop_top = crop_bottom - target_h if crop_bottom - target_h >= 0 else 0 # crop_left = max(center_x - target_w // 2, 0) # crop_right = min(crop_left + target_w, w) # crop_left = crop_right - target_w if crop_right - target_w >= 0 else 0 # X軸保留左側 crop_right = w crop_left = max(w - target_w, 0) # cropped = img[crop_top:crop_bottom, crop_left:crop_right] cropped = img[crop_top:crop_bottom, crop_left:crop_right][:, ::-1] # padding if necessary pad_h = max(target_h - cropped.shape[0], 0) pad_w = max(target_w - cropped.shape[1], 0) pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left return np.pad(cropped, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0) @spaces.GPU def predict_volume(dce, sinwas, dwi): z = dce.shape[2] output = np.zeros((z, target_h, target_w), dtype=np.float32) for i in range(z): dce_slice = np.fliplr(dce[:, :, i]) # 水平方向翻轉,對齊訓練資料 sinwas_slice = np.fliplr(sinwas[:, :, i]) dwi_slice = np.fliplr(dwi[:, :, i]) dce_crop = center_crop_pad(dce_slice) sinwas_crop = center_crop_pad(sinwas_slice) dwi_crop = center_crop_pad(dwi_slice) x1 = torch.tensor(dce_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") x2 = torch.tensor(sinwas_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") x3 = torch.tensor(dwi_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") with torch.no_grad(): out, _, _, _ = model(x1, x2, x3) output[i] = out.squeeze().cpu().numpy() return output @spaces.GPU def run_synthesis(dce_file, sinwas_file, dwi_file): start_time = time.time() dce = nib.load(dce_file.name).get_fdata().astype(np.float32) sinwas = nib.load(sinwas_file.name).get_fdata().astype(np.float32) dwi = nib.load(dwi_file.name).get_fdata().astype(np.float32) original_h = dce.shape[0] # 根據輸入自動獲取原始高度 dce = normalize(dce) sinwas = normalize(sinwas) dwi = normalize(dwi) # predict full volume fake_volume = predict_volume(dce, sinwas, dwi) fake_volume = (fake_volume + 1) / 2 # # 最終保存前裁剪 Y方向為原始高度352 # start_y = (fake_volume.shape[1] - original_h) // 2 # final_volume = fake_volume[:, start_y:start_y + original_h, :] # (Z, H, W) if fake_volume.shape[1] >= original_h: start_y = (fake_volume.shape[1] - original_h) // 2 final_volume = fake_volume[:, start_y:start_y + original_h, :] else: pad_top = (original_h - fake_volume.shape[1]) // 2 pad_bottom = original_h - fake_volume.shape[1] - pad_top final_volume = np.pad(fake_volume, ((0, 0), (pad_top, pad_bottom), (0, 0)), mode='constant') # Convert shape to (H, W, Z) final_volume = np.transpose(final_volume, (1, 2, 0)) # Save as nifti (use DCE header as reference) affine = nib.load(dce_file.name).affine header = nib.load(dce_file.name).header output_nii = nib.Nifti1Image(final_volume, affine=affine, header=header) temp_dir = tempfile.mkdtemp() filename = os.path.join(temp_dir, "t2_synthesized.nii.gz") nib.save(output_nii, filename) # Middle slice mid_slice = final_volume.shape[2] // 2 dce_slice = center_crop_pad(np.fliplr(dce[:, :, mid_slice])) if dce_slice.shape[0] > 352: x_start = (dce_slice.shape[0] - 352) // 2 dce_slice = dce_slice[x_start:x_start + 352, :] fig, ax = plt.subplots(2, 1, figsize=(5, 5)) # ax[0].imshow(center_crop_pad(np.fliplr(dce[:, :, mid_slice])), cmap="gray") ax[0].imshow(np.rot90(dce_slice,k=3), cmap="gray") ax[0].set_title("Original T1 Slice") ax[0].axis("off") ax[1].imshow(np.rot90(final_volume[:, :, mid_slice], k=3), cmap="gray") ax[1].set_title("Synthesized T2 Slice") ax[1].axis("off") elapsed_time = time.time() - start_time return fig, filename, f"Synthesis completed in {elapsed_time:.2f} seconds" example_list = [ [ "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case1.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case1.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case1.nii.gz" ], [ "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case2.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case2.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case2.nii.gz" ], [ "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case3.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case3.nii.gz", "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case3.nii.gz" ] ] custom_theme = gr.themes.Base().set( body_background_fill="#ffffff", button_primary_background_fill="#e0e7ff", button_primary_text_color="#4f46e5" ) with gr.Blocks(title="Breast MRI T2 Synthesizer", theme=custom_theme) as interface: gr.Markdown(""" # 🧠 IMPORTANT-Net: Breast MRI T2 Synthesizer ✨ Upload your T1, Sinwas, and DWI volumes to generate a synthetic T2-weighted MRI volume using a deep learning model.✨ 📧 Contact us: Dr. Tianyu Zhang (Tianyu.Zhang@radboudumc.nl), Dr. Ritse Mann (Ritse.Mann@radboudumc.nl) """) gr.HTML("""
""") with gr.Row(): with gr.Column(scale=0.8): dce_input = gr.File(label="T1 (1.nii.gz)", height=150) sinwas_input = gr.File(label="Sinwas (sinwas.nii.gz)", height=150) dwi_input = gr.File(label="DWI (0.nii.gz)", height=150) # with gr.Column(scale=1): # gr.Image( # value="https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif", # label="Neural Activity", # show_label=False, # interactive=False # ) with gr.Column(scale=1): gr.Markdown(""" ![Neural Activity](https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif) """) with gr.Column(scale=1): image_output = gr.Plot(label="Middle Slice: Original vs Synthesized T2") file_output = gr.File(label="Download Full Synthesized T2 Volume") status_output = gr.Textbox(label="Status") gr.HTML("""
""") gr.Examples( examples=example_list, inputs=[dce_input, sinwas_input, dwi_input], label="Select an example to try", examples_per_page=2, cache_examples=False ) with gr.Row(): run_button = gr.Button("Run Synthesis", variant="primary") clear_button = gr.Button("Clear All") # gr.Examples( # examples=example_list, # inputs=[dce_input, sinwas_input, dwi_input], # label="Try with example files" # ) # run_button = gr.Button("Run Synthesis") run_button.click( fn=run_synthesis, inputs=[dce_input, sinwas_input, dwi_input], outputs=[image_output, file_output, status_output] ) gr.Markdown("## 📊 Overview") gr.Markdown(""" --- ### 🖼️ Flowchart Overview ![Network Diagram](https://huggingface.co/zhang0319/synthetic_t2/resolve/main/IMPORTANT-NET.jpg) """) def clear_inputs(): return None, None, None, gr.update(value=None), gr.update(value=None), "" clear_button.click( fn=clear_inputs, inputs=[], outputs=[dce_input, sinwas_input, dwi_input, image_output, file_output, status_output] ) if __name__ == "__main__": interface.launch(share = True)