Upload model
Browse files- modeling.py +2 -2
modeling.py
CHANGED
|
@@ -418,9 +418,9 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 418 |
x = x[:, :, :, np.newaxis]
|
| 419 |
if torchify:
|
| 420 |
if x.ndim == 3:
|
| 421 |
-
x = torch.from_numpy(x).float()
|
| 422 |
elif x.ndim == 4:
|
| 423 |
-
x = torch.from_numpy(x).float()
|
| 424 |
return x
|
| 425 |
|
| 426 |
def crop_single_plane(
|
|
|
|
| 418 |
x = x[:, :, :, np.newaxis]
|
| 419 |
if torchify:
|
| 420 |
if x.ndim == 3:
|
| 421 |
+
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
|
| 422 |
elif x.ndim == 4:
|
| 423 |
+
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
|
| 424 |
return x
|
| 425 |
|
| 426 |
def crop_single_plane(
|