import os import streamlit as st import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer from groq import Groq import faiss import pickle from typing import List, Dict, Tuple import PyPDF2 import docx from io import BytesIO import time # Initialize Groq client def init_groq_client(api_key: str): """Initialize Groq client with API key""" return Groq(api_key=api_key) # Initialize embedding model @st.cache_resource def load_embedding_model(): """Load and cache the sentence transformer model""" return SentenceTransformer('all-MiniLM-L6-v2') # Document processing functions def extract_text_from_pdf(file): """Extract text from PDF file""" pdf_reader = PyPDF2.PdfReader(file) text = "" for page in pdf_reader.pages: text += page.extract_text() return text def extract_text_from_docx(file): """Extract text from DOCX file""" doc = docx.Document(file) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text def extract_text_from_txt(file): """Extract text from TXT file""" return str(file.read(), "utf-8") def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: """Split text into overlapping chunks""" words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) chunks.append(chunk) if i + chunk_size >= len(words): break return chunks # Vector store class class VectorStore: def __init__(self, embedding_model): self.embedding_model = embedding_model self.documents = [] self.embeddings = [] self.index = None def add_documents(self, documents: List[str]): """Add documents to the vector store""" self.documents.extend(documents) # Generate embeddings new_embeddings = self.embedding_model.encode(documents) if len(self.embeddings) == 0: self.embeddings = new_embeddings else: self.embeddings = np.vstack([self.embeddings, new_embeddings]) # Build/update FAISS index self._build_index() def _build_index(self): """Build FAISS index for similarity search""" if len(self.embeddings) > 0: dimension = self.embeddings.shape[1] self.index = faiss.IndexFlatIP(dimension) # Inner product for similarity # Normalize embeddings for cosine similarity normalized_embeddings = self.embeddings / np.linalg.norm( self.embeddings, axis=1, keepdims=True ) self.index.add(normalized_embeddings.astype('float32')) def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]: """Search for similar documents""" if self.index is None or len(self.documents) == 0: return [] # Encode query query_embedding = self.embedding_model.encode([query]) query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) # Search scores, indices = self.index.search(query_embedding.astype('float32'), top_k) results = [] for score, idx in zip(scores[0], indices[0]): if idx < len(self.documents): results.append((self.documents[idx], float(score))) return results def save(self, filepath: str): """Save vector store to file""" data = { 'documents': self.documents, 'embeddings': self.embeddings.tolist() if len(self.embeddings) > 0 else [] } with open(filepath, 'wb') as f: pickle.dump(data, f) def load(self, filepath: str): """Load vector store from file""" with open(filepath, 'rb') as f: data = pickle.load(f) self.documents = data['documents'] if data['embeddings']: self.embeddings = np.array(data['embeddings']) self._build_index() # RAG class class RAGSystem: def __init__(self, groq_client, embedding_model): self.groq_client = groq_client self.vector_store = VectorStore(embedding_model) def add_documents(self, documents: List[str]): """Add documents to the knowledge base""" self.vector_store.add_documents(documents) def query(self, question: str, model: str = "llama-3.3-70b-versatile", top_k: int = 3) -> Dict: """Answer a question using RAG""" # Retrieve relevant documents retrieved_docs = self.vector_store.search(question, top_k=top_k) if not retrieved_docs: return { "answer": "I don't have any relevant information to answer your question.", "sources": [], "confidence": 0.0 } # Prepare context context = "\n\n".join([doc for doc, score in retrieved_docs]) # Create prompt prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question." Context: {context} Question: {question} Answer:""" try: # Get response from Groq chat_completion = self.groq_client.chat.completions.create( messages=[ { "role": "user", "content": prompt, } ], model=model, temperature=0.1, max_tokens=1000, ) answer = chat_completion.choices[0].message.content return { "answer": answer, "sources": [{"text": doc[:200] + "...", "score": score} for doc, score in retrieved_docs], "confidence": max([score for _, score in retrieved_docs]) if retrieved_docs else 0.0 } except Exception as e: return { "answer": f"Error generating response: {str(e)}", "sources": [], "confidence": 0.0 } # Streamlit App def main(): st.set_page_config( page_title="RAG App with Groq", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) st.title("🤖 RAG App with Groq & Sentence Transformers") st.markdown("Ask questions about your documents using open-source models!") # Sidebar st.sidebar.header("⚙️ Configuration") # API Key input api_key = st.sidebar.text_input( "Groq API Key", value=os.getenv("GROQ_API_KEY", ""), type="password", help="Enter your Groq API key" ) # Model selection model_options = [ "llama-3.3-70b-versatile", "llama-3.1-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768" ] selected_model = st.sidebar.selectbox("Select Model", model_options) # Number of retrieved documents top_k = st.sidebar.slider("Number of retrieved documents", 1, 10, 3) # Initialize components if api_key: try: groq_client = init_groq_client(api_key) embedding_model = load_embedding_model() # Initialize session state if 'rag_system' not in st.session_state: st.session_state.rag_system = RAGSystem(groq_client, embedding_model) # Main content area col1, col2 = st.columns([1, 1]) with col1: st.header("📁 Document Upload") uploaded_files = st.file_uploader( "Upload your documents", type=['pdf', 'docx', 'txt'], accept_multiple_files=True, help="Supported formats: PDF, DOCX, TXT" ) if uploaded_files: if st.button("Process Documents", type="primary"): with st.spinner("Processing documents..."): all_chunks = [] for file in uploaded_files: # Extract text based on file type if file.type == "application/pdf": text = extract_text_from_pdf(file) elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": text = extract_text_from_docx(file) elif file.type == "text/plain": text = extract_text_from_txt(file) else: st.error(f"Unsupported file type: {file.type}") continue # Chunk the text chunks = chunk_text(text, chunk_size=500, overlap=50) all_chunks.extend(chunks) st.success(f"✅ Processed {file.name}: {len(chunks)} chunks") # Add to RAG system if all_chunks: st.session_state.rag_system.add_documents(all_chunks) st.success(f"🎉 Added {len(all_chunks)} chunks to knowledge base!") # Display document stats if hasattr(st.session_state.rag_system, 'vector_store') and len(st.session_state.rag_system.vector_store.documents) > 0: st.info(f"📊 Knowledge Base: {len(st.session_state.rag_system.vector_store.documents)} chunks") with col2: st.header("💬 Ask Questions") # Chat interface if "messages" not in st.session_state: st.session_state.messages = [] # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) if message["role"] == "assistant" and "sources" in message: with st.expander("📚 Sources"): for i, source in enumerate(message["sources"]): st.write(f"**Source {i+1}** (Score: {source['score']:.3f})") st.write(source["text"]) # Chat input if prompt := st.chat_input("Ask a question about your documents..."): # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) # Generate response with st.chat_message("assistant"): with st.spinner("Thinking..."): response = st.session_state.rag_system.query( prompt, model=selected_model, top_k=top_k ) st.write(response["answer"]) # Show sources if response["sources"]: with st.expander("📚 Sources"): for i, source in enumerate(response["sources"]): st.write(f"**Source {i+1}** (Score: {source['score']:.3f})") st.write(source["text"]) # Add to chat history st.session_state.messages.append({ "role": "assistant", "content": response["answer"], "sources": response["sources"] }) # Clear chat button if st.button("🗑️ Clear Chat"): st.session_state.messages = [] st.rerun() except Exception as e: st.error(f"Error initializing components: {str(e)}") else: st.warning("Please enter your Groq API key in the sidebar to get started.") # Footer st.sidebar.markdown("---") st.sidebar.markdown( """ **About this app:** - Uses Groq for fast inference - Sentence Transformers for embeddings - FAISS for vector search - Supports PDF, DOCX, TXT files """ ) if __name__ == "__main__": main()