Update apis/chat_api.py
Browse files- apis/chat_api.py +14 -26
apis/chat_api.py
CHANGED
|
@@ -187,42 +187,30 @@ class ChatAPIApp:
|
|
| 187 |
data_response = streamer.chat_return_dict(stream_response)
|
| 188 |
return data_response
|
| 189 |
|
| 190 |
-
async def chat_embedding(self,
|
| 191 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
|
| 192 |
headers = {"Authorization": f"Bearer {api_key}"}
|
| 193 |
-
response = requests.post(api_url, headers=headers, json={"inputs":
|
| 194 |
result = response.json()
|
| 195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
| 196 |
-
|
| 197 |
-
flattened_embeddings = [sum(embedding, []) for embedding in result]
|
| 198 |
-
return flattened_embeddings
|
| 199 |
elif "error" in result:
|
| 200 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
| 201 |
else:
|
| 202 |
raise RuntimeError("Unexpected response format.")
|
| 203 |
-
|
| 204 |
|
| 205 |
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
"model": request.model,
|
| 218 |
-
"usage": {"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
|
| 219 |
-
}
|
| 220 |
-
except RuntimeError as e:
|
| 221 |
-
if attempt < 2: # Don't sleep on the last attempt
|
| 222 |
-
await asyncio.sleep(10) # Delay for the retry
|
| 223 |
-
raise HTTPException(status_code=503, detail="The model is currently loading, please try again later.")
|
| 224 |
-
except Exception as e:
|
| 225 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 226 |
|
| 227 |
def setup_routes(self):
|
| 228 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
|
|
|
| 187 |
data_response = streamer.chat_return_dict(stream_response)
|
| 188 |
return data_response
|
| 189 |
|
| 190 |
+
async def chat_embedding(self, input_text: str, model_name: str, api_key: str):
|
| 191 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
|
| 192 |
headers = {"Authorization": f"Bearer {api_key}"}
|
| 193 |
+
response = requests.post(api_url, headers=headers, json={"inputs": input_text})
|
| 194 |
result = response.json()
|
| 195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
| 196 |
+
return [item for sublist in result for item in sublist] # Flatten the list of lists
|
|
|
|
|
|
|
| 197 |
elif "error" in result:
|
| 198 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
| 199 |
else:
|
| 200 |
raise RuntimeError("Unexpected response format.")
|
|
|
|
| 201 |
|
| 202 |
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
|
| 203 |
+
try:
|
| 204 |
+
embeddings = await self.chat_embedding(request.input, request.model, api_key)
|
| 205 |
+
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(embeddings)]
|
| 206 |
+
return EmbeddingResponse(
|
| 207 |
+
object="list",
|
| 208 |
+
data=data,
|
| 209 |
+
model=request.model,
|
| 210 |
+
usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
|
| 211 |
+
)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
def setup_routes(self):
|
| 216 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|