RizwanSajad commited on
Commit
67b6dfd
·
verified ·
1 Parent(s): 73b110f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -45
app.py CHANGED
@@ -3,82 +3,95 @@ import streamlit as st
3
  import numpy as np
4
  import faiss
5
  from groq import Groq
6
- from pydrive.auth import GoogleAuth
7
- from pydrive.drive import GoogleDrive
8
  from sentence_transformers import SentenceTransformer
 
9
 
10
  # Constants
11
  DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
12
  GROQ_MODEL = "llama-3.3-70b-versatile"
13
 
14
- # Authentication and setup for Google Drive
15
- @st.cache_resource
16
- def load_drive_content(file_link):
 
 
 
17
  gauth = GoogleAuth()
18
  gauth.LocalWebserverAuth()
19
  drive = GoogleDrive(gauth)
20
- file_id = file_link.split('/d/')[1].split('/view')[0]
21
- downloaded_file = drive.CreateFile({'id': file_id})
 
22
  downloaded_file.GetContentFile("document.pdf")
23
  return "document.pdf"
24
 
25
- # Chunking and embedding creation
26
- @st.cache_resource
27
- def prepare_embeddings(document_path):
28
- from PyPDF2 import PdfReader
29
-
30
- reader = PdfReader(document_path)
31
- text = ""
32
- for page in reader.pages:
33
- text += page.extract_text()
34
 
35
- # Create chunks of 500 characters with a sliding window of 200
36
- chunk_size = 500
37
- chunk_overlap = 200
38
- chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
39
-
40
- # Embedding model
41
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
42
- embeddings = embedder.encode(chunks, convert_to_tensor=True).detach().numpy()
 
 
 
 
 
 
43
 
44
- # Store in FAISS
45
  vector_dim = embeddings.shape[1]
46
  index = faiss.IndexFlatL2(vector_dim)
47
  index.add(embeddings)
48
- return chunks, index
49
 
50
- # Groq setup
51
- @st.cache_resource
52
- def groq_client():
53
- return Groq(api_key=os.environ.get("GROQ_API_KEY"))
54
 
55
- # Retrieve and query vector DB
56
  def query_vector_db(query, chunks, index, embedder):
57
  query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
58
- D, I = index.search(query_embedding, k=1) # Find top result
59
- if I[0][0] != -1: # Valid match
60
  return chunks[I[0][0]]
61
  return "No relevant content found."
62
 
63
- # Streamlit application
64
  def main():
65
  st.title("RAG-based Application with Groq")
66
 
67
- # Load document and prepare FAISS
68
- st.info("Loading document and preparing FAISS...")
69
- document_path = load_drive_content(DRIVE_FILE_LINK)
70
- chunks, index = prepare_embeddings(document_path)
71
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
72
- client = groq_client()
73
-
74
- # Interface
 
 
 
 
 
 
 
 
 
 
 
75
  user_input = st.text_input("Enter your query:")
76
  if user_input:
77
- context = query_vector_db(user_input, chunks, index, embedder)
78
- st.write("**Relevant Context:**", context)
 
 
 
 
79
 
80
- # Query Groq model
81
- with st.spinner("Querying Groq model..."):
 
82
  chat_completion = client.chat.completions.create(
83
  messages=[
84
  {"role": "user", "content": f"Based on this context: {context}, {user_input}"}
 
3
  import numpy as np
4
  import faiss
5
  from groq import Groq
 
 
6
  from sentence_transformers import SentenceTransformer
7
+ from PyPDF2 import PdfReader
8
 
9
  # Constants
10
  DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
11
  GROQ_MODEL = "llama-3.3-70b-versatile"
12
 
13
+ # Download the document
14
+ def download_document(file_link):
15
+ from pydrive.auth import GoogleAuth
16
+ from pydrive.drive import GoogleDrive
17
+
18
+ st.info("Authenticating with Google Drive...")
19
  gauth = GoogleAuth()
20
  gauth.LocalWebserverAuth()
21
  drive = GoogleDrive(gauth)
22
+
23
+ file_id = file_link.split("/d/")[1].split("/view")[0]
24
+ downloaded_file = drive.CreateFile({"id": file_id})
25
  downloaded_file.GetContentFile("document.pdf")
26
  return "document.pdf"
27
 
28
+ # Chunk the text
29
+ def chunk_text(text, chunk_size=500, chunk_overlap=200):
30
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
31
+ return chunks
 
 
 
 
 
32
 
33
+ # Create embeddings and store in FAISS
34
+ def create_vector_database(chunks):
35
+ st.info("Creating embeddings...")
 
 
 
36
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
37
+ embeddings = []
38
+
39
+ # Process embeddings in batches
40
+ for i in range(0, len(chunks), 100):
41
+ batch = chunks[i:i+100]
42
+ embeddings.append(embedder.encode(batch, convert_to_tensor=True).detach().numpy())
43
+ embeddings = np.vstack(embeddings)
44
 
45
+ st.info("Initializing FAISS vector database...")
46
  vector_dim = embeddings.shape[1]
47
  index = faiss.IndexFlatL2(vector_dim)
48
  index.add(embeddings)
 
49
 
50
+ return index
 
 
 
51
 
52
+ # Query the vector database
53
  def query_vector_db(query, chunks, index, embedder):
54
  query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
55
+ D, I = index.search(query_embedding, k=1) # Top 1 match
56
+ if I[0][0] != -1:
57
  return chunks[I[0][0]]
58
  return "No relevant content found."
59
 
60
+ # Main Streamlit App
61
  def main():
62
  st.title("RAG-based Application with Groq")
63
 
64
+ # Step 1: Load Document
65
+ if st.button("Download and Load Document"):
66
+ document_path = download_document(DRIVE_FILE_LINK)
67
+ reader = PdfReader(document_path)
68
+ text = "".join([page.extract_text() for page in reader.pages])
69
+ chunks = chunk_text(text)
70
+ st.success("Document loaded and chunked!")
71
+ st.session_state["chunks"] = chunks
72
+
73
+ # Step 2: Create Vector Database
74
+ if st.button("Create Vector Database"):
75
+ if "chunks" not in st.session_state:
76
+ st.error("Please load the document first!")
77
+ else:
78
+ index = create_vector_database(st.session_state["chunks"])
79
+ st.session_state["index"] = index
80
+ st.success("Vector database created successfully!")
81
+
82
+ # Step 3: Query
83
  user_input = st.text_input("Enter your query:")
84
  if user_input:
85
+ if "index" not in st.session_state or "chunks" not in st.session_state:
86
+ st.error("Please load the document and create the vector database first!")
87
+ else:
88
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
89
+ context = query_vector_db(user_input, st.session_state["chunks"], st.session_state["index"], embedder)
90
+ st.write("**Relevant Context:**", context)
91
 
92
+ # Query Groq model
93
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
94
+ st.info("Querying Groq model...")
95
  chat_completion = client.chat.completions.create(
96
  messages=[
97
  {"role": "user", "content": f"Based on this context: {context}, {user_input}"}