Spaces:
Running
on
T4
Running
on
T4
Rework demo UI.
Browse files
app.py
CHANGED
|
@@ -110,22 +110,22 @@ Arrange the given numbers in ascending order.
|
|
| 110 |
["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
|
| 111 |
]
|
| 112 |
|
| 113 |
-
infer_interface = gr.Interface(
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
).queue()
|
| 129 |
|
| 130 |
########################################################################################################
|
| 131 |
|
|
@@ -159,8 +159,12 @@ She also likes to tell {user} a lot about herself and her opinions, and she usua
|
|
| 159 |
|
| 160 |
_, intro_state = model.forward(pipeline.encode(chat_intro), None)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
def chat(
|
| 163 |
-
|
| 164 |
history,
|
| 165 |
token_count=10,
|
| 166 |
temperature=1.0,
|
|
@@ -174,6 +178,7 @@ def chat(
|
|
| 174 |
token_ban=[], # ban the generation of some tokens
|
| 175 |
token_stop=[]) # stop generation whenever you see any token here
|
| 176 |
|
|
|
|
| 177 |
message = message.strip(' ')
|
| 178 |
message = message.replace('\n', '')
|
| 179 |
ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
|
|
@@ -181,9 +186,9 @@ def chat(
|
|
| 181 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 182 |
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
| 183 |
|
| 184 |
-
history = history or [
|
| 185 |
|
| 186 |
-
[
|
| 187 |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
|
| 188 |
|
| 189 |
begin = len(all_tokens)
|
|
@@ -230,35 +235,80 @@ def chat(
|
|
| 230 |
gc.collect()
|
| 231 |
torch.cuda.empty_cache()
|
| 232 |
|
| 233 |
-
|
| 234 |
-
history = [
|
| 235 |
-
return
|
| 236 |
-
|
| 237 |
-
chat_interface = gr.Interface(
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
).queue()
|
| 255 |
|
| 256 |
########################################################################################################
|
| 257 |
|
| 258 |
-
demo = gr.TabbedInterface(
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
demo.queue(max_size=10)
|
| 264 |
-
demo.launch(share=
|
|
|
|
| 110 |
["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
|
| 111 |
]
|
| 112 |
|
| 113 |
+
# infer_interface = gr.Interface(
|
| 114 |
+
# fn=infer,
|
| 115 |
+
# description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
|
| 116 |
+
# allow_flagging="never",
|
| 117 |
+
# inputs=[
|
| 118 |
+
# gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
|
| 119 |
+
# gr.Slider(10, 200, step=10, value=150), # token_count
|
| 120 |
+
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
|
| 121 |
+
# gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
|
| 122 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
|
| 123 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
|
| 124 |
+
# ],
|
| 125 |
+
# outputs=gr.Textbox(label="Generated Output", lines=28),
|
| 126 |
+
# examples=examples,
|
| 127 |
+
# cache_examples=False,
|
| 128 |
+
# ).queue()
|
| 129 |
|
| 130 |
########################################################################################################
|
| 131 |
|
|
|
|
| 159 |
|
| 160 |
_, intro_state = model.forward(pipeline.encode(chat_intro), None)
|
| 161 |
|
| 162 |
+
def user(user_message, chatbot):
|
| 163 |
+
chatbot = chatbot or []
|
| 164 |
+
return "", chatbot + [[user_message, None]]
|
| 165 |
+
|
| 166 |
def chat(
|
| 167 |
+
chatbot,
|
| 168 |
history,
|
| 169 |
token_count=10,
|
| 170 |
temperature=1.0,
|
|
|
|
| 178 |
token_ban=[], # ban the generation of some tokens
|
| 179 |
token_stop=[]) # stop generation whenever you see any token here
|
| 180 |
|
| 181 |
+
message = chatbot[-1][0]
|
| 182 |
message = message.strip(' ')
|
| 183 |
message = message.replace('\n', '')
|
| 184 |
ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
|
|
|
|
| 186 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 187 |
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
| 188 |
|
| 189 |
+
history = history or [intro_state, []] # [chat, state, all_tokens]
|
| 190 |
|
| 191 |
+
[state, all_tokens] = history
|
| 192 |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
|
| 193 |
|
| 194 |
begin = len(all_tokens)
|
|
|
|
| 235 |
gc.collect()
|
| 236 |
torch.cuda.empty_cache()
|
| 237 |
|
| 238 |
+
chatbot[-1][1] = out_str.strip()
|
| 239 |
+
history = [state, all_tokens]
|
| 240 |
+
return chatbot, history
|
| 241 |
+
|
| 242 |
+
# chat_interface = gr.Interface(
|
| 243 |
+
# fn=chat,
|
| 244 |
+
# description=f'''You are {user}, bot is {bot}.''',
|
| 245 |
+
# allow_flagging="never",
|
| 246 |
+
# inputs = [
|
| 247 |
+
# gr.Textbox(label="Message"),
|
| 248 |
+
# "state",
|
| 249 |
+
# gr.Slider(10, 1000, step=10, value=250), # token_count
|
| 250 |
+
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
|
| 251 |
+
# gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
|
| 252 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
|
| 253 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
|
| 254 |
+
# ],
|
| 255 |
+
# outputs=[
|
| 256 |
+
# gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
|
| 257 |
+
# "state"
|
| 258 |
+
# ]
|
| 259 |
+
# ).queue()
|
| 260 |
|
| 261 |
########################################################################################################
|
| 262 |
|
| 263 |
+
# demo = gr.TabbedInterface(
|
| 264 |
+
# [infer_interface, chat_interface], ["Generative", "Chat"],
|
| 265 |
+
# title=title,
|
| 266 |
+
# )
|
| 267 |
+
|
| 268 |
+
# demo.queue(max_size=10)
|
| 269 |
+
# demo.launch(share=True)
|
| 270 |
+
|
| 271 |
+
with gr.Blocks() as demo:
|
| 272 |
+
with gr.Tab("Generative"):
|
| 273 |
+
with gr.Row():
|
| 274 |
+
with gr.Column():
|
| 275 |
+
prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
|
| 276 |
+
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
|
| 277 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
| 278 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
|
| 279 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
|
| 280 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
|
| 281 |
+
with gr.Column():
|
| 282 |
+
with gr.Row():
|
| 283 |
+
submit = gr.Button("Submit")
|
| 284 |
+
clear = gr.Button("Clear")
|
| 285 |
+
output = gr.Textbox(label="Generated Output", lines=28)
|
| 286 |
+
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
| 287 |
+
submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
|
| 288 |
+
clear.click(lambda: None, [], [output])
|
| 289 |
+
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
| 290 |
+
with gr.Tab("Chat"):
|
| 291 |
+
with gr.Row():
|
| 292 |
+
with gr.Column():
|
| 293 |
+
chatbot = gr.Chatbot()
|
| 294 |
+
state = gr.State()
|
| 295 |
+
message = gr.Textbox(label="Message")
|
| 296 |
+
with gr.Row():
|
| 297 |
+
send = gr.Button("Send")
|
| 298 |
+
clear = gr.Button("Clear")
|
| 299 |
+
with gr.Column():
|
| 300 |
+
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
|
| 301 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
| 302 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
|
| 303 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
|
| 304 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
|
| 305 |
+
message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
|
| 306 |
+
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
|
| 307 |
+
)
|
| 308 |
+
send.click(user, [message, chatbot], [message, chatbot], queue=False).then(
|
| 309 |
+
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
|
| 310 |
+
)
|
| 311 |
+
clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
|
| 312 |
|
| 313 |
demo.queue(max_size=10)
|
| 314 |
+
demo.launch(share=False)
|