drmjh's picture
Update log name to bot
bfae8ff
"""
Cognitive Debriefing App - Respondent Interface
Author: Dr Musashi Hinck
Respondent-facing app. Reads arguments from request (in form of shareable link)
Change Log:
- 2024.01.16: Continuous logging to wandb, change name of run to `userid`
"""
from __future__ import annotations
import os
import logging
import json
import wandb
import gradio as gr
import openai
from base64 import urlsafe_b64decode
logger = logging.getLogger(__name__)
from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
# %% Initialization
if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
seed_openai_key()
client = openai.OpenAI()
# %% (functions)
def decode_config(config_dta: str) -> dict[str, str | float]:
"Read base64_url encoded json and loads into configuration"
config_str: str = urlsafe_b64decode(config_dta)
config: dict = json.loads(config_str)
return config
def load_config(request: gr.Request):
"Read parameters from request header"
config = decode_config(request.query_params["dta"])
survey_question = config["question"]
survey_template = config["template"]
initial_message = config["initial_message"]
model_args = {"model": config["model"], "temperature": config["temperature"]}
userid = config["userid"]
return survey_question, survey_template, initial_message, model_args, userid
# Post-loading
def update_template(question: str, template: PromptTemplate | str) -> str:
"""
Updates templates. Currently only accepts a "question" variable, but can add future templating in the future.
"""
if isinstance(template, str):
template = PromptTemplate(template)
if "question" in template.variables:
return template.format(question=question)
else:
return str(template)
def reset_interview() -> tuple[list[list[str | None]], gr.Button, gr.Button]:
wandb.finish()
gr.Info("Interview reset.")
return (
[],
gr.Button("Start Interview", visible=True),
gr.Button("Reply", visible=False),
gr.Button("Save Survey", visible=False, variant="secondary"),
gr.Button("Save and Exit", visible=False, variant="stop"),
)
def initialize_interview(
system_message: str, first_question: str, model_args: dict[str, str | float]
) -> tuple[list[list[str | None]], gr.Textbox, gr.Button, gr.Button]:
"Read system prompt and start interview"
if len(first_question) == 0:
first_question = call_openai(
[], system_message, client, model_args, stream=False
)
# Use fixed prompt
chat_history = [[None, first_question]]
return (
chat_history,
gr.Textbox(
placeholder="Type response here.", interactive=True, show_label=False
),
gr.Button(variant="primary", interactive=True),
gr.Button("Start Interview", visible=False),
gr.Button("Save and Exit", visible=True, variant="stop"),
)
def initialize_tracker(
model_args: dict[str, str | float],
question: str,
template: PromptTemplate,
userid=str,
) -> None:
"Initializes wandb run for interview"
run_config = model_args | {
"question": question,
"template": str(template),
"userid": userid,
}
wandb.init(
project="cognitive-debrief", name=userid, config=run_config, tags=["dev"]
)
def save_interview(
chat_history: list[list[str | None]],
) -> None:
chat_data = []
for pair in chat_history:
for i, role in enumerate(["user", "bot"]):
if pair[i] is not None:
chat_data += [[role, pair[i]]]
chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
logger.info("Uploading interview transcript to WandB...")
wandb.log({"chat_history": chat_table})
logger.info("Uploading complete.")
def call_openai(
messages: list[dict[str, str]],
system_message: str | None,
client: openai.Client,
model_args: dict,
stream: bool = False,
):
"Utility function for calling OpenAI chat. Expects formatted messages."
if not messages:
messages = []
if system_message:
messages = [{"role": "system", "content": system_message}] + messages
try:
response = client.chat.completions.create(
messages=messages, **model_args, stream=stream
)
if stream:
for chunk in response:
yield chunk.choices[0].message.content
else:
content = response.choices[0].message.content
return content
except openai.APIConnectionError | openai.APIStatusError as e:
error_msg = (
"API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
)
gr.Error(error_msg)
logger.error(error_msg)
except openai.RateLimitError:
warning_msg = "Hit rate limit. Wait a moment and retry."
gr.Warning(warning_msg)
logger.warning(warning_msg)
def user_message(
message: str, chat_history: list[list[str | None]]
) -> tuple[str, list[list[str | None]]]:
"Displays user message immediately."
return "", chat_history + [[message, None]]
def bot_message(
chat_history: list[list[str | None]],
system_message: str,
model_args: dict[str, str | float],
) -> list[list[str | None]]:
# Prep messages
user_msg = chat_history[-1][0]
messages = convert_gradio_to_openai(chat_history[:-1])
messages = (
[{"role": "system", "content": system_message}]
+ messages
+ [{"role": "user", "content": user_msg}]
)
response = client.chat.completions.create(
messages=messages, stream=True, **model_args
)
# Streaming
chat_history[-1][1] = ""
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
chat_history[-1][1] += delta
yield chat_history
# LAYOUT
with gr.Blocks() as demo:
gr.Markdown("# Cognitive Debriefing Prototype")
# Hidden values
surveyQuestion = gr.Textbox(visible=False)
surveyTemplate = gr.Textbox(visible=False)
initialMessage = gr.Textbox(visible=False)
systemMessage = gr.Textbox(visible=False)
modelArgs = gr.State(value={"model": "", "temperature": ""})
userid = gr.Textbox(visible=False, interactive=False)
## RESPONDENT
chatDisplay = gr.Chatbot(
show_label=False,
)
with gr.Row():
chatInput = gr.Textbox(
placeholder="Click 'Start Interview' to begin.",
interactive=False,
show_label=False,
scale=10,
)
chatSubmit = gr.Button(
"",
variant="secondary",
interactive=False,
icon="./arrow_icon.svg",
)
startInterview = gr.Button("Start Interview", variant="primary")
resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
## INTERACTIONS
# Start Interview button
startInterview.click(
load_config,
inputs=None,
outputs=[
surveyQuestion,
surveyTemplate,
initialMessage,
modelArgs,
userid,
],
).then(
update_template,
inputs=[surveyQuestion, surveyTemplate],
outputs=[systemMessage],
).then(
update_template,
inputs=[surveyQuestion, initialMessage],
outputs=initialMessage,
).then(
initialize_interview,
inputs=[systemMessage, initialMessage, modelArgs],
outputs=[
chatDisplay,
chatInput,
chatSubmit,
startInterview,
resetButton,
],
).then(
initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate, userid]
)
# "Enter" on textbox
chatInput.submit(
user_message,
inputs=[chatInput, chatDisplay],
outputs=[chatInput, chatDisplay],
queue=False,
).then(
bot_message,
inputs=[chatDisplay, systemMessage, modelArgs],
outputs=[chatDisplay],
).then(
save_interview, inputs=[chatDisplay]
)
# "Submit" button
chatSubmit.click(
user_message,
inputs=[chatInput, chatDisplay],
outputs=[chatInput, chatDisplay],
queue=False,
).then(
bot_message,
inputs=[chatDisplay, systemMessage, modelArgs],
outputs=[chatDisplay],
).then(
save_interview, inputs=[chatDisplay]
)
resetButton.click(save_interview, [chatDisplay]).then(
reset_interview,
outputs=[chatDisplay, startInterview, resetButton],
show_progress=False,
)
if __name__ == "__main__":
demo.launch()