ianpan commited on
Commit
b1abe51
·
verified ·
1 Parent(s): 07c8d8d

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +9 -1
modeling.py CHANGED
@@ -405,8 +405,14 @@ class TotalClassifierModel(PreTrainedModel):
405
  return slices
406
 
407
  def preprocess(
408
- self, x: np.ndarray, mode: str = "2d", torchify: bool = True
 
 
 
 
409
  ) -> np.ndarray:
 
 
410
  mode = mode.lower()
411
  if mode == "2d":
412
  x = cv2.resize(x, self.image_size)
@@ -421,6 +427,8 @@ class TotalClassifierModel(PreTrainedModel):
421
  x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
422
  elif x.ndim == 4:
423
  x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
 
 
424
  return x
425
 
426
  def crop_single_plane(
 
405
  return slices
406
 
407
  def preprocess(
408
+ self,
409
+ x: np.ndarray,
410
+ mode: str = "2d",
411
+ torchify: bool = True,
412
+ device: str | torch.device | None = None,
413
  ) -> np.ndarray:
414
+ if device is not None:
415
+ assert torchify, "`torchify` must be `True` if specifying `device`"
416
  mode = mode.lower()
417
  if mode == "2d":
418
  x = cv2.resize(x, self.image_size)
 
427
  x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
428
  elif x.ndim == 4:
429
  x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
430
+ if device is not None:
431
+ x = x.to(device)
432
  return x
433
 
434
  def crop_single_plane(