Spaces:
Sleeping
Sleeping
AI Agent commited on
Commit ·
77f13e1
1
Parent(s): d8802bd
Fix roi_align: handle tuple from .unbind() not just list
Browse files
app.py
CHANGED
|
@@ -52,9 +52,11 @@ if torch.cuda.is_available():
|
|
| 52 |
import torchvision.ops
|
| 53 |
orig_roi_align = torchvision.ops.roi_align
|
| 54 |
def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
|
| 59 |
return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
|
| 60 |
torchvision.ops.roi_align = patched_roi_align
|
|
|
|
| 52 |
import torchvision.ops
|
| 53 |
orig_roi_align = torchvision.ops.roi_align
|
| 54 |
def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
|
| 55 |
+
# Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!)
|
| 56 |
+
if isinstance(boxes, torch.Tensor):
|
| 57 |
+
if input.is_floating_point() and boxes.dtype != input.dtype:
|
| 58 |
+
boxes = boxes.to(input.dtype)
|
| 59 |
+
elif isinstance(boxes, (list, tuple)):
|
| 60 |
boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
|
| 61 |
return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
|
| 62 |
torchvision.ops.roi_align = patched_roi_align
|