import os import shutil from dotenv import load_dotenv import gradio as gr from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_chroma import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, AIMessage # Low-cost model MODEL = "gpt-4o-mini" load_dotenv(override=True) chain = None # global chain (OK for single-user; see note below) def pick_db_dir() -> str: # 1) Allow override via env env_dir = os.environ.get("DB_DIR") if env_dir: return env_dir # 2) Prefer persistent mount if available/writable data_dir = "/data" if os.path.isdir(data_dir) and os.access(data_dir, os.W_OK): return os.path.join(data_dir, "vector_db") # 3) Fallback to tmp (ephemeral) return "/tmp/vector_db" DB_DIR = pick_db_dir() def process_pdf(pdf_file): try: if not os.getenv("OPENAI_API_KEY"): raise RuntimeError( "OPENAI_API_KEY is not set. Add it to your environment or as a Secret on HF Spaces." ) # Ensure DB dir is writable/exists os.makedirs(DB_DIR, exist_ok=True) file_path = pdf_file.name # gr.File gives a temp file with .name path loader = PyPDFLoader(file_path) pages = loader.load() if not pages: raise ValueError("No text found in PDF (may be scanned or protected).") splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, ) chunks = splitter.split_documents(pages) if not chunks: raise ValueError("Unable to split PDF into chunks (empty/protected PDF).") embeddings = OpenAIEmbeddings() # Reset persisted DB each upload (now in a writable location) if os.path.exists(DB_DIR): shutil.rmtree(DB_DIR, ignore_errors=True) os.makedirs(DB_DIR, exist_ok=True) vectorstore = Chroma.from_documents( documents=chunks, embedding=embeddings, persist_directory=DB_DIR, ) llm = ChatOpenAI(model=MODEL, temperature=0.2) retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) # Store llm + retriever only (no langchain.chains to avoid langchain_core.memory) return {"llm": llm, "retriever": retriever} except Exception as e: raise RuntimeError(f"PDF processing failed: {str(e)}") def upload_pdf(file): global chain if file is None: chain = None return "Please upload a PDF." try: chain = process_pdf(file) return f"PDF processed. Vector DB at: {DB_DIR}. Ask questions now." except RuntimeError as e: chain = None msg = str(e) if "OPENAI_API_KEY" in msg: return "Error: OPENAI_API_KEY is not set. Add it to a .env file in this folder (OPENAI_API_KEY=sk-...) or run: export OPENAI_API_KEY=your-key" return f"Error: {msg}" def _gradio_history_to_langchain(history): """Convert Gradio chat history to LangChain message list.""" if not history: return [] lc_messages = [] for m in history: if isinstance(m, dict): role, content = m.get("role", ""), m.get("content", "") else: content = getattr(m, "content", m[0] if len(m) > 0 else "") role = getattr(m, "role", m[1] if len(m) > 1 else "assistant") if role == "user": lc_messages.append(HumanMessage(content=content or "")) else: lc_messages.append(AIMessage(content=content or "")) return lc_messages def _get_answer_from_message(msg) -> str: """Extract text from LLM response (AIMessage or str).""" if hasattr(msg, "content"): return getattr(msg, "content", "") or "" return str(msg) if msg else "" def ask_question(message, history): if chain is None: history = history or [] history.append({"role": "assistant", "content": "Upload the PDF first."}) return history, history, "" chat_history_lc = _gradio_history_to_langchain(history or []) llm = chain["llm"] retriever = chain["retriever"] try: # 1) Turn current question + history into a standalone question contextualize_prompt = ChatPromptTemplate.from_messages([ ("system", "Given the chat history and the latest user question, write a single standalone question that can be understood without the chat history. If the question is already standalone, return it unchanged. Do not answer the question."), MessagesPlaceholder("chat_history"), ("human", "{input}"), ]) contextualize_chain = contextualize_prompt | llm standalone = contextualize_chain.invoke({"input": message, "chat_history": chat_history_lc}) query = _get_answer_from_message(standalone).strip() or message # 2) Retrieve docs docs = retriever.invoke(query) context = "\n\n".join((getattr(d, "page_content", "") or str(d) for d in docs)) # 3) Answer with context + history qa_prompt = ChatPromptTemplate.from_messages([ ("system", "You are an assistant for question-answering. Use the following context to answer. If the answer is not in the context, say so. Be concise.\n\nContext:\n{context}"), MessagesPlaceholder("chat_history"), ("human", "{input}"), ]) qa_chain = qa_prompt | llm response = qa_chain.invoke({"input": message, "chat_history": chat_history_lc, "context": context}) answer = _get_answer_from_message(response) or "No answer found." except Exception as e: answer = f"Error: {str(e)}" history = history or [] history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": answer}) return history, history, "" with gr.Blocks() as demo: gr.Markdown("## Chat with your PDF") file_input = gr.File(label="Upload your PDF", file_types=[".pdf"]) status = gr.Textbox(label="Status", interactive=False) chatbot = gr.Chatbot(label="Chat history", type="messages") msg = gr.Textbox(label="Ask anything related to the PDF...") clear = gr.Button("Clear chat") state = gr.State([]) file_input.change(upload_pdf, inputs=[file_input], outputs=[status]) msg.submit(ask_question, inputs=[msg, state], outputs=[chatbot, state, msg]) clear.click(lambda: ([], []), inputs=None, outputs=[chatbot, state]) demo.launch(inline=False)