Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import PyPDF2 | |
| import docx | |
| from io import BytesIO | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import pickle | |
| from groq import Groq | |
| from typing import List, Tuple | |
| import re | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="π€ Smart RAG Assistant", | |
| page_icon="π§ ", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| padding: 2rem; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| color: white; | |
| } | |
| .chat-message { | |
| padding: 1rem; | |
| border-radius: 10px; | |
| margin: 1rem 0; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .user-message { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| margin-left: 20%; | |
| } | |
| .bot-message { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
| color: white; | |
| margin-right: 20%; | |
| } | |
| .sidebar-info { | |
| background: #f0f2f6; | |
| padding: 1rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| } | |
| .doc-info { | |
| background: #e8f4fd; | |
| padding: 1rem; | |
| border-radius: 10px; | |
| border: 1px solid #b3d9ff; | |
| margin: 1rem 0; | |
| } | |
| .stButton > button { | |
| width: 100%; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border: none; | |
| padding: 0.5rem 1rem; | |
| border-radius: 10px; | |
| font-weight: bold; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| class RAGSystem: | |
| def __init__(self): | |
| self.embedding_model = None | |
| self.index = None | |
| self.documents = [] | |
| self.groq_client = None | |
| def load_embedding_model(_self): | |
| """Load the sentence transformer model""" | |
| try: | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading embedding model: {str(e)}") | |
| return None | |
| def setup_groq_client(self, api_key: str): | |
| """Setup Groq client""" | |
| try: | |
| self.groq_client = Groq(api_key=api_key) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error setting up Groq client: {str(e)}") | |
| return False | |
| def extract_text_from_pdf(self, pdf_file) -> str: | |
| """Extract text from PDF file""" | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_file.read())) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {str(e)}") | |
| return "" | |
| def extract_text_from_docx(self, docx_file) -> str: | |
| """Extract text from DOCX file""" | |
| try: | |
| doc = docx.Document(BytesIO(docx_file.read())) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| st.error(f"Error reading DOCX: {str(e)}") | |
| return "" | |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| sentences = re.split(r'[.!?]+', text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| if len(current_chunk) + len(sentence) < chunk_size: | |
| current_chunk += sentence + ". " | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence + ". " | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def create_embeddings_and_index(self, documents: List[str]): | |
| """Create embeddings and FAISS index""" | |
| if not self.embedding_model: | |
| self.embedding_model = self.load_embedding_model() | |
| if not self.embedding_model: | |
| return False | |
| try: | |
| # Create embeddings | |
| embeddings = self.embedding_model.encode(documents, show_progress_bar=True) | |
| # Create FAISS index | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product similarity | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings.astype('float32')) | |
| self.documents = documents | |
| return True | |
| except Exception as e: | |
| st.error(f"Error creating embeddings: {str(e)}") | |
| return False | |
| def retrieve_relevant_docs(self, query: str, k: int = 3) -> List[Tuple[str, float]]: | |
| """Retrieve most relevant documents for the query""" | |
| if not self.embedding_model or not self.index: | |
| return [] | |
| try: | |
| # Encode query | |
| query_embedding = self.embedding_model.encode([query]) | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding.astype('float32'), k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx < len(self.documents): | |
| results.append((self.documents[idx], float(score))) | |
| return results | |
| except Exception as e: | |
| st.error(f"Error retrieving documents: {str(e)}") | |
| return [] | |
| def generate_answer(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str: | |
| """Generate answer using Groq""" | |
| if not self.groq_client: | |
| return "Error: Groq client not initialized" | |
| try: | |
| prompt = f"""Based on the following context, please answer the question accurately and concisely. If the answer cannot be found in the context, please say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on the provided context. Be accurate and concise." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| model=model, | |
| temperature=0.3, | |
| max_tokens=1000 | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error generating answer: {str(e)}" | |
| def main(): | |
| # Header | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π€ Smart RAG Assistant</h1> | |
| <p>Upload documents and ask questions - powered by Groq & Sentence Transformers</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize RAG system | |
| if 'rag_system' not in st.session_state: | |
| st.session_state.rag_system = RAGSystem() | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Sidebar | |
| with st.sidebar: | |
| st.markdown("## βοΈ Configuration") | |
| # API Key input | |
| api_key = st.text_input( | |
| "π Groq API Key", | |
| type="password", | |
| value="GROQ_API_KEY", | |
| help="Enter your Groq API key" | |
| ) | |
| if api_key: | |
| if st.session_state.rag_system.setup_groq_client(api_key): | |
| st.success("β Groq client configured!") | |
| st.markdown("---") | |
| # Model selection | |
| model_options = [ | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "mixtral-8x7b-32768" | |
| ] | |
| selected_model = st.selectbox("π€ Select Model", model_options) | |
| st.markdown("---") | |
| # Document upload | |
| st.markdown("## π Document Upload") | |
| uploaded_files = st.file_uploader( | |
| "Upload documents", | |
| type=['pdf', 'docx', 'txt'], | |
| accept_multiple_files=True, | |
| help="Upload PDF, DOCX, or TXT files" | |
| ) | |
| if uploaded_files and st.button("π Process Documents"): | |
| with st.spinner("Processing documents..."): | |
| all_text = "" | |
| doc_info = [] | |
| for file in uploaded_files: | |
| if file.type == "application/pdf": | |
| text = st.session_state.rag_system.extract_text_from_pdf(file) | |
| doc_info.append(f"π {file.name} ({len(text)} chars)") | |
| elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
| text = st.session_state.rag_system.extract_text_from_docx(file) | |
| doc_info.append(f"π {file.name} ({len(text)} chars)") | |
| else: # txt | |
| text = str(file.read(), "utf-8") | |
| doc_info.append(f"π {file.name} ({len(text)} chars)") | |
| all_text += text + "\n\n" | |
| # Chunk the text | |
| chunks = st.session_state.rag_system.chunk_text(all_text) | |
| # Create embeddings and index | |
| if st.session_state.rag_system.create_embeddings_and_index(chunks): | |
| st.success(f"β Processed {len(chunks)} chunks from {len(uploaded_files)} documents!") | |
| # Show document info | |
| st.markdown("### π Processed Documents:") | |
| for info in doc_info: | |
| st.markdown(f"- {info}") | |
| # Clear chat history | |
| if st.button("ποΈ Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| # Main content area | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown("## π¬ Chat with your documents") | |
| # Display chat history | |
| chat_container = st.container() | |
| with chat_container: | |
| for i, (role, message) in enumerate(st.session_state.chat_history): | |
| if role == "user": | |
| st.markdown(f""" | |
| <div class="chat-message user-message"> | |
| <strong>πββοΈ You:</strong><br>{message} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.markdown(f""" | |
| <div class="chat-message bot-message"> | |
| <strong>π€ Assistant:</strong><br>{message} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Query input | |
| query = st.text_input( | |
| "Ask a question about your documents:", | |
| placeholder="e.g., What is the main topic discussed in the documents?", | |
| key="query_input" | |
| ) | |
| col_send, col_clear = st.columns([3, 1]) | |
| with col_send: | |
| send_button = st.button("π€ Send", key="send_button") | |
| if (send_button or query) and query: | |
| if not st.session_state.rag_system.documents: | |
| st.warning("β οΈ Please upload and process documents first!") | |
| elif not api_key: | |
| st.warning("β οΈ Please enter your Groq API key!") | |
| else: | |
| with st.spinner("Searching and generating answer..."): | |
| # Retrieve relevant documents | |
| relevant_docs = st.session_state.rag_system.retrieve_relevant_docs(query, k=3) | |
| if relevant_docs: | |
| # Combine context | |
| context = "\n\n".join([doc for doc, score in relevant_docs]) | |
| # Generate answer | |
| answer = st.session_state.rag_system.generate_answer(query, context, selected_model) | |
| # Add to chat history | |
| st.session_state.chat_history.append(("user", query)) | |
| st.session_state.chat_history.append(("assistant", answer)) | |
| # Clear input and rerun | |
| st.rerun() | |
| else: | |
| st.error("No relevant documents found for your query.") | |
| with col2: | |
| st.markdown("## π System Status") | |
| # System info | |
| if st.session_state.rag_system.documents: | |
| st.markdown(f""" | |
| <div class="doc-info"> | |
| <h4>π Knowledge Base</h4> | |
| <p><strong>Documents:</strong> {len(st.session_state.rag_system.documents)} chunks</p> | |
| <p><strong>Status:</strong> β Ready</p> | |
| <p><strong>Model:</strong> {selected_model}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.markdown(""" | |
| <div class="doc-info"> | |
| <h4>π Knowledge Base</h4> | |
| <p><strong>Status:</strong> β No documents loaded</p> | |
| <p>Upload documents to get started!</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Instructions | |
| st.markdown(""" | |
| <div class="sidebar-info"> | |
| <h4>π How to use:</h4> | |
| <ol> | |
| <li>Enter your Groq API key</li> | |
| <li>Upload documents (PDF, DOCX, TXT)</li> | |
| <li>Click "Process Documents"</li> | |
| <li>Ask questions about your documents</li> | |
| </ol> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Features | |
| st.markdown(""" | |
| <div class="sidebar-info"> | |
| <h4>β¨ Features:</h4> | |
| <ul> | |
| <li>π Fast inference with Groq</li> | |
| <li>π§ Smart document chunking</li> | |
| <li>π Semantic search</li> | |
| <li>π¬ Chat history</li> | |
| <li>π± Responsive design</li> | |
| </ul> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |