Spaces:
Runtime error
Runtime error
neverix
commited on
Commit
·
a40667a
1
Parent(s):
2b04aed
Fix bug? (?)
Browse files- data_loader.py +2 -2
data_loader.py
CHANGED
|
@@ -226,7 +226,7 @@ class FileDataset(Dataset):
|
|
| 226 |
if "labels" in sample:
|
| 227 |
# return UDP as 4chn XYZV float tensor
|
| 228 |
sample["labels"] = torch.from_numpy(
|
| 229 |
-
sample["labels"].transpose((2, 0, 1)))
|
| 230 |
assert (sample["labels"].dtype == torch.float32)
|
| 231 |
|
| 232 |
if "image_np" in sample:
|
|
@@ -270,4 +270,4 @@ class FileDataset(Dataset):
|
|
| 270 |
"character_masks": character_masks
|
| 271 |
})
|
| 272 |
# do not make fake labels in inference
|
| 273 |
-
return sample
|
|
|
|
| 226 |
if "labels" in sample:
|
| 227 |
# return UDP as 4chn XYZV float tensor
|
| 228 |
sample["labels"] = torch.from_numpy(
|
| 229 |
+
sample["labels"].transpose((2, 0, 1)).astype(np.float32))
|
| 230 |
assert (sample["labels"].dtype == torch.float32)
|
| 231 |
|
| 232 |
if "image_np" in sample:
|
|
|
|
| 270 |
"character_masks": character_masks
|
| 271 |
})
|
| 272 |
# do not make fake labels in inference
|
| 273 |
+
return sample
|