Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import uuid | |
| import shutil | |
| import fitz | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| import tempfile | |
| import time | |
| import threading | |
| # --- Cleanup Configuration --- | |
| CHROMA_DB_PATH = os.path.join(tempfile.gettempdir(), "chroma_db") | |
| CLEANUP_INTERVAL_HOURS = 3 # Cleanup every 3 hours | |
| SESSION_TTL_HOURS = 3 # Sessions older than 3 hours will be deleted | |
| # --- Cleanup Functions --- | |
| def cleanup_old_sessions(): | |
| """Deletes session directories older than SESSION_TTL_HOURS.""" | |
| while True: | |
| now = time.time() | |
| ttl_seconds = SESSION_TTL_HOURS * 3600 | |
| if not os.path.exists(CHROMA_DB_PATH): | |
| time.sleep(CLEANUP_INTERVAL_HOURS * 3600) | |
| continue | |
| for session_id in os.listdir(CHROMA_DB_PATH): | |
| session_path = os.path.join(CHROMA_DB_PATH, session_id) | |
| if os.path.isdir(session_path): | |
| try: | |
| mod_time = os.path.getmtime(session_path) | |
| if (now - mod_time) > ttl_seconds: | |
| print(f"Cleaning up old session: {session_id}") | |
| shutil.rmtree(session_path) | |
| except Exception as e: | |
| print(f"Error cleaning up session {session_id}: {e}") | |
| time.sleep(CLEANUP_INTERVAL_HOURS * 3600) | |
| # --- Initial Cleanup on Startup --- | |
| print("Performing initial cleanup of old ChromaDB directories...") | |
| if os.path.exists(CHROMA_DB_PATH): | |
| shutil.rmtree(CHROMA_DB_PATH) | |
| os.makedirs(CHROMA_DB_PATH) | |
| print("Cleanup complete. Starting background cleanup thread.") | |
| # --- Start Background Cleanup Thread --- | |
| cleanup_thread = threading.Thread(target=cleanup_old_sessions, daemon=True) | |
| cleanup_thread.start() | |
| # Set the Google API key from environment variables | |
| if "GOOGLE_API_KEY" not in os.environ: | |
| raise Exception("Please set the GOOGLE_API_KEY environment variable.") | |
| google_api_key = os.environ.get("GOOGLE_API_KEY") | |
| # Constants | |
| LLM_MODEL = "gemini-1.5-flash" | |
| EMBEDDING_MODEL = "models/embedding-001" | |
| class SessionState: | |
| def __init__(self): | |
| self.session_id = str(uuid.uuid4()) | |
| self.db = None | |
| self.vector_store_path = os.path.join(CHROMA_DB_PATH, self.session_id) | |
| def is_db_ready(self): | |
| return self.db is not None | |
| async def process_pdf(pdf_file, state: SessionState): | |
| """Processes the PDF and updates the state object.""" | |
| try: | |
| file_size_mb = os.path.getsize(pdf_file.name) / (1024 * 1024) | |
| if file_size_mb >= 75: | |
| raise gr.Error("File size exceeds the 75 MB limit. Please upload a smaller PDF.") | |
| print("Opening PDF file...") | |
| try: | |
| doc = fitz.open(pdf_file.name) | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| doc.close() | |
| except Exception as e: | |
| raise gr.Error(f"Error processing PDF document: {str(e)}") | |
| print("PDF file opened successfully. Splitting text into chunks...") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| docs = text_splitter.create_documents([text]) | |
| print("Text split into chunks successfully.") | |
| embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, google_api_key=google_api_key) | |
| state.db = await Chroma.afrom_documents( | |
| documents=docs, | |
| embedding=embeddings, | |
| persist_directory=state.vector_store_path, | |
| collection_name=state.session_id | |
| ) | |
| print("PDF processed successfully! Database is ready.") | |
| except Exception as e: | |
| if os.path.exists(state.vector_store_path): | |
| shutil.rmtree(state.vector_store_path) | |
| if isinstance(e, gr.Error): | |
| raise # Re-raise Gradio errors directly | |
| else: | |
| raise gr.Error(f"An unexpected error occurred: {str(e)}") | |
| async def chat_with_pdf(message, history, state: SessionState): | |
| print("Chat interface called. Checking if database is ready...") | |
| if not state or not state.is_db_ready(): | |
| print("Database is not ready.") | |
| yield "Error: Database not ready. Please upload a PDF first." | |
| return | |
| print("Database is ready. Retrieving relevant documents...") | |
| retriever = state.db.as_retriever() | |
| llm = ChatGoogleGenerativeAI(model=LLM_MODEL, temperature=0.7, google_api_key=google_api_key) | |
| condenser_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| history_aware_retriever = create_history_aware_retriever( | |
| llm, retriever, condenser_prompt | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful assistant for a PDF document. Answer the user's question based on the following context. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}"), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| chat_history_for_chain = [] | |
| for turn in history: | |
| if isinstance(turn, (list, tuple)) and len(turn) == 2: | |
| user_msg, ai_msg = turn | |
| chat_history_for_chain.append(HumanMessage(content=user_msg)) | |
| chat_history_for_chain.append(AIMessage(content=ai_msg)) | |
| response = await rag_chain.ainvoke({ | |
| "chat_history": chat_history_for_chain, | |
| "input": message | |
| }) | |
| yield response["answer"] | |
| with gr.Blocks(title="PDF Chatbot") as demo: | |
| state = gr.State() | |
| gr.Markdown( | |
| """ | |
| # PDF Chatbot | |
| Upload a PDF to start a conversation with your document. | |
| """ | |
| ) | |
| with gr.Row(): | |
| file_upload_input = gr.File( | |
| file_types=[".pdf"], | |
| label="Upload your PDF document", | |
| interactive=True | |
| ) | |
| with gr.Row(visible=False) as chat_row: | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_with_pdf, | |
| additional_inputs=[state], | |
| chatbot=gr.Chatbot(type="messages"), | |
| textbox=gr.Textbox(placeholder="Type your question here...", scale=7), | |
| examples=[["What is the main topic of the document?"], ["Summarize the key findings."], ["Who are the authors?"]], | |
| title="Chat Interface", | |
| theme="soft", | |
| type="messages" | |
| ) | |
| async def process_and_show_chat(file, state): | |
| gr.Info("Processing your PDF, please wait...") | |
| new_state = SessionState() | |
| try: | |
| await process_pdf(file, new_state) | |
| gr.Info("PDF processed successfully! You can now chat with it.") | |
| return [ | |
| gr.update(visible=True), | |
| gr.update(interactive=False), | |
| new_state, | |
| ] | |
| except gr.Error as e: | |
| # Display the Gradio error message to the user | |
| gr.Error(str(e)) | |
| return [ | |
| gr.update(visible=False), | |
| gr.update(interactive=True), | |
| state, # Return original state on failure | |
| ] | |
| file_upload_input.upload( | |
| fn=process_and_show_chat, | |
| inputs=[file_upload_input, state], | |
| outputs=[chat_row, file_upload_input, state] | |
| ) | |
| demo.launch() |