Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from utils import * | |
| from save_data import add_or_update_row_at_fixed_position, get_sheet_service | |
| from instructions import * | |
| from user_groups import user_data | |
| from constants import SDG_DETAILS, WORD_LIMIT_MIN, WORD_LIMIT_MAX, GROUP_SEPERATION, LOCAL_PARAMS | |
| from html_codes import * | |
| class SessionManager: | |
| def __init__(self): | |
| self.sessions = {} | |
| def add_session(self, cooperate_style, task, identification_code): | |
| if cooperate_style == "sequential": | |
| session = { | |
| "user_identification_code": identification_code, | |
| "task": task, | |
| "cooperate_style": cooperate_style, | |
| "human_initial_answer": None, | |
| "ai_modificated_output": None, | |
| "evaluation": None | |
| } | |
| elif cooperate_style == "reverse_sequential": | |
| session = { | |
| "user_identification_code": identification_code, | |
| "task": task, | |
| "cooperate_style": cooperate_style, | |
| "ai_initial_answer": None, | |
| "human_modifications": None, | |
| "final_answer": None, | |
| "evaluation": None | |
| } | |
| elif cooperate_style == "parallel": | |
| session = { | |
| "user_identification_code": identification_code, | |
| "task": task, | |
| "cooperate_style": cooperate_style, | |
| "ai_initial_answer": None, | |
| "human_initial_answer": None, | |
| "merged_final_answer": None, | |
| "evaluation": None | |
| } | |
| self.sessions[identification_code] = session | |
| return identification_code | |
| def update(self, index, output_content, key='final_output'): | |
| self.sessions[index][key] = output_content | |
| def get_session(self, index): | |
| return self.sessions[index] | |
| def save_session_to_sheet(self, index, service, SHEET_ID): | |
| session = self.sessions[index] | |
| row_id = int(index) % GROUP_SEPERATION + 2 # user data starts from row 2 | |
| new_row = list(session.values()) | |
| add_or_update_row_at_fixed_position( | |
| row_id = row_id, | |
| new_row = new_row, | |
| service = service, | |
| SPREADSHEET_ID = SHEET_ID, | |
| num_of_columns=len(new_row)) | |
| def handle_create_sequential(task, human_input, session_manager, api_key, identification_code): | |
| cooperate_style = "sequential" | |
| session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code) | |
| session_manager.update(session_index, human_input, 'human_initial_answer') | |
| session_manager.update(session_index, identification_code, 'user_identification_code') | |
| if word_limit_validation(human_input): | |
| output = word_limit_validation(human_input) | |
| else: | |
| output = merge_texts_sequential(task, human_input, api_key) | |
| session_manager.update(session_index, output, 'ai_modificated_output') | |
| return output, session_index | |
| def handle_create_parallel(task, human_input, session_manager, api_key, identification_code): | |
| cooperate_style = "parallel" | |
| session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code) | |
| if word_limit_validation(human_input): | |
| ai_initial_answer = word_limit_validation(human_input) | |
| final_answer = word_limit_validation(human_input) | |
| else: | |
| ai_initial_answer = generate_ai_initial_answer(task, api_key) | |
| final_answer = merge_texts_parallel(task, human_input, ai_initial_answer, api_key) | |
| session_manager.update(session_index, human_input, 'human_initial_answer') | |
| session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer') | |
| session_manager.update(session_index, final_answer, 'merged_final_answer') | |
| session_manager.update(session_index, identification_code, 'user_identification_code') | |
| return ai_initial_answer, session_index | |
| def handle_create_reverse_sequential(task, session_manager, api_key, identification_code): | |
| cooperate_style = "reverse_sequential" | |
| session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code) | |
| ai_initial_answer = generate_ai_initial_answer(task, api_key) | |
| session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer') | |
| session_manager.update(session_index, identification_code, 'user_identification_code') | |
| return ai_initial_answer, session_index | |
| def handle_modify_reverse_sequential(session_index, modification_suggestions, session_manager, api_key): | |
| session = session_manager.get_session(session_index) | |
| session_manager.update(session_index, modification_suggestions, 'human_modifications') | |
| if word_limit_validation(modification_suggestions): | |
| final_answer = word_limit_validation(modification_suggestions) | |
| else: | |
| final_answer = modification_suggestions | |
| #final_answer = modify_with_suggestion(session['task'], modification_suggestions, api_key) | |
| session_manager.update(session_index, final_answer, 'final_answer') | |
| return final_answer, session_index | |
| def evaluate_interaction(session_index, session_manager, api_key): | |
| session = session_manager.get_session(session_index) | |
| if session['cooperate_style'] == "sequential": | |
| evaluation = get_evaluation_with_gpt(session['task'], session['ai_modificated_output'], api_key) | |
| elif session['cooperate_style'] == "reverse_sequential": | |
| evaluation = get_evaluation_with_gpt(session['task'], session['final_answer'], api_key) | |
| elif session['cooperate_style'] == "parallel": | |
| evaluation = get_evaluation_with_gpt(session['task'], session['merged_final_answer'], api_key) | |
| session['evaluation'] = evaluation | |
| return evaluation | |
| def save_data(session_index, session_manager, service, SHEET_ID): | |
| session_manager.save_session_to_sheet(session_index, service, SHEET_ID) | |
| return "Data has been saved to Google Sheets." | |
| def login(identification_code): | |
| groups = ["A", "B", "C"] | |
| if not identification_code: | |
| return update_content(None) | |
| user_group_id = int(identification_code)//1000 | |
| if user_group_id in range(3): | |
| return update_content(groups[user_group_id]) | |
| else: | |
| return update_content(None) | |
| def word_limit_validation(human_input): | |
| words = human_input.split() | |
| if len(words) < WORD_LIMIT_MIN: | |
| return f"Error: Please enter at least 100 words." | |
| elif len(words) > WORD_LIMIT_MAX: | |
| return f"Error: Please enter less than 500 words." | |
| return None | |
| def on_textbox_change(session_index, session_manager, service, SHEET_ID): | |
| return save_data(session_index, session_manager, service, SHEET_ID) | |
| def update_word_count(text): | |
| words = text.split() | |
| return f"Word Count: {len(words)}" | |
| def check_initial_generated(initial_answer): | |
| if not initial_answer: | |
| gr.Warning("Please click 'Create' to generate the AI output first.") | |
| return None | |
| if __name__ == "__main__": | |
| api_key = get_api_key(local=LOCAL_PARAMS) | |
| service, SHEET_IDs = get_sheet_service(local=LOCAL_PARAMS) | |
| SHEET_ID1, SHEET_ID2, SHEET_ID3 = SHEET_IDs | |
| session_manager = SessionManager() | |
| with gr.Blocks(fill_width=True, | |
| css = background_css, | |
| js = no_copy_paste_js | |
| ) as app: | |
| title = gr.HTML("<h1 style='color: white;'> Human-AI Ensemble </h1>") | |
| with gr.Row(): | |
| identification_code = gr.Textbox(label="Enter your identification code") | |
| login_button = gr.Button("Login") | |
| experiment_notes = gr.Textbox(label ="Reward & Bonus", | |
| value = notes_for_participants()) | |
| login_status = gr.Textbox(label="Next Tasks", interactive=False) | |
| group = gr.State() | |
| with gr.Column(visible=False) as task: | |
| description = gr.Textbox(label="Task Description", | |
| value = default_task_description(), | |
| interactive=False, | |
| lines=12) | |
| with gr.Accordion(label = "Click to See 17 SDGs", | |
| open=False): | |
| gr.Markdown(SDG_DETAILS) | |
| # initialization of different group contents | |
| group_a_content = gr.Group(visible=False, elem_id="group-a") | |
| group_b_content = gr.Group(visible=False, elem_id="group-b") | |
| group_c_content = gr.Group(visible=False, elem_id="group-c") | |
| def update_content(group): | |
| if group == "A": | |
| return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), group_a_instructions() | |
| elif group == "B": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), group_b_instructions() | |
| elif group == "C": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), group_c_instructions() | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), invalid_group() | |
| login_button.click(login, inputs=identification_code, outputs=[task, group_a_content, group_b_content, group_c_content, login_status]) | |
| with group_a_content: | |
| with gr.Row(): | |
| human_input = gr.Textbox(label="Enter each idea on a new line (Shift+Enter), starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)") | |
| word_count_display_a = gr.Label(value="Word count: 0") | |
| human_input.change(fn=update_word_count, inputs=human_input, outputs=word_count_display_a) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit & See AI Output") | |
| with gr.Row(): | |
| ai_output = gr.Textbox(label="AI Output", interactive=False) | |
| session_index = gr.Number(label="Session Index", visible=False) | |
| submit_btn.click( | |
| fn=lambda task, human_input, id: handle_create_sequential(task, human_input, session_manager, api_key, id), | |
| inputs=[description, human_input, identification_code], | |
| outputs=[ai_output, session_index] | |
| ) | |
| # Evaluate without showing | |
| evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False) | |
| ai_output.change( | |
| fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key), | |
| inputs=[session_index], | |
| outputs=[evaluation_result] | |
| ) | |
| evaluation_result.change( | |
| fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID1), | |
| inputs = [session_index] | |
| ) | |
| save_btn = gr.Button("Save Data", elem_id="save_btn") | |
| save_result = gr.Label() | |
| save_btn.click( | |
| fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID1), | |
| inputs=[session_index], | |
| outputs=[save_result] | |
| ) | |
| with group_b_content: | |
| # gr.HTML("<p>Group B Content</p>") | |
| with gr.Row(): | |
| create_initial_btn = gr.Button("Create") | |
| with gr.Row(): | |
| initial_answer = gr.Textbox(label="AI Output", interactive=False) | |
| with gr.Row(): | |
| modification_suggestions = gr.Textbox(label="Please refine AI's three ideas as your final answer, starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)") | |
| word_count_display_b = gr.Label(value="Word count: 0") | |
| modification_suggestions.change(fn=update_word_count, inputs=modification_suggestions, outputs=word_count_display_b) | |
| modification_suggestions.change(fn=check_initial_generated, inputs = [initial_answer] ) | |
| with gr.Row(): | |
| create_final_btn = gr.Button("Review") | |
| with gr.Row(): | |
| final_answer = gr.Textbox(label="Final Answer", interactive=False) | |
| session_index = gr.Number(label="Session Index", visible=False) | |
| create_initial_btn.click( | |
| fn=lambda task, id: handle_create_reverse_sequential(task, session_manager, api_key, id), | |
| inputs=[description, identification_code], | |
| outputs=[initial_answer, session_index] | |
| ) | |
| initial_answer.change( | |
| fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID3), | |
| inputs = [session_index] | |
| ) | |
| create_final_btn.click( | |
| fn=lambda session_index, modification_suggestions: handle_modify_reverse_sequential(session_index, modification_suggestions, session_manager, api_key), | |
| inputs=[session_index, modification_suggestions], | |
| outputs=[final_answer, session_index] | |
| ) | |
| #evaluate_btn = gr.Button("Evaluate") | |
| evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive=False) | |
| final_answer.change( | |
| fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key), | |
| inputs=[session_index], | |
| outputs=[evaluation_result] | |
| ) | |
| evaluation_result.change( | |
| fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID3), | |
| inputs = [session_index] | |
| ) | |
| save_btn = gr.Button("Save Data", elem_id="save_btn") | |
| save_result = gr.Label() | |
| save_btn.click( | |
| fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID3), | |
| inputs=[session_index], | |
| outputs=[save_result] | |
| ) | |
| with group_c_content: | |
| with gr.Row(): | |
| human_input = gr.Textbox(label="Enter each idea on a new line (Shift+Enter), starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)") | |
| word_count_display_c = gr.Label(value="Word count: 0") | |
| human_input.change(fn=update_word_count, inputs=human_input, outputs=word_count_display_c) | |
| with gr.Row(): | |
| create_btn = gr.Button("Submit & See AI Output") | |
| with gr.Row(): | |
| ai_initial_output = gr.Textbox(label="AI Output Generated Independently", interactive=False) | |
| with gr.Row(): | |
| merge_btn = gr.Button("Merge Using the Second AI") | |
| with gr.Row(): | |
| final_output = gr.Textbox(label="Final Merged Output", interactive=False) | |
| session_index = gr.Number(label="Session Index", visible=False) | |
| create_btn.click( | |
| fn=lambda task, human_input, id: handle_create_parallel(task, human_input, session_manager, api_key, id), | |
| inputs=[description, human_input, identification_code], | |
| outputs=[ai_initial_output, session_index] | |
| ) | |
| merge_btn.click( | |
| fn= lambda session_index : display_merged_output(session_index, session_manager), | |
| inputs = [session_index], | |
| outputs=[final_output] | |
| ) | |
| #evaluate_btn = gr.Button("Evaluate") | |
| evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False) | |
| final_output.change( | |
| fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key), | |
| inputs=[session_index], | |
| outputs=[evaluation_result] | |
| ) | |
| evaluation_result.change( | |
| fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID2), | |
| inputs = [session_index] | |
| ) | |
| save_btn = gr.Button("Save Data",elem_id="save_btn") | |
| save_result = gr.Label() | |
| save_btn.click( | |
| fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID2), | |
| inputs=[session_index], | |
| outputs=[save_result] | |
| ) | |
| app.launch(share=True) | |