Haojiacheng commited on
Commit
a773c9a
·
1 Parent(s): 42c51f6

Upload 2 files

Browse files
Files changed (2) hide show
  1. demo.ipynb +461 -0
  2. demo.py +178 -0
demo.ipynb ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# video 导入"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 10,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stderr",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
20
+ "`optional` parameter is deprecated, and it has no effect\n",
21
+ "`keep_filename` parameter is deprecated, and it has no effect\n",
22
+ "The `allow_flagging` parameter in `Interface` nowtakes a string value ('auto', 'manual', or 'never'), not a boolean. Setting parameter to: 'never'.\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Running on local URL: http://127.0.0.1:7865\n",
30
+ "\n",
31
+ "To create a public link, set `share=True` in `launch()`.\n"
32
+ ]
33
+ },
34
+ {
35
+ "data": {
36
+ "text/html": [
37
+ "<div><iframe src=\"http://127.0.0.1:7865/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
38
+ ],
39
+ "text/plain": [
40
+ "<IPython.core.display.HTML object>"
41
+ ]
42
+ },
43
+ "metadata": {},
44
+ "output_type": "display_data"
45
+ },
46
+ {
47
+ "data": {
48
+ "text/plain": []
49
+ },
50
+ "execution_count": 10,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ },
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "C:\\WINDOWS\\TEMP\\gradio\\6da74a6a81402070d14fdaec056ed2fd2ef5f186\\62691117.nii.gz\n"
59
+ ]
60
+ },
61
+ {
62
+ "name": "stderr",
63
+ "output_type": "stream",
64
+ "text": [
65
+ "Traceback (most recent call last):\n",
66
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\gradio\\routes.py\", line 439, in run_predict\n",
67
+ " output = await app.get_blocks().process_api(\n",
68
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\gradio\\blocks.py\", line 1384, in process_api\n",
69
+ " result = await self.call_function(\n",
70
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\gradio\\blocks.py\", line 1089, in call_function\n",
71
+ " prediction = await anyio.to_thread.run_sync(\n",
72
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\anyio\\to_thread.py\", line 33, in run_sync\n",
73
+ " return await get_asynclib().run_sync_in_worker_thread(\n",
74
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 877, in run_sync_in_worker_thread\n",
75
+ " return await future\n",
76
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 807, in run\n",
77
+ " result = context.run(func, *args)\n",
78
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\gradio\\utils.py\", line 700, in wrapper\n",
79
+ " response = f(*args, **kwargs)\n",
80
+ " File \"C:\\Windows\\Temp\\ipykernel_17604\\3933752410.py\", line 12, in process_nii_file\n",
81
+ " model = UNETR(\n",
82
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\monai\\networks\\nets\\unetr.py\", line 93, in __init__\n",
83
+ " self.vit = ViT(\n",
84
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\monai\\networks\\nets\\vit.py\", line 93, in __init__\n",
85
+ " self.patch_embedding = PatchEmbeddingBlock(\n",
86
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\monai\\networks\\blocks\\patchembedding.py\", line 99, in __init__\n",
87
+ " Rearrange(f\"{from_chars} -> {to_chars}\", **axes_len), nn.Linear(self.patch_dim, hidden_size)\n",
88
+ " File \"d:\\anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\linear.py\", line 96, in __init__\n",
89
+ " self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))\n",
90
+ "RuntimeError: [enforce fail at C:\\cb\\pytorch_1000000000000\\work\\c10\\core\\impl\\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 12582912 bytes.\n"
91
+ ]
92
+ }
93
+ ],
94
+ "source": [
95
+ "import nibabel as nib\n",
96
+ "import numpy as np\n",
97
+ "import matplotlib.pyplot as plt\n",
98
+ "import gradio as gr\n",
99
+ "from pathlib import Path\n",
100
+ "import torch\n",
101
+ "from monai.networks.nets import UNETR\n",
102
+ "import pytorch_lightning as pl\n",
103
+ "import tempfile\n",
104
+ "import base64\n",
105
+ "from celluloid import Camera\n",
106
+ "from IPython.display import HTML \n",
107
+ "\n",
108
+ "def load_nifti(sample_path):\n",
109
+ " print(sample_path)\n",
110
+ " data = nib.load(sample_path).get_fdata()\n",
111
+ " data = np.rot90(data, 3)\n",
112
+ " return data\n",
113
+ "\n",
114
+ "def generate_animation(mri):\n",
115
+ " fig = plt.figure()\n",
116
+ " plt.axis('off')\n",
117
+ " camera = Camera(fig) # Create the camera object from celluloid\n",
118
+ "\n",
119
+ " for i in range(mri.shape[2]): # Sagital view\n",
120
+ " plt.imshow(mri[:,:,i], cmap=\"bone\")\n",
121
+ " camera.snap() # Store the current slice\n",
122
+ " \n",
123
+ " animation = camera.animate(interval=200)\n",
124
+ "\n",
125
+ " # Save the animation as a GIF file\n",
126
+ " with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as temp_file:\n",
127
+ " temp_filename = temp_file.name\n",
128
+ " animation.save(temp_filename, writer='pillow', fps=3)\n",
129
+ "\n",
130
+ " return temp_filename\n",
131
+ "\n",
132
+ "def predict_nifti(file):\n",
133
+ " file_path = file.name\n",
134
+ " # Load and process NIfTI file\n",
135
+ " mri = load_nifti(file_path)\n",
136
+ " \n",
137
+ " # Generate animation and get the temporary file path\n",
138
+ " animation_path = generate_animation(mri)\n",
139
+ "\n",
140
+ " # Read the GIF file as bytes\n",
141
+ " with open(animation_path, 'rb') as file:\n",
142
+ " animation_bytes = file.read()\n",
143
+ "\n",
144
+ " # Convert the bytes to base64 string\n",
145
+ " animation_base64 = base64.b64encode(animation_bytes).decode('utf-8')\n",
146
+ "\n",
147
+ " # Generate the HTML code to display the animation\n",
148
+ " html_code = f'<img src=\"data:image/gif;base64,{animation_base64}\" alt=\"animation\">'\n",
149
+ "\n",
150
+ " # Return the HTML code\n",
151
+ " return html_code\n",
152
+ "\n",
153
+ "examples = [[r\"F:\\sth\\23Fall\\fcpro\\brain_image2\\imageTs\\60071979.nii.gz\"]]\n",
154
+ "# 创建 Gradio 用户界面\n",
155
+ "iface = gr.Interface(\n",
156
+ " fn=predict_nifti,\n",
157
+ " inputs=gr.inputs.File(label=\"上传MRI文件\", type=\"file\"),\n",
158
+ " outputs=\"html\",\n",
159
+ " title=\"NKU \",\n",
160
+ " description=\"南开大学智齿辅助诊断系统\",\n",
161
+ " allow_flagging=False,\n",
162
+ " examples=examples\n",
163
+ ")\n",
164
+ "\n",
165
+ "iface.launch(share=False)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "# 显示切片"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 1,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import os\n",
182
+ "import shutil\n",
183
+ "import tempfile\n",
184
+ "\n",
185
+ "import matplotlib.pyplot as plt\n",
186
+ "from tqdm import tqdm\n",
187
+ "\n",
188
+ "from monai.losses import DiceCELoss\n",
189
+ "from monai.inferers import sliding_window_inference\n",
190
+ "from monai.transforms import (\n",
191
+ " AsDiscrete,\n",
192
+ " EnsureChannelFirstd,\n",
193
+ " Compose,\n",
194
+ " CropForegroundd,\n",
195
+ " LoadImaged,\n",
196
+ " Orientationd,\n",
197
+ " RandFlipd,\n",
198
+ " RandCropByPosNegLabeld,\n",
199
+ " RandShiftIntensityd,\n",
200
+ " ScaleIntensityRanged,\n",
201
+ " Spacingd,\n",
202
+ " SpatialPadd,\n",
203
+ " RandRotate90d,\n",
204
+ " CenterSpatialCropd,\n",
205
+ " ResizeWithPadOrCropd,\n",
206
+ " Flipd,\n",
207
+ " Rotate90d,\n",
208
+ " RandAffined,\n",
209
+ " RandGaussianNoised,\n",
210
+ ")\n",
211
+ "\n",
212
+ "from monai.config import print_config\n",
213
+ "from monai.metrics import DiceMetric\n",
214
+ "from monai.networks.nets import UNETR\n",
215
+ "\n",
216
+ "from monai.data import (\n",
217
+ " DataLoader,\n",
218
+ " CacheDataset,\n",
219
+ " load_decathlon_datalist,\n",
220
+ " decollate_batch,\n",
221
+ " pad_list_data_collate,\n",
222
+ " SmartCacheDataset,\n",
223
+ " ArrayDataset,\n",
224
+ " Dataset\n",
225
+ ")\n",
226
+ "\n",
227
+ "import numpy as np\n",
228
+ "import torch"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 6,
234
+ "metadata": {},
235
+ "outputs": [
236
+ {
237
+ "name": "stderr",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
241
+ "`optional` parameter is deprecated, and it has no effect\n",
242
+ "Expected 4 arguments for function <function process_nii_file at 0x000001E6090B2160>, received 3.\n",
243
+ "Expected at least 4 arguments for function <function process_nii_file at 0x000001E6090B2160>, received 3.\n"
244
+ ]
245
+ },
246
+ {
247
+ "name": "stdout",
248
+ "output_type": "stream",
249
+ "text": [
250
+ "Running on local URL: http://127.0.0.1:7864\n",
251
+ "Running on public URL: https://69e978c9547c47c32b.gradio.live\n",
252
+ "\n",
253
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
254
+ ]
255
+ },
256
+ {
257
+ "data": {
258
+ "text/html": [
259
+ "<div><iframe src=\"https://69e978c9547c47c32b.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
260
+ ],
261
+ "text/plain": [
262
+ "<IPython.core.display.HTML object>"
263
+ ]
264
+ },
265
+ "metadata": {},
266
+ "output_type": "display_data"
267
+ },
268
+ {
269
+ "data": {
270
+ "text/plain": []
271
+ },
272
+ "execution_count": 6,
273
+ "metadata": {},
274
+ "output_type": "execute_result"
275
+ }
276
+ ],
277
+ "source": [
278
+ "import gradio as gr\n",
279
+ "import matplotlib.pyplot as plt\n",
280
+ "import torch\n",
281
+ "import nibabel as nib\n",
282
+ "import numpy as np\n",
283
+ "import SimpleITK as sitk\n",
284
+ "\n",
285
+ "def dcm2nii(dcms_path, nii_path):\n",
286
+ "\t# 1.构建dicom序列文件阅读器,并执行(即将dicom序列文件“打包整合”)\n",
287
+ " reader = sitk.ImageSeriesReader()\n",
288
+ " dicom_names = reader.GetGDCMSeriesFileNames(dcms_path)\n",
289
+ " reader.SetFileNames(dicom_names)\n",
290
+ " image2 = reader.Execute()\n",
291
+ "\t# 2.将整合后的数据转为array,并获取dicom文件基本信息\n",
292
+ " image_array = sitk.GetArrayFromImage(image2) # z, y, x\n",
293
+ " origin = image2.GetOrigin() # x, y, z\n",
294
+ " print(origin)\n",
295
+ " spacing = image2.GetSpacing() # x, y, z\n",
296
+ " print(spacing)\n",
297
+ " direction = image2.GetDirection() # x, y, z\n",
298
+ " print(direction)\n",
299
+ "\n",
300
+ " # 3.将array转为img,并保存为.nii.gz\n",
301
+ " image3 = sitk.GetImageFromArray(image_array)\n",
302
+ " image3.SetSpacing(spacing)\n",
303
+ " image3.SetDirection(direction)\n",
304
+ " image3.SetOrigin(origin)\n",
305
+ " sitk.WriteImage(image3, nii_path)\n",
306
+ "\n",
307
+ "def calculate_volume(mask_image_path):\n",
308
+ " # 读取分割结果的图像文件\n",
309
+ " mask_image = sitk.ReadImage(mask_image_path)\n",
310
+ "\n",
311
+ " # 获取图像的大小、原点和间距\n",
312
+ " size = mask_image.GetSize()\n",
313
+ " origin = mask_image.GetOrigin()\n",
314
+ " spacing = mask_image.GetSpacing()\n",
315
+ "\n",
316
+ " # 将 SimpleITK 图像转换为 NumPy 数组\n",
317
+ " mask_array = sitk.GetArrayFromImage(mask_image)\n",
318
+ "\n",
319
+ " # if len(np.unique(mask_array)) != 5:\n",
320
+ " # print(mask_image_path[-15:-12])\n",
321
+ " # print(np.unique(mask_array))\n",
322
+ " \n",
323
+ " # 计算非零像素的数量\n",
324
+ " one_voxels = (mask_array == 1).sum()\n",
325
+ " two_voxels = (mask_array == 2).sum()\n",
326
+ " three_voxels = (mask_array == 3).sum()\n",
327
+ " four_voxels = (mask_array == 4).sum()\n",
328
+ " # print(one_voxels,two_voxels,three_voxels,four_voxels)\n",
329
+ " # 计算像素的体积(以立方毫米为单位)\n",
330
+ " voxel_volume_mm3 = spacing[0] * spacing[1] * spacing[2]\n",
331
+ "\n",
332
+ " # 计算体积(以 mm³ 为单位)\n",
333
+ " V_Right_ventricular_cistern = one_voxels * voxel_volume_mm3 / 1000.0\n",
334
+ " V_Right_cerebral_sulcus = two_voxels * voxel_volume_mm3 / 1000.0\n",
335
+ " V_Left_ventricular_cistern = three_voxels * voxel_volume_mm3 / 1000.0\n",
336
+ " V_Left_cerebral_sulcus = four_voxels * voxel_volume_mm3 / 1000.0\n",
337
+ " # 如果需要以其他单位(例如 cm³)显示,请进行适当的单位转换\n",
338
+ " # volume_cm3 = volume_mm3 / 1000.0\n",
339
+ "\n",
340
+ " return size,spacing,V_Right_ventricular_cistern, V_Right_cerebral_sulcus, V_Left_ventricular_cistern, V_Left_cerebral_sulcus\n",
341
+ " \n",
342
+ "def process_nii_file(input_nii_file, dicom_file, slice, mode):\n",
343
+ " \n",
344
+ " if mode == \"Step1:Segment\":\n",
345
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
346
+ " root_dir = \"./run\"\n",
347
+ " model = UNETR(\n",
348
+ " in_channels=1,\n",
349
+ " out_channels=5,\n",
350
+ " img_size=(96, 96, 16),\n",
351
+ " feature_size=16,\n",
352
+ " hidden_size=768,\n",
353
+ " mlp_dim=3072,\n",
354
+ " num_heads=12,\n",
355
+ " pos_embed=\"perceptron\",\n",
356
+ " norm_name=\"instance\",\n",
357
+ " res_block=True,\n",
358
+ " dropout_rate=0.0,\n",
359
+ " ).to(device)\n",
360
+ " model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model67v2.pth\")))\n",
361
+ " \n",
362
+ " test_transforms = Compose(\n",
363
+ " [\n",
364
+ " LoadImaged(keys=[\"image\"]),\n",
365
+ " EnsureChannelFirstd(keys=[\"image\"]),\n",
366
+ " Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n",
367
+ " ScaleIntensityRanged(\n",
368
+ " keys=[\"image\"],\n",
369
+ " a_min=-50,\n",
370
+ " a_max=100,\n",
371
+ " b_min=0.0,\n",
372
+ " b_max=1.0,\n",
373
+ " clip=True,\n",
374
+ " ),\n",
375
+ " Rotate90d(keys=[\"image\"], k=1)\n",
376
+ " # ResizeWithPadOrCropd(keys=[\"image\"], spatial_size=(512, 512, 16)),\n",
377
+ " ]\n",
378
+ " )\n",
379
+ " test_file = [{'image':input_nii_file.name}]\n",
380
+ " # test_file = [{'image':r'F:\\sth\\23Fall\\fcpro\\brain_image_copy\\image\\60020599.nii.gz'}]\n",
381
+ " test_image = SmartCacheDataset(data=test_file, transform=test_transforms)[0]['image']\n",
382
+ "\n",
383
+ " with torch.no_grad():\n",
384
+ "\n",
385
+ " inputs = torch.unsqueeze(test_image, 1).cuda()\n",
386
+ "\n",
387
+ " val_outputs = sliding_window_inference(inputs, (96, 96, 16), 8, model, overlap=0.8)\n",
388
+ " \n",
389
+ " # Process the output image\n",
390
+ " output_image = torch.argmax(val_outputs, dim=1).detach().cpu().squeeze(0)\n",
391
+ " \n",
392
+ " # Display the images\n",
393
+ " fig1 = plt.figure()\n",
394
+ " plt.title(\"image\")\n",
395
+ " plt.axis('off') # Remove axis\n",
396
+ " plt.imshow(inputs.cpu().numpy()[0, 0, :, :, slice], cmap=\"gray\")\n",
397
+ "\n",
398
+ " fig2 = plt.figure()\n",
399
+ " plt.title(\"output\")\n",
400
+ " plt.axis('off') # Remove axis\n",
401
+ " plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice])\n",
402
+ "\n",
403
+ " val_outputs = torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, :]\n",
404
+ " val_outputs = val_outputs.numpy().astype('int16')\n",
405
+ " # val_outputs = np.transpose(val_outputs, (2, 1, 0))\n",
406
+ " val_outputs = np.rot90(val_outputs, k=3)\n",
407
+ " val_outputs = nib.Nifti1Image(val_outputs, np.eye(4))\n",
408
+ " nib.save(val_outputs, f'D:/{input_nii_file.name[-15:-7]}_mask.nii.gz')\n",
409
+ "\n",
410
+ " return [\"指定切片分割结果如下, mask文件已保存至D:/\", fig1, fig2]\n",
411
+ " \n",
412
+ " if mode == \"Step2:Volumn\":\n",
413
+ " maskFilePath = input_nii_file.name\n",
414
+ " size,spacing,V_Right_ventricular_cistern, V_Right_cerebral_sulcus, V_Left_ventricular_cistern, V_Left_cerebral_sulcus = calculate_volume(maskFilePath)\n",
415
+ "\n",
416
+ " vol = f\"\"\"右侧脑室脑池的体积为{V_Right_ventricular_cistern}cm³\\n 右侧脑沟的体积为{V_Right_cerebral_sulcus}cm³\\n 左侧脑室脑池的体积为{V_Left_ventricular_cistern}cm³\\n 左侧脑沟的体积为{V_Left_cerebral_sulcus}cm³\"\"\"\n",
417
+ " fig1 = plt.figure()\n",
418
+ " fig2 = plt.figure()\n",
419
+ " return [vol, fig1, fig2]\n",
420
+ "\n",
421
+ "# Define the Gradio interface\n",
422
+ "iface = gr.Interface(\n",
423
+ " fn=process_nii_file,\n",
424
+ " inputs=\n",
425
+ " [gr.File(file_count='single', file_types=['.nii.gz']), \n",
426
+ " gr.inputs.Slider(0, 24, default=8, label=\"Select Slice\", step=1),\n",
427
+ " gr.Radio(\n",
428
+ " [\"Step1:Segment\", \"Step2:Volumn\"], label=\"mode\"\n",
429
+ " ),\n",
430
+ " ],\n",
431
+ " \n",
432
+ " outputs=[gr.Text(label=\"Output\"), gr.Plot(label=\"image\"), gr.Plot(label=\"mask\")], # Display both \"image\" and \"output\"\n",
433
+ ")\n",
434
+ "\n",
435
+ "iface.launch(share=True)\n"
436
+ ]
437
+ }
438
+ ],
439
+ "metadata": {
440
+ "kernelspec": {
441
+ "display_name": "pytorch",
442
+ "language": "python",
443
+ "name": "python3"
444
+ },
445
+ "language_info": {
446
+ "codemirror_mode": {
447
+ "name": "ipython",
448
+ "version": 3
449
+ },
450
+ "file_extension": ".py",
451
+ "mimetype": "text/x-python",
452
+ "name": "python",
453
+ "nbconvert_exporter": "python",
454
+ "pygments_lexer": "ipython3",
455
+ "version": "3.9.13"
456
+ },
457
+ "orig_nbformat": 4
458
+ },
459
+ "nbformat": 4,
460
+ "nbformat_minor": 2
461
+ }
demo.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+
5
+ from monai.losses import DiceCELoss
6
+ from monai.inferers import sliding_window_inference
7
+ from monai.transforms import (
8
+ EnsureChannelFirstd,
9
+ Compose,
10
+ LoadImaged,
11
+ Orientationd,
12
+ ScaleIntensityRanged,
13
+ Rotate90d,
14
+ )
15
+ from monai.networks.nets import UNETR
16
+
17
+ from monai.data import (
18
+ SmartCacheDataset,
19
+ )
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ import gradio as gr
25
+ import matplotlib.pyplot as plt
26
+ import torch
27
+ import nibabel as nib
28
+ import numpy as np
29
+ import SimpleITK as sitk
30
+
31
+ def dcm2nii(dcms_path, nii_path):
32
+ # 1.构建dicom序列文件阅读器,并执行(即将dicom序列文件“打包整合”)
33
+ reader = sitk.ImageSeriesReader()
34
+ dicom_names = reader.GetGDCMSeriesFileNames(dcms_path)
35
+ reader.SetFileNames(dicom_names)
36
+ image2 = reader.Execute()
37
+ # 2.将整合后的数据转为array,并获取dicom文件基本信息
38
+ image_array = sitk.GetArrayFromImage(image2) # z, y, x
39
+ origin = image2.GetOrigin() # x, y, z
40
+ print(origin)
41
+ spacing = image2.GetSpacing() # x, y, z
42
+ print(spacing)
43
+ direction = image2.GetDirection() # x, y, z
44
+ print(direction)
45
+
46
+ # 3.将array转为img,并保存为.nii.gz
47
+ image3 = sitk.GetImageFromArray(image_array)
48
+ image3.SetSpacing(spacing)
49
+ image3.SetDirection(direction)
50
+ image3.SetOrigin(origin)
51
+ sitk.WriteImage(image3, nii_path)
52
+
53
+ def calculate_volume(mask_image_path):
54
+ # 读取分割结果的图像文件
55
+ mask_image = sitk.ReadImage(mask_image_path)
56
+
57
+ # 获取图像的大小、原点和间距
58
+ size = mask_image.GetSize()
59
+ origin = mask_image.GetOrigin()
60
+ spacing = mask_image.GetSpacing()
61
+
62
+ # 将 SimpleITK 图像转换为 NumPy 数组
63
+ mask_array = sitk.GetArrayFromImage(mask_image)
64
+
65
+ # if len(np.unique(mask_array)) != 5:
66
+ # print(mask_image_path[-15:-12])
67
+ # print(np.unique(mask_array))
68
+
69
+ # 计算非零像素的数量
70
+ one_voxels = (mask_array == 1).sum()
71
+ two_voxels = (mask_array == 2).sum()
72
+ three_voxels = (mask_array == 3).sum()
73
+ four_voxels = (mask_array == 4).sum()
74
+ # print(one_voxels,two_voxels,three_voxels,four_voxels)
75
+ # 计算像素的体积(以立方毫米为单位)
76
+ voxel_volume_mm3 = spacing[0] * spacing[1] * spacing[2]
77
+
78
+ # 计算体积(以 mm³ 为单位)
79
+ V_Right_ventricular_cistern = one_voxels * voxel_volume_mm3 / 1000.0
80
+ V_Right_cerebral_sulcus = two_voxels * voxel_volume_mm3 / 1000.0
81
+ V_Left_ventricular_cistern = three_voxels * voxel_volume_mm3 / 1000.0
82
+ V_Left_cerebral_sulcus = four_voxels * voxel_volume_mm3 / 1000.0
83
+ # 如果需要以其他单位(例如 cm³)显示,请进行适当的单位转换
84
+ # volume_cm3 = volume_mm3 / 1000.0
85
+
86
+ return size,spacing,V_Right_ventricular_cistern, V_Right_cerebral_sulcus, V_Left_ventricular_cistern, V_Left_cerebral_sulcus
87
+
88
+ def process_nii_file(input_nii_file, dicom_file, slice, mode):
89
+
90
+ if mode == "Step1:Segment":
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ root_dir = "./run"
93
+ model = UNETR(
94
+ in_channels=1,
95
+ out_channels=5,
96
+ img_size=(96, 96, 16),
97
+ feature_size=16,
98
+ hidden_size=768,
99
+ mlp_dim=3072,
100
+ num_heads=12,
101
+ pos_embed="perceptron",
102
+ norm_name="instance",
103
+ res_block=True,
104
+ dropout_rate=0.0,
105
+ ).to(device)
106
+ model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model67v2.pth")))
107
+
108
+ test_transforms = Compose(
109
+ [
110
+ LoadImaged(keys=["image"]),
111
+ EnsureChannelFirstd(keys=["image"]),
112
+ Orientationd(keys=["image"], axcodes="RAS"),
113
+ ScaleIntensityRanged(
114
+ keys=["image"],
115
+ a_min=-50,
116
+ a_max=100,
117
+ b_min=0.0,
118
+ b_max=1.0,
119
+ clip=True,
120
+ ),
121
+ Rotate90d(keys=["image"], k=1)
122
+ # ResizeWithPadOrCropd(keys=["image"], spatial_size=(512, 512, 16)),
123
+ ]
124
+ )
125
+ test_file = [{'image':input_nii_file.name}]
126
+ # test_file = [{'image':r'F:\sth\23Fall\fcpro\brain_image_copy\image\60020599.nii.gz'}]
127
+ test_image = SmartCacheDataset(data=test_file, transform=test_transforms)[0]['image']
128
+
129
+ with torch.no_grad():
130
+
131
+ inputs = torch.unsqueeze(test_image, 1).cuda()
132
+
133
+ val_outputs = sliding_window_inference(inputs, (96, 96, 16), 8, model, overlap=0.8)
134
+
135
+ # Display the images
136
+ fig1 = plt.figure()
137
+ plt.title("image")
138
+ plt.axis('off') # Remove axis
139
+ plt.imshow(inputs.cpu().numpy()[0, 0, :, :, slice], cmap="gray")
140
+
141
+ fig2 = plt.figure()
142
+ plt.title("output")
143
+ plt.axis('off') # Remove axis
144
+ plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice])
145
+
146
+ val_outputs = torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, :]
147
+ val_outputs = val_outputs.numpy().astype('int16')
148
+ # val_outputs = np.transpose(val_outputs, (2, 1, 0))
149
+ val_outputs = np.rot90(val_outputs, k=3)
150
+ val_outputs = nib.Nifti1Image(val_outputs, np.eye(4))
151
+ nib.save(val_outputs, f'D:/{input_nii_file.name[-15:-7]}_mask.nii.gz')
152
+
153
+ return ["指定切片分割结果如下, mask文件已保存至D:/", fig1, fig2]
154
+
155
+ if mode == "Step2:Volumn":
156
+ maskFilePath = input_nii_file.name
157
+ size,spacing,V_Right_ventricular_cistern, V_Right_cerebral_sulcus, V_Left_ventricular_cistern, V_Left_cerebral_sulcus = calculate_volume(maskFilePath)
158
+
159
+ vol = f"""右侧脑室脑池的体积为{V_Right_ventricular_cistern}cm³\n 右侧脑沟的体积为{V_Right_cerebral_sulcus}cm³\n 左侧脑室脑池的体积为{V_Left_ventricular_cistern}cm³\n 左侧脑沟的体积为{V_Left_cerebral_sulcus}cm³"""
160
+ fig1 = plt.figure()
161
+ fig2 = plt.figure()
162
+ return [vol, fig1, fig2]
163
+
164
+ # Define the Gradio interface
165
+ iface = gr.Interface(
166
+ fn=process_nii_file,
167
+ inputs=
168
+ [gr.File(file_count='single', file_types=['.nii.gz']),
169
+ gr.inputs.Slider(0, 24, default=8, label="Select Slice", step=1),
170
+ gr.Radio(
171
+ ["Step1:Segment", "Step2:Volumn"], label="mode"
172
+ ),
173
+ ],
174
+
175
+ outputs=[gr.Text(label="Output"), gr.Plot(label="image"), gr.Plot(label="mask")], # Display both "image" and "output"
176
+ )
177
+
178
+ iface.launch(share=True)