Upload processor
Browse files- processor.py +4 -2
processor.py
CHANGED
|
@@ -71,7 +71,7 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
| 71 |
BatchFeature
|
| 72 |
pixel_values : torch.Tensor
|
| 73 |
Processed image tensor (B C H W)
|
| 74 |
-
|
| 75 |
Categories indices (B)
|
| 76 |
"""
|
| 77 |
use_cats = categories is not None
|
|
@@ -86,6 +86,8 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
| 86 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
| 87 |
|
| 88 |
if use_cats:
|
| 89 |
-
data["
|
|
|
|
|
|
|
| 90 |
|
| 91 |
return BatchFeature(data=data)
|
|
|
|
| 71 |
BatchFeature
|
| 72 |
pixel_values : torch.Tensor
|
| 73 |
Processed image tensor (B C H W)
|
| 74 |
+
category_indices : torch.Tensor
|
| 75 |
Categories indices (B)
|
| 76 |
"""
|
| 77 |
use_cats = categories is not None
|
|
|
|
| 86 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
| 87 |
|
| 88 |
if use_cats:
|
| 89 |
+
data["category_indices"] = torch.stack(
|
| 90 |
+
[self.process_cat(c) for c in categories]
|
| 91 |
+
)
|
| 92 |
|
| 93 |
return BatchFeature(data=data)
|