drmjh commited on
Commit
55eff01
·
1 Parent(s): beb7e6c

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +253 -0
  2. requirements.txt +3 -0
  3. utils.py +90 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cognitive Debriefing App - Respondent Interface
3
+
4
+ Author: Dr Musashi Hinck
5
+
6
+
7
+ Respondent-facing app. Reads arguments from request (in form of shareable link)
8
+
9
+ TODO:
10
+ - uuid from request for wandb
11
+
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import json
17
+ import wandb
18
+ import gradio as gr
19
+ import openai
20
+
21
+ from base64 import urlsafe_b64decode
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ from utils import PromptTemplate, convert_gradio_to_openai
26
+
27
+
28
+ # %% Initialization
29
+ client = openai.OpenAI()
30
+
31
+
32
+ # %% (functions)
33
+ def decode_config(config_dta: str) -> dict[str, str | float]:
34
+ "Read base64_url encoded json and loads into configuration"
35
+ config_str: str = urlsafe_b64decode(config_dta)
36
+ config: dict = json.loads(config_str)
37
+ return config
38
+
39
+
40
+ def load_config(request: gr.Request):
41
+ "Read parameters from request header"
42
+ config = decode_config(request.query_params['dta'])
43
+ survey_question = config['question']
44
+ survey_template = config['template']
45
+ initial_message = config['initial_message']
46
+ model_args = {'model': config['model'], 'temperature': config['temperature']}
47
+ return survey_question, survey_template, initial_message, model_args
48
+
49
+ # Post-loading
50
+ def update_system_message(question: str, template: PromptTemplate | str) -> str:
51
+ """
52
+ On questionBox|templateBox update, read questionBox|templateBox and update SystemMessageBox
53
+ """
54
+ return template.format(question=question)
55
+
56
+
57
+ def reset_interview() -> tuple[list[list[str | None]], gr.Button, gr.Button]:
58
+ wandb.finish()
59
+ gr.Info("Interview reset.")
60
+ return (
61
+ [],
62
+ gr.Button("Start Interview", visible=True),
63
+ gr.Button("Reply", visible=False),
64
+ gr.Button("Save Survey", visible=False, variant="secondary"),
65
+ gr.Button("Reset Survey", visible=False, variant="stop"),
66
+ )
67
+
68
+
69
+ def initialize_interview(
70
+ system_message: str, first_question: str, model_args: dict[str, str | float]
71
+ ) -> tuple[list[list[str | None]], gr.Textbox, gr.Button, gr.Button]:
72
+ "Read system prompt and start interview"
73
+ if len(first_question) == 0:
74
+ first_question = call_openai(
75
+ [], system_message, client, model_args, stream=False
76
+ )
77
+ # Use fixed prompt
78
+ chat_history = [[None, first_question]]
79
+ return (
80
+ chat_history,
81
+ gr.Textbox(
82
+ placeholder="Type response here.", interactive=True, show_label=False
83
+ ),
84
+ gr.Button("Start Interview", visible=False),
85
+ gr.Button("Reset Survey", visible=True, variant="stop"),
86
+ )
87
+
88
+
89
+ def initialize_tracker(
90
+ model_args: dict[str, str | float], question: str, template: PromptTemplate
91
+ ):
92
+ "Initializes wandb run for interview"
93
+ run_config = model_args | {"question": question, "template": str(template)}
94
+ wandb.init(project="cognitive-debrief", config=run_config, tags=["dev"])
95
+
96
+
97
+ def save_interview(
98
+ chat_history: list[list[str | None]],
99
+ ) -> None:
100
+ chat_data = []
101
+ for pair in chat_history:
102
+ for i, role in enumerate(["user", "system"]):
103
+ if pair[i] is not None:
104
+ chat_data += [[role, pair[i]]]
105
+ chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
106
+ gr.Info("Uploading interview transcript to WandB...")
107
+ wandb.log({"chat_history": chat_table})
108
+
109
+
110
+ def call_openai(
111
+ messages: list[dict[str, str]],
112
+ system_message: str | None,
113
+ client: openai.Client,
114
+ model_args: dict,
115
+ stream: bool = False,
116
+ ):
117
+ "Utility function for calling OpenAI chat. Expects formatted messages."
118
+ if not messages:
119
+ messages = []
120
+ if system_message:
121
+ messages = [{"role": "system", "content": system_message}] + messages
122
+ try:
123
+ response = client.chat.completions.create(
124
+ messages=messages, **model_args, stream=stream
125
+ )
126
+ if stream:
127
+ for chunk in response:
128
+ yield chunk.choices[0].message.content
129
+ else:
130
+ content = response.choices[0].message.content
131
+ return content
132
+ except openai.APIConnectionError | openai.APIStatusError as e:
133
+ error_msg = (
134
+ "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
135
+ )
136
+ gr.Error(error_msg)
137
+ logger.error(error_msg)
138
+ except openai.RateLimitError:
139
+ warning_msg = "Hit rate limit. Wait a moment and retry."
140
+ gr.Warning(warning_msg)
141
+ logger.warning(warning_msg)
142
+
143
+
144
+ def user_message(
145
+ message: str, chat_history: list[list[str | None]]
146
+ ) -> tuple[str, list[list[str | None]]]:
147
+ "Displays user message immediately."
148
+ return "", chat_history + [[message, None]]
149
+
150
+
151
+ def bot_message(
152
+ chat_history: list[list[str | None]],
153
+ system_message: str,
154
+ model_args: dict[str, str | float],
155
+ ) -> list[list[str | None]]:
156
+ # Prep messages
157
+ user_msg = chat_history[-1][0]
158
+ messages = convert_gradio_to_openai(chat_history[:-1])
159
+ messages = (
160
+ [{"role": "system", "content": system_message}]
161
+ + messages
162
+ + [{"role": "user", "content": user_msg}]
163
+ )
164
+ response = client.chat.completions.create(
165
+ messages=messages, stream=True, **model_args
166
+ )
167
+ # Streaming
168
+ chat_history[-1][1] = ""
169
+ for chunk in response:
170
+ delta = chunk.choices[0].delta.content
171
+ if delta:
172
+ chat_history[-1][1] += delta
173
+ yield chat_history
174
+
175
+
176
+ # LAYOUT
177
+ with gr.Blocks() as demo:
178
+ gr.Markdown("# Cognitive Debriefing Prototype")
179
+
180
+ # Hidden values
181
+ surveyQuestion = gr.Textbox(visible=False)
182
+ surveyTemplate = gr.Textbox(visible=False)
183
+ initialMessage = gr.Textbox(visible=False)
184
+ systemMessage = gr.Textbox(visible=False)
185
+ modelArgs = gr.State(value={"model": "", "temperature": ""})
186
+
187
+ # Debugging
188
+ with gr.Accordion("Debugging Panel", open=False):
189
+ debugPane = gr.Textbox(show_label=False, lines=8)
190
+ debugRequest = gr.Button('Read Request')
191
+ debugRequest.click(load_config, outputs=[debugPane])
192
+
193
+
194
+ ## RESPONDENT
195
+ chatDisplay = gr.Chatbot(
196
+ show_label=False,
197
+ )
198
+ chatInput = gr.Textbox(
199
+ placeholder="Click 'Start Interview' to begin.",
200
+ interactive=False,
201
+ show_label=False,
202
+ )
203
+ startInterview = gr.Button("Start Interview", variant="primary")
204
+ resetButton = gr.Button("Reset Survey", visible=False, variant="stop")
205
+
206
+ ## INTERACTIONS
207
+ startInterview.click(
208
+ load_config,
209
+ inputs=None,
210
+ outputs=[
211
+ surveyQuestion,
212
+ surveyTemplate,
213
+ initialMessage,
214
+ modelArgs,
215
+ ]
216
+ ).then(
217
+ update_system_message,
218
+ inputs=[surveyQuestion, surveyTemplate],
219
+ outputs=[systemMessage],
220
+ ).then(
221
+ initialize_interview,
222
+ inputs=[systemMessage, initialMessage, modelArgs],
223
+ outputs=[
224
+ chatDisplay,
225
+ chatInput,
226
+ startInterview,
227
+ resetButton,
228
+ ],
229
+ ).then(initialize_tracker, inputs=[modelArgs, surveyQuestion, surveyTemplate])
230
+
231
+ chatInput.submit(
232
+ user_message,
233
+ inputs=[chatInput, chatDisplay],
234
+ outputs=[chatInput, chatDisplay],
235
+ queue=False
236
+ ).then(
237
+ bot_message,
238
+ inputs=[chatDisplay, systemMessage, modelArgs],
239
+ outputs=[chatDisplay])
240
+
241
+ resetButton.click(
242
+ save_interview,
243
+ [chatDisplay]
244
+ ).then(
245
+ reset_interview,
246
+ outputs=[chatDisplay, startInterview, resetButton],
247
+ show_progress=False,
248
+ )
249
+
250
+
251
+ if __name__ == "__main__":
252
+ # Testing
253
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ openai
3
+ wandb
utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from string import Formatter
4
+
5
+
6
+ # General Class
7
+ class PromptTemplate(str):
8
+ """More robust String Formatter. Takes a string and parses out the keywords."""
9
+
10
+ def __init__(self, template) -> None:
11
+ self.template: str = template
12
+ self.variables: list[str] = self.parse_template()
13
+
14
+ def parse_template(self) -> list[str]:
15
+ "Returns template variables"
16
+ return [
17
+ fn for _, fn, _, _ in Formatter().parse(self.template) if fn is not None
18
+ ]
19
+
20
+ def format(self, *args, **kwargs) -> str:
21
+ """
22
+ Formats the template string with the given arguments.
23
+ Provides slightly more informative error handling.
24
+
25
+ :param args: Positional arguments for unnamed placeholders.
26
+ :param kwargs: Keyword arguments for named placeholders.
27
+ :return: Formatted string.
28
+ :raises: ValueError if arguments do not match template variables.
29
+ """
30
+ # If keyword arguments are provided, check if they match the template variables
31
+ if kwargs and set(kwargs) != set(self.variables):
32
+ raise ValueError("Keyword arguments do not match template variables.")
33
+
34
+ # If positional arguments are provided, check if their count matches the number of template variables
35
+ if args and len(args) != len(self.variables):
36
+ raise ValueError(
37
+ "Number of arguments does not match the number of template variables."
38
+ )
39
+
40
+ # Check if a dictionary is passed as a single positional argument
41
+ if len(args) == 1 and isinstance(args[0], dict):
42
+ arg_dict = args[0]
43
+ if set(arg_dict) != set(self.variables):
44
+ raise ValueError("Dictionary keys do not match template variables.")
45
+ return self.template.format(**arg_dict)
46
+
47
+ # Check for the special case where both args and kwargs are empty, which means self.variables must also be empty
48
+ if not args and not kwargs and self.variables:
49
+ raise ValueError("No arguments provided, but template expects variables.")
50
+
51
+ # Use the arguments to format the template
52
+ try:
53
+ return self.template.format(*args, **kwargs)
54
+ except KeyError as e:
55
+ raise ValueError(f"Missing a keyword argument: {e}")
56
+
57
+ @classmethod
58
+ def from_file(cls, file_path: str) -> PromptTemplate:
59
+ with open(file_path, encoding="utf-8") as file:
60
+ template_content = file.read()
61
+ return cls(template_content)
62
+
63
+ def dump_prompt(self, file_path: str) -> None:
64
+ with open(file_path, "w", encoding="utf-8") as file:
65
+ file.write(self.template)
66
+ file.close()
67
+
68
+
69
+ def convert_gradio_to_openai(
70
+ chat_history: list[list[str | None]],
71
+ ) -> list[dict[str, str]]:
72
+ "Converts gradio chat format -> openai chat request format"
73
+ messages = []
74
+ for pair in chat_history: # [(user), (assistant)]
75
+ for i, role in enumerate(["user", "assistant"]):
76
+ if not ((pair[i] is None) or (pair[i] == "")):
77
+ messages += [{"role": role, "content": pair[i]}]
78
+ return messages
79
+
80
+
81
+ def convert_openai_to_gradio(
82
+ messages: list[dict[str, str]]
83
+ ) -> list[list[str, str | None]]:
84
+ "Converts openai chat request format -> gradio chat format"
85
+ chat_history = []
86
+ if messages[0]["role"] != "user":
87
+ messages.insert(0, {"role": "user", "content": None})
88
+ for i in range(0, len(messages), 2):
89
+ chat_history.append([messages[i]["content"], messages[i + 1]["content"]])
90
+ return chat_history