drankush-ai commited on
Commit
f37e491
·
verified ·
1 Parent(s): febd0e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -124
app.py CHANGED
@@ -13,152 +13,155 @@ import pickle
13
 
14
  # Function to extract slices from mask
15
  def extract_slices_from_mask(img, mask_data, view):
16
- """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
- slices = []
18
- target_size = (320, 320)
19
-
20
- for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
- if view == "Sagittal":
22
- slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
- elif view == "Axial":
24
- slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
- elif view == "Coronal":
26
- slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
-
28
- slice_img = np.fliplr(np.rot90(slice_img, -1))
29
- slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
-
31
- slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
- slices.append((slice_img_resized, slice_mask_resized))
33
-
34
- return slices
35
 
36
  # Function to resize and pad slices
37
  def resize_and_pad(slice_img, slice_mask, target_size):
38
- """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
- h, w = slice_img.shape
40
- scale = min(target_size[0] / w, target_size[1] / h)
41
- new_w, new_h = int(w * scale), int(h * scale)
42
-
43
- resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
- resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
 
46
- pad_w = (target_size[0] - new_w) // 2
47
- pad_h = (target_size[1] - new_h) // 2
48
 
49
- padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
 
52
- return padded_img, padded_mask
53
 
54
  # Function to normalize image
55
  def normalize_image(slice_img):
56
- """Normalize the image to the range [0, 255] safely."""
57
- slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
- if slice_img_min == slice_img_max: # Avoid division by zero
59
- return np.zeros_like(slice_img, dtype=np.uint8)
60
- normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
- return normalized_img.astype(np.uint8)
62
 
63
  # Function to get fused image
64
  def get_fused_image(img, pred_mask, view, alpha=0.8):
65
- """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
- gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
- mask_color = np.array([255, 0, 0])
68
- colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
-
70
- fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
71
-
72
- # Flip the fused image vertically and horizontally
73
- fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
-
75
- if view == 'Sagittal':
76
- return fused_flipped
77
- elif view == 'Coronal' or view == 'Axial':
78
- rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
- return rotated
80
 
81
  # Function for Gradio image segmentation
82
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
- """Predict function using the learner and other resources."""
84
-
85
- if view is None:
86
- view = 'Sagittal'
87
-
88
- img_path = Path(fileobj.name)
89
-
90
- save_fn = 'pred_' + img_path.stem
91
- save_path = save_dir / save_fn
92
- org_img, input_img, org_size = med_img_reader(img_path,
93
- reorder=reorder,
94
- resample=resample,
95
- only_tensor=False)
96
-
97
- mask_data = inference(learn, reorder=reorder, resample=resample,
98
- org_img=org_img, input_img=input_img,
99
- org_size=org_size).data
100
-
101
- if "".join(org_img.orientation) == "LSA":
102
- mask_data = mask_data.permute(0,1,3,2)
103
- mask_data = torch.flip(mask_data[0], dims=[1])
104
- mask_data = torch.Tensor(mask_data)[None]
105
-
106
- img = org_img.data
107
- org_img.set_data(mask_data)
108
- org_img.save(save_path)
109
-
110
- slices = extract_slices_from_mask(img[0], mask_data[0], view)
111
- fused_images = [(get_fused_image(
112
- normalize_image(slice_img), # Normalize safely
113
- slice_mask, view))
114
- for slice_img, slice_mask in slices]
115
-
116
- volume = compute_binary_tumor_volume(org_img)
117
-
118
- return fused_images, round(volume, 2)
 
 
 
119
 
120
  # Function to load system resources
121
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
122
- """Load the model and other required resources."""
123
- try:
124
- learn = load_learner(models_path / learner_fn)
125
- except Exception as e:
126
- raise ValueError(f"Error loading the model: {str(e)}")
127
-
128
- try:
129
- with open(models_path / variables_fn, 'rb') as f:
130
- variables = pickle.load(f)
131
-
132
- if not isinstance(variables, list) or len(variables) != 3:
133
- raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
134
-
135
- # Assuming the format is [shape, reorder, resample]
136
- shape = variables[0]
137
- reorder = variables[1]
138
- resample = variables[2]
139
-
140
- if not isinstance(reorder, bool):
141
- raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
142
-
143
- if not isinstance(resample, list) or len(resample) != 3:
144
- raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
145
-
146
- except Exception as e:
147
- raise ValueError(f"Error loading variables: {str(e)}")
148
-
149
- return learn, reorder, resample
150
 
151
  # Initialize the system
152
  clone_dir = Path.cwd() / 'clone_dir'
153
  URI = os.getenv('PAT_Token_URI')
154
 
155
  if not URI:
156
- raise ValueError("PAT_Token_URI environment variable is not set")
157
 
158
  if os.path.exists(clone_dir):
159
- pass
160
  else:
161
- Repo.clone_from(URI, clone_dir)
162
 
163
  models_path = clone_dir
164
  save_dir = Path.cwd() / 'hs_pred'
@@ -172,12 +175,15 @@ output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
172
 
173
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
174
 
 
 
 
175
  demo = gr.Interface(
176
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
177
- inputs=["file", view_selector],
178
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
179
- examples=[[str(Path.cwd() / "sample.nii.gz")]],
180
- allow_flagging='never')
181
 
182
  # Launch the Gradio interface
183
  demo.launch()
 
13
 
14
  # Function to extract slices from mask
15
  def extract_slices_from_mask(img, mask_data, view):
16
+ """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
+ slices = []
18
+ target_size = (320, 320)
19
+
20
+ for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
+ if view == "Sagittal":
22
+ slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
+ elif view == "Axial":
24
+ slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
+ elif view == "Coronal":
26
+ slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
+
28
+ slice_img = np.fliplr(np.rot90(slice_img, -1))
29
+ slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
+
31
+ slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
+ slices.append((slice_img_resized, slice_mask_resized))
33
+
34
+ return slices
35
 
36
  # Function to resize and pad slices
37
  def resize_and_pad(slice_img, slice_mask, target_size):
38
+ """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
+ h, w = slice_img.shape
40
+ scale = min(target_size[0] / w, target_size[1] / h)
41
+ new_w, new_h = int(w * scale), int(h * scale)
42
+
43
+ resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
+ resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
 
46
+ pad_w = (target_size[0] - new_w) // 2
47
+ pad_h = (target_size[1] - new_h) // 2
48
 
49
+ padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
+ padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
 
52
+ return padded_img, padded_mask
53
 
54
  # Function to normalize image
55
  def normalize_image(slice_img):
56
+ """Normalize the image to the range [0, 255] safely."""
57
+ slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
+ if slice_img_min == slice_img_max: # Avoid division by zero
59
+ return np.zeros_like(slice_img, dtype=np.uint8)
60
+ normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
+ return normalized_img.astype(np.uint8)
62
 
63
  # Function to get fused image
64
  def get_fused_image(img, pred_mask, view, alpha=0.8):
65
+ """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
+ gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
+ mask_color = np.array([255, 0, 0])
68
+ colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
+
70
+ fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
71
+
72
+ # Flip the fused image vertically and horizontally
73
+ fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
+
75
+ if view == 'Sagittal':
76
+ return fused_flipped
77
+ elif view == 'Coronal' or view == 'Axial':
78
+ rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
+ return rotated
80
 
81
  # Function for Gradio image segmentation
82
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
+ """Predict function using the learner and other resources."""
84
+
85
+ if view is None:
86
+ view = 'Sagittal'
87
+
88
+ img_path = Path(fileobj.name)
89
+
90
+ # Convert PosixPath to string
91
+ img_path_str = str(img_path)
92
+
93
+ save_fn = 'pred_' + img_path.stem
94
+ save_path = save_dir / save_fn
95
+ org_img, input_img, org_size = med_img_reader(img_path_str,
96
+ reorder=reorder,
97
+ resample=resample,
98
+ only_tensor=False)
99
+
100
+ mask_data = inference(learn, reorder=reorder, resample=resample,
101
+ org_img=org_img, input_img=input_img,
102
+ org_size=org_size).data
103
+
104
+ if "".join(org_img.orientation) == "LSA":
105
+ mask_data = mask_data.permute(0,1,3,2)
106
+ mask_data = torch.flip(mask_data[0], dims=[1])
107
+ mask_data = torch.Tensor(mask_data)[None]
108
+
109
+ img = org_img.data
110
+ org_img.set_data(mask_data)
111
+ org_img.save(save_path)
112
+
113
+ slices = extract_slices_from_mask(img[0], mask_data[0], view)
114
+ fused_images = [(get_fused_image(
115
+ normalize_image(slice_img), # Normalize safely
116
+ slice_mask, view))
117
+ for slice_img, slice_mask in slices]
118
+
119
+ volume = compute_binary_tumor_volume(org_img)
120
+
121
+ return fused_images, round(volume, 2)
122
 
123
  # Function to load system resources
124
  def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
125
+ """Load the model and other required resources."""
126
+ try:
127
+ learn = load_learner(models_path / learner_fn)
128
+ except Exception as e:
129
+ raise ValueError(f"Error loading the model: {str(e)}")
130
+
131
+ try:
132
+ with open(models_path / variables_fn, 'rb') as f:
133
+ variables = pickle.load(f)
134
+
135
+ if not isinstance(variables, list) or len(variables) != 3:
136
+ raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
137
+
138
+ # Assuming the format is [shape, reorder, resample]
139
+ shape = variables[0]
140
+ reorder = variables[1]
141
+ resample = variables[2]
142
+
143
+ if not isinstance(reorder, bool):
144
+ raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
145
+
146
+ if not isinstance(resample, list) or len(resample) != 3:
147
+ raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
148
+
149
+ except Exception as e:
150
+ raise ValueError(f"Error loading variables: {str(e)}")
151
+
152
+ return learn, reorder, resample
153
 
154
  # Initialize the system
155
  clone_dir = Path.cwd() / 'clone_dir'
156
  URI = os.getenv('PAT_Token_URI')
157
 
158
  if not URI:
159
+ raise ValueError("PAT_Token_URI environment variable is not set")
160
 
161
  if os.path.exists(clone_dir):
162
+ pass
163
  else:
164
+ Repo.clone_from(URI, clone_dir)
165
 
166
  models_path = clone_dir
167
  save_dir = Path.cwd() / 'hs_pred'
 
175
 
176
  view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
177
 
178
+ # Ensure the example file path is correct
179
+ example_path = str(clone_dir / "sample.nii.gz")
180
+
181
  demo = gr.Interface(
182
+ fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
183
+ inputs=["file", view_selector],
184
+ outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
185
+ examples=[[example_path]],
186
+ allow_flagging='never')
187
 
188
  # Launch the Gradio interface
189
  demo.launch()