RizwanSajad's picture
Update app.py
76bb0a6 verified
import os
import streamlit as st
import PyPDF2
from groq import Groq
from sentence_transformers import SentenceTransformer
import faiss
# Initialize Groq Client
client = Groq(api_key="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")
# Streamlit Frontend
st.title("RAG-Based Q&A App")
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
user_question = st.text_input("Ask a question:")
# Load Embedding Model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize FAISS Index
dimension = 384 # Match the embedding model output size
faiss_index = faiss.IndexFlatL2(dimension)
doc_chunks = []
chunk_embeddings = []
# PDF Text Extraction
def extract_text_from_pdf(file):
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Chunk Text
def chunk_text(text, chunk_size=500):
words = text.split()
return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
# Embed and Index Chunks
def index_chunks(chunks):
embeddings = embedding_model.encode(chunks)
faiss_index.add(embeddings)
return embeddings
# Retrieve Relevant Chunks
def retrieve_relevant_chunks(question, top_k=2):
question_embedding = embedding_model.encode([question])
distances, indices = faiss_index.search(question_embedding, top_k)
return [doc_chunks[i] for i in indices[0]]
# Get Groq Response
def generate_answer(prompt):
chat_completion = client.chat.completions.create(
messages=[
{"role": "user", "content": prompt},
],
model="llama3-8b-8192",
)
return chat_completion.choices[0].message.content
# Application Logic
if uploaded_file:
# Step 1: Extract Text
with st.spinner("Extracting text from PDF..."):
document_text = extract_text_from_pdf(uploaded_file)
# Step 2: Chunk Text
doc_chunks = chunk_text(document_text)
# Step 3: Index Chunks
with st.spinner("Creating embeddings and indexing..."):
chunk_embeddings = index_chunks(doc_chunks)
st.success("PDF Processed and Indexed!")
if user_question:
# Step 4: Retrieve Relevant Chunks
with st.spinner("Retrieving relevant context..."):
relevant_chunks = retrieve_relevant_chunks(user_question)
context = " ".join(relevant_chunks)
# Step 5: Generate Answer
with st.spinner("Generating answer..."):
answer = generate_answer(f"Context: {context}\n\nQuestion: {user_question}")
st.write("### Answer:")
st.write(answer)