Upload model
Browse files- modeling.py +8 -1
modeling.py
CHANGED
|
@@ -349,6 +349,7 @@ class CTCropModel(PreTrainedModel):
|
|
| 349 |
mode: str,
|
| 350 |
device: str | None = None,
|
| 351 |
raw_hu: bool = False,
|
|
|
|
| 352 |
add_buffer: float | tuple[float, float] | None = None,
|
| 353 |
) -> np.ndarray:
|
| 354 |
assert mode in ["2d", "3d"]
|
|
@@ -375,7 +376,8 @@ class CTCropModel(PreTrainedModel):
|
|
| 375 |
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
|
| 376 |
# get the union of all slice-wise bounding boxes
|
| 377 |
# exclude empty boxes
|
| 378 |
-
|
|
|
|
| 379 |
# if all empty, return original input
|
| 380 |
if coords.shape[0] == 0:
|
| 381 |
print("no foreground detected, returning original input ...")
|
|
@@ -386,4 +388,9 @@ class CTCropModel(PreTrainedModel):
|
|
| 386 |
x1, y1 = x1.min().item(), y1.min().item()
|
| 387 |
x2, y2 = x2.max().item(), y2.max().item()
|
| 388 |
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
return cropped
|
|
|
|
| 349 |
mode: str,
|
| 350 |
device: str | None = None,
|
| 351 |
raw_hu: bool = False,
|
| 352 |
+
remove_empty_slices: bool = False,
|
| 353 |
add_buffer: float | tuple[float, float] | None = None,
|
| 354 |
) -> np.ndarray:
|
| 355 |
assert mode in ["2d", "3d"]
|
|
|
|
| 376 |
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
|
| 377 |
# get the union of all slice-wise bounding boxes
|
| 378 |
# exclude empty boxes
|
| 379 |
+
empty = coords.sum(dim=1) == 0
|
| 380 |
+
coords = coords[~empty]
|
| 381 |
# if all empty, return original input
|
| 382 |
if coords.shape[0] == 0:
|
| 383 |
print("no foreground detected, returning original input ...")
|
|
|
|
| 388 |
x1, y1 = x1.min().item(), y1.min().item()
|
| 389 |
x2, y2 = x2.max().item(), y2.max().item()
|
| 390 |
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
|
| 391 |
+
if remove_empty_slices and empty.sum() > 0:
|
| 392 |
+
empty_indices = list(torch.where(empty)[0].cpu().numpy())
|
| 393 |
+
print(f"removing {empty.sum()} empty slices ...")
|
| 394 |
+
cropped = cropped[~empty]
|
| 395 |
+
return cropped, empty_indices
|
| 396 |
return cropped
|