Spaces:
Runtime error
Runtime error
Updates - mostly wandb
Browse files- app.py +70 -36
- arrow_icon.svg +1 -0
- utils.py +17 -0
app.py
CHANGED
|
@@ -6,12 +6,14 @@ Author: Dr Musashi Hinck
|
|
| 6 |
|
| 7 |
Respondent-facing app. Reads arguments from request (in form of shareable link)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 15 |
import logging
|
| 16 |
import json
|
| 17 |
import wandb
|
|
@@ -22,10 +24,12 @@ from base64 import urlsafe_b64decode
|
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
-
from utils import PromptTemplate, convert_gradio_to_openai
|
| 26 |
|
| 27 |
|
| 28 |
# %% Initialization
|
|
|
|
|
|
|
| 29 |
client = openai.OpenAI()
|
| 30 |
|
| 31 |
|
|
@@ -39,14 +43,15 @@ def decode_config(config_dta: str) -> dict[str, str | float]:
|
|
| 39 |
|
| 40 |
def load_config(request: gr.Request):
|
| 41 |
"Read parameters from request header"
|
| 42 |
-
config = decode_config(request.query_params[
|
| 43 |
-
survey_question = config[
|
| 44 |
-
survey_template = config[
|
| 45 |
-
initial_message = config[
|
| 46 |
-
model_args = {
|
| 47 |
-
userid = config[
|
| 48 |
return survey_question, survey_template, initial_message, model_args, userid
|
| 49 |
|
|
|
|
| 50 |
# Post-loading
|
| 51 |
def update_template(question: str, template: PromptTemplate | str) -> str:
|
| 52 |
"""
|
|
@@ -87,17 +92,27 @@ def initialize_interview(
|
|
| 87 |
gr.Textbox(
|
| 88 |
placeholder="Type response here.", interactive=True, show_label=False
|
| 89 |
),
|
|
|
|
| 90 |
gr.Button("Start Interview", visible=False),
|
| 91 |
gr.Button("Save and Exit", visible=True, variant="stop"),
|
| 92 |
)
|
| 93 |
|
| 94 |
|
| 95 |
def initialize_tracker(
|
| 96 |
-
model_args: dict[str, str | float],
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
"Initializes wandb run for interview"
|
| 99 |
-
run_config = model_args | {
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def save_interview(
|
|
@@ -109,9 +124,9 @@ def save_interview(
|
|
| 109 |
if pair[i] is not None:
|
| 110 |
chat_data += [[role, pair[i]]]
|
| 111 |
chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
|
| 112 |
-
|
| 113 |
wandb.log({"chat_history": chat_table})
|
| 114 |
-
|
| 115 |
|
| 116 |
|
| 117 |
def call_openai(
|
|
@@ -190,28 +205,30 @@ with gr.Blocks() as demo:
|
|
| 190 |
initialMessage = gr.Textbox(visible=False)
|
| 191 |
systemMessage = gr.Textbox(visible=False)
|
| 192 |
modelArgs = gr.State(value={"model": "", "temperature": ""})
|
| 193 |
-
userid = gr.Textbox(visible=False)
|
| 194 |
-
|
| 195 |
-
# Debugging
|
| 196 |
-
# with gr.Accordion("Debugging Panel", open=False):
|
| 197 |
-
# debugPane = gr.Textbox(show_label=False, lines=8)
|
| 198 |
-
# debugRequest = gr.Button('Read Request')
|
| 199 |
-
# debugRequest.click(load_config, outputs=[debugPane])
|
| 200 |
-
|
| 201 |
|
| 202 |
## RESPONDENT
|
| 203 |
chatDisplay = gr.Chatbot(
|
| 204 |
show_label=False,
|
| 205 |
)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
startInterview = gr.Button("Start Interview", variant="primary")
|
| 212 |
resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
|
| 213 |
|
| 214 |
## INTERACTIONS
|
|
|
|
| 215 |
startInterview.click(
|
| 216 |
load_config,
|
| 217 |
inputs=None,
|
|
@@ -221,7 +238,7 @@ with gr.Blocks() as demo:
|
|
| 221 |
initialMessage,
|
| 222 |
modelArgs,
|
| 223 |
userid,
|
| 224 |
-
]
|
| 225 |
).then(
|
| 226 |
update_template,
|
| 227 |
inputs=[surveyQuestion, surveyTemplate],
|
|
@@ -236,25 +253,43 @@ with gr.Blocks() as demo:
|
|
| 236 |
outputs=[
|
| 237 |
chatDisplay,
|
| 238 |
chatInput,
|
|
|
|
| 239 |
startInterview,
|
| 240 |
resetButton,
|
| 241 |
],
|
| 242 |
-
).then(
|
|
|
|
|
|
|
| 243 |
|
|
|
|
| 244 |
chatInput.submit(
|
| 245 |
user_message,
|
| 246 |
inputs=[chatInput, chatDisplay],
|
| 247 |
outputs=[chatInput, chatDisplay],
|
| 248 |
-
queue=False
|
| 249 |
).then(
|
| 250 |
bot_message,
|
| 251 |
inputs=[chatDisplay, systemMessage, modelArgs],
|
| 252 |
-
outputs=[chatDisplay]
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
).then(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
reset_interview,
|
| 259 |
outputs=[chatDisplay, startInterview, resetButton],
|
| 260 |
show_progress=False,
|
|
@@ -262,5 +297,4 @@ with gr.Blocks() as demo:
|
|
| 262 |
|
| 263 |
|
| 264 |
if __name__ == "__main__":
|
| 265 |
-
# Testing
|
| 266 |
demo.launch()
|
|
|
|
| 6 |
|
| 7 |
Respondent-facing app. Reads arguments from request (in form of shareable link)
|
| 8 |
|
| 9 |
+
Change Log:
|
| 10 |
+
|
| 11 |
+
- 2024.01.16: Continuous logging to wandb, change name of run to `userid`
|
| 12 |
|
| 13 |
"""
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
+
import os
|
| 17 |
import logging
|
| 18 |
import json
|
| 19 |
import wandb
|
|
|
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
+
from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
|
| 28 |
|
| 29 |
|
| 30 |
# %% Initialization
|
| 31 |
+
if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
|
| 32 |
+
seed_openai_key()
|
| 33 |
client = openai.OpenAI()
|
| 34 |
|
| 35 |
|
|
|
|
| 43 |
|
| 44 |
def load_config(request: gr.Request):
|
| 45 |
"Read parameters from request header"
|
| 46 |
+
config = decode_config(request.query_params["dta"])
|
| 47 |
+
survey_question = config["question"]
|
| 48 |
+
survey_template = config["template"]
|
| 49 |
+
initial_message = config["initial_message"]
|
| 50 |
+
model_args = {"model": config["model"], "temperature": config["temperature"]}
|
| 51 |
+
userid = config["userid"]
|
| 52 |
return survey_question, survey_template, initial_message, model_args, userid
|
| 53 |
|
| 54 |
+
|
| 55 |
# Post-loading
|
| 56 |
def update_template(question: str, template: PromptTemplate | str) -> str:
|
| 57 |
"""
|
|
|
|
| 92 |
gr.Textbox(
|
| 93 |
placeholder="Type response here.", interactive=True, show_label=False
|
| 94 |
),
|
| 95 |
+
gr.Button(variant="primary", interactive=True),
|
| 96 |
gr.Button("Start Interview", visible=False),
|
| 97 |
gr.Button("Save and Exit", visible=True, variant="stop"),
|
| 98 |
)
|
| 99 |
|
| 100 |
|
| 101 |
def initialize_tracker(
|
| 102 |
+
model_args: dict[str, str | float],
|
| 103 |
+
question: str,
|
| 104 |
+
template: PromptTemplate,
|
| 105 |
+
userid=str,
|
| 106 |
+
) -> None:
|
| 107 |
"Initializes wandb run for interview"
|
| 108 |
+
run_config = model_args | {
|
| 109 |
+
"question": question,
|
| 110 |
+
"template": str(template),
|
| 111 |
+
"userid": userid,
|
| 112 |
+
}
|
| 113 |
+
wandb.init(
|
| 114 |
+
project="cognitive-debrief", name=userid, config=run_config, tags=["dev"]
|
| 115 |
+
)
|
| 116 |
|
| 117 |
|
| 118 |
def save_interview(
|
|
|
|
| 124 |
if pair[i] is not None:
|
| 125 |
chat_data += [[role, pair[i]]]
|
| 126 |
chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
|
| 127 |
+
logger.info("Uploading interview transcript to WandB...")
|
| 128 |
wandb.log({"chat_history": chat_table})
|
| 129 |
+
logger.info("Uploading complete.")
|
| 130 |
|
| 131 |
|
| 132 |
def call_openai(
|
|
|
|
| 205 |
initialMessage = gr.Textbox(visible=False)
|
| 206 |
systemMessage = gr.Textbox(visible=False)
|
| 207 |
modelArgs = gr.State(value={"model": "", "temperature": ""})
|
| 208 |
+
userid = gr.Textbox(visible=False, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
## RESPONDENT
|
| 211 |
chatDisplay = gr.Chatbot(
|
| 212 |
show_label=False,
|
| 213 |
)
|
| 214 |
+
with gr.Row():
|
| 215 |
+
chatInput = gr.Textbox(
|
| 216 |
+
placeholder="Click 'Start Interview' to begin.",
|
| 217 |
+
interactive=False,
|
| 218 |
+
show_label=False,
|
| 219 |
+
scale=10,
|
| 220 |
+
)
|
| 221 |
+
chatSubmit = gr.Button(
|
| 222 |
+
"",
|
| 223 |
+
variant="secondary",
|
| 224 |
+
interactive=False,
|
| 225 |
+
icon="./arrow_icon.svg",
|
| 226 |
+
)
|
| 227 |
startInterview = gr.Button("Start Interview", variant="primary")
|
| 228 |
resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
|
| 229 |
|
| 230 |
## INTERACTIONS
|
| 231 |
+
# Start Interview button
|
| 232 |
startInterview.click(
|
| 233 |
load_config,
|
| 234 |
inputs=None,
|
|
|
|
| 238 |
initialMessage,
|
| 239 |
modelArgs,
|
| 240 |
userid,
|
| 241 |
+
],
|
| 242 |
).then(
|
| 243 |
update_template,
|
| 244 |
inputs=[surveyQuestion, surveyTemplate],
|
|
|
|
| 253 |
outputs=[
|
| 254 |
chatDisplay,
|
| 255 |
chatInput,
|
| 256 |
+
chatSubmit,
|
| 257 |
startInterview,
|
| 258 |
resetButton,
|
| 259 |
],
|
| 260 |
+
).then(
|
| 261 |
+
initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate, userid]
|
| 262 |
+
)
|
| 263 |
|
| 264 |
+
# "Enter" on textbox
|
| 265 |
chatInput.submit(
|
| 266 |
user_message,
|
| 267 |
inputs=[chatInput, chatDisplay],
|
| 268 |
outputs=[chatInput, chatDisplay],
|
| 269 |
+
queue=False,
|
| 270 |
).then(
|
| 271 |
bot_message,
|
| 272 |
inputs=[chatDisplay, systemMessage, modelArgs],
|
| 273 |
+
outputs=[chatDisplay],
|
| 274 |
+
).then(
|
| 275 |
+
save_interview, inputs=[chatDisplay]
|
| 276 |
+
)
|
| 277 |
|
| 278 |
+
# "Submit" button
|
| 279 |
+
chatSubmit.click(
|
| 280 |
+
user_message,
|
| 281 |
+
inputs=[chatInput, chatDisplay],
|
| 282 |
+
outputs=[chatInput, chatDisplay],
|
| 283 |
+
queue=False,
|
| 284 |
+
).then(
|
| 285 |
+
bot_message,
|
| 286 |
+
inputs=[chatDisplay, systemMessage, modelArgs],
|
| 287 |
+
outputs=[chatDisplay],
|
| 288 |
).then(
|
| 289 |
+
save_interview, inputs=[chatDisplay]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
resetButton.click(save_interview, [chatDisplay]).then(
|
| 293 |
reset_interview,
|
| 294 |
outputs=[chatDisplay, startInterview, resetButton],
|
| 295 |
show_progress=False,
|
|
|
|
| 297 |
|
| 298 |
|
| 299 |
if __name__ == "__main__":
|
|
|
|
| 300 |
demo.launch()
|
arrow_icon.svg
ADDED
|
|
utils.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from string import Formatter
|
| 4 |
|
| 5 |
|
|
@@ -88,3 +91,17 @@ def convert_openai_to_gradio(
|
|
| 88 |
for i in range(0, len(messages), 2):
|
| 89 |
chat_history.append([messages[i]["content"], messages[i + 1]["content"]])
|
| 90 |
return chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
+
from configparser import ConfigParser
|
| 5 |
+
from pathlib import Path
|
| 6 |
from string import Formatter
|
| 7 |
|
| 8 |
|
|
|
|
| 91 |
for i in range(0, len(messages), 2):
|
| 92 |
chat_history.append([messages[i]["content"], messages[i + 1]["content"]])
|
| 93 |
return chat_history
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def seed_openai_key(cfg: str = "~/.cfg/openai.cfg") -> None:
|
| 97 |
+
"""
|
| 98 |
+
Reads OpenAI key from config file and adds it to environment.
|
| 99 |
+
Assumed config location is "~/.cfg/openai.cfg"
|
| 100 |
+
"""
|
| 101 |
+
# Get OpenAI Key
|
| 102 |
+
config = ConfigParser()
|
| 103 |
+
try:
|
| 104 |
+
config.read(Path(cfg).expanduser())
|
| 105 |
+
except:
|
| 106 |
+
raise ValueError(f"Could not using read file at: {cfg}.")
|
| 107 |
+
os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
|