BetaGen's picture
Create app.py
ccd7243 verified
import streamlit as st
import os
import PyPDF2
import docx
from io import BytesIO
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import pickle
from groq import Groq
from typing import List, Tuple
import re
# Page configuration
st.set_page_config(
page_title="πŸ€– Smart RAG Assistant",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
color: white;
}
.chat-message {
padding: 1rem;
border-radius: 10px;
margin: 1rem 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.user-message {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
margin-left: 20%;
}
.bot-message {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
color: white;
margin-right: 20%;
}
.sidebar-info {
background: #f0f2f6;
padding: 1rem;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.doc-info {
background: #e8f4fd;
padding: 1rem;
border-radius: 10px;
border: 1px solid #b3d9ff;
margin: 1rem 0;
}
.stButton > button {
width: 100%;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 10px;
font-weight: bold;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
</style>
""", unsafe_allow_html=True)
class RAGSystem:
def __init__(self):
self.embedding_model = None
self.index = None
self.documents = []
self.groq_client = None
@st.cache_resource
def load_embedding_model(_self):
"""Load the sentence transformer model"""
try:
model = SentenceTransformer('all-MiniLM-L6-v2')
return model
except Exception as e:
st.error(f"Error loading embedding model: {str(e)}")
return None
def setup_groq_client(self, api_key: str):
"""Setup Groq client"""
try:
self.groq_client = Groq(api_key=api_key)
return True
except Exception as e:
st.error(f"Error setting up Groq client: {str(e)}")
return False
def extract_text_from_pdf(self, pdf_file) -> str:
"""Extract text from PDF file"""
try:
pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_file.read()))
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text
except Exception as e:
st.error(f"Error reading PDF: {str(e)}")
return ""
def extract_text_from_docx(self, docx_file) -> str:
"""Extract text from DOCX file"""
try:
doc = docx.Document(BytesIO(docx_file.read()))
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
st.error(f"Error reading DOCX: {str(e)}")
return ""
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks"""
sentences = re.split(r'[.!?]+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if len(current_chunk) + len(sentence) < chunk_size:
current_chunk += sentence + ". "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def create_embeddings_and_index(self, documents: List[str]):
"""Create embeddings and FAISS index"""
if not self.embedding_model:
self.embedding_model = self.load_embedding_model()
if not self.embedding_model:
return False
try:
# Create embeddings
embeddings = self.embedding_model.encode(documents, show_progress_bar=True)
# Create FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Inner product similarity
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
self.index.add(embeddings.astype('float32'))
self.documents = documents
return True
except Exception as e:
st.error(f"Error creating embeddings: {str(e)}")
return False
def retrieve_relevant_docs(self, query: str, k: int = 3) -> List[Tuple[str, float]]:
"""Retrieve most relevant documents for the query"""
if not self.embedding_model or not self.index:
return []
try:
# Encode query
query_embedding = self.embedding_model.encode([query])
faiss.normalize_L2(query_embedding)
# Search
scores, indices = self.index.search(query_embedding.astype('float32'), k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.documents):
results.append((self.documents[idx], float(score)))
return results
except Exception as e:
st.error(f"Error retrieving documents: {str(e)}")
return []
def generate_answer(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str:
"""Generate answer using Groq"""
if not self.groq_client:
return "Error: Groq client not initialized"
try:
prompt = f"""Based on the following context, please answer the question accurately and concisely. If the answer cannot be found in the context, please say so.
Context:
{context}
Question: {query}
Answer:"""
chat_completion = self.groq_client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context. Be accurate and concise."
},
{
"role": "user",
"content": prompt
}
],
model=model,
temperature=0.3,
max_tokens=1000
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"Error generating answer: {str(e)}"
def main():
# Header
st.markdown("""
<div class="main-header">
<h1>πŸ€– Smart RAG Assistant</h1>
<p>Upload documents and ask questions - powered by Groq & Sentence Transformers</p>
</div>
""", unsafe_allow_html=True)
# Initialize RAG system
if 'rag_system' not in st.session_state:
st.session_state.rag_system = RAGSystem()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Sidebar
with st.sidebar:
st.markdown("## βš™οΈ Configuration")
# API Key input
api_key = st.text_input(
"πŸ”‘ Groq API Key",
type="password",
value="GROQ_API_KEY",
help="Enter your Groq API key"
)
if api_key:
if st.session_state.rag_system.setup_groq_client(api_key):
st.success("βœ… Groq client configured!")
st.markdown("---")
# Model selection
model_options = [
"llama-3.3-70b-versatile",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768"
]
selected_model = st.selectbox("πŸ€– Select Model", model_options)
st.markdown("---")
# Document upload
st.markdown("## πŸ“ Document Upload")
uploaded_files = st.file_uploader(
"Upload documents",
type=['pdf', 'docx', 'txt'],
accept_multiple_files=True,
help="Upload PDF, DOCX, or TXT files"
)
if uploaded_files and st.button("πŸš€ Process Documents"):
with st.spinner("Processing documents..."):
all_text = ""
doc_info = []
for file in uploaded_files:
if file.type == "application/pdf":
text = st.session_state.rag_system.extract_text_from_pdf(file)
doc_info.append(f"πŸ“„ {file.name} ({len(text)} chars)")
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
text = st.session_state.rag_system.extract_text_from_docx(file)
doc_info.append(f"πŸ“ {file.name} ({len(text)} chars)")
else: # txt
text = str(file.read(), "utf-8")
doc_info.append(f"πŸ“„ {file.name} ({len(text)} chars)")
all_text += text + "\n\n"
# Chunk the text
chunks = st.session_state.rag_system.chunk_text(all_text)
# Create embeddings and index
if st.session_state.rag_system.create_embeddings_and_index(chunks):
st.success(f"βœ… Processed {len(chunks)} chunks from {len(uploaded_files)} documents!")
# Show document info
st.markdown("### πŸ“Š Processed Documents:")
for info in doc_info:
st.markdown(f"- {info}")
# Clear chat history
if st.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
# Main content area
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("## πŸ’¬ Chat with your documents")
# Display chat history
chat_container = st.container()
with chat_container:
for i, (role, message) in enumerate(st.session_state.chat_history):
if role == "user":
st.markdown(f"""
<div class="chat-message user-message">
<strong>πŸ™‹β€β™‚οΈ You:</strong><br>{message}
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="chat-message bot-message">
<strong>πŸ€– Assistant:</strong><br>{message}
</div>
""", unsafe_allow_html=True)
# Query input
query = st.text_input(
"Ask a question about your documents:",
placeholder="e.g., What is the main topic discussed in the documents?",
key="query_input"
)
col_send, col_clear = st.columns([3, 1])
with col_send:
send_button = st.button("πŸ“€ Send", key="send_button")
if (send_button or query) and query:
if not st.session_state.rag_system.documents:
st.warning("⚠️ Please upload and process documents first!")
elif not api_key:
st.warning("⚠️ Please enter your Groq API key!")
else:
with st.spinner("Searching and generating answer..."):
# Retrieve relevant documents
relevant_docs = st.session_state.rag_system.retrieve_relevant_docs(query, k=3)
if relevant_docs:
# Combine context
context = "\n\n".join([doc for doc, score in relevant_docs])
# Generate answer
answer = st.session_state.rag_system.generate_answer(query, context, selected_model)
# Add to chat history
st.session_state.chat_history.append(("user", query))
st.session_state.chat_history.append(("assistant", answer))
# Clear input and rerun
st.rerun()
else:
st.error("No relevant documents found for your query.")
with col2:
st.markdown("## πŸ“ˆ System Status")
# System info
if st.session_state.rag_system.documents:
st.markdown(f"""
<div class="doc-info">
<h4>πŸ“š Knowledge Base</h4>
<p><strong>Documents:</strong> {len(st.session_state.rag_system.documents)} chunks</p>
<p><strong>Status:</strong> βœ… Ready</p>
<p><strong>Model:</strong> {selected_model}</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div class="doc-info">
<h4>πŸ“š Knowledge Base</h4>
<p><strong>Status:</strong> ❌ No documents loaded</p>
<p>Upload documents to get started!</p>
</div>
""", unsafe_allow_html=True)
# Instructions
st.markdown("""
<div class="sidebar-info">
<h4>πŸ“‹ How to use:</h4>
<ol>
<li>Enter your Groq API key</li>
<li>Upload documents (PDF, DOCX, TXT)</li>
<li>Click "Process Documents"</li>
<li>Ask questions about your documents</li>
</ol>
</div>
""", unsafe_allow_html=True)
# Features
st.markdown("""
<div class="sidebar-info">
<h4>✨ Features:</h4>
<ul>
<li>πŸš€ Fast inference with Groq</li>
<li>🧠 Smart document chunking</li>
<li>πŸ” Semantic search</li>
<li>πŸ’¬ Chat history</li>
<li>πŸ“± Responsive design</li>
</ul>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()