Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PyPDF2 import PdfReader | |
| from io import BytesIO | |
| import os | |
| import tempfile | |
| # Fixed imports for LangChain | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain.prompts import PromptTemplate | |
| # --- Get API key from Hugging Face Secrets --- | |
| # In Hugging Face Spaces, set this in Settings -> Repository secrets | |
| GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "") | |
| # Use temporary directory for Hugging Face Spaces | |
| TEMP_DIR = tempfile.gettempdir() | |
| FAISS_INDEX_PATH = os.path.join(TEMP_DIR, "faiss_index") | |
| def get_pdf_text(pdf_docs): | |
| text = "" | |
| for pdf in pdf_docs: | |
| pdf_reader = PdfReader(BytesIO(pdf.read())) | |
| for page in pdf_reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text | |
| return text | |
| def get_text_chunks(text): | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000) | |
| return text_splitter.split_text(text) | |
| def get_vector_store(text_chunks, api_key): | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) | |
| vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) | |
| vector_store.save_local(FAISS_INDEX_PATH) | |
| def get_conversational_chain(api_key): | |
| prompt_template = """ | |
| You are a helpful assistant that only answers based on the context provided from the PDF documents. | |
| Do not use any external knowledge or assumptions. If the answer is not found in the context below, reply with "I don't know." | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| model = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0, google_api_key=api_key) | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| chain = load_qa_chain(model, chain_type="stuff", prompt=prompt) | |
| return chain | |
| def user_input(user_question, api_key): | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) | |
| new_db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) | |
| docs = new_db.similarity_search(user_question) | |
| chain = get_conversational_chain(api_key) | |
| response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True) | |
| st.write("Reply: ", response["output_text"]) | |
| def main(): | |
| st.set_page_config(page_title="Chat PDF", page_icon="") | |
| st.header("Retrieval-Augmented Generation - Gemini 2.0") | |
| st.markdown("---") | |
| # Initialize session state | |
| if "api_entered" not in st.session_state: | |
| st.session_state["api_entered"] = False | |
| if "pdf_processed" not in st.session_state: | |
| st.session_state["pdf_processed"] = False | |
| # Check for API key | |
| api_key = GOOGLE_API_KEY | |
| # STEP 1: API Key handling | |
| if not st.session_state["api_entered"]: | |
| if not api_key: | |
| st.warning(" Google API Key not found in environment variables.") | |
| st.info("Please add GOOGLE_API_KEY to your Hugging Face Space secrets or enter it below.") | |
| user_api_key = st.text_input("Enter your Gemini API key", type="password", help="Get your API key from https://makersuite.google.com/app/apikey") | |
| if st.button("Continue", type="primary") and user_api_key: | |
| st.session_state["user_api_key"] = user_api_key | |
| st.session_state["api_entered"] = True | |
| st.rerun() | |
| st.stop() | |
| else: | |
| st.session_state["user_api_key"] = api_key | |
| st.session_state["api_entered"] = True | |
| api_key = st.session_state.get("user_api_key", "") | |
| # STEP 2: Upload PDF(s) | |
| if not st.session_state["pdf_processed"]: | |
| st.subheader(" Step 1: Upload your PDF file(s)") | |
| pdf_docs = st.file_uploader( | |
| "Upload PDF files", | |
| accept_multiple_files=True, | |
| type=['pdf'], | |
| help="Select one or more PDF files to analyze" | |
| ) | |
| if st.button("Submit & Process PDFs", type="primary", disabled=not pdf_docs): | |
| if pdf_docs: | |
| with st.spinner("Processing PDFs... This may take a moment."): | |
| try: | |
| raw_text = get_pdf_text(pdf_docs) | |
| if not raw_text.strip(): | |
| st.error(" No text could be extracted from the PDF(s). Please check your files.") | |
| st.stop() | |
| text_chunks = get_text_chunks(raw_text) | |
| get_vector_store(text_chunks, api_key) | |
| st.session_state["pdf_processed"] = True | |
| st.success(" PDFs processed successfully! You can now ask questions.") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f" Error processing PDFs: {str(e)}") | |
| st.stop() | |
| else: | |
| st.error("Please upload at least one PDF file.") | |
| if not pdf_docs: | |
| st.info(" Please upload one or more PDF files to get started.") | |
| st.stop() | |
| # STEP 3: Ask questions | |
| st.subheader(" Step 2: Ask questions about your PDFs") | |
| # Add a reset button | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| if st.button(" Upload New PDFs"): | |
| st.session_state["pdf_processed"] = False | |
| st.rerun() | |
| # Question input | |
| user_question = st.text_input( | |
| "Ask a question about your uploaded PDFs", | |
| placeholder="e.g., What are the main topics discussed in the document?", | |
| help="The AI will only answer based on the content of your uploaded PDFs" | |
| ) | |
| if user_question: | |
| with st.spinner("Searching for answer..."): | |
| try: | |
| user_input(user_question, api_key) | |
| except Exception as e: | |
| st.error(f" Error getting answer: {str(e)}") | |
| # Add footer | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style='text-align: center'> | |
| <small></small> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == "__main__": | |
| main() |