Yeroyan commited on
Commit
51cf2c0
·
verified ·
1 Parent(s): de13ab3

Fix BFloat16 numpy conversion

Browse files
visual_rag/retrieval/multi_vector.py CHANGED
@@ -157,9 +157,10 @@ class MultiVectorRetriever:
157
  ) -> List[Dict[str, Any]]:
158
  q = self.embedder.embed_query(query)
159
  if isinstance(q, torch.Tensor):
160
- query_embedding = q.detach().cpu().numpy()
 
161
  else:
162
- query_embedding = np.asarray(q)
163
 
164
  return self.search_embedded(
165
  query_embedding=query_embedding,
 
157
  ) -> List[Dict[str, Any]]:
158
  q = self.embedder.embed_query(query)
159
  if isinstance(q, torch.Tensor):
160
+ # .float() converts BFloat16 to Float32 (numpy doesn't support BFloat16)
161
+ query_embedding = q.detach().cpu().float().numpy()
162
  else:
163
+ query_embedding = np.asarray(q, dtype=np.float32)
164
 
165
  return self.search_embedded(
166
  query_embedding=query_embedding,
visual_rag/retrieval/single_stage.py CHANGED
@@ -129,5 +129,5 @@ class SingleStageRetriever:
129
  if isinstance(embedding, torch.Tensor):
130
  if embedding.dtype == torch.bfloat16:
131
  return embedding.cpu().float().numpy()
132
- return embedding.cpu().numpy()
133
  return np.array(embedding, dtype=np.float32)
 
129
  if isinstance(embedding, torch.Tensor):
130
  if embedding.dtype == torch.bfloat16:
131
  return embedding.cpu().float().numpy()
132
+ return embedding.cpu().float().numpy() # .float() for BFloat16 compatibility
133
  return np.array(embedding, dtype=np.float32)
visual_rag/retrieval/three_stage.py CHANGED
@@ -51,7 +51,7 @@ class ThreeStageRetriever:
51
  if isinstance(embedding, torch.Tensor):
52
  if embedding.dtype == torch.bfloat16:
53
  return embedding.cpu().float().numpy()
54
- return embedding.cpu().numpy()
55
  return np.array(embedding, dtype=np.float32)
56
 
57
  def _infer_vector_is_multivector(self, vector_name: str) -> bool:
 
51
  if isinstance(embedding, torch.Tensor):
52
  if embedding.dtype == torch.bfloat16:
53
  return embedding.cpu().float().numpy()
54
+ return embedding.cpu().float().numpy() # .float() for BFloat16 compatibility
55
  return np.array(embedding, dtype=np.float32)
56
 
57
  def _infer_vector_is_multivector(self, vector_name: str) -> bool:
visual_rag/retrieval/two_stage.py CHANGED
@@ -418,7 +418,7 @@ class TwoStageRetriever:
418
  if isinstance(embedding, torch.Tensor):
419
  if embedding.dtype == torch.bfloat16:
420
  return embedding.cpu().float().numpy()
421
- return embedding.cpu().numpy()
422
  return np.array(embedding, dtype=np.float32)
423
 
424
  def build_filter(
 
418
  if isinstance(embedding, torch.Tensor):
419
  if embedding.dtype == torch.bfloat16:
420
  return embedding.cpu().float().numpy()
421
+ return embedding.cpu().float().numpy() # .float() for BFloat16 compatibility
422
  return np.array(embedding, dtype=np.float32)
423
 
424
  def build_filter(