Napron commited on
Commit
736b2a1
·
verified ·
1 Parent(s): ae6bf0f

modified nomic encode images

Browse files
Files changed (1) hide show
  1. nomic_fewshot.py +9 -17
nomic_fewshot.py CHANGED
@@ -228,34 +228,26 @@ class NomicVisionEncoderONNX:
228
 
229
  def encode_images(self, images: list[Image.Image]) -> np.ndarray:
230
  rgb = [img.convert("RGB") for img in images]
231
- processed = self.processor(images=rgb, return_tensors="np")
232
-
233
  if "pixel_values" not in processed:
234
  raise RuntimeError(f"Processor did not return pixel_values. Keys={list(processed.keys())}")
235
-
236
- pixel_values = processed["pixel_values"]
237
- pixel_values = (
238
- pixel_values.numpy().astype(np.float32)
239
- if hasattr(pixel_values, "numpy")
240
- else np.asarray(pixel_values, dtype=np.float32)
241
- )
242
-
243
- feeds = {}
244
  if self._pixel_name is None:
245
  raise RuntimeError(f"Could not find pixel input in ONNX inputs: {self.input_names}")
246
- feeds[self._pixel_name] = pixel_values
247
-
248
- outputs = self.session.run(self.output_names, feeds)
249
  main_out = _pick_output(outputs, self.output_names, kind="vision")
250
-
251
- # Current PyTorch behavior: CLS token from last_hidden_state
252
  if main_out.ndim == 3:
253
  embs = main_out[:, 0, :]
254
  elif main_out.ndim == 2:
255
  embs = main_out
256
  else:
257
  raise RuntimeError(f"Unexpected vision output rank: {main_out.ndim}")
258
-
259
  return _l2_normalize(embs, axis=1)
260
 
261
 
 
228
 
229
  def encode_images(self, images: list[Image.Image]) -> np.ndarray:
230
  rgb = [img.convert("RGB") for img in images]
231
+ processed = self.processor(images=rgb, return_tensors="pt")
232
+
233
  if "pixel_values" not in processed:
234
  raise RuntimeError(f"Processor did not return pixel_values. Keys={list(processed.keys())}")
235
+
236
+ pixel_values = processed["pixel_values"].detach().cpu().numpy().astype(np.float32)
237
+
 
 
 
 
 
 
238
  if self._pixel_name is None:
239
  raise RuntimeError(f"Could not find pixel input in ONNX inputs: {self.input_names}")
240
+
241
+ outputs = self.session.run(self.output_names, {self._pixel_name: pixel_values})
 
242
  main_out = _pick_output(outputs, self.output_names, kind="vision")
243
+
 
244
  if main_out.ndim == 3:
245
  embs = main_out[:, 0, :]
246
  elif main_out.ndim == 2:
247
  embs = main_out
248
  else:
249
  raise RuntimeError(f"Unexpected vision output rank: {main_out.ndim}")
250
+
251
  return _l2_normalize(embs, axis=1)
252
 
253