MiCkSoftware commited on
Commit
0d3bbff
·
1 Parent(s): 3712677
Files changed (2) hide show
  1. app.py +22 -45
  2. client.py +3 -3
app.py CHANGED
@@ -1,9 +1,9 @@
1
- from starlette.types import Message
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
  from huggingface_hub import InferenceClient
5
  from pydantic import BaseModel
6
  from typing import List, Tuple
 
7
 
8
  # Initialisation du client Hugging Face
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
@@ -12,8 +12,6 @@ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
  app = FastAPI()
13
 
14
  # Modèle pour les données d'entrée
15
-
16
-
17
  class PredictionRequest(BaseModel):
18
  message: str
19
  history: List[Tuple[str, str]] = []
@@ -23,9 +21,10 @@ class PredictionRequest(BaseModel):
23
  top_p: float = 0.95
24
 
25
 
26
- def generate_stream(request: PredictionRequest):
27
  """
28
- Générateur synchrone pour produire les tokens progressivement.
 
29
  """
30
  messages = [{"role": "system", "content": request.system_message}]
31
  for user_input, assistant_response in request.history:
@@ -36,57 +35,35 @@ def generate_stream(request: PredictionRequest):
36
  {"role": "assistant", "content": assistant_response})
37
  messages.append({"role": "user", "content": request.message})
38
 
39
- yield "START\n".encode("utf-8")
40
-
41
  try:
42
- # Appel à l'API Hugging Face avec streaming
43
- for message in client.chat_completion(
44
- messages,
45
- max_tokens=request.max_tokens,
46
- stream=True,
47
- temperature=request.temperature,
48
- top_p=request.top_p,
49
- ):
 
 
 
 
50
  token = message.choices[0].delta.content
51
- print(token)
52
- # Chaque token avec un saut de ligne
53
- yield f"{token}\n".encode("utf-8")
54
  except Exception as e:
55
- yield f"Error: {str(e)}\n".encode("utf-8")
56
-
57
-
58
- class CustomStreamingResponse(StreamingResponse):
59
- """
60
- Personnalisation de StreamingResponse pour s'assurer que chaque chunk est envoyé immédiatement.
61
- """
62
-
63
- def __init__(self, *args, **kwargs):
64
- super().__init__(*args, **kwargs)
65
- self.started = False # Initialisation de l'attribut `started`
66
-
67
- async def stream_response(self, send: Message):
68
- # Envoi du message de démarrage une seule fois
69
- if not self.started:
70
- await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]})
71
- self.started = True
72
-
73
- # Envoi des chunks de réponse
74
- async for chunk in self.body_iterator:
75
- await send({"type": "http.response.body", "body": chunk, "more_body": True})
76
- await send({"type": "http.response.body", "body": b"", "more_body": False})
77
 
78
 
79
  @app.post("/predict")
80
- def predict(request: PredictionRequest):
81
  """
82
- Endpoint REST avec réponse en streaming synchrone.
83
  """
84
- return CustomStreamingResponse(
85
  generate_stream(request),
86
- media_type="text/plain" # Peut être changé en JSON si nécessaire
87
  )
88
 
89
-
90
  # Pour le test en local
91
  if __name__ == "__main__":
92
  import uvicorn
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  from huggingface_hub import InferenceClient
4
  from pydantic import BaseModel
5
  from typing import List, Tuple
6
+ import asyncio
7
 
8
  # Initialisation du client Hugging Face
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
12
  app = FastAPI()
13
 
14
  # Modèle pour les données d'entrée
 
 
15
  class PredictionRequest(BaseModel):
16
  message: str
17
  history: List[Tuple[str, str]] = []
 
21
  top_p: float = 0.95
22
 
23
 
24
+ async def generate_stream(request: PredictionRequest):
25
  """
26
+ Générateur asynchrone pour produire les tokens progressivement.
27
+ Utilise asyncio.to_thread pour rendre l'appel synchrone compatible avec async.
28
  """
29
  messages = [{"role": "system", "content": request.system_message}]
30
  for user_input, assistant_response in request.history:
 
35
  {"role": "assistant", "content": assistant_response})
36
  messages.append({"role": "user", "content": request.message})
37
 
 
 
38
  try:
39
+ # Exécution du client synchrone dans un thread séparé
40
+ def sync_stream():
41
+ return client.chat_completion(
42
+ messages,
43
+ max_tokens=request.max_tokens,
44
+ stream=True,
45
+ temperature=request.temperature,
46
+ top_p=request.top_p,
47
+ )
48
+
49
+ # Appel synchrone dans un thread asynchrone
50
+ for message in await asyncio.to_thread(sync_stream):
51
  token = message.choices[0].delta.content
52
+ yield f"{token}\n"
 
 
53
  except Exception as e:
54
+ yield f"Error: {str(e)}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  @app.post("/predict")
58
+ async def predict(request: PredictionRequest):
59
  """
60
+ Endpoint REST avec réponse en streaming.
61
  """
62
+ return StreamingResponse(
63
  generate_stream(request),
64
+ media_type="text/plain" # Peut être changé en JSON si besoin
65
  )
66
 
 
67
  # Pour le test en local
68
  if __name__ == "__main__":
69
  import uvicorn
client.py CHANGED
@@ -5,10 +5,10 @@ url = "https://micksoftware-laria-startup.hf.space/predict"
5
 
6
  # Données pour la requête
7
  payload = {
8
- "message": "racontes moi une histoire de 50 mots",
9
  "history": [],
10
  "system_message": "You are a friendly Chatbot.",
11
- "max_tokens": 2048,
12
  "temperature": 0.7,
13
  "top_p": 0.95,
14
  }
@@ -21,6 +21,6 @@ if response.status_code == 200:
21
  print("Streaming response:")
22
  for chunk in response.iter_lines(decode_unicode=True):
23
  if chunk:
24
- print(chunk, end="\n")
25
  else:
26
  print(f"Erreur : {response.status_code} - {response.text}")
 
5
 
6
  # Données pour la requête
7
  payload = {
8
+ "message": "quelle longueur d'ypothenuse pour un triangle de cote 4 et 9",
9
  "history": [],
10
  "system_message": "You are a friendly Chatbot.",
11
+ "max_tokens": 512,
12
  "temperature": 0.7,
13
  "top_p": 0.95,
14
  }
 
21
  print("Streaming response:")
22
  for chunk in response.iter_lines(decode_unicode=True):
23
  if chunk:
24
+ print(chunk, end="")
25
  else:
26
  print(f"Erreur : {response.status_code} - {response.text}")