Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from groq import Groq | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import pickle | |
| from typing import List, Dict, Any | |
| import PyPDF2 | |
| import docx | |
| from io import BytesIO | |
| import tempfile | |
| # Set page config | |
| st.set_page_config( | |
| page_title="RAG Chat Assistant", | |
| page_icon="🤖", | |
| layout="wide" | |
| ) | |
| class RAGSystem: | |
| def __init__(self, groq_api_key: str): | |
| """Initialize the RAG system with Groq client and embedding model""" | |
| self.groq_client = Groq(api_key=groq_api_key) | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.index = None | |
| self.documents = [] | |
| self.embeddings = None | |
| def extract_text_from_pdf(self, file) -> str: | |
| """Extract text from PDF file""" | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {str(e)}") | |
| return "" | |
| def extract_text_from_docx(self, file) -> str: | |
| """Extract text from DOCX file""" | |
| try: | |
| doc = docx.Document(file) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| st.error(f"Error reading DOCX: {str(e)}") | |
| return "" | |
| def extract_text_from_txt(self, file) -> str: | |
| """Extract text from TXT file""" | |
| try: | |
| return str(file.read(), "utf-8") | |
| except Exception as e: | |
| st.error(f"Error reading TXT: {str(e)}") | |
| return "" | |
| def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk.strip()) | |
| return chunks | |
| def process_documents(self, uploaded_files) -> None: | |
| """Process uploaded documents and create embeddings""" | |
| all_chunks = [] | |
| for uploaded_file in uploaded_files: | |
| file_extension = uploaded_file.name.split('.')[-1].lower() | |
| # Extract text based on file type | |
| if file_extension == 'pdf': | |
| text = self.extract_text_from_pdf(uploaded_file) | |
| elif file_extension == 'docx': | |
| text = self.extract_text_from_docx(uploaded_file) | |
| elif file_extension == 'txt': | |
| text = self.extract_text_from_txt(uploaded_file) | |
| else: | |
| st.error(f"Unsupported file type: {file_extension}") | |
| continue | |
| if text: | |
| # Chunk the text | |
| chunks = self.chunk_text(text) | |
| for chunk in chunks: | |
| all_chunks.append({ | |
| 'text': chunk, | |
| 'source': uploaded_file.name | |
| }) | |
| if all_chunks: | |
| self.documents = all_chunks | |
| # Create embeddings | |
| texts = [doc['text'] for doc in all_chunks] | |
| embeddings = self.embedding_model.encode(texts) | |
| self.embeddings = embeddings | |
| # Create FAISS index | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dimension) | |
| self.index.add(embeddings.astype('float32')) | |
| st.success(f"Processed {len(all_chunks)} chunks from {len(uploaded_files)} documents") | |
| else: | |
| st.error("No text could be extracted from the uploaded files") | |
| def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]: | |
| """Retrieve the most relevant chunks for a given query""" | |
| if self.index is None: | |
| return [] | |
| # Encode the query | |
| query_embedding = self.embedding_model.encode([query]) | |
| # Search for similar chunks | |
| distances, indices = self.index.search(query_embedding.astype('float32'), k) | |
| relevant_chunks = [] | |
| for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): | |
| if idx < len(self.documents): | |
| relevant_chunks.append({ | |
| 'text': self.documents[idx]['text'], | |
| 'source': self.documents[idx]['source'], | |
| 'similarity_score': 1 / (1 + distance), # Convert distance to similarity | |
| 'rank': i + 1 | |
| }) | |
| return relevant_chunks | |
| def generate_response(self, query: str, relevant_chunks: List[Dict[str, Any]]) -> str: | |
| """Generate response using Groq with retrieved context""" | |
| # Prepare context from relevant chunks | |
| context = "\n\n".join([f"Source: {chunk['source']}\nContent: {chunk['text']}" | |
| for chunk in relevant_chunks]) | |
| # Create prompt | |
| prompt = f"""Based on the following context, please answer the question accurately and comprehensively. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Instructions: | |
| - Use only the information provided in the context to answer the question | |
| - If the context doesn't contain enough information to answer the question, say so | |
| - Cite the sources when possible | |
| - Be concise but comprehensive | |
| Answer:""" | |
| try: | |
| # Generate response using Groq | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on provided context. Always be accurate and cite your sources." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.1, | |
| max_tokens=1024 | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def main(): | |
| st.title("🤖 RAG Chat Assistant") | |
| st.markdown("Upload documents and ask questions using Retrieval Augmented Generation") | |
| # Initialize session state | |
| if 'rag_system' not in st.session_state: | |
| st.session_state.rag_system = None | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.header("Configuration") | |
| # API Key input (pre-filled with your key) | |
| groq_api_key = st.text_input( | |
| "Groq API Key", | |
| value=os.getenv("GROQ_API_KEY", ""), # Reads from environme | |
| type="password", | |
| help="Enter your Groq API key" | |
| ) | |
| # Initialize RAG system when API key is provided | |
| if groq_api_key and st.session_state.rag_system is None: | |
| try: | |
| st.session_state.rag_system = RAGSystem(groq_api_key) | |
| st.success("RAG System initialized!") | |
| except Exception as e: | |
| st.error(f"Error initializing RAG system: {str(e)}") | |
| st.header("Document Upload") | |
| uploaded_files = st.file_uploader( | |
| "Upload documents", | |
| accept_multiple_files=True, | |
| type=['pdf', 'txt', 'docx'], | |
| help="Upload PDF, TXT, or DOCX files" | |
| ) | |
| # Process documents button | |
| if uploaded_files and st.session_state.rag_system: | |
| if st.button("Process Documents"): | |
| with st.spinner("Processing documents..."): | |
| st.session_state.rag_system.process_documents(uploaded_files) | |
| # Retrieval settings | |
| st.header("Retrieval Settings") | |
| num_chunks = st.slider("Number of chunks to retrieve", 1, 10, 3) | |
| # Main chat interface | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.header("Chat") | |
| # Display chat history | |
| for i, (question, answer) in enumerate(st.session_state.chat_history): | |
| st.write(f"**You:** {question}") | |
| st.write(f"**Assistant:** {answer}") | |
| st.divider() | |
| # Query input | |
| query = st.text_input("Ask a question about your documents:", key="query_input") | |
| if st.button("Ask") and query: | |
| if not st.session_state.rag_system: | |
| st.error("Please enter a valid Groq API key first") | |
| elif not st.session_state.rag_system.documents: | |
| st.error("Please upload and process documents first") | |
| else: | |
| with st.spinner("Generating response..."): | |
| # Retrieve relevant chunks | |
| relevant_chunks = st.session_state.rag_system.retrieve_relevant_chunks( | |
| query, k=num_chunks | |
| ) | |
| if relevant_chunks: | |
| # Generate response | |
| response = st.session_state.rag_system.generate_response(query, relevant_chunks) | |
| # Add to chat history | |
| st.session_state.chat_history.append((query, response)) | |
| # Display the response | |
| st.write(f"**You:** {query}") | |
| st.write(f"**Assistant:** {response}") | |
| # Show retrieved chunks in sidebar | |
| with col2: | |
| st.header("Retrieved Context") | |
| for chunk in relevant_chunks: | |
| with st.expander(f"Rank {chunk['rank']} - {chunk['source']}"): | |
| st.write(f"**Similarity:** {chunk['similarity_score']:.3f}") | |
| st.write(f"**Text:** {chunk['text'][:200]}...") | |
| else: | |
| st.error("No relevant information found in the documents") | |
| # Clear chat history | |
| if st.button("Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.experimental_rerun() | |
| with col2: | |
| if not st.session_state.chat_history: | |
| st.header("Retrieved Context") | |
| st.info("Ask a question to see retrieved context here") | |
| if __name__ == "__main__": | |
| main() |