rag3 / app.py
zeeshan4801's picture
Create app.py
1faeb48 verified
import streamlit as st
import os
import numpy as np
import pandas as pd
from groq import Groq
from sentence_transformers import SentenceTransformer
import faiss
import pickle
from typing import List, Dict, Any
import PyPDF2
import docx
from io import BytesIO
import tempfile
# Set page config
st.set_page_config(
page_title="RAG Chat Assistant",
page_icon="🤖",
layout="wide"
)
class RAGSystem:
def __init__(self, groq_api_key: str):
"""Initialize the RAG system with Groq client and embedding model"""
self.groq_client = Groq(api_key=groq_api_key)
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.index = None
self.documents = []
self.embeddings = None
def extract_text_from_pdf(self, file) -> str:
"""Extract text from PDF file"""
try:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text
except Exception as e:
st.error(f"Error reading PDF: {str(e)}")
return ""
def extract_text_from_docx(self, file) -> str:
"""Extract text from DOCX file"""
try:
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
st.error(f"Error reading DOCX: {str(e)}")
return ""
def extract_text_from_txt(self, file) -> str:
"""Extract text from TXT file"""
try:
return str(file.read(), "utf-8")
except Exception as e:
st.error(f"Error reading TXT: {str(e)}")
return ""
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk.strip())
return chunks
def process_documents(self, uploaded_files) -> None:
"""Process uploaded documents and create embeddings"""
all_chunks = []
for uploaded_file in uploaded_files:
file_extension = uploaded_file.name.split('.')[-1].lower()
# Extract text based on file type
if file_extension == 'pdf':
text = self.extract_text_from_pdf(uploaded_file)
elif file_extension == 'docx':
text = self.extract_text_from_docx(uploaded_file)
elif file_extension == 'txt':
text = self.extract_text_from_txt(uploaded_file)
else:
st.error(f"Unsupported file type: {file_extension}")
continue
if text:
# Chunk the text
chunks = self.chunk_text(text)
for chunk in chunks:
all_chunks.append({
'text': chunk,
'source': uploaded_file.name
})
if all_chunks:
self.documents = all_chunks
# Create embeddings
texts = [doc['text'] for doc in all_chunks]
embeddings = self.embedding_model.encode(texts)
self.embeddings = embeddings
# Create FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype('float32'))
st.success(f"Processed {len(all_chunks)} chunks from {len(uploaded_files)} documents")
else:
st.error("No text could be extracted from the uploaded files")
def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""Retrieve the most relevant chunks for a given query"""
if self.index is None:
return []
# Encode the query
query_embedding = self.embedding_model.encode([query])
# Search for similar chunks
distances, indices = self.index.search(query_embedding.astype('float32'), k)
relevant_chunks = []
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
if idx < len(self.documents):
relevant_chunks.append({
'text': self.documents[idx]['text'],
'source': self.documents[idx]['source'],
'similarity_score': 1 / (1 + distance), # Convert distance to similarity
'rank': i + 1
})
return relevant_chunks
def generate_response(self, query: str, relevant_chunks: List[Dict[str, Any]]) -> str:
"""Generate response using Groq with retrieved context"""
# Prepare context from relevant chunks
context = "\n\n".join([f"Source: {chunk['source']}\nContent: {chunk['text']}"
for chunk in relevant_chunks])
# Create prompt
prompt = f"""Based on the following context, please answer the question accurately and comprehensively.
Context:
{context}
Question: {query}
Instructions:
- Use only the information provided in the context to answer the question
- If the context doesn't contain enough information to answer the question, say so
- Cite the sources when possible
- Be concise but comprehensive
Answer:"""
try:
# Generate response using Groq
chat_completion = self.groq_client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on provided context. Always be accurate and cite your sources."
},
{
"role": "user",
"content": prompt
}
],
model="llama-3.3-70b-versatile",
temperature=0.1,
max_tokens=1024
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"Error generating response: {str(e)}"
def main():
st.title("🤖 RAG Chat Assistant")
st.markdown("Upload documents and ask questions using Retrieval Augmented Generation")
# Initialize session state
if 'rag_system' not in st.session_state:
st.session_state.rag_system = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Sidebar for configuration
with st.sidebar:
st.header("Configuration")
# API Key input (pre-filled with your key)
groq_api_key = st.text_input(
"Groq API Key",
value=os.getenv("GROQ_API_KEY", ""), # Reads from environme
type="password",
help="Enter your Groq API key"
)
# Initialize RAG system when API key is provided
if groq_api_key and st.session_state.rag_system is None:
try:
st.session_state.rag_system = RAGSystem(groq_api_key)
st.success("RAG System initialized!")
except Exception as e:
st.error(f"Error initializing RAG system: {str(e)}")
st.header("Document Upload")
uploaded_files = st.file_uploader(
"Upload documents",
accept_multiple_files=True,
type=['pdf', 'txt', 'docx'],
help="Upload PDF, TXT, or DOCX files"
)
# Process documents button
if uploaded_files and st.session_state.rag_system:
if st.button("Process Documents"):
with st.spinner("Processing documents..."):
st.session_state.rag_system.process_documents(uploaded_files)
# Retrieval settings
st.header("Retrieval Settings")
num_chunks = st.slider("Number of chunks to retrieve", 1, 10, 3)
# Main chat interface
col1, col2 = st.columns([2, 1])
with col1:
st.header("Chat")
# Display chat history
for i, (question, answer) in enumerate(st.session_state.chat_history):
st.write(f"**You:** {question}")
st.write(f"**Assistant:** {answer}")
st.divider()
# Query input
query = st.text_input("Ask a question about your documents:", key="query_input")
if st.button("Ask") and query:
if not st.session_state.rag_system:
st.error("Please enter a valid Groq API key first")
elif not st.session_state.rag_system.documents:
st.error("Please upload and process documents first")
else:
with st.spinner("Generating response..."):
# Retrieve relevant chunks
relevant_chunks = st.session_state.rag_system.retrieve_relevant_chunks(
query, k=num_chunks
)
if relevant_chunks:
# Generate response
response = st.session_state.rag_system.generate_response(query, relevant_chunks)
# Add to chat history
st.session_state.chat_history.append((query, response))
# Display the response
st.write(f"**You:** {query}")
st.write(f"**Assistant:** {response}")
# Show retrieved chunks in sidebar
with col2:
st.header("Retrieved Context")
for chunk in relevant_chunks:
with st.expander(f"Rank {chunk['rank']} - {chunk['source']}"):
st.write(f"**Similarity:** {chunk['similarity_score']:.3f}")
st.write(f"**Text:** {chunk['text'][:200]}...")
else:
st.error("No relevant information found in the documents")
# Clear chat history
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
with col2:
if not st.session_state.chat_history:
st.header("Retrieved Context")
st.info("Ask a question to see retrieved context here")
if __name__ == "__main__":
main()