Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import requests | |
| import os | |
| from model_inference import Inference | |
| import time | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| question_selector_map = {} | |
| every_model = ["llama2", "llama2-chat", "vicuna", "falcon", "falcon-instruct", "orca", "wizardlm"] | |
| with open("src/inference_endpoint.json", "r") as f: | |
| inference_endpoint = json.load(f) | |
| for i in range(len(every_model)): | |
| inference_endpoint[every_model[i]]["headers"]["Authorization"] += HF_TOKEN | |
| def build_question_selector_map(questions): | |
| question_selector_map = {} | |
| # Build question selector map | |
| for q in questions: | |
| preview = f"{q['question_id']+1}: " + q["question"][:128] + "..." | |
| question_selector_map[preview] = q | |
| return question_selector_map | |
| def math_display_question_answer(question, cot, request: gr.Request): | |
| if cot: | |
| q = math_cot_question_selector_map[question] | |
| else: | |
| q = math_question_selector_map[question] | |
| return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
| def gsm_display_question_answer(question, cot, request: gr.Request): | |
| if cot: | |
| q = gsm_cot_question_selector_map[question] | |
| else: | |
| q = gsm_question_selector_map[question] | |
| return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
| def mmlu_display_question_answer(question, cot, request: gr.Request): | |
| if cot: | |
| q = mmlu_cot_question_selector_map[question] | |
| else: | |
| q = mmlu_question_selector_map[question] | |
| return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
| def warmup(list_model, model_inference_endpoints=inference_endpoint): | |
| for model in list_model: | |
| model = model.lower() | |
| API_URL = model_inference_endpoints[model]["API_URL"] | |
| headers = model_inference_endpoints[model]["headers"] | |
| headers["Authorization"] += HF_TOKEN | |
| def query(payload): | |
| return requests.post(API_URL, headers=headers, json=payload) | |
| output = query({ | |
| "inputs": "Hello. " | |
| }) | |
| time.sleep(300) | |
| return { | |
| model_list: gr.update(visible=False), | |
| options: gr.update(visible=True), | |
| inputbox: gr.update(visible=True), | |
| submit: gr.update(visible=True), | |
| warmup_button: gr.update(visible=False), | |
| welcome_message: gr.update(visible=True) | |
| } | |
| def inference(model_list, question, API_KEY, cot, hf_token=HF_TOKEN): | |
| if len(model_list) != 3: | |
| raise gr.Error("Please choose just '3' models! Neither more nor less!") | |
| for i in range(len(model_list)): | |
| model_list[i] = model_list[i].lower() | |
| model_response = Inference(model_list, question, API_KEY, cot, hf_token) | |
| return { | |
| output_msg: gr.update(visible=True), | |
| output_col: gr.update(visible=True), | |
| model1_output1: model_response["agent_response"][model_list[0]][0], | |
| model2_output1: model_response["agent_response"][model_list[1]][0], | |
| model3_output1: model_response["agent_response"][model_list[2]][0], | |
| summarization_text1: model_response["summarization"][0], | |
| model1_output2: model_response["agent_response"][model_list[0]][1], | |
| model2_output2: model_response["agent_response"][model_list[1]][1], | |
| model3_output2: model_response["agent_response"][model_list[2]][1], | |
| summarization_text2: model_response["summarization"][1], | |
| model1_output3: model_response["agent_response"][model_list[0]][2], | |
| model2_output3: model_response["agent_response"][model_list[1]][2], | |
| model3_output3: model_response["agent_response"][model_list[2]][2] | |
| } | |
| def load_responses(): | |
| with open("result/Math/math_result.json", "r") as math_file: | |
| math_responses = json.load(math_file) | |
| with open("result/Math/math_result_cot.json", "r") as math_cot_file: | |
| math_cot_responses = json.load(math_cot_file) | |
| with open("result/GSM8K/gsm_result.json", "r") as gsm_file: | |
| gsm_responses = json.load(gsm_file) | |
| with open("result/GSM8K/gsm_result_cot.json", "r") as gsm_cot_file: | |
| gsm_cot_responses = json.load(gsm_cot_file) | |
| with open("result/MMLU/mmlu_result.json", "r") as mmlu_file: | |
| mmlu_responses = json.load(mmlu_file) | |
| with open("result/MMLU/mmlu_result_cot.json", "r") as mmlu_cot_file: | |
| mmlu_cot_responses = json.load(mmlu_cot_file) | |
| return math_responses, math_cot_responses, gsm_responses, gsm_cot_responses, mmlu_responses, mmlu_cot_responses | |
| def load_questions(math, gsm, mmlu): | |
| math_questions = [] | |
| gsm_questions = [] | |
| mmlu_questions = [] | |
| for i in range(100): | |
| math_questions.append(f"{i+1}: " + math[i]["question"][:128] + "...") | |
| gsm_questions.append(f"{i+1}: " + gsm[i]["question"][:128] + "...") | |
| mmlu_questions.append(f"{i+1}: " + mmlu[i]["question"][:128] + "...") | |
| return math_questions, gsm_questions, mmlu_questions | |
| math_result, math_cot_result, gsm_result, gsm_cot_result, mmlu_result, mmlu_cot_result = load_responses() | |
| math_questions, gsm_questions, mmlu_questions = load_questions(math_result, gsm_result, mmlu_result) | |
| math_question_selector_map = build_question_selector_map(math_result) | |
| math_cot_question_selector_map = build_question_selector_map(math_cot_result) | |
| gsm_question_selector_map = build_question_selector_map(gsm_result) | |
| gsm_cot_question_selector_map = build_question_selector_map(gsm_cot_result) | |
| mmlu_question_selector_map = build_question_selector_map(mmlu_result) | |
| mmlu_cot_question_selector_map = build_question_selector_map(mmlu_cot_result) | |
| TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>""" | |
| INTRODUCTION_TEXT = """ | |
| The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325). | |
| Thank you to the authors of this paper for suggesting a great idea! | |
| Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗 | |
| For more details, please refer to the [GitHub Repository](https://github.com/gauss5930/LLM-Agora). | |
| You can also check the results in this Space! | |
| You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality! | |
| The Math, GSM8K, and MMLU Tabs show the results of the experiment(Llama2, WizardLM2, Orca2), and for inference, please use the 'Inference' tab. | |
| Here's how to use LLM Agora! | |
| 1. Before starting, choose just 3 models and click the 'Warm-up LLM Agora 🔥' button and wait until '🤗🔥 Welcome to LLM Agora 🔥🤗' appears. (Suggest to go grab a coffee☕ since it takes 5 minutes!) | |
| 2. Once the interaction space is available, proceed with the following process. | |
| 3. Check the CoT box if you want to utilize the Chain-of-Thought while inferencing. | |
| 4. Please fill in your OpenAI API KEY, it will be used to use ChatGPT to summarize the responses. | |
| 5. Type your question in the Question box and click the 'Submit' button! If you do so, LLM Agora will show you improved answers! 🤗 (It will take roughly a minute! Please wait for an answer!) | |
| For more detailed information, please check '※ Specific information about LLM Agora' at the bottom of the page. | |
| ※ Due to quota limitations, 'Llama2-Chat' and 'Falcon-Instruct' are currently unavailable. We will provide additional updates in the future. | |
| """ | |
| WELCOME_TEXT = """<h1 align="center">🤗🔥 Welcome to LLM Agora 🔥🤗</h1>""" | |
| RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>""" | |
| SPECIFIC_INFORMATION = """ | |
| This is the specific information about LLM Agora! | |
| **Tasks** | |
| - Math: The problem of arithmetic operations on six randomly selected numbers. The format is '{}+{}*{}+{}-{}*{}=?' | |
| - GSM8K: GSM8K is a dataset of 8.5K high quality linguistically diverse grade school math word problems created by human problem writers. | |
| - MMLU: MMLU (Massive Multitask Language Understanding) is a new benchmark designed to measure knowledge acquired during pretraining by evaluating models exclusively in zero-shot and few-shot settings. | |
| **Model size** | |
| Besides Falcon, all other models are based on Llama2. | |
| |Model name|Model size| | |
| |---|---| | |
| |Llama2|13B| | |
| |Llama2-Chat|13B| | |
| |Vicuna|13B| | |
| |Falcon|7B| | |
| |Falcon-Instruct|7B| | |
| |WizardLM|13B| | |
| |Orca|13B| | |
| **Agent numbers & Debate rounds** | |
| - We limit the number of agents and debate rounds because of the limitation of resources. As a result, we decided to use 3 agents and 2 rounds of debate! | |
| **GitHub Repository** | |
| - If you want to see more specific information, please check the [GitHub Repository](https://github.com/gauss5930/LLM-Agora) of LLM Agora! | |
| **Citation** | |
| ``` | |
| @article{du2023improving, | |
| title={Improving Factuality and Reasoning in Language Models through Multiagent Debate}, | |
| author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor}, | |
| journal={arXiv preprint arXiv:2305.14325}, | |
| year={2023} | |
| } | |
| ``` | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.HTML(TITLE) | |
| gr.Markdown(INTRODUCTION_TEXT) | |
| with gr.Column(): | |
| with gr.Tab("Inference"): | |
| model_list = gr.CheckboxGroup(["Llama2", "Vicuna", "Falcon", "WizardLM", "Orca"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value", visible=True) | |
| warmup_button = gr.Button("Warm-up LLM Agora 🔥", visible=True) | |
| welcome_message = gr.HTML(WELCOME_TEXT, visible=False) | |
| with gr.Row(visible=False) as options: | |
| cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?") | |
| API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password") | |
| with gr.Column(visible=False) as inputbox: | |
| question = gr.Textbox(label="Question", value="", info="Please type your question!", placeholder="") | |
| submit = gr.Button("Submit", visible=False) | |
| with gr.Row(visible=False) as output_msg: | |
| gr.HTML(RESPONSE_TEXT) | |
| with gr.Column(visible=False) as output_col: | |
| with gr.Row(elem_id="model1_response"): | |
| model1_output1 = gr.Textbox(label="1️⃣ model's initial response") | |
| model2_output1 = gr.Textbox(label="2️⃣ model's initial response") | |
| model3_output1 = gr.Textbox(label="3️⃣ model's initial response") | |
| summarization_text1 = gr.Textbox(label="Summarization 1") | |
| with gr.Row(elem_id="model2_response"): | |
| model1_output2 = gr.Textbox(label="1️⃣ model's revised response") | |
| model2_output2 = gr.Textbox(label="2️⃣ model's revised response") | |
| model3_output2 = gr.Textbox(label="3️⃣ model's revised response") | |
| summarization_text2 = gr.Textbox(label="Summarization 2") | |
| with gr.Row(elem_id="model3_response"): | |
| model1_output3 = gr.Textbox(label="1️⃣ model's final response") | |
| model2_output3 = gr.Textbox(label="2️⃣ model's final response") | |
| model3_output3 = gr.Textbox(label="3️⃣ model's final response") | |
| with gr.Tab("Math"): | |
| math_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
| math_question_list = gr.Dropdown(math_questions, label="Math Question") | |
| with gr.Column(): | |
| with gr.Row(elem_id="model1_response"): | |
| math_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
| math_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
| math_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
| math_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
| with gr.Row(elem_id="model2_response"): | |
| math_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
| math_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
| math_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
| math_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
| with gr.Row(elem_id="model3_response"): | |
| math_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
| math_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
| math_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
| gr.HTML("""<h1 align="center"> The result of Math </h1>""") | |
| gr.HTML("""<p align="center"><img src='https://github.com/gauss5930/LLM-Agora/assets/80087878/4fc22896-1306-4a93-bd54-a7a2ff184c98'></p>""") | |
| math_cot.select( | |
| math_display_question_answer, | |
| [math_question_list, math_cot], | |
| [math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3] | |
| ) | |
| math_question_list.change( | |
| math_display_question_answer, | |
| [math_question_list, math_cot], | |
| [math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3] | |
| ) | |
| with gr.Tab("GSM8K"): | |
| gsm_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
| gsm_question_list = gr.Dropdown(gsm_questions, label="GSM8K Question") | |
| with gr.Column(): | |
| with gr.Row(elem_id="model1_response"): | |
| gsm_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
| gsm_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
| gsm_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
| gsm_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
| with gr.Row(elem_id="model2_response"): | |
| gsm_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
| gsm_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
| gsm_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
| gsm_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
| with gr.Row(elem_id="model3_response"): | |
| gsm_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
| gsm_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
| gsm_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
| gr.HTML("""<h1 align="center"> The result of GSM8K </h1>""") | |
| gr.HTML("""<p align="center"><img src="https://github.com/gauss5930/LLM-Agora/assets/80087878/64f05ea4-5bec-41e4-83d7-d8855e753290"></p>""") | |
| gsm_cot.select( | |
| gsm_display_question_answer, | |
| [gsm_question_list, gsm_cot], | |
| [gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3] | |
| ) | |
| gsm_question_list.change( | |
| gsm_display_question_answer, | |
| [gsm_question_list, gsm_cot], | |
| [gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3] | |
| ) | |
| with gr.Tab("MMLU"): | |
| mmlu_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
| mmlu_question_list = gr.Dropdown(mmlu_questions, label="MMLU Question") | |
| with gr.Column(): | |
| with gr.Row(elem_id="model1_response"): | |
| mmlu_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
| mmlu_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
| mmlu_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
| mmlu_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
| with gr.Row(elem_id="model2_response"): | |
| mmlu_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
| mmlu_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
| mmlu_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
| mmlu_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
| with gr.Row(elem_id="model3_response"): | |
| mmlu_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
| mmlu_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
| mmlu_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
| gr.HTML("""<h1 align="center"> The result of MMLU </h1>""") | |
| gr.HTML("""<p align="center"><img src="https://github.com/composable-models/llm_multiagent_debate/assets/80087878/963571aa-228b-4d73-9082-5f528552383e"></p>""") | |
| mmlu_cot.select( | |
| mmlu_display_question_answer, | |
| [mmlu_question_list, mmlu_cot], | |
| [mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3] | |
| ) | |
| mmlu_question_list.change( | |
| mmlu_display_question_answer, | |
| [mmlu_question_list, mmlu_cot], | |
| [mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3] | |
| ) | |
| with gr.Accordion("※ Specific information about LLM Agora", open=False): | |
| gr.Markdown(SPECIFIC_INFORMATION) | |
| warmup_button.click(warmup, [model_list], [model_list, options, inputbox, submit, warmup_button, welcome_message]) | |
| submit.click(inference, [model_list, question, API_KEY, cot], [output_msg, output_col, model1_output1, model2_output1, model3_output1, summarization_text1, model1_output2, model2_output2, model3_output2, summarization_text2, model1_output3, model2_output3, model3_output3]) | |
| demo.launch() |