|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from models import get_all_models, get_random_models |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
share_js = """ |
|
|
function () { |
|
|
const captureElement = document.querySelector('#share-region-annoy'); |
|
|
// console.log(captureElement); |
|
|
html2canvas(captureElement) |
|
|
.then(canvas => { |
|
|
canvas.style.display = 'none' |
|
|
document.body.appendChild(canvas) |
|
|
return canvas |
|
|
}) |
|
|
.then(canvas => { |
|
|
const image = canvas.toDataURL('image/png') |
|
|
const a = document.createElement('a') |
|
|
a.setAttribute('download', 'guardrails-arena.png') |
|
|
a.setAttribute('href', image) |
|
|
a.click() |
|
|
canvas.remove() |
|
|
}); |
|
|
return []; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def activate_chat_buttons(): |
|
|
regenerate_btn = gr.Button( |
|
|
value="🔄 Regenerate", interactive=True, elem_id="regenerate_btn" |
|
|
) |
|
|
clear_btn = gr.ClearButton( |
|
|
elem_id="clear_btn", |
|
|
interactive=True, |
|
|
) |
|
|
return regenerate_btn, clear_btn |
|
|
|
|
|
|
|
|
def deactivate_chat_buttons(): |
|
|
regenerate_btn = gr.Button( |
|
|
value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn" |
|
|
) |
|
|
clear_btn = gr.ClearButton( |
|
|
elem_id="clear_btn", |
|
|
interactive=False, |
|
|
) |
|
|
return regenerate_btn, clear_btn |
|
|
|
|
|
|
|
|
def handle_message( |
|
|
llms, user_input, temperature, top_p, max_output_tokens, states1, states2, states3, states4 |
|
|
): |
|
|
history1 = states1.value if states1 else [] |
|
|
history2 = states2.value if states2 else [] |
|
|
history3 = states3.value if states3 else [] |
|
|
history4 = states4.value if states4 else [] |
|
|
states = [states1, states2,states3, states4] |
|
|
history = [history1, history2,history3, history4] |
|
|
for hist in history: |
|
|
hist.append((user_input, None)) |
|
|
for ( |
|
|
updated_history1, |
|
|
updated_history2, |
|
|
updated_history3, |
|
|
updated_history4, |
|
|
updated_states1, |
|
|
updated_states2, |
|
|
updated_states3, |
|
|
updated_states4, |
|
|
) in process_responses( |
|
|
llms, temperature, top_p, max_output_tokens, history, states |
|
|
): |
|
|
yield updated_history1, updated_history2,updated_history3, updated_history4, updated_states1, updated_states2,updated_states3, updated_states4 |
|
|
|
|
|
|
|
|
def regenerate_message(llms, temperature, top_p, max_output_tokens, states1, states2, states3, states4): |
|
|
history1 = states1.value if states1 else [] |
|
|
history2 = states2.value if states2 else [] |
|
|
history3 = states3.value if states3 else [] |
|
|
history4 = states4.value if states4 else [] |
|
|
user_input = ( |
|
|
history1.pop()[0] if history1 else None |
|
|
) |
|
|
if history2: |
|
|
history2.pop() |
|
|
if history3: |
|
|
history3.pop() |
|
|
if history4: |
|
|
history4.pop() |
|
|
states = [states1, states2,states3, states4] |
|
|
history = [history1, history2,history3, history4] |
|
|
for hist in history: |
|
|
hist.append((user_input, None)) |
|
|
for ( |
|
|
updated_history1, |
|
|
updated_history2, |
|
|
updated_history3, |
|
|
updated_history4, |
|
|
updated_states1, |
|
|
updated_states2, |
|
|
updated_states3, |
|
|
updated_states4, |
|
|
) in process_responses( |
|
|
llms, temperature, top_p, max_output_tokens, history, states |
|
|
): |
|
|
yield updated_history1, updated_history2,updated_history3, updated_history4, updated_states1, updated_states2,updated_states3, updated_states4 |
|
|
|
|
|
|
|
|
def process_responses(llms, temperature, top_p, max_output_tokens, history, states): |
|
|
generators = [ |
|
|
llms[i]["model"](history[i], temperature, top_p, max_output_tokens) |
|
|
for i in range(4) |
|
|
] |
|
|
|
|
|
responses = [[], [],[], []] |
|
|
done = [False, False,False, False] |
|
|
|
|
|
while not all(done): |
|
|
for i in range(4): |
|
|
|
|
|
print(generators[i]) |
|
|
print(done[i]) |
|
|
if not done[i]: |
|
|
try: |
|
|
response = next(generators[i]) |
|
|
if response: |
|
|
responses[i].append(response) |
|
|
history[i][-1] = (history[i][-1][0], "".join(responses[i])) |
|
|
states[i] = gr.State(history[i]) |
|
|
yield history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3] |
|
|
except StopIteration: |
|
|
done[i] = True |
|
|
print(history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3]) |
|
|
yield history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3] |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Cherokee Language Function Test", |
|
|
theme=gr.themes.Soft(secondary_hue=gr.themes.colors.sky,neutral_hue=gr.themes.colors.stone), |
|
|
|
|
|
) as demo: |
|
|
num_sides = 4 |
|
|
states = [gr.State() for _ in range(num_sides)] |
|
|
print(states) |
|
|
chatbots = [None] * num_sides |
|
|
models = gr.State(get_random_models) |
|
|
all_models = get_all_models() |
|
|
gr.Markdown( |
|
|
"# Cherokee Language Preserve Model V0.4 \n\nChat with multiple models at the same time and compare their responses. " |
|
|
) |
|
|
with gr.Group(elem_id="share-region-annoy"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
for i in range(num_sides): |
|
|
label = models.value[i]["name"] |
|
|
with gr.Column(scale=1, min_width=200): |
|
|
chatbots[i] = gr.Chatbot( |
|
|
label=label, |
|
|
elem_id=f"chatbot", |
|
|
height=300, |
|
|
show_copy_button=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
textbox = gr.Textbox( |
|
|
show_label=False, |
|
|
placeholder="Enter your query and press ENTER", |
|
|
elem_id="input_box", |
|
|
scale=4, |
|
|
) |
|
|
send_btn = gr.Button(value="Send", variant="primary", scale=0) |
|
|
|
|
|
with gr.Row() as button_row: |
|
|
clear_btn = gr.ClearButton( |
|
|
value="🎲 New Round", |
|
|
elem_id="clear_btn", |
|
|
interactive=False, |
|
|
components=chatbots + states, |
|
|
) |
|
|
regenerate_btn = gr.Button( |
|
|
value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn" |
|
|
) |
|
|
share_btn = gr.Button(value="📷 Share Image") |
|
|
|
|
|
with gr.Row(): |
|
|
examples = gr.Examples( |
|
|
[ |
|
|
"Tell me a story", |
|
|
"What is the capital of France?", |
|
|
"Do you like me?", |
|
|
], |
|
|
inputs=[textbox], |
|
|
label="Example task: General skill", |
|
|
) |
|
|
with gr.Row(): |
|
|
examples = gr.Examples( |
|
|
[ |
|
|
"translate: ᎧᏃᎮᏍᎩ", |
|
|
"Could you assist in rendering this Cherokee word into English?\nᎤᏲᎢ", |
|
|
"translate the following Cherokee word into English. ᏧᎩᏨᏅᏓ", |
|
|
], |
|
|
inputs=[textbox], |
|
|
label="Example task: Translate words", |
|
|
) |
|
|
with gr.Row(): |
|
|
examples = gr.Examples( |
|
|
[ |
|
|
"translate: ᏚᏁᏤᎴᏃ ᎬᏩᏍᏓᏩᏗᏙᎯ, ᎾᏍᎩ ᏥᏳ ᎤᎦᏘᏗᏍᏗᏱ, ᎤᏂᏣᏘ ᎨᏒ ᎢᏳᏍᏗ, ᎾᏍᎩ ᎬᏩᏁᏄᎳᏍᏙᏗᏱ ᏂᎨᏒᎾ", |
|
|
"translate following Cherokee sentences into English.\nᏥᏌᏃ ᎤᏓᏅᏎ ᏚᏘᏅᏎ ᎬᏩᏍᏓᏩᏗᏙᎯ ᎥᏓᎵ ᏭᏂᎶᏎᎢ; ᎤᏂᏣᏘᏃ ᎬᏩᏍᏓᏩᏛᏎᎢ, ᏅᏓᏳᏂᎶᏒᎯ ᎨᎵᎵ, ᎠᎴ ᏧᏗᏱ,", |
|
|
"translate following sentences.\nᎯᎠᏃ ᏄᏪᏎᎴ ᎠᏍᎦᏯ ᎤᏬᏰᏂ ᎤᏩᎢᏎᎸᎯ; ᎠᏰᎵ ᎭᎴᎲᎦ.", |
|
|
], |
|
|
inputs=[textbox], |
|
|
label="Example task: Translate sentences", |
|
|
) |
|
|
|
|
|
with gr.Accordion("Parameters", open=False) as parameter_row: |
|
|
temperature = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
interactive=True, |
|
|
label="Temperature", |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.7, |
|
|
step=0.01, |
|
|
interactive=True, |
|
|
label="Top P", |
|
|
) |
|
|
max_output_tokens = gr.Slider( |
|
|
minimum=16, |
|
|
maximum=4096, |
|
|
value=1024, |
|
|
step=64, |
|
|
interactive=True, |
|
|
label="Max output tokens", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(states[0]), |
|
|
print(states[1]), |
|
|
print(states[2]), |
|
|
print(states[3]), |
|
|
textbox.submit( |
|
|
handle_message, |
|
|
inputs=[ |
|
|
models, |
|
|
textbox, |
|
|
temperature, |
|
|
top_p, |
|
|
max_output_tokens, |
|
|
states[0], |
|
|
states[1], |
|
|
states[2], |
|
|
states[3], |
|
|
], |
|
|
|
|
|
outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
|
|
).then( |
|
|
activate_chat_buttons, |
|
|
inputs=[], |
|
|
outputs=[regenerate_btn, clear_btn], |
|
|
) |
|
|
|
|
|
send_btn.click( |
|
|
handle_message, |
|
|
inputs=[ |
|
|
models, |
|
|
textbox, |
|
|
temperature, |
|
|
top_p, |
|
|
max_output_tokens, |
|
|
states[0], |
|
|
states[1], |
|
|
states[2], |
|
|
states[3], |
|
|
], |
|
|
outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
|
|
).then( |
|
|
activate_chat_buttons, |
|
|
inputs=[], |
|
|
outputs=[regenerate_btn, clear_btn], |
|
|
) |
|
|
|
|
|
regenerate_btn.click( |
|
|
regenerate_message, |
|
|
inputs=[ |
|
|
models, |
|
|
temperature, |
|
|
top_p, |
|
|
max_output_tokens, |
|
|
states[0], |
|
|
states[1], |
|
|
states[2], |
|
|
states[3], |
|
|
|
|
|
], |
|
|
outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
deactivate_chat_buttons, |
|
|
inputs=[], |
|
|
outputs=[regenerate_btn, clear_btn], |
|
|
).then(lambda: get_random_models(), inputs=None, outputs=[models]) |
|
|
|
|
|
share_btn.click(None, inputs=[], outputs=[], js=share_js) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(default_concurrency_limit=10) |
|
|
demo.launch(server_name="127.0.01", server_port=5009, share=True) |