cherokee / app2.py
amber
1
777015e
import gradio as gr
from dotenv import load_dotenv
from models import get_all_models, get_random_models
load_dotenv()
share_js = """
function () {
const captureElement = document.querySelector('#share-region-annoy');
// console.log(captureElement);
html2canvas(captureElement)
.then(canvas => {
canvas.style.display = 'none'
document.body.appendChild(canvas)
return canvas
})
.then(canvas => {
const image = canvas.toDataURL('image/png')
const a = document.createElement('a')
a.setAttribute('download', 'guardrails-arena.png')
a.setAttribute('href', image)
a.click()
canvas.remove()
});
return [];
}
"""
def activate_chat_buttons():
regenerate_btn = gr.Button(
value="🔄 Regenerate", interactive=True, elem_id="regenerate_btn"
)
clear_btn = gr.ClearButton(
elem_id="clear_btn",
interactive=True,
)
return regenerate_btn, clear_btn
def deactivate_chat_buttons():
regenerate_btn = gr.Button(
value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn"
)
clear_btn = gr.ClearButton(
elem_id="clear_btn",
interactive=False,
)
return regenerate_btn, clear_btn
def handle_message(
llms, user_input, temperature, top_p, max_output_tokens, states1, states2
):
history1 = states1.value if states1 else []
history2 = states2.value if states2 else []
states = [states1, states2]
history = [history1, history2]
for hist in history:
hist.append((user_input, None))
for (
updated_history1,
updated_history2,
updated_states1,
updated_states2,
) in process_responses(
llms, temperature, top_p, max_output_tokens, history, states
):
yield updated_history1, updated_history2, updated_states1, updated_states2
def regenerate_message(llms, temperature, top_p, max_output_tokens, states1, states2):
history1 = states1.value if states1 else []
history2 = states2.value if states2 else []
user_input = (
history1.pop()[0] if history1 else None
) # Assumes regeneration is needed so there is at least one input
if history2:
history2.pop()
states = [states1, states2]
history = [history1, history2]
for hist in history:
hist.append((user_input, None))
for (
updated_history1,
updated_history2,
updated_states1,
updated_states2,
) in process_responses(
llms, temperature, top_p, max_output_tokens, history, states
):
yield updated_history1, updated_history2, updated_states1, updated_states2
def process_responses(llms, temperature, top_p, max_output_tokens, history, states):
generators = [
llms[i]["model"](history[i], temperature, top_p, max_output_tokens)
for i in range(2)
]
responses = [[], []]
done = [False, False]
while not all(done):
for i in range(2):
if not done[i]:
try:
response = next(generators[i])
if response:
responses[i].append(response)
history[i][-1] = (history[i][-1][0], "".join(responses[i]))
states[i] = gr.State(history[i])
yield history[0], history[1], states[0], states[1]
except StopIteration:
done[i] = True
yield history[0], history[1], states[0], states[1]
with gr.Blocks(
title="Chatbot Arena",
# theme=gr.themes.Soft(secondary_hue=gr.themes.colors.sky),
theme=gr.themes.Base(),
) as demo:
num_sides = 2
states = [gr.State() for _ in range(num_sides)]
chatbots = [None] * num_sides
models = gr.State(get_random_models)
all_models = get_all_models()
# gr.Markdown(
# "# Cherokee Language Preserve Model V0.4 \n\nChat with multiple models at the same time and compare their responses. "
# )
with gr.Group(elem_id="share-region-annoy"):
# with gr.Accordion(f"🔍 Expand to see the {len(all_models)} models", open=False):
# model_description_md = """| | | |\n| ---- | ---- | ---- |\n"""
# count = 0
# for model in all_models:
# if count % 3 == 0:
# model_description_md += "|"
# model_description_md += f" {model['name']} |"
# if count % 3 == 2:
# model_description_md += "\n"
# count += 1
# gr.Markdown(model_description_md, elem_id="model_description_markdown")
with gr.Row():
for i in range(num_sides):
label = models.value[i]["name"]
with gr.Column():
chatbots[i] = gr.Chatbot(
label=label,
elem_id=f"chatbot",
height=550,
show_copy_button=True,
)
with gr.Row():
textbox = gr.Textbox(
show_label=False,
placeholder="Enter your query and press ENTER",
elem_id="input_box",
scale=4,
)
send_btn = gr.Button(value="Send", variant="primary", scale=0)
with gr.Row() as button_row:
upload_button = gr.UploadButton("Upload File", file_types=["image", "video"], file_count="multiple")
clear_btn = gr.ClearButton(
value="🎲 New Round",
elem_id="clear_btn",
interactive=False,
components=chatbots + states,
)
regenerate_btn = gr.Button(
value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn"
)
share_btn = gr.Button(value="📷 Share Image")
with gr.Row():
examples = gr.Examples(
[
"translate: ᎧᏃᎮᏍᎩ",
"Could you assist in rendering this Cherokee word into English?\nᎤᏲᎢ",
"translate the following Cherokee word into English. ᏧᎩᏨᏅᏓ",
],
inputs=[textbox],
label="Example task: Translate words",
)
with gr.Row():
examples = gr.Examples(
[
"translate: ᏚᏁᏤᎴᏃ ᎬᏩᏍᏓᏩᏗᏙᎯ, ᎾᏍᎩ ᏥᏳ ᎤᎦᏘᏗᏍᏗᏱ, ᎤᏂᏣᏘ ᎨᏒ ᎢᏳᏍᏗ, ᎾᏍᎩ ᎬᏩᏁᏄᎳᏍᏙᏗᏱ ᏂᎨᏒᎾ",
"translate following Cherokee sentences into English.\nᏥᏌᏃ ᎤᏓᏅᏎ ᏚᏘᏅᏎ ᎬᏩᏍᏓᏩᏗᏙᎯ ᎥᏓᎵ ᏭᏂᎶᏎᎢ; ᎤᏂᏣᏘᏃ ᎬᏩᏍᏓᏩᏛᏎᎢ, ᏅᏓᏳᏂᎶᏒᎯ ᎨᎵᎵ, ᎠᎴ ᏧᏗᏱ,",
"translate following sentences.\nᎯᎠᏃ ᏄᏪᏎᎴ ᎠᏍᎦᏯ ᎤᏬᏰᏂ ᎤᏩᎢᏎᎸᎯ; ᎠᏰᎵ ᎭᎴᎲᎦ.",
],
inputs=[textbox],
label="Example task: Translate sentences",
)
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.5,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=16,
maximum=4096,
value=1024,
step=64,
interactive=True,
label="Max output tokens",
)
textbox.submit(
handle_message,
inputs=[
models,
textbox,
temperature,
top_p,
max_output_tokens,
states[0],
states[1],
],
outputs=[chatbots[0], chatbots[1], states[0], states[1]],
).then(
activate_chat_buttons,
inputs=[],
outputs=[regenerate_btn, clear_btn],
)
send_btn.click(
handle_message,
inputs=[
models,
textbox,
temperature,
top_p,
max_output_tokens,
states[0],
states[1],
],
outputs=[chatbots[0], chatbots[1], states[0], states[1]],
).then(
activate_chat_buttons,
inputs=[],
outputs=[regenerate_btn, clear_btn],
)
regenerate_btn.click(
regenerate_message,
inputs=[
models,
temperature,
top_p,
max_output_tokens,
states[0],
states[1],
],
outputs=[chatbots[0], chatbots[1], states[0], states[1]],
)
clear_btn.click(
deactivate_chat_buttons,
inputs=[],
outputs=[regenerate_btn, clear_btn],
).then(lambda: get_random_models(), inputs=None, outputs=[models])
share_btn.click(None, inputs=[], outputs=[], js=share_js)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10)
demo.launch(server_name="127.0.01", server_port=5009, share=True)