convitom commited on
Commit ·
cba2b6c
1
Parent(s): d313d81
- model/projection.py +12 -0
model/projection.py
CHANGED
|
@@ -98,6 +98,18 @@ class MLPProjection(nn.Module):
|
|
| 98 |
"""
|
| 99 |
B = patch_features.size(0)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Expand query tokens to batch size
|
| 102 |
queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
|
| 103 |
|
|
|
|
| 98 |
"""
|
| 99 |
B = patch_features.size(0)
|
| 100 |
|
| 101 |
+
# Align input dtype with the projection's own parameter dtype.
|
| 102 |
+
# The frozen image encoder may run in bf16/fp16 (llm_dtype) while
|
| 103 |
+
# the projection's MLP/MHA weights stay fp32. Under bf16 autocast,
|
| 104 |
+
# nn.MultiheadAttention's in-projection sometimes bypasses autocast
|
| 105 |
+
# (cross-attention path), giving:
|
| 106 |
+
# RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float
|
| 107 |
+
# Upcasting patch_features keeps the matmul self-consistent on any
|
| 108 |
+
# GPU/precision. No-op when dtypes already match (T4 fp16 fast path).
|
| 109 |
+
target_dtype = self.query_tokens.dtype
|
| 110 |
+
if patch_features.dtype != target_dtype:
|
| 111 |
+
patch_features = patch_features.to(target_dtype)
|
| 112 |
+
|
| 113 |
# Expand query tokens to batch size
|
| 114 |
queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
|
| 115 |
|