import streamlit as st import os import numpy as np import pandas as pd from groq import Groq from sentence_transformers import SentenceTransformer import faiss import pickle from typing import List, Dict, Any import PyPDF2 import docx from io import BytesIO import tempfile # Set page config st.set_page_config( page_title="RAG Chat Assistant", page_icon="🤖", layout="wide" ) class RAGSystem: def __init__(self, groq_api_key: str): """Initialize the RAG system with Groq client and embedding model""" self.groq_client = Groq(api_key=groq_api_key) self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.index = None self.documents = [] self.embeddings = None def extract_text_from_pdf(self, file) -> str: """Extract text from PDF file""" try: pdf_reader = PyPDF2.PdfReader(file) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text except Exception as e: st.error(f"Error reading PDF: {str(e)}") return "" def extract_text_from_docx(self, file) -> str: """Extract text from DOCX file""" try: doc = docx.Document(file) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text except Exception as e: st.error(f"Error reading DOCX: {str(e)}") return "" def extract_text_from_txt(self, file) -> str: """Extract text from TXT file""" try: return str(file.read(), "utf-8") except Exception as e: st.error(f"Error reading TXT: {str(e)}") return "" def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]: """Split text into overlapping chunks""" words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if chunk.strip(): chunks.append(chunk.strip()) return chunks def process_documents(self, uploaded_files) -> None: """Process uploaded documents and create embeddings""" all_chunks = [] for uploaded_file in uploaded_files: file_extension = uploaded_file.name.split('.')[-1].lower() # Extract text based on file type if file_extension == 'pdf': text = self.extract_text_from_pdf(uploaded_file) elif file_extension == 'docx': text = self.extract_text_from_docx(uploaded_file) elif file_extension == 'txt': text = self.extract_text_from_txt(uploaded_file) else: st.error(f"Unsupported file type: {file_extension}") continue if text: # Chunk the text chunks = self.chunk_text(text) for chunk in chunks: all_chunks.append({ 'text': chunk, 'source': uploaded_file.name }) if all_chunks: self.documents = all_chunks # Create embeddings texts = [doc['text'] for doc in all_chunks] embeddings = self.embedding_model.encode(texts) self.embeddings = embeddings # Create FAISS index dimension = embeddings.shape[1] self.index = faiss.IndexFlatL2(dimension) self.index.add(embeddings.astype('float32')) st.success(f"Processed {len(all_chunks)} chunks from {len(uploaded_files)} documents") else: st.error("No text could be extracted from the uploaded files") def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]: """Retrieve the most relevant chunks for a given query""" if self.index is None: return [] # Encode the query query_embedding = self.embedding_model.encode([query]) # Search for similar chunks distances, indices = self.index.search(query_embedding.astype('float32'), k) relevant_chunks = [] for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): if idx < len(self.documents): relevant_chunks.append({ 'text': self.documents[idx]['text'], 'source': self.documents[idx]['source'], 'similarity_score': 1 / (1 + distance), # Convert distance to similarity 'rank': i + 1 }) return relevant_chunks def generate_response(self, query: str, relevant_chunks: List[Dict[str, Any]]) -> str: """Generate response using Groq with retrieved context""" # Prepare context from relevant chunks context = "\n\n".join([f"Source: {chunk['source']}\nContent: {chunk['text']}" for chunk in relevant_chunks]) # Create prompt prompt = f"""Based on the following context, please answer the question accurately and comprehensively. Context: {context} Question: {query} Instructions: - Use only the information provided in the context to answer the question - If the context doesn't contain enough information to answer the question, say so - Cite the sources when possible - Be concise but comprehensive Answer:""" try: # Generate response using Groq chat_completion = self.groq_client.chat.completions.create( messages=[ { "role": "system", "content": "You are a helpful assistant that answers questions based on provided context. Always be accurate and cite your sources." }, { "role": "user", "content": prompt } ], model="llama-3.3-70b-versatile", temperature=0.1, max_tokens=1024 ) return chat_completion.choices[0].message.content except Exception as e: return f"Error generating response: {str(e)}" def main(): st.title("🤖 RAG Chat Assistant") st.markdown("Upload documents and ask questions using Retrieval Augmented Generation") # Initialize session state if 'rag_system' not in st.session_state: st.session_state.rag_system = None if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Sidebar for configuration with st.sidebar: st.header("Configuration") # API Key input (pre-filled with your key) groq_api_key = st.text_input( "Groq API Key", value=os.getenv("GROQ_API_KEY", ""), # Reads from environme type="password", help="Enter your Groq API key" ) # Initialize RAG system when API key is provided if groq_api_key and st.session_state.rag_system is None: try: st.session_state.rag_system = RAGSystem(groq_api_key) st.success("RAG System initialized!") except Exception as e: st.error(f"Error initializing RAG system: {str(e)}") st.header("Document Upload") uploaded_files = st.file_uploader( "Upload documents", accept_multiple_files=True, type=['pdf', 'txt', 'docx'], help="Upload PDF, TXT, or DOCX files" ) # Process documents button if uploaded_files and st.session_state.rag_system: if st.button("Process Documents"): with st.spinner("Processing documents..."): st.session_state.rag_system.process_documents(uploaded_files) # Retrieval settings st.header("Retrieval Settings") num_chunks = st.slider("Number of chunks to retrieve", 1, 10, 3) # Main chat interface col1, col2 = st.columns([2, 1]) with col1: st.header("Chat") # Display chat history for i, (question, answer) in enumerate(st.session_state.chat_history): st.write(f"**You:** {question}") st.write(f"**Assistant:** {answer}") st.divider() # Query input query = st.text_input("Ask a question about your documents:", key="query_input") if st.button("Ask") and query: if not st.session_state.rag_system: st.error("Please enter a valid Groq API key first") elif not st.session_state.rag_system.documents: st.error("Please upload and process documents first") else: with st.spinner("Generating response..."): # Retrieve relevant chunks relevant_chunks = st.session_state.rag_system.retrieve_relevant_chunks( query, k=num_chunks ) if relevant_chunks: # Generate response response = st.session_state.rag_system.generate_response(query, relevant_chunks) # Add to chat history st.session_state.chat_history.append((query, response)) # Display the response st.write(f"**You:** {query}") st.write(f"**Assistant:** {response}") # Show retrieved chunks in sidebar with col2: st.header("Retrieved Context") for chunk in relevant_chunks: with st.expander(f"Rank {chunk['rank']} - {chunk['source']}"): st.write(f"**Similarity:** {chunk['similarity_score']:.3f}") st.write(f"**Text:** {chunk['text'][:200]}...") else: st.error("No relevant information found in the documents") # Clear chat history if st.button("Clear Chat History"): st.session_state.chat_history = [] st.experimental_rerun() with col2: if not st.session_state.chat_history: st.header("Retrieved Context") st.info("Ask a question to see retrieved context here") if __name__ == "__main__": main()