Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_chroma import Chroma | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| # Low-cost model | |
| MODEL = "gpt-4o-mini" | |
| load_dotenv(override=True) | |
| chain = None # global chain (OK for single-user; see note below) | |
| def pick_db_dir() -> str: | |
| # 1) Allow override via env | |
| env_dir = os.environ.get("DB_DIR") | |
| if env_dir: | |
| return env_dir | |
| # 2) Prefer persistent mount if available/writable | |
| data_dir = "/data" | |
| if os.path.isdir(data_dir) and os.access(data_dir, os.W_OK): | |
| return os.path.join(data_dir, "vector_db") | |
| # 3) Fallback to tmp (ephemeral) | |
| return "/tmp/vector_db" | |
| DB_DIR = pick_db_dir() | |
| def process_pdf(pdf_file): | |
| try: | |
| if not os.getenv("OPENAI_API_KEY"): | |
| raise RuntimeError( | |
| "OPENAI_API_KEY is not set. Add it to your environment or as a Secret on HF Spaces." | |
| ) | |
| # Ensure DB dir is writable/exists | |
| os.makedirs(DB_DIR, exist_ok=True) | |
| file_path = pdf_file.name # gr.File gives a temp file with .name path | |
| loader = PyPDFLoader(file_path) | |
| pages = loader.load() | |
| if not pages: | |
| raise ValueError("No text found in PDF (may be scanned or protected).") | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| ) | |
| chunks = splitter.split_documents(pages) | |
| if not chunks: | |
| raise ValueError("Unable to split PDF into chunks (empty/protected PDF).") | |
| embeddings = OpenAIEmbeddings() | |
| # Reset persisted DB each upload (now in a writable location) | |
| if os.path.exists(DB_DIR): | |
| shutil.rmtree(DB_DIR, ignore_errors=True) | |
| os.makedirs(DB_DIR, exist_ok=True) | |
| vectorstore = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| persist_directory=DB_DIR, | |
| ) | |
| llm = ChatOpenAI(model=MODEL, temperature=0.2) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) | |
| # Store llm + retriever only (no langchain.chains to avoid langchain_core.memory) | |
| return {"llm": llm, "retriever": retriever} | |
| except Exception as e: | |
| raise RuntimeError(f"PDF processing failed: {str(e)}") | |
| def upload_pdf(file): | |
| global chain | |
| if file is None: | |
| chain = None | |
| return "Please upload a PDF." | |
| try: | |
| chain = process_pdf(file) | |
| return f"PDF processed. Vector DB at: {DB_DIR}. Ask questions now." | |
| except RuntimeError as e: | |
| chain = None | |
| msg = str(e) | |
| if "OPENAI_API_KEY" in msg: | |
| return "Error: OPENAI_API_KEY is not set. Add it to a .env file in this folder (OPENAI_API_KEY=sk-...) or run: export OPENAI_API_KEY=your-key" | |
| return f"Error: {msg}" | |
| def _gradio_history_to_langchain(history): | |
| """Convert Gradio chat history to LangChain message list.""" | |
| if not history: | |
| return [] | |
| lc_messages = [] | |
| for m in history: | |
| if isinstance(m, dict): | |
| role, content = m.get("role", ""), m.get("content", "") | |
| else: | |
| content = getattr(m, "content", m[0] if len(m) > 0 else "") | |
| role = getattr(m, "role", m[1] if len(m) > 1 else "assistant") | |
| if role == "user": | |
| lc_messages.append(HumanMessage(content=content or "")) | |
| else: | |
| lc_messages.append(AIMessage(content=content or "")) | |
| return lc_messages | |
| def _get_answer_from_message(msg) -> str: | |
| """Extract text from LLM response (AIMessage or str).""" | |
| if hasattr(msg, "content"): | |
| return getattr(msg, "content", "") or "" | |
| return str(msg) if msg else "" | |
| def ask_question(message, history): | |
| if chain is None: | |
| history = history or [] | |
| history.append({"role": "assistant", "content": "Upload the PDF first."}) | |
| return history, history, "" | |
| chat_history_lc = _gradio_history_to_langchain(history or []) | |
| llm = chain["llm"] | |
| retriever = chain["retriever"] | |
| try: | |
| # 1) Turn current question + history into a standalone question | |
| contextualize_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Given the chat history and the latest user question, write a single standalone question that can be understood without the chat history. If the question is already standalone, return it unchanged. Do not answer the question."), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| contextualize_chain = contextualize_prompt | llm | |
| standalone = contextualize_chain.invoke({"input": message, "chat_history": chat_history_lc}) | |
| query = _get_answer_from_message(standalone).strip() or message | |
| # 2) Retrieve docs | |
| docs = retriever.invoke(query) | |
| context = "\n\n".join((getattr(d, "page_content", "") or str(d) for d in docs)) | |
| # 3) Answer with context + history | |
| qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an assistant for question-answering. Use the following context to answer. If the answer is not in the context, say so. Be concise.\n\nContext:\n{context}"), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ]) | |
| qa_chain = qa_prompt | llm | |
| response = qa_chain.invoke({"input": message, "chat_history": chat_history_lc, "context": context}) | |
| answer = _get_answer_from_message(response) or "No answer found." | |
| except Exception as e: | |
| answer = f"Error: {str(e)}" | |
| history = history or [] | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, history, "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Chat with your PDF") | |
| file_input = gr.File(label="Upload your PDF", file_types=[".pdf"]) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| chatbot = gr.Chatbot(label="Chat history", type="messages") | |
| msg = gr.Textbox(label="Ask anything related to the PDF...") | |
| clear = gr.Button("Clear chat") | |
| state = gr.State([]) | |
| file_input.change(upload_pdf, inputs=[file_input], outputs=[status]) | |
| msg.submit(ask_question, inputs=[msg, state], outputs=[chatbot, state, msg]) | |
| clear.click(lambda: ([], []), inputs=None, outputs=[chatbot, state]) | |
| demo.launch(inline=False) |