laria-startup / app.py
MiCkSoftware's picture
back to meta
673f80d
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel
from typing import List, Tuple
import asyncio
import os
# Initialisation du client Hugging Face
token = os.environ.get("HF_TOKEN")
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
# Initialisation de FastAPI
app = FastAPI()
# Modèle pour les données d'entrée
class PredictionRequest(BaseModel):
message: str
history: List[Tuple[str, str]] = []
system_message: str = "You are a friendly Chatbot."
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
async def generate_stream(request: PredictionRequest):
"""
Générateur asynchrone pour produire les tokens progressivement.
Utilise asyncio.to_thread pour rendre l'appel synchrone compatible avec async.
"""
messages = [{"role": "system", "content": request.system_message}]
for user_input, assistant_response in request.history:
if user_input:
messages.append({"role": "user", "content": user_input})
if assistant_response:
messages.append(
{"role": "assistant", "content": assistant_response})
messages.append({"role": "user", "content": request.message})
try:
# Exécution du client synchrone dans un thread séparé
def sync_stream():
return client.chat_completion(
messages,
max_tokens=request.max_tokens,
stream=True,
temperature=request.temperature,
top_p=request.top_p,
)
# Appel synchrone dans un thread asynchrone
for message in await asyncio.to_thread(sync_stream):
token = message.choices[0].delta.content
yield f"{token}\n"
except Exception as e:
yield f"Error: {str(e)}\n"
@app.post("/predict")
async def predict(request: PredictionRequest):
"""
Endpoint REST avec réponse en streaming.
"""
return StreamingResponse(
generate_stream(request),
media_type="text/plain" # Peut être changé en JSON si besoin
)
# Pour le test en local
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)