drmjh commited on
Commit
7549871
·
1 Parent(s): e03a0ea

Copied from CD Interface

Browse files
Files changed (4) hide show
  1. app.py +300 -0
  2. arrow_icon.svg +1 -0
  3. requirements.txt +3 -0
  4. utils.py +107 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cognitive Debriefing App - Respondent Interface
3
+
4
+ Author: Dr Musashi Hinck
5
+
6
+
7
+ Respondent-facing app. Reads arguments from request (in form of shareable link)
8
+
9
+ Change Log:
10
+
11
+ - 2024.01.16: Continuous logging to wandb, change name of run to `userid`
12
+
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ import logging
18
+ import json
19
+ import wandb
20
+ import gradio as gr
21
+ import openai
22
+
23
+ from base64 import urlsafe_b64decode
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
28
+
29
+
30
+ # %% Initialization
31
+ if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
32
+ seed_openai_key()
33
+ client = openai.OpenAI()
34
+
35
+
36
+ # %% (functions)
37
+ def decode_config(config_dta: str) -> dict[str, str | float]:
38
+ "Read base64_url encoded json and loads into configuration"
39
+ config_str: str = urlsafe_b64decode(config_dta)
40
+ config: dict = json.loads(config_str)
41
+ return config
42
+
43
+
44
+ def load_config(request: gr.Request):
45
+ "Read parameters from request header"
46
+ config = decode_config(request.query_params["dta"])
47
+ survey_question = config["question"]
48
+ survey_template = config["template"]
49
+ initial_message = config["initial_message"]
50
+ model_args = {"model": config["model"], "temperature": config["temperature"]}
51
+ userid = config["userid"]
52
+ return survey_question, survey_template, initial_message, model_args, userid
53
+
54
+
55
+ # Post-loading
56
+ def update_template(question: str, template: PromptTemplate | str) -> str:
57
+ """
58
+ Updates templates. Currently only accepts a "question" variable, but can add future templating in the future.
59
+ """
60
+ if isinstance(template, str):
61
+ template = PromptTemplate(template)
62
+ if "question" in template.variables:
63
+ return template.format(question=question)
64
+ else:
65
+ return str(template)
66
+
67
+
68
+ def reset_interview() -> tuple[list[list[str | None]], gr.Button, gr.Button]:
69
+ wandb.finish()
70
+ gr.Info("Interview reset.")
71
+ return (
72
+ [],
73
+ gr.Button("Start Interview", visible=True),
74
+ gr.Button("Reply", visible=False),
75
+ gr.Button("Save Survey", visible=False, variant="secondary"),
76
+ gr.Button("Save and Exit", visible=False, variant="stop"),
77
+ )
78
+
79
+
80
+ def initialize_interview(
81
+ system_message: str, first_question: str, model_args: dict[str, str | float]
82
+ ) -> tuple[list[list[str | None]], gr.Textbox, gr.Button, gr.Button]:
83
+ "Read system prompt and start interview"
84
+ if len(first_question) == 0:
85
+ first_question = call_openai(
86
+ [], system_message, client, model_args, stream=False
87
+ )
88
+ # Use fixed prompt
89
+ chat_history = [[None, first_question]]
90
+ return (
91
+ chat_history,
92
+ gr.Textbox(
93
+ placeholder="Type response here.", interactive=True, show_label=False
94
+ ),
95
+ gr.Button(variant="primary", interactive=True),
96
+ gr.Button("Start Interview", visible=False),
97
+ gr.Button("Save and Exit", visible=True, variant="stop"),
98
+ )
99
+
100
+
101
+ def initialize_tracker(
102
+ model_args: dict[str, str | float],
103
+ question: str,
104
+ template: PromptTemplate,
105
+ userid=str,
106
+ ) -> None:
107
+ "Initializes wandb run for interview"
108
+ run_config = model_args | {
109
+ "question": question,
110
+ "template": str(template),
111
+ "userid": userid,
112
+ }
113
+ wandb.init(
114
+ project="cognitive-debrief", name=userid, config=run_config, tags=["dev"]
115
+ )
116
+
117
+
118
+ def save_interview(
119
+ chat_history: list[list[str | None]],
120
+ ) -> None:
121
+ chat_data = []
122
+ for pair in chat_history:
123
+ for i, role in enumerate(["user", "bot"]):
124
+ if pair[i] is not None:
125
+ chat_data += [[role, pair[i]]]
126
+ chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
127
+ logger.info("Uploading interview transcript to WandB...")
128
+ wandb.log({"chat_history": chat_table})
129
+ logger.info("Uploading complete.")
130
+
131
+
132
+ def call_openai(
133
+ messages: list[dict[str, str]],
134
+ system_message: str | None,
135
+ client: openai.Client,
136
+ model_args: dict,
137
+ stream: bool = False,
138
+ ):
139
+ "Utility function for calling OpenAI chat. Expects formatted messages."
140
+ if not messages:
141
+ messages = []
142
+ if system_message:
143
+ messages = [{"role": "system", "content": system_message}] + messages
144
+ try:
145
+ response = client.chat.completions.create(
146
+ messages=messages, **model_args, stream=stream
147
+ )
148
+ if stream:
149
+ for chunk in response:
150
+ yield chunk.choices[0].message.content
151
+ else:
152
+ content = response.choices[0].message.content
153
+ return content
154
+ except openai.APIConnectionError | openai.APIStatusError as e:
155
+ error_msg = (
156
+ "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
157
+ )
158
+ gr.Error(error_msg)
159
+ logger.error(error_msg)
160
+ except openai.RateLimitError:
161
+ warning_msg = "Hit rate limit. Wait a moment and retry."
162
+ gr.Warning(warning_msg)
163
+ logger.warning(warning_msg)
164
+
165
+
166
+ def user_message(
167
+ message: str, chat_history: list[list[str | None]]
168
+ ) -> tuple[str, list[list[str | None]]]:
169
+ "Displays user message immediately."
170
+ return "", chat_history + [[message, None]]
171
+
172
+
173
+ def bot_message(
174
+ chat_history: list[list[str | None]],
175
+ system_message: str,
176
+ model_args: dict[str, str | float],
177
+ ) -> list[list[str | None]]:
178
+ # Prep messages
179
+ user_msg = chat_history[-1][0]
180
+ messages = convert_gradio_to_openai(chat_history[:-1])
181
+ messages = (
182
+ [{"role": "system", "content": system_message}]
183
+ + messages
184
+ + [{"role": "user", "content": user_msg}]
185
+ )
186
+ response = client.chat.completions.create(
187
+ messages=messages, stream=True, **model_args
188
+ )
189
+ # Streaming
190
+ chat_history[-1][1] = ""
191
+ for chunk in response:
192
+ delta = chunk.choices[0].delta.content
193
+ if delta:
194
+ chat_history[-1][1] += delta
195
+ yield chat_history
196
+
197
+
198
+ # LAYOUT
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown("# Cognitive Debriefing Prototype")
201
+
202
+ # Hidden values
203
+ surveyQuestion = gr.Textbox(visible=False)
204
+ surveyTemplate = gr.Textbox(visible=False)
205
+ initialMessage = gr.Textbox(visible=False)
206
+ systemMessage = gr.Textbox(visible=False)
207
+ modelArgs = gr.State(value={"model": "", "temperature": ""})
208
+ userid = gr.Textbox(visible=False, interactive=False)
209
+
210
+ ## RESPONDENT
211
+ chatDisplay = gr.Chatbot(
212
+ show_label=False,
213
+ )
214
+ with gr.Row():
215
+ chatInput = gr.Textbox(
216
+ placeholder="Click 'Start Interview' to begin.",
217
+ interactive=False,
218
+ show_label=False,
219
+ scale=10,
220
+ )
221
+ chatSubmit = gr.Button(
222
+ "",
223
+ variant="secondary",
224
+ interactive=False,
225
+ icon="./arrow_icon.svg",
226
+ )
227
+ startInterview = gr.Button("Start Interview", variant="primary")
228
+ resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
229
+
230
+ ## INTERACTIONS
231
+ # Start Interview button
232
+ startInterview.click(
233
+ load_config,
234
+ inputs=None,
235
+ outputs=[
236
+ surveyQuestion,
237
+ surveyTemplate,
238
+ initialMessage,
239
+ modelArgs,
240
+ userid,
241
+ ],
242
+ ).then(
243
+ update_template,
244
+ inputs=[surveyQuestion, surveyTemplate],
245
+ outputs=[systemMessage],
246
+ ).then(
247
+ update_template,
248
+ inputs=[surveyQuestion, initialMessage],
249
+ outputs=initialMessage,
250
+ ).then(
251
+ initialize_interview,
252
+ inputs=[systemMessage, initialMessage, modelArgs],
253
+ outputs=[
254
+ chatDisplay,
255
+ chatInput,
256
+ chatSubmit,
257
+ startInterview,
258
+ resetButton,
259
+ ],
260
+ ).then(
261
+ initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate, userid]
262
+ )
263
+
264
+ # "Enter" on textbox
265
+ chatInput.submit(
266
+ user_message,
267
+ inputs=[chatInput, chatDisplay],
268
+ outputs=[chatInput, chatDisplay],
269
+ queue=False,
270
+ ).then(
271
+ bot_message,
272
+ inputs=[chatDisplay, systemMessage, modelArgs],
273
+ outputs=[chatDisplay],
274
+ ).then(
275
+ save_interview, inputs=[chatDisplay]
276
+ )
277
+
278
+ # "Submit" button
279
+ chatSubmit.click(
280
+ user_message,
281
+ inputs=[chatInput, chatDisplay],
282
+ outputs=[chatInput, chatDisplay],
283
+ queue=False,
284
+ ).then(
285
+ bot_message,
286
+ inputs=[chatDisplay, systemMessage, modelArgs],
287
+ outputs=[chatDisplay],
288
+ ).then(
289
+ save_interview, inputs=[chatDisplay]
290
+ )
291
+
292
+ resetButton.click(save_interview, [chatDisplay]).then(
293
+ reset_interview,
294
+ outputs=[chatDisplay, startInterview, resetButton],
295
+ show_progress=False,
296
+ )
297
+
298
+
299
+ if __name__ == "__main__":
300
+ demo.launch()
arrow_icon.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ openai
3
+ wandb
utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from configparser import ConfigParser
5
+ from pathlib import Path
6
+ from string import Formatter
7
+
8
+
9
+ # General Class
10
+ class PromptTemplate(str):
11
+ """More robust String Formatter. Takes a string and parses out the keywords."""
12
+
13
+ def __init__(self, template) -> None:
14
+ self.template: str = template
15
+ self.variables: list[str] = self.parse_template()
16
+
17
+ def parse_template(self) -> list[str]:
18
+ "Returns template variables"
19
+ return [
20
+ fn for _, fn, _, _ in Formatter().parse(self.template) if fn is not None
21
+ ]
22
+
23
+ def format(self, *args, **kwargs) -> str:
24
+ """
25
+ Formats the template string with the given arguments.
26
+ Provides slightly more informative error handling.
27
+
28
+ :param args: Positional arguments for unnamed placeholders.
29
+ :param kwargs: Keyword arguments for named placeholders.
30
+ :return: Formatted string.
31
+ :raises: ValueError if arguments do not match template variables.
32
+ """
33
+ # If keyword arguments are provided, check if they match the template variables
34
+ if kwargs and set(kwargs) != set(self.variables):
35
+ raise ValueError("Keyword arguments do not match template variables.")
36
+
37
+ # If positional arguments are provided, check if their count matches the number of template variables
38
+ if args and len(args) != len(self.variables):
39
+ raise ValueError(
40
+ "Number of arguments does not match the number of template variables."
41
+ )
42
+
43
+ # Check if a dictionary is passed as a single positional argument
44
+ if len(args) == 1 and isinstance(args[0], dict):
45
+ arg_dict = args[0]
46
+ if set(arg_dict) != set(self.variables):
47
+ raise ValueError("Dictionary keys do not match template variables.")
48
+ return self.template.format(**arg_dict)
49
+
50
+ # Check for the special case where both args and kwargs are empty, which means self.variables must also be empty
51
+ if not args and not kwargs and self.variables:
52
+ raise ValueError("No arguments provided, but template expects variables.")
53
+
54
+ # Use the arguments to format the template
55
+ try:
56
+ return self.template.format(*args, **kwargs)
57
+ except KeyError as e:
58
+ raise ValueError(f"Missing a keyword argument: {e}")
59
+
60
+ @classmethod
61
+ def from_file(cls, file_path: str) -> PromptTemplate:
62
+ with open(file_path, encoding="utf-8") as file:
63
+ template_content = file.read()
64
+ return cls(template_content)
65
+
66
+ def dump_prompt(self, file_path: str) -> None:
67
+ with open(file_path, "w", encoding="utf-8") as file:
68
+ file.write(self.template)
69
+ file.close()
70
+
71
+
72
+ def convert_gradio_to_openai(
73
+ chat_history: list[list[str | None]],
74
+ ) -> list[dict[str, str]]:
75
+ "Converts gradio chat format -> openai chat request format"
76
+ messages = []
77
+ for pair in chat_history: # [(user), (assistant)]
78
+ for i, role in enumerate(["user", "assistant"]):
79
+ if not ((pair[i] is None) or (pair[i] == "")):
80
+ messages += [{"role": role, "content": pair[i]}]
81
+ return messages
82
+
83
+
84
+ def convert_openai_to_gradio(
85
+ messages: list[dict[str, str]]
86
+ ) -> list[list[str, str | None]]:
87
+ "Converts openai chat request format -> gradio chat format"
88
+ chat_history = []
89
+ if messages[0]["role"] != "user":
90
+ messages.insert(0, {"role": "user", "content": None})
91
+ for i in range(0, len(messages), 2):
92
+ chat_history.append([messages[i]["content"], messages[i + 1]["content"]])
93
+ return chat_history
94
+
95
+
96
+ def seed_openai_key(cfg: str = "~/.cfg/openai.cfg") -> None:
97
+ """
98
+ Reads OpenAI key from config file and adds it to environment.
99
+ Assumed config location is "~/.cfg/openai.cfg"
100
+ """
101
+ # Get OpenAI Key
102
+ config = ConfigParser()
103
+ try:
104
+ config.read(Path(cfg).expanduser())
105
+ except:
106
+ raise ValueError(f"Could not using read file at: {cfg}.")
107
+ os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]