chat.pdf / app.py
dimoZ's picture
Update app.py
8d6ca2f verified
import os
import streamlit as st
from PyPDF2 import PdfReader
from docx import Document
import google.generativeai as genai
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv
from fuzzywuzzy import process
import base64
from io import BytesIO
# Load environment variables
load_dotenv()
google_api_key = os.getenv("GOOGLE_API_KEY")
if google_api_key is None:
st.error("GOOGLE_API_KEY is not set. Please set it in the .env file.")
# Configure the Gemini API
genai.configure(api_key=google_api_key)
# Global variables
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# List of predefined questions for suggestions
suggested_questions = [
"What are the key achievements mentioned in the report?",
"What is the focus of the Agri-Innovation Hub (AIH) established by PJTSAU?",
"What are the main objectives of NABARD’s Livelihood and Enterprise Development Programme (LEDP)?",
"Which award did the Veerapandy Kalanjia Jeevidam Producer Company Limited (VKJPCL) win?",
"What is the target number of households surveyed under NABARD’s NAFIS 2.0 initiative?",
"How many climate-related projects has NABARD facilitated with grant and loan assistance?"
]
# Function to extract text from PDF
def extract_text_from_pdf(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Function to extract text from .docx
def extract_text_from_docx(docx_docs):
text = ""
tables = []
images = []
for doc in docx_docs:
document = Document(doc)
# Extract text
for para in document.paragraphs:
text += para.text + "\n"
# Extract tables
for table in document.tables:
table_text = ""
for row in table.rows:
row_text = [cell.text for cell in row.cells]
table_text += " | ".join(row_text) + "\n"
tables.append(table_text)
# Extract images (figures)
for rel in document.part.rels.values():
if "image" in rel.target_ref:
img = rel.target_part
img_data = img.blob
img_b64 = base64.b64encode(img_data).decode("utf-8")
images.append(f"data:image/png;base64,{img_b64}") # Storing image as base64
return text, tables, images
# Function to split text into chunks
def split_text_into_chunks(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return splitter.split_text(text)
# Function to create vector store
def create_vector_store(text_chunks):
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
vector_store.save_local("faiss_index")
# Function to load a QA chain
def load_qa_chain_model():
prompt_template = """
Use the context provided to answer the question accurately. If the answer is not found, respond with "Answer not available in the context."
Context:\n{context}\n
Question:\n{question}\n
Answer:
"""
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.5)
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
return load_qa_chain(model, chain_type="stuff", prompt=prompt)
# Function to process user questions
def process_user_question(question):
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
docs = vector_store.similarity_search(question)
chain = load_qa_chain_model()
response = chain({"input_documents": docs, "question": question}, return_only_outputs=True)
return response["output_text"]
# Main app function
def main():
st.set_page_config(page_title="Virtual Agent", layout="wide")
st.title("AI-Virtual Agent")
# Real-time suggestion box
user_input = st.text_input("Type your question", placeholder="Ask a question...", key="question_input")
suggestions = process.extract(user_input, suggested_questions, limit=5) if user_input else []
if user_input:
st.markdown("**Suggestions:**")
for suggestion, _ in suggestions:
st.button(suggestion, on_click=lambda s=suggestion: st.session_state.update({"question_input": s}))
# Sidebar for file upload
with st.sidebar:
st.header("Upload Documents")
pdf_docs = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
docx_docs = st.file_uploader("Upload .docx files", type="docx", accept_multiple_files=True)
if st.button("Process Documents"):
if pdf_docs or docx_docs:
st.spinner("Processing...")
pdf_text = extract_text_from_pdf(pdf_docs) if pdf_docs else ""
docx_text, tables, images = extract_text_from_docx(docx_docs) if docx_docs else ("", [], [])
combined_text = pdf_text + docx_text
text_chunks = split_text_into_chunks(combined_text)
create_vector_store(text_chunks)
st.success("Documents processed successfully!")
# Optionally display tables and images
st.subheader("Tables Extracted:")
for table in tables:
st.write(table)
st.subheader("Figures/Images Extracted:")
for img in images:
st.image(img) # Display base64 image
else:
st.error("Please upload at least one document.")
# Handle question input and response
if user_input:
st.spinner("Generating response...")
answer = process_user_question(user_input)
st.session_state.chat_history.append({"question": user_input, "answer": answer})
st.write(f"**Answer:** {answer}")
# Chat history download option
if st.sidebar.button("Download Chat History"):
chat_history = "\n".join([f"Q: {entry['question']}\nA: {entry['answer']}" for entry in st.session_state.chat_history])
st.sidebar.download_button("Download", chat_history, file_name="chat_history.txt", mime="text/plain")
if __name__ == "__main__":
main()