File size: 4,618 Bytes
f959683
 
 
 
 
 
ebc3422
 
1ac30c6
ebc3422
49c5632
b8ba0c8
d96618a
b8ba0c8
d96618a
 
 
b8ba0c8
b7c855c
d96618a
deda8cf
b7c855c
ed4265b
b7c855c
 
 
 
 
 
ed4265b
d96618a
ed4265b
 
 
d96618a
b7c855c
ed4265b
d96618a
b7c855c
d96618a
 
ed4265b
b7c855c
ed4265b
 
 
 
 
 
 
 
 
 
 
 
 
b7c855c
 
 
ed4265b
 
b7c855c
ebc3422
d96618a
b7c855c
 
d96618a
deda8cf
b7c855c
 
 
 
 
deda8cf
b7c855c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deda8cf
b7c855c
 
 
ed4265b
b7c855c
b8ba0c8
 
d96618a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import asyncio
try:
    asyncio.get_running_loop()
except RuntimeError:
    asyncio.set_event_loop(asyncio.new_event_loop())

import streamlit as st
from PyPDF2 import PdfReader
from io import BytesIO
import os

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate

# --- Get API key from environment variable (set in Hugging Face Secrets or .env file) ---
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")

def get_pdf_text(pdf_docs):
    text = ""
    for pdf in pdf_docs:
        pdf_reader = PdfReader(BytesIO(pdf.read()))
        for page in pdf_reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text
    return text

def get_text_chunks(text):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
    return text_splitter.split_text(text)

def get_vector_store(text_chunks, api_key):
    embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
    vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
    vector_store.save_local("/tmp/faiss_index")

def get_conversational_chain(api_key):
    prompt_template = """
    You are a helpful assistant that only answers based on the context provided from the PDF documents.
    Do not use any external knowledge or assumptions. If the answer is not found in the context below, reply with "I don't know."
    Context:
    {context}
    Question:
    {question}
    Answer:
    """
    model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, google_api_key=api_key)
    prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
    return chain

def user_input(user_question, api_key):
    embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
    new_db = FAISS.load_local("/tmp/faiss_index", embeddings, allow_dangerous_deserialization=True)
    docs = new_db.similarity_search(user_question)
    chain = get_conversational_chain(api_key)
    response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
    st.write("Reply: ", response["output_text"])

def main():
    st.set_page_config(page_title="Chat PDF")
    st.header("Retrieval-Augmented Generation - Gemini 2.0")
    st.markdown("---")

    # STEP 1: Use API key from env or ask user
    if "api_entered" not in st.session_state:
        st.session_state["api_entered"] = False
    if "pdf_processed" not in st.session_state:
        st.session_state["pdf_processed"] = False

    api_key = GOOGLE_API_KEY

    if not st.session_state["api_entered"]:
        if not api_key:
            user_api_key = st.text_input("Enter your Gemini API key", type="password")
            if st.button("Continue") and user_api_key:
                st.session_state["user_api_key"] = user_api_key
                st.session_state["api_entered"] = True
                st.experimental_rerun()
            st.stop()
        else:
            st.session_state["user_api_key"] = api_key
            st.session_state["api_entered"] = True
            st.experimental_rerun()

    api_key = st.session_state.get("user_api_key", "")

    # STEP 2: Upload PDF(s)
    if not st.session_state["pdf_processed"]:
        st.subheader("Step 2: Upload your PDF file(s)")
        pdf_docs = st.file_uploader("Upload PDF files", accept_multiple_files=True, type=['pdf'])
        if st.button("Submit & Process PDFs"):
            if pdf_docs:
                with st.spinner("Processing..."):
                    raw_text = get_pdf_text(pdf_docs)
                    text_chunks = get_text_chunks(raw_text)
                    get_vector_store(text_chunks, api_key)
                    st.session_state["pdf_processed"] = True
                    st.success("PDFs processed! You can now ask questions.")
                    st.experimental_rerun()
            else:
                st.error("Please upload at least one PDF file.")
        st.stop()

    # STEP 3: Ask questions
    st.subheader("Step 3: Ask a question about your PDFs")
    user_question = st.text_input("Ask a question")
    if user_question:
        user_input(user_question, api_key)

if __name__ == "__main__":
    main()