tcm03
commited on
Commit
·
aef4077
1
Parent(s):
1060235
Enable text-only embedding request
Browse files- handler.py +21 -4
handler.py
CHANGED
|
@@ -47,6 +47,14 @@ def get_image_embedding(image_base64, model, transformer):
|
|
| 47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
| 48 |
return image_feature.cpu().numpy().tolist()
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
class EndpointHandler:
|
| 51 |
def __init__(self, path: str = ""):
|
| 52 |
"""
|
|
@@ -80,9 +88,10 @@ class EndpointHandler:
|
|
| 80 |
Returns:
|
| 81 |
dict: {"embedding": [float, float, ...]}
|
| 82 |
"""
|
| 83 |
-
|
| 84 |
inputs = data.pop("inputs", data)
|
| 85 |
-
|
|
|
|
| 86 |
sketch_base64 = inputs.get("sketch", "")
|
| 87 |
text_query = inputs.get("text", "")
|
| 88 |
if not sketch_base64 or not text_query:
|
|
@@ -91,11 +100,19 @@ class EndpointHandler:
|
|
| 91 |
# Generate Fused Embedding
|
| 92 |
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
|
| 93 |
return {"embedding": fused_embedding}
|
| 94 |
-
|
|
|
|
| 95 |
image_base64 = inputs.get("image", "")
|
| 96 |
if not image_base64:
|
| 97 |
return {"error": "Image 'image' (base64) is required input."}
|
| 98 |
embedding = get_image_embedding(image_base64, self.model, self.transform)
|
| 99 |
return {"embedding": embedding}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
else:
|
| 101 |
-
return {"error": "
|
|
|
|
| 47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
| 48 |
return image_feature.cpu().numpy().tolist()
|
| 49 |
|
| 50 |
+
def get_text_embedding(text, model):
|
| 51 |
+
"""Convert text query to tensor."""
|
| 52 |
+
text_tensor = preprocess_text(text)
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
text_feature = model.encode_text(text_tensor)
|
| 55 |
+
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
| 56 |
+
return text_feature.cpu().numpy().tolist()
|
| 57 |
+
|
| 58 |
class EndpointHandler:
|
| 59 |
def __init__(self, path: str = ""):
|
| 60 |
"""
|
|
|
|
| 88 |
Returns:
|
| 89 |
dict: {"embedding": [float, float, ...]}
|
| 90 |
"""
|
| 91 |
+
|
| 92 |
inputs = data.pop("inputs", data)
|
| 93 |
+
# text-sketch embedding
|
| 94 |
+
if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
|
| 95 |
sketch_base64 = inputs.get("sketch", "")
|
| 96 |
text_query = inputs.get("text", "")
|
| 97 |
if not sketch_base64 or not text_query:
|
|
|
|
| 100 |
# Generate Fused Embedding
|
| 101 |
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
|
| 102 |
return {"embedding": fused_embedding}
|
| 103 |
+
# image-only embedding
|
| 104 |
+
elif len(inputs) == 1 and "image" in inputs:
|
| 105 |
image_base64 = inputs.get("image", "")
|
| 106 |
if not image_base64:
|
| 107 |
return {"error": "Image 'image' (base64) is required input."}
|
| 108 |
embedding = get_image_embedding(image_base64, self.model, self.transform)
|
| 109 |
return {"embedding": embedding}
|
| 110 |
+
# text-only embedding
|
| 111 |
+
elif len(inputs) == 1 and "text" in inputs:
|
| 112 |
+
text_query = inputs.get("text", "")
|
| 113 |
+
if not text_query:
|
| 114 |
+
return {"error": "Text 'text' is required input."}
|
| 115 |
+
embedding = get_text_embedding(text_query, self.model)
|
| 116 |
+
return {"embedding": embedding}
|
| 117 |
else:
|
| 118 |
+
return {"error": "Invalid request."}
|