Spaces:
Sleeping
Sleeping
AI Agent commited on
Commit ·
d8802bd
1
Parent(s): 273261c
Add roi_align + layer_norm interceptors to fix HalfTensor/FloatTensor mismatch in geometry encoder
Browse files
app.py
CHANGED
|
@@ -46,6 +46,29 @@ if torch.cuda.is_available():
|
|
| 46 |
return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
|
| 47 |
F.conv2d = patched_conv2d
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
|
| 50 |
# (HuggingFace Spaces can use the hf_hub_download mechanism)
|
| 51 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 46 |
return orig_conv2d(input, weight, bias, stride, padding, dilation, groups)
|
| 47 |
F.conv2d = patched_conv2d
|
| 48 |
|
| 49 |
+
# 3. Patch torchvision.ops.roi_align — Meta's geometry_encoders.py
|
| 50 |
+
# calls boxes_xyxy.float() which creates float32 while img_feats is float16.
|
| 51 |
+
try:
|
| 52 |
+
import torchvision.ops
|
| 53 |
+
orig_roi_align = torchvision.ops.roi_align
|
| 54 |
+
def patched_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
|
| 55 |
+
if isinstance(boxes, torch.Tensor) and input.is_floating_point() and boxes.dtype != input.dtype:
|
| 56 |
+
boxes = boxes.to(input.dtype)
|
| 57 |
+
elif isinstance(boxes, list):
|
| 58 |
+
boxes = [b.to(input.dtype) if isinstance(b, torch.Tensor) and b.dtype != input.dtype else b for b in boxes]
|
| 59 |
+
return orig_roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned)
|
| 60 |
+
torchvision.ops.roi_align = patched_roi_align
|
| 61 |
+
except ImportError:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
# 4. Patch layer_norm / group_norm — common ViT dtype mismatch points
|
| 65 |
+
orig_layer_norm = F.layer_norm
|
| 66 |
+
def patched_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
| 67 |
+
if weight is not None and input.is_floating_point() and input.dtype != weight.dtype:
|
| 68 |
+
input = input.to(weight.dtype)
|
| 69 |
+
return orig_layer_norm(input, normalized_shape, weight, bias, eps)
|
| 70 |
+
F.layer_norm = patched_layer_norm
|
| 71 |
+
|
| 72 |
# ── Ensure SAM 3 Checkpoint is downloaded ────────────────────────
|
| 73 |
# (HuggingFace Spaces can use the hf_hub_download mechanism)
|
| 74 |
from huggingface_hub import hf_hub_download
|