drankush-ai commited on
Commit
490c5a6
·
verified ·
1 Parent(s): 723b107

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -205
app.py CHANGED
@@ -1,222 +1,115 @@
1
- import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- from pathlib import Path
6
- from huggingface_hub import snapshot_download
7
- from fastMONAI.vision_all import *
8
- from git import Repo
9
- import os
10
  from fastai.learner import load_learner
11
- from fastai.basics import load_pickle
12
  import pickle
13
 
14
- # Function to extract slices from mask
15
- def extract_slices_from_mask(img, mask_data, view):
16
- """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
17
- slices = []
18
- target_size = (320, 320)
19
-
20
- for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
21
- if view == "Sagittal":
22
- slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
23
- elif view == "Axial":
24
- slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
25
- elif view == "Coronal":
26
- slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
27
-
28
- slice_img = np.fliplr(np.rot90(slice_img, -1))
29
- slice_mask = np.fliplr(np.rot90(slice_mask, -1))
30
-
31
- slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
32
- slices.append((slice_img_resized, slice_mask_resized))
33
-
34
- return slices
35
-
36
- # Function to resize and pad slices
37
- def resize_and_pad(slice_img, slice_mask, target_size):
38
- """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
39
- h, w = slice_img.shape
40
- scale = min(target_size[0] / w, target_size[1] / h)
41
- new_w, new_h = int(w * scale), int(h * scale)
42
-
43
- resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
44
- resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
45
-
46
- pad_w = (target_size[0] - new_w) // 2
47
- pad_h = (target_size[1] - new_h) // 2
48
-
49
- padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
50
- padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
51
-
52
- return padded_img, padded_mask
53
-
54
- # Function to normalize image
55
- def normalize_image(slice_img):
56
- """Normalize the image to the range [0, 255] safely."""
57
- slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
58
- if slice_img_min == slice_img_max: # Avoid division by zero
59
- return np.zeros_like(slice_img, dtype=np.uint8)
60
- normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
61
- return normalized_img.astype(np.uint8)
62
-
63
- # Function to get fused image
64
- def get_fused_image(img, pred_mask, view, alpha=0.8):
65
- """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
66
- gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
- mask_color = np.array([255, 0, 0])
68
- colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
69
-
70
- fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
71
-
72
- # Flip the fused image vertically and horizontally
73
- fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
74
-
75
- if view == 'Sagittal':
76
- return fused_flipped
77
- elif view == 'Coronal' or view == 'Axial':
78
- rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
- return rotated
80
-
81
- # Define the inference function
82
  def inference(learn, reorder, resample, org_img, input_img, org_size):
83
- """Perform segmentation using the loaded model."""
84
- # Ensure input_img is a torch.Tensor
85
  if not isinstance(input_img, torch.Tensor):
86
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
87
-
88
- # Perform the segmentation
 
 
 
89
  with torch.no_grad():
90
  pred = learn.predict(input_img)
91
-
92
- # Process the prediction if necessary
93
- mask_data = pred[0] # Assuming the first element of the prediction is the mask
94
-
95
- return mask_data
96
 
97
- # Function for Gradio image segmentation
98
- def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
99
- """Predict function using the learner and other resources."""
100
 
101
- if view is None:
102
- view = 'Sagittal'
103
-
104
- img_path = Path(fileobj.name)
105
 
106
- # Convert PosixPath to string
107
- img_path_str = str(img_path)
 
108
 
109
- save_fn = 'pred_' + img_path.stem
 
110
  save_path = save_dir / save_fn
111
-
112
- # Ensure only_tensor is set to False to get all values
113
- try:
114
- org_img, input_img, org_size = med_img_reader(img_path_str,
115
- reorder=reorder,
116
- resample=resample,
117
- only_tensor=False,
118
- dtype=torch.Tensor)
119
- except ValueError:
120
- # Handle the case where med_img_reader returns only two values
121
- org_img, input_img = med_img_reader(img_path_str,
122
- reorder=reorder,
123
- resample=resample,
124
- only_tensor=False,
125
- dtype=torch.Tensor)
126
- # Infer org_size from org_img
127
- org_size = org_img.shape[1:] # Assuming org_img has a shape attribute
128
-
129
- # Ensure input_img is a torch.Tensor
130
- if not isinstance(input_img, torch.Tensor):
131
- raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
132
-
133
- mask_data = inference(learn, reorder=reorder, resample=resample,
134
- org_img=org_img, input_img=input_img,
135
- org_size=org_size)
136
-
137
- if "".join(org_img.orientation) == "LSA":
138
- mask_data = mask_data.permute(0,1,3,2)
139
- mask_data = torch.flip(mask_data[0], dims=[1])
140
- mask_data = torch.Tensor(mask_data)[None]
141
-
142
- img = org_img.data
143
- org_img.set_data(mask_data)
144
- org_img.save(save_path)
145
-
146
- slices = extract_slices_from_mask(img[0], mask_data[0], view)
147
- fused_images = [(get_fused_image(
148
- normalize_image(slice_img), # Normalize safely
149
- slice_mask, view))
150
- for slice_img, slice_mask in slices]
151
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  volume = compute_binary_tumor_volume(org_img)
153
 
154
- return fused_images, round(volume, 2)
155
-
156
- # Function to load system resources
157
- def load_system_resources(models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl'):
158
- """Load the model and other required resources."""
159
- try:
160
- learn = load_learner(models_path / learner_fn)
161
- except Exception as e:
162
- raise ValueError(f"Error loading the model: {str(e)}")
163
-
164
- try:
165
- with open(models_path / variables_fn, 'rb') as f:
166
- variables = pickle.load(f)
167
-
168
- if not isinstance(variables, list) or len(variables) != 3:
169
- raise ValueError(f"vars.pkl does not contain the expected list format. Found: {variables}")
170
-
171
- # Assuming the format is [shape, reorder, resample]
172
- shape = variables[0]
173
- reorder = variables[1]
174
- resample = variables[2]
175
-
176
- if not isinstance(reorder, bool):
177
- raise ValueError(f"vars.pkl does not contain a valid 'reorder' value. Found: {reorder}")
178
-
179
- if not isinstance(resample, list) or len(resample) != 3:
180
- raise ValueError(f"vars.pkl does not contain a valid 'resample' value. Found: {resample}")
181
-
182
- except Exception as e:
183
- raise ValueError(f"Error loading variables: {str(e)}")
184
-
185
- return learn, reorder, resample
186
-
187
- # Initialize the system
188
- clone_dir = Path.cwd() / 'clone_dir'
189
- URI = os.getenv('PAT_Token_URI')
190
-
191
- if not URI:
192
- raise ValueError("PAT_Token_URI environment variable is not set")
193
-
194
- if os.path.exists(clone_dir):
195
- pass
196
- else:
197
- Repo.clone_from(URI, clone_dir)
198
-
199
- models_path = clone_dir
200
- save_dir = Path.cwd() / 'hs_pred'
201
- save_dir.mkdir(parents=True, exist_ok=True)
202
-
203
- # Load the model and other required resources
204
- learn, reorder, resample = load_system_resources(models_path=models_path)
205
-
206
- # Gradio interface setup
207
- output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
208
-
209
- view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
210
-
211
- # Ensure the example file path is correct
212
- example_path = str(clone_dir / "sample.nii.gz")
213
-
214
- demo = gr.Interface(
215
- fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
216
- inputs=["file", view_selector],
217
- outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
218
- examples=[[example_path]],
219
- allow_flagging='never')
220
-
221
- # Launch the Gradio interface
222
- demo.launch()
 
1
+ from pathlib import Path
2
  import torch
3
  import cv2
4
  import numpy as np
5
+ from fastMONAI.vision_all import med_img_reader # keep the import
 
 
 
 
6
  from fastai.learner import load_learner
 
7
  import pickle
8
 
9
+ def read_image_as_tensor(path: str, reorder: bool, resample: list) -> tuple:
10
+ """Read a medical image and always return a torch.Tensor as the second element."""
11
+ org_img, raw, *rest = med_img_reader(
12
+ path,
13
+ reorder=reorder,
14
+ resample=resample,
15
+ only_tensor=False,
16
+ dtype=torch.Tensor,
17
+ )
18
+ # raw may be a dict or a tensor
19
+ if isinstance(raw, dict):
20
+ # fastMONAI convention: the actual tensor lives under the key "tensor"
21
+ tensor = raw.get("tensor")
22
+ if tensor is None:
23
+ # fallback: first torch.Tensor found in the dict
24
+ tensor = next(v for v in raw.values() if isinstance(v, torch.Tensor))
25
+ else:
26
+ tensor = raw
27
+
28
+ # original size (used later for volume calculation)
29
+ org_size = rest[0] if rest else org_img.shape[1:]
30
+
31
+ return org_img, tensor, org_size
32
+
33
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def inference(learn, reorder, resample, org_img, input_img, org_size):
35
+ """Run the learner on a single 3‑D volume."""
 
36
  if not isinstance(input_img, torch.Tensor):
37
+ raise TypeError(f"input_img must be a torch.Tensor, got {type(input_img)}")
38
+
39
+ # Ensure batch dimension exists (fastai expects N×C×D×H×W)
40
+ if input_img.dim() == 4: # (C, D, H, W) → (1, C, D, H, W)
41
+ input_img = input_img.unsqueeze(0)
42
+
43
  with torch.no_grad():
44
  pred = learn.predict(input_img)
 
 
 
 
 
45
 
46
+ # fastai returns (tensor, ...) – we only need the mask tensor
47
+ mask = pred[0] if isinstance(pred, (list, tuple)) else pred
48
+ return mask
49
 
 
 
 
 
50
 
51
+ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
52
+ """Main Gradio callback – reads the file, runs inference, returns visualisation."""
53
+ view = view or "Sagittal"
54
 
55
+ img_path = Path(fileobj.name)
56
+ save_fn = f"pred_{img_path.stem}.nii.gz"
57
  save_path = save_dir / save_fn
58
+
59
+ # ------------------------------------------------------------------
60
+ # 1️⃣ Read the image and guarantee a torch.Tensor for the model
61
+ # ------------------------------------------------------------------
62
+ org_img, input_tensor, org_size = read_image_as_tensor(
63
+ str(img_path),
64
+ reorder=reorder,
65
+ resample=resample,
66
+ )
67
+
68
+ # ------------------------------------------------------------------
69
+ # 2️⃣ Run the model
70
+ # ------------------------------------------------------------------
71
+ mask_data = inference(
72
+ learn,
73
+ reorder=reorder,
74
+ resample=resample,
75
+ org_img=org_img,
76
+ input_img=input_tensor,
77
+ org_size=org_size,
78
+ )
79
+
80
+ # ------------------------------------------------------------------
81
+ # 3️⃣ Post‑process orientation (keep your original logic)
82
+ # ------------------------------------------------------------------
83
+ if "".join(org_img.orientation) == "LSA":
84
+ # Adjust axes to match the original orientation
85
+ mask_data = mask_data.permute(0, 1, 3, 2) # (B, C, H, W, D) → (B, C, H, D, W)
86
+ mask_data = torch.flip(mask_data[0], dims=[1]) # remove batch, flip dim‑1
87
+ mask_data = mask_data.unsqueeze(0) # add batch back
88
+
89
+ # ------------------------------------------------------------------
90
+ # 4️⃣ Save the mask as a NIfTI file (optional)
91
+ # ------------------------------------------------------------------
92
+ img = org_img.data # original image data (torch.Tensor)
93
+ org_img.set_data(mask_data) # replace image data with mask
94
+ org_img.save(save_path) # writes a .nii.gz file
95
+
96
+ # ------------------------------------------------------------------
97
+ # 5️⃣ Build gallery of fused slices
98
+ # ------------------------------------------------------------------
99
+ slices = extract_slices_from_mask(img[0].cpu().numpy(), mask_data[0].cpu().numpy(), view)
100
+
101
+ fused_images = [
102
+ get_fused_image(
103
+ normalize_image(slice_img), # safe 0‑255 uint8
104
+ slice_mask,
105
+ view,
106
+ )
107
+ for slice_img, slice_mask in slices
108
+ ]
109
+
110
+ # ------------------------------------------------------------------
111
+ # 6️⃣ Compute volume (your helper expects a FastMRIImage with mask inside)
112
+ # ------------------------------------------------------------------
113
  volume = compute_binary_tumor_volume(org_img)
114
 
115
+ return fused_images, round(volume, 2)