Spaces:
Running
Running
| import gradio as gr | |
| from utils import get_base_answer, get_nudging_answer | |
| from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS | |
| import datetime | |
| import logging | |
| # add logging info to console | |
| logging.basicConfig(level=logging.INFO) | |
| addr_limit_counter = {} | |
| LAST_UPDATE_TIME = datetime.datetime.now() | |
| base_models = BASE_MODELS | |
| nudging_models = NUDGING_MODELS | |
| def respond_base( | |
| system_prompt: str, | |
| message: str, | |
| max_tokens: int, | |
| base_model: str, | |
| request:gr.Request | |
| ): | |
| global LAST_UPDATE_TIME, addr_limit_counter | |
| # if already 24 hours passed, reset the counter | |
| if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): | |
| addr_limit_counter = {} | |
| LAST_UPDATE_TIME = datetime.datetime.now() | |
| host_addr = request.client.host | |
| if host_addr not in addr_limit_counter: | |
| addr_limit_counter[host_addr] = 0 | |
| if addr_limit_counter[host_addr] > 50: | |
| raise gr.Error("You have reached the limit of 50 requests for today.", duration=10) | |
| base_answer = get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens) | |
| addr_limit_counter[host_addr] += 1 | |
| logging.info(f"Requesting chat completion from OpenAI API with model {base_model}") | |
| logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") | |
| return [(message, base_answer)] | |
| def respond_nudging( | |
| system_prompt: str, | |
| message: str, | |
| # history: list[tuple[str, str]], | |
| max_tokens: int, | |
| nudging_thres: float, | |
| base_model: str, | |
| nudging_model: str, | |
| request:gr.Request | |
| ): | |
| global LAST_UPDATE_TIME, addr_limit_counter | |
| # if already 24 hours passed, reset the counter | |
| if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): | |
| addr_limit_counter = {} | |
| LAST_UPDATE_TIME = datetime.datetime.now() | |
| host_addr = request.client.host | |
| if host_addr not in addr_limit_counter: | |
| addr_limit_counter[host_addr] = 0 | |
| if addr_limit_counter[host_addr] > 50: | |
| raise gr.Error("You have reached the limit of 50 requests for today.", duration=10) | |
| all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres) | |
| all_completions = all_info["all_completions"] | |
| nudging_words = all_info["all_nudging_words"] | |
| formatted_response = format_response(all_completions, nudging_words) | |
| addr_limit_counter[host_addr] += 1 | |
| logging.info(f"Requesting chat completion from OpenAI API with model {base_model} and {nudging_model}") | |
| logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") | |
| return [(message, formatted_response)] | |
| def clear_fn(): | |
| # mega_hist["base"] = [] | |
| # mega_hist["aligned"] = [] | |
| return None, None, None | |
| def format_response(all_completions, nudging_words): | |
| html_code = "" | |
| for all_completion, nudging_word in zip(all_completions, nudging_words): | |
| # each all_completion = nudging_word + base_completion | |
| base_completion = all_completion[len(nudging_word):] | |
| base_completion = base_completion | |
| nudging_word = nudging_word | |
| html_code += f"<mark>{nudging_word}</mark>{base_completion}" | |
| return html_code | |
| with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo: | |
| api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False) | |
| gr.Markdown(HEADER_MD) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=1.5): | |
| system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here") | |
| message = gr.Textbox(label="Prompt", placeholder="Enter your message here") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True) | |
| nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True) | |
| with gr.Accordion("Nudging Parameters", open=True): | |
| with gr.Row(): | |
| max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True) | |
| nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4) | |
| with gr.Row(): | |
| btn = gr.Button("Generate") | |
| with gr.Row(): | |
| stop_btn = gr.Button("Stop") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| chat_b = gr.Chatbot(height=500, label="Base Answer") | |
| chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot") | |
| base_model_choice.value = "Llama-2-70B" | |
| nudging_model_choice.value = "Mistral-7B-v0.1-Instruct" | |
| # nudging_model_choice.value = "Llama-2-13B-chat" | |
| system_prompt.value = "Answer the question by walking through the reasoning steps." | |
| message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?" | |
| model_type_left = gr.Textbox(visible=False, value="base") | |
| model_type_right = gr.Textbox(visible=False, value="aligned") | |
| go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a) | |
| go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b) | |
| stop_btn.click(None, None, None, cancels=[go1, go2]) | |
| clear_btn.click(clear_fn, None, [message, chat_a, chat_b]) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False) |