Spaces:
Sleeping
Sleeping
File size: 9,866 Bytes
b0ce241 0c0614b b0ce241 0c0614b 578f2a2 0c0614b d8ff5b3 86f90b6 a12cf19 6821d9e abe5bd9 4a401a9 6821d9e b6f8021 4273e26 e0439a6 6821d9e d8ff5b3 6821d9e 0c0614b c381018 4273e26 da407c5 0c0614b c381018 d180216 0c0614b d180216 da407c5 d4f73f3 d180216 578f2a2 d4f73f3 9914097 c381018 578f2a2 6821d9e 411bb66 da407c5 c381018 da407c5 4273e26 c381018 da407c5 06fed23 578f2a2 c381018 578f2a2 6821d9e 411bb66 5c5db11 c381018 6f796eb c381018 4273e26 411bb66 4273e26 578f2a2 d180216 6f796eb c381018 2b97aab c381018 411bb66 c381018 2b97aab 9f799ce da407c5 9f799ce 4273e26 38d6f27 2b97aab 38d6f27 1456ad0 38d6f27 09f7686 38d6f27 26c3fc9 38d6f27 1d4cf28 26c3fc9 1d4cf28 c9adbe6 de07630 4ae9a02 de07630 c9adbe6 ce8f63e 63fb010 ce8f63e 1d4cf28 c824f6d de07630 3baef41 1d4cf28 04357be 911d7bd 1d4cf28 f1d261e 5769a1b ce3218a 6b313e4 7b144a3 70cc95d 7b144a3 6b313e4 f1d261e 1d4cf28 ce8f63e 9a0b3c0 a1fe6cc 9a0b3c0 b5b122c de07630 30cf2d5 b5b122c ce8f63e 911d7bd dd7279c ce8f63e 1d4cf28 dd7279c ce8f63e 1d4cf28 dd7279c 2b97aab 1d4cf28 ce8f63e 1d4cf28 b0ce241 58adb60 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | import gradio as gr
import torch
import numpy as np
import nibabel as nib
import torch.nn.functional as F
import matplotlib.pyplot as plt
from generator import UnetGenerator
from huggingface_hub import hf_hub_download
from collections import OrderedDict
import tempfile
import os
import time
import spaces # for @spaces.GPU
# Device setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model_path = hf_hub_download(repo_id="zhang0319/synthetic_t2", filename="generator_t2_88.pth")
state_dict = torch.load(model_path, map_location="cuda")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace("module.", "") if k.startswith("module.") else k
new_state_dict[new_key] = v
model = UnetGenerator()
model.load_state_dict(new_state_dict)
model.eval().to("cuda")
# Target size (H, W) for each slice
target_h, target_w = 384, 192
def normalize(img):
min_val, max_val = img.min(), img.max()
return 2 * (img - min_val) / (max_val - min_val) - 1 if max_val > min_val else img
def center_crop_pad(img, target_h=384, target_w=192):
h, w = img.shape
center_y, center_x = h // 2, w // 2
crop_top = max(center_y - target_h // 2, 0)
crop_bottom = min(crop_top + target_h, h)
crop_top = crop_bottom - target_h if crop_bottom - target_h >= 0 else 0
# crop_left = max(center_x - target_w // 2, 0)
# crop_right = min(crop_left + target_w, w)
# crop_left = crop_right - target_w if crop_right - target_w >= 0 else 0
# X軸保留左側
crop_right = w
crop_left = max(w - target_w, 0)
# cropped = img[crop_top:crop_bottom, crop_left:crop_right]
cropped = img[crop_top:crop_bottom, crop_left:crop_right][:, ::-1]
# padding if necessary
pad_h = max(target_h - cropped.shape[0], 0)
pad_w = max(target_w - cropped.shape[1], 0)
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
return np.pad(cropped, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0)
@spaces.GPU
def predict_volume(dce, sinwas, dwi):
z = dce.shape[2]
output = np.zeros((z, target_h, target_w), dtype=np.float32)
for i in range(z):
dce_slice = np.fliplr(dce[:, :, i]) # 水平方向翻轉,對齊訓練資料
sinwas_slice = np.fliplr(sinwas[:, :, i])
dwi_slice = np.fliplr(dwi[:, :, i])
dce_crop = center_crop_pad(dce_slice)
sinwas_crop = center_crop_pad(sinwas_slice)
dwi_crop = center_crop_pad(dwi_slice)
x1 = torch.tensor(dce_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
x2 = torch.tensor(sinwas_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
x3 = torch.tensor(dwi_crop, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cuda")
with torch.no_grad():
out, _, _, _ = model(x1, x2, x3)
output[i] = out.squeeze().cpu().numpy()
return output
@spaces.GPU
def run_synthesis(dce_file, sinwas_file, dwi_file):
start_time = time.time()
dce = nib.load(dce_file.name).get_fdata().astype(np.float32)
sinwas = nib.load(sinwas_file.name).get_fdata().astype(np.float32)
dwi = nib.load(dwi_file.name).get_fdata().astype(np.float32)
original_h = dce.shape[0] # 根據輸入自動獲取原始高度
dce = normalize(dce)
sinwas = normalize(sinwas)
dwi = normalize(dwi)
# predict full volume
fake_volume = predict_volume(dce, sinwas, dwi)
fake_volume = (fake_volume + 1) / 2
# # 最終保存前裁剪 Y方向為原始高度352
# start_y = (fake_volume.shape[1] - original_h) // 2
# final_volume = fake_volume[:, start_y:start_y + original_h, :] # (Z, H, W)
if fake_volume.shape[1] >= original_h:
start_y = (fake_volume.shape[1] - original_h) // 2
final_volume = fake_volume[:, start_y:start_y + original_h, :]
else:
pad_top = (original_h - fake_volume.shape[1]) // 2
pad_bottom = original_h - fake_volume.shape[1] - pad_top
final_volume = np.pad(fake_volume, ((0, 0), (pad_top, pad_bottom), (0, 0)), mode='constant')
# Convert shape to (H, W, Z)
final_volume = np.transpose(final_volume, (1, 2, 0))
# Save as nifti (use DCE header as reference)
affine = nib.load(dce_file.name).affine
header = nib.load(dce_file.name).header
output_nii = nib.Nifti1Image(final_volume, affine=affine, header=header)
temp_dir = tempfile.mkdtemp()
filename = os.path.join(temp_dir, "t2_synthesized.nii.gz")
nib.save(output_nii, filename)
# Middle slice
mid_slice = final_volume.shape[2] // 2
dce_slice = center_crop_pad(np.fliplr(dce[:, :, mid_slice]))
if dce_slice.shape[0] > 352:
x_start = (dce_slice.shape[0] - 352) // 2
dce_slice = dce_slice[x_start:x_start + 352, :]
fig, ax = plt.subplots(2, 1, figsize=(5, 5))
# ax[0].imshow(center_crop_pad(np.fliplr(dce[:, :, mid_slice])), cmap="gray")
ax[0].imshow(np.rot90(dce_slice,k=3), cmap="gray")
ax[0].set_title("Original T1 Slice")
ax[0].axis("off")
ax[1].imshow(np.rot90(final_volume[:, :, mid_slice], k=3), cmap="gray")
ax[1].set_title("Synthesized T2 Slice")
ax[1].axis("off")
elapsed_time = time.time() - start_time
return fig, filename, f"Synthesis completed in {elapsed_time:.2f} seconds"
example_list = [
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case1.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case1.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case1.nii.gz"
],
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case2.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case2.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case2.nii.gz"
],
[
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/1_case3.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/sinwas_case3.nii.gz",
"https://huggingface.co/zhang0319/synthetic_t2/resolve/main/examples/0_case3.nii.gz"
]
]
custom_theme = gr.themes.Base().set(
body_background_fill="#ffffff",
button_primary_background_fill="#e0e7ff",
button_primary_text_color="#4f46e5"
)
with gr.Blocks(title="Breast MRI T2 Synthesizer", theme=custom_theme) as interface:
gr.Markdown("""
# 🧠 IMPORTANT-Net: Breast MRI T2 Synthesizer
✨ Upload your T1, Sinwas, and DWI volumes to generate a synthetic T2-weighted MRI volume using a deep learning model.✨
📧 Contact us: Dr. Tianyu Zhang (Tianyu.Zhang@radboudumc.nl), Dr. Ritse Mann (Ritse.Mann@radboudumc.nl)
""")
gr.HTML("""
<div style="
height: 6px;
background: linear-gradient(to right, #6366f1, #a78bfa);
border-radius: 3px;
margin-top: 10px;
margin-bottom: 20px;
"></div>
""")
with gr.Row():
with gr.Column(scale=0.8):
dce_input = gr.File(label="T1 (1.nii.gz)", height=150)
sinwas_input = gr.File(label="Sinwas (sinwas.nii.gz)", height=150)
dwi_input = gr.File(label="DWI (0.nii.gz)", height=150)
# with gr.Column(scale=1):
# gr.Image(
# value="https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExanVjNG1lM3JlMnZyajFoMm5hcXh1dDlkZW83Ymx4bTh6emFrZmJ3cSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/eljCVpMrhepUSgZaVP/giphy.gif",
# label="Neural Activity",
# show_label=False,
# interactive=False
# )
with gr.Column(scale=1):
gr.Markdown("""

""")
with gr.Column(scale=1):
image_output = gr.Plot(label="Middle Slice: Original vs Synthesized T2")
file_output = gr.File(label="Download Full Synthesized T2 Volume")
status_output = gr.Textbox(label="Status")
gr.HTML("""
<div style="
height: 6px;
background: linear-gradient(to right, #6366f1, #a78bfa);
border-radius: 3px;
margin-top: 10px;
margin-bottom: 20px;
"></div>
""")
gr.Examples(
examples=example_list,
inputs=[dce_input, sinwas_input, dwi_input],
label="Select an example to try",
examples_per_page=2,
cache_examples=False
)
with gr.Row():
run_button = gr.Button("Run Synthesis", variant="primary")
clear_button = gr.Button("Clear All")
# gr.Examples(
# examples=example_list,
# inputs=[dce_input, sinwas_input, dwi_input],
# label="Try with example files"
# )
# run_button = gr.Button("Run Synthesis")
run_button.click(
fn=run_synthesis,
inputs=[dce_input, sinwas_input, dwi_input],
outputs=[image_output, file_output, status_output]
)
gr.Markdown("## 📊 Overview")
gr.Markdown("""
---
### 🖼️ Flowchart Overview

""")
def clear_inputs():
return None, None, None, gr.update(value=None), gr.update(value=None), ""
clear_button.click(
fn=clear_inputs,
inputs=[],
outputs=[dce_input, sinwas_input, dwi_input, image_output, file_output, status_output]
)
if __name__ == "__main__":
interface.launch(share = True)
|