JLW commited on
Commit
fb2f766
·
1 Parent(s): 1436284

Implement embeddings

Browse files
Files changed (1) hide show
  1. app.py +81 -17
app.py CHANGED
@@ -32,6 +32,13 @@ from langchain.prompts import PromptTemplate
32
  from polly_utils import PollyVoiceData, NEURAL_ENGINE
33
  from azure_utils import AzureVoiceData
34
 
 
 
 
 
 
 
 
35
  news_api_key = os.environ["NEWS_API_KEY"]
36
  tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
37
 
@@ -57,7 +64,8 @@ LANG_LEVEL_DEFAULT = "N/A"
57
  TRANSLATE_TO_DEFAULT = "N/A"
58
  LITERARY_STYLE_DEFAULT = "N/A"
59
  PROMPT_TEMPLATE = PromptTemplate(
60
- input_variables=["original_words", "num_words", "formality", "emotions", "lang_level", "translate_to", "literary_style"],
 
61
  template="Restate {num_words}{formality}{emotions}{lang_level}{translate_to}{literary_style}the following: \n{original_words}\n",
62
  )
63
 
@@ -150,7 +158,8 @@ def transform_text(desc, express_chain, num_words, formality,
150
 
151
  translate_to_str = ""
152
  if translate_to != TRANSLATE_TO_DEFAULT:
153
- translate_to_str = "translated to " + ("" if lang_level == TRANSLATE_TO_DEFAULT else lang_level + " level ") + translate_to + ", "
 
154
 
155
  literary_style_str = ""
156
  if literary_style != LITERARY_STYLE_DEFAULT:
@@ -216,7 +225,6 @@ def load_chain(tools_list, llm):
216
 
217
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
218
  express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
219
-
220
  return chain, express_chain
221
 
222
 
@@ -227,14 +235,22 @@ def set_openai_api_key(api_key):
227
  if api_key and api_key.startswith("sk-") and len(api_key) > 50:
228
  os.environ["OPENAI_API_KEY"] = api_key
229
  print("\n\n ++++++++++++++ Setting OpenAI API key ++++++++++++++ \n\n")
230
- print(str(datetime.datetime.now()) + ": Before OpenAI, OPENAI_API_KEY length: " + str(len(os.environ["OPENAI_API_KEY"])))
 
231
  llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
232
- print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(len(os.environ["OPENAI_API_KEY"])))
 
233
  chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
234
- print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(len(os.environ["OPENAI_API_KEY"])))
 
 
 
 
 
 
235
  os.environ["OPENAI_API_KEY"] = ""
236
- return chain, express_chain, llm
237
- return None, None, None
238
 
239
 
240
  def run_chain(chain, inp, capture_hidden_text):
@@ -311,7 +327,7 @@ class ChatWrapper:
311
  trace_chain: bool, speak_text: bool, talking_head: bool, monologue: bool, express_chain: Optional[LLMChain],
312
  num_words, formality, anticipation_level, joy_level, trust_level,
313
  fear_level, surprise_level, sadness_level, disgust_level, anger_level,
314
- lang_level, translate_to, literary_style
315
  ):
316
  """Execute the chat functionality."""
317
  self.lock.acquire()
@@ -332,7 +348,15 @@ class ChatWrapper:
332
  import openai
333
  openai.api_key = api_key
334
  if not monologue:
335
- output, hidden_text = run_chain(chain, inp, capture_hidden_text=trace_chain)
 
 
 
 
 
 
 
 
336
  else:
337
  output, hidden_text = inp, None
338
 
@@ -486,6 +510,24 @@ def update_foo(widget, state):
486
  return state
487
 
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
490
  llm_state = gr.State()
491
  history_state = gr.State()
@@ -515,12 +557,18 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
515
  # Pertains to WHISPER functionality
516
  whisper_lang_state = gr.State(WHISPER_DETECT_LANG)
517
 
 
 
 
 
 
 
518
  with gr.Tab("Chat"):
519
  with gr.Row():
520
  with gr.Column():
521
  gr.HTML(
522
  """<b><center>GPT + WolframAlpha + Whisper</center></b>
523
- <p><center>New feature in <b>Translate to</b>: Choose <b>Language level</b> (e.g. for conversation practice or explain like I'm five)</center></p>""")
524
 
525
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
526
  show_label=False, lines=1, type='password')
@@ -587,7 +635,7 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
587
 
588
  talking_head_cb = gr.Checkbox(label="Show talking head", value=True)
589
  talking_head_cb.change(update_talking_head, inputs=[talking_head_cb, talking_head_state],
590
- outputs=[talking_head_state, video_html])
591
 
592
  monologue_cb = gr.Checkbox(label="Babel fish mode (translate/restate what you enter, no conversational agent)",
593
  value=False)
@@ -715,6 +763,20 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
715
  inputs=[num_words_slider, num_words_state],
716
  outputs=[num_words_state])
717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  gr.HTML("""
719
  <p>This application, developed by <a href='https://www.linkedin.com/in/javafxpert/'>James L. Weaver</a>,
720
  demonstrates a conversational agent implemented with OpenAI GPT-3.5 and LangChain.
@@ -745,21 +807,23 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
745
  express_chain_state, num_words_state, formality_state,
746
  anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
747
  surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
748
- lang_level_state, translate_to_state, literary_style_state],
 
749
  outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
750
- # outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
751
 
752
  submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, trace_chain_state,
753
  speak_text_state, talking_head_state, monologue_state,
754
  express_chain_state, num_words_state, formality_state,
755
  anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
756
  surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
757
- lang_level_state, translate_to_state, literary_style_state],
 
758
  outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
759
- # outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
760
 
761
  openai_api_key_textbox.change(set_openai_api_key,
762
  inputs=[openai_api_key_textbox],
763
- outputs=[chain_state, express_chain_state, llm_state])
764
 
765
  block.launch(debug=True)
 
32
  from polly_utils import PollyVoiceData, NEURAL_ENGINE
33
  from azure_utils import AzureVoiceData
34
 
35
+ # Pertains to question answering functionality
36
+ from langchain.embeddings.openai import OpenAIEmbeddings
37
+ from langchain.text_splitter import CharacterTextSplitter
38
+ from langchain.vectorstores.faiss import FAISS
39
+ from langchain.docstore.document import Document
40
+ from langchain.chains.question_answering import load_qa_chain
41
+
42
  news_api_key = os.environ["NEWS_API_KEY"]
43
  tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
44
 
 
64
  TRANSLATE_TO_DEFAULT = "N/A"
65
  LITERARY_STYLE_DEFAULT = "N/A"
66
  PROMPT_TEMPLATE = PromptTemplate(
67
+ input_variables=["original_words", "num_words", "formality", "emotions", "lang_level", "translate_to",
68
+ "literary_style"],
69
  template="Restate {num_words}{formality}{emotions}{lang_level}{translate_to}{literary_style}the following: \n{original_words}\n",
70
  )
71
 
 
158
 
159
  translate_to_str = ""
160
  if translate_to != TRANSLATE_TO_DEFAULT:
161
+ translate_to_str = "translated to " + (
162
+ "" if lang_level == TRANSLATE_TO_DEFAULT else lang_level + " level ") + translate_to + ", "
163
 
164
  literary_style_str = ""
165
  if literary_style != LITERARY_STYLE_DEFAULT:
 
225
 
226
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
227
  express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
 
228
  return chain, express_chain
229
 
230
 
 
235
  if api_key and api_key.startswith("sk-") and len(api_key) > 50:
236
  os.environ["OPENAI_API_KEY"] = api_key
237
  print("\n\n ++++++++++++++ Setting OpenAI API key ++++++++++++++ \n\n")
238
+ print(str(datetime.datetime.now()) + ": Before OpenAI, OPENAI_API_KEY length: " + str(
239
+ len(os.environ["OPENAI_API_KEY"])))
240
  llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
241
+ print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(
242
+ len(os.environ["OPENAI_API_KEY"])))
243
  chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
244
+
245
+ # Pertains to question answering functionality
246
+ embeddings = OpenAIEmbeddings()
247
+ qa_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff")
248
+
249
+ print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(
250
+ len(os.environ["OPENAI_API_KEY"])))
251
  os.environ["OPENAI_API_KEY"] = ""
252
+ return chain, express_chain, llm, embeddings, qa_chain
253
+ return None, None, None, None, None
254
 
255
 
256
  def run_chain(chain, inp, capture_hidden_text):
 
327
  trace_chain: bool, speak_text: bool, talking_head: bool, monologue: bool, express_chain: Optional[LLMChain],
328
  num_words, formality, anticipation_level, joy_level, trust_level,
329
  fear_level, surprise_level, sadness_level, disgust_level, anger_level,
330
+ lang_level, translate_to, literary_style, qa_chain, docsearch, use_embeddings
331
  ):
332
  """Execute the chat functionality."""
333
  self.lock.acquire()
 
348
  import openai
349
  openai.api_key = api_key
350
  if not monologue:
351
+ if use_embeddings:
352
+ output, hidden_text = "What's on your mind?", None
353
+ if inp and inp.strip() != "" and docsearch:
354
+ docs = docsearch.similarity_search(inp)
355
+ output = qa_chain.run(input_documents=docs, question=inp)
356
+ else:
357
+ output = "Please supply some text in the the Embeddings tab."
358
+ else:
359
+ output, hidden_text = run_chain(chain, inp, capture_hidden_text=trace_chain)
360
  else:
361
  output, hidden_text = inp, None
362
 
 
510
  return state
511
 
512
 
513
+ # Pertains to question answering functionality
514
+ def update_embeddings(embeddings_text, embeddings, qa_chain):
515
+ if embeddings_text:
516
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
517
+ texts = text_splitter.split_text(embeddings_text)
518
+
519
+ docsearch = FAISS.from_texts(texts, embeddings)
520
+ print("Embeddings updated")
521
+ return docsearch
522
+
523
+
524
+ # Pertains to question answering functionality
525
+ def update_use_embeddings(widget, state):
526
+ if widget:
527
+ state = widget
528
+ return state
529
+
530
+
531
  with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
532
  llm_state = gr.State()
533
  history_state = gr.State()
 
557
  # Pertains to WHISPER functionality
558
  whisper_lang_state = gr.State(WHISPER_DETECT_LANG)
559
 
560
+ # Pertains to question answering functionality
561
+ embeddings_state = gr.State()
562
+ qa_chain_state = gr.State()
563
+ docsearch_state = gr.State()
564
+ use_embeddings_state = gr.State(False)
565
+
566
  with gr.Tab("Chat"):
567
  with gr.Row():
568
  with gr.Column():
569
  gr.HTML(
570
  """<b><center>GPT + WolframAlpha + Whisper</center></b>
571
+ <p><center>New feature: <b>Embeddings</b></center></p>""")
572
 
573
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
574
  show_label=False, lines=1, type='password')
 
635
 
636
  talking_head_cb = gr.Checkbox(label="Show talking head", value=True)
637
  talking_head_cb.change(update_talking_head, inputs=[talking_head_cb, talking_head_state],
638
+ outputs=[talking_head_state, video_html])
639
 
640
  monologue_cb = gr.Checkbox(label="Babel fish mode (translate/restate what you enter, no conversational agent)",
641
  value=False)
 
763
  inputs=[num_words_slider, num_words_state],
764
  outputs=[num_words_state])
765
 
766
+ with gr.Tab("Embeddings"):
767
+ embeddings_text_box = gr.Textbox(label="Enter text for embeddings and hit Create:",
768
+ lines=20)
769
+
770
+ with gr.Row():
771
+ use_embeddings_cb = gr.Checkbox(label="Use embeddings", value=False)
772
+ use_embeddings_cb.change(update_use_embeddings, inputs=[use_embeddings_cb, use_embeddings_state],
773
+ outputs=[use_embeddings_state])
774
+
775
+ embeddings_text_submit = gr.Button(value="Create", variant="secondary").style(full_width=False)
776
+ embeddings_text_submit.click(update_embeddings,
777
+ inputs=[embeddings_text_box, embeddings_state, qa_chain_state],
778
+ outputs=[docsearch_state])
779
+
780
  gr.HTML("""
781
  <p>This application, developed by <a href='https://www.linkedin.com/in/javafxpert/'>James L. Weaver</a>,
782
  demonstrates a conversational agent implemented with OpenAI GPT-3.5 and LangChain.
 
807
  express_chain_state, num_words_state, formality_state,
808
  anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
809
  surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
810
+ lang_level_state, translate_to_state, literary_style_state,
811
+ qa_chain_state, docsearch_state, use_embeddings_state],
812
  outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
813
+ # outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
814
 
815
  submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, trace_chain_state,
816
  speak_text_state, talking_head_state, monologue_state,
817
  express_chain_state, num_words_state, formality_state,
818
  anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
819
  surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
820
+ lang_level_state, translate_to_state, literary_style_state,
821
+ qa_chain_state, docsearch_state, use_embeddings_state],
822
  outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
823
+ # outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
824
 
825
  openai_api_key_textbox.change(set_openai_api_key,
826
  inputs=[openai_api_key_textbox],
827
+ outputs=[chain_state, express_chain_state, llm_state, embeddings_state, qa_chain_state])
828
 
829
  block.launch(debug=True)