|
|
import os |
|
|
import streamlit as st |
|
|
from PyPDF2 import PdfReader |
|
|
from docx import Document |
|
|
import google.generativeai as genai |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings |
|
|
from langchain.chains.question_answering import load_qa_chain |
|
|
from langchain.prompts import PromptTemplate |
|
|
from dotenv import load_dotenv |
|
|
from fuzzywuzzy import process |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
google_api_key = os.getenv("GOOGLE_API_KEY") |
|
|
if google_api_key is None: |
|
|
st.error("GOOGLE_API_KEY is not set. Please set it in the .env file.") |
|
|
|
|
|
|
|
|
genai.configure(api_key=google_api_key) |
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
suggested_questions = [ |
|
|
"What are the key achievements mentioned in the report?", |
|
|
"What is the focus of the Agri-Innovation Hub (AIH) established by PJTSAU?", |
|
|
"What are the main objectives of NABARD’s Livelihood and Enterprise Development Programme (LEDP)?", |
|
|
"Which award did the Veerapandy Kalanjia Jeevidam Producer Company Limited (VKJPCL) win?", |
|
|
"What is the target number of households surveyed under NABARD’s NAFIS 2.0 initiative?", |
|
|
"How many climate-related projects has NABARD facilitated with grant and loan assistance?" |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
def extract_text_from_pdf(pdf_docs): |
|
|
text = "" |
|
|
for pdf in pdf_docs: |
|
|
pdf_reader = PdfReader(pdf) |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() |
|
|
return text |
|
|
|
|
|
|
|
|
def extract_text_from_docx(docx_docs): |
|
|
text = "" |
|
|
tables = [] |
|
|
images = [] |
|
|
for doc in docx_docs: |
|
|
document = Document(doc) |
|
|
|
|
|
|
|
|
for para in document.paragraphs: |
|
|
text += para.text + "\n" |
|
|
|
|
|
|
|
|
for table in document.tables: |
|
|
table_text = "" |
|
|
for row in table.rows: |
|
|
row_text = [cell.text for cell in row.cells] |
|
|
table_text += " | ".join(row_text) + "\n" |
|
|
tables.append(table_text) |
|
|
|
|
|
|
|
|
for rel in document.part.rels.values(): |
|
|
if "image" in rel.target_ref: |
|
|
img = rel.target_part |
|
|
img_data = img.blob |
|
|
img_b64 = base64.b64encode(img_data).decode("utf-8") |
|
|
images.append(f"data:image/png;base64,{img_b64}") |
|
|
|
|
|
return text, tables, images |
|
|
|
|
|
|
|
|
def split_text_into_chunks(text): |
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
return splitter.split_text(text) |
|
|
|
|
|
|
|
|
def create_vector_store(text_chunks): |
|
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") |
|
|
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) |
|
|
vector_store.save_local("faiss_index") |
|
|
|
|
|
|
|
|
def load_qa_chain_model(): |
|
|
prompt_template = """ |
|
|
Use the context provided to answer the question accurately. If the answer is not found, respond with "Answer not available in the context." |
|
|
Context:\n{context}\n |
|
|
Question:\n{question}\n |
|
|
Answer: |
|
|
""" |
|
|
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.5) |
|
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
|
return load_qa_chain(model, chain_type="stuff", prompt=prompt) |
|
|
|
|
|
|
|
|
def process_user_question(question): |
|
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") |
|
|
vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) |
|
|
docs = vector_store.similarity_search(question) |
|
|
chain = load_qa_chain_model() |
|
|
response = chain({"input_documents": docs, "question": question}, return_only_outputs=True) |
|
|
return response["output_text"] |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.set_page_config(page_title="Virtual Agent", layout="wide") |
|
|
st.title("AI-Virtual Agent") |
|
|
|
|
|
|
|
|
user_input = st.text_input("Type your question", placeholder="Ask a question...", key="question_input") |
|
|
suggestions = process.extract(user_input, suggested_questions, limit=5) if user_input else [] |
|
|
if user_input: |
|
|
st.markdown("**Suggestions:**") |
|
|
for suggestion, _ in suggestions: |
|
|
st.button(suggestion, on_click=lambda s=suggestion: st.session_state.update({"question_input": s})) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Upload Documents") |
|
|
pdf_docs = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) |
|
|
docx_docs = st.file_uploader("Upload .docx files", type="docx", accept_multiple_files=True) |
|
|
if st.button("Process Documents"): |
|
|
if pdf_docs or docx_docs: |
|
|
st.spinner("Processing...") |
|
|
pdf_text = extract_text_from_pdf(pdf_docs) if pdf_docs else "" |
|
|
docx_text, tables, images = extract_text_from_docx(docx_docs) if docx_docs else ("", [], []) |
|
|
combined_text = pdf_text + docx_text |
|
|
text_chunks = split_text_into_chunks(combined_text) |
|
|
create_vector_store(text_chunks) |
|
|
st.success("Documents processed successfully!") |
|
|
|
|
|
|
|
|
st.subheader("Tables Extracted:") |
|
|
for table in tables: |
|
|
st.write(table) |
|
|
|
|
|
st.subheader("Figures/Images Extracted:") |
|
|
for img in images: |
|
|
st.image(img) |
|
|
else: |
|
|
st.error("Please upload at least one document.") |
|
|
|
|
|
|
|
|
if user_input: |
|
|
st.spinner("Generating response...") |
|
|
answer = process_user_question(user_input) |
|
|
st.session_state.chat_history.append({"question": user_input, "answer": answer}) |
|
|
st.write(f"**Answer:** {answer}") |
|
|
|
|
|
|
|
|
if st.sidebar.button("Download Chat History"): |
|
|
chat_history = "\n".join([f"Q: {entry['question']}\nA: {entry['answer']}" for entry in st.session_state.chat_history]) |
|
|
st.sidebar.download_button("Download", chat_history, file_name="chat_history.txt", mime="text/plain") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|