File size: 2,405 Bytes
0a33686
e0bc5c6
0a33686
e0bc5c6
0a33686
0d3bbff
d56d0b4
f3affbe
0a33686
e87158e
673f80d
 
e2cdbc6
f3affbe
0a33686
 
f3affbe
0a33686
 
 
 
 
 
 
 
 
e0bc5c6
0d3bbff
0a33686
0d3bbff
 
0a33686
 
 
 
 
 
41f2a92
 
0a33686
41f2a92
0a33686
0d3bbff
 
 
 
 
 
 
 
 
 
 
 
e0bc5c6
0d3bbff
0a33686
0d3bbff
b496d68
 
e0bc5c6
0d3bbff
e0bc5c6
0d3bbff
e0bc5c6
0d3bbff
e0bc5c6
0d3bbff
e0bc5c6
 
0a33686
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel
from typing import List, Tuple
import asyncio
import os

# Initialisation du client Hugging Face
token = os.environ.get("HF_TOKEN")
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)


# Initialisation de FastAPI
app = FastAPI()

# Modèle pour les données d'entrée
class PredictionRequest(BaseModel):
    message: str
    history: List[Tuple[str, str]] = []
    system_message: str = "You are a friendly Chatbot."
    max_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.95


async def generate_stream(request: PredictionRequest):
    """
    Générateur asynchrone pour produire les tokens progressivement.
    Utilise asyncio.to_thread pour rendre l'appel synchrone compatible avec async.
    """
    messages = [{"role": "system", "content": request.system_message}]
    for user_input, assistant_response in request.history:
        if user_input:
            messages.append({"role": "user", "content": user_input})
        if assistant_response:
            messages.append(
                {"role": "assistant", "content": assistant_response})
    messages.append({"role": "user", "content": request.message})

    try:
        # Exécution du client synchrone dans un thread séparé
        def sync_stream():
            return client.chat_completion(
                messages,
                max_tokens=request.max_tokens,
                stream=True,
                temperature=request.temperature,
                top_p=request.top_p,
            )

        # Appel synchrone dans un thread asynchrone
        for message in await asyncio.to_thread(sync_stream):
            token = message.choices[0].delta.content
            yield f"{token}\n"
    except Exception as e:
        yield f"Error: {str(e)}\n"


@app.post("/predict")
async def predict(request: PredictionRequest):
    """
    Endpoint REST avec réponse en streaming.
    """
    return StreamingResponse(
        generate_stream(request),
        media_type="text/plain"  # Peut être changé en JSON si besoin
    )

# Pour le test en local
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)