Spaces:
Paused
Paused
File size: 2,405 Bytes
0a33686 e0bc5c6 0a33686 e0bc5c6 0a33686 0d3bbff d56d0b4 f3affbe 0a33686 e87158e 673f80d e2cdbc6 f3affbe 0a33686 f3affbe 0a33686 e0bc5c6 0d3bbff 0a33686 0d3bbff 0a33686 41f2a92 0a33686 41f2a92 0a33686 0d3bbff e0bc5c6 0d3bbff 0a33686 0d3bbff b496d68 e0bc5c6 0d3bbff e0bc5c6 0d3bbff e0bc5c6 0d3bbff e0bc5c6 0d3bbff e0bc5c6 0a33686 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel
from typing import List, Tuple
import asyncio
import os
# Initialisation du client Hugging Face
token = os.environ.get("HF_TOKEN")
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
# Initialisation de FastAPI
app = FastAPI()
# Modèle pour les données d'entrée
class PredictionRequest(BaseModel):
message: str
history: List[Tuple[str, str]] = []
system_message: str = "You are a friendly Chatbot."
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
async def generate_stream(request: PredictionRequest):
"""
Générateur asynchrone pour produire les tokens progressivement.
Utilise asyncio.to_thread pour rendre l'appel synchrone compatible avec async.
"""
messages = [{"role": "system", "content": request.system_message}]
for user_input, assistant_response in request.history:
if user_input:
messages.append({"role": "user", "content": user_input})
if assistant_response:
messages.append(
{"role": "assistant", "content": assistant_response})
messages.append({"role": "user", "content": request.message})
try:
# Exécution du client synchrone dans un thread séparé
def sync_stream():
return client.chat_completion(
messages,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
top_p=request.top_p,
)
# Appel synchrone dans un thread asynchrone
for message in await asyncio.to_thread(sync_stream):
token = message.choices[0].delta.content
yield f"{token}\n"
except Exception as e:
yield f"Error: {str(e)}\n"
@app.post("/predict")
async def predict(request: PredictionRequest):
"""
Endpoint REST avec réponse en streaming.
"""
return StreamingResponse(
generate_stream(request),
media_type="text/plain" # Peut être changé en JSON si besoin
)
# Pour le test en local
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|