ChatWIthPdf / app.py
manoj1hcl's picture
Update app.py
2e07a52 verified
import os
import shutil
from dotenv import load_dotenv
import gradio as gr
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
# Low-cost model
MODEL = "gpt-4o-mini"
load_dotenv(override=True)
chain = None # global chain (OK for single-user; see note below)
def pick_db_dir() -> str:
# 1) Allow override via env
env_dir = os.environ.get("DB_DIR")
if env_dir:
return env_dir
# 2) Prefer persistent mount if available/writable
data_dir = "/data"
if os.path.isdir(data_dir) and os.access(data_dir, os.W_OK):
return os.path.join(data_dir, "vector_db")
# 3) Fallback to tmp (ephemeral)
return "/tmp/vector_db"
DB_DIR = pick_db_dir()
def process_pdf(pdf_file):
try:
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError(
"OPENAI_API_KEY is not set. Add it to your environment or as a Secret on HF Spaces."
)
# Ensure DB dir is writable/exists
os.makedirs(DB_DIR, exist_ok=True)
file_path = pdf_file.name # gr.File gives a temp file with .name path
loader = PyPDFLoader(file_path)
pages = loader.load()
if not pages:
raise ValueError("No text found in PDF (may be scanned or protected).")
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
)
chunks = splitter.split_documents(pages)
if not chunks:
raise ValueError("Unable to split PDF into chunks (empty/protected PDF).")
embeddings = OpenAIEmbeddings()
# Reset persisted DB each upload (now in a writable location)
if os.path.exists(DB_DIR):
shutil.rmtree(DB_DIR, ignore_errors=True)
os.makedirs(DB_DIR, exist_ok=True)
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory=DB_DIR,
)
llm = ChatOpenAI(model=MODEL, temperature=0.2)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Store llm + retriever only (no langchain.chains to avoid langchain_core.memory)
return {"llm": llm, "retriever": retriever}
except Exception as e:
raise RuntimeError(f"PDF processing failed: {str(e)}")
def upload_pdf(file):
global chain
if file is None:
chain = None
return "Please upload a PDF."
try:
chain = process_pdf(file)
return f"PDF processed. Vector DB at: {DB_DIR}. Ask questions now."
except RuntimeError as e:
chain = None
msg = str(e)
if "OPENAI_API_KEY" in msg:
return "Error: OPENAI_API_KEY is not set. Add it to a .env file in this folder (OPENAI_API_KEY=sk-...) or run: export OPENAI_API_KEY=your-key"
return f"Error: {msg}"
def _gradio_history_to_langchain(history):
"""Convert Gradio chat history to LangChain message list."""
if not history:
return []
lc_messages = []
for m in history:
if isinstance(m, dict):
role, content = m.get("role", ""), m.get("content", "")
else:
content = getattr(m, "content", m[0] if len(m) > 0 else "")
role = getattr(m, "role", m[1] if len(m) > 1 else "assistant")
if role == "user":
lc_messages.append(HumanMessage(content=content or ""))
else:
lc_messages.append(AIMessage(content=content or ""))
return lc_messages
def _get_answer_from_message(msg) -> str:
"""Extract text from LLM response (AIMessage or str)."""
if hasattr(msg, "content"):
return getattr(msg, "content", "") or ""
return str(msg) if msg else ""
def ask_question(message, history):
if chain is None:
history = history or []
history.append({"role": "assistant", "content": "Upload the PDF first."})
return history, history, ""
chat_history_lc = _gradio_history_to_langchain(history or [])
llm = chain["llm"]
retriever = chain["retriever"]
try:
# 1) Turn current question + history into a standalone question
contextualize_prompt = ChatPromptTemplate.from_messages([
("system", "Given the chat history and the latest user question, write a single standalone question that can be understood without the chat history. If the question is already standalone, return it unchanged. Do not answer the question."),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
contextualize_chain = contextualize_prompt | llm
standalone = contextualize_chain.invoke({"input": message, "chat_history": chat_history_lc})
query = _get_answer_from_message(standalone).strip() or message
# 2) Retrieve docs
docs = retriever.invoke(query)
context = "\n\n".join((getattr(d, "page_content", "") or str(d) for d in docs))
# 3) Answer with context + history
qa_prompt = ChatPromptTemplate.from_messages([
("system", "You are an assistant for question-answering. Use the following context to answer. If the answer is not in the context, say so. Be concise.\n\nContext:\n{context}"),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
qa_chain = qa_prompt | llm
response = qa_chain.invoke({"input": message, "chat_history": chat_history_lc, "context": context})
answer = _get_answer_from_message(response) or "No answer found."
except Exception as e:
answer = f"Error: {str(e)}"
history = history or []
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": answer})
return history, history, ""
with gr.Blocks() as demo:
gr.Markdown("## Chat with your PDF")
file_input = gr.File(label="Upload your PDF", file_types=[".pdf"])
status = gr.Textbox(label="Status", interactive=False)
chatbot = gr.Chatbot(label="Chat history", type="messages")
msg = gr.Textbox(label="Ask anything related to the PDF...")
clear = gr.Button("Clear chat")
state = gr.State([])
file_input.change(upload_pdf, inputs=[file_input], outputs=[status])
msg.submit(ask_question, inputs=[msg, state], outputs=[chatbot, state, msg])
clear.click(lambda: ([], []), inputs=None, outputs=[chatbot, state])
demo.launch(inline=False)