drmjh commited on
Commit
617da91
·
1 Parent(s): 8c50ce7

initial commit for testing

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ .venv
3
+ wandb
app.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot App for Cognitive Debriefing Interview
3
+
4
+ Author: Dr Musashi Hinck
5
+
6
+ Version Log:
7
+ - 02.04.24: Initial demo with passed values from Qualtrics survey
8
+ - 07.04.24: Added configurations for survey edition
9
+
10
+ Notes:
11
+ - Need to call Request from start state
12
+ - Example URL: localhost:7860/?user=123&session=456&questionid=0&response=0
13
+
14
+
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import json
21
+ import logging
22
+ import gradio as gr
23
+ from uuid import uuid4
24
+ from typing import Generator, Any
25
+
26
+ from pathlib import Path
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ from utils import (
31
+ PromptTemplate,
32
+ convert_gradio_to_openai,
33
+ initialize_client,
34
+ seed_azure_key,
35
+ upload_azure,
36
+ record_chat,
37
+ )
38
+
39
+
40
+ # %% Initialize common assets
41
+ if os.environ.get("AZURE_ENDPOINT") is None: # Set Azure credentials from local files
42
+ seed_azure_key()
43
+ client = initialize_client() # Shared across sessions
44
+ question_mapping: dict[str, str] = json.loads(Path("assets/question_mapping.json").read_text())
45
+
46
+
47
+ # %% (functions)
48
+
49
+ # Initialization
50
+ # - Record user and session id
51
+ # - Record question and response
52
+ # - Build system message
53
+ # - Build initial message
54
+ # - Wrapper - start_survey
55
+
56
+
57
+ def initialize_interview(request: gr.Request) -> tuple:
58
+ """
59
+ Read: Request
60
+ Set: values of userId, sessionId, questionWording, initialMessage, systemMessage
61
+ """
62
+ # Parse request
63
+ request_params = request.query_params
64
+ user_id: str = request_params.get("user", "testUser")
65
+ session_id: str = request_params.get("session", "testSession")
66
+ logger.info(f"User: {user_id} (Session: {session_id})")
67
+
68
+ # Parse question
69
+ question_id: str = request_params.get("questionid", "0")
70
+ response_id: str = request_params.get("response", "0")
71
+ question_data: dict = json.loads(Path(f"./assets/questions/{question_mapping[question_id]}").read_text())
72
+ question_wording: str = question_data["question"]
73
+ question_choices: str = question_data["choices"]
74
+ response_text: str = question_choices[int(response_id)]
75
+ logger.info(f"Question: {question_wording} ({response_text})")
76
+
77
+ # Load initial and system messages
78
+ initial_message: str = PromptTemplate.from_file("assets/initial_message.txt").format(surveyQuestion=question_wording)
79
+ system_message: str = PromptTemplate.from_file("assets/system_message.txt").format(surveyQuestion=question_wording, responseVal=response_text)
80
+ logger.info(f"Initial message: {initial_message}")
81
+ logger.info(f"System message: {system_message}")
82
+
83
+ # Return all
84
+ return (
85
+ user_id,
86
+ session_id,
87
+ question_wording,
88
+ initial_message,
89
+ system_message
90
+ )
91
+
92
+ def initialize_interface(initial_message: str) -> tuple:
93
+ """
94
+ Change interface to interactive mode.
95
+ Read: initial_message
96
+ Set:
97
+ instruction_text: modify (to empty)
98
+ chat_display: set initial_message
99
+ chat_input: update placeholder, make interactive
100
+ chat_submit: make interactive
101
+ start_button: hide
102
+ """
103
+ instruction_text = gr.Markdown("")
104
+ chat_display = gr.Chatbot(
105
+ value=[[None, initial_message]],
106
+ elem_id="chatDisplay",
107
+ show_label=False,
108
+ visible=True,
109
+ )
110
+ chat_input = gr.Textbox(
111
+ placeholder="Type response here. Hit `Enter` or click the arrow to submit.",
112
+ visible=True,
113
+ interactive=True,
114
+ show_label=False,
115
+ scale=10,
116
+ )
117
+ chat_submit = gr.Button(
118
+ "",
119
+ variant="primary",
120
+ interactive=True,
121
+ icon="./arrow_icon.svg",
122
+ visible=True,
123
+ )
124
+ start_button = gr.Button("Start Interview", visible=False, variant="primary")
125
+ return (instruction_text, chat_display, chat_input, chat_submit, start_button)
126
+
127
+
128
+ # Interaction
129
+ # - User message
130
+ # - Bot message
131
+ # - Check if interview finished
132
+ # - Record interaction (local log)
133
+
134
+
135
+ def user_message(
136
+ message: str, chat_history: list[list[str | None]]
137
+ ) -> tuple[str, list[list[str | None]]]:
138
+ "Display user message immediately"
139
+ return "", chat_history + [[message, None]]
140
+
141
+
142
+ def bot_message(
143
+ chat_history: list[list[str | None]],
144
+ system_message: str,
145
+ model_args: dict = {"model": "gpt-4o-default", "temperature": 0.0},
146
+ ) -> Generator[Any, Any, Any]:
147
+ "Streams response from OpenAI API to chat interface."
148
+ # Prep messages
149
+ user_msg = chat_history[-1][0]
150
+ messages = convert_gradio_to_openai(chat_history[:-1])
151
+ messages = (
152
+ [{"role": "system", "content": system_message}]
153
+ + messages
154
+ + [{"role": "user", "content": user_msg}]
155
+ )
156
+ # API call
157
+ response = client.chat.completions.create(
158
+ messages=messages, stream=True, **model_args
159
+ )
160
+ # Streaming
161
+ chat_history[-1][1] = ""
162
+ for chunk in response:
163
+ delta = chunk.choices[0].delta.content
164
+ if delta:
165
+ chat_history[-1][1] += delta
166
+ yield chat_history
167
+
168
+
169
+ def log_interaction(
170
+ chat_history: list[list[str | None]],
171
+ session_id: str,
172
+ ) -> None:
173
+ "Record last pair of interactions"
174
+ record_chat(session_id, "user", chat_history[-1][0])
175
+ record_chat(session_id, "bot", chat_history[-1][1])
176
+
177
+
178
+ def interview_end_check(
179
+ chat_history: list[list[str | None]],
180
+ limit: int = 20,
181
+ end_of_interview: str = "<end_interview>",
182
+ ) -> tuple[list[list[str | None]], gr.Button]:
183
+ """
184
+ Checks if interview has completed using two conditions:
185
+ 1. If the last bot message contains `end_of_interview` (default: "<end_interview>")
186
+ 2. Conversation length has reached `limit` (default: 10)
187
+
188
+ If either condition is met, the end of interview button is displayed.
189
+ """
190
+ flag = False
191
+ if len(chat_history) >= limit:
192
+ flag = True
193
+ if end_of_interview in chat_history[-1][1]:
194
+ chat_history[-1][1] = chat_history[-1][1].replace(end_of_interview, "")
195
+ flag = True
196
+ button = gr.Button("Save and Exit", visible=flag, variant="stop")
197
+ return chat_history, button
198
+
199
+
200
+ # Completion
201
+ # - Create completion code
202
+ # - Append to message history
203
+ # - Display completion code
204
+
205
+
206
+ def generate_completion_code() -> str:
207
+ return str(uuid4())
208
+
209
+
210
+ def upload_interview(
211
+ session_id: str,
212
+ chat_history: list[list[str | None]],
213
+ ) -> None:
214
+ "Upload chat history to Azure blob storage"
215
+ upload_azure(session_id, chat_history)
216
+
217
+
218
+ def end_interview(
219
+ session_id: str,
220
+ chat_history: list[list[str | None]],
221
+ ) -> list[list[str | None]]:
222
+ """Create completion code and display in chat interface."""
223
+ completion_message = (
224
+ "Thank you for participating.\n\n"
225
+ "Your completion code is: {}\n\n"
226
+ "Please now return to the Qualtrics survey "
227
+ "and paste this code into the completion "
228
+ "code box.".format(generate_completion_code())
229
+ )
230
+ chat_history += [[None, completion_message]]
231
+ upload_interview(session_id, chat_history)
232
+ return chat_history
233
+
234
+
235
+ # LAYOUT
236
+ with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
237
+ # Header and instructions
238
+ gr.Markdown("# SurveyGPT Interview")
239
+ instructionText = gr.Markdown(
240
+ "Use this chat interface to talk to SurveyGPT.\n"
241
+ "To start, click 'Start Interview' and follow the instructions.\n\n"
242
+ "You can type your answer into the box below and hit 'Enter' or click the arrow to submit.\n\n"
243
+ "The interview will end either after 2 minutes, or if the chatbot decides the interview is done.\n"
244
+ "At this point, you will see a 'Save and Exit' button. Click this to save your responses and receive a completion code."
245
+ )
246
+ # Initialize empty hidden values.
247
+ userId = gr.State()
248
+ sessionId = gr.State()
249
+ questionWording = gr.State()
250
+ initialMessage = gr.State()
251
+ systemMessage = gr.State()
252
+ modelArgs = gr.State(value={"model": "gpt-4o-default", "temperature": 0.0})
253
+
254
+ # Chat app (display, input, submit button)
255
+ startButton = gr.Button("Start Interview", visible=True, variant="primary")
256
+ chatDisplay = gr.Chatbot(
257
+ value=None,
258
+ elem_id="chatDisplay",
259
+ show_label=False,
260
+ visible=True,
261
+ )
262
+ with gr.Row(): # Interaction
263
+ chatInput = gr.Textbox(
264
+ placeholder="Click 'Start Interview' to begin.",
265
+ visible=False,
266
+ interactive=False,
267
+ show_label=False,
268
+ scale=10,
269
+ )
270
+ chatSubmit = gr.Button(
271
+ "",
272
+ variant="primary",
273
+ visible=False,
274
+ interactive=False,
275
+ icon="./arrow_icon.svg",
276
+ )
277
+ exitButton = gr.Button("Save and Exit", visible=False, variant="stop")
278
+ # Footer
279
+ disclaimer = gr.HTML(
280
+ """
281
+ <div
282
+ style='font-size: 1em;
283
+ font-style: italic;
284
+ position: fixed;
285
+ left: 50%;
286
+ bottom: 20px;
287
+ transform: translate(-50%, -50%);
288
+ margin: 0 auto;
289
+ '
290
+ >{}</div>
291
+ """.format(
292
+ "Statements by the chatbot may contain factual inaccuracies."
293
+ )
294
+ )
295
+
296
+ # INTERACTIONS
297
+ # Initialization
298
+ startButton.click(
299
+ initialize_interview, # Reads in request params
300
+ inputs=None,
301
+ outputs=[
302
+ userId,
303
+ sessionId,
304
+ questionWording,
305
+ initialMessage,
306
+ systemMessage,
307
+ ],
308
+ ).then(
309
+ initialize_interface, # Changes interface to interactive mode
310
+ inputs=[initialMessage],
311
+ outputs=[
312
+ instructionText,
313
+ chatDisplay,
314
+ chatInput,
315
+ chatSubmit,
316
+ startButton,
317
+ ],
318
+ )
319
+ # Chat interaction
320
+ # "Enter"
321
+ chatInput.submit(
322
+ user_message,
323
+ inputs=[chatInput, chatDisplay],
324
+ outputs=[chatInput, chatDisplay],
325
+ queue=False,
326
+ ).then(
327
+ bot_message,
328
+ inputs=[chatDisplay, systemMessage, modelArgs],
329
+ outputs=[chatDisplay],
330
+ ).then(
331
+ log_interaction,
332
+ inputs=[chatDisplay, sessionId],
333
+ ).then(
334
+ interview_end_check, inputs=[chatDisplay], outputs=[chatDisplay, exitButton]
335
+ )
336
+
337
+ # Button
338
+ chatSubmit.click(
339
+ user_message,
340
+ inputs=[chatInput, chatDisplay],
341
+ outputs=[chatInput, chatDisplay],
342
+ queue=False,
343
+ ).then(
344
+ bot_message,
345
+ inputs=[chatDisplay, systemMessage, modelArgs],
346
+ outputs=[chatDisplay],
347
+ ).then(
348
+ log_interaction,
349
+ inputs=[chatDisplay, sessionId],
350
+ ).then(
351
+ interview_end_check, inputs=[chatDisplay], outputs=[chatDisplay, exitButton]
352
+ )
353
+
354
+ # Reset button
355
+ exitButton.click(
356
+ end_interview, inputs=[sessionId, chatDisplay], outputs=[chatDisplay]
357
+ )
358
+
359
+
360
+ if __name__ == "__main__":
361
+ demo.launch()#auth=auth_no_user)
arrow_icon.svg ADDED
assets/arrow_icon.svg ADDED
assets/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "model_args" : {
3
+ "model": "gpt-4o-default",
4
+ "temperature": 0.0
5
+ }
6
+ }
assets/initial_message.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Hello! My name is SurveyGPT, a conversational AI designed to help improve survey research.
2
+
3
+ In the survey you were asked the following question:
4
+
5
+ {surveyQuestion}
6
+
7
+ What did you think was meant by the question?
assets/question_mapping.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "0": "partisanship.json",
3
+ "1": "ideology.json"
4
+ }
assets/questions/partisanship.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "question": "Generally speaking, do you think of yourself as a Republican, a Democrat, an independent, or something else?",
3
+ "choices": [
4
+ "Republican",
5
+ "Democrat",
6
+ "Independent",
7
+ "Something else"
8
+ ]
9
+ }
assets/system_message.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an AI designed to help researchers validate questions before they field their surveys by pre-testing interviews with human respondents.
2
+ The person you are speaking with is a participant in a survey.
3
+
4
+ Your job is to conduct a Cognitive Debriefing interview with the respondent.
5
+
6
+ This interview consists of two parts.
7
+
8
+ - In the first part, ask "In the survey you were asked the following question:\n{surveyQuestion}\n\nWhat did you think was meant by the question?"
9
+ - In the second part, ask "You answered {responseVal} to the question. What did you mean by that?
10
+
11
+ In both parts:
12
+
13
+ - After each answer, ask follow-up questions designed to expand and clarify responses.
14
+ - Move on to the next part once there is a satisfactory amount of information to conduct analyses of differences in how questions are understood.
15
+ - Do not lead the respondent to a particular answer or suggest answers to the respondent; your goal is to provide an informative transcript that the researcher can review afterwards to help revise their question wording so that new respondents will have the same understanding as intended by the researcher.
16
+
17
+ At the end of the survey, thank the respondent for participating and include the special <end_of_survey> token in the response to end the interview
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ openai
3
+ wandb
4
+ azure-storage-blob
5
+ azure-identity
6
+ debugpy
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import json
5
+ import os
6
+ from configparser import ConfigParser
7
+ from pathlib import Path
8
+ from string import Formatter
9
+
10
+ import openai
11
+ from azure.storage.blob import BlobClient
12
+
13
+
14
+ # Logging util
15
+ def get_current_timestamp() -> str:
16
+ return datetime.datetime.now().isoformat()
17
+
18
+
19
+ class ChatLoggerHandler:
20
+ """Shared logging handler for chat logs. Runs common to all Gradio sessions."""
21
+
22
+ def __init__(self, logdir: str = "./logs") -> None:
23
+ self.logdir: Path = Path(logdir)
24
+
25
+ def record(self, session: str, role: str, record: str):
26
+ log_entry = {
27
+ "metadata": {"session": session, "timestamp": get_current_timestamp()},
28
+ "record": {"role": role, "message": record},
29
+ }
30
+ (self.logdir / f"{session}.jsonl").write_text(json.dumps(log_entry) + "\n")
31
+
32
+
33
+ def record_chat(
34
+ logger: ChatLoggerHandler, session: str, role: str, record: str
35
+ ) -> None:
36
+ logger.record(session, role, record)
37
+
38
+
39
+ # General Class
40
+ class PromptTemplate(str):
41
+ """More robust String Formatter. Takes a string and parses out the keywords."""
42
+
43
+ def __init__(self, template: str) -> None:
44
+ self.template: str = template
45
+ self.variables: list[str] = self.parse_template()
46
+
47
+ def parse_template(self) -> list[str]:
48
+ "Returns template variables"
49
+ return [
50
+ fn for _, fn, _, _ in Formatter().parse(self.template) if fn is not None
51
+ ]
52
+
53
+ def format(self, *args, **kwargs) -> str:
54
+ """
55
+ Formats the template string with the given arguments.
56
+ Provides slightly more informative error handling.
57
+
58
+ :param args: Positional arguments for unnamed placeholders.
59
+ :param kwargs: Keyword arguments for named placeholders.
60
+ :return: Formatted string.
61
+ :raises: ValueError if arguments do not match template variables.
62
+ """
63
+ # If keyword arguments are provided, check if they match the template variables
64
+ if kwargs and set(kwargs) != set(self.variables):
65
+ raise ValueError("Keyword arguments do not match template variables.")
66
+
67
+ # If positional arguments are provided, check if their count matches the number of template variables
68
+ if args and len(args) != len(self.variables):
69
+ raise ValueError(
70
+ "Number of arguments does not match the number of template variables."
71
+ )
72
+
73
+ # Check if a dictionary is passed as a single positional argument
74
+ if len(args) == 1 and isinstance(args[0], dict):
75
+ arg_dict = args[0]
76
+ if set(arg_dict) != set(self.variables):
77
+ raise ValueError("Dictionary keys do not match template variables.")
78
+ return self.template.format(**arg_dict)
79
+
80
+ # Check for the special case where both args and kwargs are empty, which means self.variables must also be empty
81
+ if not args and not kwargs and self.variables:
82
+ raise ValueError("No arguments provided, but template expects variables.")
83
+
84
+ # Use the arguments to format the template
85
+ try:
86
+ return self.template.format(*args, **kwargs)
87
+ except KeyError as e:
88
+ raise ValueError(f"Missing a keyword argument: {e}")
89
+
90
+ @classmethod
91
+ def from_file(cls, file_path: str) -> PromptTemplate:
92
+ with open(file_path, encoding="utf-8") as file:
93
+ template_content = file.read()
94
+ return cls(template_content)
95
+
96
+ def dump_prompt(self, file_path: str) -> None:
97
+ with open(file_path, "w", encoding="utf-8") as file:
98
+ file.write(self.template)
99
+ file.close()
100
+
101
+
102
+ def convert_gradio_to_openai(
103
+ chat_history: list[list[str | None]],
104
+ ) -> list[dict[str, str]]:
105
+ "Converts gradio chat format -> openai chat request format"
106
+ messages = []
107
+ for pair in chat_history: # [(user), (assistant)]
108
+ for i, role in enumerate(["user", "assistant"]):
109
+ if not ((pair[i] is None) or (pair[i] == "")):
110
+ messages += [{"role": role, "content": pair[i]}]
111
+ return messages
112
+
113
+
114
+ def convert_openai_to_gradio(
115
+ messages: list[dict[str, str]]
116
+ ) -> list[list[str, str | None]]:
117
+ "Converts openai chat request format -> gradio chat format"
118
+ chat_history = []
119
+ if messages[0]["role"] != "user":
120
+ messages.insert(0, {"role": "user", "content": None})
121
+ for i in range(0, len(messages), 2):
122
+ chat_history.append([messages[i]["content"], messages[i + 1]["content"]])
123
+ return chat_history
124
+
125
+
126
+ def seed_azure_key(cfg: str = "~/.cfg/openai.cfg") -> None:
127
+ config = ConfigParser()
128
+ try:
129
+ config.read(Path(cfg).expanduser())
130
+ except:
131
+ raise ValueError(f"Could not using read file at: {cfg}.")
132
+ os.environ["AZURE_ENDPOINT"] = config["AZURE"]["endpoint"]
133
+ os.environ["AZURE_SECRET"] = config["AZURE"]["key"]
134
+
135
+
136
+ def initialize_client() -> openai.AsyncClient:
137
+ client = openai.AzureOpenAI(
138
+ azure_endpoint=os.environ["AZURE_ENDPOINT"],
139
+ api_key=os.environ["AZURE_SECRET"],
140
+ api_version="2023-05-15",
141
+ )
142
+ return client
143
+
144
+
145
+ def auth_no_user(username, password):
146
+ if password == os.getenv("GRADIO_PASSWORD", ""):
147
+ return True
148
+ else:
149
+ return False
150
+
151
+
152
+ def upload_azure(conversation_id: str, chat_history) -> None:
153
+ # Get blob client
154
+ conn_str = os.getenv("AZURE_CONN_STR")
155
+ container_name = os.getenv("AZURE_CONTAINER_NAME")
156
+ blob_name = conversation_id
157
+ blob_client = BlobClient.from_connection_string(conn_str, container_name, blob_name)
158
+
159
+ # Convert chat_history to json lines
160
+ records = convert_gradio_to_openai(chat_history)
161
+ records_text = "\n".join([json.dumps(record) for record in records])
162
+ blob_client.upload_blob(records_text, blob_type="AppendBlob", overwrite=True)