File size: 3,026 Bytes
a0391eb
 
73b110f
 
a0391eb
cb681da
 
73b110f
a0391eb
 
73b110f
 
a0391eb
cb681da
 
 
73b110f
 
 
cb681da
 
73b110f
 
a0391eb
cb681da
 
 
 
 
 
 
 
 
a0391eb
cb681da
 
 
 
67b6dfd
cb681da
 
 
a0391eb
cb681da
73b110f
 
a0391eb
cb681da
a0391eb
cb681da
 
 
 
a0391eb
cb681da
73b110f
 
cb681da
 
73b110f
 
a0391eb
cb681da
a0391eb
73b110f
a0391eb
cb681da
 
 
 
 
 
 
 
73b110f
 
cb681da
 
a0391eb
cb681da
 
73b110f
 
 
 
 
 
 
a0391eb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import streamlit as st
import numpy as np
import faiss
from groq import Groq
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from sentence_transformers import SentenceTransformer

# Constants
DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
GROQ_MODEL = "llama-3.3-70b-versatile"

# Authentication and setup for Google Drive
@st.cache_resource
def load_drive_content(file_link):
    gauth = GoogleAuth()
    gauth.LocalWebserverAuth()
    drive = GoogleDrive(gauth)
    file_id = file_link.split('/d/')[1].split('/view')[0]
    downloaded_file = drive.CreateFile({'id': file_id})
    downloaded_file.GetContentFile("document.pdf")
    return "document.pdf"

# Chunking and embedding creation
@st.cache_resource
def prepare_embeddings(document_path):
    from PyPDF2 import PdfReader
    
    reader = PdfReader(document_path)
    text = ""
    for page in reader.pages:
        text += page.extract_text()

    # Create chunks of 500 characters with a sliding window of 200
    chunk_size = 500
    chunk_overlap = 200
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
    
    # Embedding model
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
    embeddings = embedder.encode(chunks, convert_to_tensor=True).detach().numpy()

    # Store in FAISS
    vector_dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(vector_dim)
    index.add(embeddings)
    return chunks, index

# Groq setup
@st.cache_resource
def groq_client():
    return Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Retrieve and query vector DB
def query_vector_db(query, chunks, index, embedder):
    query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
    D, I = index.search(query_embedding, k=1)  # Find top result
    if I[0][0] != -1:  # Valid match
        return chunks[I[0][0]]
    return "No relevant content found."

# Streamlit application
def main():
    st.title("RAG-based Application with Groq")

    # Load document and prepare FAISS
    st.info("Loading document and preparing FAISS...")
    document_path = load_drive_content(DRIVE_FILE_LINK)
    chunks, index = prepare_embeddings(document_path)
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
    client = groq_client()
    
    # Interface
    user_input = st.text_input("Enter your query:")
    if user_input:
        context = query_vector_db(user_input, chunks, index, embedder)
        st.write("**Relevant Context:**", context)

        # Query Groq model
        with st.spinner("Querying Groq model..."):
            chat_completion = client.chat.completions.create(
                messages=[
                    {"role": "user", "content": f"Based on this context: {context}, {user_input}"}
                ],
                model=GROQ_MODEL,
            )
            st.write("**Groq Model Response:**", chat_completion.choices[0].message.content)

if __name__ == "__main__":
    main()