Spaces:
Runtime error
Runtime error
| """ | |
| Cognitive Debriefing App - Respondent Interface | |
| Author: Dr Musashi Hinck | |
| Respondent-facing app. Reads arguments from request (in form of shareable link) | |
| Change Log: | |
| - 2024.01.16: Continuous logging to wandb, change name of run to `userid` | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import logging | |
| import json | |
| import wandb | |
| import gradio as gr | |
| import openai | |
| from base64 import urlsafe_b64decode | |
| logger = logging.getLogger(__name__) | |
| from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key | |
| # %% Initialization | |
| if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT": | |
| seed_openai_key() | |
| client = openai.OpenAI() | |
| # %% (functions) | |
| def decode_config(config_dta: str) -> dict[str, str | float]: | |
| "Read base64_url encoded json and loads into configuration" | |
| config_str: str = urlsafe_b64decode(config_dta) | |
| config: dict = json.loads(config_str) | |
| return config | |
| def load_config(request: gr.Request): | |
| "Read parameters from request header" | |
| config = decode_config(request.query_params["dta"]) | |
| survey_question = config["question"] | |
| survey_template = config["template"] | |
| initial_message = config["initial_message"] | |
| model_args = {"model": config["model"], "temperature": config["temperature"]} | |
| userid = config["userid"] | |
| return survey_question, survey_template, initial_message, model_args, userid | |
| # Post-loading | |
| def update_template(question: str, template: PromptTemplate | str) -> str: | |
| """ | |
| Updates templates. Currently only accepts a "question" variable, but can add future templating in the future. | |
| """ | |
| if isinstance(template, str): | |
| template = PromptTemplate(template) | |
| if "question" in template.variables: | |
| return template.format(question=question) | |
| else: | |
| return str(template) | |
| def reset_interview() -> tuple[list[list[str | None]], gr.Button, gr.Button]: | |
| wandb.finish() | |
| gr.Info("Interview reset.") | |
| return ( | |
| [], | |
| gr.Button("Start Interview", visible=True), | |
| gr.Button("Reply", visible=False), | |
| gr.Button("Save Survey", visible=False, variant="secondary"), | |
| gr.Button("Save and Exit", visible=False, variant="stop"), | |
| ) | |
| def initialize_interview( | |
| system_message: str, first_question: str, model_args: dict[str, str | float] | |
| ) -> tuple[list[list[str | None]], gr.Textbox, gr.Button, gr.Button]: | |
| "Read system prompt and start interview" | |
| if len(first_question) == 0: | |
| first_question = call_openai( | |
| [], system_message, client, model_args, stream=False | |
| ) | |
| # Use fixed prompt | |
| chat_history = [[None, first_question]] | |
| return ( | |
| chat_history, | |
| gr.Textbox( | |
| placeholder="Type response here.", interactive=True, show_label=False | |
| ), | |
| gr.Button(variant="primary", interactive=True), | |
| gr.Button("Start Interview", visible=False), | |
| gr.Button("Save and Exit", visible=True, variant="stop"), | |
| ) | |
| def initialize_tracker( | |
| model_args: dict[str, str | float], | |
| question: str, | |
| template: PromptTemplate, | |
| userid=str, | |
| ) -> None: | |
| "Initializes wandb run for interview" | |
| run_config = model_args | { | |
| "question": question, | |
| "template": str(template), | |
| "userid": userid, | |
| } | |
| wandb.init( | |
| project="cognitive-debrief", name=userid, config=run_config, tags=["dev"] | |
| ) | |
| def save_interview( | |
| chat_history: list[list[str | None]], | |
| ) -> None: | |
| chat_data = [] | |
| for pair in chat_history: | |
| for i, role in enumerate(["user", "bot"]): | |
| if pair[i] is not None: | |
| chat_data += [[role, pair[i]]] | |
| chat_table = wandb.Table(data=chat_data, columns=["role", "message"]) | |
| logger.info("Uploading interview transcript to WandB...") | |
| wandb.log({"chat_history": chat_table}) | |
| logger.info("Uploading complete.") | |
| def call_openai( | |
| messages: list[dict[str, str]], | |
| system_message: str | None, | |
| client: openai.Client, | |
| model_args: dict, | |
| stream: bool = False, | |
| ): | |
| "Utility function for calling OpenAI chat. Expects formatted messages." | |
| if not messages: | |
| messages = [] | |
| if system_message: | |
| messages = [{"role": "system", "content": system_message}] + messages | |
| try: | |
| response = client.chat.completions.create( | |
| messages=messages, **model_args, stream=stream | |
| ) | |
| if stream: | |
| for chunk in response: | |
| yield chunk.choices[0].message.content | |
| else: | |
| content = response.choices[0].message.content | |
| return content | |
| except openai.APIConnectionError | openai.APIStatusError as e: | |
| error_msg = ( | |
| "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}" | |
| ) | |
| gr.Error(error_msg) | |
| logger.error(error_msg) | |
| except openai.RateLimitError: | |
| warning_msg = "Hit rate limit. Wait a moment and retry." | |
| gr.Warning(warning_msg) | |
| logger.warning(warning_msg) | |
| def user_message( | |
| message: str, chat_history: list[list[str | None]] | |
| ) -> tuple[str, list[list[str | None]]]: | |
| "Displays user message immediately." | |
| return "", chat_history + [[message, None]] | |
| def bot_message( | |
| chat_history: list[list[str | None]], | |
| system_message: str, | |
| model_args: dict[str, str | float], | |
| ) -> list[list[str | None]]: | |
| # Prep messages | |
| user_msg = chat_history[-1][0] | |
| messages = convert_gradio_to_openai(chat_history[:-1]) | |
| messages = ( | |
| [{"role": "system", "content": system_message}] | |
| + messages | |
| + [{"role": "user", "content": user_msg}] | |
| ) | |
| response = client.chat.completions.create( | |
| messages=messages, stream=True, **model_args | |
| ) | |
| # Streaming | |
| chat_history[-1][1] = "" | |
| for chunk in response: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| chat_history[-1][1] += delta | |
| yield chat_history | |
| # LAYOUT | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Cognitive Debriefing Prototype") | |
| # Hidden values | |
| surveyQuestion = gr.Textbox(visible=False) | |
| surveyTemplate = gr.Textbox(visible=False) | |
| initialMessage = gr.Textbox(visible=False) | |
| systemMessage = gr.Textbox(visible=False) | |
| modelArgs = gr.State(value={"model": "", "temperature": ""}) | |
| userid = gr.Textbox(visible=False, interactive=False) | |
| ## RESPONDENT | |
| chatDisplay = gr.Chatbot( | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| chatInput = gr.Textbox( | |
| placeholder="Click 'Start Interview' to begin.", | |
| interactive=False, | |
| show_label=False, | |
| scale=10, | |
| ) | |
| chatSubmit = gr.Button( | |
| "", | |
| variant="secondary", | |
| interactive=False, | |
| icon="./arrow_icon.svg", | |
| ) | |
| startInterview = gr.Button("Start Interview", variant="primary") | |
| resetButton = gr.Button("Save and Exit", visible=False, variant="stop") | |
| ## INTERACTIONS | |
| # Start Interview button | |
| startInterview.click( | |
| load_config, | |
| inputs=None, | |
| outputs=[ | |
| surveyQuestion, | |
| surveyTemplate, | |
| initialMessage, | |
| modelArgs, | |
| userid, | |
| ], | |
| ).then( | |
| update_template, | |
| inputs=[surveyQuestion, surveyTemplate], | |
| outputs=[systemMessage], | |
| ).then( | |
| update_template, | |
| inputs=[surveyQuestion, initialMessage], | |
| outputs=initialMessage, | |
| ).then( | |
| initialize_interview, | |
| inputs=[systemMessage, initialMessage, modelArgs], | |
| outputs=[ | |
| chatDisplay, | |
| chatInput, | |
| chatSubmit, | |
| startInterview, | |
| resetButton, | |
| ], | |
| ).then( | |
| initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate, userid] | |
| ) | |
| # "Enter" on textbox | |
| chatInput.submit( | |
| user_message, | |
| inputs=[chatInput, chatDisplay], | |
| outputs=[chatInput, chatDisplay], | |
| queue=False, | |
| ).then( | |
| bot_message, | |
| inputs=[chatDisplay, systemMessage, modelArgs], | |
| outputs=[chatDisplay], | |
| ).then( | |
| save_interview, inputs=[chatDisplay] | |
| ) | |
| # "Submit" button | |
| chatSubmit.click( | |
| user_message, | |
| inputs=[chatInput, chatDisplay], | |
| outputs=[chatInput, chatDisplay], | |
| queue=False, | |
| ).then( | |
| bot_message, | |
| inputs=[chatDisplay, systemMessage, modelArgs], | |
| outputs=[chatDisplay], | |
| ).then( | |
| save_interview, inputs=[chatDisplay] | |
| ) | |
| resetButton.click(save_interview, [chatDisplay]).then( | |
| reset_interview, | |
| outputs=[chatDisplay, startInterview, resetButton], | |
| show_progress=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |