postprocess fix
Browse files- bikefusion/data_utils.py +7 -2
bikefusion/data_utils.py
CHANGED
|
@@ -29,7 +29,7 @@ def pad_to_square(image, target_size=(128, 128)):
|
|
| 29 |
|
| 30 |
def preprocess(images):
|
| 31 |
# Convert arrays to tensors
|
| 32 |
-
images = torch.
|
| 33 |
|
| 34 |
# Apply padding to each image in the dataset
|
| 35 |
images = torch.stack([pad_to_square(img) for img in images])
|
|
@@ -52,7 +52,12 @@ def un_pad(image, target_size=(80, 128)):
|
|
| 52 |
|
| 53 |
def postprocess(images):
|
| 54 |
# Convert tensors to arrays
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Unpad each image in the dataset
|
| 58 |
images = np.stack([un_pad(img) for img in images])
|
|
|
|
| 29 |
|
| 30 |
def preprocess(images):
|
| 31 |
# Convert arrays to tensors
|
| 32 |
+
images = torch.tensor(images).float()
|
| 33 |
|
| 34 |
# Apply padding to each image in the dataset
|
| 35 |
images = torch.stack([pad_to_square(img) for img in images])
|
|
|
|
| 52 |
|
| 53 |
def postprocess(images):
|
| 54 |
# Convert tensors to arrays
|
| 55 |
+
if isinstance(images, torch.Tensor):
|
| 56 |
+
images = images.detach().cpu().numpy()
|
| 57 |
+
elif isinstance(images, np.ndarray):
|
| 58 |
+
pass
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError("images must be either a torch.Tensor or a np.ndarray")
|
| 61 |
|
| 62 |
# Unpad each image in the dataset
|
| 63 |
images = np.stack([un_pad(img) for img in images])
|