Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from huggingface_hub import InferenceClient | |
| from pydantic import BaseModel | |
| from typing import List, Tuple | |
| import asyncio | |
| import os | |
| # Initialisation du client Hugging Face | |
| token = os.environ.get("HF_TOKEN") | |
| # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token) | |
| client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token) | |
| # Initialisation de FastAPI | |
| app = FastAPI() | |
| # Modèle pour les données d'entrée | |
| class PredictionRequest(BaseModel): | |
| message: str | |
| history: List[Tuple[str, str]] = [] | |
| system_message: str = "You are a friendly Chatbot." | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| async def generate_stream(request: PredictionRequest): | |
| """ | |
| Générateur asynchrone pour produire les tokens progressivement. | |
| Utilise asyncio.to_thread pour rendre l'appel synchrone compatible avec async. | |
| """ | |
| messages = [{"role": "system", "content": request.system_message}] | |
| for user_input, assistant_response in request.history: | |
| if user_input: | |
| messages.append({"role": "user", "content": user_input}) | |
| if assistant_response: | |
| messages.append( | |
| {"role": "assistant", "content": assistant_response}) | |
| messages.append({"role": "user", "content": request.message}) | |
| try: | |
| # Exécution du client synchrone dans un thread séparé | |
| def sync_stream(): | |
| return client.chat_completion( | |
| messages, | |
| max_tokens=request.max_tokens, | |
| stream=True, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| ) | |
| # Appel synchrone dans un thread asynchrone | |
| for message in await asyncio.to_thread(sync_stream): | |
| token = message.choices[0].delta.content | |
| yield f"{token}\n" | |
| except Exception as e: | |
| yield f"Error: {str(e)}\n" | |
| async def predict(request: PredictionRequest): | |
| """ | |
| Endpoint REST avec réponse en streaming. | |
| """ | |
| return StreamingResponse( | |
| generate_stream(request), | |
| media_type="text/plain" # Peut être changé en JSON si besoin | |
| ) | |
| # Pour le test en local | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |