d3dname commited on
Commit
e8e2b88
·
verified ·
1 Parent(s): ae00ac5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from typing import List, Tuple
6
+ from huggingface_hub import InferenceClient
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ app = FastAPI()
13
+
14
+ # Add CORS middleware
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
24
+
25
+ class ChatRequest(BaseModel):
26
+ message: str
27
+ history: List[Tuple[str, str]]
28
+ system_message: str
29
+ max_tokens: int
30
+ temperature: float
31
+ top_p: float
32
+
33
+ def respond(
34
+ message,
35
+ history: list[tuple[str, str]],
36
+ max_tokens,
37
+ temperature,
38
+ top_p,
39
+ system_message: str = """You are a chatbot serving a user a text based adventure. When the user says 'start adventure', you will write a short (((70 word))) adventure story with 2 to 4 choices for the user to take at the end. Progress the story based on their choices. Number the choices as 1,2,3 and 4 etc. Don't take the choice yourself. Wait for the user to respond.""",
40
+ ):
41
+ messages = [{"role": "system", "content": system_message}]
42
+
43
+ for val in history:
44
+ if val[0]:
45
+ messages.append({"role": "user", "content": val[0]})
46
+ if val[1]:
47
+ messages.append({"role": "assistant", "content": val[1]})
48
+
49
+ messages.append({"role": "user", "content": message})
50
+
51
+ response = ""
52
+
53
+ for message in client.chat_completion(
54
+ messages,
55
+ max_tokens=max_tokens,
56
+ stream=True,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ ):
60
+ token = message.choices[0].delta.content
61
+
62
+ response += token
63
+ yield response
64
+
65
+ @app.post("/chat")
66
+ async def chat_endpoint(request: ChatRequest):
67
+ try:
68
+ response = respond(
69
+ request.message,
70
+ request.history,
71
+ request.max_tokens,
72
+ request.temperature,
73
+ request.top_p,
74
+ request.system_message,
75
+ )
76
+ return {"response": list(response)}
77
+ except Exception as e:
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+ # Gradio interface
81
+ demo = gr.ChatInterface(
82
+ respond,
83
+ additional_inputs=[
84
+ gr.Slider(minimum=1, maximum=2048, value=250, step=1, label="Max new tokens"),
85
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
86
+ gr.Slider(
87
+ minimum=0.1,
88
+ maximum=1.0,
89
+ value=0.95,
90
+ step=0.05,
91
+ label="Top-p (nucleus sampling)",
92
+ ),
93
+ ],
94
+ )
95
+
96
+ # Mount the Gradio app
97
+ app = gr.mount_gradio_app(app, demo, path="/")
98
+
99
+ if __name__ == "__main__":
100
+ import uvicorn
101
+ uvicorn.run(app, host="0.0.0.0", port=7860)