ianpan commited on
Commit
07c8d8d
·
verified ·
1 Parent(s): 9e9710a

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +2 -2
modeling.py CHANGED
@@ -418,9 +418,9 @@ class TotalClassifierModel(PreTrainedModel):
418
  x = x[:, :, :, np.newaxis]
419
  if torchify:
420
  if x.ndim == 3:
421
- x = torch.from_numpy(x).float().permute(2, 0, 1)
422
  elif x.ndim == 4:
423
- x = torch.from_numpy(x).float().permute(3, 0, 1, 2)
424
  return x
425
 
426
  def crop_single_plane(
 
418
  x = x[:, :, :, np.newaxis]
419
  if torchify:
420
  if x.ndim == 3:
421
+ x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
422
  elif x.ndim == 4:
423
+ x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
424
  return x
425
 
426
  def crop_single_plane(