briankchan commited on
Commit
12e8ac9
·
1 Parent(s): 280e63f

Add setting for Azure OpenAI

Browse files
Files changed (1) hide show
  1. app.py +50 -11
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import re
5
  import gradio as gr
6
  from langchain.chains import LLMChain
7
- from langchain.chat_models import PromptLayerChatOpenAI
8
  from langchain.memory import ConversationBufferMemory
9
  from langchain.prompts import PromptTemplate
10
  from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
@@ -48,10 +48,39 @@ Use this output format:
48
  - [textual evidence 2] - outside source"""
49
  # 7. Make a section for each key point. In each section, make a bullet list of textual evidence from task 3 that relate to that key point.
50
 
51
- def load_chain(api_key):
52
- api_key = os.environ["OPENAI_API_KEY"] if api_key == "" or api_key.isspace() else api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if api_key:
54
- llm = PromptLayerChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_key=api_key, pl_tags=["grammar"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  prompt1 = PromptTemplate(
56
  input_variables=["content"],
57
  template=GRAMMAR_PROMPT
@@ -85,7 +114,7 @@ def run(content, chain):
85
  return chain.run(content)
86
 
87
  def run_followup(followup_question, input_vars, chain, chat):
88
-
89
  history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
90
  for m in chain.memory.chat_memory.messages]
91
  prompt = ChatPromptTemplate.from_messages([
@@ -189,7 +218,7 @@ def select_diff(evt: gr.SelectData, changes):
189
  # original, edited = text.split("→")
190
  return f"Why is it better to change [{original}] to [{edited}]?"
191
 
192
- block = gr.Blocks(css="""
193
  .diff-component {
194
  white-space: pre-wrap;
195
  }
@@ -197,13 +226,14 @@ block = gr.Blocks(css="""
197
  white-space: normal;
198
  }
199
  """)
200
- with block:
201
  # api_key = gr.Textbox(
202
  # placeholder="Paste your OpenAPI API key here (sk-...)",
203
  # show_label=False,
204
  # lines=1,
205
  # type="password"
206
  # )
 
207
  gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="https://thinkcol.com/thinkcol-logo.png" alt="ThinkCol" /></a></div>""")
208
  gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""")
209
  content = gr.Textbox(
@@ -291,6 +321,13 @@ with block:
291
  # followup_answer_custom = gr.Textbox(
292
  # label="Answer"
293
  # )
 
 
 
 
 
 
 
294
 
295
  changes = gr.State()
296
  edited = gr.State()
@@ -301,7 +338,9 @@ with block:
301
 
302
  chain_custom = gr.State()
303
 
304
- # api_key.change(load_chain, [api_key], [chain, llm, chain_intro, chain_body1])
 
 
305
 
306
  inputs = [content, chain]
307
  outputs = [output_before, output_after, changes, edited]
@@ -321,9 +360,9 @@ with block:
321
  submit_intro.click(run, [content, chain_intro], output_intro)
322
  submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only
323
  # submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain?
324
-
325
  # followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom)
326
 
327
- block.load(load_chain, gr.State(""), [chain, llm, chain_intro, chain_body1])
328
 
329
- block.launch(debug=True)
 
4
  import re
5
  import gradio as gr
6
  from langchain.chains import LLMChain
7
+ from langchain.chat_models import PromptLayerChatOpenAI, AzureChatOpenAI
8
  from langchain.memory import ConversationBufferMemory
9
  from langchain.prompts import PromptTemplate
10
  from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
 
48
  - [textual evidence 2] - outside source"""
49
  # 7. Make a section for each key point. In each section, make a bullet list of textual evidence from task 3 that relate to that key point.
50
 
51
+
52
+ class AzurePromptLayerChatOpenAI(AzureChatOpenAI, PromptLayerChatOpenAI):
53
+ pass
54
+
55
+
56
+
57
+ def load_chain(api_key, api_type):
58
+ if api_key == "" or api_key.isspace():
59
+ if api_type == "OpenAI":
60
+ api_key = os.environ.get("OPENAI_API_KEY", None)
61
+ elif api_type == "Azure OpenAI":
62
+ api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
63
+ else:
64
+ raise RuntimeError("Unknown API type? " + api_type)
65
+
66
+
67
  if api_key:
68
+ shared_args = {
69
+ "temperature": 0,
70
+ "model_name": "gpt-3.5-turbo",
71
+ "openai_api_key": api_key,
72
+ "pl_tags": ["grammar"],
73
+ }
74
+ if api_type == "OpenAI":
75
+ llm = PromptLayerChatOpenAI(**shared_args)
76
+ elif api_type == "Azure OpenAI":
77
+ llm = AzurePromptLayerChatOpenAI(
78
+ openai_api_base = os.environ.get("AZURE_OPENAI_API_BASE", None),
79
+ openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
80
+ deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None),
81
+ **shared_args
82
+ )
83
+
84
  prompt1 = PromptTemplate(
85
  input_variables=["content"],
86
  template=GRAMMAR_PROMPT
 
114
  return chain.run(content)
115
 
116
  def run_followup(followup_question, input_vars, chain, chat):
117
+
118
  history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
119
  for m in chain.memory.chat_memory.messages]
120
  prompt = ChatPromptTemplate.from_messages([
 
218
  # original, edited = text.split("→")
219
  return f"Why is it better to change [{original}] to [{edited}]?"
220
 
221
+ demo = gr.Blocks(css="""
222
  .diff-component {
223
  white-space: pre-wrap;
224
  }
 
226
  white-space: normal;
227
  }
228
  """)
229
+ with demo:
230
  # api_key = gr.Textbox(
231
  # placeholder="Paste your OpenAPI API key here (sk-...)",
232
  # show_label=False,
233
  # lines=1,
234
  # type="password"
235
  # )
236
+ api_key = gr.State("")
237
  gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="https://thinkcol.com/thinkcol-logo.png" alt="ThinkCol" /></a></div>""")
238
  gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""")
239
  content = gr.Textbox(
 
321
  # followup_answer_custom = gr.Textbox(
322
  # label="Answer"
323
  # )
324
+ with gr.Tab("Settings"):
325
+ api_type = gr.Radio(
326
+ ["OpenAI", "Azure OpenAI"],
327
+ value="OpenAI",
328
+ label="Server",
329
+ info="You can try changing this if responses are slow."
330
+ )
331
 
332
  changes = gr.State()
333
  edited = gr.State()
 
338
 
339
  chain_custom = gr.State()
340
 
341
+
342
+ # api_key.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
343
+ api_type.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
344
 
345
  inputs = [content, chain]
346
  outputs = [output_before, output_after, changes, edited]
 
360
  submit_intro.click(run, [content, chain_intro], output_intro)
361
  submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only
362
  # submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain?
363
+
364
  # followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom)
365
 
366
+ demo.load(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
367
 
368
+ demo.launch(debug=True)