File size: 10,173 Bytes
a749b20
ad7ff16
a749b20
 
 
84015e9
d25a986
 
 
 
 
 
ad7ff16
 
 
 
 
 
 
 
 
84015e9
 
94dc18e
ad7ff16
84015e9
d25a986
84015e9
d25a986
ad7ff16
d25a986
 
 
 
84015e9
ad7ff16
d25a986
84015e9
d25a986
 
84015e9
ad7ff16
d25a986
84015e9
ad7ff16
d25a986
 
 
84015e9
d25a986
 
 
 
 
 
 
 
 
94dc18e
d25a986
84015e9
d25a986
 
 
 
 
 
 
 
 
 
 
 
 
 
94dc18e
d25a986
 
 
 
84015e9
d25a986
 
 
 
 
 
 
 
84015e9
 
 
 
d25a986
 
 
 
 
84015e9
d25a986
 
 
 
 
 
84015e9
 
 
 
d25a986
 
 
 
 
 
 
ad7ff16
 
d25a986
ad7ff16
d25a986
 
ad7ff16
 
 
 
 
 
 
 
d25a986
ad7ff16
 
 
 
 
 
 
 
 
 
d25a986
ad7ff16
 
84015e9
ad7ff16
 
 
 
 
 
 
 
84015e9
ad7ff16
 
 
 
 
 
 
84015e9
ad7ff16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a749b20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25a986
 
 
 
 
 
 
 
 
 
 
 
 
a749b20
 
 
 
 
 
 
 
 
 
d25a986
a749b20
 
 
 
 
 
 
 
b57cdda
 
 
94dc18e
ad7ff16
a749b20
d25a986
fefc214
 
 
 
 
 
 
 
 
a749b20
 
 
 
 
fefc214
a749b20
fefc214
 
a749b20
 
 
fefc214
 
 
 
ad7ff16
 
 
 
fefc214
a749b20
 
 
 
 
fefc214
a749b20
 
fefc214
 
 
 
 
a749b20
fefc214
 
 
ad7ff16
 
a749b20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
from fastapi import FastAPI, File, UploadFile, HTTPException, Response, Request
from fastapi.middleware.cors import CORSMiddleware
import uuid
from datetime import datetime, timedelta
import asyncio

from typing import List, Dict, Any
from io import BytesIO, StringIO
from docx import Document
from langchain.docstore.document import Document as langchain_Document
from PyPDF2 import PdfReader
import csv

from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.memory import ConversationBufferMemory
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain

from dotenv import load_dotenv

load_dotenv()


class Document_Processor:
    def __init__(self, file_details: List[Dict[Any, str]]):
        self.file_details = file_details

    def get_docs(self) -> List[langchain_Document]:
        docs = []
        for file_detail in self.file_details:
            if file_detail["name"].endswith(".txt"):
                docs.extend(self.get_txt_docs(file_detail=file_detail))

            elif file_detail["name"].endswith(".csv"):
                docs.extend(self.get_csv_docs(file_detail=file_detail))

            elif file_detail["name"].endswith(".docx"):
                docs.extend(self.get_docx_docs(file_detail=file_detail))

            elif file_detail["name"].endswith(".pdf"):
                docs.extend(self.get_pdf_docs(file_detail=file_detail))

        return docs

    @staticmethod
    def get_txt_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]:
        text = file_detail["content"].decode("utf-8")
        source = file_detail["name"]
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=100
        )
        text_docs = text_splitter.create_documents(
            [text], metadatas=[{"source": source}]
        )
        return text_docs

    @staticmethod
    def get_csv_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]:
        csv_data = file_detail["content"]
        source = file_detail["name"]
        csv_string = csv_data.decode("utf-8")
        # Use StringIO to create a file-like object from the string
        csv_file = StringIO(csv_string)
        csv_reader = csv.DictReader(csv_file)
        csv_docs = []
        for row in csv_reader:
            # Convert each row into a dictionary of key/value pairs
            page_content = ""
            for key, value in row.items():
                page_content += f"{key}: {value}\n"
            doc = langchain_Document(
                page_content=page_content, metadata={"source": source}
            )
            csv_docs.append(doc)
        return csv_docs

    @staticmethod
    def get_pdf_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]:
        pdf_content = BytesIO(file_detail["content"])
        source = file_detail["name"]

        reader = PdfReader(pdf_content)
        pdf_text = ""
        for page in reader.pages:
            pdf_text += page.extract_text() + "\n"

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=100
        )
        pdf_docs = text_splitter.create_documents(
            texts=[pdf_text], metadatas=[{"source": source}]
        )
        return pdf_docs

    @staticmethod
    def get_docx_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]:
        docx_content = BytesIO(file_detail["content"])
        source = file_detail["name"]

        document = Document(docx_content)
        docx_text = " ".join([paragraph.text for paragraph in document.paragraphs])

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=100
        )
        docx_docs = text_splitter.create_documents(
            [docx_text], metadatas=[{"source": source}]
        )
        return docx_docs


class Conversational_Chain:
    def __init__(self, file_details: List[Dict[Any, str]]):
        self.llm_model = ChatOpenAI()
        self.embeddings = OpenAIEmbeddings()
        self.file_details = file_details

    def create_conversational_chain(self):
        docs = Document_Processor(self.file_details).get_docs()
        memory = ConversationBufferMemory(
            memory_key="chat_history", return_messages=True
        )
        vectordb = Chroma.from_documents(
            docs,
            self.embeddings,
        )
        retriever = vectordb.as_retriever()
        conversation_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm_model,
            retriever=retriever,
            condense_question_prompt=self.get_question_generator_prompt(),
            combine_docs_chain_kwargs={
                "document_prompt": self.get_document_prompt(),
                "prompt": self.get_final_prompt(),
            },
            memory=memory,
        )

        return conversation_chain

    @staticmethod
    def get_document_prompt() -> PromptTemplate:
        document_template = """Document Content:{page_content}
    Document Path: {source}"""
        return PromptTemplate(
            input_variables=["page_content", "source"],
            template=document_template,
        )

    @staticmethod
    def get_question_generator_prompt() -> PromptTemplate:
        question_generator_template = """Combine the chat history and follow up question into
    a standalone question.\n Chat History: {chat_history}\n
    Follow up question: {question}
    """
        return PromptTemplate.from_template(question_generator_template)

    @staticmethod
    def get_final_prompt() -> ChatPromptTemplate:
        final_prompt_template = """Answer question based on the context and chat_history.
    If you cannot find answers, ask more related questions from the user.
    Use only the basename of the file path as name of the documents.
    Mention document name of the documents you used in your answer.

    context:
    {context}

    chat_history:
    {chat_history}

    question:
    {question}

    Answer:
    """

        messages = [
            SystemMessagePromptTemplate.from_template(final_prompt_template),
            HumanMessagePromptTemplate.from_template("{question}"),
        ]

        return ChatPromptTemplate.from_messages(messages)


class UserSessionManager:
    def __init__(self):
        self.sessions = {}
        self.last_request_time = {}

    def get_session(self, user_id: str):
        if user_id not in self.sessions:
            self.sessions[user_id] = None
            self.last_request_time = datetime.now()
        return self.sessions[user_id]

    def set_session(self, user_id: str, conversational_chain):
        self.sessions[user_id] = conversational_chain
        self.last_request_time[user_id] = datetime.now()

    def delete_inactive_sessions(self, inactive_period: timedelta):
        current_time = datetime.now()
        for user_id, last_request_time in list(self.last_request_time.items()):
            if current_time - last_request_time > inactive_period:
                del self.sessions[user_id]
                del self.last_request_time[user_id]


app = FastAPI()

origins = ["https://viboognesh-react-chat.static.hf.space"]
# origins = ["http://localhost:3000"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

user_session_manager = UserSessionManager()


@app.middleware("http")
async def update_last_request_time(request: Request, call_next):
    user_id = request.cookies.get("user_id")
    if user_id:
        user_session_manager.last_request_time[user_id] = datetime.now()
    response = await call_next(request)
    return response


async def check_inactivity():
    inactive_period = timedelta(hours=2)
    while True:
        await asyncio.sleep(600)
        user_session_manager.delete_inactive_sessions(inactive_period)


@app.on_event("startup")
async def startup_event():
    asyncio.create_task(check_inactivity())

@app.post("/upload_files/")
async def upload_files(response: Response, files: List[UploadFile] = File(...)):
    file_details = []
    try:
        for file in files:
            content = await file.read()
            name = f"{file.filename}"
            details = {"content": content, "name": name}
            file_details.append(details)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

    user_id = response.cookies.get("user_id")
    if not user_id:
        user_id = str(uuid.uuid4())
        response.set_cookie(key="user_id", value=user_id)

    try:
        conversational_chain = Conversational_Chain(
            file_details
        ).create_conversational_chain()
        user_session_manager.set_session(
            user_id=user_id, conversational_chain=conversational_chain
        )
        print("conversational_chain_manager created")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    return {"message": "ConversationalRetrievalChain is created. Please ask questions."}


@app.get("/predict/")
async def predict(query: str):
    user_id = response.cookies.get("user_id")
    if not user_id:
        user_id = str(uuid.uuid4())
        response.set_cookie(key="user_id", value=user_id)

    try:
        conversational_chain = user_session_manager.get_session(user_id=user_id)
        if conversational_chain is None:
            system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
            llm_model = ChatOpenAI()
            response = llm_model.invoke(system_prompt + query)
            answer = response.content
        else:
            response = conversational_chain.invoke(query)
            answer = response["answer"]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    print("predict called")
    return {"answer": answer}