Update apis/chat_api.py
Browse files- apis/chat_api.py +6 -3
apis/chat_api.py
CHANGED
|
@@ -24,7 +24,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 24 |
class EmbeddingResponseItem(BaseModel):
|
| 25 |
object: str = "embedding"
|
| 26 |
index: int
|
| 27 |
-
embedding: List[float]
|
| 28 |
|
| 29 |
class EmbeddingResponse(BaseModel):
|
| 30 |
object: str = "list"
|
|
@@ -193,11 +193,14 @@ class ChatAPIApp:
|
|
| 193 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
| 194 |
result = response.json()
|
| 195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
elif "error" in result:
|
| 198 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
| 199 |
else:
|
| 200 |
-
raise RuntimeError("Unexpected response format.")
|
|
|
|
| 201 |
|
| 202 |
async def embedding(self, request: QueryRequest):
|
| 203 |
try:
|
|
|
|
| 24 |
class EmbeddingResponseItem(BaseModel):
|
| 25 |
object: str = "embedding"
|
| 26 |
index: int
|
| 27 |
+
embedding: List[List[float]]
|
| 28 |
|
| 29 |
class EmbeddingResponse(BaseModel):
|
| 30 |
object: str = "list"
|
|
|
|
| 193 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
| 194 |
result = response.json()
|
| 195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
| 196 |
+
# Assuming each embedding is a list of lists of floats, flatten it
|
| 197 |
+
flattened_embeddings = [sum(embedding, []) for embedding in result]
|
| 198 |
+
return flattened_embeddings
|
| 199 |
elif "error" in result:
|
| 200 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
| 201 |
else:
|
| 202 |
+
raise RuntimeError("Unexpected response format.")
|
| 203 |
+
|
| 204 |
|
| 205 |
async def embedding(self, request: QueryRequest):
|
| 206 |
try:
|