stevafernandes commited on
Commit
d96618a
·
verified ·
1 Parent(s): cd18249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -59
app.py CHANGED
@@ -1,72 +1,86 @@
1
- import os
2
- import asyncio
3
- import gradio as gr
4
  from PyPDF2 import PdfReader
 
5
 
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
8
  from langchain_community.vectorstores import FAISS
9
- from langchain.chains import RetrievalQA
 
 
10
 
11
- # ---- ensure asyncio loop ----
12
- try:
13
- asyncio.get_running_loop()
14
- except RuntimeError:
15
- asyncio.set_event_loop(asyncio.new_event_loop())
16
 
17
- # load key
18
- API_KEY = os.getenv("GOOGLE_API_KEY", "").strip()
19
- if not API_KEY:
20
- raise RuntimeError("Set the GOOGLE_API_KEY env var")
21
-
22
- # 1) build FAISS index over librarianship.pdf
23
- vector_store = None
24
- def build_index():
25
- global vector_store
26
- reader = PdfReader("librarianship.pdf")
27
- full_text = ""
28
  for page in reader.pages:
29
- txt = page.extract_text() or ""
30
- full_text += txt + "\n"
31
- splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
32
- chunks = splitter.split_text(full_text)
33
- embeds = GoogleGenerativeAIEmbeddings(
34
- model="models/embedding-001", google_api_key=API_KEY
35
- )
36
- vector_store = FAISS.from_texts(chunks, embedding=embeds)
37
- print(f"Indexed {len(chunks)} chunks from librarianship.pdf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- build_index()
 
 
 
40
 
41
- # 2) set up a RetrievalQA chain
42
- llm = ChatGoogleGenerativeAI(
43
- model="gemini-2.0-flash-exp",
44
- temperature=0,
45
- google_api_key=API_KEY
46
- )
47
- qa = RetrievalQA.from_chain_type(
48
- llm=llm,
49
- chain_type="stuff",
50
- retriever=vector_store.as_retriever()
51
- )
52
 
53
- # 3) Gradio interface
54
- def answer(question, chat_history):
55
- if not question.strip():
56
- return chat_history, ""
57
- result = qa.run(question)
58
- chat_history.append({"role": "user", "content": question})
59
- chat_history.append({"role": "assistant", "content": result})
60
- return chat_history, ""
61
 
62
- with gr.Blocks() as demo:
63
- gr.Markdown("## 📚 Chat over **librarianship.pdf** with Gemini AI")
64
- chatbot = gr.Chatbot(type="messages")
65
- user_input = gr.Textbox(placeholder="Ask anything about librarianship…")
66
- user_input.submit(answer, [user_input, chatbot], [chatbot, user_input])
67
 
68
  if __name__ == "__main__":
69
- demo.launch(
70
- server_name="0.0.0.0",
71
- server_port=int(os.environ.get("PORT", 7860))
72
- )
 
1
+ import streamlit as st
 
 
2
  from PyPDF2 import PdfReader
3
+ import os
4
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain.chains.question_answering import load_qa_chain
10
+ from langchain.prompts import PromptTemplate
11
 
12
+ # --- Get API key from environment variable (set in Hugging Face Secrets or .env file) ---
13
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
14
+ PDF_PATH = "librarianship.pdf"
15
+ INDEX_PATH = "/tmp/faiss_index"
 
16
 
17
+ def get_pdf_text(pdf_path):
18
+ text = ""
19
+ reader = PdfReader(pdf_path)
 
 
 
 
 
 
 
 
20
  for page in reader.pages:
21
+ page_text = page.extract_text()
22
+ if page_text:
23
+ text += page_text
24
+ return text
25
+
26
+ def get_text_chunks(text):
27
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
28
+ return text_splitter.split_text(text)
29
+
30
+ def build_and_save_vector_store(text_chunks, api_key):
31
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
32
+ vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
33
+ vector_store.save_local(INDEX_PATH)
34
+
35
+ def load_vector_store(api_key):
36
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
37
+ return FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
38
+
39
+ def get_conversational_chain(api_key):
40
+ prompt_template = """
41
+ You are a helpful assistant that only answers based on the context provided from the PDF document.
42
+ Do not use any external knowledge or assumptions. If the answer is not found in the context below, reply with "I don't know."
43
+ Context:
44
+ {context}
45
+ Question:
46
+ {question}
47
+ Answer:
48
+ """
49
+ model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, google_api_key=api_key)
50
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
51
+ chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
52
+ return chain
53
+
54
+ def user_input(user_question, api_key):
55
+ db = load_vector_store(api_key)
56
+ docs = db.similarity_search(user_question)
57
+ chain = get_conversational_chain(api_key)
58
+ response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
59
+ st.write("**Reply:**", response["output_text"])
60
 
61
+ def main():
62
+ st.set_page_config(page_title="Chat librarianship.pdf")
63
+ st.header("RAG: Chat with librarianship.pdf using Gemini 2.0")
64
+ st.markdown("---")
65
 
66
+ # Ensure API key is present
67
+ if not GOOGLE_API_KEY:
68
+ st.error("Please set the GOOGLE_API_KEY environment variable in your Hugging Face Space secrets or .env file.")
69
+ st.stop()
 
 
 
 
 
 
 
70
 
71
+ # Build FAISS index if not present
72
+ if not os.path.exists(INDEX_PATH + ".index"):
73
+ with st.spinner(f"Indexing {PDF_PATH}..."):
74
+ raw_text = get_pdf_text(PDF_PATH)
75
+ text_chunks = get_text_chunks(raw_text)
76
+ build_and_save_vector_store(text_chunks, GOOGLE_API_KEY)
77
+ st.success(f"Indexed {PDF_PATH}. You can now ask questions.")
 
78
 
79
+ # Simple chat UI
80
+ st.subheader("Ask a question about librarianship.pdf")
81
+ user_question = st.text_input("Ask a question", key="user_question")
82
+ if user_question:
83
+ user_input(user_question, GOOGLE_API_KEY)
84
 
85
  if __name__ == "__main__":
86
+ main()