viboognesh commited on
Commit
d25a986
·
verified ·
1 Parent(s): 94dc18e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +117 -84
main.py CHANGED
@@ -1,17 +1,16 @@
1
  from fastapi import FastAPI, File, UploadFile, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from typing import List
4
- import os
5
- import aiofiles
6
- import uuid
7
- import shutil
8
 
 
9
  from dotenv import load_dotenv
10
 
11
- from langchain_community.document_loaders import TextLoader, Docx2txtLoader, PyPDFLoader
12
  from langchain.prompts import ChatPromptTemplate, PromptTemplate
13
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
14
- from langchain_community.document_loaders.csv_loader import CSVLoader
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain.memory import ConversationBufferMemory
17
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
@@ -20,37 +19,97 @@ from langchain.chains import ConversationalRetrievalChain
20
 
21
  load_dotenv()
22
 
23
- app = FastAPI()
 
 
24
 
25
- origins = ["https://viboognesh-react-chat.static.hf.space"]
26
- # origins = ["http://localhost:3000"]
 
 
 
27
 
28
- app.add_middleware(
29
- CORSMiddleware,
30
- allow_origins=origins,
31
- allow_credentials=True,
32
- allow_methods=["GET", "POST"],
33
- allow_headers=["*"],
34
- )
35
 
 
 
36
 
37
- class ConversationChainManager:
38
- _instance = None
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def __new__(cls, *args, **kwargs):
41
- if not cls._instance:
42
- cls._instance = super(ConversationChainManager, cls).__new__(
43
- cls, *args, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
- return cls._instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def __init__(self):
48
- self.conversation_chain = None
49
  self.llm_model = ChatOpenAI()
50
  self.embeddings = OpenAIEmbeddings()
 
51
 
52
- def create_conversational_chain(self, file_paths: List[str]):
53
- docs = self.get_docs(file_paths)
54
  memory = ConversationBufferMemory(
55
  memory_key="chat_history", return_messages=True
56
  )
@@ -59,7 +118,7 @@ class ConversationChainManager:
59
  self.embeddings,
60
  )
61
  retriever = vectordb.as_retriever()
62
- self.conversation_chain = ConversationalRetrievalChain.from_llm(
63
  llm=self.llm_model,
64
  retriever=retriever,
65
  condense_question_prompt=self.get_question_generator_prompt(),
@@ -70,39 +129,10 @@ class ConversationChainManager:
70
  memory=memory,
71
  )
72
 
73
- @staticmethod
74
- def get_docs(file_paths: List[str]) -> List:
75
- docs = []
76
- for file_path in file_paths:
77
- if file_path.endswith(".txt"):
78
- loader = TextLoader(file_path)
79
- document = loader.load()
80
- splitter = RecursiveCharacterTextSplitter(
81
- chunk_size=1000, chunk_overlap=100
82
- )
83
- txt_documents = splitter.split_documents(document)
84
- docs.extend(txt_documents)
85
- elif file_path.endswith(".csv"):
86
- loader = CSVLoader(file_path)
87
- csv_documents = loader.load()
88
- docs.extend(csv_documents)
89
- elif file_path.endswith(".docx"):
90
- loader = Docx2txtLoader(file_path)
91
- document = loader.load()
92
- splitter = RecursiveCharacterTextSplitter(
93
- chunk_size=1000, chunk_overlap=100
94
- )
95
- docx_documents = splitter.split_documents(document)
96
- docs.extend(docx_documents)
97
- elif file_path.endswith(".pdf"):
98
- loader = PyPDFLoader(file_path)
99
- pdf_documents = loader.load_and_split()
100
- docs.extend(pdf_documents)
101
- os.remove(file_path)
102
- return docs
103
 
104
  @staticmethod
105
- def get_document_prompt() -> PromptTemplate:
106
  document_template = """Document Content:{page_content}
107
  Document Path: {source}"""
108
  return PromptTemplate(
@@ -111,7 +141,7 @@ class ConversationChainManager:
111
  )
112
 
113
  @staticmethod
114
- def get_question_generator_prompt() -> PromptTemplate:
115
  question_generator_template = """Combine the chat history and follow up question into
116
  a standalone question.\n Chat History: {chat_history}\n
117
  Follow up question: {question}
@@ -119,7 +149,7 @@ class ConversationChainManager:
119
  return PromptTemplate.from_template(question_generator_template)
120
 
121
  @staticmethod
122
- def get_final_prompt() -> ChatPromptTemplate:
123
  final_prompt_template = """Answer question based on the context and chat_history.
124
  If you cannot find answers, ask more related questions from the user.
125
  Use only the basename of the file path as name of the documents.
@@ -145,27 +175,33 @@ class ConversationChainManager:
145
  return ChatPromptTemplate.from_messages(messages)
146
 
147
 
148
- app.state.conversational_chain_manager = ConversationChainManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  @app.post("/upload_files/")
152
- async def upload_files(
153
- files: List[UploadFile] = File(...),
154
- conversation_chain_manager: ConversationChainManager = Depends(
155
- lambda: app.state.conversational_chain_manager
156
- ),
157
- ):
158
- session_folder = f"uploads"
159
- os.makedirs(session_folder, exist_ok=True)
160
- file_paths = []
161
  for file in files:
162
- file_path = f"{session_folder}/{file.filename}"
163
- async with aiofiles.open(file_path, "wb") as out_file:
164
- content = await file.read()
165
- await out_file.write(content)
166
- file_paths.append(file_path)
167
 
168
- conversation_chain_manager.create_conversational_chain(file_paths)
169
  print("conversational_chain_manager created")
170
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
171
 
@@ -173,16 +209,13 @@ async def upload_files(
173
  @app.get("/predict/")
174
  async def predict(
175
  query: str,
176
- conversation_chain_manager: ConversationChainManager = Depends(
177
- lambda: app.state.conversational_chain_manager
178
- ),
179
  ):
180
- if conversation_chain_manager.conversation_chain is None:
181
  system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
182
- response = conversation_chain_manager.llm_model.invoke(system_prompt + query)
183
  answer = response.content
184
  else:
185
- response = conversation_chain_manager.conversation_chain.invoke(query)
186
  answer = response["answer"]
187
 
188
  print("predict called")
 
1
  from fastapi import FastAPI, File, UploadFile, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List, Dict, Any
4
+ from io import BytesIO, StringIO
5
+ from docx import Document
6
+ from langchain.docstore.document import Document as langchain_Document
7
+ from PyPDF2 import PdfReader
8
 
9
+ import csv
10
  from dotenv import load_dotenv
11
 
 
12
  from langchain.prompts import ChatPromptTemplate, PromptTemplate
13
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
 
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
 
19
 
20
  load_dotenv()
21
 
22
+ class Document_Processor:
23
+ def __init__(self , file_details: List[Dict[Any, str]]):
24
+ self.file_details = file_details
25
 
26
+ def get_docs(self) -> List[langchain_Document]:
27
+ docs = []
28
+ for file_detail in self.file_details:
29
+ if file_detail["name"].endswith(".txt"):
30
+ docs.extend(self.get_txt_docs(file_detail))
31
 
32
+ elif file_detail["name"].endswith(".csv"):
33
+ docs.extend(self.get_csv_docs(file_detail))
34
+
35
+ elif file_detail["name"].endswith(".docx"):
36
+ docs.extend(self.get_docx_docs(file_detail))
 
 
37
 
38
+ elif file_detail["name"].endswith(".pdf"):
39
+ docs.extend(self.get_pdf_docs(file_detail))
40
 
41
+ return docs
42
+
43
+ @staticmethod
44
+ def get_txt_docs(self, file_detail: Dict[str, Any]) -> List[langchain_Document]:
45
+ text = file_detail["content"].decode("utf-8")
46
+ source = file_detail["name"]
47
+ text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=1000, chunk_overlap=100
49
+ )
50
+ text_docs = text_splitter.create_documents(
51
+ [text], metadatas=[{"source": source}]
52
+ )
53
+ return text_docs
54
 
55
+ @staticmethod
56
+ def get_csv_docs(self, file_detail: Dict[str, Any]) -> List[langchain_Document]:
57
+ csv_data = file_detail["content"]
58
+ source = file_detail["name"]
59
+ csv_string = csv_data.decode("utf-8")
60
+ # Use StringIO to create a file-like object from the string
61
+ csv_file = StringIO(csv_string)
62
+ csv_reader = csv.DictReader(csv_file)
63
+ csv_docs = []
64
+ for row in csv_reader:
65
+ # Convert each row into a dictionary of key/value pairs
66
+ page_content = ""
67
+ for key, value in row.items():
68
+ page_content += f"{key}: {value}\n"
69
+ doc = langchain_Document(
70
+ page_content=page_content, metadata={"source": source}
71
  )
72
+ csv_docs.append(doc)
73
+ return csv_docs
74
+
75
+ @staticmethod
76
+ def get_pdf_docs(self, file_detail: Dict[str, Any]) -> List[langchain_Document]:
77
+ pdf_content = BytesIO(file_detail["content"])
78
+ source = file_detail["name"]
79
+
80
+ reader = PdfReader(pdf_content)
81
+ pdf_text = ""
82
+ for page in reader.pages:
83
+ pdf_text += page.extract_text() + "\n"
84
+
85
+ pdf_docs = RecursiveCharacterTextSplitter.create_documents(
86
+ texts=[pdf_text], metadatas=[{"source": source}]
87
+ )
88
+ return pdf_docs
89
+
90
+ @staticmethod
91
+ def get_docx_docs(self, file_detail: Dict[str, Any]) -> List[langchain_Document]:
92
+ docx_content = BytesIO(file_detail["content"])
93
+ source = file_detail["name"]
94
+
95
+ document = Document(docx_content)
96
+ docx_text = " ".join([paragraph.text for paragraph in document.paragraphs])
97
+
98
+ docx_docs = RecursiveCharacterTextSplitter.create_documents(
99
+ [docx_text], metadatas=[{"source": source}]
100
+ )
101
+ return docx_docs
102
+
103
+
104
+ class Conversational_Chain:
105
 
106
+ def __init__(self, file_details: List[Dict[Any, str]]):
 
107
  self.llm_model = ChatOpenAI()
108
  self.embeddings = OpenAIEmbeddings()
109
+ self.file_details = file_details
110
 
111
+ def create_conversational_chain(self):
112
+ docs = Document_Processor(self.file_details).get_docs()
113
  memory = ConversationBufferMemory(
114
  memory_key="chat_history", return_messages=True
115
  )
 
118
  self.embeddings,
119
  )
120
  retriever = vectordb.as_retriever()
121
+ conversation_chain = ConversationalRetrievalChain.from_llm(
122
  llm=self.llm_model,
123
  retriever=retriever,
124
  condense_question_prompt=self.get_question_generator_prompt(),
 
129
  memory=memory,
130
  )
131
 
132
+ return conversation_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  @staticmethod
135
+ def get_document_prompt(self) -> PromptTemplate:
136
  document_template = """Document Content:{page_content}
137
  Document Path: {source}"""
138
  return PromptTemplate(
 
141
  )
142
 
143
  @staticmethod
144
+ def get_question_generator_prompt(self) -> PromptTemplate:
145
  question_generator_template = """Combine the chat history and follow up question into
146
  a standalone question.\n Chat History: {chat_history}\n
147
  Follow up question: {question}
 
149
  return PromptTemplate.from_template(question_generator_template)
150
 
151
  @staticmethod
152
+ def get_final_prompt(self) -> ChatPromptTemplate:
153
  final_prompt_template = """Answer question based on the context and chat_history.
154
  If you cannot find answers, ask more related questions from the user.
155
  Use only the basename of the file path as name of the documents.
 
175
  return ChatPromptTemplate.from_messages(messages)
176
 
177
 
178
+ app = FastAPI()
179
+
180
+ origins = ["https://viboognesh-react-chat.static.hf.space"]
181
+ # origins = ["http://localhost:3000"]
182
+
183
+ app.add_middleware(
184
+ CORSMiddleware,
185
+ allow_origins=origins,
186
+ allow_credentials=True,
187
+ allow_methods=["GET", "POST"],
188
+ allow_headers=["*"],
189
+ )
190
+
191
+
192
+ app.state.conversation_chain = None
193
 
194
 
195
  @app.post("/upload_files/")
196
+ async def upload_files(files: List[UploadFile] = File(...)):
197
+ file_details = []
 
 
 
 
 
 
 
198
  for file in files:
199
+ content = await file.read()
200
+ name = f"{file.filename}"
201
+ details = {"content": content, "name": name}
202
+ file_details.append(details)
 
203
 
204
+ app.state.conversational_chain = Conversational_Chain(file_details).create_conversational_chain()
205
  print("conversational_chain_manager created")
206
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
207
 
 
209
  @app.get("/predict/")
210
  async def predict(
211
  query: str,
 
 
 
212
  ):
213
+ if app.state.conversation_chain is None:
214
  system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
215
+ response = app.state.llm_model.invoke(system_prompt + query)
216
  answer = response.content
217
  else:
218
+ response = app.state.conversation_chain.invoke(query)
219
  answer = response["answer"]
220
 
221
  print("predict called")