rag2 / app.py
anasfsd123's picture
Create app.py
c392441 verified
import os
import streamlit as st
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from groq import Groq
import faiss
import pickle
from typing import List, Dict, Tuple
import PyPDF2
import docx
from io import BytesIO
import time
# Initialize Groq client
def init_groq_client(api_key: str):
"""Initialize Groq client with API key"""
return Groq(api_key=api_key)
# Initialize embedding model
@st.cache_resource
def load_embedding_model():
"""Load and cache the sentence transformer model"""
return SentenceTransformer('all-MiniLM-L6-v2')
# Document processing functions
def extract_text_from_pdf(file):
"""Extract text from PDF file"""
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def extract_text_from_docx(file):
"""Extract text from DOCX file"""
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
def extract_text_from_txt(file):
"""Extract text from TXT file"""
return str(file.read(), "utf-8")
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
chunks.append(chunk)
if i + chunk_size >= len(words):
break
return chunks
# Vector store class
class VectorStore:
def __init__(self, embedding_model):
self.embedding_model = embedding_model
self.documents = []
self.embeddings = []
self.index = None
def add_documents(self, documents: List[str]):
"""Add documents to the vector store"""
self.documents.extend(documents)
# Generate embeddings
new_embeddings = self.embedding_model.encode(documents)
if len(self.embeddings) == 0:
self.embeddings = new_embeddings
else:
self.embeddings = np.vstack([self.embeddings, new_embeddings])
# Build/update FAISS index
self._build_index()
def _build_index(self):
"""Build FAISS index for similarity search"""
if len(self.embeddings) > 0:
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Inner product for similarity
# Normalize embeddings for cosine similarity
normalized_embeddings = self.embeddings / np.linalg.norm(
self.embeddings, axis=1, keepdims=True
)
self.index.add(normalized_embeddings.astype('float32'))
def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
"""Search for similar documents"""
if self.index is None or len(self.documents) == 0:
return []
# Encode query
query_embedding = self.embedding_model.encode([query])
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
# Search
scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.documents):
results.append((self.documents[idx], float(score)))
return results
def save(self, filepath: str):
"""Save vector store to file"""
data = {
'documents': self.documents,
'embeddings': self.embeddings.tolist() if len(self.embeddings) > 0 else []
}
with open(filepath, 'wb') as f:
pickle.dump(data, f)
def load(self, filepath: str):
"""Load vector store from file"""
with open(filepath, 'rb') as f:
data = pickle.load(f)
self.documents = data['documents']
if data['embeddings']:
self.embeddings = np.array(data['embeddings'])
self._build_index()
# RAG class
class RAGSystem:
def __init__(self, groq_client, embedding_model):
self.groq_client = groq_client
self.vector_store = VectorStore(embedding_model)
def add_documents(self, documents: List[str]):
"""Add documents to the knowledge base"""
self.vector_store.add_documents(documents)
def query(self, question: str, model: str = "llama-3.3-70b-versatile", top_k: int = 3) -> Dict:
"""Answer a question using RAG"""
# Retrieve relevant documents
retrieved_docs = self.vector_store.search(question, top_k=top_k)
if not retrieved_docs:
return {
"answer": "I don't have any relevant information to answer your question.",
"sources": [],
"confidence": 0.0
}
# Prepare context
context = "\n\n".join([doc for doc, score in retrieved_docs])
# Create prompt
prompt = f"""Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question."
Context:
{context}
Question: {question}
Answer:"""
try:
# Get response from Groq
chat_completion = self.groq_client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=model,
temperature=0.1,
max_tokens=1000,
)
answer = chat_completion.choices[0].message.content
return {
"answer": answer,
"sources": [{"text": doc[:200] + "...", "score": score}
for doc, score in retrieved_docs],
"confidence": max([score for _, score in retrieved_docs]) if retrieved_docs else 0.0
}
except Exception as e:
return {
"answer": f"Error generating response: {str(e)}",
"sources": [],
"confidence": 0.0
}
# Streamlit App
def main():
st.set_page_config(
page_title="RAG App with Groq",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("πŸ€– RAG App with Groq & Sentence Transformers")
st.markdown("Ask questions about your documents using open-source models!")
# Sidebar
st.sidebar.header("βš™οΈ Configuration")
# API Key input
api_key = st.sidebar.text_input(
"Groq API Key",
value=st.secrets.get("GROQ_API_KEY", ""),
type="password",
help="Enter your Groq API key"
)
# Option 2: Fallback to environment variable (useful for local dev)
if not api_key:
api_key = os.getenv("GROQ_API_KEY")
# Model selection
model_options = [
"llama-3.3-70b-versatile",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768"
]
selected_model = st.sidebar.selectbox("Select Model", model_options)
# Number of retrieved documents
top_k = st.sidebar.slider("Number of retrieved documents", 1, 10, 3)
# Initialize components
if api_key:
try:
groq_client = init_groq_client(api_key)
embedding_model = load_embedding_model()
# Initialize session state
if 'rag_system' not in st.session_state:
st.session_state.rag_system = RAGSystem(groq_client, embedding_model)
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.header("πŸ“ Document Upload")
uploaded_files = st.file_uploader(
"Upload your documents",
type=['pdf', 'docx', 'txt'],
accept_multiple_files=True,
help="Supported formats: PDF, DOCX, TXT"
)
if uploaded_files:
if st.button("Process Documents", type="primary"):
with st.spinner("Processing documents..."):
all_chunks = []
for file in uploaded_files:
# Extract text based on file type
if file.type == "application/pdf":
text = extract_text_from_pdf(file)
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
text = extract_text_from_docx(file)
elif file.type == "text/plain":
text = extract_text_from_txt(file)
else:
st.error(f"Unsupported file type: {file.type}")
continue
# Chunk the text
chunks = chunk_text(text, chunk_size=500, overlap=50)
all_chunks.extend(chunks)
st.success(f"βœ… Processed {file.name}: {len(chunks)} chunks")
# Add to RAG system
if all_chunks:
st.session_state.rag_system.add_documents(all_chunks)
st.success(f"πŸŽ‰ Added {len(all_chunks)} chunks to knowledge base!")
# Display document stats
if hasattr(st.session_state.rag_system, 'vector_store') and len(st.session_state.rag_system.vector_store.documents) > 0:
st.info(f"πŸ“Š Knowledge Base: {len(st.session_state.rag_system.vector_store.documents)} chunks")
with col2:
st.header("πŸ’¬ Ask Questions")
# Chat interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
if message["role"] == "assistant" and "sources" in message:
with st.expander("πŸ“š Sources"):
for i, source in enumerate(message["sources"]):
st.write(f"**Source {i+1}** (Score: {source['score']:.3f})")
st.write(source["text"])
# Chat input
if prompt := st.chat_input("Ask a question about your documents..."):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.rag_system.query(
prompt,
model=selected_model,
top_k=top_k
)
st.write(response["answer"])
# Show sources
if response["sources"]:
with st.expander("πŸ“š Sources"):
for i, source in enumerate(response["sources"]):
st.write(f"**Source {i+1}** (Score: {source['score']:.3f})")
st.write(source["text"])
# Add to chat history
st.session_state.messages.append({
"role": "assistant",
"content": response["answer"],
"sources": response["sources"]
})
# Clear chat button
if st.button("πŸ—‘οΈ Clear Chat"):
st.session_state.messages = []
st.rerun()
except Exception as e:
st.error(f"Error initializing components: {str(e)}")
else:
st.warning("Please enter your Groq API key in the sidebar to get started.")
# Footer
st.sidebar.markdown("---")
st.sidebar.markdown(
"""
**About this app:**
- Uses Groq for fast inference
- Sentence Transformers for embeddings
- FAISS for vector search
- Supports PDF, DOCX, TXT files
"""
)
if __name__ == "__main__":
main()