achapman commited on
Commit
55e7097
·
1 Parent(s): 8c5710f

Remove pdf parsing for now

Browse files
aimakerspace/openai_utils/embedding.py CHANGED
@@ -21,6 +21,8 @@ class EmbeddingModel:
21
  self.embeddings_model_name = embeddings_model_name
22
 
23
  async def async_get_embeddings(self, list_of_text: List[str]) -> List[List[float]]:
 
 
24
  embedding_response = await self.async_client.embeddings.create(
25
  input=list_of_text, model=self.embeddings_model_name
26
  )
 
21
  self.embeddings_model_name = embeddings_model_name
22
 
23
  async def async_get_embeddings(self, list_of_text: List[str]) -> List[List[float]]:
24
+ if not list_of_text:
25
+ raise(ValueError("Cannot embed nonexistent text."))
26
  embedding_response = await self.async_client.embeddings.create(
27
  input=list_of_text, model=self.embeddings_model_name
28
  )
aimakerspace/qa_pipeline.py CHANGED
@@ -1,11 +1,30 @@
 
 
1
  from rank_bm25 import BM25Plus
2
- from langchain.vectorstores import Qdrant
 
 
3
 
4
  from .openai_utils.prompts import (
 
5
  SystemRolePrompt
6
  )
7
- from .vectordatabase import VectorDatabase
8
- from .openai_utils.chatmodel import ChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Utility function for reranking
11
  def bm25plus_rerank(corpus, query, initial_ranking, top_n=3):
@@ -18,12 +37,6 @@ def bm25plus_rerank(corpus, query, initial_ranking, top_n=3):
18
  ranked_indices = [initial_ranking[i] for i in bm25_scores.argsort()[::-1]]
19
  return ranked_indices[:top_n]
20
 
21
- def search_by_text(qdrant: Qdrant, query_text: str, k: int, return_as_text: bool = False) -> List[Tuple[str, float]]:
22
- results = qdrant.similarity_search_with_score(query_text, k)
23
- if return_as_text:
24
- return [result[0].page_content for result in results]
25
- return [(result[0].page_content, result[1]) for result in results]
26
-
27
 
28
  class RetrievalAugmentedQAPipeline:
29
  def __init__(self, llm: ChatOpenAI(), vector_db_retriever) -> None:
@@ -31,10 +44,7 @@ class RetrievalAugmentedQAPipeline:
31
  self.vector_db_retriever = vector_db_retriever
32
 
33
  async def arun_pipeline(self, user_query: str):
34
- if type(self.vector_db_retriever == "Qdrant"):
35
- context_list = search_by_text(self.vector_db_retriever,user_query, k=4)
36
- else:
37
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
38
 
39
  context_prompt = ""
40
  for context in context_list:
@@ -56,10 +66,8 @@ class RerankedQAPipeline(RetrievalAugmentedQAPipeline):
56
  async def arun_pipeline(self, user_query: str, rerank: bool=False) -> str:
57
  # Retrieve the top 10 results. Either return the top 3, or rerank with BM25 and then return
58
  # the new top 3
59
- if type(self.vector_db_retriever == "Qdrant"):
60
- context_list = search_by_text(self.vector_db_retriever,user_query, k=10)
61
- else:
62
- context_list = self.vector_db_retriever.search_by_text(user_query, k=10)
63
  # Convert from tuples to strings
64
  context_list_str = [context_list[i][0] for i in range(len(context_list))]
65
 
@@ -72,10 +80,13 @@ class RerankedQAPipeline(RetrievalAugmentedQAPipeline):
72
  reranked_indices = bm25plus_rerank(context_list_str, user_query, initial_ranking, top_n=n)
73
  reranked_contexts = [context_list_str[i] for i in reranked_indices]
74
 
75
- context_prompt = "\n\n".join(context for context in reranked_contexts) + "\n\n"
 
 
 
 
76
 
77
- formatted_system_prompt = SystemRolePrompt(system_prompt).create_message() if system_prompt else rag_prompt.create_message()
78
- formatted_user_prompt = user_prompt.create_message(user_query=user_query, context=context_prompt)
79
 
80
  async def generate_response():
81
  async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
 
1
+ from typing import List, Tuple
2
+
3
  from rank_bm25 import BM25Plus
4
+
5
+ from .vectordatabase import VectorDatabase
6
+ from .openai_utils.chatmodel import ChatOpenAI
7
 
8
  from .openai_utils.prompts import (
9
+ UserRolePrompt,
10
  SystemRolePrompt
11
  )
12
+
13
+ system_template = """\
14
+ Use the provided context to answer the user's question. Answer in one paragraph and provide lots
15
+ of details based on the context. If you are certain the context is not relevant, apologize and say you don't
16
+ have enough information to answer."""
17
+ system_role_prompt = SystemRolePrompt(system_template)
18
+
19
+ user_prompt_template = """\
20
+ Context:
21
+ {context}
22
+
23
+ The question to answer is:
24
+ {question}
25
+ """
26
+ user_role_prompt = UserRolePrompt(user_prompt_template)
27
+
28
 
29
  # Utility function for reranking
30
  def bm25plus_rerank(corpus, query, initial_ranking, top_n=3):
 
37
  ranked_indices = [initial_ranking[i] for i in bm25_scores.argsort()[::-1]]
38
  return ranked_indices[:top_n]
39
 
 
 
 
 
 
 
40
 
41
  class RetrievalAugmentedQAPipeline:
42
  def __init__(self, llm: ChatOpenAI(), vector_db_retriever) -> None:
 
44
  self.vector_db_retriever = vector_db_retriever
45
 
46
  async def arun_pipeline(self, user_query: str):
47
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
 
 
 
48
 
49
  context_prompt = ""
50
  for context in context_list:
 
66
  async def arun_pipeline(self, user_query: str, rerank: bool=False) -> str:
67
  # Retrieve the top 10 results. Either return the top 3, or rerank with BM25 and then return
68
  # the new top 3
69
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=10)
70
+
 
 
71
  # Convert from tuples to strings
72
  context_list_str = [context_list[i][0] for i in range(len(context_list))]
73
 
 
80
  reranked_indices = bm25plus_rerank(context_list_str, user_query, initial_ranking, top_n=n)
81
  reranked_contexts = [context_list_str[i] for i in reranked_indices]
82
 
83
+ context_prompt = ""
84
+ for context in context_list:
85
+ context_prompt += context[0] + "\n"
86
+
87
+ formatted_system_prompt = system_role_prompt.create_message()
88
 
89
+ formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
 
90
 
91
  async def generate_response():
92
  async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
app.py CHANGED
@@ -1,39 +1,20 @@
1
- import os
2
  from typing import List
3
  import tempfile
4
 
5
  import chainlit as cl
6
  from chainlit.types import AskFileResponse
7
- from PyPDF2 import PdfReader
8
 
9
- from langchain.vectorstores import Qdrant
10
- from langchain.embeddings import OpenAIEmbeddings
11
 
12
  from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
13
- from aimakerspace.openai_utils.prompts import (
14
- UserRolePrompt,
15
- SystemRolePrompt
16
- )
17
  from aimakerspace.openai_utils.embedding import EmbeddingModel
18
  from aimakerspace.vectordatabase import VectorDatabase
19
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
20
  from aimakerspace.qa_pipeline import RerankedQAPipeline
21
 
22
- system_template = """\
23
- Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
24
- system_role_prompt = SystemRolePrompt(system_template)
25
-
26
- user_prompt_template = """\
27
- Context:
28
- {context}
29
-
30
- Question:
31
- {question}
32
- """
33
- user_role_prompt = UserRolePrompt(user_prompt_template)
34
-
35
  text_splitter = CharacterTextSplitter()
36
- EmbeddingModel = OpenAIEmbeddings()
37
 
38
  def process_text_file(file: AskFileResponse):
39
 
@@ -49,33 +30,18 @@ def process_text_file(file: AskFileResponse):
49
  return texts
50
 
51
  def process_pdf(file: AskFileResponse) -> list[str]:
52
-
53
- # Create a temporary file to store the PDF content
54
  with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pdf") as temp_file:
55
  temp_file_path = temp_file.name
56
  temp_file.write(file.content)
 
57
 
58
- # Read the PDF content
59
- with open(temp_file_path, "rb") as f:
60
- pdf_reader = PdfReader(f)
61
- text = ""
62
- for page in pdf_reader.pages:
63
- text += page.extract_text()
64
-
65
- # Assuming you have a text splitter similar to the one used for text files
66
- texts = text_splitter.split_texts([text])
67
- return texts
68
 
69
- # Function to build the vector database from a list of texts
70
- async def build_qdrant_vector_database(list_of_text: List[str]) -> Qdrant:
71
- embeddings = await embedding_model.async_get_embeddings(list_of_text)
72
- qdrant = Qdrant.from_texts(
73
- texts=list_of_text,
74
- embeddings=[embedding.tolist() for embedding in embeddings],
75
- embedding=embedding_model,
76
- collection_name="vectors"
77
- )
78
- return qdrant
79
 
80
  @cl.on_chat_start
81
  async def on_chat_start():
@@ -84,9 +50,9 @@ async def on_chat_start():
84
  # Wait for the user to upload a file
85
  while files == None:
86
  files = await cl.AskFileMessage(
87
- content="Please upload a Text or PDF File file to begin!",
88
- accept=["text/plain","application/pdf"],
89
- max_size_mb=5,
90
  timeout=180,
91
  ).send()
92
 
@@ -98,29 +64,30 @@ async def on_chat_start():
98
  await msg.send()
99
 
100
  # load the file
101
- if "pdf" not in file.name.lower():
102
- texts = process_text_file(file)
103
- else: texts = process_pdf(file)
104
-
105
- print(f"Processing {len(texts)} text chunks")
106
-
107
- # Create a dict vector store
108
- qdrant_db = await build_qdrant_vector_database(texts)
109
-
110
- chat_openai = ChatOpenAI()
111
-
112
- # Create a chain
113
- retrieval_augmented_qa_pipeline = RerankedQAPipeline(
114
- vector_db_retriever=qdrant_db,
115
- llm=chat_openai,
116
- True
117
- )
118
-
119
- # Let the user know that the system is ready
120
- msg.content = f"Processing `{file.name}` done. You can now ask questions!"
121
- await msg.update()
 
122
 
123
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
124
 
125
 
126
  @cl.on_message
@@ -128,7 +95,7 @@ async def main(message):
128
  chain = cl.user_session.get("chain")
129
 
130
  msg = cl.Message(content="")
131
- result = await chain.arun_pipeline(message.content)
132
 
133
  async for stream_resp in result["response"]:
134
  await msg.stream_token(stream_resp)
 
 
1
  from typing import List
2
  import tempfile
3
 
4
  import chainlit as cl
5
  from chainlit.types import AskFileResponse
6
+ import fitz
7
 
8
+ from langchain_community.embeddings import OpenAIEmbeddings
 
9
 
10
  from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
 
 
 
 
11
  from aimakerspace.openai_utils.embedding import EmbeddingModel
12
  from aimakerspace.vectordatabase import VectorDatabase
13
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
14
  from aimakerspace.qa_pipeline import RerankedQAPipeline
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  text_splitter = CharacterTextSplitter()
17
+ embedding_model = OpenAIEmbeddings()
18
 
19
  def process_text_file(file: AskFileResponse):
20
 
 
30
  return texts
31
 
32
  def process_pdf(file: AskFileResponse) -> list[str]:
 
 
33
  with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pdf") as temp_file:
34
  temp_file_path = temp_file.name
35
  temp_file.write(file.content)
36
+ temp_file.flush()
37
 
38
+ text = ""
39
+ with fitz.open(temp_file_path) as doc:
40
+ for page in doc:
41
+ text += page.get_text().strip()
 
 
 
 
 
 
42
 
43
+ text_list = text_splitter.split_texts(text)
44
+ return text_list
 
 
 
 
 
 
 
 
45
 
46
  @cl.on_chat_start
47
  async def on_chat_start():
 
50
  # Wait for the user to upload a file
51
  while files == None:
52
  files = await cl.AskFileMessage(
53
+ content="Please upload a Text File file to begin!",
54
+ accept=["text/plain"],
55
+ max_size_mb=20,
56
  timeout=180,
57
  ).send()
58
 
 
64
  await msg.send()
65
 
66
  # load the file
67
+ texts = process_text_file(file)
68
+
69
+ if not texts:
70
+ await cl.Message(content=f"Error: Could not extract any text from input file").send()
71
+ else:
72
+ print(f"Processing {len(texts)} text chunks")
73
+
74
+ # Create a dict vector store
75
+ vector_db = VectorDatabase()
76
+ vector_db = await vector_db.abuild_from_list(texts)
77
+
78
+ chat_openai = ChatOpenAI()
79
+
80
+ # Create a chain
81
+ retrieval_augmented_qa_pipeline = RerankedQAPipeline(
82
+ vector_db_retriever=vector_db,
83
+ llm=chat_openai,
84
+ )
85
+
86
+ # Let the user know that the system is ready
87
+ msg.content = f"Processing `{file.name}` done. You can now ask questions!"
88
+ await msg.update()
89
 
90
+ cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
91
 
92
 
93
  @cl.on_message
 
95
  chain = cl.user_session.get("chain")
96
 
97
  msg = cl.Message(content="")
98
+ result = await chain.arun_pipeline(message.content,rerank=True)
99
 
100
  async for stream_resp in result["response"]:
101
  await msg.stream_token(stream_resp)
requirements.txt CHANGED
@@ -2,5 +2,8 @@ numpy
2
  chainlit==0.7.700
3
  openai
4
  rank_bm25
5
- PyPDF2
6
- langchain>=0.2
 
 
 
 
2
  chainlit==0.7.700
3
  openai
4
  rank_bm25
5
+ pymupdf
6
+ langchain>=0.2
7
+ langchain-community
8
+ tiktoken
9
+ langchain-openai