Synthetic_Model / app.py
zhang0319's picture
Update app.py
30cf2d5 verified
raw
history blame
9.87 kB
import gradio as gr
import torch
import numpy as np
import nibabel as nib
import torch.nn.functional as F
import matplotlib.pyplot as plt
from generator import UnetGenerator
from huggingface_hub import hf_hub_download
from collections import OrderedDict
import tempfile
import os
import time
import spaces # for @spaces.GPU
# Device setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model_path = hf_hub_download(repo_id="zhang0319/synthetic_t2", filename="generator_t2_88.pth")
state_dict = torch.load(model_path, map_location="cuda")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace("module.", "") if k.startswith("module.") else k
new_state_dict[new_key] = v
model = UnetGenerator()
model.load_state_dict(new_state_dict)
model.eval().to("cuda")
# Target size (H, W) for each slice
target_h, target_w = 384, 192
def normalize(img):
min_val, max_val = img.min(), img.max()
return 2 * (img - min_val) / (max_val - min_val) - 1 if max_val > min_val else img
def center_crop_pad(img, target_h=384, target_w=192):
h, w = img.shape
center_y, center_x = h // 2, w // 2
crop_top = max(center_y - target_h // 2, 0)
crop_bottom = min(crop_top + target_h, h)
crop_top = crop_bottom - target_h if crop_bottom - target_h >= 0 else 0
# crop_left = max(center_x - target_w // 2, 0)
# crop_right = min(crop_left + target_w, w)
# crop_left = crop_right - target_w if crop_right - target_w >= 0 else 0
# X軸保留左側
crop_right = w
crop_left = max(w - target_w, 0)
# cropped = img[crop_top:crop_bottom, crop_left:crop_right]
cropped = img[crop_top:crop_bottom, crop_left:crop_right][:, ::-1]
# padding if necessary
pad_h = max(target_h - cropped.shape[0], 0)
pad_w = max(target_w - cropped.shape[1], 0)
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
return np.pad(cropped, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0)
@spaces.GPU
def predict_volume(dce, sinwas, dwi):
z = dce.shape[2]
output = np.zeros((z, target_h, target_w), dtype=np.float32)
for i in range(z):
dce_slice = np.fliplr(dce[:, :, i]) # 水平方向翻轉,對齊訓練資料
sinwas_slice = np.fliplr(sinwas[:, :, i])
dwi_slice = np.fliplr(dwi[:, :, i])
dce_crop = center_crop_pad(dce_slice)
sinwas_crop = center_crop_pad(sinwas_slice)
dwi_crop = center_crop_pad(dwi_slice)
x1 = torch.tensor(dce_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
x2 = torch.tensor(sinwas_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
x3 = torch.tensor(dwi_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
with torch.no_grad():
out, _, _, _ = model(x1, x2, x3)
output[i] = out.squeeze().cpu().numpy()
return output
@spaces.GPU
def run_synthesis(dce_file, sinwas_file, dwi_file):
start_time = time.time()
dce = nib.load(dce_file.name).get_fdata().astype(np.float32)
sinwas = nib.load(sinwas_file.name).get_fdata().astype(np.float32)
dwi = nib.load(dwi_file.name).get_fdata().astype(np.float32)
original_h = dce.shape[0] # 根據輸入自動獲取原始高度
dce = normalize(dce)
sinwas = normalize(sinwas)
dwi = normalize(dwi)
# predict full volume
fake_volume = predict_volume(dce, sinwas, dwi)
fake_volume = (fake_volume + 1) / 2
# # 最終保存前裁剪 Y方向為原始高度352
# start_y = (fake_volume.shape[1] - original_h) // 2
# final_volume = fake_volume[:, start_y:start_y + original_h, :] # (Z, H, W)
if fake_volume.shape[1] >= original_h:
start_y = (fake_volume.shape[1] - original_h) // 2
final_volume = fake_volume[:, start_y:start_y + original_h, :]
else:
pad_top = (original_h - fake_volume.shape[1]) // 2
pad_bottom = original_h - fake_volume.shape[1] - pad_top
final_volume = np.pad(fake_volume, ((0, 0), (pad_top, pad_bottom), (0, 0)), mode='constant')
# Convert shape to (H, W, Z)
final_volume = np.transpose(final_volume, (1, 2, 0))
# Save as nifti (use DCE header as reference)
affine = nib.load(dce_file.name).affine
header = nib.load(dce_file.name).header
output_nii = nib.Nifti1Image(final_volume, affine=affine, header=header)
temp_dir = tempfile.mkdtemp()
filename = os.path.join(temp_dir, "t2_synthesized.nii.gz")
nib.save(output_nii, filename)
# Middle slice
mid_slice = final_volume.shape[2] // 2
dce_slice = center_crop_pad(np.fliplr(dce[:, :, mid_slice]))
if dce_slice.shape[0] > 352:
x_start = (dce_slice.shape[0] - 352) // 2
dce_slice = dce_slice[x_start:x_start + 352, :]
fig, ax = plt.subplots(2, 1, figsize=(5, 5))
# ax[0].imshow(center_crop_pad(np.fliplr(dce[:, :, mid_slice])), cmap="gray")
ax[0].imshow(np.rot90(dce_slice,k=3), cmap="gray")
ax[0].set_title("Original T1 Slice")
ax[0].axis("off")
ax[1].imshow(np.rot90(final_volume[:, :, mid_slice], k=3), cmap="gray")
ax[1].set_title("Synthesized T2 Slice")
ax[1].axis("off")
elapsed_time = time.time() - start_time
return fig, filename, f"Synthesis completed in {elapsed_time:.2f} seconds"
example_list = [
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case1.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case1.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case1.nii.gz"
],
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case2.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case2.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case2.nii.gz"
],
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case3.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case3.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case3.nii.gz"
]
]
custom_theme = gr.themes.Base().set(
body_background_fill="#ffffff",
button_primary_background_fill="#e0e7ff",
button_primary_text_color="#4f46e5"
)
with gr.Blocks(title="Breast MRI T2 Synthesizer", theme=custom_theme) as interface:
gr.Markdown("""
# 🧠 IMPORTANT-Net: Breast MRI T2 Synthesizer
✨ Upload your T1, Sinwas, and DWI volumes to generate a synthetic T2-weighted MRI volume using a deep learning model.✨
📧 Contact us: Dr. Tianyu Zhang (Tianyu.Zhang@radboudumc.nl), Dr. Ritse Mann (Ritse.Mann@radboudumc.nl)
""")
gr.HTML("""
<div style="
height: 6px;
background: linear-gradient(to right, #6366f1, #a78bfa);
border-radius: 3px;
margin-top: 10px;
margin-bottom: 20px;
"></div>
""")
with gr.Row():
with gr.Column(scale=0.8):
dce_input = gr.File(label="T1 (1.nii.gz)", height=150)
sinwas_input = gr.File(label="Sinwas (sinwas.nii.gz)", height=150)
dwi_input = gr.File(label="DWI (0.nii.gz)", height=150)
# with gr.Column(scale=1):
# gr.Image(
# value="https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif",
# label="Neural Activity",
# show_label=False,
# interactive=False
# )
with gr.Column(scale=1):
gr.Markdown("""
![Neural Activity](https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif)
""")
with gr.Column(scale=1):
image_output = gr.Plot(label="Middle Slice: Original vs Synthesized T2")
file_output = gr.File(label="Download Full Synthesized T2 Volume")
status_output = gr.Textbox(label="Status")
gr.HTML("""
<div style="
height: 6px;
background: linear-gradient(to right, #6366f1, #a78bfa);
border-radius: 3px;
margin-top: 10px;
margin-bottom: 20px;
"></div>
""")
gr.Examples(
examples=example_list,
inputs=[dce_input, sinwas_input, dwi_input],
label="Select an example to try",
examples_per_page=2,
cache_examples=False
)
with gr.Row():
run_button = gr.Button("Run Synthesis", variant="primary")
clear_button = gr.Button("Clear All")
# gr.Examples(
# examples=example_list,
# inputs=[dce_input, sinwas_input, dwi_input],
# label="Try with example files"
# )
# run_button = gr.Button("Run Synthesis")
run_button.click(
fn=run_synthesis,
inputs=[dce_input, sinwas_input, dwi_input],
outputs=[image_output, file_output, status_output]
)
gr.Markdown("## 📊 Overview")
gr.Markdown("""
---
### 🖼️ Flowchart Overview
![Network Diagram](https://huggingface.co/zhang0319/synthetic_t2/resolve/main/IMPORTANT-NET.jpg)
""")
def clear_inputs():
return None, None, None, gr.update(value=None), gr.update(value=None), ""
clear_button.click(
fn=clear_inputs,
inputs=[],
outputs=[dce_input, sinwas_input, dwi_input, image_output, file_output, status_output]
)
if __name__ == "__main__":
interface.launch(share = True)