NHZ commited on
Commit
d386915
·
verified ·
1 Parent(s): 75f7375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -64
app.py CHANGED
@@ -1,87 +1,107 @@
1
- import os
2
- import streamlit as st
3
- import PyPDF2
4
  import requests
5
- from sentence_transformers import SentenceTransformer
6
  import faiss
 
 
7
  from groq import Groq
 
 
 
8
 
9
- # Initialize Groq client using the secret environment variable
10
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
11
 
12
- # Function to download and read PDF content
13
- def extract_text_from_google_drive():
14
- link = "https://drive.google.com/uc?id=1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0"
15
- response = requests.get(link)
16
- with open("document.pdf", "wb") as file:
17
- file.write(response.content)
 
 
 
 
18
 
19
- with open("document.pdf", "rb") as file:
20
- reader = PyPDF2.PdfReader(file)
21
- text = " ".join([page.extract_text() for page in reader.pages])
 
 
 
 
22
  return text
23
 
24
- # Function to chunk text
25
- def chunk_text(text, max_length=500):
26
- sentences = text.split(". ")
27
- chunks = []
28
- chunk = ""
29
- for sentence in sentences:
30
- if len(chunk) + len(sentence) <= max_length:
31
- chunk += sentence + ". "
32
- else:
33
- chunks.append(chunk.strip())
34
- chunk = sentence + ". "
35
- if chunk:
36
- chunks.append(chunk.strip())
37
  return chunks
38
 
39
- # Function to create FAISS index
40
- def create_faiss_index(chunks, model):
41
- embeddings = model.encode(chunks)
42
- dimension = len(embeddings[0])
43
- index = faiss.IndexFlatL2(dimension)
44
- index.add(embeddings)
45
- return index, chunks
 
 
46
 
47
- # Function to query Groq API
48
- def query_groq(question, model_name="llama-3.3-70b-versatile"):
49
- chat_completion = client.chat.completions.create(
50
- messages=[{"role": "user", "content": question}],
51
- model=model_name,
52
- )
53
- return chat_completion.choices[0].message.content
54
 
55
  # Streamlit app
56
- def main():
57
- st.title("RAG-based Application with Groq API")
58
- st.subheader("Query the document stored on Google Drive")
 
59
 
60
- st.write("Extracting text from the document...")
61
- text = extract_text_from_google_drive()
62
- st.write("Document text extracted successfully!")
 
 
 
 
 
 
63
 
64
- st.write("Chunking and embedding text...")
65
- model = SentenceTransformer("all-MiniLM-L6-v2")
66
- chunks = chunk_text(text)
67
- index, chunks = create_faiss_index(chunks, model)
68
- st.write(f"Created FAISS index with {len(chunks)} chunks.")
69
 
70
  # Query input
71
- question = st.text_input("Ask a question based on the document:")
72
- if question:
73
- st.write("Searching for relevant chunks...")
74
- question_embedding = model.encode([question])
75
- _, indices = index.search(question_embedding, k=1)
 
 
 
 
 
76
  relevant_chunk = chunks[indices[0][0]]
 
 
77
 
78
- st.write("Generating answer using Groq API...")
79
- answer = query_groq(relevant_chunk)
80
- st.write("### Answer:")
81
- st.write(answer)
82
-
83
- if __name__ == "__main__":
84
- main()
 
 
 
 
 
 
 
85
 
86
 
87
 
 
 
 
 
1
  import requests
2
+ import numpy as np
3
  import faiss
4
+ from PyPDF2 import PdfReader
5
+ from transformers import AutoTokenizer, AutoModel
6
  from groq import Groq
7
+ import streamlit as st
8
+ import torch
9
+ import os
10
 
11
+ # Initialize Groq client using secret API key
12
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
13
 
14
+ # Function to download and extract content from a public Google Drive PDF link
15
+ def extract_pdf_content(drive_url):
16
+ # Extract file ID from the Google Drive URL
17
+ file_id = drive_url.split("/d/")[1].split("/view")[0]
18
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
19
+
20
+ # Download the PDF content
21
+ response = requests.get(download_url)
22
+ if response.status_code != 200:
23
+ return None
24
 
25
+ # Save and extract text from the PDF
26
+ with open("document.pdf", "wb") as f:
27
+ f.write(response.content)
28
+ reader = PdfReader("document.pdf")
29
+ text = ""
30
+ for page in reader.pages:
31
+ text += page.extract_text()
32
  return text
33
 
34
+ # Function to chunk and tokenize text
35
+ def chunk_and_tokenize(text, tokenizer, chunk_size=512):
36
+ tokens = tokenizer.encode(text, add_special_tokens=False)
37
+ chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
 
 
 
 
 
 
 
 
 
38
  return chunks
39
 
40
+ # Function to compute embeddings and build FAISS index
41
+ def build_faiss_index(chunks, model):
42
+ embeddings = []
43
+ for chunk in chunks:
44
+ input_ids = torch.tensor([chunk])
45
+ with torch.no_grad():
46
+ embedding = model(input_ids).last_hidden_state.mean(dim=1).numpy()
47
+ embeddings.append(embedding)
48
+ embeddings = np.vstack(embeddings)
49
 
50
+ index = faiss.IndexFlatL2(embeddings.shape[1])
51
+ index.add(embeddings)
52
+ return index
 
 
 
 
53
 
54
  # Streamlit app
55
+ st.title("RAG-based Application with Groq API")
56
+
57
+ # Predefined Google Drive link
58
+ drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
59
 
60
+ # Extract document content
61
+ st.write("Extracting content from the document...")
62
+ text = extract_pdf_content(drive_url)
63
+ if text:
64
+ st.write("Document extracted successfully!")
65
+
66
+ # Initialize tokenizer and model
67
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
68
+ model = AutoModel.from_pretrained("bert-base-uncased")
69
 
70
+ st.write("Chunking and tokenizing content...")
71
+ chunks = chunk_and_tokenize(text, tokenizer)
72
+
73
+ st.write("Building FAISS index...")
74
+ index = build_faiss_index(chunks, model)
75
 
76
  # Query input
77
+ query = st.text_input("Enter your query:")
78
+ if query:
79
+ st.write("Searching for the most relevant chunk...")
80
+ query_tokens = tokenizer.encode(query, add_special_tokens=False)
81
+ query_embedding = (
82
+ model(torch.tensor([query_tokens])).last_hidden_state.mean(dim=1).numpy()
83
+ )
84
+ _, indices = index.search(query_embedding, k=1)
85
+
86
+ # Retrieve the most relevant chunk
87
  relevant_chunk = chunks[indices[0][0]]
88
+ relevant_text = tokenizer.decode(relevant_chunk)
89
+ st.write("Relevant chunk found:", relevant_text)
90
 
91
+ # Interact with Groq API
92
+ st.write("Querying the Groq API...")
93
+ chat_completion = client.chat.completions.create(
94
+ messages=[
95
+ {
96
+ "role": "user",
97
+ "content": relevant_text,
98
+ }
99
+ ],
100
+ model="llama-3.3-70b-versatile",
101
+ )
102
+ st.write("Model Response:", chat_completion.choices[0].message.content)
103
+ else:
104
+ st.error("Failed to extract content from the document.")
105
 
106
 
107