File size: 8,407 Bytes
1f44a86
 
 
 
febca85
1f44a86
 
 
22dd68c
 
 
 
1f44a86
 
 
 
2b5cccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5865dab
9a69935
 
 
 
 
60bbf45
1f44a86
 
124f540
1f44a86
f5dabf5
1f44a86
f5dabf5
 
 
886337d
78d7b7f
f5dabf5
78d7b7f
f11f97c
 
f5dabf5
 
 
f11f97c
79ff6ae
f5dabf5
 
 
 
 
 
 
 
f11f97c
1f44a86
f5dabf5
 
 
 
 
 
9cc9353
f11f97c
f5dabf5
 
f11f97c
 
78d7b7f
f5dabf5
9cc9353
f5dabf5
f11f97c
f5dabf5
9cc9353
f11f97c
 
 
 
 
f5dabf5
 
 
 
 
 
 
 
 
 
 
 
22dd68c
 
 
 
 
 
 
 
f5dabf5
1f44a86
22dd68c
 
 
 
 
 
 
 
 
1f44a86
22dd68c
480beb6
 
 
 
 
22dd68c
 
 
 
 
 
 
79ff6ae
1f44a86
22dd68c
1f44a86
 
 
 
 
 
 
79ff6ae
22dd68c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f11f97c
22dd68c
 
 
5070fef
 
 
 
 
f11f97c
 
 
5070fef
 
 
 
 
1f44a86
 
22dd68c
9cc9353
22dd68c
1f44a86
 
480beb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import gradio as gr
import os
import uuid
import shutil
import fitz
from langchain_community.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import tempfile
import time
import threading

# --- Cleanup Configuration ---
CHROMA_DB_PATH = os.path.join(tempfile.gettempdir(), "chroma_db")
CLEANUP_INTERVAL_HOURS = 3  # Cleanup every 3 hours
SESSION_TTL_HOURS = 3       # Sessions older than 3 hours will be deleted

# --- Cleanup Functions ---
def cleanup_old_sessions():
    """Deletes session directories older than SESSION_TTL_HOURS."""
    while True:
        now = time.time()
        ttl_seconds = SESSION_TTL_HOURS * 3600
        
        if not os.path.exists(CHROMA_DB_PATH):
            time.sleep(CLEANUP_INTERVAL_HOURS * 3600)
            continue

        for session_id in os.listdir(CHROMA_DB_PATH):
            session_path = os.path.join(CHROMA_DB_PATH, session_id)
            if os.path.isdir(session_path):
                try:
                    mod_time = os.path.getmtime(session_path)
                    if (now - mod_time) > ttl_seconds:
                        print(f"Cleaning up old session: {session_id}")
                        shutil.rmtree(session_path)
                except Exception as e:
                    print(f"Error cleaning up session {session_id}: {e}")
        
        time.sleep(CLEANUP_INTERVAL_HOURS * 3600)

# --- Initial Cleanup on Startup ---
print("Performing initial cleanup of old ChromaDB directories...")
if os.path.exists(CHROMA_DB_PATH):
    shutil.rmtree(CHROMA_DB_PATH)
os.makedirs(CHROMA_DB_PATH)
print("Cleanup complete. Starting background cleanup thread.")

# --- Start Background Cleanup Thread ---
cleanup_thread = threading.Thread(target=cleanup_old_sessions, daemon=True)
cleanup_thread.start()


# Set the Google API key from environment variables
if "GOOGLE_API_KEY" not in os.environ:
    raise Exception("Please set the GOOGLE_API_KEY environment variable.")

google_api_key = os.environ.get("GOOGLE_API_KEY")

# Constants
LLM_MODEL = "gemini-1.5-flash"
EMBEDDING_MODEL = "models/embedding-001"

class SessionState:
    def __init__(self):
        self.session_id = str(uuid.uuid4())
        self.db = None
        self.vector_store_path = os.path.join(CHROMA_DB_PATH, self.session_id)

    def is_db_ready(self):
        return self.db is not None

async def process_pdf(pdf_file, state: SessionState):
    """Processes the PDF and updates the state object."""
    try:
        file_size_mb = os.path.getsize(pdf_file.name) / (1024 * 1024)
        if file_size_mb >= 75:
            raise gr.Error("File size exceeds the 75 MB limit. Please upload a smaller PDF.")

        print("Opening PDF file...")
        try:
            doc = fitz.open(pdf_file.name)
            text = ""
            for page in doc:
                text += page.get_text()
            doc.close()
        except Exception as e:
            raise gr.Error(f"Error processing PDF document: {str(e)}")

        print("PDF file opened successfully. Splitting text into chunks...")
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        docs = text_splitter.create_documents([text])
        print("Text split into chunks successfully.")

        embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, google_api_key=google_api_key)
        
        state.db = await Chroma.afrom_documents(
            documents=docs,
            embedding=embeddings,
            persist_directory=state.vector_store_path,
            collection_name=state.session_id
        )
        print("PDF processed successfully! Database is ready.")
        
    except Exception as e:
        if os.path.exists(state.vector_store_path):
            shutil.rmtree(state.vector_store_path)
        
        if isinstance(e, gr.Error):
            raise  # Re-raise Gradio errors directly
        else:
            raise gr.Error(f"An unexpected error occurred: {str(e)}")


async def chat_with_pdf(message, history, state: SessionState):
    print("Chat interface called. Checking if database is ready...")
    if not state or not state.is_db_ready():
        print("Database is not ready.")
        yield "Error: Database not ready. Please upload a PDF first."
        return

    print("Database is ready. Retrieving relevant documents...")
    retriever = state.db.as_retriever()
    llm = ChatGoogleGenerativeAI(model=LLM_MODEL, temperature=0.7, google_api_key=google_api_key)

    condenser_prompt = ChatPromptTemplate.from_messages([
        ("system", "Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
    ])
    
    history_aware_retriever = create_history_aware_retriever(
        llm, retriever, condenser_prompt
    )

    qa_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant for a PDF document. Answer the user's question based on the following context. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
    ])

    question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

    chat_history_for_chain = []
    for turn in history:
        if isinstance(turn, (list, tuple)) and len(turn) == 2:
            user_msg, ai_msg = turn
            chat_history_for_chain.append(HumanMessage(content=user_msg))
            chat_history_for_chain.append(AIMessage(content=ai_msg))

    response = await rag_chain.ainvoke({
        "chat_history": chat_history_for_chain,
        "input": message
    })

    yield response["answer"]

with gr.Blocks(title="PDF Chatbot") as demo:
    state = gr.State()

    gr.Markdown(
        """
        # PDF Chatbot
        Upload a PDF to start a conversation with your document.
        """
    )

    with gr.Row():
        file_upload_input = gr.File(
            file_types=[".pdf"],
            label="Upload your PDF document",
            interactive=True
        )

    with gr.Row(visible=False) as chat_row:
        chat_interface = gr.ChatInterface(
            fn=chat_with_pdf,
            additional_inputs=[state],
            chatbot=gr.Chatbot(type="messages"),
            textbox=gr.Textbox(placeholder="Type your question here...", scale=7),
            examples=[["What is the main topic of the document?"], ["Summarize the key findings."], ["Who are the authors?"]],
            title="Chat Interface",
            theme="soft",
            type="messages"
        )

    async def process_and_show_chat(file, state):
        gr.Info("Processing your PDF, please wait...")
        new_state = SessionState()
        try:
            await process_pdf(file, new_state)
            gr.Info("PDF processed successfully! You can now chat with it.")
            return [
                gr.update(visible=True),
                gr.update(interactive=False),
                new_state,
            ]
        except gr.Error as e:
            # Display the Gradio error message to the user
            gr.Error(str(e))
            return [
                gr.update(visible=False),
                gr.update(interactive=True),
                state, # Return original state on failure
            ]

    file_upload_input.upload(
        fn=process_and_show_chat,
        inputs=[file_upload_input, state],
        outputs=[chat_row, file_upload_input, state]
    )

demo.launch()