convitom commited on
Commit
cba2b6c
·
1 Parent(s): d313d81
Files changed (1) hide show
  1. model/projection.py +12 -0
model/projection.py CHANGED
@@ -98,6 +98,18 @@ class MLPProjection(nn.Module):
98
  """
99
  B = patch_features.size(0)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Expand query tokens to batch size
102
  queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
103
 
 
98
  """
99
  B = patch_features.size(0)
100
 
101
+ # Align input dtype with the projection's own parameter dtype.
102
+ # The frozen image encoder may run in bf16/fp16 (llm_dtype) while
103
+ # the projection's MLP/MHA weights stay fp32. Under bf16 autocast,
104
+ # nn.MultiheadAttention's in-projection sometimes bypasses autocast
105
+ # (cross-attention path), giving:
106
+ # RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float
107
+ # Upcasting patch_features keeps the matmul self-consistent on any
108
+ # GPU/precision. No-op when dtypes already match (T4 fp16 fast path).
109
+ target_dtype = self.query_tokens.dtype
110
+ if patch_features.dtype != target_dtype:
111
+ patch_features = patch_features.to(target_dtype)
112
+
113
  # Expand query tokens to batch size
114
  queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
115