Commit
·
b38d2da
1
Parent(s):
d19a591
output serialization
Browse files- handler.py +8 -3
handler.py
CHANGED
|
@@ -36,9 +36,14 @@ class EndpointHandler:
|
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = self.model(**inputs)
|
| 38 |
|
| 39 |
-
# Process outputs -
|
| 40 |
-
#
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
|
|
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = self.model(**inputs)
|
| 38 |
|
| 39 |
+
# Process outputs - convert tensors to serializable format
|
| 40 |
+
# Extract the last hidden state and convert to list for JSON serialization
|
| 41 |
+
last_hidden_state = outputs.last_hidden_state
|
| 42 |
+
|
| 43 |
+
# Convert to Python list (serializable) - using the mean of the embeddings as a simple approach
|
| 44 |
+
embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist()
|
| 45 |
+
|
| 46 |
+
return [{"input_text": input_text, "embedding": embedding}]
|
| 47 |
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|