File size: 9,130 Bytes
7549871
0aa2975
7549871
 
 
 
0aa2975
7549871
0aa2975
 
 
 
 
12dd101
 
a4a3dbe
7549871
 
 
 
 
 
 
 
 
a4a3dbe
7549871
0aa2975
7549871
 
 
626a080
 
 
 
 
 
7549871
 
 
a4a3dbe
626a080
 
 
7549871
 
0aa2975
 
 
 
 
 
 
 
 
7549871
0aa2975
 
7549871
 
 
12dd101
 
 
 
 
 
 
0aa2975
 
 
7549871
12dd101
7549871
0aa2975
 
 
 
 
 
 
7549871
 
 
 
 
0aa2975
 
 
12dd101
 
7549871
0aa2975
7549871
 
12dd101
7549871
0aa2975
 
 
 
7549871
12dd101
7549871
 
 
 
 
a4a3dbe
 
 
7549871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa2975
7549871
 
 
 
 
 
 
a4a3dbe
7549871
 
 
 
 
 
 
 
0aa2975
7549871
 
 
 
 
 
 
 
 
 
 
 
0aa2975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7549871
a4a3dbe
 
626a080
7549871
0aa2975
 
7549871
 
 
0aa2975
7549871
0aa2975
 
12dd101
7549871
0aa2975
 
 
 
7549871
 
 
0aa2975
7549871
 
 
 
 
 
0aa2975
7549871
 
0aa2975
7549871
 
a4a3dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7549871
 
 
12dd101
0aa2975
7549871
0aa2975
 
7549871
 
626a080
7549871
 
 
 
 
 
 
 
12dd101
 
 
0aa2975
12dd101
0aa2975
 
51f767a
0aa2975
 
 
626a080
12dd101
 
 
 
 
 
 
0aa2975
12dd101
 
 
7549871
 
0aa2975
 
7549871
 
 
 
 
 
 
 
 
 
 
 
0aa2975
7549871
 
 
 
 
 
 
 
 
 
 
 
 
0aa2975
7549871
 
0aa2975
 
 
 
 
 
 
 
7549871
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
General-Purpose LM Interview Interface

Author: Dr Musashi Hinck


Version Log:

- 2024.01.29: prototype without separate launching interface for demoing in SPIA class.
    - Remove URL decoding
    - Read sysprompt and initial_message from file
    - Begins with user entering name/alias
    - Azure OpenAI?
- 2024.01.31: wandb does not work for use case, what to do instead?
    - Write to local file and then upload at end? (does filestream cause blocking?)
- 2024.03.03: Creating new instance for demoing to IRB

"""
from __future__ import annotations

import os
import logging
import json
import wandb
import gradio as gr
from typing import Generator, Any

from pathlib import Path

logger = logging.getLogger(__name__)

from utils import (
    PromptTemplate,
    convert_gradio_to_openai,
    initialize_client,
    seed_azure_key
)


# %% Initialization
CONFIG_DIR: Path = Path("./CogDebIRB")
if os.environ.get("AZURE_ENDPOINT") is None: # Set Azure credentials from local files
    seed_azure_key()
client = initialize_client()

# %% (functions)
def load_config(
    path: Path,
) -> tuple[str, str, dict[str, str | float], dict[str, str | list[str]]]:
    "Read configs, return inital_message, system_message, model_args, wandb_args"
    initial_message: str = (path / "initial_message.txt").read_text().strip()
    system_message: str = (path / "system_message.txt").read_text().strip()
    cfg: dict[str, str] = json.loads((path / "config.json").read_bytes())
    model_args: dict[str, str | float] = cfg.get(
        "model_args", {"model": "gpt4", "temperature": 0.0}
    )
    wandb_args: dict = cfg.get("wandb_args")
    return initial_message, system_message, model_args, wandb_args


def initialize_interview(
    initial_message: str,
) -> tuple[gr.Chatbot,
           gr.Textbox,
           gr.Button,
           gr.Button,
           gr.Button]:
    "Read system prompt and start interview. Change visibilities of elements."
    chat_history = [
        [None, initial_message]
    ]  # First item is for user, in this case bot starts interaction.
    return (
        gr.Chatbot(visible=True, value=chat_history), # chatDisplay
        gr.Textbox(
            placeholder="Type response here. Hit 'enter' to submit.",
            visible=True,
            interactive=True,
        ),  # chatInput
        gr.Button(visible=True, interactive=True),  # chatSubmit
        gr.Button(visible=False),  # startInterview
        gr.Button(visible=True),  # resetButton
    )


def initialize_tracker(
    model_args: dict[str, str | float],
    system_message: PromptTemplate,
    userid: str,
    wandb_args: dict[str, str | list[str]],
) -> gr.Textbox:
    "Initializes wandb run for interview. Resets userBox afterwards."
    run_config = model_args | {
        "system_message": str(system_message),
        "userid": userid,
    }
    logger.info(f"Initializing WandB run for {userid}")
    wandb.init(
        project=wandb_args["project"],
        name=userid,
        config=run_config,
        tags=wandb_args["tags"],
    )
    return gr.Textbox(value=None, visible=False)


def save_interview(
    chat_history: list[list[str | None]],
) -> None:
    # Save chat_history as json
    with open(CONFIG_DIR/"transcript.json", 'w') as fh:
        json.dump(chat_history, fh, indent=2)
    chat_data = []
    for pair in chat_history:
        for i, role in enumerate(["user", "bot"]):
            if pair[i] is not None:
                chat_data += [[role, pair[i]]]
    chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
    logger.info("Uploading interview transcript to WandB...")
    wandb.log({"chat_history": chat_table})
    logger.info("Uploading complete.")



def user_message(
    message: str, chat_history: list[list[str | None]]
) -> tuple[str, list[list[str | None]]]:
    "Display user message immediately"
    return "", chat_history + [[message, None]]


def bot_message(
    chat_history: list[list[str | None]],
    system_message: str,
    model_args: dict[str, str | float],
) -> Generator[Any, Any, Any]:
    # Prep messages
    user_msg = chat_history[-1][0]
    messages = convert_gradio_to_openai(chat_history[:-1])
    messages = (
        [{"role": "system", "content": system_message}]
        + messages
        + [{"role": "user", "content": user_msg}]
    )
    # API call
    response = client.chat.completions.create(
        messages=messages, stream=True, **model_args
    )
    # Streaming
    chat_history[-1][1] = ""
    for chunk in response:
        delta = chunk.choices[0].delta.content
        if delta:
            chat_history[-1][1] += delta
            yield chat_history


def reset_interview() -> (
    tuple[
        list[list[str | None]], gr.Chatbot, gr.Textbox, gr.Button, gr.Button, gr.Button
    ]
):
    wandb.finish()
    gr.Info("Interview reset.")
    return (
        gr.Chatbot(visible=False, value=[]),  # chatDisplay
        gr.Textbox(visible=False),  # chatInput
        gr.Button(visible=False),  # chatSubmit
        gr.Textbox(value=None, visible=True),  # userBox
        gr.Button(visible=True),  # startInterview
        gr.Button(visible=False),  # resetButton
    )


# LAYOUT
with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
    gr.Markdown("# Chat Interview Interface")
    userDisplay = gr.Markdown("", visible=False)

    # Config values
    configDir = gr.State(value=CONFIG_DIR)
    initialMessage = gr.Textbox(visible=False)
    systemMessage = gr.Textbox(visible=False)
    modelArgs = gr.State(value={"model": "", "temperature": ""})
    wandbArgs = gr.State(value={"project": "", "tags": []})

    ## Start interview by entering name or alias
    userBox = gr.Textbox(
        value=None, placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
    )
    startInterview = gr.Button("Start Interview", variant="primary", visible=True)

    ## RESPONDENT
    chatDisplay = gr.Chatbot(show_label=False, visible=False)
    with gr.Row():
        chatInput = gr.Textbox(
            placeholder="Click 'Start Interview' to begin.",
            visible=False,
            interactive=False,
            show_label=False,
            scale=10,
        )
        chatSubmit = gr.Button(
            "",
            variant="primary",
            interactive=False,
            icon="./arrow_icon.svg",
            visible=False,
        )
    resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
    disclaimer = gr.HTML(
        """
        <div
        style='font-size: 1em;
               font-style: italic;   
               position: fixed;
               left: 50%;
               bottom: 20px;
               transform: translate(-50%, -50%);
               margin: 0 auto;
               '
        >{}</div>
        """.format(
            "Statements by the chatbot may contain factual inaccuracies."
        )
    )

    ## INTERACTIONS
    # Start Interview button
    userBox.change(lambda x: x, inputs=[userBox], outputs=[userDisplay], show_progress=False)
    userBox.submit(
        load_config,
        inputs=configDir,
        outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
    ).then(
        initialize_interview,
        inputs=[initialMessage],
        outputs=[
            chatDisplay,
            chatInput,
            chatSubmit,
            startInterview,
            resetButton,
        ],
    ).then(
        initialize_tracker,
        inputs=[modelArgs, systemMessage, userBox, wandbArgs],
        outputs=[userBox]
    )

    startInterview.click(
        load_config,
        inputs=configDir,
        outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
    ).then(
        initialize_interview,
        inputs=[initialMessage],
        outputs=[
            chatDisplay,
            chatInput,
            chatSubmit,
            startInterview,
            resetButton,
        ],
    ).then(
        initialize_tracker,
        inputs=[modelArgs, systemMessage, userBox, wandbArgs],
        outputs=[userBox]
    )

    # Chat interaction
    # "Enter"
    chatInput.submit(
        user_message,
        inputs=[chatInput, chatDisplay],
        outputs=[chatInput, chatDisplay],
        queue=False,
    ).then(
        bot_message,
        inputs=[chatDisplay, systemMessage, modelArgs],
        outputs=[chatDisplay],
    ).then(
        save_interview, inputs=[chatDisplay]
    )
    # Button
    chatSubmit.click(
        user_message,
        inputs=[chatInput, chatDisplay],
        outputs=[chatInput, chatDisplay],
        queue=False,
    ).then(
        bot_message,
        inputs=[chatDisplay, systemMessage, modelArgs],
        outputs=[chatDisplay],
    ).then(
        save_interview, inputs=[chatDisplay]
    )

    # Reset button
    resetButton.click(save_interview, [chatDisplay]).then(
        reset_interview,
        outputs=[
            chatDisplay,
            chatInput,
            chatSubmit,
            userBox,
            startInterview,
            resetButton,
        ],
        show_progress=False,
    )


if __name__ == "__main__":
    demo.launch()