NavyDevilDoc commited on
Commit
f09334e
ยท
verified ยท
1 Parent(s): 0b474cc

Update app.py

Browse files

refactored for document matching versus chunk matching

Files changed (1) hide show
  1. app.py +123 -154
app.py CHANGED
@@ -8,83 +8,37 @@ from huggingface_hub import HfApi, hf_hub_download
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
  import pypdf
10
  import docx
11
- import time
12
 
13
  # --- CONFIGURATION ---
14
- DATASET_REPO_ID = "NavyDevilDoc/navy-policy-index" # Your Dataset
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
-
17
- # File paths for local storage
18
  INDEX_FILE = "navy_index.faiss"
19
  META_FILE = "navy_metadata.pkl"
20
 
21
- st.set_page_config(page_title="Navy Search (FAISS)", layout="wide")
22
 
23
- # --- PERSISTENCE MANAGER ---
24
  class IndexManager:
25
- """Manages loading/saving the FAISS index and Metadata from Hugging Face"""
26
-
27
  @staticmethod
28
  def load_from_hub():
29
- """Downloads the index files from HF Dataset"""
30
- if not HF_TOKEN:
31
- st.warning("HF_TOKEN missing. Running in local-only mode.")
32
- return False
33
-
34
  try:
35
- with st.spinner("Downloading Knowledge Base..."):
36
- # Download Vector Index
37
- hf_hub_download(
38
- repo_id=DATASET_REPO_ID,
39
- filename=INDEX_FILE,
40
- repo_type="dataset",
41
- local_dir=".",
42
- token=HF_TOKEN
43
- )
44
- # Download Metadata
45
- hf_hub_download(
46
- repo_id=DATASET_REPO_ID,
47
- filename=META_FILE,
48
- repo_type="dataset",
49
- local_dir=".",
50
- token=HF_TOKEN
51
- )
52
  return True
53
- except (EntryNotFoundError, RepositoryNotFoundError):
54
- st.toast("No existing index found in Cloud. Starting fresh.", icon="๐Ÿ†•")
55
- return False
56
- except Exception as e:
57
- st.error(f"Sync Error: {e}")
58
- return False
59
 
60
  @staticmethod
61
  def save_to_hub():
62
- """Uploads the local files to HF Dataset"""
63
- if not HF_TOKEN:
64
- return
65
-
66
  api = HfApi(token=HF_TOKEN)
67
  try:
68
- st.toast("Syncing to Cloud...", icon="โ˜๏ธ")
69
- api.upload_file(
70
- path_or_fileobj=INDEX_FILE,
71
- path_in_repo=INDEX_FILE,
72
- repo_id=DATASET_REPO_ID,
73
- repo_type="dataset",
74
- commit_message="Update FAISS Index"
75
- )
76
- api.upload_file(
77
- path_or_fileobj=META_FILE,
78
- path_in_repo=META_FILE,
79
- repo_id=DATASET_REPO_ID,
80
- repo_type="dataset",
81
- commit_message="Update Metadata"
82
- )
83
- st.success("Knowledge Base Saved!")
84
- except Exception as e:
85
- st.error(f"Upload failed: {e}")
86
-
87
- # --- HELPER FUNCTIONS ---
88
  def parse_file(uploaded_file):
89
  text = ""
90
  filename = uploaded_file.name
@@ -92,159 +46,174 @@ def parse_file(uploaded_file):
92
  if filename.endswith(".pdf"):
93
  reader = pypdf.PdfReader(uploaded_file)
94
  for i, page in enumerate(reader.pages):
95
- page_text = page.extract_text()
96
- if page_text:
97
- text += f"\n[PAGE {i+1}] {page_text}"
98
  elif filename.endswith(".docx"):
99
  doc = docx.Document(uploaded_file)
100
  text = "\n".join([para.text for para in doc.paragraphs])
101
  elif filename.endswith(".txt"):
102
  text = uploaded_file.read().decode("utf-8")
103
- except Exception as e:
104
- st.error(f"Error parsing {filename}: {e}")
105
  return text, filename
106
 
107
  def recursive_chunking(text, source, chunk_size=500, overlap=100):
108
  words = text.split()
109
  chunks = []
110
  for i in range(0, len(words), chunk_size - overlap):
111
- chunk_words = words[i:i + chunk_size]
112
- chunk_text = " ".join(chunk_words)
113
-
114
- # Simple Page Extraction
115
- page_num = "Unknown"
116
- if "[PAGE" in chunk_text:
117
- try:
118
- start = chunk_text.rfind("[PAGE") + 6
119
- end = chunk_text.find("]", start)
120
- page_num = chunk_text[start:end]
121
- except: pass
122
-
123
  if len(chunk_text) > 50:
124
- chunks.append({
125
- "text": chunk_text,
126
- "source": source,
127
- "page": page_num
128
- })
129
  return chunks
130
 
131
- # --- CORE SEARCH ENGINE (FAISS VERSION) ---
132
- class RobustSearchEngine:
133
  def __init__(self):
134
- # Load Models (Force CPU to avoid meta tensor errors)
135
- self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")
136
  self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu", automodel_args={"low_cpu_mem_usage": False})
137
-
138
  self.index = None
139
- self.metadata = [] # List of dicts matching index order
140
 
141
- # Try to load existing index from disk
142
  if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE):
143
  self.index = faiss.read_index(INDEX_FILE)
144
- with open(META_FILE, "rb") as f:
145
- self.metadata = pickle.load(f)
146
- else:
147
- # Initialize new index
148
- self.index = None # Will init on first add
149
- self.metadata = []
150
 
151
  def add_documents(self, chunks):
152
- # 1. Encode
153
  texts = [c["text"] for c in chunks]
154
  embeddings = self.bi_encoder.encode(texts)
155
- faiss.normalize_L2(embeddings) # Normalize for Cosine Sim
156
 
157
- # 2. Init Index if needed
158
  if self.index is None:
159
- dimension = embeddings.shape[1]
160
- self.index = faiss.IndexFlatIP(dimension) # Inner Product = Cosine
161
 
162
- # 3. Add to Index
163
  self.index.add(embeddings)
164
  self.metadata.extend(chunks)
165
 
166
- # 4. Save to Disk
167
  faiss.write_index(self.index, INDEX_FILE)
168
- with open(META_FILE, "wb") as f:
169
- pickle.dump(self.metadata, f)
170
-
171
  return len(texts)
172
 
173
- def search(self, query, top_k=5):
174
- if not self.index or self.index.ntotal == 0:
175
- return []
176
 
177
- # 1. Retrieval
178
- candidate_k = top_k * 3
 
 
179
  q_vec = self.bi_encoder.encode([query])
180
  faiss.normalize_L2(q_vec)
181
 
182
  scores, indices = self.index.search(q_vec, min(self.index.ntotal, candidate_k))
183
 
184
- candidates = []
 
185
  for i, idx in enumerate(indices[0]):
186
  if idx != -1:
187
- candidates.append({
188
  "text": self.metadata[idx]["text"],
189
  "source": self.metadata[idx]["source"],
190
- "page": self.metadata[idx]["page"],
191
- "base_score": scores[0][i]
192
  })
193
-
194
- # 2. Re-Ranking
195
- pairs = [[query, c["text"]] for c in candidates]
196
- cross_scores = self.cross_encoder.predict(pairs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- for i, c in enumerate(candidates):
199
- c["score"] = cross_scores[i]
 
 
 
 
 
 
200
 
201
- # Sort
202
- final_results = sorted(candidates, key=lambda x: x["score"], reverse=True)
203
- return final_results[:top_k]
 
 
 
 
 
 
 
 
204
 
205
  # --- UI LOGIC ---
206
  if 'engine' not in st.session_state:
207
- # 1. Try cloud sync first
208
  IndexManager.load_from_hub()
209
- # 2. Start engine
210
- st.session_state.engine = RobustSearchEngine()
211
 
212
  with st.sidebar:
213
- st.header("๐Ÿ—„๏ธ Knowledge Base")
214
- uploaded_files = st.file_uploader("Ingest Documents", accept_multiple_files=True)
215
-
216
- if uploaded_files and st.button("Index Documents"):
217
- with st.spinner("Processing..."):
218
  new_chunks = []
219
  for f in uploaded_files:
220
  txt, fname = parse_file(f)
221
- chunks = recursive_chunking(txt, fname)
222
- new_chunks.extend(chunks)
223
-
224
  if new_chunks:
225
- count = st.session_state.engine.add_documents(new_chunks)
226
  IndexManager.save_to_hub()
227
- st.success(f"Added {count} chunks!")
 
 
 
228
 
229
- st.title("โš“ Navy Search (FAISS Architecture)")
230
- query = st.text_input("Enter Query:")
231
 
232
  if query:
233
- results = st.session_state.engine.search(query)
234
 
235
- st.markdown("### ๐Ÿ” Results")
236
- context_text = ""
 
 
 
237
  for res in results:
238
- context_text += f"Source: {res['source']}\n{res['text']}\n\n"
239
- with st.expander(f"{res['source']} (Pg {res['page']}) - Score {res['score']:.2f}", expanded=True):
240
- st.markdown(res['text'])
241
-
242
- if st.button("Generate Summary"):
243
- from huggingface_hub import InferenceClient
244
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=HF_TOKEN)
245
- prompt = f"Context:\n{context_text}\n\nUser: {query}\nAnswer:"
246
- with st.spinner("Thinking..."):
247
- try:
248
- st.write(client.text_generation(prompt, max_new_tokens=400))
249
- except Exception as e:
250
- st.error(f"LLM Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
  import pypdf
10
  import docx
 
11
 
12
  # --- CONFIGURATION ---
13
+ DATASET_REPO_ID = "NavyDevilDoc/navy-policy-index"
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
15
  INDEX_FILE = "navy_index.faiss"
16
  META_FILE = "navy_metadata.pkl"
17
 
18
+ st.set_page_config(page_title="Document Finder", layout="wide")
19
 
20
+ # --- PERSISTENCE (SAME AS BEFORE) ---
21
  class IndexManager:
 
 
22
  @staticmethod
23
  def load_from_hub():
24
+ if not HF_TOKEN: return False
 
 
 
 
25
  try:
26
+ hf_hub_download(repo_id=DATASET_REPO_ID, filename=INDEX_FILE, local_dir=".", token=HF_TOKEN)
27
+ hf_hub_download(repo_id=DATASET_REPO_ID, filename=META_FILE, local_dir=".", token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return True
29
+ except: return False
 
 
 
 
 
30
 
31
  @staticmethod
32
  def save_to_hub():
33
+ if not HF_TOKEN: return
 
 
 
34
  api = HfApi(token=HF_TOKEN)
35
  try:
36
+ api.upload_file(path_or_fileobj=INDEX_FILE, path_in_repo=INDEX_FILE, repo_id=DATASET_REPO_ID, repo_type="dataset")
37
+ api.upload_file(path_or_fileobj=META_FILE, path_in_repo=META_FILE, repo_id=DATASET_REPO_ID, repo_type="dataset")
38
+ st.toast("Database Synced!", icon="โ˜๏ธ")
39
+ except Exception as e: st.error(f"Sync Error: {e}")
40
+
41
+ # --- PARSING & CHUNKING (SAME AS BEFORE) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def parse_file(uploaded_file):
43
  text = ""
44
  filename = uploaded_file.name
 
46
  if filename.endswith(".pdf"):
47
  reader = pypdf.PdfReader(uploaded_file)
48
  for i, page in enumerate(reader.pages):
49
+ if page.extract_text(): text += f"\n[PAGE {i+1}] {page.extract_text()}"
 
 
50
  elif filename.endswith(".docx"):
51
  doc = docx.Document(uploaded_file)
52
  text = "\n".join([para.text for para in doc.paragraphs])
53
  elif filename.endswith(".txt"):
54
  text = uploaded_file.read().decode("utf-8")
55
+ except: pass
 
56
  return text, filename
57
 
58
  def recursive_chunking(text, source, chunk_size=500, overlap=100):
59
  words = text.split()
60
  chunks = []
61
  for i in range(0, len(words), chunk_size - overlap):
62
+ chunk_text = " ".join(words[i:i + chunk_size])
 
 
 
 
 
 
 
 
 
 
 
63
  if len(chunk_text) > 50:
64
+ chunks.append({"text": chunk_text, "source": source})
 
 
 
 
65
  return chunks
66
 
67
+ # --- CORE SEARCH ENGINE (UPDATED FOR DOC LEVEL) ---
68
+ class DocSearchEngine:
69
  def __init__(self):
70
+ self.bi_encoder = SentenceTransformer('all-mpnet-base-v2', device="cpu")
 
71
  self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu", automodel_args={"low_cpu_mem_usage": False})
 
72
  self.index = None
73
+ self.metadata = []
74
 
 
75
  if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE):
76
  self.index = faiss.read_index(INDEX_FILE)
77
+ with open(META_FILE, "rb") as f: self.metadata = pickle.load(f)
 
 
 
 
 
78
 
79
  def add_documents(self, chunks):
 
80
  texts = [c["text"] for c in chunks]
81
  embeddings = self.bi_encoder.encode(texts)
82
+ faiss.normalize_L2(embeddings)
83
 
 
84
  if self.index is None:
85
+ self.index = faiss.IndexFlatIP(embeddings.shape[1])
 
86
 
 
87
  self.index.add(embeddings)
88
  self.metadata.extend(chunks)
89
 
 
90
  faiss.write_index(self.index, INDEX_FILE)
91
+ with open(META_FILE, "wb") as f: pickle.dump(self.metadata, f)
 
 
92
  return len(texts)
93
 
94
+ def search_documents(self, query, top_k=5):
95
+ if not self.index or self.index.ntotal == 0: return []
 
96
 
97
+ # 1. Retrieve MANY chunks (to ensure we find diverse documents)
98
+ # If we only get top 5 chunks, they might all be from the same document.
99
+ candidate_k = top_k * 10
100
+
101
  q_vec = self.bi_encoder.encode([query])
102
  faiss.normalize_L2(q_vec)
103
 
104
  scores, indices = self.index.search(q_vec, min(self.index.ntotal, candidate_k))
105
 
106
+ # 2. Extract Raw Candidates
107
+ raw_candidates = []
108
  for i, idx in enumerate(indices[0]):
109
  if idx != -1:
110
+ raw_candidates.append({
111
  "text": self.metadata[idx]["text"],
112
  "source": self.metadata[idx]["source"],
113
+ "bi_score": scores[0][i]
 
114
  })
115
+
116
+ # 3. Aggregation: Find the BEST chunk for each document
117
+ # We group by 'source' and keep the max score
118
+ doc_map = {} # {filename: {best_score, best_snippet}}
119
+
120
+ for cand in raw_candidates:
121
+ source = cand['source']
122
+ score = cand['bi_score']
123
+
124
+ # Initialization
125
+ if source not in doc_map:
126
+ doc_map[source] = {"score": score, "snippet": cand['text']}
127
+ else:
128
+ # Update if we found a better chunk in the same doc
129
+ if score > doc_map[source]["score"]:
130
+ doc_map[source]["score"] = score
131
+ doc_map[source]["snippet"] = cand['text']
132
+
133
+ # 4. Sort Documents by their Best Chunk Score
134
+ ranked_docs = sorted(doc_map.items(), key=lambda item: item[1]['score'], reverse=True)
135
 
136
+ # 5. Cross-Encoder Verification (Optional but recommended)
137
+ # We verify the "Best Snippet" to ensure it's not a hallucination
138
+ final_results = []
139
+ top_docs = ranked_docs[:top_k] # Only re-rank the top contenders
140
+
141
+ if top_docs:
142
+ pairs = [[query, doc[1]['snippet']] for doc in top_docs]
143
+ cross_scores = self.cross_encoder.predict(pairs)
144
 
145
+ for i, (source, data) in enumerate(top_docs):
146
+ final_results.append({
147
+ "source": source,
148
+ "score": cross_scores[i], # High accuracy score
149
+ "snippet": data['snippet']
150
+ })
151
+
152
+ # Final Sort after Cross-Encoder
153
+ final_results = sorted(final_results, key=lambda x: x["score"], reverse=True)
154
+
155
+ return final_results
156
 
157
  # --- UI LOGIC ---
158
  if 'engine' not in st.session_state:
 
159
  IndexManager.load_from_hub()
160
+ st.session_state.engine = DocSearchEngine()
 
161
 
162
  with st.sidebar:
163
+ st.header("๐Ÿ—„๏ธ Upload Documents")
164
+ uploaded_files = st.file_uploader("Upload Files", accept_multiple_files=True)
165
+ if uploaded_files and st.button("Index"):
166
+ with st.spinner("Indexing..."):
 
167
  new_chunks = []
168
  for f in uploaded_files:
169
  txt, fname = parse_file(f)
170
+ new_chunks.extend(recursive_chunking(txt, fname))
 
 
171
  if new_chunks:
172
+ st.session_state.engine.add_documents(new_chunks)
173
  IndexManager.save_to_hub()
174
+ st.success("Indexed!")
175
+
176
+ st.title("โš“ Document Finder")
177
+ st.caption("Locates the specific Instruction or NAVADMIN relevant to your query.")
178
 
179
+ query = st.text_input("What are you looking for?", placeholder="e.g. 'FY25 Retention Bonuses'")
 
180
 
181
  if query:
182
+ results = st.session_state.engine.search_documents(query, top_k=5)
183
 
184
+ st.subheader("Top Relevant Documents")
185
+
186
+ if not results:
187
+ st.info("No documents found.")
188
+
189
  for res in results:
190
+ score = res['score']
191
+
192
+ # Color coding the confidence
193
+ if score > 2:
194
+ border_color = "#09ab3b" # Green
195
+ confidence = "High Match"
196
+ elif score > 0:
197
+ border_color = "#ffbd45" # Orange
198
+ confidence = "Possible Match"
199
+ else:
200
+ border_color = "#ff4b4b" # Red
201
+ confidence = "Low Match"
202
+
203
+ # --- DOCUMENT CARD UI ---
204
+ with st.container():
205
+ st.markdown(f"""
206
+ <div style="
207
+ border: 1px solid #ddd;
208
+ border-left: 5px solid {border_color};
209
+ padding: 15px;
210
+ border-radius: 5px;
211
+ margin-bottom: 10px;
212
+ ">
213
+ <h3 style="margin:0; padding:0;">๐Ÿ“„ {res['source']}</h3>
214
+ <small style="color: gray;">Confidence: {confidence} ({score:.2f})</small>
215
+ </div>
216
+ """, unsafe_allow_html=True)
217
+
218
+ with st.expander("View matching excerpt"):
219
+ st.markdown(f"**...{res['snippet']}...**")