Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| import faiss | |
| import pickle | |
| from typing import List, Dict, Tuple | |
| import PyPDF2 | |
| import docx | |
| from io import BytesIO | |
| import time | |
| # Initialize Groq client | |
| def init_groq_client(api_key: str): | |
| """Initialize Groq client with API key""" | |
| return Groq(api_key=api_key) | |
| # Initialize embedding model | |
| def load_embedding_model(): | |
| """Load and cache the sentence transformer model""" | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| # Document processing functions | |
| def extract_text_from_pdf(file): | |
| """Extract text from PDF file""" | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| return text | |
| def extract_text_from_docx(file): | |
| """Extract text from DOCX file""" | |
| doc = docx.Document(file) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| def extract_text_from_txt(file): | |
| """Extract text from TXT file""" | |
| return str(file.read(), "utf-8") | |
| def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| chunks.append(chunk) | |
| if i + chunk_size >= len(words): | |
| break | |
| return chunks | |
| # Vector store class | |
| class VectorStore: | |
| def __init__(self, embedding_model): | |
| self.embedding_model = embedding_model | |
| self.documents = [] | |
| self.embeddings = [] | |
| self.index = None | |
| def add_documents(self, documents: List[str]): | |
| """Add documents to the vector store""" | |
| self.documents.extend(documents) | |
| # Generate embeddings | |
| new_embeddings = self.embedding_model.encode(documents) | |
| if len(self.embeddings) == 0: | |
| self.embeddings = new_embeddings | |
| else: | |
| self.embeddings = np.vstack([self.embeddings, new_embeddings]) | |
| # Build/update FAISS index | |
| self._build_index() | |
| def _build_index(self): | |
| """Build FAISS index for similarity search""" | |
| if len(self.embeddings) > 0: | |
| dimension = self.embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product for similarity | |
| # Normalize embeddings for cosine similarity | |
| normalized_embeddings = self.embeddings / np.linalg.norm( | |
| self.embeddings, axis=1, keepdims=True | |
| ) | |
| self.index.add(normalized_embeddings.astype('float32')) | |
| def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]: | |
| """Search for similar documents""" | |
| if self.index is None or len(self.documents) == 0: | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode([query]) | |
| query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) | |
| # Search | |
| scores, indices = self.index.search(query_embedding.astype('float32'), top_k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx < len(self.documents): | |
| results.append((self.documents[idx], float(score))) | |
| return results | |
| def save(self, filepath: str): | |
| """Save vector store to file""" | |
| data = { | |
| 'documents': self.documents, | |
| 'embeddings': self.embeddings.tolist() if len(self.embeddings) > 0 else [] | |
| } | |
| with open(filepath, 'wb') as f: | |
| pickle.dump(data, f) | |
| def load(self, filepath: str): | |
| """Load vector store from file""" | |
| with open(filepath, 'rb') as f: | |
| data = pickle.load(f) | |
| self.documents = data['documents'] | |
| if data['embeddings']: | |
| self.embeddings = np.array(data['embeddings']) | |
| self._build_index() | |
| # RAG class | |
| class RAGSystem: | |
| def __init__(self, groq_client, embedding_model): | |
| self.groq_client = groq_client | |
| self.vector_store = VectorStore(embedding_model) | |
| def add_documents(self, documents: List[str]): | |
| """Add documents to the knowledge base""" | |
| self.vector_store.add_documents(documents) | |
| def query(self, question: str, model: str = "llama-3.3-70b-versatile", top_k: int = 3) -> Dict: | |
| """Answer a question using RAG""" | |
| # Retrieve relevant documents | |
| retrieved_docs = self.vector_store.search(question, top_k=top_k) | |
| if not retrieved_docs: | |
| return { | |
| "answer": "I don't have any relevant information to answer your question.", | |
| "sources": [], | |
| "confidence": 0.0 | |
| } | |
| # Prepare context | |
| context = "\n\n".join([doc for doc, score in retrieved_docs]) | |
| # Create prompt | |
| prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| try: | |
| # Get response from Groq | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model=model, | |
| temperature=0.1, | |
| max_tokens=1000, | |
| ) | |
| answer = chat_completion.choices[0].message.content | |
| return { | |
| "answer": answer, | |
| "sources": [{"text": doc[:200] + "...", "score": score} | |
| for doc, score in retrieved_docs], | |
| "confidence": max([score for _, score in retrieved_docs]) if retrieved_docs else 0.0 | |
| } | |
| except Exception as e: | |
| return { | |
| "answer": f"Error generating response: {str(e)}", | |
| "sources": [], | |
| "confidence": 0.0 | |
| } | |
| # Streamlit App | |
| def main(): | |
| st.set_page_config( | |
| page_title="RAG App with Groq", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.title("π€ RAG App with Groq & Sentence Transformers") | |
| st.markdown("Ask questions about your documents using open-source models!") | |
| # Sidebar | |
| st.sidebar.header("βοΈ Configuration") | |
| # API Key input | |
| api_key = st.sidebar.text_input( | |
| "Groq API Key", | |
| value=st.secrets.get("GROQ_API_KEY", ""), | |
| type="password", | |
| help="Enter your Groq API key" | |
| ) | |
| # Option 2: Fallback to environment variable (useful for local dev) | |
| if not api_key: | |
| api_key = os.getenv("GROQ_API_KEY") | |
| # Model selection | |
| model_options = [ | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "mixtral-8x7b-32768" | |
| ] | |
| selected_model = st.sidebar.selectbox("Select Model", model_options) | |
| # Number of retrieved documents | |
| top_k = st.sidebar.slider("Number of retrieved documents", 1, 10, 3) | |
| # Initialize components | |
| if api_key: | |
| try: | |
| groq_client = init_groq_client(api_key) | |
| embedding_model = load_embedding_model() | |
| # Initialize session state | |
| if 'rag_system' not in st.session_state: | |
| st.session_state.rag_system = RAGSystem(groq_client, embedding_model) | |
| # Main content area | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.header("π Document Upload") | |
| uploaded_files = st.file_uploader( | |
| "Upload your documents", | |
| type=['pdf', 'docx', 'txt'], | |
| accept_multiple_files=True, | |
| help="Supported formats: PDF, DOCX, TXT" | |
| ) | |
| if uploaded_files: | |
| if st.button("Process Documents", type="primary"): | |
| with st.spinner("Processing documents..."): | |
| all_chunks = [] | |
| for file in uploaded_files: | |
| # Extract text based on file type | |
| if file.type == "application/pdf": | |
| text = extract_text_from_pdf(file) | |
| elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
| text = extract_text_from_docx(file) | |
| elif file.type == "text/plain": | |
| text = extract_text_from_txt(file) | |
| else: | |
| st.error(f"Unsupported file type: {file.type}") | |
| continue | |
| # Chunk the text | |
| chunks = chunk_text(text, chunk_size=500, overlap=50) | |
| all_chunks.extend(chunks) | |
| st.success(f"β Processed {file.name}: {len(chunks)} chunks") | |
| # Add to RAG system | |
| if all_chunks: | |
| st.session_state.rag_system.add_documents(all_chunks) | |
| st.success(f"π Added {len(all_chunks)} chunks to knowledge base!") | |
| # Display document stats | |
| if hasattr(st.session_state.rag_system, 'vector_store') and len(st.session_state.rag_system.vector_store.documents) > 0: | |
| st.info(f"π Knowledge Base: {len(st.session_state.rag_system.vector_store.documents)} chunks") | |
| with col2: | |
| st.header("π¬ Ask Questions") | |
| # Chat interface | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| if message["role"] == "assistant" and "sources" in message: | |
| with st.expander("π Sources"): | |
| for i, source in enumerate(message["sources"]): | |
| st.write(f"**Source {i+1}** (Score: {source['score']:.3f})") | |
| st.write(source["text"]) | |
| # Chat input | |
| if prompt := st.chat_input("Ask a question about your documents..."): | |
| # Add user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| # Generate response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| response = st.session_state.rag_system.query( | |
| prompt, | |
| model=selected_model, | |
| top_k=top_k | |
| ) | |
| st.write(response["answer"]) | |
| # Show sources | |
| if response["sources"]: | |
| with st.expander("π Sources"): | |
| for i, source in enumerate(response["sources"]): | |
| st.write(f"**Source {i+1}** (Score: {source['score']:.3f})") | |
| st.write(source["text"]) | |
| # Add to chat history | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response["answer"], | |
| "sources": response["sources"] | |
| }) | |
| # Clear chat button | |
| if st.button("ποΈ Clear Chat"): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error initializing components: {str(e)}") | |
| else: | |
| st.warning("Please enter your Groq API key in the sidebar to get started.") | |
| # Footer | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown( | |
| """ | |
| **About this app:** | |
| - Uses Groq for fast inference | |
| - Sentence Transformers for embeddings | |
| - FAISS for vector search | |
| - Supports PDF, DOCX, TXT files | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| main() |