Spaces:
Runtime error
Runtime error
| #from utils.multiple_stream import create_interface | |
| import random | |
| import gradio as gr | |
| import json | |
| from utils.data import dataset | |
| from utils.multiple_stream import stream_data | |
| from pages.summarization_playground import get_model_batch_generation | |
| from pages.summarization_playground import custom_css | |
| global global_selected_choice | |
| def random_data_selection(): | |
| datapoint = random.choice(dataset) | |
| datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'] | |
| return datapoint | |
| # Function to handle user selection and disable the radio | |
| def lock_selection(selected_option): | |
| global global_selected_choice | |
| global_selected_choice = selected_option # Store the selected choice in the variable | |
| return gr.update(visible=True), selected_option, gr.update(interactive=False), gr.update(interactive=False) | |
| def create_arena(): | |
| with open("prompt/prompt.json", "r") as file: | |
| json_data = file.read() | |
| prompts = json.loads(json_data) | |
| with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo: | |
| with gr.Group(): | |
| datapoint = random_data_selection() | |
| gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt. | |
| Once the streaming is complete, you can choose the best response.\u2764\ufe0f""") | |
| data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint) | |
| with gr.Row(): | |
| random_selection_button = gr.Button("Change Data") | |
| submit_button = gr.Button("✨ Click to Streaming ✨") | |
| random_selection_button.click( | |
| fn=random_data_selection, | |
| inputs=[], | |
| outputs=[data_textbox] | |
| ) | |
| random.shuffle(prompts) | |
| random_selected_prompts = prompts[:3] | |
| with gr.Row(): | |
| columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))] | |
| content_list = [prompt['prompt'] + '\n{' + data_textbox.value + '}\n\nsummary:' for prompt in random_selected_prompts] | |
| model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct") | |
| def start_streaming(): | |
| for data in stream_data(content_list, model): | |
| updates = [gr.update(value=data[i]) for i in range(len(columns))] | |
| yield tuple(updates) | |
| submit_button.click( | |
| fn=start_streaming, | |
| inputs=[], | |
| outputs=columns, | |
| show_progress=False | |
| ) | |
| choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"]) | |
| submit_button = gr.Button("Submit") | |
| # Output to display the selected option | |
| output = gr.Textbox(label="You selected:", visible=False) | |
| submit_button.click(fn=lock_selection, inputs=choice, outputs=[output, output, choice, submit_button]) | |
| global global_selected_choice | |
| if global_selected_choice == "Response 1": | |
| prompt_id = random_selected_prompts[0] | |
| elif global_selected_choice == "Response 2": | |
| prompt_id = random_selected_prompts[1] | |
| elif global_selected_choice == "Response 3": | |
| prompt_id = random_selected_prompts[2] | |
| else: | |
| raise ValueError(f"No corresponding response of {global_selected_choice}") | |
| for i in range(len(prompts)): | |
| if prompts[i]['id'] == prompt_id: | |
| prompts[i]["metric"]["winning_number"] += 1 | |
| break | |
| if i == len(prompts)-1: | |
| raise ValueError(f"No prompt of id {prompt_id}") | |
| with open("prompt/prompt.json", "w") as f: | |
| json.dump(prompts, f) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_arena() | |
| demo.queue() | |
| demo.launch() | |