Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from PyPDF2 import PdfReader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| import os | |
| import pandas as pd | |
| from pandasai import SmartDataframe, SmartDatalake | |
| from pandasai.responses.response_parser import ResponseParser | |
| from pandasai.llm import GoogleGemini | |
| import plotly.graph_objects as go | |
| from PIL import Image | |
| import io | |
| import base64 | |
| class StreamLitResponse(ResponseParser): | |
| def __init__(self, context): | |
| super().__init__(context) | |
| def format_dataframe(self, result): | |
| """Enhanced DataFrame rendering with type identifier""" | |
| return { | |
| 'type': 'dataframe', | |
| 'value': result['value'] | |
| } | |
| def format_plot(self, result): | |
| """Enhanced plot rendering with type identifier""" | |
| try: | |
| image = result['value'] | |
| # Convert image to base64 for consistent storage | |
| if isinstance(image, Image.Image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| elif isinstance(image, bytes): | |
| base64_image = base64.b64encode(image).decode('utf-8') | |
| elif isinstance(image, str) and os.path.exists(image): | |
| with open(image, "rb") as f: | |
| base64_image = base64.b64encode(f.read()).decode('utf-8') | |
| else: | |
| return {'type': 'text', 'value': "Unsupported image format"} | |
| return { | |
| 'type': 'plot', | |
| 'value': base64_image | |
| } | |
| except Exception as e: | |
| return {'type': 'text', 'value': f"Error processing plot: {e}"} | |
| def format_other(self, result): | |
| """Handle other types of responses""" | |
| return { | |
| 'type': 'text', | |
| 'value': str(result['value']) | |
| } | |
| # Load environment variables | |
| load_dotenv() | |
| GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') | |
| if not GOOGLE_API_KEY: | |
| st.error("GOOGLE_API_KEY environment variable not set.") | |
| st.stop() | |
| def generateResponse(prompt, dfs): | |
| """Generate response using PandasAI""" | |
| llm = GoogleGemini(api_key=GOOGLE_API_KEY) | |
| pandas_agent = SmartDatalake(dfs, config={ | |
| "llm": llm, | |
| "response_parser": StreamLitResponse | |
| }) | |
| return pandas_agent.chat(prompt) | |
| # Other utility functions remain the same as in the original code | |
| # (get_pdf_text, get_text_chunks, get_vectorstore, get_conversation_chain) | |
| # Processing pdfs | |
| def get_pdf_text(pdf_docs): | |
| text = "" | |
| for pdf in pdf_docs: | |
| pdf_reader = PdfReader(pdf) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| return text | |
| # Splitting text into small chunks to create embeddings | |
| def get_text_chunks(text): | |
| text_splitter = CharacterTextSplitter( | |
| separator = "\n", | |
| chunk_size = 1000, | |
| chunk_overlap = 200, | |
| length_function = len | |
| ) | |
| chunks = text_splitter.split_text(text) | |
| return chunks | |
| # Using Google's embedding004 model to create embeddings and FAISS to store the embeddings | |
| def get_vectorstore(text_chunks): | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004") | |
| vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) | |
| return vectorstore | |
| def get_conversation_chain(vectorstore): | |
| llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp') | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| conversation_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=vectorstore.as_retriever(), | |
| memory=memory, | |
| ) | |
| return conversation_chain | |
| def render_chat_message(message): | |
| """Render different types of chat messages""" | |
| if "dataframe" in message: | |
| st.dataframe(message["dataframe"]) | |
| elif "plot" in message: | |
| try: | |
| # Handle base64 encoded images | |
| plot_data = message["plot"] | |
| if isinstance(plot_data, str): | |
| st.image(f"data:image/png;base64,{plot_data}") | |
| elif isinstance(plot_data, Image.Image): | |
| st.image(plot_data) | |
| elif isinstance(plot_data, go.Figure): | |
| st.plotly_chart(plot_data) | |
| elif isinstance(plot_data, bytes): | |
| image = Image.open(io.BytesIO(plot_data)) | |
| st.image(image) | |
| else: | |
| st.write("Unsupported plot format") | |
| except Exception as e: | |
| st.error(f"Error rendering plot: {e}") | |
| # Always render text content | |
| if "content" in message: | |
| st.markdown(message["content"]) | |
| def handle_userinput(question, pdf_vectorstore, dfs): | |
| """Enhanced input handling with robust content processing""" | |
| try: | |
| if pdf_vectorstore and st.session_state.conversation: | |
| # PDF/Vector search mode | |
| response = st.session_state.conversation({"question": question}) | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": question | |
| }) | |
| assistant_response = response.get('answer', '') | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": assistant_response | |
| }) | |
| elif dfs: | |
| # PandasAI data analysis mode | |
| st.session_state.chat_history.append({ | |
| "role": "user", | |
| "content": question | |
| }) | |
| # Generate response with PandasAI | |
| result = generateResponse(question, dfs) | |
| # Handle different response types | |
| if isinstance(result, dict): | |
| response_type = result.get('type', 'text') | |
| response_value = result.get('value') | |
| if response_type == 'dataframe': | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": "Here's the table:", | |
| "dataframe": response_value | |
| }) | |
| elif response_type == 'plot': | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": "Here's the chart:", | |
| "plot": response_value | |
| }) | |
| else: | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": str(response_value) | |
| }) | |
| else: | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": str(result) | |
| }) | |
| else: | |
| st.write("Please upload and process your documents/data first.") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error processing input: {e}") | |
| def main(): | |
| st.set_page_config(page_title="Chat with PDFs or your Data", page_icon=":books:") | |
| # Initialize session state variables | |
| if "conversation" not in st.session_state: | |
| st.session_state.conversation = None | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "vectorstore" not in st.session_state: | |
| st.session_state.vectorstore = None | |
| if "dfs" not in st.session_state: | |
| st.session_state.dfs = None | |
| st.title("AI Chat with your PDFs :books: or your Data :bar_chart:") | |
| # Chat display with enhanced rendering | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| render_chat_message(message) | |
| # Chat input | |
| user_question = st.chat_input("Ask a question about your documents or data:") | |
| if user_question: | |
| handle_userinput(user_question, st.session_state.vectorstore, st.session_state.dfs) | |
| # Sidebar for file upload | |
| with st.sidebar: | |
| st.sidebar.image("logoqb.jpeg", use_container_width=True) | |
| st.subheader("Your files") | |
| uploaded_files = st.file_uploader( | |
| "Upload PDFs, CSVs, or Excel files (up to 3)", | |
| accept_multiple_files=True, | |
| key="file_uploader", | |
| type=['pdf', 'csv', 'xlsx', 'xls'] | |
| ) | |
| if st.button("Process"): | |
| with st.spinner("Processing"): | |
| pdf_docs = [] | |
| dfs = [] | |
| pdf_uploaded = False | |
| data_uploaded = False | |
| # File processing logic remains the same as in the original code | |
| for uploaded_file in uploaded_files: | |
| file_extension = uploaded_file.name.split(".")[-1].lower() | |
| if file_extension == "pdf": | |
| if data_uploaded: | |
| if st.session_state.dfs: | |
| st.session_state.dfs = None | |
| data_uploaded = False | |
| st.warning("Switching to PDF mode. Data files removed.") | |
| pdf_docs.append(uploaded_file) | |
| pdf_uploaded = True | |
| elif file_extension in ["csv", "xlsx", "xls"]: | |
| if pdf_uploaded: | |
| if st.session_state.vectorstore: | |
| st.session_state.vectorstore = None | |
| st.session_state.conversation = None | |
| pdf_uploaded = False | |
| st.warning("Switching to Data mode. PDF files removed.") | |
| try: | |
| if file_extension == 'csv': | |
| df = pd.read_csv(uploaded_file) | |
| else: | |
| df = pd.read_excel(uploaded_file) | |
| dfs.append(df) | |
| data_uploaded = True | |
| except Exception as e: | |
| st.error(f"Error reading {uploaded_file.name}: {e}") | |
| st.stop() | |
| # Set up vectorstore and conversation chain for PDFs | |
| if pdf_docs: | |
| raw_text = get_pdf_text(pdf_docs) | |
| text_chunks = get_text_chunks(raw_text) | |
| st.session_state.vectorstore = get_vectorstore(text_chunks) | |
| st.session_state.conversation = get_conversation_chain(st.session_state.vectorstore) | |
| else: | |
| st.session_state.vectorstore = None | |
| st.session_state.conversation = None | |
| # Set up DataFrames for PandasAI | |
| if dfs: | |
| st.session_state.dfs = dfs | |
| else: | |
| st.session_state.dfs = None | |
| if st.button("Clear Chat"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |