ashwin / app.py
d3dname's picture
Create app.py
e8e2b88 verified
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Tuple
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
class ChatRequest(BaseModel):
message: str
history: List[Tuple[str, str]]
system_message: str
max_tokens: int
temperature: float
top_p: float
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
system_message: str = """You are a chatbot serving a user a text based adventure. When the user says 'start adventure', you will write a short (((70 word))) adventure story with 2 to 4 choices for the user to take at the end. Progress the story based on their choices. Number the choices as 1,2,3 and 4 etc. Don't take the choice yourself. Wait for the user to respond.""",
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
response = respond(
request.message,
request.history,
request.max_tokens,
request.temperature,
request.top_p,
request.system_message,
)
return {"response": list(response)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=250, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
# Mount the Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)