Muqadas-13's picture
Update app.py
29be773 verified
import os
import streamlit as st
from PyPDF2 import PdfReader
from docx import Document
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
# =========================
# βœ… Initialize Groq client
# =========================
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# =========================
# βœ… Load embedding model
# =========================
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# =========================
# βœ… Initialize FAISS index
# =========================
INDEX = faiss.IndexFlatL2(384)
stored_chunks = []
# =========================
# βœ… Streamlit UI Styling
# =========================
st.markdown("""
<style>
.main-title {
font-size: 40px;
color: #2E86C1;
font-weight: bold;
text-align: center;
margin-bottom: 30px;
}
.card {
background-color: #ffffff;
padding: 20px;
border-radius: 15px;
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
margin-top: 20px;
}
body {
background-color: #f8fbfd;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="main-title">πŸ“„ Smart RAG Document QA Assistant</div>', unsafe_allow_html=True)
# =========================
# βœ… Extract text from files
# =========================
def extract_text(file):
if file.type == "application/pdf":
reader = PdfReader(file)
return " ".join([page.extract_text() or "" for page in reader.pages])
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
doc = Document(file)
return "\n".join([p.text for p in doc.paragraphs])
elif file.type.startswith("text"):
return file.read().decode("utf-8")
return ""
# =========================
# βœ… Chunk text for embedding
# =========================
def chunk_text(text, chunk_size=200):
words = text.split()
return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
# =========================
# βœ… Store vector embeddings in FAISS
# =========================
def store_embeddings(chunks):
vectors = embed_model.encode(chunks)
INDEX.add(np.array(vectors, dtype=np.float32))
stored_chunks.extend(chunks)
# =========================
# βœ… Retrieve top-k similar chunks
# =========================
def retrieve_similar_chunks(query, top_k=3):
query_vector = embed_model.encode([query])
distances, indices = INDEX.search(np.array(query_vector, dtype=np.float32), top_k)
return [stored_chunks[i] for i in indices[0]]
# =========================
# βœ… Ask Groq LLM using context
# =========================
def get_llm_answer(query, context):
prompt = f"Answer the question based on the following context:\n\n{context}\n\nQuestion: {query}"
# βœ… Updated to a supported Groq model
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-13b" # Use supported model
)
return chat_completion.choices[0].message.content
# =========================
# βœ… Streamlit File Uploader
# =========================
uploaded_file = st.file_uploader("πŸ“ Upload your document", type=["pdf", "docx", "txt"])
query = st.text_input("πŸ’¬ Ask a question about your document")
# =========================
# βœ… Process document and index
# =========================
if uploaded_file:
with st.spinner("Processing file..."):
text = extract_text(uploaded_file)
chunks = chunk_text(text)
store_embeddings(chunks)
st.success("βœ… Document uploaded and indexed!")
# =========================
# βœ… Ask question and get answer
# =========================
if st.button("🧠 Get Answer") and query:
if len(stored_chunks) == 0:
st.warning("Please upload and process a document first!")
else:
with st.spinner("Thinking..."):
context = "\n\n".join(retrieve_similar_chunks(query))
answer = get_llm_answer(query, context)
st.markdown(f'<div class="card"><b>Answer:</b><br>{answer}</div>', unsafe_allow_html=True)
# =========================
# βœ… Footer
# =========================
st.markdown("<br><center style='color: grey;'>Built by Muqadas with ❀️ using Streamlit + Groq + FAISS</center>", unsafe_allow_html=True)