RizwanSajad commited on
Commit
73b110f
·
verified ·
1 Parent(s): 3166986

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -75
app.py CHANGED
@@ -1,94 +1,91 @@
1
  import os
2
- import json
3
- import faiss
4
- import numpy as np
5
- import PyPDF2
6
- import requests
7
  import streamlit as st
 
 
8
  from groq import Groq
 
 
 
9
 
10
  # Constants
11
- PDF_URL = "https://drive.google.com/uc?export=download&id=1YWX-RYxgtcKO1QETnz1N3rboZUhRZwcH"
12
- VECTOR_DIM = 768
13
- CHUNK_SIZE = 512
14
 
15
- # Function to download and extract text from the PDF
16
- def extract_text_from_pdf(url):
17
- response = requests.get(url)
18
- with open("document.pdf", "wb") as f:
19
- f.write(response.content)
 
 
 
 
 
20
 
21
- with open("document.pdf", "rb") as f:
22
- reader = PyPDF2.PdfReader(f)
23
- text = "\n".join(page.extract_text() for page in reader.pages)
24
- return text
 
 
 
 
 
25
 
26
- # Function to split text into chunks
27
- def create_chunks(text, chunk_size):
28
- words = text.split()
29
- chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
30
- return chunks
 
 
 
31
 
32
- # Function to create FAISS vector store
33
- def create_faiss_index(chunks, vector_dim):
34
- # Check if GPU is available and use it
35
- if faiss.get_num_gpus() > 0:
36
- st.write("Using GPU for FAISS indexing.")
37
- resource = faiss.StandardGpuResources() # Initialize GPU resources
38
- index_flat = faiss.IndexFlatL2(vector_dim)
39
- index = faiss.index_cpu_to_gpu(resource, 0, index_flat)
40
- else:
41
- st.write("Using CPU for FAISS indexing.")
42
- index = faiss.IndexFlatL2(vector_dim)
43
-
44
- embeddings = np.random.rand(len(chunks), vector_dim).astype('float32') # Replace with real embeddings
45
  index.add(embeddings)
46
- return index, embeddings
47
 
48
- # Initialize Groq API client
49
- def get_groq_client():
50
- return os.environ.get("GROQ_API_KEY")
 
51
 
52
- # Query Groq model
53
- def query_model(client, question):
54
- chat_completion = client.chat.completions.create(
55
- messages=[{"role": "user", "content": question}],
56
- model="llama-3.3-70b-versatile",
57
- )
58
- return chat_completion.choices[0].message.content
59
 
60
- # Streamlit app
61
  def main():
62
- st.title("RAG-Based Application")
63
-
64
- # Step 1: Extract text from the document
65
- st.header("Step 1: Extract Text")
66
- if st.button("Extract Text from PDF"):
67
- text = extract_text_from_pdf(PDF_URL)
68
- st.session_state["text"] = text
69
- st.success("Text extracted successfully!")
70
-
71
- # Step 2: Chunk the text
72
- st.header("Step 2: Create Chunks")
73
- if "text" in st.session_state and st.button("Create Chunks"):
74
- chunks = create_chunks(st.session_state["text"], CHUNK_SIZE)
75
- st.session_state["chunks"] = chunks
76
- st.success(f"Created {len(chunks)} chunks.")
77
 
78
- # Step 3: Create FAISS index
79
- st.header("Step 3: Create Vector Database")
80
- if "chunks" in st.session_state and st.button("Create Vector Database"):
81
- index, embeddings = create_faiss_index(st.session_state["chunks"], VECTOR_DIM)
82
- st.session_state["index"] = index
83
- st.success("FAISS vector database created.")
 
 
 
 
 
 
84
 
85
- # Step 4: Ask a question
86
- st.header("Step 4: Query the Model")
87
- question = st.text_input("Ask a question about the document:")
88
- if question and "index" in st.session_state:
89
- client = get_groq_client()
90
- answer = query_model(client, question)
91
- st.write("Answer:", answer)
 
 
92
 
93
  if __name__ == "__main__":
94
  main()
 
1
  import os
 
 
 
 
 
2
  import streamlit as st
3
+ import numpy as np
4
+ import faiss
5
  from groq import Groq
6
+ from pydrive.auth import GoogleAuth
7
+ from pydrive.drive import GoogleDrive
8
+ from sentence_transformers import SentenceTransformer
9
 
10
  # Constants
11
+ DRIVE_FILE_LINK = "https://drive.google.com/file/d/1kYGomSibXW-wCFptEMcWP12jOz1390OK/view?usp=drive_link"
12
+ GROQ_MODEL = "llama-3.3-70b-versatile"
 
13
 
14
+ # Authentication and setup for Google Drive
15
+ @st.cache_resource
16
+ def load_drive_content(file_link):
17
+ gauth = GoogleAuth()
18
+ gauth.LocalWebserverAuth()
19
+ drive = GoogleDrive(gauth)
20
+ file_id = file_link.split('/d/')[1].split('/view')[0]
21
+ downloaded_file = drive.CreateFile({'id': file_id})
22
+ downloaded_file.GetContentFile("document.pdf")
23
+ return "document.pdf"
24
 
25
+ # Chunking and embedding creation
26
+ @st.cache_resource
27
+ def prepare_embeddings(document_path):
28
+ from PyPDF2 import PdfReader
29
+
30
+ reader = PdfReader(document_path)
31
+ text = ""
32
+ for page in reader.pages:
33
+ text += page.extract_text()
34
 
35
+ # Create chunks of 500 characters with a sliding window of 200
36
+ chunk_size = 500
37
+ chunk_overlap = 200
38
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - chunk_overlap)]
39
+
40
+ # Embedding model
41
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
42
+ embeddings = embedder.encode(chunks, convert_to_tensor=True).detach().numpy()
43
 
44
+ # Store in FAISS
45
+ vector_dim = embeddings.shape[1]
46
+ index = faiss.IndexFlatL2(vector_dim)
 
 
 
 
 
 
 
 
 
 
47
  index.add(embeddings)
48
+ return chunks, index
49
 
50
+ # Groq setup
51
+ @st.cache_resource
52
+ def groq_client():
53
+ return Groq(api_key=os.environ.get("GROQ_API_KEY"))
54
 
55
+ # Retrieve and query vector DB
56
+ def query_vector_db(query, chunks, index, embedder):
57
+ query_embedding = embedder.encode([query], convert_to_tensor=True).detach().numpy()
58
+ D, I = index.search(query_embedding, k=1) # Find top result
59
+ if I[0][0] != -1: # Valid match
60
+ return chunks[I[0][0]]
61
+ return "No relevant content found."
62
 
63
+ # Streamlit application
64
  def main():
65
+ st.title("RAG-based Application with Groq")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Load document and prepare FAISS
68
+ st.info("Loading document and preparing FAISS...")
69
+ document_path = load_drive_content(DRIVE_FILE_LINK)
70
+ chunks, index = prepare_embeddings(document_path)
71
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
72
+ client = groq_client()
73
+
74
+ # Interface
75
+ user_input = st.text_input("Enter your query:")
76
+ if user_input:
77
+ context = query_vector_db(user_input, chunks, index, embedder)
78
+ st.write("**Relevant Context:**", context)
79
 
80
+ # Query Groq model
81
+ with st.spinner("Querying Groq model..."):
82
+ chat_completion = client.chat.completions.create(
83
+ messages=[
84
+ {"role": "user", "content": f"Based on this context: {context}, {user_input}"}
85
+ ],
86
+ model=GROQ_MODEL,
87
+ )
88
+ st.write("**Groq Model Response:**", chat_completion.choices[0].message.content)
89
 
90
  if __name__ == "__main__":
91
  main()