Spaces:
Running
Running
Added option to output probability map from CLI
Browse files- lungtumormask/__main__.py +4 -0
- lungtumormask/dataprocessing.py +16 -6
- lungtumormask/mask.py +1 -1
lungtumormask/__main__.py
CHANGED
|
@@ -27,6 +27,10 @@ def main():
|
|
| 27 |
if args.cpu:
|
| 28 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# import method here to enable faster testing
|
| 31 |
from lungtumormask import mask
|
| 32 |
|
|
|
|
| 27 |
if args.cpu:
|
| 28 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 29 |
|
| 30 |
+
# check if chosen threshold is in accepted range
|
| 31 |
+
if not (((args.threshold >= 0.0) and (args.threshold <= 1.0)) or (args.threshold == -1)):
|
| 32 |
+
raise ValueError("Chosen threshold must be -1 or in range [0.0, 1.0], but was:", args.threshold)
|
| 33 |
+
|
| 34 |
# import method here to enable faster testing
|
| 35 |
from lungtumormask import mask
|
| 36 |
|
lungtumormask/dataprocessing.py
CHANGED
|
@@ -226,20 +226,30 @@ def post_process(left, right, preprocess_dump, lung_filter, threshold, radius):
|
|
| 226 |
left = voxel_space(left, preprocess_dump['left_extremes'])
|
| 227 |
right = voxel_space(right, preprocess_dump['right_extremes'])
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
|
|
|
| 231 |
|
| 232 |
left = stitch(preprocess_dump['org_shape'], left, preprocess_dump['left_extremes'])
|
| 233 |
right = stitch(preprocess_dump['org_shape'], right, preprocess_dump['right_extremes'])
|
| 234 |
|
| 235 |
-
|
|
|
|
|
|
|
| 236 |
|
| 237 |
# filter tumor predictions outside the predicted lung area
|
| 238 |
if lung_filter:
|
| 239 |
stitched[preprocess_dump['lungmask'] == 0] = 0
|
| 240 |
|
| 241 |
-
# final post-processing - fix fragmentation
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
return stitched
|
|
|
|
| 226 |
left = voxel_space(left, preprocess_dump['left_extremes'])
|
| 227 |
right = voxel_space(right, preprocess_dump['right_extremes'])
|
| 228 |
|
| 229 |
+
if threshold != -1:
|
| 230 |
+
left = (left >= threshold).astype(int)
|
| 231 |
+
right = (right >= threshold).astype(int)
|
| 232 |
|
| 233 |
left = stitch(preprocess_dump['org_shape'], left, preprocess_dump['left_extremes'])
|
| 234 |
right = stitch(preprocess_dump['org_shape'], right, preprocess_dump['right_extremes'])
|
| 235 |
|
| 236 |
+
# fuse left and right lung preds together (but keep probability maps if available)
|
| 237 |
+
stitched = np.maximum(left, right)
|
| 238 |
+
del left, right
|
| 239 |
|
| 240 |
# filter tumor predictions outside the predicted lung area
|
| 241 |
if lung_filter:
|
| 242 |
stitched[preprocess_dump['lungmask'] == 0] = 0
|
| 243 |
|
| 244 |
+
# final post-processing - fix fragmentation (only relevant for binary volume)
|
| 245 |
+
if threshold != -1:
|
| 246 |
+
for i in range(stitched.shape[-1]):
|
| 247 |
+
stitched[..., i] = binary_closing(stitched[..., i], footprint=disk(radius=radius))
|
| 248 |
+
|
| 249 |
+
# for threshold != -1, set result to uint8 dtype, else float32 (for probability map)
|
| 250 |
+
if threshold == -1:
|
| 251 |
+
stitched = stitched.astype("float32")
|
| 252 |
+
else:
|
| 253 |
+
stitched = stitched.astype("uint8")
|
| 254 |
|
| 255 |
return stitched
|
lungtumormask/mask.py
CHANGED
|
@@ -27,7 +27,7 @@ def mask(image_path, save_path, lung_filter, threshold, radius, batch_size):
|
|
| 27 |
right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
|
| 28 |
|
| 29 |
print("Post-processing image...")
|
| 30 |
-
inferred = post_process(left, right, preprocess_dump, lung_filter, threshold, radius)
|
| 31 |
|
| 32 |
print(f"Storing segmentation at {save_path}")
|
| 33 |
nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])
|
|
|
|
| 27 |
right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
|
| 28 |
|
| 29 |
print("Post-processing image...")
|
| 30 |
+
inferred = post_process(left, right, preprocess_dump, lung_filter, threshold, radius)
|
| 31 |
|
| 32 |
print(f"Storing segmentation at {save_path}")
|
| 33 |
nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])
|