Ojasvi-Nagayach commited on
Commit
dc75c07
·
verified ·
1 Parent(s): 380c2ea

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +98 -0
  2. heart_model.pkl +3 -0
  3. requirements.txt +1 -0
  4. sample.nii.gz +3 -0
  5. vars.pkl +3 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from huggingface_hub import snapshot_download
7
+ from fastMONAI.vision_all import *
8
+ #import pathlib
9
+ #temp = pathlib.PosixPath
10
+ #pathlib.PosixPath = pathlib.WindowsPath
11
+ #pathlib.PosixPath = temp
12
+
13
+ def initialize_system():
14
+ """Initial setup of model paths and other constants."""
15
+ models_path = Path.cwd()
16
+ save_dir = Path.cwd() / 'hs_pred'
17
+ save_dir.mkdir(parents=True, exist_ok=True)
18
+ download_example_endometrial_cancer_data(path=save_dir, multi_channel=False)
19
+
20
+ return models_path, save_dir
21
+
22
+ def extract_slices_from_mask(img, mask_data):
23
+ """Extract all slices from the 3D [W, H, D] image and mask data."""
24
+ slices = []
25
+ for idx in range(img.shape[-1]):
26
+ slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
27
+ slice_img = np.fliplr(np.rot90(slice_img, -1))
28
+ slice_mask = np.fliplr(np.rot90(slice_mask, -1))
29
+ slices.append((slice_img, slice_mask))
30
+ return slices
31
+
32
+ def get_fused_image(img, pred_mask, alpha=0.8):
33
+ """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
34
+ gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
35
+ mask_color = np.array([255, 0, 0])
36
+ colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
37
+
38
+ fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
39
+
40
+ # Flip the fused image vertically and horizontally
41
+ fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
42
+
43
+ return fused_flipped
44
+
45
+ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir):
46
+ """Predict function using the learner and other resources."""
47
+ img_path = Path(fileobj.name)
48
+
49
+ save_fn = 'pred_' + img_path.stem
50
+ save_path = save_dir / save_fn
51
+ org_img, input_img, org_size = med_img_reader(img_path,
52
+ reorder=reorder,
53
+ resample=resample,
54
+ only_tensor=False)
55
+
56
+ mask_data = inference(learn, reorder=reorder, resample=resample,
57
+ org_img=org_img, input_img=input_img,
58
+ org_size=org_size).data
59
+
60
+ if "".join(org_img.orientation) == "LSA":
61
+ mask_data = mask_data.permute(0,1,3,2)
62
+ mask_data = torch.flip(mask_data[0], dims=[1])
63
+ mask_data = torch.Tensor(mask_data)[None]
64
+
65
+ img = org_img.data
66
+ org_img.set_data(mask_data)
67
+ org_img.save(save_path)
68
+
69
+ slices = extract_slices_from_mask(img[0], mask_data[0])
70
+ fused_images = [(get_fused_image(
71
+ ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8),
72
+ slice_mask))
73
+ for slice_img, slice_mask in slices]
74
+
75
+ volume = compute_binary_tumor_volume(org_img)
76
+
77
+ return fused_images, round(volume, 2)
78
+
79
+ # Initialize the system
80
+ models_path, save_dir = initialize_system()
81
+
82
+ # Load the model and other required resources
83
+ learn, reorder, resample = load_system_resources(models_path=Path.cwd(),
84
+ learner_fn='heart_model.pkl',
85
+ variables_fn='vars.pkl')
86
+
87
+ # Gradio interface setup
88
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
89
+
90
+ demo = gr.Interface(
91
+ fn=lambda fileobj: gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir),
92
+ inputs=["file"],
93
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
94
+ examples=[[str(Path.cwd() /"sample.nii.gz")]],
95
+ allow_flagging='never')
96
+
97
+ # Launch the Gradio interface
98
+ demo.launch()
heart_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8afefff7a465f013ca5978f03f6d0a0c4aa1dd2650dc0308962f1ad66cee4ae6
3
+ size 19363377
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fastMONAI
sample.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05477be1b339567c8b304dafa737ad22be268024140197eeae9f14172a76e0c4
3
+ size 16059519
vars.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd8458577a45f5ee60fc50a8ec5f6a499c6733b0f241a33cf76fa22bb9e715d3
3
+ size 173