MiCkSoftware commited on
Commit
e0bc5c6
·
1 Parent(s): 0a33686

streaming

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -1,7 +1,9 @@
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
 
4
  from typing import List, Tuple
 
5
 
6
  # Initialisation du client Hugging Face
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
@@ -18,12 +20,11 @@ class PredictionRequest(BaseModel):
18
  temperature: float = 0.7
19
  top_p: float = 0.95
20
 
21
- @app.post("/predict")
22
- async def predict(request: PredictionRequest):
23
  """
24
- Endpoint REST pour effectuer une prédiction.
25
  """
26
- # Préparer les messages pour l'inférence
27
  messages = [{"role": "system", "content": request.system_message}]
28
  for user_input, assistant_response in request.history:
29
  if user_input:
@@ -32,21 +33,30 @@ async def predict(request: PredictionRequest):
32
  messages.append({"role": "assistant", "content": assistant_response})
33
  messages.append({"role": "user", "content": request.message})
34
 
35
- # Appel de l'API Hugging Face
36
  try:
37
- response = ""
38
- for message in client.chat_completion(
39
  messages,
40
  max_tokens=request.max_tokens,
41
  stream=True,
42
  temperature=request.temperature,
43
  top_p=request.top_p,
44
  ):
45
- response += message.choices[0].delta.content
46
- return {"response": response}
47
  except Exception as e:
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Pour le test en local
51
  if __name__ == "__main__":
52
  import uvicorn
 
1
  from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
  from huggingface_hub import InferenceClient
4
+ from pydantic import BaseModel
5
  from typing import List, Tuple
6
+ import asyncio
7
 
8
  # Initialisation du client Hugging Face
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
20
  temperature: float = 0.7
21
  top_p: float = 0.95
22
 
23
+
24
+ async def generate_stream(request: PredictionRequest):
25
  """
26
+ Générateur asynchrone pour le streaming de réponse.
27
  """
 
28
  messages = [{"role": "system", "content": request.system_message}]
29
  for user_input, assistant_response in request.history:
30
  if user_input:
 
33
  messages.append({"role": "assistant", "content": assistant_response})
34
  messages.append({"role": "user", "content": request.message})
35
 
 
36
  try:
37
+ async for message in client.chat_completion(
 
38
  messages,
39
  max_tokens=request.max_tokens,
40
  stream=True,
41
  temperature=request.temperature,
42
  top_p=request.top_p,
43
  ):
44
+ token = message.choices[0].delta.content
45
+ yield token
46
  except Exception as e:
47
  raise HTTPException(status_code=500, detail=str(e))
48
 
49
+
50
+ @app.post("/predict")
51
+ async def predict(request: PredictionRequest):
52
+ """
53
+ Endpoint REST avec réponse en streaming.
54
+ """
55
+ return StreamingResponse(
56
+ generate_stream(request),
57
+ media_type="text/plain" # Peut être changé en JSON si besoin
58
+ )
59
+
60
  # Pour le test en local
61
  if __name__ == "__main__":
62
  import uvicorn