Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import nibabel as nib | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from generator import UnetGenerator | |
| from huggingface_hub import hf_hub_download | |
| from collections import OrderedDict | |
| import tempfile | |
| import os | |
| import time | |
| import spaces # for @spaces.GPU | |
| # Device setup | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| model_path = hf_hub_download(repo_id="zhang0319/synthetic_t2", filename="generator_t2_88.pth") | |
| state_dict = torch.load(model_path, map_location="cuda") | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_key = k.replace("module.", "") if k.startswith("module.") else k | |
| new_state_dict[new_key] = v | |
| model = UnetGenerator() | |
| model.load_state_dict(new_state_dict) | |
| model.eval().to("cuda") | |
| # Target size (H, W) for each slice | |
| target_h, target_w = 384, 192 | |
| def normalize(img): | |
| min_val, max_val = img.min(), img.max() | |
| return 2 * (img - min_val) / (max_val - min_val) - 1 if max_val > min_val else img | |
| def center_crop_pad(img, target_h=384, target_w=192): | |
| h, w = img.shape | |
| center_y, center_x = h // 2, w // 2 | |
| crop_top = max(center_y - target_h // 2, 0) | |
| crop_bottom = min(crop_top + target_h, h) | |
| crop_top = crop_bottom - target_h if crop_bottom - target_h >= 0 else 0 | |
| # crop_left = max(center_x - target_w // 2, 0) | |
| # crop_right = min(crop_left + target_w, w) | |
| # crop_left = crop_right - target_w if crop_right - target_w >= 0 else 0 | |
| # X軸保留左側 | |
| crop_right = w | |
| crop_left = max(w - target_w, 0) | |
| # cropped = img[crop_top:crop_bottom, crop_left:crop_right] | |
| cropped = img[crop_top:crop_bottom, crop_left:crop_right][:, ::-1] | |
| # padding if necessary | |
| pad_h = max(target_h - cropped.shape[0], 0) | |
| pad_w = max(target_w - cropped.shape[1], 0) | |
| pad_top = pad_h // 2 | |
| pad_bottom = pad_h - pad_top | |
| pad_left = pad_w // 2 | |
| pad_right = pad_w - pad_left | |
| return np.pad(cropped, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0) | |
| def predict_volume(dce, sinwas, dwi): | |
| z = dce.shape[2] | |
| output = np.zeros((z, target_h, target_w), dtype=np.float32) | |
| for i in range(z): | |
| dce_slice = np.fliplr(dce[:, :, i]) # 水平方向翻轉,對齊訓練資料 | |
| sinwas_slice = np.fliplr(sinwas[:, :, i]) | |
| dwi_slice = np.fliplr(dwi[:, :, i]) | |
| dce_crop = center_crop_pad(dce_slice) | |
| sinwas_crop = center_crop_pad(sinwas_slice) | |
| dwi_crop = center_crop_pad(dwi_slice) | |
| x1 = torch.tensor(dce_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") | |
| x2 = torch.tensor(sinwas_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") | |
| x3 = torch.tensor(dwi_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda") | |
| with torch.no_grad(): | |
| out, _, _, _ = model(x1, x2, x3) | |
| output[i] = out.squeeze().cpu().numpy() | |
| return output | |
| def run_synthesis(dce_file, sinwas_file, dwi_file): | |
| start_time = time.time() | |
| dce = nib.load(dce_file.name).get_fdata().astype(np.float32) | |
| sinwas = nib.load(sinwas_file.name).get_fdata().astype(np.float32) | |
| dwi = nib.load(dwi_file.name).get_fdata().astype(np.float32) | |
| original_h = dce.shape[0] # 根據輸入自動獲取原始高度 | |
| dce = normalize(dce) | |
| sinwas = normalize(sinwas) | |
| dwi = normalize(dwi) | |
| # predict full volume | |
| fake_volume = predict_volume(dce, sinwas, dwi) | |
| fake_volume = (fake_volume + 1) / 2 | |
| # # 最終保存前裁剪 Y方向為原始高度352 | |
| # start_y = (fake_volume.shape[1] - original_h) // 2 | |
| # final_volume = fake_volume[:, start_y:start_y + original_h, :] # (Z, H, W) | |
| if fake_volume.shape[1] >= original_h: | |
| start_y = (fake_volume.shape[1] - original_h) // 2 | |
| final_volume = fake_volume[:, start_y:start_y + original_h, :] | |
| else: | |
| pad_top = (original_h - fake_volume.shape[1]) // 2 | |
| pad_bottom = original_h - fake_volume.shape[1] - pad_top | |
| final_volume = np.pad(fake_volume, ((0, 0), (pad_top, pad_bottom), (0, 0)), mode='constant') | |
| # Convert shape to (H, W, Z) | |
| final_volume = np.transpose(final_volume, (1, 2, 0)) | |
| # Save as nifti (use DCE header as reference) | |
| affine = nib.load(dce_file.name).affine | |
| header = nib.load(dce_file.name).header | |
| output_nii = nib.Nifti1Image(final_volume, affine=affine, header=header) | |
| temp_dir = tempfile.mkdtemp() | |
| filename = os.path.join(temp_dir, "t2_synthesized.nii.gz") | |
| nib.save(output_nii, filename) | |
| # Middle slice | |
| mid_slice = final_volume.shape[2] // 2 | |
| dce_slice = center_crop_pad(np.fliplr(dce[:, :, mid_slice])) | |
| if dce_slice.shape[0] > 352: | |
| x_start = (dce_slice.shape[0] - 352) // 2 | |
| dce_slice = dce_slice[x_start:x_start + 352, :] | |
| fig, ax = plt.subplots(2, 1, figsize=(5, 5)) | |
| # ax[0].imshow(center_crop_pad(np.fliplr(dce[:, :, mid_slice])), cmap="gray") | |
| ax[0].imshow(np.rot90(dce_slice,k=3), cmap="gray") | |
| ax[0].set_title("Original T1 Slice") | |
| ax[0].axis("off") | |
| ax[1].imshow(np.rot90(final_volume[:, :, mid_slice], k=3), cmap="gray") | |
| ax[1].set_title("Synthesized T2 Slice") | |
| ax[1].axis("off") | |
| elapsed_time = time.time() - start_time | |
| return fig, filename, f"Synthesis completed in {elapsed_time:.2f} seconds" | |
| example_list = [ | |
| [ | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case1.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case1.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case1.nii.gz" | |
| ], | |
| [ | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case2.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case2.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case2.nii.gz" | |
| ], | |
| [ | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case3.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case3.nii.gz", | |
| "https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case3.nii.gz" | |
| ] | |
| ] | |
| custom_theme = gr.themes.Base().set( | |
| body_background_fill="#ffffff", | |
| button_primary_background_fill="#e0e7ff", | |
| button_primary_text_color="#4f46e5" | |
| ) | |
| with gr.Blocks(title="Breast MRI T2 Synthesizer", theme=custom_theme) as interface: | |
| gr.Markdown(""" | |
| # 🧠 IMPORTANT-Net: Breast MRI T2 Synthesizer | |
| ✨ Upload your T1, Sinwas, and DWI volumes to generate a synthetic T2-weighted MRI volume using a deep learning model.✨ | |
| 📧 Contact us: Dr. Tianyu Zhang (Tianyu.Zhang@radboudumc.nl), Dr. Ritse Mann (Ritse.Mann@radboudumc.nl) | |
| """) | |
| gr.HTML(""" | |
| <div style=" | |
| height: 6px; | |
| background: linear-gradient(to right, #6366f1, #a78bfa); | |
| border-radius: 3px; | |
| margin-top: 10px; | |
| margin-bottom: 20px; | |
| "></div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=0.8): | |
| dce_input = gr.File(label="T1 (1.nii.gz)", height=150) | |
| sinwas_input = gr.File(label="Sinwas (sinwas.nii.gz)", height=150) | |
| dwi_input = gr.File(label="DWI (0.nii.gz)", height=150) | |
| # with gr.Column(scale=1): | |
| # gr.Image( | |
| # value="https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif", | |
| # label="Neural Activity", | |
| # show_label=False, | |
| # interactive=False | |
| # ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
|  | |
| """) | |
| with gr.Column(scale=1): | |
| image_output = gr.Plot(label="Middle Slice: Original vs Synthesized T2") | |
| file_output = gr.File(label="Download Full Synthesized T2 Volume") | |
| status_output = gr.Textbox(label="Status") | |
| gr.HTML(""" | |
| <div style=" | |
| height: 6px; | |
| background: linear-gradient(to right, #6366f1, #a78bfa); | |
| border-radius: 3px; | |
| margin-top: 10px; | |
| margin-bottom: 20px; | |
| "></div> | |
| """) | |
| gr.Examples( | |
| examples=example_list, | |
| inputs=[dce_input, sinwas_input, dwi_input], | |
| label="Select an example to try", | |
| examples_per_page=2, | |
| cache_examples=False | |
| ) | |
| with gr.Row(): | |
| run_button = gr.Button("Run Synthesis", variant="primary") | |
| clear_button = gr.Button("Clear All") | |
| # gr.Examples( | |
| # examples=example_list, | |
| # inputs=[dce_input, sinwas_input, dwi_input], | |
| # label="Try with example files" | |
| # ) | |
| # run_button = gr.Button("Run Synthesis") | |
| run_button.click( | |
| fn=run_synthesis, | |
| inputs=[dce_input, sinwas_input, dwi_input], | |
| outputs=[image_output, file_output, status_output] | |
| ) | |
| gr.Markdown("## 📊 Overview") | |
| gr.Markdown(""" | |
| --- | |
| ### 🖼️ Flowchart Overview | |
|  | |
| """) | |
| def clear_inputs(): | |
| return None, None, None, gr.update(value=None), gr.update(value=None), "" | |
| clear_button.click( | |
| fn=clear_inputs, | |
| inputs=[], | |
| outputs=[dce_input, sinwas_input, dwi_input, image_output, file_output, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(share = True) | |