peterciank commited on
Commit
1a62c2c
·
verified ·
1 Parent(s): a6212b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py CHANGED
@@ -14,6 +14,240 @@ def query(payload):
14
  response = requests.post(API_URL, headers=headers, json=payload)
15
  return response.json()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Streamlit UI
18
  st.title("Chat App with Hugging Face")
19
  user_input = st.text_input("You:", "")
 
14
  response = requests.post(API_URL, headers=headers, json=payload)
15
  return response.json()
16
 
17
+
18
+ rom text_generation import Client, InferenceAPIClient
19
+
20
+ openchat_preprompt = (
21
+ "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
22
+ "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
23
+ "community. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
24
+ "but I am programmed to be helpful, polite, honest, and friendly.\n"
25
+ )
26
+
27
+
28
+ def get_client(model: str):
29
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
30
+ return Client(os.getenv("OPENCHAT_API_URL"))
31
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
32
+
33
+
34
+ def get_usernames(model: str):
35
+ """
36
+ Returns:
37
+ (str, str, str, str): pre-prompt, username, bot name, separator
38
+ """
39
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
40
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
41
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
42
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
43
+ return "", "User: ", "Assistant: ", "\n"
44
+
45
+
46
+ def predict(
47
+ model: str,
48
+ inputs: str,
49
+ typical_p: float,
50
+ top_p: float,
51
+ temperature: float,
52
+ top_k: int,
53
+ repetition_penalty: float,
54
+ watermark: bool,
55
+ chatbot,
56
+ history,
57
+ ):
58
+ client = get_client(model)
59
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
60
+
61
+ history.append(inputs)
62
+
63
+ past = []
64
+ for data in chatbot:
65
+ user_data, model_data = data
66
+
67
+ if not user_data.startswith(user_name):
68
+ user_data = user_name + user_data
69
+ if not model_data.startswith(sep + assistant_name):
70
+ model_data = sep + assistant_name + model_data
71
+
72
+ past.append(user_data + model_data.rstrip() + sep)
73
+
74
+ if not inputs.startswith(user_name):
75
+ inputs = user_name + inputs
76
+
77
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
78
+
79
+ partial_words = ""
80
+
81
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
82
+ iterator = client.generate_stream(
83
+ total_inputs,
84
+ typical_p=typical_p,
85
+ truncate=1000,
86
+ watermark=watermark,
87
+ max_new_tokens=500,
88
+ )
89
+ else:
90
+ iterator = client.generate_stream(
91
+ total_inputs,
92
+ top_p=top_p if top_p < 1.0 else None,
93
+ top_k=top_k,
94
+ truncate=1000,
95
+ repetition_penalty=repetition_penalty,
96
+ watermark=watermark,
97
+ temperature=temperature,
98
+ max_new_tokens=500,
99
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
100
+ )
101
+
102
+ for i, response in enumerate(iterator):
103
+ if response.token.special:
104
+ continue
105
+
106
+ partial_words = partial_words + response.token.text
107
+ if partial_words.endswith(user_name.rstrip()):
108
+ partial_words = partial_words.rstrip(user_name.rstrip())
109
+ if partial_words.endswith(assistant_name.rstrip()):
110
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
111
+
112
+ if i == 0:
113
+ history.append(" " + partial_words)
114
+ elif response.token.text not in user_name:
115
+ history[-1] = partial_words
116
+
117
+ chat = [
118
+ (history[i].strip(), history[i + 1].strip())
119
+ for i in range(0, len(history) - 1, 2)
120
+ ]
121
+ yield chat, history
122
+
123
+
124
+ def reset_textbox():
125
+ return gr.update(value="")
126
+
127
+
128
+ def radio_on_change(
129
+ value: str,
130
+ disclaimer,
131
+ typical_p,
132
+ top_p,
133
+ top_k,
134
+ temperature,
135
+ repetition_penalty,
136
+ watermark,
137
+ ):
138
+ if value in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
139
+ typical_p = typical_p.update(value=0.2, visible=True)
140
+ top_p = top_p.update(visible=False)
141
+ top_k = top_k.update(visible=False)
142
+ temperature = temperature.update(visible=False)
143
+ disclaimer = disclaimer.update(visible=False)
144
+ repetition_penalty = repetition_penalty.update(visible=False)
145
+ watermark = watermark.update(False)
146
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
147
+ typical_p = typical_p.update(visible=False)
148
+ top_p = top_p.update(value=0.25, visible=True)
149
+ top_k = top_k.update(value=50, visible=True)
150
+ temperature = temperature.update(value=0.6, visible=True)
151
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
152
+ watermark = watermark.update(False)
153
+ disclaimer = disclaimer.update(visible=True)
154
+ else:
155
+ typical_p = typical_p.update(visible=False)
156
+ top_p = top_p.update(value=0.95, visible=True)
157
+ top_k = top_k.update(value=4, visible=True)
158
+ temperature = temperature.update(value=0.5, visible=True)
159
+ repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
160
+ watermark = watermark.update(True)
161
+ disclaimer = disclaimer.update(visible=False)
162
+ return (
163
+ disclaimer,
164
+ typical_p,
165
+ top_p,
166
+ top_k,
167
+ temperature,
168
+ repetition_penalty,
169
+ watermark,
170
+ )
171
+
172
+
173
+ title = """<h1 align="center">Large Language Model Chat API</h1>"""
174
+ description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
175
+ ```
176
+ User: <utterance>
177
+ Assistant: <utterance>
178
+ User: <utterance>
179
+ Assistant: <utterance>
180
+ ...
181
+ ```
182
+ In this app, you can explore the outputs of multiple LLMs when prompted in this way.
183
+ """
184
+
185
+ text_generation_inference = """
186
+ <div align="center">Powered by: <a href=https://github.com/huggingface/text-generation-inference>Text Generation Inference</a></div>
187
+ """
188
+
189
+ openchat_disclaimer = """
190
+ <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
191
+ """
192
+ import streamlit as st
193
+
194
+ # CSS styles
195
+ st.markdown(
196
+ """
197
+ <style>
198
+ #col_container {margin-left: auto; margin-right: auto;}
199
+ #chatbot {height: 520px; overflow: auto;}
200
+ </style>
201
+ """,
202
+ unsafe_allow_html=True
203
+ )
204
+
205
+ # Title and description
206
+ st.title(title)
207
+ st.markdown(text_generation_inference)
208
+
209
+ # Model selection
210
+ model = st.radio(
211
+ "Model",
212
+ [
213
+ "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
214
+ "OpenAssistant/oasst-sft-1-pythia-12b",
215
+ "google/flan-t5-xxl",
216
+ "google/flan-ul2",
217
+ "bigscience/bloom",
218
+ "bigscience/bloomz",
219
+ "EleutherAI/gpt-neox-20b",
220
+ ]
221
+ )
222
+
223
+ # Input textbox
224
+ input_text = st.text_input(label="Type an input and press Enter", placeholder="Hi there!")
225
+
226
+ # Parameters
227
+ with st.expander("Parameters", expanded=False):
228
+ typical_p = st.slider("Typical P mass", min_value=0.0, max_value=1.0, value=0.2, step=0.05)
229
+ top_p = st.slider("Top-p (nucleus sampling)", min_value=0.0, max_value=1.0, value=0.25, step=0.05)
230
+ temperature = st.slider("Temperature", min_value=0.0, max_value=5.0, value=0.6, step=0.1)
231
+ top_k = st.slider("Top-k", min_value=1, max_value=50, value=50, step=1)
232
+ repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=3.0, value=1.03, step=0.01)
233
+ watermark = st.checkbox("Text watermarking", value=False)
234
+
235
+ # Submit button
236
+ if st.button("Submit"):
237
+ # Perform prediction
238
+ predict(model, input_text, typical_p, top_p, temperature, top_k, repetition_penalty, watermark)
239
+
240
+ # Reset button
241
+ if st.button("Reset"):
242
+ input_text = ""
243
+
244
+ # Description
245
+ st.markdown(description)
246
+
247
+
248
+
249
+
250
+
251
  # Streamlit UI
252
  st.title("Chat App with Hugging Face")
253
  user_input = st.text_input("You:", "")