import streamlit as st from dotenv import load_dotenv from PyPDF2 import PdfReader from langchain.text_splitter import CharacterTextSplitter from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI from langchain_community.vectorstores import FAISS from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain import os import pandas as pd from pandasai import SmartDataframe, SmartDatalake from pandasai.responses.response_parser import ResponseParser from pandasai.llm import GoogleGemini import plotly.graph_objects as go from PIL import Image import io import base64 class StreamLitResponse(ResponseParser): def __init__(self, context): super().__init__(context) def format_dataframe(self, result): """Enhanced DataFrame rendering with type identifier""" return { 'type': 'dataframe', 'value': result['value'] } def format_plot(self, result): """Enhanced plot rendering with type identifier""" try: image = result['value'] # Convert image to base64 for consistent storage if isinstance(image, Image.Image): buffered = io.BytesIO() image.save(buffered, format="PNG") base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') elif isinstance(image, bytes): base64_image = base64.b64encode(image).decode('utf-8') elif isinstance(image, str) and os.path.exists(image): with open(image, "rb") as f: base64_image = base64.b64encode(f.read()).decode('utf-8') else: return {'type': 'text', 'value': "Unsupported image format"} return { 'type': 'plot', 'value': base64_image } except Exception as e: return {'type': 'text', 'value': f"Error processing plot: {e}"} def format_other(self, result): """Handle other types of responses""" return { 'type': 'text', 'value': str(result['value']) } # Load environment variables load_dotenv() GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') if not GOOGLE_API_KEY: st.error("GOOGLE_API_KEY environment variable not set.") st.stop() def generateResponse(prompt, dfs): """Generate response using PandasAI""" llm = GoogleGemini(api_key=GOOGLE_API_KEY) pandas_agent = SmartDatalake(dfs, config={ "llm": llm, "response_parser": StreamLitResponse }) return pandas_agent.chat(prompt) # Other utility functions remain the same as in the original code # (get_pdf_text, get_text_chunks, get_vectorstore, get_conversation_chain) # Processing pdfs def get_pdf_text(pdf_docs): text = "" for pdf in pdf_docs: pdf_reader = PdfReader(pdf) for page in pdf_reader.pages: text += page.extract_text() return text # Splitting text into small chunks to create embeddings def get_text_chunks(text): text_splitter = CharacterTextSplitter( separator = "\n", chunk_size = 1000, chunk_overlap = 200, length_function = len ) chunks = text_splitter.split_text(text) return chunks # Using Google's embedding004 model to create embeddings and FAISS to store the embeddings def get_vectorstore(text_chunks): embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004") vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) return vectorstore def get_conversation_chain(vectorstore): llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp') memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vectorstore.as_retriever(), memory=memory, ) return conversation_chain def render_chat_message(message): """Render different types of chat messages""" if "dataframe" in message: st.dataframe(message["dataframe"]) elif "plot" in message: try: # Handle base64 encoded images plot_data = message["plot"] if isinstance(plot_data, str): st.image(f"data:image/png;base64,{plot_data}") elif isinstance(plot_data, Image.Image): st.image(plot_data) elif isinstance(plot_data, go.Figure): st.plotly_chart(plot_data) elif isinstance(plot_data, bytes): image = Image.open(io.BytesIO(plot_data)) st.image(image) else: st.write("Unsupported plot format") except Exception as e: st.error(f"Error rendering plot: {e}") # Always render text content if "content" in message: st.markdown(message["content"]) def handle_userinput(question, pdf_vectorstore, dfs): """Enhanced input handling with robust content processing""" try: if pdf_vectorstore and st.session_state.conversation: # PDF/Vector search mode response = st.session_state.conversation({"question": question}) st.session_state.chat_history.append({ "role": "user", "content": question }) assistant_response = response.get('answer', '') st.session_state.chat_history.append({ "role": "assistant", "content": assistant_response }) elif dfs: # PandasAI data analysis mode st.session_state.chat_history.append({ "role": "user", "content": question }) # Generate response with PandasAI result = generateResponse(question, dfs) # Handle different response types if isinstance(result, dict): response_type = result.get('type', 'text') response_value = result.get('value') if response_type == 'dataframe': st.session_state.chat_history.append({ "role": "assistant", "content": "Here's the table:", "dataframe": response_value }) elif response_type == 'plot': st.session_state.chat_history.append({ "role": "assistant", "content": "Here's the chart:", "plot": response_value }) else: st.session_state.chat_history.append({ "role": "assistant", "content": str(response_value) }) else: st.session_state.chat_history.append({ "role": "assistant", "content": str(result) }) else: st.write("Please upload and process your documents/data first.") st.rerun() except Exception as e: st.error(f"Error processing input: {e}") def main(): st.set_page_config(page_title="Chat with PDFs or your Data", page_icon=":books:") # Initialize session state variables if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "vectorstore" not in st.session_state: st.session_state.vectorstore = None if "dfs" not in st.session_state: st.session_state.dfs = None st.title("AI Chat with your PDFs :books: or your Data :bar_chart:") # Chat display with enhanced rendering for message in st.session_state.chat_history: with st.chat_message(message["role"]): render_chat_message(message) # Chat input user_question = st.chat_input("Ask a question about your documents or data:") if user_question: handle_userinput(user_question, st.session_state.vectorstore, st.session_state.dfs) # Sidebar for file upload with st.sidebar: st.sidebar.image("logoqb.jpeg", use_container_width=True) st.subheader("Your files") uploaded_files = st.file_uploader( "Upload PDFs, CSVs, or Excel files (up to 3)", accept_multiple_files=True, key="file_uploader", type=['pdf', 'csv', 'xlsx', 'xls'] ) if st.button("Process"): with st.spinner("Processing"): pdf_docs = [] dfs = [] pdf_uploaded = False data_uploaded = False # File processing logic remains the same as in the original code for uploaded_file in uploaded_files: file_extension = uploaded_file.name.split(".")[-1].lower() if file_extension == "pdf": if data_uploaded: if st.session_state.dfs: st.session_state.dfs = None data_uploaded = False st.warning("Switching to PDF mode. Data files removed.") pdf_docs.append(uploaded_file) pdf_uploaded = True elif file_extension in ["csv", "xlsx", "xls"]: if pdf_uploaded: if st.session_state.vectorstore: st.session_state.vectorstore = None st.session_state.conversation = None pdf_uploaded = False st.warning("Switching to Data mode. PDF files removed.") try: if file_extension == 'csv': df = pd.read_csv(uploaded_file) else: df = pd.read_excel(uploaded_file) dfs.append(df) data_uploaded = True except Exception as e: st.error(f"Error reading {uploaded_file.name}: {e}") st.stop() # Set up vectorstore and conversation chain for PDFs if pdf_docs: raw_text = get_pdf_text(pdf_docs) text_chunks = get_text_chunks(raw_text) st.session_state.vectorstore = get_vectorstore(text_chunks) st.session_state.conversation = get_conversation_chain(st.session_state.vectorstore) else: st.session_state.vectorstore = None st.session_state.conversation = None # Set up DataFrames for PandasAI if dfs: st.session_state.dfs = dfs else: st.session_state.dfs = None if st.button("Clear Chat"): st.session_state.chat_history = [] st.rerun() if __name__ == "__main__": main()