Asrar990's picture
Update app.py
aa5fbc4 verified
import os
import streamlit as st
from PyPDF2 import PdfReader
import pandas as pd
import docx
from sentence_transformers import SentenceTransformer
import faiss
from groq import Groq
import numpy as np
from sklearn.preprocessing import normalize
# Initialize Groq API
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Initialize SentenceTransformer model
embedder_model = SentenceTransformer("all-MiniLM-L6-v2")
# Helper function to extract text from PDF
def extract_text_from_pdf(file):
pdf_reader = PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Helper function to extract text from Excel
def extract_text_from_excel(file):
df = pd.read_excel(file)
return df.to_string()
# Helper function to extract text from Word document
def extract_text_from_docx(file):
doc = docx.Document(file)
text = "\n".join([para.text for para in doc.paragraphs])
return text
# Function to chunk text into smaller parts
def chunk_text(text, chunk_size=512):
# Split text into chunks of specified size
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
return chunks
# Function to create FAISS index and store embeddings
def create_faiss_index(texts, model):
embeddings = model.encode(texts)
embeddings = normalize(embeddings) # Normalize embeddings for better comparison
index = faiss.IndexFlatL2(embeddings.shape[1]) # Create FAISS index
index.add(embeddings) # Add embeddings to FAISS index
return index, embeddings
# Function to retrieve context from FAISS
def retrieve_context(query, index, texts, model, top_k=5):
query_embedding = model.encode([query])
distances, indices = index.search(query_embedding, top_k)
retrieved_texts = [texts[i] for i in indices[0]]
return "\n".join(retrieved_texts)
# Function to query Groq API
def query_groq_api(context, question):
try:
response = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
],
model="llama-3.3-70b-versatile",
)
return response.choices[0].message.content
except Exception as e:
return f"Error querying Groq API: {e}"
# Streamlit App
def main():
st.title("RAG-based Document Q&A")
st.write("Upload a document, and ask questions based on its content.")
uploaded_file = st.file_uploader("Upload your document", type=["pdf", "xlsx", "docx", "txt"])
user_question = st.text_input("Enter your question:")
if uploaded_file is not None:
# Extract text based on file type
if uploaded_file.name.endswith(".pdf"):
context = extract_text_from_pdf(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
context = extract_text_from_excel(uploaded_file)
elif uploaded_file.name.endswith(".docx"):
context = extract_text_from_docx(uploaded_file)
elif uploaded_file.name.endswith(".txt"):
context = uploaded_file.read().decode("utf-8")
else:
st.error("Unsupported file format!")
return
# Chunk the extracted text into smaller segments
chunks = chunk_text(context)
# Create FAISS index for the text chunks
index, embeddings = create_faiss_index(chunks, embedder_model)
if user_question:
if st.button("Submit Question"):
st.write("Answer:")
# Retrieve relevant context from the FAISS index
retrieved_context = retrieve_context(user_question, index, chunks, embedder_model)
# Query Groq API with the context and question
answer = query_groq_api(retrieved_context, user_question)
st.success(answer)
if __name__ == "__main__":
main()