RizwanSajad commited on
Commit
cb681da
·
verified ·
1 Parent(s): 67b6dfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -58
app.py CHANGED
@@ -3,95 +3,82 @@ import streamlit as st
3
  import numpy as np
4
  import faiss
5
  from groq import Groq
 
 
6
  from sentence_transformers import SentenceTransformer
7
- from PyPDF2 import PdfReader
8
 
9
  # Constants
10
  DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
11
  GROQ_MODEL = "llama-3.3-70b-versatile"
12
 
13
- # Download the document
14
- def download_document(file_link):
15
- from pydrive.auth import GoogleAuth
16
- from pydrive.drive import GoogleDrive
17
-
18
- st.info("Authenticating with Google Drive...")
19
  gauth = GoogleAuth()
20
  gauth.LocalWebserverAuth()
21
  drive = GoogleDrive(gauth)
22
-
23
- file_id = file_link.split("/d/")[1].split("/view")[0]
24
- downloaded_file = drive.CreateFile({"id": file_id})
25
  downloaded_file.GetContentFile("document.pdf")
26
  return "document.pdf"
27
 
28
- # Chunk the text
29
- def chunk_text(text, chunk_size=500, chunk_overlap=200):
30
- chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
31
- return chunks
 
 
 
 
 
32
 
33
- # Create embeddings and store in FAISS
34
- def create_vector_database(chunks):
35
- st.info("Creating embeddings...")
36
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
37
- embeddings = []
38
 
39
- # Process embeddings in batches
40
- for i in range(0, len(chunks), 100):
41
- batch = chunks[i:i+100]
42
- embeddings.append(embedder.encode(batch, convert_to_tensor=True).detach().numpy())
43
- embeddings = np.vstack(embeddings)
44
 
45
- st.info("Initializing FAISS vector database...")
46
  vector_dim = embeddings.shape[1]
47
  index = faiss.IndexFlatL2(vector_dim)
48
  index.add(embeddings)
 
49
 
50
- return index
 
 
 
51
 
52
- # Query the vector database
53
  def query_vector_db(query, chunks, index, embedder):
54
  query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
55
- D, I = index.search(query_embedding, k=1) # Top 1 match
56
- if I[0][0] != -1:
57
  return chunks[I[0][0]]
58
  return "No relevant content found."
59
 
60
- # Main Streamlit App
61
  def main():
62
  st.title("RAG-based Application with Groq")
63
 
64
- # Step 1: Load Document
65
- if st.button("Download and Load Document"):
66
- document_path = download_document(DRIVE_FILE_LINK)
67
- reader = PdfReader(document_path)
68
- text = "".join([page.extract_text() for page in reader.pages])
69
- chunks = chunk_text(text)
70
- st.success("Document loaded and chunked!")
71
- st.session_state["chunks"] = chunks
72
-
73
- # Step 2: Create Vector Database
74
- if st.button("Create Vector Database"):
75
- if "chunks" not in st.session_state:
76
- st.error("Please load the document first!")
77
- else:
78
- index = create_vector_database(st.session_state["chunks"])
79
- st.session_state["index"] = index
80
- st.success("Vector database created successfully!")
81
-
82
- # Step 3: Query
83
  user_input = st.text_input("Enter your query:")
84
  if user_input:
85
- if "index" not in st.session_state or "chunks" not in st.session_state:
86
- st.error("Please load the document and create the vector database first!")
87
- else:
88
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
89
- context = query_vector_db(user_input, st.session_state["chunks"], st.session_state["index"], embedder)
90
- st.write("**Relevant Context:**", context)
91
 
92
- # Query Groq model
93
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
94
- st.info("Querying Groq model...")
95
  chat_completion = client.chat.completions.create(
96
  messages=[
97
  {"role": "user", "content": f"Based on this context: {context}, {user_input}"}
 
3
  import numpy as np
4
  import faiss
5
  from groq import Groq
6
+ from pydrive.auth import GoogleAuth
7
+ from pydrive.drive import GoogleDrive
8
  from sentence_transformers import SentenceTransformer
 
9
 
10
  # Constants
11
  DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
12
  GROQ_MODEL = "llama-3.3-70b-versatile"
13
 
14
+ # Authentication and setup for Google Drive
15
+ @st.cache_resource
16
+ def load_drive_content(file_link):
 
 
 
17
  gauth = GoogleAuth()
18
  gauth.LocalWebserverAuth()
19
  drive = GoogleDrive(gauth)
20
+ file_id = file_link.split('/d/')[1].split('/view')[0]
21
+ downloaded_file = drive.CreateFile({'id': file_id})
 
22
  downloaded_file.GetContentFile("document.pdf")
23
  return "document.pdf"
24
 
25
+ # Chunking and embedding creation
26
+ @st.cache_resource
27
+ def prepare_embeddings(document_path):
28
+ from PyPDF2 import PdfReader
29
+
30
+ reader = PdfReader(document_path)
31
+ text = ""
32
+ for page in reader.pages:
33
+ text += page.extract_text()
34
 
35
+ # Create chunks of 500 characters with a sliding window of 200
36
+ chunk_size = 500
37
+ chunk_overlap = 200
38
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
 
39
 
40
+ # Embedding model
41
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
42
+ embeddings = embedder.encode(chunks, convert_to_tensor=True).detach().numpy()
 
 
43
 
44
+ # Store in FAISS
45
  vector_dim = embeddings.shape[1]
46
  index = faiss.IndexFlatL2(vector_dim)
47
  index.add(embeddings)
48
+ return chunks, index
49
 
50
+ # Groq setup
51
+ @st.cache_resource
52
+ def groq_client():
53
+ return Groq(api_key=os.environ.get("GROQ_API_KEY"))
54
 
55
+ # Retrieve and query vector DB
56
  def query_vector_db(query, chunks, index, embedder):
57
  query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
58
+ D, I = index.search(query_embedding, k=1) # Find top result
59
+ if I[0][0] != -1: # Valid match
60
  return chunks[I[0][0]]
61
  return "No relevant content found."
62
 
63
+ # Streamlit application
64
  def main():
65
  st.title("RAG-based Application with Groq")
66
 
67
+ # Load document and prepare FAISS
68
+ st.info("Loading document and preparing FAISS...")
69
+ document_path = load_drive_content(DRIVE_FILE_LINK)
70
+ chunks, index = prepare_embeddings(document_path)
71
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
72
+ client = groq_client()
73
+
74
+ # Interface
 
 
 
 
 
 
 
 
 
 
 
75
  user_input = st.text_input("Enter your query:")
76
  if user_input:
77
+ context = query_vector_db(user_input, chunks, index, embedder)
78
+ st.write("**Relevant Context:**", context)
 
 
 
 
79
 
80
+ # Query Groq model
81
+ with st.spinner("Querying Groq model..."):
 
82
  chat_completion = client.chat.completions.create(
83
  messages=[
84
  {"role": "user", "content": f"Based on this context: {context}, {user_input}"}