Nikhil0987 commited on
Commit
cf940aa
·
verified ·
1 Parent(s): bb34982

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -29,50 +29,44 @@ def set_custom_prompt():
29
  input_variables=['context', 'question'])
30
  return prompt
31
 
32
- #Retrieval QA Chain
33
  def retrieval_qa_chain(llm, prompt, db):
34
  qa_chain = RetrievalQA.from_chain_type(llm=llm,
35
- chain_type='stuff',
36
- retriever=db.as_retriever(search_kwargs={'k': 2}),
37
- return_source_documents=True,
38
- chain_type_kwargs={'prompt': prompt}
39
- )
40
  return qa_chain
41
 
42
- #Loading the model
43
  def load_llm():
44
  # Load the locally downloaded model here
45
  llm = CTransformers(
46
- model = "TheBloke/Llama-2-7B-Chat-GGML",
47
  model_type="llama",
48
- max_new_tokens = 512,
49
- temperature = 0.5
50
  )
51
  return llm
52
 
53
- #QA Model Function
54
  def qa_bot():
55
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
56
  model_kwargs={'device': 'cpu'})
57
- db = FAISS.load_local(DB_FAISS_PATH, embeddings)
58
  llm = load_llm()
59
  qa_prompt = set_custom_prompt()
60
  qa = retrieval_qa_chain(llm, qa_prompt, db)
61
- db = faiss.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
62
- # 1. Verify file ownership
63
- if not os.path.exists(DB_FAISS_PATH) or os.stat(DB_FAISS_PATH).st_uid != os.getuid():
64
- raise RuntimeError("Vector store file may have been tampered with. Aborting.")
65
-
66
-
67
  return qa
68
 
69
- #output function
70
  def final_result(query):
71
  qa_result = qa_bot()
72
  response = qa_result({'query': query})
73
  return response
74
 
75
- #chainlit code
76
  @cl.on_chat_start
77
  async def start():
78
  chain = qa_bot()
@@ -85,7 +79,7 @@ async def start():
85
 
86
  @cl.on_message
87
  async def main(message: cl.Message):
88
- chain = cl.user_session.get("chain")
89
  cb = cl.AsyncLangchainCallbackHandler(
90
  stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
91
  )
@@ -101,7 +95,6 @@ async def main(message: cl.Message):
101
 
102
  await cl.Message(content=answer).send()
103
 
104
-
105
  def main():
106
  st.title("Medical Bot")
107
  st.text_input("Enter your query:", key="query")
@@ -113,4 +106,4 @@ def main():
113
  st.markdown(response)
114
 
115
  if __name__ == "__main__":
116
- main()
 
29
  input_variables=['context', 'question'])
30
  return prompt
31
 
32
+ # Retrieval QA Chain
33
  def retrieval_qa_chain(llm, prompt, db):
34
  qa_chain = RetrievalQA.from_chain_type(llm=llm,
35
+ chain_type='stuff',
36
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
37
+ return_source_documents=True,
38
+ chain_type_kwargs={'prompt': prompt}
39
+ )
40
  return qa_chain
41
 
42
+ # Loading the model
43
  def load_llm():
44
  # Load the locally downloaded model here
45
  llm = CTransformers(
46
+ model="TheBloke/Llama-2-7B-Chat-GGML",
47
  model_type="llama",
48
+ max_new_tokens=512,
49
+ temperature=0.5
50
  )
51
  return llm
52
 
53
+ # QA Model Function
54
  def qa_bot():
55
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
56
  model_kwargs={'device': 'cpu'})
57
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
58
  llm = load_llm()
59
  qa_prompt = set_custom_prompt()
60
  qa = retrieval_qa_chain(llm, qa_prompt, db)
 
 
 
 
 
 
61
  return qa
62
 
63
+ # Output function
64
  def final_result(query):
65
  qa_result = qa_bot()
66
  response = qa_result({'query': query})
67
  return response
68
 
69
+ # Chainlit code
70
  @cl.on_chat_start
71
  async def start():
72
  chain = qa_bot()
 
79
 
80
  @cl.on_message
81
  async def main(message: cl.Message):
82
+ chain = cl.user_session.get("chain")
83
  cb = cl.AsyncLangchainCallbackHandler(
84
  stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
85
  )
 
95
 
96
  await cl.Message(content=answer).send()
97
 
 
98
  def main():
99
  st.title("Medical Bot")
100
  st.text_input("Enter your query:", key="query")
 
106
  st.markdown(response)
107
 
108
  if __name__ == "__main__":
109
+ main()