drankush-ai commited on
Commit
24eeab3
·
verified ·
1 Parent(s): a4ab603

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -178
app.py CHANGED
@@ -1,179 +1,183 @@
1
  import gradio as gr
2
- import torch
3
- import cv2
4
- import numpy as np
5
- from pathlib import Path
6
- from huggingface_hub import snapshot_download
7
- from fastMONAI.vision_all import *
8
- from git import Repo
9
- import os
10
- from fastai.learner import load_learner
11
- from fastai.basics import load_pickle
12
-
13
- # Function to extract slices from mask
14
- def extract_slices_from_mask(img, mask_data, view):
15
- """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
16
- slices = []
17
- target_size = (320, 320)
18
-
19
- for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
20
- if view == "Sagittal":
21
- slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
22
- elif view == "Axial":
23
- slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
24
- elif view == "Coronal":
25
- slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
26
-
27
- slice_img = np.fliplr(np.rot90(slice_img, -1))
28
- slice_mask = np.fliplr(np.rot90(slice_mask, -1))
29
-
30
- slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
31
- slices.append((slice_img_resized, slice_mask_resized))
32
-
33
- return slices
34
-
35
- # Function to resize and pad slices
36
- def resize_and_pad(slice_img, slice_mask, target_size):
37
- """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
38
- h, w = slice_img.shape
39
- scale = min(target_size[0] / w, target_size[1] / h)
40
- new_w, new_h = int(w * scale), int(h * scale)
41
-
42
- resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
43
- resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
44
-
45
- pad_w = (target_size[0] - new_w) // 2
46
- pad_h = (target_size[1] - new_h) // 2
47
-
48
- padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
49
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
-
51
- return padded_img, padded_mask
52
-
53
- # Function to normalize image
54
- def normalize_image(slice_img):
55
- """Normalize the image to the range [0, 255] safely."""
56
- slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
57
- if slice_img_min == slice_img_max: # Avoid division by zero
58
- return np.zeros_like(slice_img, dtype=np.uint8)
59
- normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
60
- return normalized_img.astype(np.uint8)
61
-
62
- # Function to get fused image
63
- def get_fused_image(img, pred_mask, view, alpha=0.8):
64
- """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
65
- gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
66
- mask_color = np.array([255, 0, 0])
67
- colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
68
-
69
- fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
70
-
71
- # Flip the fused image vertically and horizontally
72
- fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
73
-
74
- if view == 'Sagittal':
75
- return fused_flipped
76
- elif view == 'Coronal' or view == 'Axial':
77
- rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
78
- return rotated
79
-
80
- # Function for Gradio image segmentation
81
- def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
82
- """Predict function using the learner and other resources."""
83
-
84
- if view is None:
85
- view = 'Sagittal'
86
-
87
- img_path = Path(fileobj.name)
88
-
89
- save_fn = 'pred_' + img_path.stem
90
- save_path = save_dir / save_fn
91
- org_img, input_img, org_size = med_img_reader(img_path,
92
- reorder=reorder,
93
- resample=resample,
94
- only_tensor=False)
95
-
96
- mask_data = inference(learn, reorder=reorder, resample=resample,
97
- org_img=org_img, input_img=input_img,
98
- org_size=org_size).data
99
-
100
- if "".join(org_img.orientation) == "LSA":
101
- mask_data = mask_data.permute(0,1,3,2)
102
- mask_data = torch.flip(mask_data[0], dims=[1])
103
- mask_data = torch.Tensor(mask_data)[None]
104
-
105
- img = org_img.data
106
- org_img.set_data(mask_data)
107
- org_img.save(save_path)
108
-
109
- slices = extract_slices_from_mask(img[0], mask_data[0], view)
110
- fused_images = [(get_fused_image(
111
- normalize_image(slice_img), # Normalize safely
112
- slice_mask, view))
113
- for slice_img, slice_mask in slices]
114
-
115
- volume = compute_binary_tumor_volume(org_img)
116
-
117
- return fused_images, round(volume, 2)
118
-
119
- # Function to load system resources
120
- def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
121
- """Load the model and other required resources."""
122
- try:
123
- learn = load_learner(models_path / learner_fn)
124
- except Exception as e:
125
- raise ValueError(f"Error loading the model: {str(e)}")
126
-
127
- try:
128
- variables = load_pickle(models_path / variables_fn)
129
- if isinstance(variables, dict):
130
- reorder = variables.get('reorder')
131
- resample = variables.get('resample')
132
- if reorder is None or resample is None:
133
- raise ValueError("'reorder' or 'resample' not found in vars.pkl. Using default values.")
134
- reorder = True # Set a default value
135
- resample = True # Set a default value
136
- else:
137
- raise ValueError("vars.pkl does not contain a dictionary. Using default values for reorder and resample.")
138
- reorder = True # Set a default value
139
- resample = True # Set a default value
140
- except Exception as e:
141
- raise ValueError(f"Error loading variables: {str(e)}")
142
- reorder = True # Set a default value
143
- resample = True # Set a default value
144
-
145
- return learn, reorder, resample
146
-
147
- # Initialize the system
148
- clone_dir = Path.cwd() / 'clone_dir'
149
- URI = os.getenv('PAT_Token_URI')
150
-
151
- if not URI:
152
- raise ValueError("PAT_Token_URI environment variable is not set")
153
-
154
- if os.path.exists(clone_dir):
155
- pass
156
- else:
157
- Repo.clone_from(URI, clone_dir)
158
-
159
- models_path = clone_dir
160
- save_dir = Path.cwd() / 'hs_pred'
161
- save_dir.mkdir(parents=True, exist_ok=True)
162
-
163
- # Load the model and other required resources
164
- learn, reorder, resample = load_system_resources(models_path=models_path)
165
-
166
- # Gradio interface setup
167
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
168
-
169
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
170
-
171
- demo = gr.Interface(
172
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
173
- inputs=["file", view_selector],
174
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
175
- examples=[[str(Path.cwd() / "sample.nii.gz")]],
176
- allow_flagging='never')
177
-
178
- # Launch the Gradio interface
179
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from huggingface_hub import snapshot_download
7
+ from fastMONAI.vision_all import *
8
+ from git import Repo
9
+ import os
10
+ from fastai.learner import load_learner
11
+ from fastai.basics import load_pickle
12
+ import pickle
13
+
14
+ # Function to extract slices from mask
15
+ def extract_slices_from_mask(img, mask_data, view):
16
+ """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
+ slices = []
18
+ target_size = (320, 320)
19
+
20
+ for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
+ if view == "Sagittal":
22
+ slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
+ elif view == "Axial":
24
+ slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
+ elif view == "Coronal":
26
+ slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
+
28
+ slice_img = np.fliplr(np.rot90(slice_img, -1))
29
+ slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
+
31
+ slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
+ slices.append((slice_img_resized, slice_mask_resized))
33
+
34
+ return slices
35
+
36
+ # Function to resize and pad slices
37
+ def resize_and_pad(slice_img, slice_mask, target_size):
38
+ """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
+ h, w = slice_img.shape
40
+ scale = min(target_size[0] / w, target_size[1] / h)
41
+ new_w, new_h = int(w * scale), int(h * scale)
42
+
43
+ resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
+ resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
+
46
+ pad_w = (target_size[0] - new_w) // 2
47
+ pad_h = (target_size[1] - new_h) // 2
48
+
49
+ padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
+ padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
+
52
+ return padded_img, padded_mask
53
+
54
+ # Function to normalize image
55
+ def normalize_image(slice_img):
56
+ """Normalize the image to the range [0, 255] safely."""
57
+ slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
+ if slice_img_min == slice_img_max: # Avoid division by zero
59
+ return np.zeros_like(slice_img, dtype=np.uint8)
60
+ normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
+ return normalized_img.astype(np.uint8)
62
+
63
+ # Function to get fused image
64
+ def get_fused_image(img, pred_mask, view, alpha=0.8):
65
+ """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
+ gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
+ mask_color = np.array([255, 0, 0])
68
+ colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
+
70
+ fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
71
+
72
+ # Flip the fused image vertically and horizontally
73
+ fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
+
75
+ if view == 'Sagittal':
76
+ return fused_flipped
77
+ elif view == 'Coronal' or view == 'Axial':
78
+ rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
+ return rotated
80
+
81
+ # Function for Gradio image segmentation
82
+ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
+ """Predict function using the learner and other resources."""
84
+
85
+ if view is None:
86
+ view = 'Sagittal'
87
+
88
+ img_path = Path(fileobj.name)
89
+
90
+ save_fn = 'pred_' + img_path.stem
91
+ save_path = save_dir / save_fn
92
+ org_img, input_img, org_size = med_img_reader(img_path,
93
+ reorder=reorder,
94
+ resample=resample,
95
+ only_tensor=False)
96
+
97
+ mask_data = inference(learn, reorder=reorder, resample=resample,
98
+ org_img=org_img, input_img=input_img,
99
+ org_size=org_size).data
100
+
101
+ if "".join(org_img.orientation) == "LSA":
102
+ mask_data = mask_data.permute(0,1,3,2)
103
+ mask_data = torch.flip(mask_data[0], dims=[1])
104
+ mask_data = torch.Tensor(mask_data)[None]
105
+
106
+ img = org_img.data
107
+ org_img.set_data(mask_data)
108
+ org_img.save(save_path)
109
+
110
+ slices = extract_slices_from_mask(img[0], mask_data[0], view)
111
+ fused_images = [(get_fused_image(
112
+ normalize_image(slice_img), # Normalize safely
113
+ slice_mask, view))
114
+ for slice_img, slice_mask in slices]
115
+
116
+ volume = compute_binary_tumor_volume(org_img)
117
+
118
+ return fused_images, round(volume, 2)
119
+
120
+ # Function to load system resources
121
+ def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
122
+ """Load the model and other required resources."""
123
+ try:
124
+ learn = load_learner(models_path / learner_fn)
125
+ except Exception as e:
126
+ raise ValueError(f"Error loading the model: {str(e)}")
127
+
128
+ try:
129
+ with open(models_path / variables_fn, 'rb') as f:
130
+ variables = pickle.load(f)
131
+
132
+ if not isinstance(variables, list) or len(variables) != 3:
133
+ raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
134
+
135
+ # Assuming the format is [shape, reorder, resample]
136
+ shape = variables[0]
137
+ reorder = variables[1]
138
+ resample = variables[2]
139
+
140
+ if not isinstance(reorder, bool):
141
+ raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
142
+
143
+ if not isinstance(resample, list) or len(resample) != 3:
144
+ raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
145
+
146
+ except Exception as e:
147
+ raise ValueError(f"Error loading variables: {str(e)}")
148
+
149
+ return learn, reorder, resample
150
+
151
+ # Initialize the system
152
+ clone_dir = Path.cwd() / 'clone_dir'
153
+ URI = os.getenv('PAT_Token_URI')
154
+
155
+ if not URI:
156
+ raise ValueError("PAT_Token_URI environment variable is not set")
157
+
158
+ if os.path.exists(clone_dir):
159
+ pass
160
+ else:
161
+ Repo.clone_from(URI, clone_dir)
162
+
163
+ models_path = clone_dir
164
+ save_dir = Path.cwd() / 'hs_pred'
165
+ save_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ # Load the model and other required resources
168
+ learn, reorder, resample = load_system_resources(models_path=models_path)
169
+
170
+ # Gradio interface setup
171
+ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
172
+
173
+ view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
174
+
175
+ demo = gr.Interface(
176
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
177
+ inputs=["file", view_selector],
178
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
179
+ examples=[[str(Path.cwd() / "sample.nii.gz")]],
180
+ allow_flagging='never')
181
+
182
+ # Launch the Gradio interface
183
+ demo.launch()