ThomasSimonini commited on
Commit
779a991
·
1 Parent(s): f02b2fd

Create the app

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import os
4
+ import requests
5
+
6
+ hf_token = os.getenv("HF_TOKEN")
7
+ api_url_7b = os.getenv("API_URL_LLAMA_7")
8
+ api_url_13b = os.getenv("API_URL_LLAMA_13")
9
+ api_url_70b = os.getenv("API_URL_LLAMA_70")
10
+
11
+ headers = {
12
+ 'Content-Type': 'application/json',
13
+ }
14
+
15
+ def predict(message,
16
+ chatbot,
17
+ system_prompt = "",
18
+ temperature = 0.9,
19
+ max_new_tokens = 256,
20
+ top_p = 0.6,
21
+ repetition_penalty = 1.0,
22
+ model):
23
+
24
+ # Write the system prompt
25
+ if system_prompt != "":
26
+ input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
27
+ else:
28
+ input_prompt = f"<s>[INST] "
29
+
30
+ temperature = float(temperature)
31
+
32
+ # We check that temperature is not less than 1e-2
33
+ if temperature < 1e-2:
34
+ temperature = 1e-2
35
+
36
+ top_p = float(top_p)
37
+
38
+ for interaction in chatbot:
39
+ input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
40
+
41
+ input_prompt = input_prompt + str(message) + " [/INST] "
42
+
43
+ data = {
44
+ "inputs": input_prompt,
45
+ "parameters": {
46
+ "max_new_tokens": max_new_tokens,
47
+ "temperature": temperature,
48
+ "top_p": top_p,
49
+ "repetition_penalty": repetition_penalty,
50
+ "do_sample": True,
51
+ },
52
+ }
53
+
54
+ response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=("hf, hf_token"), stream=True)
55
+
56
+ partial_message = ""
57
+ for line in response.iter_lines():
58
+ if line: # filter out keep-alive new lines
59
+ # Decode from bytes to string
60
+ decoded_line = line.decode('utf-8')
61
+
62
+ # Remove 'data:' prefix
63
+ if decoded_line.startswith('data:'):
64
+ json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
65
+ else:
66
+ gr.Warning(f"This line does not start with 'data:': {decoded_line}")
67
+ continue
68
+
69
+ # Load as JSON
70
+ try:
71
+ json_obj = json.loads(json_line)
72
+ if 'token' in json_obj:
73
+ partial_message = partial_message + json_obj['token']['text']
74
+ yield partial_message
75
+ elif 'error' in json_obj:
76
+ yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
77
+ else:
78
+ gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
79
+
80
+ except json.JSONDecodeError:
81
+ gr.Warning(f"This line is not valid JSON: {json_line}")
82
+ continue
83
+ except KeyError as e:
84
+ gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
85
+ continue
86
+
87
+
88
+ additional_inputs=[
89
+ gr.Dropdown(["api_url_7b", "api_url_13b", "api_url_70b"], label="Model", info="Which model to use?")
90
+ gr.Textbox("", label="Optional system prompt")
91
+ gr.Slider(
92
+ label="Temperature",
93
+ value=0.9,
94
+ minimum=0.0,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ interactive=True,
98
+ info="Higher values produce more diverse outputs",
99
+ ),
100
+ gr.Slider(
101
+ label="Max new tokens",
102
+ value=256,
103
+ minimum=0,
104
+ maximum=4096,
105
+ step=64,
106
+ interactive=True,
107
+ info="The maximum numbers of new tokens",
108
+ ),
109
+ gr.Slider(
110
+ label="Top-p (nucleus sampling)",
111
+ value=0.6,
112
+ minimum=0.0,
113
+ maximum=1,
114
+ step=0.05,
115
+ interactive=True,
116
+ info="Higher values sample more low-probability tokens",
117
+ ),
118
+ gr.Slider(
119
+ label="Repetition penalty",
120
+ value=1.2,
121
+ minimum=1.0,
122
+ maximum=2.0,
123
+ step=0.05,
124
+ interactive=True,
125
+ info="Penalize repeated tokens",
126
+ )
127
+
128
+ ]
129
+
130
+ chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
131
+
132
+ chat_interface_stream = gr.ChatInterface(predict,
133
+ title=title,
134
+ description=description,
135
+ textbox=gr.Textbox(),
136
+ chatbot=chatbot_stream,
137
+ css=css,
138
+ examples=examples,
139
+ cache_examples=True,
140
+ additional_inputs=additional_inputs,
141
+ model = model)
142
+
143
+ # Gradio Demo
144
+ with gr.Blocks() as demo:
145
+ with gr.Tab("Llama 70B"):
146
+ chat_interface.render()
147
+