diana3135
fix task description error, pass a textbox rather than a string
093aba3
import gradio as gr
from utils import *
from save_data import add_or_update_row_at_fixed_position, get_sheet_service
from instructions import *
from user_groups import user_data
from constants import SDG_DETAILS, WORD_LIMIT_MIN, WORD_LIMIT_MAX, GROUP_SEPERATION, LOCAL_PARAMS
from html_codes import *
class SessionManager:
def __init__(self):
self.sessions = {}
def add_session(self, cooperate_style, task, identification_code):
if cooperate_style == "sequential":
session = {
"user_identification_code": identification_code,
"task": task,
"cooperate_style": cooperate_style,
"human_initial_answer": None,
"ai_modificated_output": None,
"evaluation": None
}
elif cooperate_style == "reverse_sequential":
session = {
"user_identification_code": identification_code,
"task": task,
"cooperate_style": cooperate_style,
"ai_initial_answer": None,
"human_modifications": None,
"final_answer": None,
"evaluation": None
}
elif cooperate_style == "parallel":
session = {
"user_identification_code": identification_code,
"task": task,
"cooperate_style": cooperate_style,
"ai_initial_answer": None,
"human_initial_answer": None,
"merged_final_answer": None,
"evaluation": None
}
self.sessions[identification_code] = session
return identification_code
def update(self, index, output_content, key='final_output'):
self.sessions[index][key] = output_content
def get_session(self, index):
return self.sessions[index]
def save_session_to_sheet(self, index, service, SHEET_ID):
session = self.sessions[index]
row_id = int(index) % GROUP_SEPERATION + 2 # user data starts from row 2
new_row = list(session.values())
add_or_update_row_at_fixed_position(
row_id = row_id,
new_row = new_row,
service = service,
SPREADSHEET_ID = SHEET_ID,
num_of_columns=len(new_row))
def handle_create_sequential(task, human_input, session_manager, api_key, identification_code):
cooperate_style = "sequential"
session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
session_manager.update(session_index, human_input, 'human_initial_answer')
session_manager.update(session_index, identification_code, 'user_identification_code')
if word_limit_validation(human_input):
output = word_limit_validation(human_input)
else:
output = merge_texts_sequential(task, human_input, api_key)
session_manager.update(session_index, output, 'ai_modificated_output')
return output, session_index
def handle_create_parallel(task, human_input, session_manager, api_key, identification_code):
cooperate_style = "parallel"
session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
if word_limit_validation(human_input):
ai_initial_answer = word_limit_validation(human_input)
final_answer = word_limit_validation(human_input)
else:
ai_initial_answer = generate_ai_initial_answer(task, api_key)
final_answer = merge_texts_parallel(task, human_input, ai_initial_answer, api_key)
session_manager.update(session_index, human_input, 'human_initial_answer')
session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer')
session_manager.update(session_index, final_answer, 'merged_final_answer')
session_manager.update(session_index, identification_code, 'user_identification_code')
return ai_initial_answer, session_index
def handle_create_reverse_sequential(task, session_manager, api_key, identification_code):
cooperate_style = "reverse_sequential"
session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
ai_initial_answer = generate_ai_initial_answer(task, api_key)
session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer')
session_manager.update(session_index, identification_code, 'user_identification_code')
return ai_initial_answer, session_index
def handle_modify_reverse_sequential(session_index, modification_suggestions, session_manager, api_key):
session = session_manager.get_session(session_index)
session_manager.update(session_index, modification_suggestions, 'human_modifications')
if word_limit_validation(modification_suggestions):
final_answer = word_limit_validation(modification_suggestions)
else:
final_answer = modification_suggestions
#final_answer = modify_with_suggestion(session['task'], modification_suggestions, api_key)
session_manager.update(session_index, final_answer, 'final_answer')
return final_answer, session_index
def evaluate_interaction(session_index, session_manager, api_key):
session = session_manager.get_session(session_index)
if session['cooperate_style'] == "sequential":
evaluation = get_evaluation_with_gpt(session['task'], session['ai_modificated_output'], api_key)
elif session['cooperate_style'] == "reverse_sequential":
evaluation = get_evaluation_with_gpt(session['task'], session['final_answer'], api_key)
elif session['cooperate_style'] == "parallel":
evaluation = get_evaluation_with_gpt(session['task'], session['merged_final_answer'], api_key)
session['evaluation'] = evaluation
return evaluation
def save_data(session_index, session_manager, service, SHEET_ID):
session_manager.save_session_to_sheet(session_index, service, SHEET_ID)
return "Data has been saved to Google Sheets."
def login(identification_code):
groups = ["A", "B", "C"]
if not identification_code:
return update_content(None)
user_group_id = int(identification_code)//1000
if user_group_id in range(3):
return update_content(groups[user_group_id])
else:
return update_content(None)
def word_limit_validation(human_input):
words = human_input.split()
if len(words) < WORD_LIMIT_MIN:
return f"Error: Please enter at least 100 words."
elif len(words) > WORD_LIMIT_MAX:
return f"Error: Please enter less than 500 words."
return None
def on_textbox_change(session_index, session_manager, service, SHEET_ID):
return save_data(session_index, session_manager, service, SHEET_ID)
def update_word_count(text):
words = text.split()
return f"Word Count: {len(words)}"
def check_initial_generated(initial_answer):
if not initial_answer:
gr.Warning("Please click 'Create' to generate the AI output first.")
return None
if __name__ == "__main__":
api_key = get_api_key(local=LOCAL_PARAMS)
service, SHEET_IDs = get_sheet_service(local=LOCAL_PARAMS)
SHEET_ID1, SHEET_ID2, SHEET_ID3 = SHEET_IDs
session_manager = SessionManager()
with gr.Blocks(fill_width=True,
css = background_css,
js = no_copy_paste_js
) as app:
title = gr.HTML("<h1 style='color: white;'> Human-AI Ensemble </h1>")
with gr.Row():
identification_code = gr.Textbox(label="Enter your identification code")
login_button = gr.Button("Login")
experiment_notes = gr.Textbox(label ="Reward & Bonus",
value = notes_for_participants())
login_status = gr.Textbox(label="Next Tasks", interactive=False)
group = gr.State()
with gr.Column(visible=False) as task:
description = gr.Textbox(label="Task Description",
value = default_task_description(),
interactive=False,
lines=12)
with gr.Accordion(label = "Click to See 17 SDGs",
open=False):
gr.Markdown(SDG_DETAILS)
# initialization of different group contents
group_a_content = gr.Group(visible=False, elem_id="group-a")
group_b_content = gr.Group(visible=False, elem_id="group-b")
group_c_content = gr.Group(visible=False, elem_id="group-c")
def update_content(group):
if group == "A":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), group_a_instructions()
elif group == "B":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), group_b_instructions()
elif group == "C":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), group_c_instructions()
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), invalid_group()
login_button.click(login, inputs=identification_code, outputs=[task, group_a_content, group_b_content, group_c_content, login_status])
with group_a_content:
with gr.Row():
human_input = gr.Textbox(label="Enter each idea on a new line (Shift+Enter), starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)")
word_count_display_a = gr.Label(value="Word count: 0")
human_input.change(fn=update_word_count, inputs=human_input, outputs=word_count_display_a)
with gr.Row():
submit_btn = gr.Button("Submit & See AI Output")
with gr.Row():
ai_output = gr.Textbox(label="AI Output", interactive=False)
session_index = gr.Number(label="Session Index", visible=False)
submit_btn.click(
fn=lambda task, human_input, id: handle_create_sequential(task, human_input, session_manager, api_key, id),
inputs=[description, human_input, identification_code],
outputs=[ai_output, session_index]
)
# Evaluate without showing
evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False)
ai_output.change(
fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key),
inputs=[session_index],
outputs=[evaluation_result]
)
evaluation_result.change(
fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID1),
inputs = [session_index]
)
save_btn = gr.Button("Save Data", elem_id="save_btn")
save_result = gr.Label()
save_btn.click(
fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID1),
inputs=[session_index],
outputs=[save_result]
)
with group_b_content:
# gr.HTML("<p>Group B Content</p>")
with gr.Row():
create_initial_btn = gr.Button("Create")
with gr.Row():
initial_answer = gr.Textbox(label="AI Output", interactive=False)
with gr.Row():
modification_suggestions = gr.Textbox(label="Please refine AI's three ideas as your final answer, starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)")
word_count_display_b = gr.Label(value="Word count: 0")
modification_suggestions.change(fn=update_word_count, inputs=modification_suggestions, outputs=word_count_display_b)
modification_suggestions.change(fn=check_initial_generated, inputs = [initial_answer] )
with gr.Row():
create_final_btn = gr.Button("Review")
with gr.Row():
final_answer = gr.Textbox(label="Final Answer", interactive=False)
session_index = gr.Number(label="Session Index", visible=False)
create_initial_btn.click(
fn=lambda task, id: handle_create_reverse_sequential(task, session_manager, api_key, id),
inputs=[description, identification_code],
outputs=[initial_answer, session_index]
)
initial_answer.change(
fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID3),
inputs = [session_index]
)
create_final_btn.click(
fn=lambda session_index, modification_suggestions: handle_modify_reverse_sequential(session_index, modification_suggestions, session_manager, api_key),
inputs=[session_index, modification_suggestions],
outputs=[final_answer, session_index]
)
#evaluate_btn = gr.Button("Evaluate")
evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive=False)
final_answer.change(
fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key),
inputs=[session_index],
outputs=[evaluation_result]
)
evaluation_result.change(
fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID3),
inputs = [session_index]
)
save_btn = gr.Button("Save Data", elem_id="save_btn")
save_result = gr.Label()
save_btn.click(
fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID3),
inputs=[session_index],
outputs=[save_result]
)
with group_c_content:
with gr.Row():
human_input = gr.Textbox(label="Enter each idea on a new line (Shift+Enter), starting with '1', '2', and '3'.", placeholder="Please propose 3 ideas to help Airbnb’s business model align with 17 SDGs (At least 100 words)")
word_count_display_c = gr.Label(value="Word count: 0")
human_input.change(fn=update_word_count, inputs=human_input, outputs=word_count_display_c)
with gr.Row():
create_btn = gr.Button("Submit & See AI Output")
with gr.Row():
ai_initial_output = gr.Textbox(label="AI Output Generated Independently", interactive=False)
with gr.Row():
merge_btn = gr.Button("Merge Using the Second AI")
with gr.Row():
final_output = gr.Textbox(label="Final Merged Output", interactive=False)
session_index = gr.Number(label="Session Index", visible=False)
create_btn.click(
fn=lambda task, human_input, id: handle_create_parallel(task, human_input, session_manager, api_key, id),
inputs=[description, human_input, identification_code],
outputs=[ai_initial_output, session_index]
)
merge_btn.click(
fn= lambda session_index : display_merged_output(session_index, session_manager),
inputs = [session_index],
outputs=[final_output]
)
#evaluate_btn = gr.Button("Evaluate")
evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False)
final_output.change(
fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key),
inputs=[session_index],
outputs=[evaluation_result]
)
evaluation_result.change(
fn = lambda session_index: on_textbox_change(session_index, session_manager, service, SHEET_ID2),
inputs = [session_index]
)
save_btn = gr.Button("Save Data",elem_id="save_btn")
save_result = gr.Label()
save_btn.click(
fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID2),
inputs=[session_index],
outputs=[save_result]
)
app.launch(share=True)