Upload model
Browse files- modeling.py +9 -1
modeling.py
CHANGED
|
@@ -405,8 +405,14 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 405 |
return slices
|
| 406 |
|
| 407 |
def preprocess(
|
| 408 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
) -> np.ndarray:
|
|
|
|
|
|
|
| 410 |
mode = mode.lower()
|
| 411 |
if mode == "2d":
|
| 412 |
x = cv2.resize(x, self.image_size)
|
|
@@ -421,6 +427,8 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 421 |
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
|
| 422 |
elif x.ndim == 4:
|
| 423 |
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
|
|
|
|
|
|
|
| 424 |
return x
|
| 425 |
|
| 426 |
def crop_single_plane(
|
|
|
|
| 405 |
return slices
|
| 406 |
|
| 407 |
def preprocess(
|
| 408 |
+
self,
|
| 409 |
+
x: np.ndarray,
|
| 410 |
+
mode: str = "2d",
|
| 411 |
+
torchify: bool = True,
|
| 412 |
+
device: str | torch.device | None = None,
|
| 413 |
) -> np.ndarray:
|
| 414 |
+
if device is not None:
|
| 415 |
+
assert torchify, "`torchify` must be `True` if specifying `device`"
|
| 416 |
mode = mode.lower()
|
| 417 |
if mode == "2d":
|
| 418 |
x = cv2.resize(x, self.image_size)
|
|
|
|
| 427 |
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
|
| 428 |
elif x.ndim == 4:
|
| 429 |
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
|
| 430 |
+
if device is not None:
|
| 431 |
+
x = x.to(device)
|
| 432 |
return x
|
| 433 |
|
| 434 |
def crop_single_plane(
|