drmjh commited on
Commit
626a080
·
1 Parent(s): 12dd101

Changed to Azure endpoint

Browse files
Files changed (2) hide show
  1. app.py +13 -49
  2. utils.py +12 -38
app.py CHANGED
@@ -28,16 +28,19 @@ from pathlib import Path
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- from utils import PromptTemplate, convert_gradio_to_openai, seed_openai_key
 
 
 
 
 
32
 
33
 
34
  # %% Initialization
35
-
36
- CONFIG_DIR: Path = Path("./SPIA2024_localtesting")
37
- if os.environ.get(f"OPENAI_API_KEY", "DEFAULT") == "DEFAULT":
38
- seed_openai_key()
39
- client = openai.OpenAI()
40
-
41
 
42
  # %% (functions)
43
  def load_config(
@@ -55,19 +58,13 @@ def load_config(
55
 
56
 
57
  def initialize_interview(
58
- system_message: str,
59
  initial_message: str,
60
- model_args: dict[str, str | float]
61
  ) -> tuple[gr.Chatbot,
62
  gr.Textbox,
63
  gr.Button,
64
  gr.Button,
65
  gr.Button]:
66
  "Read system prompt and start interview. Change visibilities of elements."
67
- if len(initial_message) == 0: # If empty inital message, ask the LM to write it.
68
- initial_message = call_openai(
69
- [], system_message, client, model_args, stream=False
70
- )
71
  chat_history = [
72
  [None, initial_message]
73
  ] # First item is for user, in this case bot starts interaction.
@@ -119,39 +116,6 @@ def save_interview(
119
  logger.info("Uploading complete.")
120
 
121
 
122
- def call_openai(
123
- messages: list[dict[str, str]],
124
- system_message: str | None,
125
- client: openai.Client,
126
- model_args: dict,
127
- stream: bool = False,
128
- ):
129
- "Utility function for calling OpenAI chat. Expects formatted messages."
130
- if not messages:
131
- messages = []
132
- if system_message:
133
- messages = [{"role": "system", "content": system_message}] + messages
134
- try:
135
- response = client.chat.completions.create(
136
- messages=messages, **model_args, stream=stream
137
- )
138
- if stream:
139
- for chunk in response:
140
- yield chunk.choices[0].message.content
141
- else:
142
- content = response.choices[0].message.content
143
- return content
144
- except openai.APIConnectionError | openai.APIStatusError as e:
145
- error_msg = (
146
- "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
147
- )
148
- gr.Error(error_msg)
149
- logger.error(error_msg)
150
- except openai.RateLimitError:
151
- warning_msg = "Hit rate limit. Wait a moment and retry."
152
- gr.Warning(warning_msg)
153
- logger.warning(warning_msg)
154
-
155
 
156
  def user_message(
157
  message: str, chat_history: list[list[str | None]]
@@ -206,7 +170,7 @@ def reset_interview() -> (
206
  # LAYOUT
207
  with gr.Blocks() as demo:
208
  gr.Markdown("# StewartLab LM Interviewer")
209
- userDisplay = gr.Markdown("")
210
 
211
  # Config values
212
  configDir = gr.State(value=CONFIG_DIR)
@@ -249,7 +213,7 @@ with gr.Blocks() as demo:
249
  outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
250
  ).then(
251
  initialize_interview,
252
- inputs=[systemMessage, initialMessage, modelArgs],
253
  outputs=[
254
  chatDisplay,
255
  chatInput,
@@ -269,7 +233,7 @@ with gr.Blocks() as demo:
269
  outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
270
  ).then(
271
  initialize_interview,
272
- inputs=[systemMessage, initialMessage, modelArgs],
273
  outputs=[
274
  chatDisplay,
275
  chatInput,
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
+ from utils import (
32
+ PromptTemplate,
33
+ convert_gradio_to_openai,
34
+ initialize_client,
35
+ seed_azure_key
36
+ )
37
 
38
 
39
  # %% Initialization
40
+ CONFIG_DIR: Path = Path("./SPIA2024")
41
+ if os.environ.get("AZURE_ENDPOINT") is None: # Set Azure credentials from local files
42
+ seed_azure_key()
43
+ client = initialize_client()
 
 
44
 
45
  # %% (functions)
46
  def load_config(
 
58
 
59
 
60
  def initialize_interview(
 
61
  initial_message: str,
 
62
  ) -> tuple[gr.Chatbot,
63
  gr.Textbox,
64
  gr.Button,
65
  gr.Button,
66
  gr.Button]:
67
  "Read system prompt and start interview. Change visibilities of elements."
 
 
 
 
68
  chat_history = [
69
  [None, initial_message]
70
  ] # First item is for user, in this case bot starts interaction.
 
116
  logger.info("Uploading complete.")
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def user_message(
121
  message: str, chat_history: list[list[str | None]]
 
170
  # LAYOUT
171
  with gr.Blocks() as demo:
172
  gr.Markdown("# StewartLab LM Interviewer")
173
+ userDisplay = gr.Markdown("", visible=False)
174
 
175
  # Config values
176
  configDir = gr.State(value=CONFIG_DIR)
 
213
  outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
214
  ).then(
215
  initialize_interview,
216
+ inputs=[initialMessage],
217
  outputs=[
218
  chatDisplay,
219
  chatInput,
 
233
  outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
234
  ).then(
235
  initialize_interview,
236
+ inputs=[initialMessage],
237
  outputs=[
238
  chatDisplay,
239
  chatInput,
utils.py CHANGED
@@ -94,46 +94,20 @@ def convert_openai_to_gradio(
94
  return chat_history
95
 
96
 
97
- def seed_openai_key(cfg: str = "~/.cfg/openai.cfg") -> None:
98
- """
99
- Reads OpenAI key from config file and adds it to environment.
100
- Assumed config location is "~/.cfg/openai.cfg"
101
- """
102
- # Get OpenAI Key
103
  config = ConfigParser()
104
  try:
105
  config.read(Path(cfg).expanduser())
106
  except:
107
  raise ValueError(f"Could not using read file at: {cfg}.")
108
- os.environ["OPENAI_API_KEY"] = config["API_KEY"]["secret"]
109
-
110
-
111
-
112
- def query_openai(
113
- messages: list[dict[str, str]],
114
- system_message: str | None,
115
- client: openai.Client,
116
- model_args: dict,
117
- stream: bool = False,
118
- ) -> None:
119
- "Utility function for calling OpenAI chat. Expects formatted messages."
120
- if not messages:
121
- messages = []
122
- if system_message:
123
- messages = [{"role": "system", "content": system_message}] + messages
124
- try:
125
- response = client.chat.completions.create(
126
- messages=messages, **model_args, stream=stream
127
- )
128
- if stream:
129
- for chunk in response:
130
- yield chunk.choices[0].message.content
131
- else:
132
- content = response.choices[0].message.content
133
- return content
134
- except openai.APIConnectionError | openai.APIStatusError as e:
135
- error_msg = (
136
- "API unreachable.\n" f"STATUS_CODE: {e.status_code}" f"ERROR: {e.response}"
137
- )
138
- except openai.RateLimitError:
139
- warning_msg = "Hit rate limit. Wait a moment and retry."
 
94
  return chat_history
95
 
96
 
97
+ def seed_azure_key(cfg: str = "~/.cfg/openai.cfg") -> None:
 
 
 
 
 
98
  config = ConfigParser()
99
  try:
100
  config.read(Path(cfg).expanduser())
101
  except:
102
  raise ValueError(f"Could not using read file at: {cfg}.")
103
+ os.environ["AZURE_ENDPOINT"] = config["AZURE"]["endpoint"]
104
+ os.environ["AZURE_SECRET"] = config["AZURE"]["key"]
105
+
106
+
107
+ def initialize_client() -> openai.AsyncClient:
108
+ client = openai.AzureOpenAI(
109
+ azure_endpoint=os.environ["AZURE_ENDPOINT"],
110
+ api_key=os.environ["AZURE_SECRET"],
111
+ api_version="2023-05-15",
112
+ )
113
+ return client