Pdf_reader / app.py
Ahmed12322's picture
Update app.py
927fe6a verified
import streamlit as st
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
import os
import pypdf
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Initialize session state variables
if "faiss_index" not in st.session_state:
st.session_state["faiss_index"] = None
if "chunks" not in st.session_state:
st.session_state["chunks"] = []
# Set Groq API key - Consider using st.secrets for better security
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or st.secrets.get("GROQ_API_KEY", "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")
if not GROQ_API_KEY:
st.error("⚠️ GROQ_API_KEY is missing! Please set it in your environment variables or secrets.toml file.")
st.stop()
# Load embedding model with error handling
try:
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
st.error(f"❌ Failed to load embedding model: {str(e)}")
st.stop()
# Set up Groq client with error handling
try:
client = Groq(api_key=GROQ_API_KEY)
except Exception as e:
st.error(f"❌ Failed to initialize Groq client: {str(e)}")
st.stop()
# Function to extract text from PDF with error handling
def extract_text_from_pdf(uploaded_file):
try:
reader = pypdf.PdfReader(uploaded_file)
extracted_text = [page.extract_text() for page in reader.pages if page.extract_text()]
return "\n".join(extracted_text) if extracted_text else ""
except Exception as e:
st.error(f"❌ Error extracting text from PDF: {str(e)}")
return ""
# Function to create text chunks
def create_chunks(text, chunk_size=500, chunk_overlap=100):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""] # Added separators for better splitting
)
return text_splitter.split_text(text)
# Function to create and save FAISS index
def create_faiss_index(chunks):
try:
embeddings = embedding_model.encode(chunks, convert_to_numpy=True)
# Create FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, chunks
except Exception as e:
st.error(f"❌ Error creating FAISS index: {str(e)}")
return None, []
# Function to search FAISS
def search_faiss(query, index, chunks, top_k=2):
if index is None or not chunks:
return []
try:
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
return [chunks[i] for i in indices[0] if i < len(chunks)]
except Exception as e:
st.error(f"❌ Search error: {str(e)}")
return []
# Function to query Groq with enhanced prompt
def query_groq(query, context=None):
try:
prompt = f"""Use the following context to answer the question.
If you don't know the answer, say you don't know. Don't make up answers.
Context: {context if context else 'No specific context provided'}
Question: {query}
Answer:"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3-70b-8192", # Updated model name
temperature=0.3,
max_tokens=1024
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"Error querying Groq: {str(e)}"
# Streamlit UI
st.set_page_config(page_title="RAG Chatbot", page_icon="πŸ€–", layout="wide")
st.title("πŸ“„ RAG-Based Chatbot with FAISS & Groq")
# Sidebar for settings
with st.sidebar:
st.header("Settings")
top_k = st.slider("Number of chunks to retrieve", 1, 5, 2)
chunk_size = st.slider("Chunk size (characters)", 200, 1000, 500)
chunk_overlap = st.slider("Chunk overlap (characters)", 0, 200, 100)
# Upload PDF
uploaded_file = st.file_uploader("πŸ“€ Upload a PDF file", type="pdf")
if uploaded_file:
with st.spinner("πŸ”„ Processing PDF..."):
text = extract_text_from_pdf(uploaded_file)
if text.strip():
chunks = create_chunks(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# Create FAISS index
index, chunks = create_faiss_index(chunks)
# Store in session state
st.session_state["faiss_index"] = index
st.session_state["chunks"] = chunks
st.success(f"βœ… PDF processed successfully! Created {len(chunks)} chunks.")
else:
st.error("❌ No text found in the uploaded PDF.")
# Chat interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User query input
if prompt := st.chat_input("πŸ’¬ Ask me something about the document:"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner("πŸ”Ž Retrieving response..."):
retrieved_text = search_faiss(prompt, st.session_state["faiss_index"], st.session_state["chunks"], top_k=top_k)
context = "\n".join(retrieved_text) if retrieved_text else "No relevant context found."
response = query_groq(prompt, context)
st.session_state.messages.append({"role": "assistant", "content": response})
with st.chat_message("assistant"):
st.markdown(response)