File size: 6,712 Bytes
9edbb70
aca5cc5
9edbb70
 
aca5cc5
9edbb70
 
aca5cc5
 
2e07a52
 
9edbb70
aca5cc5
9edbb70
 
 
2554008
aca5cc5
 
e73bf0f
2e07a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9edbb70
a64e065
aca5cc5
 
 
 
 
2e07a52
 
 
aca5cc5
 
a64e065
aca5cc5
a64e065
aca5cc5
 
 
a64e065
aca5cc5
a64e065
aca5cc5
 
a64e065
aca5cc5
 
a64e065
aca5cc5
2e07a52
aca5cc5
 
2e07a52
aca5cc5
 
 
 
 
 
 
 
 
 
2e07a52
 
aca5cc5
a64e065
 
9edbb70
aca5cc5
9edbb70
 
 
 
aca5cc5
2e07a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca5cc5
9edbb70
a64e065
9edbb70
aca5cc5
 
 
 
2e07a52
 
 
 
aca5cc5
2e07a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca5cc5
 
 
 
 
 
3394cd2
aca5cc5
 
9edbb70
aca5cc5
 
a64e065
1795471
9edbb70
aca5cc5
 
a64e065
65c53a0
aca5cc5
a64e065
 
aca5cc5
 
9edbb70
aca5cc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import shutil
from dotenv import load_dotenv
import gradio as gr

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage

# Low-cost model
MODEL = "gpt-4o-mini"

load_dotenv(override=True)

chain = None  # global chain (OK for single-user; see note below)


def pick_db_dir() -> str:
    # 1) Allow override via env
    env_dir = os.environ.get("DB_DIR")
    if env_dir:
        return env_dir

    # 2) Prefer persistent mount if available/writable
    data_dir = "/data"
    if os.path.isdir(data_dir) and os.access(data_dir, os.W_OK):
        return os.path.join(data_dir, "vector_db")

    # 3) Fallback to tmp (ephemeral)
    return "/tmp/vector_db"


DB_DIR = pick_db_dir()


def process_pdf(pdf_file):
    try:
        if not os.getenv("OPENAI_API_KEY"):
            raise RuntimeError(
                "OPENAI_API_KEY is not set. Add it to your environment or as a Secret on HF Spaces."
            )

        # Ensure DB dir is writable/exists
        os.makedirs(DB_DIR, exist_ok=True)

        file_path = pdf_file.name  # gr.File gives a temp file with .name path
        loader = PyPDFLoader(file_path)
        pages = loader.load()

        if not pages:
            raise ValueError("No text found in PDF (may be scanned or protected).")

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
        )
        chunks = splitter.split_documents(pages)

        if not chunks:
            raise ValueError("Unable to split PDF into chunks (empty/protected PDF).")

        embeddings = OpenAIEmbeddings()

        # Reset persisted DB each upload (now in a writable location)
        if os.path.exists(DB_DIR):
            shutil.rmtree(DB_DIR, ignore_errors=True)
        os.makedirs(DB_DIR, exist_ok=True)

        vectorstore = Chroma.from_documents(
            documents=chunks,
            embedding=embeddings,
            persist_directory=DB_DIR,
        )

        llm = ChatOpenAI(model=MODEL, temperature=0.2)
        retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

        # Store llm + retriever only (no langchain.chains to avoid langchain_core.memory)
        return {"llm": llm, "retriever": retriever}

    except Exception as e:
        raise RuntimeError(f"PDF processing failed: {str(e)}")


def upload_pdf(file):
    global chain
    if file is None:
        chain = None
        return "Please upload a PDF."
    try:
        chain = process_pdf(file)
        return f"PDF processed. Vector DB at: {DB_DIR}. Ask questions now."
    except RuntimeError as e:
        chain = None
        msg = str(e)
        if "OPENAI_API_KEY" in msg:
            return "Error: OPENAI_API_KEY is not set. Add it to a .env file in this folder (OPENAI_API_KEY=sk-...) or run: export OPENAI_API_KEY=your-key"
        return f"Error: {msg}"


def _gradio_history_to_langchain(history):
    """Convert Gradio chat history to LangChain message list."""
    if not history:
        return []
    lc_messages = []
    for m in history:
        if isinstance(m, dict):
            role, content = m.get("role", ""), m.get("content", "")
        else:
            content = getattr(m, "content", m[0] if len(m) > 0 else "")
            role = getattr(m, "role", m[1] if len(m) > 1 else "assistant")
        if role == "user":
            lc_messages.append(HumanMessage(content=content or ""))
        else:
            lc_messages.append(AIMessage(content=content or ""))
    return lc_messages


def _get_answer_from_message(msg) -> str:
    """Extract text from LLM response (AIMessage or str)."""
    if hasattr(msg, "content"):
        return getattr(msg, "content", "") or ""
    return str(msg) if msg else ""


def ask_question(message, history):
    if chain is None:
        history = history or []
        history.append({"role": "assistant", "content": "Upload the PDF first."})
        return history, history, ""

    chat_history_lc = _gradio_history_to_langchain(history or [])
    llm = chain["llm"]
    retriever = chain["retriever"]

    try:
        # 1) Turn current question + history into a standalone question
        contextualize_prompt = ChatPromptTemplate.from_messages([
            ("system", "Given the chat history and the latest user question, write a single standalone question that can be understood without the chat history. If the question is already standalone, return it unchanged. Do not answer the question."),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])
        contextualize_chain = contextualize_prompt | llm
        standalone = contextualize_chain.invoke({"input": message, "chat_history": chat_history_lc})
        query = _get_answer_from_message(standalone).strip() or message

        # 2) Retrieve docs
        docs = retriever.invoke(query)
        context = "\n\n".join((getattr(d, "page_content", "") or str(d) for d in docs))

        # 3) Answer with context + history
        qa_prompt = ChatPromptTemplate.from_messages([
            ("system", "You are an assistant for question-answering. Use the following context to answer. If the answer is not in the context, say so. Be concise.\n\nContext:\n{context}"),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ])
        qa_chain = qa_prompt | llm
        response = qa_chain.invoke({"input": message, "chat_history": chat_history_lc, "context": context})
        answer = _get_answer_from_message(response) or "No answer found."
    except Exception as e:
        answer = f"Error: {str(e)}"

    history = history or []
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": answer})
    return history, history, ""


with gr.Blocks() as demo:
    gr.Markdown("## Chat with your PDF")

    file_input = gr.File(label="Upload your PDF", file_types=[".pdf"])
    status = gr.Textbox(label="Status", interactive=False)

    chatbot = gr.Chatbot(label="Chat history", type="messages")
    msg = gr.Textbox(label="Ask anything related to the PDF...")
    clear = gr.Button("Clear chat")

    state = gr.State([])

    file_input.change(upload_pdf, inputs=[file_input], outputs=[status])
    msg.submit(ask_question, inputs=[msg, state], outputs=[chatbot, state, msg])
    clear.click(lambda: ([], []), inputs=None, outputs=[chatbot, state])

demo.launch(inline=False)