RizwanSajad's picture
Update app.py
cb681da verified
import os
import streamlit as st
import numpy as np
import faiss
from groq import Groq
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from sentence_transformers import SentenceTransformer
# Constants
DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
GROQ_MODEL = "llama-3.3-70b-versatile"
# Authentication and setup for Google Drive
@st.cache_resource
def load_drive_content(file_link):
gauth = GoogleAuth()
gauth.LocalWebserverAuth()
drive = GoogleDrive(gauth)
file_id = file_link.split('/d/')[1].split('/view')[0]
downloaded_file = drive.CreateFile({'id': file_id})
downloaded_file.GetContentFile("document.pdf")
return "document.pdf"
# Chunking and embedding creation
@st.cache_resource
def prepare_embeddings(document_path):
from PyPDF2 import PdfReader
reader = PdfReader(document_path)
text = ""
for page in reader.pages:
text += page.extract_text()
# Create chunks of 500 characters with a sliding window of 200
chunk_size = 500
chunk_overlap = 200
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
# Embedding model
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(chunks, convert_to_tensor=True).detach().numpy()
# Store in FAISS
vector_dim = embeddings.shape[1]
index = faiss.IndexFlatL2(vector_dim)
index.add(embeddings)
return chunks, index
# Groq setup
@st.cache_resource
def groq_client():
return Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Retrieve and query vector DB
def query_vector_db(query, chunks, index, embedder):
query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
D, I = index.search(query_embedding, k=1) # Find top result
if I[0][0] != -1: # Valid match
return chunks[I[0][0]]
return "No relevant content found."
# Streamlit application
def main():
st.title("RAG-based Application with Groq")
# Load document and prepare FAISS
st.info("Loading document and preparing FAISS...")
document_path = load_drive_content(DRIVE_FILE_LINK)
chunks, index = prepare_embeddings(document_path)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
client = groq_client()
# Interface
user_input = st.text_input("Enter your query:")
if user_input:
context = query_vector_db(user_input, chunks, index, embedder)
st.write("**Relevant Context:**", context)
# Query Groq model
with st.spinner("Querying Groq model..."):
chat_completion = client.chat.completions.create(
messages=[
{"role": "user", "content": f"Based on this context: {context}, {user_input}"}
],
model=GROQ_MODEL,
)
st.write("**Groq Model Response:**", chat_completion.choices[0].message.content)
if __name__ == "__main__":
main()