AI Agent commited on
Commit
d8802bd
·
1 Parent(s): 273261c

Add roi_align + layer_norm interceptors to fix HalfTensor/FloatTensor mismatch in geometry encoder

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -46,6 +46,29 @@ if torch.cuda.is_available():
46
  return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
47
  F.conv2d = patched_conv2d
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
50
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
51
  from huggingface_hub import hf_hub_download
 
46
  return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
47
  F.conv2d = patched_conv2d
48
 
49
+ # 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py
50
+ # calls boxes_xyxy.float() which creates float32 while img_feats is float16.
51
+ try:
52
+ import torchvision.ops
53
+ orig_roi_align = torchvision.ops.roi_align
54
+ def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
55
+ if isinstance(boxes, torch.Tensor) and input.is_floating_point() and boxes.dtype != input.dtype:
56
+ boxes = boxes.to(input.dtype)
57
+ elif isinstance(boxes, list):
58
+ boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
59
+ return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
60
+ torchvision.ops.roi_align = patched_roi_align
61
+ except ImportError:
62
+ pass
63
+
64
+ # 4. Patch layer_norm / group_norm — common ViT dtype mismatch points
65
+ orig_layer_norm = F.layer_norm
66
+ def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
67
+ if weight is not None and input.is_floating_point() and input.dtype != weight.dtype:
68
+ input = input.to(weight.dtype)
69
+ return orig_layer_norm(input, normalized_shape, weight, bias, eps)
70
+ F.layer_norm = patched_layer_norm
71
+
72
  # ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
73
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
74
  from huggingface_hub import hf_hub_download