mountainsma commited on
Commit
bcfaacf
·
1 Parent(s): 3b2121f

Normalize the embedding

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -89,8 +89,8 @@ class EndpointHandler:
89
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
90
  text_tensor = self._tokenize_text(text)
91
 
92
- image_features = self.model.encode_image(image_tensor, normalize=False)
93
- text_features = self.model.encode_text(text_tensor, normalize=False)
94
 
95
  response = {"image_embedding": image_features[0].cpu().tolist()}
96
  if isinstance(text, list):
@@ -100,11 +100,11 @@ class EndpointHandler:
100
  return response
101
  elif image is not None:
102
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
103
- image_features = self.model.encode_image(image_tensor, normalize=False)
104
  return {"image_embedding": image_features[0].cpu().tolist()}
105
  elif text is not None:
106
  text_tensor = self._tokenize_text(text)
107
- text_features = self.model.encode_text(text_tensor, normalize=False)
108
  if isinstance(text, list):
109
  return {"text_embeddings": text_features.cpu().tolist()}
110
  return {"text_embedding": text_features[0].cpu().tolist()}
 
89
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
90
  text_tensor = self._tokenize_text(text)
91
 
92
+ image_features = self.model.encode_image(image_tensor, normalize=True)
93
+ text_features = self.model.encode_text(text_tensor, normalize=True)
94
 
95
  response = {"image_embedding": image_features[0].cpu().tolist()}
96
  if isinstance(text, list):
 
100
  return response
101
  elif image is not None:
102
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
103
+ image_features = self.model.encode_image(image_tensor, normalize=True)
104
  return {"image_embedding": image_features[0].cpu().tolist()}
105
  elif text is not None:
106
  text_tensor = self._tokenize_text(text)
107
+ text_features = self.model.encode_text(text_tensor, normalize=True)
108
  if isinstance(text, list):
109
  return {"text_embeddings": text_features.cpu().tolist()}
110
  return {"text_embedding": text_features[0].cpu().tolist()}