AI Agent commited on
Commit
77f13e1
·
1 Parent(s): d8802bd

Fix roi_align: handle tuple from .unbind() not just list

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -52,9 +52,11 @@ if torch.cuda.is_available():
52
  import torchvision.ops
53
  orig_roi_align = torchvision.ops.roi_align
54
  def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
55
- if isinstance(boxes, torch.Tensor) and input.is_floating_point() and boxes.dtype != input.dtype:
56
- boxes = boxes.to(input.dtype)
57
- elif isinstance(boxes, list):
 
 
58
  boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
59
  return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
60
  torchvision.ops.roi_align = patched_roi_align
 
52
  import torchvision.ops
53
  orig_roi_align = torchvision.ops.roi_align
54
  def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
55
+ # Handle Tensor, list, or tuple (Meta uses .unbind() which returns a tuple!)
56
+ if isinstance(boxes, torch.Tensor):
57
+ if input.is_floating_point() and boxes.dtype != input.dtype:
58
+ boxes = boxes.to(input.dtype)
59
+ elif isinstance(boxes, (list, tuple)):
60
  boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
61
  return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
62
  torchvision.ops.roi_align = patched_roi_align