NHZ commited on
Commit
644455e
·
verified ·
1 Parent(s): acd8ea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -66
app.py CHANGED
@@ -1,86 +1,113 @@
1
  import os
2
- import re
3
  import requests
4
- import pdfplumber
5
  import streamlit as st
 
6
  import faiss
7
  from sentence_transformers import SentenceTransformer
 
8
 
9
- # Constants
10
- DOCUMENT_URL = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
11
- CHUNK_SIZE = 500
12
-
13
- # Function to download document
14
- def download_document(file_url):
15
- file_id = file_url.split("/d/")[1].split("/")[0]
16
- download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
17
  response = requests.get(download_url)
18
- output = "document.pdf"
19
- with open(output, "wb") as f:
20
- f.write(response.content)
21
- return output
22
-
23
- # Extract text from PDF
24
- def extract_text_from_pdf(file_path):
25
- text = ""
26
- with pdfplumber.open(file_path) as pdf:
27
- for page in pdf.pages:
28
- text += page.extract_text()
29
- return text
30
-
31
- # Chunk text into smaller parts
32
- def chunk_text(text, chunk_size=CHUNK_SIZE):
33
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
34
- chunks, current_chunk = [], ""
35
  for sentence in sentences:
36
- if len(current_chunk) + len(sentence) < chunk_size:
37
- current_chunk += sentence + " "
38
- else:
39
- chunks.append(current_chunk.strip())
40
- current_chunk = sentence + " "
 
 
 
41
  if current_chunk:
42
- chunks.append(current_chunk.strip())
 
43
  return chunks
44
 
45
- # Vectorize and store in FAISS
46
- def create_faiss_index(chunks, model):
47
- embeddings = model.encode(chunks)
48
  dimension = embeddings.shape[1]
49
  index = faiss.IndexFlatL2(dimension)
50
  index.add(embeddings)
51
- return index, embeddings
52
 
53
- # Query FAISS index
54
- def query_faiss(query, index, chunks, model, k=5):
55
- query_embedding = model.encode([query])
56
- distances, indices = index.search(query_embedding, k)
57
- return [chunks[i] for i in indices[0]]
58
 
59
- # Streamlit application
60
  def main():
61
- st.title("Document-Based Query Application")
62
- st.write("This application uses a pre-configured document as the dataset for answering queries.")
63
-
64
- # Download and process the document
65
- st.write("Processing the pre-configured document...")
66
- document_path = download_document(DOCUMENT_URL)
67
- text = extract_text_from_pdf(document_path)
68
- chunks = chunk_text(text)
69
-
70
- # Create FAISS index
71
- st.write("Creating FAISS index...")
72
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
73
- index, embeddings = create_faiss_index(chunks, embedding_model)
74
- st.success("Document processed and indexed!")
75
-
76
- # Query the database
77
- query = st.text_input("Enter your query")
78
- if query:
79
- st.write("Fetching relevant content from the document...")
80
- results = query_faiss(query, index, chunks, embedding_model)
81
- st.write("Top relevant chunks:")
82
- for i, result in enumerate(results):
83
- st.write(f"{i+1}. {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  main()
 
1
  import os
 
2
  import requests
 
3
  import streamlit as st
4
+ import numpy as np
5
  import faiss
6
  from sentence_transformers import SentenceTransformer
7
+ from groq import Groq
8
 
9
+ # Function to download document from a public Google Drive link
10
+ def download_file_from_public_link(url):
11
+ file_id = url.split("/d/")[1].split("/")[0]
12
+ download_url = f"https://drive.google.com/uc?id={file_id}&export=download"
 
 
 
 
13
  response = requests.get(download_url)
14
+ if response.status_code == 200:
15
+ return response.text
16
+ else:
17
+ raise Exception("Failed to download file from Google Drive.")
18
+
19
+ # Function to preprocess text
20
+ def preprocess_text(text, chunk_size=512):
21
+ sentences = text.split(".")
22
+ chunks = []
23
+ current_chunk = []
24
+ current_length = 0
25
+
 
 
 
 
 
26
  for sentence in sentences:
27
+ sentence_length = len(sentence.split())
28
+ if current_length + sentence_length > chunk_size:
29
+ chunks.append(" ".join(current_chunk))
30
+ current_chunk = []
31
+ current_length = 0
32
+ current_chunk.append(sentence)
33
+ current_length += sentence_length
34
+
35
  if current_chunk:
36
+ chunks.append(" ".join(current_chunk))
37
+
38
  return chunks
39
 
40
+ # Function to create a FAISS index
41
+ def create_faiss_index(embeddings):
 
42
  dimension = embeddings.shape[1]
43
  index = faiss.IndexFlatL2(dimension)
44
  index.add(embeddings)
45
+ return index
46
 
47
+ # Function to query FAISS index
48
+ def query_faiss_index(index, query_embedding, top_k=5):
49
+ distances, indices = index.search(query_embedding, top_k)
50
+ return indices[0], distances[0]
 
51
 
52
+ # Streamlit App
53
  def main():
54
+ st.title("RAG-based Application")
55
+
56
+ # Load Groq API Key from environment (set in Hugging Face secrets)
57
+ groq_api_key = os.getenv("GROQ_API_KEY")
58
+ if not groq_api_key:
59
+ st.error("Groq API Key is missing. Ensure it is set as a secret in Hugging Face.")
60
+ return
61
+
62
+ # Predefined Google Drive link
63
+ drive_link = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
64
+
65
+ if st.button("Load Document"):
66
+ try:
67
+ document_text = download_file_from_public_link(drive_link)
68
+ st.success("Document downloaded successfully!")
69
+
70
+ # Process the document
71
+ chunks = preprocess_text(document_text)
72
+ st.write(f"Document split into {len(chunks)} chunks.")
73
+
74
+ # Embed chunks
75
+ model = SentenceTransformer("all-MiniLM-L6-v2")
76
+ embeddings = np.array([model.encode(chunk) for chunk in chunks])
77
+
78
+ # Create FAISS index
79
+ index = create_faiss_index(embeddings)
80
+ st.success("FAISS index created.")
81
+
82
+ # Save index and chunks
83
+ st.session_state["index"] = index
84
+ st.session_state["chunks"] = chunks
85
+
86
+ except Exception as e:
87
+ st.error(f"Failed to load document: {str(e)}")
88
+
89
+ if "index" in st.session_state and "chunks" in st.session_state:
90
+ query = st.text_input("Enter your query")
91
+ if query:
92
+ model = SentenceTransformer("all-MiniLM-L6-v2")
93
+ query_embedding = model.encode([query])
94
+ indices, distances = query_faiss_index(st.session_state["index"], query_embedding)
95
+
96
+ # Display results
97
+ st.write("Relevant Chunks:")
98
+ for i, idx in enumerate(indices):
99
+ st.write(f"Chunk {i + 1} (Distance: {distances[i]}):")
100
+ st.write(st.session_state["chunks"][idx])
101
+
102
+ # Query Groq API
103
+ client = Groq(api_key=groq_api_key)
104
+ chat_completion = client.chat.completions.create(
105
+ messages=[{"role": "user", "content": query}],
106
+ model="llama-3.3-70b-versatile",
107
+ )
108
+ st.write("Groq Model Response:")
109
+ st.write(chat_completion.choices[0].message.content)
110
+
111
 
112
  if __name__ == "__main__":
113
  main()