drmjh commited on
Commit
12dd101
·
1 Parent(s): 51f767a

checkpointing changes

Browse files
Files changed (2) hide show
  1. app.py +35 -14
  2. utils.py +32 -0
app.py CHANGED
@@ -11,6 +11,8 @@ Version Log:
11
  - Read sysprompt and initial_message from file
12
  - Begins with user entering name/alias
13
  - Azure OpenAI?
 
 
14
 
15
  """
16
  from __future__ import annotations
@@ -23,7 +25,6 @@ import gradio as gr
23
  import openai
24
 
25
  from pathlib import Path
26
- from base64 import urlsafe_b64decode
27
 
28
  logger = logging.getLogger(__name__)
29
 
@@ -31,7 +32,8 @@ from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
31
 
32
 
33
  # %% Initialization
34
- CONFIG_DIR: Path = Path("./SPIA2024")
 
35
  if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
36
  seed_openai_key()
37
  client = openai.OpenAI()
@@ -53,9 +55,15 @@ def load_config(
53
 
54
 
55
  def initialize_interview(
56
- system_message: str, initial_message: str, model_args: dict[str, str | float]
57
- ) -> tuple[gr.Chatbot, gr.Textbox, gr.Button, gr.Button]:
58
- "Read system prompt and start interview"
 
 
 
 
 
 
59
  if len(initial_message) == 0: # If empty inital message, ask the LM to write it.
60
  initial_message = call_openai(
61
  [], system_message, client, model_args, stream=False
@@ -64,7 +72,7 @@ def initialize_interview(
64
  [None, initial_message]
65
  ] # First item is for user, in this case bot starts interaction.
66
  return (
67
- gr.Chatbot(visible=True, value=chat_history),
68
  gr.Textbox(
69
  placeholder="Type response here. Hit 'enter' to submit.",
70
  visible=True,
@@ -72,7 +80,6 @@ def initialize_interview(
72
  ), # chatInput
73
  gr.Button(visible=True, interactive=True), # chatSubmit
74
  gr.Button(visible=False), # startInterview
75
- gr.Textbox(visible=False), # userBox
76
  gr.Button(visible=True), # resetButton
77
  )
78
 
@@ -82,18 +89,20 @@ def initialize_tracker(
82
  system_message: PromptTemplate,
83
  userid: str,
84
  wandb_args: dict[str, str | list[str]],
85
- ) -> None:
86
- "Initializes wandb run for interview"
87
  run_config = model_args | {
88
  "system_message": str(system_message),
89
  "userid": userid,
90
  }
 
91
  wandb.init(
92
  project=wandb_args["project"],
93
  name=userid,
94
  config=run_config,
95
  tags=wandb_args["tags"],
96
  )
 
97
 
98
 
99
  def save_interview(
@@ -197,6 +206,7 @@ def reset_interview() -> (
197
  # LAYOUT
198
  with gr.Blocks() as demo:
199
  gr.Markdown("# StewartLab LM Interviewer")
 
200
 
201
  # Config values
202
  configDir = gr.State(value=CONFIG_DIR)
@@ -207,7 +217,7 @@ with gr.Blocks() as demo:
207
 
208
  ## Start interview by entering name or alias
209
  userBox = gr.Textbox(
210
- placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
211
  )
212
  startInterview = gr.Button("Start Interview", variant="primary", visible=True)
213
 
@@ -232,6 +242,7 @@ with gr.Blocks() as demo:
232
 
233
  ## INTERACTIONS
234
  # Start Interview button
 
235
  userBox.submit(
236
  load_config,
237
  inputs=configDir,
@@ -244,12 +255,14 @@ with gr.Blocks() as demo:
244
  chatInput,
245
  chatSubmit,
246
  startInterview,
247
- userBox,
248
  resetButton,
249
  ],
250
  ).then(
251
- initialize_tracker, inputs=[modelArgs, systemMessage, userBox, wandbArgs]
 
 
252
  )
 
253
  startInterview.click(
254
  load_config,
255
  inputs=configDir,
@@ -257,9 +270,17 @@ with gr.Blocks() as demo:
257
  ).then(
258
  initialize_interview,
259
  inputs=[systemMessage, initialMessage, modelArgs],
260
- outputs=[chatDisplay, chatInput, chatSubmit, startInterview, resetButton],
 
 
 
 
 
 
261
  ).then(
262
- initialize_tracker, inputs=[modelArgs, systemMessage, userBox, wandbArgs]
 
 
263
  )
264
 
265
  # Chat interaction
 
11
  - Read sysprompt and initial_message from file
12
  - Begins with user entering name/alias
13
  - Azure OpenAI?
14
+ - 2024.01.31: wandb does not work for use case, what to do instead?
15
+ - Write to local file and then upload at end? (does filestream cause blocking?)
16
 
17
  """
18
  from __future__ import annotations
 
25
  import openai
26
 
27
  from pathlib import Path
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
 
32
 
33
 
34
  # %% Initialization
35
+
36
+ CONFIG_DIR: Path = Path("./SPIA2024_localtesting")
37
  if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
38
  seed_openai_key()
39
  client = openai.OpenAI()
 
55
 
56
 
57
  def initialize_interview(
58
+ system_message: str,
59
+ initial_message: str,
60
+ model_args: dict[str, str | float]
61
+ ) -> tuple[gr.Chatbot,
62
+ gr.Textbox,
63
+ gr.Button,
64
+ gr.Button,
65
+ gr.Button]:
66
+ "Read system prompt and start interview. Change visibilities of elements."
67
  if len(initial_message) == 0: # If empty inital message, ask the LM to write it.
68
  initial_message = call_openai(
69
  [], system_message, client, model_args, stream=False
 
72
  [None, initial_message]
73
  ] # First item is for user, in this case bot starts interaction.
74
  return (
75
+ gr.Chatbot(visible=True, value=chat_history), # chatDisplay
76
  gr.Textbox(
77
  placeholder="Type response here. Hit 'enter' to submit.",
78
  visible=True,
 
80
  ), # chatInput
81
  gr.Button(visible=True, interactive=True), # chatSubmit
82
  gr.Button(visible=False), # startInterview
 
83
  gr.Button(visible=True), # resetButton
84
  )
85
 
 
89
  system_message: PromptTemplate,
90
  userid: str,
91
  wandb_args: dict[str, str | list[str]],
92
+ ) -> gr.Textbox:
93
+ "Initializes wandb run for interview. Resets userBox afterwards."
94
  run_config = model_args | {
95
  "system_message": str(system_message),
96
  "userid": userid,
97
  }
98
+ logger.info(f"Initializing WandB run for {userid}")
99
  wandb.init(
100
  project=wandb_args["project"],
101
  name=userid,
102
  config=run_config,
103
  tags=wandb_args["tags"],
104
  )
105
+ return gr.Textbox(value=None, visible=False)
106
 
107
 
108
  def save_interview(
 
206
  # LAYOUT
207
  with gr.Blocks() as demo:
208
  gr.Markdown("# StewartLab LM Interviewer")
209
+ userDisplay = gr.Markdown("")
210
 
211
  # Config values
212
  configDir = gr.State(value=CONFIG_DIR)
 
217
 
218
  ## Start interview by entering name or alias
219
  userBox = gr.Textbox(
220
+ value=None, placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
221
  )
222
  startInterview = gr.Button("Start Interview", variant="primary", visible=True)
223
 
 
242
 
243
  ## INTERACTIONS
244
  # Start Interview button
245
+ userBox.change(lambda x: x, inputs=[userBox], outputs=[userDisplay], show_progress=False)
246
  userBox.submit(
247
  load_config,
248
  inputs=configDir,
 
255
  chatInput,
256
  chatSubmit,
257
  startInterview,
 
258
  resetButton,
259
  ],
260
  ).then(
261
+ initialize_tracker,
262
+ inputs=[modelArgs, systemMessage, userBox, wandbArgs],
263
+ outputs=[userBox]
264
  )
265
+
266
  startInterview.click(
267
  load_config,
268
  inputs=configDir,
 
270
  ).then(
271
  initialize_interview,
272
  inputs=[systemMessage, initialMessage, modelArgs],
273
+ outputs=[
274
+ chatDisplay,
275
+ chatInput,
276
+ chatSubmit,
277
+ startInterview,
278
+ resetButton,
279
+ ],
280
  ).then(
281
+ initialize_tracker,
282
+ inputs=[modelArgs, systemMessage, userBox, wandbArgs],
283
+ outputs=[userBox]
284
  )
285
 
286
  # Chat interaction
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  from configparser import ConfigParser
5
  from pathlib import Path
6
  from string import Formatter
@@ -105,3 +106,34 @@ def seed_openai_key(cfg: str = "~/.cfg/openai.cfg") -> None:
105
  except:
106
  raise ValueError(f"Could not using read file at: {cfg}.")
107
  os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import openai
5
  from configparser import ConfigParser
6
  from pathlib import Path
7
  from string import Formatter
 
106
  except:
107
  raise ValueError(f"Could not using read file at: {cfg}.")
108
  os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
109
+
110
+
111
+
112
+ def query_openai(
113
+ messages: list[dict[str, str]],
114
+ system_message: str | None,
115
+ client: openai.Client,
116
+ model_args: dict,
117
+ stream: bool = False,
118
+ ) -> None:
119
+ "Utility function for calling OpenAI chat. Expects formatted messages."
120
+ if not messages:
121
+ messages = []
122
+ if system_message:
123
+ messages = [{"role": "system", "content": system_message}] + messages
124
+ try:
125
+ response = client.chat.completions.create(
126
+ messages=messages, **model_args, stream=stream
127
+ )
128
+ if stream:
129
+ for chunk in response:
130
+ yield chunk.choices[0].message.content
131
+ else:
132
+ content = response.choices[0].message.content
133
+ return content
134
+ except openai.APIConnectionError | openai.APIStatusError as e:
135
+ error_msg = (
136
+ "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
137
+ )
138
+ except openai.RateLimitError:
139
+ warning_msg = "Hit rate limit. Wait a moment and retry."