File size: 3,974 Bytes
981adb2
 
 
 
 
 
 
 
82ca82c
 
981adb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82ca82c
 
aa5fbc4
 
82ca82c
 
 
981adb2
 
aa5fbc4
82ca82c
 
981adb2
 
 
 
 
 
aa5fbc4
 
981adb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86926ac
981adb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82ca82c
 
981adb2
82ca82c
981adb2
 
 
82ca82c
aa5fbc4
 
82ca82c
aa5fbc4
 
 
 
 
981adb2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import streamlit as st
from PyPDF2 import PdfReader
import pandas as pd
import docx
from sentence_transformers import SentenceTransformer
import faiss
from groq import Groq
import numpy as np
from sklearn.preprocessing import normalize

# Initialize Groq API
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Initialize SentenceTransformer model
embedder_model = SentenceTransformer("all-MiniLM-L6-v2")

# Helper function to extract text from PDF
def extract_text_from_pdf(file):
    pdf_reader = PdfReader(file)
    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text()
    return text

# Helper function to extract text from Excel
def extract_text_from_excel(file):
    df = pd.read_excel(file)
    return df.to_string()

# Helper function to extract text from Word document
def extract_text_from_docx(file):
    doc = docx.Document(file)
    text = "\n".join([para.text for para in doc.paragraphs])
    return text

# Function to chunk text into smaller parts
def chunk_text(text, chunk_size=512):
    # Split text into chunks of specified size
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
    return chunks

# Function to create FAISS index and store embeddings
def create_faiss_index(texts, model):
    embeddings = model.encode(texts)
    embeddings = normalize(embeddings)  # Normalize embeddings for better comparison
    index = faiss.IndexFlatL2(embeddings.shape[1])  # Create FAISS index
    index.add(embeddings)  # Add embeddings to FAISS index
    return index, embeddings

# Function to retrieve context from FAISS
def retrieve_context(query, index, texts, model, top_k=5):
    query_embedding = model.encode([query])
    distances, indices = index.search(query_embedding, top_k)
    retrieved_texts = [texts[i] for i in indices[0]]
    return "\n".join(retrieved_texts)

# Function to query Groq API
def query_groq_api(context, question):
    try:
        response = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
            ],
            model="llama-3.3-70b-versatile",
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error querying Groq API: {e}"

# Streamlit App
def main():
    st.title("RAG-based Document Q&A")
    st.write("Upload a document, and ask questions based on its content.")

    uploaded_file = st.file_uploader("Upload your document", type=["pdf", "xlsx", "docx", "txt"])
    user_question = st.text_input("Enter your question:")

    if uploaded_file is not None:
        # Extract text based on file type
        if uploaded_file.name.endswith(".pdf"):
            context = extract_text_from_pdf(uploaded_file)
        elif uploaded_file.name.endswith(".xlsx"):
            context = extract_text_from_excel(uploaded_file)
        elif uploaded_file.name.endswith(".docx"):
            context = extract_text_from_docx(uploaded_file)
        elif uploaded_file.name.endswith(".txt"):
            context = uploaded_file.read().decode("utf-8")
        else:
            st.error("Unsupported file format!")
            return

        # Chunk the extracted text into smaller segments
        chunks = chunk_text(context)

        # Create FAISS index for the text chunks
        index, embeddings = create_faiss_index(chunks, embedder_model)

        if user_question:
            if st.button("Submit Question"):
                st.write("Answer:")

                # Retrieve relevant context from the FAISS index
                retrieved_context = retrieve_context(user_question, index, chunks, embedder_model)

                # Query Groq API with the context and question
                answer = query_groq_api(retrieved_context, user_question)
                st.success(answer)

if __name__ == "__main__":
    main()