Talk2Docs / app.py
rairo's picture
Update app.py
860778e verified
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
import os
import pandas as pd
from pandasai import SmartDataframe, SmartDatalake
from pandasai.responses.response_parser import ResponseParser
from pandasai.llm import GoogleGemini
import plotly.graph_objects as go
from PIL import Image
import io
import base64
class StreamLitResponse(ResponseParser):
def __init__(self, context):
super().__init__(context)
def format_dataframe(self, result):
"""Enhanced DataFrame rendering with type identifier"""
return {
'type': 'dataframe',
'value': result['value']
}
def format_plot(self, result):
"""Enhanced plot rendering with type identifier"""
try:
image = result['value']
# Convert image to base64 for consistent storage
if isinstance(image, Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
elif isinstance(image, bytes):
base64_image = base64.b64encode(image).decode('utf-8')
elif isinstance(image, str) and os.path.exists(image):
with open(image, "rb") as f:
base64_image = base64.b64encode(f.read()).decode('utf-8')
else:
return {'type': 'text', 'value': "Unsupported image format"}
return {
'type': 'plot',
'value': base64_image
}
except Exception as e:
return {'type': 'text', 'value': f"Error processing plot: {e}"}
def format_other(self, result):
"""Handle other types of responses"""
return {
'type': 'text',
'value': str(result['value'])
}
# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
st.error("GOOGLE_API_KEY environment variable not set.")
st.stop()
def generateResponse(prompt, dfs):
"""Generate response using PandasAI"""
llm = GoogleGemini(api_key=GOOGLE_API_KEY)
pandas_agent = SmartDatalake(dfs, config={
"llm": llm,
"response_parser": StreamLitResponse
})
return pandas_agent.chat(prompt)
# Other utility functions remain the same as in the original code
# (get_pdf_text, get_text_chunks, get_vectorstore, get_conversation_chain)
# Processing pdfs
def get_pdf_text(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Splitting text into small chunks to create embeddings
def get_text_chunks(text):
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 1000,
chunk_overlap = 200,
length_function = len
)
chunks = text_splitter.split_text(text)
return chunks
# Using Google's embedding004 model to create embeddings and FAISS to store the embeddings
def get_vectorstore(text_chunks):
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
return vectorstore
def get_conversation_chain(vectorstore):
llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp')
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory,
)
return conversation_chain
def render_chat_message(message):
"""Render different types of chat messages"""
if "dataframe" in message:
st.dataframe(message["dataframe"])
elif "plot" in message:
try:
# Handle base64 encoded images
plot_data = message["plot"]
if isinstance(plot_data, str):
st.image(f"data:image/png;base64,{plot_data}")
elif isinstance(plot_data, Image.Image):
st.image(plot_data)
elif isinstance(plot_data, go.Figure):
st.plotly_chart(plot_data)
elif isinstance(plot_data, bytes):
image = Image.open(io.BytesIO(plot_data))
st.image(image)
else:
st.write("Unsupported plot format")
except Exception as e:
st.error(f"Error rendering plot: {e}")
# Always render text content
if "content" in message:
st.markdown(message["content"])
def handle_userinput(question, pdf_vectorstore, dfs):
"""Enhanced input handling with robust content processing"""
try:
if pdf_vectorstore and st.session_state.conversation:
# PDF/Vector search mode
response = st.session_state.conversation({"question": question})
st.session_state.chat_history.append({
"role": "user",
"content": question
})
assistant_response = response.get('answer', '')
st.session_state.chat_history.append({
"role": "assistant",
"content": assistant_response
})
elif dfs:
# PandasAI data analysis mode
st.session_state.chat_history.append({
"role": "user",
"content": question
})
# Generate response with PandasAI
result = generateResponse(question, dfs)
# Handle different response types
if isinstance(result, dict):
response_type = result.get('type', 'text')
response_value = result.get('value')
if response_type == 'dataframe':
st.session_state.chat_history.append({
"role": "assistant",
"content": "Here's the table:",
"dataframe": response_value
})
elif response_type == 'plot':
st.session_state.chat_history.append({
"role": "assistant",
"content": "Here's the chart:",
"plot": response_value
})
else:
st.session_state.chat_history.append({
"role": "assistant",
"content": str(response_value)
})
else:
st.session_state.chat_history.append({
"role": "assistant",
"content": str(result)
})
else:
st.write("Please upload and process your documents/data first.")
st.rerun()
except Exception as e:
st.error(f"Error processing input: {e}")
def main():
st.set_page_config(page_title="Chat with PDFs or your Data", page_icon=":books:")
# Initialize session state variables
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "vectorstore" not in st.session_state:
st.session_state.vectorstore = None
if "dfs" not in st.session_state:
st.session_state.dfs = None
st.title("AI Chat with your PDFs :books: or your Data :bar_chart:")
# Chat display with enhanced rendering
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
render_chat_message(message)
# Chat input
user_question = st.chat_input("Ask a question about your documents or data:")
if user_question:
handle_userinput(user_question, st.session_state.vectorstore, st.session_state.dfs)
# Sidebar for file upload
with st.sidebar:
st.sidebar.image("logoqb.jpeg", use_container_width=True)
st.subheader("Your files")
uploaded_files = st.file_uploader(
"Upload PDFs, CSVs, or Excel files (up to 3)",
accept_multiple_files=True,
key="file_uploader",
type=['pdf', 'csv', 'xlsx', 'xls']
)
if st.button("Process"):
with st.spinner("Processing"):
pdf_docs = []
dfs = []
pdf_uploaded = False
data_uploaded = False
# File processing logic remains the same as in the original code
for uploaded_file in uploaded_files:
file_extension = uploaded_file.name.split(".")[-1].lower()
if file_extension == "pdf":
if data_uploaded:
if st.session_state.dfs:
st.session_state.dfs = None
data_uploaded = False
st.warning("Switching to PDF mode. Data files removed.")
pdf_docs.append(uploaded_file)
pdf_uploaded = True
elif file_extension in ["csv", "xlsx", "xls"]:
if pdf_uploaded:
if st.session_state.vectorstore:
st.session_state.vectorstore = None
st.session_state.conversation = None
pdf_uploaded = False
st.warning("Switching to Data mode. PDF files removed.")
try:
if file_extension == 'csv':
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
dfs.append(df)
data_uploaded = True
except Exception as e:
st.error(f"Error reading {uploaded_file.name}: {e}")
st.stop()
# Set up vectorstore and conversation chain for PDFs
if pdf_docs:
raw_text = get_pdf_text(pdf_docs)
text_chunks = get_text_chunks(raw_text)
st.session_state.vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(st.session_state.vectorstore)
else:
st.session_state.vectorstore = None
st.session_state.conversation = None
# Set up DataFrames for PandasAI
if dfs:
st.session_state.dfs = dfs
else:
st.session_state.dfs = None
if st.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
if __name__ == "__main__":
main()