stevafernandes commited on
Commit
87ff579
·
verified ·
1 Parent(s): 94f0aa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -26
app.py CHANGED
@@ -2,17 +2,24 @@ import streamlit as st
2
  from PyPDF2 import PdfReader
3
  from io import BytesIO
4
  import os
 
5
 
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
 
7
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain.chains.question_answering import load_qa_chain
11
  from langchain.prompts import PromptTemplate
12
 
13
- # --- Get API key from environment variable (set in Hugging Face Secrets or .env file) ---
 
14
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
15
 
 
 
 
 
16
  def get_pdf_text(pdf_docs):
17
  text = ""
18
  for pdf in pdf_docs:
@@ -30,7 +37,7 @@ def get_text_chunks(text):
30
  def get_vector_store(text_chunks, api_key):
31
  embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
32
  vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
33
- vector_store.save_local("/tmp/faiss_index")
34
 
35
  def get_conversational_chain(api_key):
36
  prompt_template = """
@@ -42,69 +49,119 @@ def get_conversational_chain(api_key):
42
  {question}
43
  Answer:
44
  """
45
- model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, google_api_key=api_key)
46
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
47
  chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
48
  return chain
49
 
50
  def user_input(user_question, api_key):
51
  embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
52
- new_db = FAISS.load_local("/tmp/faiss_index", embeddings, allow_dangerous_deserialization=True)
53
  docs = new_db.similarity_search(user_question)
54
  chain = get_conversational_chain(api_key)
55
  response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
56
  st.write("Reply: ", response["output_text"])
57
 
58
  def main():
59
- st.set_page_config(page_title="Chat PDF")
60
- st.header("Retrieval-Augmented Generation - Gemini 2.0")
61
  st.markdown("---")
62
 
63
- # STEP 1: Use API key from env or ask user
64
  if "api_entered" not in st.session_state:
65
  st.session_state["api_entered"] = False
66
  if "pdf_processed" not in st.session_state:
67
  st.session_state["pdf_processed"] = False
68
 
 
69
  api_key = GOOGLE_API_KEY
70
 
 
71
  if not st.session_state["api_entered"]:
72
  if not api_key:
73
- user_api_key = st.text_input("Enter your Gemini API key", type="password")
74
- if st.button("Continue") and user_api_key:
 
 
75
  st.session_state["user_api_key"] = user_api_key
76
  st.session_state["api_entered"] = True
77
- st.experimental_rerun()
78
  st.stop()
79
  else:
80
  st.session_state["user_api_key"] = api_key
81
  st.session_state["api_entered"] = True
82
- st.experimental_rerun()
83
 
84
  api_key = st.session_state.get("user_api_key", "")
85
 
86
  # STEP 2: Upload PDF(s)
87
  if not st.session_state["pdf_processed"]:
88
- st.subheader("Step 2: Upload your PDF file(s)")
89
- pdf_docs = st.file_uploader("Upload PDF files", accept_multiple_files=True, type=['pdf'])
90
- if st.button("Submit & Process PDFs"):
 
 
 
 
 
 
91
  if pdf_docs:
92
- with st.spinner("Processing..."):
93
- raw_text = get_pdf_text(pdf_docs)
94
- text_chunks = get_text_chunks(raw_text)
95
- get_vector_store(text_chunks, api_key)
96
- st.session_state["pdf_processed"] = True
97
- st.success("PDFs processed! You can now ask questions.")
98
- st.experimental_rerun()
 
 
 
 
 
 
 
 
99
  else:
100
  st.error("Please upload at least one PDF file.")
 
 
 
 
101
  st.stop()
102
 
103
  # STEP 3: Ask questions
104
- st.subheader("Step 3: Ask a question about your PDFs")
105
- user_question = st.text_input("Ask a question")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if user_question:
107
- user_input(user_question, api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
- main()
 
2
  from PyPDF2 import PdfReader
3
  from io import BytesIO
4
  import os
5
+ import tempfile
6
 
7
+ # Fixed imports for LangChain
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.prompts import PromptTemplate
14
 
15
+ # --- Get API key from Hugging Face Secrets ---
16
+ # In Hugging Face Spaces, set this in Settings -> Repository secrets
17
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
18
 
19
+ # Use temporary directory for Hugging Face Spaces
20
+ TEMP_DIR = tempfile.gettempdir()
21
+ FAISS_INDEX_PATH = os.path.join(TEMP_DIR, "faiss_index")
22
+
23
  def get_pdf_text(pdf_docs):
24
  text = ""
25
  for pdf in pdf_docs:
 
37
  def get_vector_store(text_chunks, api_key):
38
  embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
39
  vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
40
+ vector_store.save_local(FAISS_INDEX_PATH)
41
 
42
  def get_conversational_chain(api_key):
43
  prompt_template = """
 
49
  {question}
50
  Answer:
51
  """
52
+ model = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0, google_api_key=api_key)
53
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
54
  chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
55
  return chain
56
 
57
  def user_input(user_question, api_key):
58
  embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
59
+ new_db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
60
  docs = new_db.similarity_search(user_question)
61
  chain = get_conversational_chain(api_key)
62
  response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
63
  st.write("Reply: ", response["output_text"])
64
 
65
  def main():
66
+ st.set_page_config(page_title="Chat PDF", page_icon="📄")
67
+ st.header("📚 RAG Chat with PDF - Gemini 2.0")
68
  st.markdown("---")
69
 
70
+ # Initialize session state
71
  if "api_entered" not in st.session_state:
72
  st.session_state["api_entered"] = False
73
  if "pdf_processed" not in st.session_state:
74
  st.session_state["pdf_processed"] = False
75
 
76
+ # Check for API key
77
  api_key = GOOGLE_API_KEY
78
 
79
+ # STEP 1: API Key handling
80
  if not st.session_state["api_entered"]:
81
  if not api_key:
82
+ st.warning(" Google API Key not found in environment variables.")
83
+ st.info("Please add GOOGLE_API_KEY to your Hugging Face Space secrets or enter it below.")
84
+ user_api_key = st.text_input("Enter your Gemini API key", type="password", help="Get your API key from https://makersuite.google.com/app/apikey")
85
+ if st.button("Continue", type="primary") and user_api_key:
86
  st.session_state["user_api_key"] = user_api_key
87
  st.session_state["api_entered"] = True
88
+ st.rerun()
89
  st.stop()
90
  else:
91
  st.session_state["user_api_key"] = api_key
92
  st.session_state["api_entered"] = True
 
93
 
94
  api_key = st.session_state.get("user_api_key", "")
95
 
96
  # STEP 2: Upload PDF(s)
97
  if not st.session_state["pdf_processed"]:
98
+ st.subheader("📤 Step 1: Upload your PDF file(s)")
99
+ pdf_docs = st.file_uploader(
100
+ "Upload PDF files",
101
+ accept_multiple_files=True,
102
+ type=['pdf'],
103
+ help="Select one or more PDF files to analyze"
104
+ )
105
+
106
+ if st.button("Submit & Process PDFs", type="primary", disabled=not pdf_docs):
107
  if pdf_docs:
108
+ with st.spinner("Processing PDFs... This may take a moment."):
109
+ try:
110
+ raw_text = get_pdf_text(pdf_docs)
111
+ if not raw_text.strip():
112
+ st.error(" No text could be extracted from the PDF(s). Please check your files.")
113
+ st.stop()
114
+
115
+ text_chunks = get_text_chunks(raw_text)
116
+ get_vector_store(text_chunks, api_key)
117
+ st.session_state["pdf_processed"] = True
118
+ st.success(" PDFs processed successfully! You can now ask questions.")
119
+ st.rerun()
120
+ except Exception as e:
121
+ st.error(f" Error processing PDFs: {str(e)}")
122
+ st.stop()
123
  else:
124
  st.error("Please upload at least one PDF file.")
125
+
126
+ if not pdf_docs:
127
+ st.info(" Please upload one or more PDF files to get started.")
128
+
129
  st.stop()
130
 
131
  # STEP 3: Ask questions
132
+ st.subheader(" Step 2: Ask questions about your PDFs")
133
+
134
+ # Add a reset button
135
+ col1, col2 = st.columns([3, 1])
136
+ with col2:
137
+ if st.button(" Upload New PDFs"):
138
+ st.session_state["pdf_processed"] = False
139
+ st.rerun()
140
+
141
+ # Question input
142
+ user_question = st.text_input(
143
+ "Ask a question about your uploaded PDFs",
144
+ placeholder="e.g., What are the main topics discussed in the document?",
145
+ help="The AI will only answer based on the content of your uploaded PDFs"
146
+ )
147
+
148
  if user_question:
149
+ with st.spinner("Searching for answer..."):
150
+ try:
151
+ user_input(user_question, api_key)
152
+ except Exception as e:
153
+ st.error(f" Error getting answer: {str(e)}")
154
+
155
+ # Add footer
156
+ st.markdown("---")
157
+ st.markdown(
158
+ """
159
+ <div style='text-align: center'>
160
+ <small>Built with Streamlit, LangChain, and Google Gemini 2.0</small>
161
+ </div>
162
+ """,
163
+ unsafe_allow_html=True
164
+ )
165
 
166
  if __name__ == "__main__":
167
+ main()