""" Cognitive Debriefing App - Respondent Interface Author: Dr Musashi Hinck Respondent-facing app. Reads arguments from request (in form of shareable link) Change Log: - 2024.01.16: Continuous logging to wandb, change name of run to `userid` """ from __future__ import annotations import os import logging import json import wandb import gradio as gr import openai from base64 import urlsafe_b64decode logger = logging.getLogger(__name__) from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key # %% Initialization if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT": seed_openai_key() client = openai.OpenAI() # %% (functions) def decode_config(config_dta: str) -> dict[str, str | float]: "Read base64_url encoded json and loads into configuration" config_str: str = urlsafe_b64decode(config_dta) config: dict = json.loads(config_str) return config def load_config(request: gr.Request): "Read parameters from request header" config = decode_config(request.query_params["dta"]) survey_question = config["question"] survey_template = config["template"] initial_message = config["initial_message"] model_args = {"model": config["model"], "temperature": config["temperature"]} userid = config["userid"] return survey_question, survey_template, initial_message, model_args, userid # Post-loading def update_template(question: str, template: PromptTemplate | str) -> str: """ Updates templates. Currently only accepts a "question" variable, but can add future templating in the future. """ if isinstance(template, str): template = PromptTemplate(template) if "question" in template.variables: return template.format(question=question) else: return str(template) def reset_interview() -> tuple[list[list[str | None]], gr.Button, gr.Button]: wandb.finish() gr.Info("Interview reset.") return ( [], gr.Button("Start Interview", visible=True), gr.Button("Reply", visible=False), gr.Button("Save Survey", visible=False, variant="secondary"), gr.Button("Save and Exit", visible=False, variant="stop"), ) def initialize_interview( system_message: str, first_question: str, model_args: dict[str, str | float] ) -> tuple[list[list[str | None]], gr.Textbox, gr.Button, gr.Button]: "Read system prompt and start interview" if len(first_question) == 0: first_question = call_openai( [], system_message, client, model_args, stream=False ) # Use fixed prompt chat_history = [[None, first_question]] return ( chat_history, gr.Textbox( placeholder="Type response here.", interactive=True, show_label=False ), gr.Button(variant="primary", interactive=True), gr.Button("Start Interview", visible=False), gr.Button("Save and Exit", visible=True, variant="stop"), ) def initialize_tracker( model_args: dict[str, str | float], question: str, template: PromptTemplate, userid=str, ) -> None: "Initializes wandb run for interview" run_config = model_args | { "question": question, "template": str(template), "userid": userid, } wandb.init( project="cognitive-debrief", name=userid, config=run_config, tags=["dev"] ) def save_interview( chat_history: list[list[str | None]], ) -> None: chat_data = [] for pair in chat_history: for i, role in enumerate(["user", "bot"]): if pair[i] is not None: chat_data += [[role, pair[i]]] chat_table = wandb.Table(data=chat_data, columns=["role", "message"]) logger.info("Uploading interview transcript to WandB...") wandb.log({"chat_history": chat_table}) logger.info("Uploading complete.") def call_openai( messages: list[dict[str, str]], system_message: str | None, client: openai.Client, model_args: dict, stream: bool = False, ): "Utility function for calling OpenAI chat. Expects formatted messages." if not messages: messages = [] if system_message: messages = [{"role": "system", "content": system_message}] + messages try: response = client.chat.completions.create( messages=messages, **model_args, stream=stream ) if stream: for chunk in response: yield chunk.choices[0].message.content else: content = response.choices[0].message.content return content except openai.APIConnectionError | openai.APIStatusError as e: error_msg = ( "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}" ) gr.Error(error_msg) logger.error(error_msg) except openai.RateLimitError: warning_msg = "Hit rate limit. Wait a moment and retry." gr.Warning(warning_msg) logger.warning(warning_msg) def user_message( message: str, chat_history: list[list[str | None]] ) -> tuple[str, list[list[str | None]]]: "Displays user message immediately." return "", chat_history + [[message, None]] def bot_message( chat_history: list[list[str | None]], system_message: str, model_args: dict[str, str | float], ) -> list[list[str | None]]: # Prep messages user_msg = chat_history[-1][0] messages = convert_gradio_to_openai(chat_history[:-1]) messages = ( [{"role": "system", "content": system_message}] + messages + [{"role": "user", "content": user_msg}] ) response = client.chat.completions.create( messages=messages, stream=True, **model_args ) # Streaming chat_history[-1][1] = "" for chunk in response: delta = chunk.choices[0].delta.content if delta: chat_history[-1][1] += delta yield chat_history # LAYOUT with gr.Blocks() as demo: gr.Markdown("# Cognitive Debriefing Prototype") # Hidden values surveyQuestion = gr.Textbox(visible=False) surveyTemplate = gr.Textbox(visible=False) initialMessage = gr.Textbox(visible=False) systemMessage = gr.Textbox(visible=False) modelArgs = gr.State(value={"model": "", "temperature": ""}) userid = gr.Textbox(visible=False, interactive=False) ## RESPONDENT chatDisplay = gr.Chatbot( show_label=False, ) with gr.Row(): chatInput = gr.Textbox( placeholder="Click 'Start Interview' to begin.", interactive=False, show_label=False, scale=10, ) chatSubmit = gr.Button( "", variant="secondary", interactive=False, icon="./arrow_icon.svg", ) startInterview = gr.Button("Start Interview", variant="primary") resetButton = gr.Button("Save and Exit", visible=False, variant="stop") ## INTERACTIONS # Start Interview button startInterview.click( load_config, inputs=None, outputs=[ surveyQuestion, surveyTemplate, initialMessage, modelArgs, userid, ], ).then( update_template, inputs=[surveyQuestion, surveyTemplate], outputs=[systemMessage], ).then( update_template, inputs=[surveyQuestion, initialMessage], outputs=initialMessage, ).then( initialize_interview, inputs=[systemMessage, initialMessage, modelArgs], outputs=[ chatDisplay, chatInput, chatSubmit, startInterview, resetButton, ], ).then( initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate, userid] ) # "Enter" on textbox chatInput.submit( user_message, inputs=[chatInput, chatDisplay], outputs=[chatInput, chatDisplay], queue=False, ).then( bot_message, inputs=[chatDisplay, systemMessage, modelArgs], outputs=[chatDisplay], ).then( save_interview, inputs=[chatDisplay] ) # "Submit" button chatSubmit.click( user_message, inputs=[chatInput, chatDisplay], outputs=[chatInput, chatDisplay], queue=False, ).then( bot_message, inputs=[chatDisplay, systemMessage, modelArgs], outputs=[chatDisplay], ).then( save_interview, inputs=[chatDisplay] ) resetButton.click(save_interview, [chatDisplay]).then( reset_interview, outputs=[chatDisplay, startInterview, resetButton], show_progress=False, ) if __name__ == "__main__": demo.launch()