Spaces:
Running
Running
added llm_qa_chain_with_memory.py
Browse files- .gitignore +1 -0
- Makefile +4 -1
- app.py +9 -4
- app_modules/init.py +10 -1
- app_modules/llm_qa_chain_with_memory.py +32 -0
- app_modules/utils.py +4 -0
- qa_chain_with_memory_test.py +104 -0
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
*.out
|
|
|
|
| 2 |
pdfs/
|
| 3 |
.vscode/
|
| 4 |
|
|
|
|
| 1 |
*.out
|
| 2 |
+
*.log
|
| 3 |
pdfs/
|
| 4 |
.vscode/
|
| 5 |
|
Makefile
CHANGED
|
@@ -5,8 +5,11 @@ start:
|
|
| 5 |
test:
|
| 6 |
python qa_chain_test.py
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
chat:
|
| 9 |
-
python
|
| 10 |
|
| 11 |
ingest:
|
| 12 |
python ingest.py
|
|
|
|
| 5 |
test:
|
| 6 |
python qa_chain_test.py
|
| 7 |
|
| 8 |
+
long-test:
|
| 9 |
+
python qa_chain_with_memory_test.py 100
|
| 10 |
+
|
| 11 |
chat:
|
| 12 |
+
python qa_chain_with_memory_test.py chat
|
| 13 |
|
| 14 |
ingest:
|
| 15 |
python ingest.py
|
app.py
CHANGED
|
@@ -8,6 +8,8 @@ from timeit import default_timer as timer
|
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from app_modules.init import app_init
|
| 12 |
from app_modules.utils import print_llm_response
|
| 13 |
|
|
@@ -29,10 +31,13 @@ href = (
|
|
| 29 |
)
|
| 30 |
|
| 31 |
title = "Chat with PCI DSS v4"
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
description = f"""\
|
| 38 |
<div align="left">
|
|
|
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
+
os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
|
| 12 |
+
|
| 13 |
from app_modules.init import app_init
|
| 14 |
from app_modules.utils import print_llm_response
|
| 15 |
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
title = "Chat with PCI DSS v4"
|
| 34 |
+
|
| 35 |
+
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
| 36 |
+
|
| 37 |
+
# Open the file for reading
|
| 38 |
+
with open(questions_file_path, "r") as file:
|
| 39 |
+
examples = file.readlines()
|
| 40 |
+
examples = [example.strip() for example in examples]
|
| 41 |
|
| 42 |
description = f"""\
|
| 43 |
<div align="left">
|
app_modules/init.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
"""Main entrypoint for the app."""
|
|
|
|
| 2 |
import os
|
| 3 |
from timeit import default_timer as timer
|
| 4 |
from typing import List, Optional
|
|
@@ -9,7 +10,6 @@ from langchain.vectorstores.chroma import Chroma
|
|
| 9 |
from langchain.vectorstores.faiss import FAISS
|
| 10 |
|
| 11 |
from app_modules.llm_loader import LLMLoader
|
| 12 |
-
from app_modules.llm_qa_chain import QAChain
|
| 13 |
from app_modules.utils import get_device_types, init_settings
|
| 14 |
|
| 15 |
found_dotenv = find_dotenv(".env")
|
|
@@ -27,6 +27,15 @@ if os.environ.get("LANGCHAIN_DEBUG") == "true":
|
|
| 27 |
|
| 28 |
langchain.debug = True
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def app_init():
|
| 32 |
# https://github.com/huggingface/transformers/issues/17611
|
|
|
|
| 1 |
"""Main entrypoint for the app."""
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
from timeit import default_timer as timer
|
| 5 |
from typing import List, Optional
|
|
|
|
| 10 |
from langchain.vectorstores.faiss import FAISS
|
| 11 |
|
| 12 |
from app_modules.llm_loader import LLMLoader
|
|
|
|
| 13 |
from app_modules.utils import get_device_types, init_settings
|
| 14 |
|
| 15 |
found_dotenv = find_dotenv(".env")
|
|
|
|
| 27 |
|
| 28 |
langchain.debug = True
|
| 29 |
|
| 30 |
+
if os.environ.get("USER_CONVERSATION_SUMMARY_BUFFER_MEMORY") == "true":
|
| 31 |
+
from app_modules.llm_qa_chain_with_memory import QAChain
|
| 32 |
+
|
| 33 |
+
print("using llm_qa_chain_with_memory")
|
| 34 |
+
else:
|
| 35 |
+
from app_modules.llm_qa_chain import QAChain
|
| 36 |
+
|
| 37 |
+
print("using llm_qa_chain")
|
| 38 |
+
|
| 39 |
|
| 40 |
def app_init():
|
| 41 |
# https://github.com/huggingface/transformers/issues/17611
|
app_modules/llm_qa_chain_with_memory.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 2 |
+
from langchain.chains.base import Chain
|
| 3 |
+
from langchain.memory import ConversationSummaryBufferMemory
|
| 4 |
+
|
| 5 |
+
from app_modules.llm_inference import LLMInference
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class QAChain(LLMInference):
|
| 9 |
+
def __init__(self, vectorstore, llm_loader):
|
| 10 |
+
super().__init__(llm_loader)
|
| 11 |
+
self.vectorstore = vectorstore
|
| 12 |
+
|
| 13 |
+
def create_chain(self) -> Chain:
|
| 14 |
+
memory = ConversationSummaryBufferMemory(
|
| 15 |
+
llm=self.llm_loader.llm,
|
| 16 |
+
output_key="answer",
|
| 17 |
+
memory_key="chat_history",
|
| 18 |
+
max_token_limit=1024,
|
| 19 |
+
return_messages=True,
|
| 20 |
+
)
|
| 21 |
+
qa = ConversationalRetrievalChain.from_llm(
|
| 22 |
+
self.llm_loader.llm,
|
| 23 |
+
memory=memory,
|
| 24 |
+
chain_type="stuff",
|
| 25 |
+
retriever=self.vectorstore.as_retriever(
|
| 26 |
+
search_kwargs=self.llm_loader.search_kwargs
|
| 27 |
+
),
|
| 28 |
+
get_chat_history=lambda h: h,
|
| 29 |
+
return_source_documents=True,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return qa
|
app_modules/utils.py
CHANGED
|
@@ -85,6 +85,10 @@ def print_llm_response(llm_response):
|
|
| 85 |
source["page_content"] if "page_content" in source else source.page_content
|
| 86 |
)
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def get_device_types():
|
| 90 |
print("Running on: ", platform.platform())
|
|
|
|
| 85 |
source["page_content"] if "page_content" in source else source.page_content
|
| 86 |
)
|
| 87 |
|
| 88 |
+
if "chat_history" in llm_response:
|
| 89 |
+
print("\nChat History:")
|
| 90 |
+
print(llm_response["chat_history"])
|
| 91 |
+
|
| 92 |
|
| 93 |
def get_device_types():
|
| 94 |
print("Running on: ", platform.platform())
|
qa_chain_with_memory_test.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from timeit import default_timer as timer
|
| 4 |
+
|
| 5 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
| 6 |
+
from langchain.schema import LLMResult
|
| 7 |
+
|
| 8 |
+
os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
|
| 9 |
+
|
| 10 |
+
from app_modules.init import app_init
|
| 11 |
+
from app_modules.utils import print_llm_response
|
| 12 |
+
|
| 13 |
+
llm_loader, qa_chain = app_init()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MyCustomHandler(BaseCallbackHandler):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.reset()
|
| 19 |
+
|
| 20 |
+
def reset(self):
|
| 21 |
+
self.texts = []
|
| 22 |
+
|
| 23 |
+
def get_standalone_question(self) -> str:
|
| 24 |
+
return self.texts[0].strip() if len(self.texts) > 0 else None
|
| 25 |
+
|
| 26 |
+
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
| 27 |
+
"""Run when chain ends running."""
|
| 28 |
+
print("\n<on_llm_end>")
|
| 29 |
+
# print(response)
|
| 30 |
+
self.texts.append(response.generations[0][0].text)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
num_of_test_runs = 1
|
| 34 |
+
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
| 35 |
+
if len(sys.argv) > 1 and not chatting:
|
| 36 |
+
num_of_test_runs = int(sys.argv[1])
|
| 37 |
+
|
| 38 |
+
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
| 39 |
+
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
| 40 |
+
|
| 41 |
+
custom_handler = MyCustomHandler()
|
| 42 |
+
|
| 43 |
+
# Chatbot loop
|
| 44 |
+
chat_history = []
|
| 45 |
+
|
| 46 |
+
# Open the file for reading
|
| 47 |
+
file = open(questions_file_path, "r")
|
| 48 |
+
|
| 49 |
+
# Read the contents of the file into a list of strings
|
| 50 |
+
questions = file.readlines()
|
| 51 |
+
for i in range(len(questions)):
|
| 52 |
+
questions[i] = questions[i].strip()
|
| 53 |
+
|
| 54 |
+
if num_of_test_runs > 1:
|
| 55 |
+
new_questions = []
|
| 56 |
+
|
| 57 |
+
for i in range(num_of_test_runs):
|
| 58 |
+
new_questions += questions
|
| 59 |
+
|
| 60 |
+
questions = new_questions
|
| 61 |
+
|
| 62 |
+
# Close the file
|
| 63 |
+
file.close()
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
questions.append("exit")
|
| 67 |
+
|
| 68 |
+
chat_start = timer()
|
| 69 |
+
|
| 70 |
+
while True:
|
| 71 |
+
if chatting:
|
| 72 |
+
query = input("Please enter your question: ")
|
| 73 |
+
else:
|
| 74 |
+
query = questions.pop(0)
|
| 75 |
+
|
| 76 |
+
query = query.strip()
|
| 77 |
+
if query.lower() == "exit":
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
print("\nQuestion: " + query)
|
| 81 |
+
custom_handler.reset()
|
| 82 |
+
|
| 83 |
+
start = timer()
|
| 84 |
+
result = qa_chain.call_chain(
|
| 85 |
+
{"question": query, "chat_history": chat_history},
|
| 86 |
+
custom_handler,
|
| 87 |
+
None,
|
| 88 |
+
True,
|
| 89 |
+
)
|
| 90 |
+
end = timer()
|
| 91 |
+
print(f"Completed in {end - start:.3f}s")
|
| 92 |
+
|
| 93 |
+
if chat_history_enabled == "true":
|
| 94 |
+
chat_history.append((query, result["answer"]))
|
| 95 |
+
|
| 96 |
+
print_llm_response(result)
|
| 97 |
+
|
| 98 |
+
chat_end = timer()
|
| 99 |
+
total_time = chat_end - chat_start
|
| 100 |
+
print(f"Total time used: {total_time:.3f} s")
|
| 101 |
+
print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}")
|
| 102 |
+
print(
|
| 103 |
+
f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s"
|
| 104 |
+
)
|