abhivsh commited on
Commit
5052352
·
verified ·
1 Parent(s): dc50cde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -13,6 +13,9 @@ import initialize
13
  from langchain_openai import ChatOpenAI
14
  from langchain.memory import ConversationBufferMemory
15
  from langchain.chains import ConversationalRetrievalChain
 
 
 
16
 
17
  import gradio as gr
18
  import os
@@ -49,16 +52,21 @@ def chat_query(question):
49
 
50
  llm = ChatOpenAI(model=llm_name, temperature=0.1, api_key = OPENAI_API_KEY)
51
 
52
- # Memory
53
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
54
-
55
- # Conversation Retrival Chain
56
  retriever=vectordb.as_retriever()
57
- qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)
58
 
59
  # Replace input() with question variable for Gradio
60
- result = qa({"question": question})
61
- return result['answer']
 
 
 
 
 
 
 
62
 
63
 
64
  # logo_path = os.path.join(os.getcwd(), "Logo.png")
 
13
  from langchain_openai import ChatOpenAI
14
  from langchain.memory import ConversationBufferMemory
15
  from langchain.chains import ConversationalRetrievalChain
16
+ from langchain.chains import VectorDBQA
17
+ from langchain.llms import OpenAI
18
+
19
 
20
  import gradio as gr
21
  import os
 
52
 
53
  llm = ChatOpenAI(model=llm_name, temperature=0.1, api_key = OPENAI_API_KEY)
54
 
55
+ # Conversation Retrival Chain with Memory
56
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
 
57
  retriever=vectordb.as_retriever()
58
+ #qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)
59
 
60
  # Replace input() with question variable for Gradio
61
+ # result = qa({"question": question})
62
+ # return result['answer']
63
+
64
+ # Chatbot only answers based on Documents
65
+ qa = VectorDBQA.from_chain_type(llm=OpenAI(openai_api_key = OPENAI_API_KEY, ), chain_type="stuff", vectorstore=vectordb)
66
+ result = qa.query(question)
67
+ return result
68
+
69
+
70
 
71
 
72
  # logo_path = os.path.join(os.getcwd(), "Logo.png")