Spaces:
Runtime error
Runtime error
checkpointing changes
Browse files
app.py
CHANGED
|
@@ -11,6 +11,8 @@ Version Log:
|
|
| 11 |
- Read sysprompt and initial_message from file
|
| 12 |
- Begins with user entering name/alias
|
| 13 |
- Azure OpenAI?
|
|
|
|
|
|
|
| 14 |
|
| 15 |
"""
|
| 16 |
from __future__ import annotations
|
|
@@ -23,7 +25,6 @@ import gradio as gr
|
|
| 23 |
import openai
|
| 24 |
|
| 25 |
from pathlib import Path
|
| 26 |
-
from base64 import urlsafe_b64decode
|
| 27 |
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
|
@@ -31,7 +32,8 @@ from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
|
|
| 31 |
|
| 32 |
|
| 33 |
# %% Initialization
|
| 34 |
-
|
|
|
|
| 35 |
if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
|
| 36 |
seed_openai_key()
|
| 37 |
client = openai.OpenAI()
|
|
@@ -53,9 +55,15 @@ def load_config(
|
|
| 53 |
|
| 54 |
|
| 55 |
def initialize_interview(
|
| 56 |
-
system_message: str,
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if len(initial_message) == 0: # If empty inital message, ask the LM to write it.
|
| 60 |
initial_message = call_openai(
|
| 61 |
[], system_message, client, model_args, stream=False
|
|
@@ -64,7 +72,7 @@ def initialize_interview(
|
|
| 64 |
[None, initial_message]
|
| 65 |
] # First item is for user, in this case bot starts interaction.
|
| 66 |
return (
|
| 67 |
-
gr.Chatbot(visible=True, value=chat_history),
|
| 68 |
gr.Textbox(
|
| 69 |
placeholder="Type response here. Hit 'enter' to submit.",
|
| 70 |
visible=True,
|
|
@@ -72,7 +80,6 @@ def initialize_interview(
|
|
| 72 |
), # chatInput
|
| 73 |
gr.Button(visible=True, interactive=True), # chatSubmit
|
| 74 |
gr.Button(visible=False), # startInterview
|
| 75 |
-
gr.Textbox(visible=False), # userBox
|
| 76 |
gr.Button(visible=True), # resetButton
|
| 77 |
)
|
| 78 |
|
|
@@ -82,18 +89,20 @@ def initialize_tracker(
|
|
| 82 |
system_message: PromptTemplate,
|
| 83 |
userid: str,
|
| 84 |
wandb_args: dict[str, str | list[str]],
|
| 85 |
-
) ->
|
| 86 |
-
"Initializes wandb run for interview"
|
| 87 |
run_config = model_args | {
|
| 88 |
"system_message": str(system_message),
|
| 89 |
"userid": userid,
|
| 90 |
}
|
|
|
|
| 91 |
wandb.init(
|
| 92 |
project=wandb_args["project"],
|
| 93 |
name=userid,
|
| 94 |
config=run_config,
|
| 95 |
tags=wandb_args["tags"],
|
| 96 |
)
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def save_interview(
|
|
@@ -197,6 +206,7 @@ def reset_interview() -> (
|
|
| 197 |
# LAYOUT
|
| 198 |
with gr.Blocks() as demo:
|
| 199 |
gr.Markdown("# StewartLab LM Interviewer")
|
|
|
|
| 200 |
|
| 201 |
# Config values
|
| 202 |
configDir = gr.State(value=CONFIG_DIR)
|
|
@@ -207,7 +217,7 @@ with gr.Blocks() as demo:
|
|
| 207 |
|
| 208 |
## Start interview by entering name or alias
|
| 209 |
userBox = gr.Textbox(
|
| 210 |
-
placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
|
| 211 |
)
|
| 212 |
startInterview = gr.Button("Start Interview", variant="primary", visible=True)
|
| 213 |
|
|
@@ -232,6 +242,7 @@ with gr.Blocks() as demo:
|
|
| 232 |
|
| 233 |
## INTERACTIONS
|
| 234 |
# Start Interview button
|
|
|
|
| 235 |
userBox.submit(
|
| 236 |
load_config,
|
| 237 |
inputs=configDir,
|
|
@@ -244,12 +255,14 @@ with gr.Blocks() as demo:
|
|
| 244 |
chatInput,
|
| 245 |
chatSubmit,
|
| 246 |
startInterview,
|
| 247 |
-
userBox,
|
| 248 |
resetButton,
|
| 249 |
],
|
| 250 |
).then(
|
| 251 |
-
initialize_tracker,
|
|
|
|
|
|
|
| 252 |
)
|
|
|
|
| 253 |
startInterview.click(
|
| 254 |
load_config,
|
| 255 |
inputs=configDir,
|
|
@@ -257,9 +270,17 @@ with gr.Blocks() as demo:
|
|
| 257 |
).then(
|
| 258 |
initialize_interview,
|
| 259 |
inputs=[systemMessage, initialMessage, modelArgs],
|
| 260 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
).then(
|
| 262 |
-
initialize_tracker,
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
|
| 265 |
# Chat interaction
|
|
|
|
| 11 |
- Read sysprompt and initial_message from file
|
| 12 |
- Begins with user entering name/alias
|
| 13 |
- Azure OpenAI?
|
| 14 |
+
- 2024.01.31: wandb does not work for use case, what to do instead?
|
| 15 |
+
- Write to local file and then upload at end? (does filestream cause blocking?)
|
| 16 |
|
| 17 |
"""
|
| 18 |
from __future__ import annotations
|
|
|
|
| 25 |
import openai
|
| 26 |
|
| 27 |
from pathlib import Path
|
|
|
|
| 28 |
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# %% Initialization
|
| 35 |
+
|
| 36 |
+
CONFIG_DIR: Path = Path("./SPIA2024_localtesting")
|
| 37 |
if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
|
| 38 |
seed_openai_key()
|
| 39 |
client = openai.OpenAI()
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def initialize_interview(
|
| 58 |
+
system_message: str,
|
| 59 |
+
initial_message: str,
|
| 60 |
+
model_args: dict[str, str | float]
|
| 61 |
+
) -> tuple[gr.Chatbot,
|
| 62 |
+
gr.Textbox,
|
| 63 |
+
gr.Button,
|
| 64 |
+
gr.Button,
|
| 65 |
+
gr.Button]:
|
| 66 |
+
"Read system prompt and start interview. Change visibilities of elements."
|
| 67 |
if len(initial_message) == 0: # If empty inital message, ask the LM to write it.
|
| 68 |
initial_message = call_openai(
|
| 69 |
[], system_message, client, model_args, stream=False
|
|
|
|
| 72 |
[None, initial_message]
|
| 73 |
] # First item is for user, in this case bot starts interaction.
|
| 74 |
return (
|
| 75 |
+
gr.Chatbot(visible=True, value=chat_history), # chatDisplay
|
| 76 |
gr.Textbox(
|
| 77 |
placeholder="Type response here. Hit 'enter' to submit.",
|
| 78 |
visible=True,
|
|
|
|
| 80 |
), # chatInput
|
| 81 |
gr.Button(visible=True, interactive=True), # chatSubmit
|
| 82 |
gr.Button(visible=False), # startInterview
|
|
|
|
| 83 |
gr.Button(visible=True), # resetButton
|
| 84 |
)
|
| 85 |
|
|
|
|
| 89 |
system_message: PromptTemplate,
|
| 90 |
userid: str,
|
| 91 |
wandb_args: dict[str, str | list[str]],
|
| 92 |
+
) -> gr.Textbox:
|
| 93 |
+
"Initializes wandb run for interview. Resets userBox afterwards."
|
| 94 |
run_config = model_args | {
|
| 95 |
"system_message": str(system_message),
|
| 96 |
"userid": userid,
|
| 97 |
}
|
| 98 |
+
logger.info(f"Initializing WandB run for {userid}")
|
| 99 |
wandb.init(
|
| 100 |
project=wandb_args["project"],
|
| 101 |
name=userid,
|
| 102 |
config=run_config,
|
| 103 |
tags=wandb_args["tags"],
|
| 104 |
)
|
| 105 |
+
return gr.Textbox(value=None, visible=False)
|
| 106 |
|
| 107 |
|
| 108 |
def save_interview(
|
|
|
|
| 206 |
# LAYOUT
|
| 207 |
with gr.Blocks() as demo:
|
| 208 |
gr.Markdown("# StewartLab LM Interviewer")
|
| 209 |
+
userDisplay = gr.Markdown("")
|
| 210 |
|
| 211 |
# Config values
|
| 212 |
configDir = gr.State(value=CONFIG_DIR)
|
|
|
|
| 217 |
|
| 218 |
## Start interview by entering name or alias
|
| 219 |
userBox = gr.Textbox(
|
| 220 |
+
value=None, placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
|
| 221 |
)
|
| 222 |
startInterview = gr.Button("Start Interview", variant="primary", visible=True)
|
| 223 |
|
|
|
|
| 242 |
|
| 243 |
## INTERACTIONS
|
| 244 |
# Start Interview button
|
| 245 |
+
userBox.change(lambda x: x, inputs=[userBox], outputs=[userDisplay], show_progress=False)
|
| 246 |
userBox.submit(
|
| 247 |
load_config,
|
| 248 |
inputs=configDir,
|
|
|
|
| 255 |
chatInput,
|
| 256 |
chatSubmit,
|
| 257 |
startInterview,
|
|
|
|
| 258 |
resetButton,
|
| 259 |
],
|
| 260 |
).then(
|
| 261 |
+
initialize_tracker,
|
| 262 |
+
inputs=[modelArgs, systemMessage, userBox, wandbArgs],
|
| 263 |
+
outputs=[userBox]
|
| 264 |
)
|
| 265 |
+
|
| 266 |
startInterview.click(
|
| 267 |
load_config,
|
| 268 |
inputs=configDir,
|
|
|
|
| 270 |
).then(
|
| 271 |
initialize_interview,
|
| 272 |
inputs=[systemMessage, initialMessage, modelArgs],
|
| 273 |
+
outputs=[
|
| 274 |
+
chatDisplay,
|
| 275 |
+
chatInput,
|
| 276 |
+
chatSubmit,
|
| 277 |
+
startInterview,
|
| 278 |
+
resetButton,
|
| 279 |
+
],
|
| 280 |
).then(
|
| 281 |
+
initialize_tracker,
|
| 282 |
+
inputs=[modelArgs, systemMessage, userBox, wandbArgs],
|
| 283 |
+
outputs=[userBox]
|
| 284 |
)
|
| 285 |
|
| 286 |
# Chat interaction
|
utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
|
|
|
| 4 |
from configparser import ConfigParser
|
| 5 |
from pathlib import Path
|
| 6 |
from string import Formatter
|
|
@@ -105,3 +106,34 @@ def seed_openai_key(cfg: str = "~/.cfg/openai.cfg") -> None:
|
|
| 105 |
except:
|
| 106 |
raise ValueError(f"Could not using read file at: {cfg}.")
|
| 107 |
os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
import openai
|
| 5 |
from configparser import ConfigParser
|
| 6 |
from pathlib import Path
|
| 7 |
from string import Formatter
|
|
|
|
| 106 |
except:
|
| 107 |
raise ValueError(f"Could not using read file at: {cfg}.")
|
| 108 |
os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def query_openai(
|
| 113 |
+
messages: list[dict[str, str]],
|
| 114 |
+
system_message: str | None,
|
| 115 |
+
client: openai.Client,
|
| 116 |
+
model_args: dict,
|
| 117 |
+
stream: bool = False,
|
| 118 |
+
) -> None:
|
| 119 |
+
"Utility function for calling OpenAI chat. Expects formatted messages."
|
| 120 |
+
if not messages:
|
| 121 |
+
messages = []
|
| 122 |
+
if system_message:
|
| 123 |
+
messages = [{"role": "system", "content": system_message}] + messages
|
| 124 |
+
try:
|
| 125 |
+
response = client.chat.completions.create(
|
| 126 |
+
messages=messages, **model_args, stream=stream
|
| 127 |
+
)
|
| 128 |
+
if stream:
|
| 129 |
+
for chunk in response:
|
| 130 |
+
yield chunk.choices[0].message.content
|
| 131 |
+
else:
|
| 132 |
+
content = response.choices[0].message.content
|
| 133 |
+
return content
|
| 134 |
+
except openai.APIConnectionError | openai.APIStatusError as e:
|
| 135 |
+
error_msg = (
|
| 136 |
+
"API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
|
| 137 |
+
)
|
| 138 |
+
except openai.RateLimitError:
|
| 139 |
+
warning_msg = "Hit rate limit. Wait a moment and retry."
|