nikhmr1235's picture
fix ValueError below
480beb6 verified
import gradio as gr
import os
import uuid
import shutil
import fitz
from langchain_community.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import tempfile
import time
import threading
# --- Cleanup Configuration ---
CHROMA_DB_PATH = os.path.join(tempfile.gettempdir(), "chroma_db")
CLEANUP_INTERVAL_HOURS = 3 # Cleanup every 3 hours
SESSION_TTL_HOURS = 3 # Sessions older than 3 hours will be deleted
# --- Cleanup Functions ---
def cleanup_old_sessions():
"""Deletes session directories older than SESSION_TTL_HOURS."""
while True:
now = time.time()
ttl_seconds = SESSION_TTL_HOURS * 3600
if not os.path.exists(CHROMA_DB_PATH):
time.sleep(CLEANUP_INTERVAL_HOURS * 3600)
continue
for session_id in os.listdir(CHROMA_DB_PATH):
session_path = os.path.join(CHROMA_DB_PATH, session_id)
if os.path.isdir(session_path):
try:
mod_time = os.path.getmtime(session_path)
if (now - mod_time) > ttl_seconds:
print(f"Cleaning up old session: {session_id}")
shutil.rmtree(session_path)
except Exception as e:
print(f"Error cleaning up session {session_id}: {e}")
time.sleep(CLEANUP_INTERVAL_HOURS * 3600)
# --- Initial Cleanup on Startup ---
print("Performing initial cleanup of old ChromaDB directories...")
if os.path.exists(CHROMA_DB_PATH):
shutil.rmtree(CHROMA_DB_PATH)
os.makedirs(CHROMA_DB_PATH)
print("Cleanup complete. Starting background cleanup thread.")
# --- Start Background Cleanup Thread ---
cleanup_thread = threading.Thread(target=cleanup_old_sessions, daemon=True)
cleanup_thread.start()
# Set the Google API key from environment variables
if "GOOGLE_API_KEY" not in os.environ:
raise Exception("Please set the GOOGLE_API_KEY environment variable.")
google_api_key = os.environ.get("GOOGLE_API_KEY")
# Constants
LLM_MODEL = "gemini-1.5-flash"
EMBEDDING_MODEL = "models/embedding-001"
class SessionState:
def __init__(self):
self.session_id = str(uuid.uuid4())
self.db = None
self.vector_store_path = os.path.join(CHROMA_DB_PATH, self.session_id)
def is_db_ready(self):
return self.db is not None
async def process_pdf(pdf_file, state: SessionState):
"""Processes the PDF and updates the state object."""
try:
file_size_mb = os.path.getsize(pdf_file.name) / (1024 * 1024)
if file_size_mb >= 75:
raise gr.Error("File size exceeds the 75 MB limit. Please upload a smaller PDF.")
print("Opening PDF file...")
try:
doc = fitz.open(pdf_file.name)
text = ""
for page in doc:
text += page.get_text()
doc.close()
except Exception as e:
raise gr.Error(f"Error processing PDF document: {str(e)}")
print("PDF file opened successfully. Splitting text into chunks...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = text_splitter.create_documents([text])
print("Text split into chunks successfully.")
embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, google_api_key=google_api_key)
state.db = await Chroma.afrom_documents(
documents=docs,
embedding=embeddings,
persist_directory=state.vector_store_path,
collection_name=state.session_id
)
print("PDF processed successfully! Database is ready.")
except Exception as e:
if os.path.exists(state.vector_store_path):
shutil.rmtree(state.vector_store_path)
if isinstance(e, gr.Error):
raise # Re-raise Gradio errors directly
else:
raise gr.Error(f"An unexpected error occurred: {str(e)}")
async def chat_with_pdf(message, history, state: SessionState):
print("Chat interface called. Checking if database is ready...")
if not state or not state.is_db_ready():
print("Database is not ready.")
yield "Error: Database not ready. Please upload a PDF first."
return
print("Database is ready. Retrieving relevant documents...")
retriever = state.db.as_retriever()
llm = ChatGoogleGenerativeAI(model=LLM_MODEL, temperature=0.7, google_api_key=google_api_key)
condenser_prompt = ChatPromptTemplate.from_messages([
("system", "Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
llm, retriever, condenser_prompt
)
qa_prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful assistant for a PDF document. Answer the user's question based on the following context. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
chat_history_for_chain = []
for turn in history:
if isinstance(turn, (list, tuple)) and len(turn) == 2:
user_msg, ai_msg = turn
chat_history_for_chain.append(HumanMessage(content=user_msg))
chat_history_for_chain.append(AIMessage(content=ai_msg))
response = await rag_chain.ainvoke({
"chat_history": chat_history_for_chain,
"input": message
})
yield response["answer"]
with gr.Blocks(title="PDF Chatbot") as demo:
state = gr.State()
gr.Markdown(
"""
# PDF Chatbot
Upload a PDF to start a conversation with your document.
"""
)
with gr.Row():
file_upload_input = gr.File(
file_types=[".pdf"],
label="Upload your PDF document",
interactive=True
)
with gr.Row(visible=False) as chat_row:
chat_interface = gr.ChatInterface(
fn=chat_with_pdf,
additional_inputs=[state],
chatbot=gr.Chatbot(type="messages"),
textbox=gr.Textbox(placeholder="Type your question here...", scale=7),
examples=[["What is the main topic of the document?"], ["Summarize the key findings."], ["Who are the authors?"]],
title="Chat Interface",
theme="soft",
type="messages"
)
async def process_and_show_chat(file, state):
gr.Info("Processing your PDF, please wait...")
new_state = SessionState()
try:
await process_pdf(file, new_state)
gr.Info("PDF processed successfully! You can now chat with it.")
return [
gr.update(visible=True),
gr.update(interactive=False),
new_state,
]
except gr.Error as e:
# Display the Gradio error message to the user
gr.Error(str(e))
return [
gr.update(visible=False),
gr.update(interactive=True),
state, # Return original state on failure
]
file_upload_input.upload(
fn=process_and_show_chat,
inputs=[file_upload_input, state],
outputs=[chat_row, file_upload_input, state]
)
demo.launch()