drankush-ai commited on
Commit
febd0e1
·
verified ·
1 Parent(s): 24eeab3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -176
app.py CHANGED
@@ -1,183 +1,183 @@
1
  import gradio as gr
2
- import torch
3
- import cv2
4
- import numpy as np
5
- from pathlib import Path
6
- from huggingface_hub import snapshot_download
7
- from fastMONAI.vision_all import *
8
- from git import Repo
9
- import os
10
- from fastai.learner import load_learner
11
- from fastai.basics import load_pickle
12
- import pickle
13
-
14
- # Function to extract slices from mask
15
- def extract_slices_from_mask(img, mask_data, view):
16
- """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
- slices = []
18
- target_size = (320, 320)
 
 
 
 
 
 
 
 
19
 
20
- for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
- if view == "Sagittal":
22
- slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
- elif view == "Axial":
24
- slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
- elif view == "Coronal":
26
- slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
-
28
- slice_img = np.fliplr(np.rot90(slice_img, -1))
29
- slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
-
31
- slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
- slices.append((slice_img_resized, slice_mask_resized))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- return slices
35
-
36
- # Function to resize and pad slices
37
- def resize_and_pad(slice_img, slice_mask, target_size):
38
- """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
- h, w = slice_img.shape
40
- scale = min(target_size[0] / w, target_size[1] / h)
41
- new_w, new_h = int(w * scale), int(h * scale)
42
 
43
- resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
- resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
-
46
- pad_w = (target_size[0] - new_w) // 2
47
- pad_h = (target_size[1] - new_h) // 2
48
-
49
- padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
-
52
- return padded_img, padded_mask
53
-
54
- # Function to normalize image
55
- def normalize_image(slice_img):
56
- """Normalize the image to the range [0, 255] safely."""
57
- slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
- if slice_img_min == slice_img_max: # Avoid division by zero
59
- return np.zeros_like(slice_img, dtype=np.uint8)
60
- normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
- return normalized_img.astype(np.uint8)
62
-
63
- # Function to get fused image
64
- def get_fused_image(img, pred_mask, view, alpha=0.8):
65
- """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
- gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
- mask_color = np.array([255, 0, 0])
68
- colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
 
70
- fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
 
71
 
72
- # Flip the fused image vertically and horizontally
73
- fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
 
75
- if view == 'Sagittal':
76
- return fused_flipped
77
- elif view == 'Coronal' or view == 'Axial':
78
- rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
- return rotated
80
-
81
- # Function for Gradio image segmentation
82
- def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
- """Predict function using the learner and other resources."""
84
-
85
- if view is None:
86
- view = 'Sagittal'
87
-
88
- img_path = Path(fileobj.name)
89
-
90
- save_fn = 'pred_' + img_path.stem
91
- save_path = save_dir / save_fn
92
- org_img, input_img, org_size = med_img_reader(img_path,
93
- reorder=reorder,
94
- resample=resample,
95
- only_tensor=False)
96
-
97
- mask_data = inference(learn, reorder=reorder, resample=resample,
98
- org_img=org_img, input_img=input_img,
99
- org_size=org_size).data
100
-
101
- if "".join(org_img.orientation) == "LSA":
102
- mask_data = mask_data.permute(0,1,3,2)
103
- mask_data = torch.flip(mask_data[0], dims=[1])
104
- mask_data = torch.Tensor(mask_data)[None]
105
-
106
- img = org_img.data
107
- org_img.set_data(mask_data)
108
- org_img.save(save_path)
109
-
110
- slices = extract_slices_from_mask(img[0], mask_data[0], view)
111
- fused_images = [(get_fused_image(
112
- normalize_image(slice_img), # Normalize safely
113
- slice_mask, view))
114
- for slice_img, slice_mask in slices]
115
-
116
- volume = compute_binary_tumor_volume(org_img)
117
-
118
- return fused_images, round(volume, 2)
119
-
120
- # Function to load system resources
121
- def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
122
- """Load the model and other required resources."""
123
- try:
124
- learn = load_learner(models_path / learner_fn)
125
- except Exception as e:
126
- raise ValueError(f"Error loading the model: {str(e)}")
127
-
128
- try:
129
- with open(models_path / variables_fn, 'rb') as f:
130
- variables = pickle.load(f)
131
-
132
- if not isinstance(variables, list) or len(variables) != 3:
133
- raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
134
-
135
- # Assuming the format is [shape, reorder, resample]
136
- shape = variables[0]
137
- reorder = variables[1]
138
- resample = variables[2]
139
-
140
- if not isinstance(reorder, bool):
141
- raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
142
-
143
- if not isinstance(resample, list) or len(resample) != 3:
144
- raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
145
-
146
- except Exception as e:
147
- raise ValueError(f"Error loading variables: {str(e)}")
148
-
149
- return learn, reorder, resample
150
-
151
- # Initialize the system
152
- clone_dir = Path.cwd() / 'clone_dir'
153
- URI = os.getenv('PAT_Token_URI')
154
-
155
- if not URI:
156
- raise ValueError("PAT_Token_URI environment variable is not set")
157
-
158
- if os.path.exists(clone_dir):
159
- pass
160
- else:
161
- Repo.clone_from(URI, clone_dir)
162
-
163
- models_path = clone_dir
164
- save_dir = Path.cwd() / 'hs_pred'
165
- save_dir.mkdir(parents=True, exist_ok=True)
166
-
167
- # Load the model and other required resources
168
- learn, reorder, resample = load_system_resources(models_path=models_path)
169
-
170
- # Gradio interface setup
171
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
172
-
173
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
174
-
175
- demo = gr.Interface(
176
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
177
- inputs=["file", view_selector],
178
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
179
- examples=[[str(Path.cwd() / "sample.nii.gz")]],
180
- allow_flagging='never')
181
-
182
- # Launch the Gradio interface
183
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from huggingface_hub import snapshot_download
7
+ from fastMONAI.vision_all import *
8
+ from git import Repo
9
+ import os
10
+ from fastai.learner import load_learner
11
+ from fastai.basics import load_pickle
12
+ import pickle
13
+
14
+ # Function to extract slices from mask
15
+ def extract_slices_from_mask(img, mask_data, view):
16
+ """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
+ slices = []
18
+ target_size = (320, 320)
19
+
20
+ for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
+ if view == "Sagittal":
22
+ slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
+ elif view == "Axial":
24
+ slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
+ elif view == "Coronal":
26
+ slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
 
28
+ slice_img = np.fliplr(np.rot90(slice_img, -1))
29
+ slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
+
31
+ slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
+ slices.append((slice_img_resized, slice_mask_resized))
33
+
34
+ return slices
35
+
36
+ # Function to resize and pad slices
37
+ def resize_and_pad(slice_img, slice_mask, target_size):
38
+ """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
+ h, w = slice_img.shape
40
+ scale = min(target_size[0] / w, target_size[1] / h)
41
+ new_w, new_h = int(w * scale), int(h * scale)
42
+
43
+ resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
+ resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
+
46
+ pad_w = (target_size[0] - new_w) // 2
47
+ pad_h = (target_size[1] - new_h) // 2
48
+
49
+ padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
+ padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
+
52
+ return padded_img, padded_mask
53
+
54
+ # Function to normalize image
55
+ def normalize_image(slice_img):
56
+ """Normalize the image to the range [0, 255] safely."""
57
+ slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
+ if slice_img_min == slice_img_max: # Avoid division by zero
59
+ return np.zeros_like(slice_img, dtype=np.uint8)
60
+ normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
+ return normalized_img.astype(np.uint8)
62
+
63
+ # Function to get fused image
64
+ def get_fused_image(img, pred_mask, view, alpha=0.8):
65
+ """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
+ gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
+ mask_color = np.array([255, 0, 0])
68
+ colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
+
70
+ fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
71
+
72
+ # Flip the fused image vertically and horizontally
73
+ fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
+
75
+ if view == 'Sagittal':
76
+ return fused_flipped
77
+ elif view == 'Coronal' or view == 'Axial':
78
+ rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
+ return rotated
80
+
81
+ # Function for Gradio image segmentation
82
+ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
+ """Predict function using the learner and other resources."""
84
+
85
+ if view is None:
86
+ view = 'Sagittal'
87
+
88
+ img_path = Path(fileobj.name)
89
+
90
+ save_fn = 'pred_' + img_path.stem
91
+ save_path = save_dir / save_fn
92
+ org_img, input_img, org_size = med_img_reader(img_path,
93
+ reorder=reorder,
94
+ resample=resample,
95
+ only_tensor=False)
96
+
97
+ mask_data = inference(learn, reorder=reorder, resample=resample,
98
+ org_img=org_img, input_img=input_img,
99
+ org_size=org_size).data
100
+
101
+ if "".join(org_img.orientation) == "LSA":
102
+ mask_data = mask_data.permute(0,1,3,2)
103
+ mask_data = torch.flip(mask_data[0], dims=[1])
104
+ mask_data = torch.Tensor(mask_data)[None]
105
+
106
+ img = org_img.data
107
+ org_img.set_data(mask_data)
108
+ org_img.save(save_path)
109
+
110
+ slices = extract_slices_from_mask(img[0], mask_data[0], view)
111
+ fused_images = [(get_fused_image(
112
+ normalize_image(slice_img), # Normalize safely
113
+ slice_mask, view))
114
+ for slice_img, slice_mask in slices]
115
+
116
+ volume = compute_binary_tumor_volume(org_img)
117
+
118
+ return fused_images, round(volume, 2)
119
+
120
+ # Function to load system resources
121
+ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
122
+ """Load the model and other required resources."""
123
+ try:
124
+ learn = load_learner(models_path / learner_fn)
125
+ except Exception as e:
126
+ raise ValueError(f"Error loading the model: {str(e)}")
127
+
128
+ try:
129
+ with open(models_path / variables_fn, 'rb') as f:
130
+ variables = pickle.load(f)
131
 
132
+ if not isinstance(variables, list) or len(variables) != 3:
133
+ raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
 
 
 
 
 
 
134
 
135
+ # Assuming the format is [shape, reorder, resample]
136
+ shape = variables[0]
137
+ reorder = variables[1]
138
+ resample = variables[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ if not isinstance(reorder, bool):
141
+ raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
142
 
143
+ if not isinstance(resample, list) or len(resample) != 3:
144
+ raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
145
 
146
+ except Exception as e:
147
+ raise ValueError(f"Error loading variables: {str(e)}")
148
+
149
+ return learn, reorder, resample
150
+
151
+ # Initialize the system
152
+ clone_dir = Path.cwd() / 'clone_dir'
153
+ URI = os.getenv('PAT_Token_URI')
154
+
155
+ if not URI:
156
+ raise ValueError("PAT_Token_URI environment variable is not set")
157
+
158
+ if os.path.exists(clone_dir):
159
+ pass
160
+ else:
161
+ Repo.clone_from(URI, clone_dir)
162
+
163
+ models_path = clone_dir
164
+ save_dir = Path.cwd() / 'hs_pred'
165
+ save_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ # Load the model and other required resources
168
+ learn, reorder, resample = load_system_resources(models_path=models_path)
169
+
170
+ # Gradio interface setup
171
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
172
+
173
+ view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
174
+
175
+ demo = gr.Interface(
176
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
177
+ inputs=["file", view_selector],
178
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
179
+ examples=[[str(Path.cwd() / "sample.nii.gz")]],
180
+ allow_flagging='never')
181
+
182
+ # Launch the Gradio interface
183
+ demo.launch()