| import gradio as gr |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from typing import List, Tuple |
| from huggingface_hub import InferenceClient |
| import os |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN")) |
|
|
| class ChatRequest(BaseModel): |
| message: str |
| history: List[Tuple[str, str]] |
| system_message: str |
| max_tokens: int |
| temperature: float |
| top_p: float |
|
|
| def respond( |
| message, |
| history: list[tuple[str, str]], |
| max_tokens, |
| temperature, |
| top_p, |
| system_message: str = """You are a chatbot serving a user a text based adventure. When the user says 'start adventure', you will write a short (((70 word))) adventure story with 2 to 4 choices for the user to take at the end. Progress the story based on their choices. Number the choices as 1,2,3 and 4 etc. Don't take the choice yourself. Wait for the user to respond.""", |
| ): |
| messages = [{"role": "system", "content": system_message}] |
|
|
| for val in history: |
| if val[0]: |
| messages.append({"role": "user", "content": val[0]}) |
| if val[1]: |
| messages.append({"role": "assistant", "content": val[1]}) |
|
|
| messages.append({"role": "user", "content": message}) |
|
|
| response = "" |
|
|
| for message in client.chat_completion( |
| messages, |
| max_tokens=max_tokens, |
| stream=True, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| token = message.choices[0].delta.content |
|
|
| response += token |
| yield response |
|
|
| @app.post("/chat") |
| async def chat_endpoint(request: ChatRequest): |
| try: |
| response = respond( |
| request.message, |
| request.history, |
| request.max_tokens, |
| request.temperature, |
| request.top_p, |
| request.system_message, |
| ) |
| return {"response": list(response)} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Slider(minimum=1, maximum=2048, value=250, step=1, label="Max new tokens"), |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.95, |
| step=0.05, |
| label="Top-p (nucleus sampling)", |
| ), |
| ], |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |