ianpan commited on
Commit
9e9710a
·
verified ·
1 Parent(s): 9df570f

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +8 -1
modeling.py CHANGED
@@ -404,7 +404,9 @@ class TotalClassifierModel(PreTrainedModel):
404
  return slices, [dicom_files[idx] for idx in indices]
405
  return slices
406
 
407
- def preprocess(self, x: np.ndarray, mode="2d") -> np.ndarray:
 
 
408
  mode = mode.lower()
409
  if mode == "2d":
410
  x = cv2.resize(x, self.image_size)
@@ -414,6 +416,11 @@ class TotalClassifierModel(PreTrainedModel):
414
  x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0)
415
  if x.ndim == 3:
416
  x = x[:, :, :, np.newaxis]
 
 
 
 
 
417
  return x
418
 
419
  def crop_single_plane(
 
404
  return slices, [dicom_files[idx] for idx in indices]
405
  return slices
406
 
407
+ def preprocess(
408
+ self, x: np.ndarray, mode: str = "2d", torchify: bool = True
409
+ ) -> np.ndarray:
410
  mode = mode.lower()
411
  if mode == "2d":
412
  x = cv2.resize(x, self.image_size)
 
416
  x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0)
417
  if x.ndim == 3:
418
  x = x[:, :, :, np.newaxis]
419
+ if torchify:
420
+ if x.ndim == 3:
421
+ x = torch.from_numpy(x).float().permute(2, 0, 1)
422
+ elif x.ndim == 4:
423
+ x = torch.from_numpy(x).float().permute(3, 0, 1, 2)
424
  return x
425
 
426
  def crop_single_plane(