File size: 3,552 Bytes
c427af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import streamlit as st
import faiss
import numpy as np
import fitz  # PyMuPDF for PDF text extraction
from sentence_transformers import SentenceTransformer
from groq import Groq

# Set up API key for Groq LLM
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
if not GROQ_API_KEY:
    st.error("🚨 Groq API Key is missing! Set `GROQ_API_KEY` in the environment.")
    st.stop()

# Initialize Groq Client
client = Groq(api_key=GROQ_API_KEY)

# Load Embedding Model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# Initialize FAISS Index
embedding_size = 384  # Dimension of embeddings from MiniLM
index = faiss.IndexFlatL2(embedding_size)
documents = []  # To store text chunks

# Function to extract text from PDF
def extract_text_from_pdf(pdf_file):
    try:
        doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
        text = "\n".join([page.get_text("text") for page in doc])
        return text
    except Exception as e:
        st.error(f"❌ Error extracting text: {e}")
        return ""

# Function to split text into chunks
def chunk_text(text, chunk_size=512):
    words = text.split()
    return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]

# Function to store document embeddings in FAISS
def store_embeddings(chunks):
    global documents, index
    embeddings = embed_model.encode(chunks)
    index.add(np.array(embeddings).astype("float32"))
    documents.extend(chunks)

# Function to retrieve relevant chunks from FAISS
def retrieve_relevant_chunks(query, top_k=3):
    if index.ntotal == 0:
        return []
    
    query_embedding = embed_model.encode([query]).astype("float32")
    distances, indices = index.search(query_embedding, top_k)
    return [documents[i] for i in indices[0] if i < len(documents)]

# Function to query Groq API with retrieved context
def ask_groq(question, context):
    prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
    
    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama-3.3-70b-versatile",
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"❌ Error generating response: {e}"

# Streamlit UI
st.set_page_config(page_title="RAG Q&A with Groq", page_icon="📄", layout="wide")

st.title("📄 RAG-based Q&A with Open Source LLM & FAISS")
st.write("Upload a **PDF document**, then ask questions based on its content!")

uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])

if uploaded_file:
    with st.spinner("Extracting text from PDF..."):
        pdf_text = extract_text_from_pdf(uploaded_file)

    if pdf_text:
        with st.spinner("Chunking and embedding document..."):
            chunks = chunk_text(pdf_text)
            store_embeddings(chunks)

        st.success("✅ Document processed! You can now ask questions.")

question = st.text_input("Ask a question from the document:", "")

if st.button("Get Answer"):
    if question:
        if index.ntotal == 0:
            st.warning("⚠️ No document uploaded! Please upload a PDF first.")
        else:
            with st.spinner("Retrieving relevant context..."):
                context = " ".join(retrieve_relevant_chunks(question))

            with st.spinner("Generating answer using Groq LLM..."):
                answer = ask_groq(question, context)

            st.success("Answer:")
            st.write(answer)
    else:
        st.warning("⚠️ Please enter a question!")