viboognesh commited on
Commit
94dc18e
·
verified ·
1 Parent(s): 8ea5ecb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -11
main.py CHANGED
@@ -6,7 +6,7 @@ import aiofiles
6
  import uuid
7
  import shutil
8
 
9
- # from dotenv import load_dotenv
10
 
11
  from langchain_community.document_loaders import TextLoader, Docx2txtLoader, PyPDFLoader
12
  from langchain.prompts import ChatPromptTemplate, PromptTemplate
@@ -18,11 +18,12 @@ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
  from langchain_community.vectorstores import Chroma
19
  from langchain.chains import ConversationalRetrievalChain
20
 
21
- # load_dotenv()
22
 
23
  app = FastAPI()
24
 
25
  origins = ["https://viboognesh-react-chat.static.hf.space"]
 
26
 
27
  app.add_middleware(
28
  CORSMiddleware,
@@ -34,12 +35,21 @@ app.add_middleware(
34
 
35
 
36
  class ConversationChainManager:
 
 
 
 
 
 
 
 
 
37
  def __init__(self):
38
  self.conversation_chain = None
39
  self.llm_model = ChatOpenAI()
40
  self.embeddings = OpenAIEmbeddings()
41
 
42
- def create_conversational_chain(self, file_paths: List[str], session_id: str):
43
  docs = self.get_docs(file_paths)
44
  memory = ConversationBufferMemory(
45
  memory_key="chat_history", return_messages=True
@@ -47,8 +57,6 @@ class ConversationChainManager:
47
  vectordb = Chroma.from_documents(
48
  docs,
49
  self.embeddings,
50
- collection_name=session_id,
51
- persist_directory="./chroma_db",
52
  )
53
  retriever = vectordb.as_retriever()
54
  self.conversation_chain = ConversationalRetrievalChain.from_llm(
@@ -90,6 +98,7 @@ class ConversationChainManager:
90
  loader = PyPDFLoader(file_path)
91
  pdf_documents = loader.load_and_split()
92
  docs.extend(pdf_documents)
 
93
  return docs
94
 
95
  @staticmethod
@@ -136,13 +145,17 @@ class ConversationChainManager:
136
  return ChatPromptTemplate.from_messages(messages)
137
 
138
 
 
 
 
139
  @app.post("/upload_files/")
140
  async def upload_files(
141
  files: List[UploadFile] = File(...),
142
- conversation_chain_manager: ConversationChainManager = Depends(),
 
 
143
  ):
144
- session_id = str(uuid.uuid4())
145
- session_folder = f"uploads/{session_id}"
146
  os.makedirs(session_folder, exist_ok=True)
147
  file_paths = []
148
  for file in files:
@@ -152,15 +165,17 @@ async def upload_files(
152
  await out_file.write(content)
153
  file_paths.append(file_path)
154
 
155
- conversation_chain_manager.create_conversational_chain(file_paths, session_id)
156
- # shutil.rmtree(session_folder)
157
  print("conversational_chain_manager created")
158
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
159
 
160
 
161
  @app.get("/predict/")
162
  async def predict(
163
- query: str, conversation_chain_manager: ConversationChainManager = Depends()
 
 
 
164
  ):
165
  if conversation_chain_manager.conversation_chain is None:
166
  system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
 
6
  import uuid
7
  import shutil
8
 
9
+ from dotenv import load_dotenv
10
 
11
  from langchain_community.document_loaders import TextLoader, Docx2txtLoader, PyPDFLoader
12
  from langchain.prompts import ChatPromptTemplate, PromptTemplate
 
18
  from langchain_community.vectorstores import Chroma
19
  from langchain.chains import ConversationalRetrievalChain
20
 
21
+ load_dotenv()
22
 
23
  app = FastAPI()
24
 
25
  origins = ["https://viboognesh-react-chat.static.hf.space"]
26
+ # origins = ["http://localhost:3000"]
27
 
28
  app.add_middleware(
29
  CORSMiddleware,
 
35
 
36
 
37
  class ConversationChainManager:
38
+ _instance = None
39
+
40
+ def __new__(cls, *args, **kwargs):
41
+ if not cls._instance:
42
+ cls._instance = super(ConversationChainManager, cls).__new__(
43
+ cls, *args, **kwargs
44
+ )
45
+ return cls._instance
46
+
47
  def __init__(self):
48
  self.conversation_chain = None
49
  self.llm_model = ChatOpenAI()
50
  self.embeddings = OpenAIEmbeddings()
51
 
52
+ def create_conversational_chain(self, file_paths: List[str]):
53
  docs = self.get_docs(file_paths)
54
  memory = ConversationBufferMemory(
55
  memory_key="chat_history", return_messages=True
 
57
  vectordb = Chroma.from_documents(
58
  docs,
59
  self.embeddings,
 
 
60
  )
61
  retriever = vectordb.as_retriever()
62
  self.conversation_chain = ConversationalRetrievalChain.from_llm(
 
98
  loader = PyPDFLoader(file_path)
99
  pdf_documents = loader.load_and_split()
100
  docs.extend(pdf_documents)
101
+ os.remove(file_path)
102
  return docs
103
 
104
  @staticmethod
 
145
  return ChatPromptTemplate.from_messages(messages)
146
 
147
 
148
+ app.state.conversational_chain_manager = ConversationChainManager()
149
+
150
+
151
  @app.post("/upload_files/")
152
  async def upload_files(
153
  files: List[UploadFile] = File(...),
154
+ conversation_chain_manager: ConversationChainManager = Depends(
155
+ lambda: app.state.conversational_chain_manager
156
+ ),
157
  ):
158
+ session_folder = f"uploads"
 
159
  os.makedirs(session_folder, exist_ok=True)
160
  file_paths = []
161
  for file in files:
 
165
  await out_file.write(content)
166
  file_paths.append(file_path)
167
 
168
+ conversation_chain_manager.create_conversational_chain(file_paths)
 
169
  print("conversational_chain_manager created")
170
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
171
 
172
 
173
  @app.get("/predict/")
174
  async def predict(
175
+ query: str,
176
+ conversation_chain_manager: ConversationChainManager = Depends(
177
+ lambda: app.state.conversational_chain_manager
178
+ ),
179
  ):
180
  if conversation_chain_manager.conversation_chain is None:
181
  system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"