NavyDevilDoc commited on
Commit
39f313e
·
verified ·
1 Parent(s): 5a9d0e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -160
app.py CHANGED
@@ -1,211 +1,344 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
5
- import faiss
 
6
  from rank_bm25 import BM25Okapi
 
 
7
  import pypdf
8
  import docx
9
- import torch
 
 
 
10
 
11
  # --- CONFIGURATION ---
12
- st.set_page_config(page_title="Advanced Semantic Search", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # --- HELPER FUNCTIONS ---
15
  def parse_file(uploaded_file):
16
  text = ""
 
17
  try:
18
- if uploaded_file.name.endswith(".pdf"):
19
  reader = pypdf.PdfReader(uploaded_file)
20
- for page in reader.pages:
21
- text += page.extract_text() + "\n"
22
- elif uploaded_file.name.endswith(".docx"):
 
 
 
23
  doc = docx.Document(uploaded_file)
24
  text = "\n".join([para.text for para in doc.paragraphs])
25
- elif uploaded_file.name.endswith(".txt"):
26
  text = uploaded_file.read().decode("utf-8")
27
- elif uploaded_file.name.endswith(".csv"):
28
- df = pd.read_csv(uploaded_file)
29
- text = df.to_string()
30
  except Exception as e:
31
- st.error(f"Error reading file: {e}")
32
- return text
33
 
34
- def chunk_text(text, chunk_size=300, overlap=50):
 
 
 
35
  words = text.split()
36
  chunks = []
 
37
  for i in range(0, len(words), chunk_size - overlap):
38
- chunk = " ".join(words[i:i + chunk_size])
39
- if len(chunk) > 50:
40
- chunks.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return chunks
42
 
43
- # --- CORE LOGIC: RETRIEVER + RE-RANKER ---
44
- class SearchEngine:
45
- def __init__(self, bi_encoder_name):
46
- # 1. Bi-Encoder (Fast Retrieval)
47
- self.bi_encoder = SentenceTransformer(bi_encoder_name)
 
48
 
49
- # 2. Cross-Encoder (Accurate Re-Ranking)
50
- # We use a standard MS MARCO model designed for this exact task
51
  self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
52
 
53
- self.documents = []
54
- self.faiss_index = None
55
  self.bm25 = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def fit(self, documents):
58
- self.documents = documents
59
-
60
- # Build Dense Index
61
- embeddings = self.bi_encoder.encode(documents, convert_to_tensor=True)
62
- # Convert to numpy for FAISS
63
- embeddings_np = embeddings.cpu().numpy()
64
- faiss.normalize_L2(embeddings_np)
65
 
66
- dimension = embeddings_np.shape[1]
67
- self.faiss_index = faiss.IndexFlatIP(dimension)
68
- self.faiss_index.add(embeddings_np)
 
 
 
69
 
70
- # Build Sparse Index
71
- tokenized_corpus = [doc.lower().split() for doc in documents]
 
 
 
72
  self.bm25 = BM25Okapi(tokenized_corpus)
 
73
 
 
 
 
 
 
74
  def search(self, query, top_k=5, alpha=0.5):
75
- # STAGE 1: RETRIEVAL (Get a candidate pool)
76
- # We retrieve 3x the requested amount to give the re-ranker options
77
- candidate_k = top_k * 3
78
-
79
- # Vector Search
80
- query_vector = self.bi_encoder.encode([query])
81
- faiss.normalize_L2(query_vector)
82
- v_scores, v_indices = self.faiss_index.search(query_vector, min(len(self.documents), candidate_k))
83
-
84
- # BM25 Search
85
- tokenized_query = query.lower().split()
86
- bm25_scores = self.bm25.get_scores(tokenized_query)
87
-
88
- # Normalize BM25
89
- if len(bm25_scores) > 0 and max(bm25_scores) > 0:
90
- bm25_scores = (bm25_scores - min(bm25_scores)) / (max(bm25_scores) - min(bm25_scores))
91
-
92
- # Combine Scores to get candidates
93
- candidates = {} # {doc_idx: hybrid_score}
94
-
95
- # Map vector results
96
- for i, idx in enumerate(v_indices[0]):
97
- if idx != -1:
98
- v_score = v_scores[0][i]
99
- candidates[idx] = alpha * v_score
100
-
101
- # Add BM25 results (for all docs, efficient enough for small corpora)
102
- # In production, you'd only check top BM25 results
103
- top_bm25_indices = np.argsort(bm25_scores)[-candidate_k:]
104
- for idx in top_bm25_indices:
105
- score = (1 - alpha) * bm25_scores[idx]
106
- if idx in candidates:
107
- candidates[idx] += score
108
- else:
109
- candidates[idx] = score
110
-
111
- # Sort candidates by Hybrid Score
112
- sorted_candidates = sorted(candidates.items(), key=lambda x: x[1], reverse=True)[:candidate_k]
113
-
114
- # STAGE 2: RE-RANKING (Cross-Encoder)
115
- # Prepare pairs for the Cross-Encoder: [[query, doc1], [query, doc2]...]
116
- candidate_indices = [idx for idx, score in sorted_candidates]
117
- candidate_docs = [self.documents[idx] for idx in candidate_indices]
118
-
119
- pairs = [[query, doc] for doc in candidate_docs]
120
-
121
- if not pairs:
122
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Predict scores (logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  cross_scores = self.cross_encoder.predict(pairs)
126
 
127
- # Combine everything into final results
128
  final_results = []
129
- for i, idx in enumerate(candidate_indices):
130
  final_results.append({
131
- "chunk": self.documents[idx],
132
- "score": cross_scores[i], # This is the high-accuracy score
133
- "original_hybrid_score": sorted_candidates[i][1]
134
  })
135
 
136
- # Sort by Cross-Encoder score
137
- final_results = sorted(final_results, key=lambda x: x["score"], reverse=True)
138
-
139
  return final_results[:top_k]
140
 
141
- # --- UI LAYOUT ---
142
- st.title("🧠 Semantic Search: Hybrid + Cross-Encoder")
143
- st.markdown("""
144
- This system uses a **Two-Stage Retrieval Process**:
145
- 1. **Retrieval:** Finds top candidates using Vector (semantic) and BM25 (keyword) search.
146
- 2. **Re-Ranking:** A Cross-Encoder model reads the query and candidates to score true relevance.
147
- """)
 
 
 
 
148
 
149
  with st.sidebar:
150
- st.header("1. Setup Knowledge Base")
151
- uploaded_files = st.file_uploader(
152
- "Upload Documents",
153
- type=['txt', 'pdf', 'docx', 'csv'],
154
- accept_multiple_files=True
155
- )
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- st.header("2. Tuning")
160
- model_choice = st.selectbox(
161
- "Base Embedding Model",
162
- ["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
163
- help="Used for the initial fast retrieval."
164
- )
165
-
166
- alpha = st.slider("Hybrid Alpha", 0.0, 1.0, 0.4,
167
- help="0.0 = Keywords, 1.0 = Vectors. 0.4 is often best for Hybrid.")
168
-
169
- top_k = st.number_input("Final Results", 1, 20, 5)
170
 
171
- build_btn = st.button("Build Database")
172
-
173
- # --- APP STATE ---
174
- if 'engine' not in st.session_state:
175
- st.session_state.engine = None
176
-
177
- if build_btn and uploaded_files:
178
- with st.spinner("Processing files..."):
179
- all_chunks = []
180
- for file in uploaded_files:
181
- raw = parse_file(file)
182
- chunks = chunk_text(raw)
183
- all_chunks.extend(chunks)
184
-
185
- if all_chunks:
186
- # Initialize Engine
187
- st.session_state.engine = SearchEngine(model_choice)
188
- st.session_state.engine.fit(all_chunks)
189
- st.success(f"Indexed {len(all_chunks)} chunks!")
190
- else:
191
- st.warning("No text extracted.")
192
-
193
- # --- SEARCH ---
194
- if st.session_state.engine:
195
- query = st.text_input("Ask a question:")
196
- if query:
197
- with st.spinner("Retrieving & Re-Ranking..."):
198
- results = st.session_state.engine.search(query, top_k=top_k, alpha=alpha)
199
-
200
- for i, res in enumerate(results):
201
- score = res['score']
202
- # Color code high relevance
203
- color = "green" if score > 0 else "blue"
204
 
205
- with st.container():
206
- st.markdown(f"### Rank {i+1}")
207
- st.caption(f"Relevance Score: :{color}[{score:.3f}]")
208
- st.info(res['chunk'])
209
- st.divider()
210
- else:
211
- st.info("Upload documents to start.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ from sentence_transformers import SentenceTransformer, CrossEncoder
7
  from rank_bm25 import BM25Okapi
8
+ from huggingface_hub import HfApi, snapshot_download
9
+ from huggingface_hub.utils import RepositoryNotFoundError
10
  import pypdf
11
  import docx
12
+ import os
13
+ import shutil
14
+ import pickle
15
+ import time
16
 
17
  # --- CONFIGURATION ---
18
+ # REPLACE THIS WITH YOUR NEW DATASET NAME!
19
+ DATASET_REPO_ID = "NavyDevilDoc/navy-policy-index"
20
+ LOCAL_DB_PATH = "./data_store"
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+
23
+ st.set_page_config(page_title="Navy Search & Intel", layout="wide")
24
+
25
+ # --- PERSISTENCE MANAGER ---
26
+ class DataManager:
27
+ """Handles syncing the ChromaDB and BM25 index with the Hugging Face Hub"""
28
+
29
+ @staticmethod
30
+ def sync_from_hub():
31
+ """Downloads the latest DB from the HF Dataset"""
32
+ if not HF_TOKEN:
33
+ st.warning("HF_TOKEN not found in Secrets. Persistence will not work.")
34
+ return False
35
+
36
+ try:
37
+ st.toast("Syncing database from Cloud...", icon="☁️")
38
+ snapshot_download(
39
+ repo_id=DATASET_REPO_ID,
40
+ repo_type="dataset",
41
+ local_dir=LOCAL_DB_PATH,
42
+ token=HF_TOKEN
43
+ )
44
+ return True
45
+ except (RepositoryNotFoundError, Exception) as e:
46
+ # If dataset is empty or doesn't exist yet, that's fine for a fresh start
47
+ print(f"Cloud sync note: {e}")
48
+ return False
49
+
50
+ @staticmethod
51
+ def sync_to_hub():
52
+ """Uploads the local DB to the HF Dataset"""
53
+ if not HF_TOKEN:
54
+ return
55
+
56
+ api = HfApi(token=HF_TOKEN)
57
+ try:
58
+ st.toast("Uploading new index to Cloud...", icon="🚀")
59
+ api.upload_folder(
60
+ folder_path=LOCAL_DB_PATH,
61
+ repo_id=DATASET_REPO_ID,
62
+ repo_type="dataset",
63
+ commit_message="Auto-save: Update Index"
64
+ )
65
+ st.success("Database saved to Cloud!")
66
+ except Exception as e:
67
+ st.error(f"Failed to sync to cloud: {e}")
68
 
69
  # --- HELPER FUNCTIONS ---
70
  def parse_file(uploaded_file):
71
  text = ""
72
+ filename = uploaded_file.name
73
  try:
74
+ if filename.endswith(".pdf"):
75
  reader = pypdf.PdfReader(uploaded_file)
76
+ for i, page in enumerate(reader.pages):
77
+ page_text = page.extract_text()
78
+ if page_text:
79
+ # We inject Page markers into the text for the LLM to see later
80
+ text += f"\n[PAGE {i+1}] {page_text}"
81
+ elif filename.endswith(".docx"):
82
  doc = docx.Document(uploaded_file)
83
  text = "\n".join([para.text for para in doc.paragraphs])
84
+ elif filename.endswith(".txt"):
85
  text = uploaded_file.read().decode("utf-8")
 
 
 
86
  except Exception as e:
87
+ st.error(f"Error parsing {filename}: {e}")
88
+ return text, filename
89
 
90
+ def recursive_chunking(text, source, chunk_size=500, overlap=100):
91
+ """
92
+ Splits text into chunks, trying to respect page boundaries if possible.
93
+ """
94
  words = text.split()
95
  chunks = []
96
+
97
  for i in range(0, len(words), chunk_size - overlap):
98
+ chunk_words = words[i:i + chunk_size]
99
+ chunk_text = " ".join(chunk_words)
100
+
101
+ # Metadata extraction (simple heuristic for page numbers we injected)
102
+ page_num = "Unknown"
103
+ if "[PAGE" in chunk_text:
104
+ try:
105
+ # Find the last page marker in this chunk
106
+ start = chunk_text.rfind("[PAGE") + 6
107
+ end = chunk_text.find("]", start)
108
+ page_num = chunk_text[start:end]
109
+ except:
110
+ pass
111
+
112
+ if len(chunk_text) > 50:
113
+ chunks.append({
114
+ "text": chunk_text,
115
+ "metadata": {"source": source, "page": page_num}
116
+ })
117
  return chunks
118
 
119
+ # --- CORE SEARCH ENGINE ---
120
+ class PersistentSearchEngine:
121
+ def __init__(self, collection_name="navy_docs"):
122
+ # 1. Initialize ChromaDB (Persistent)
123
+ self.client = chromadb.PersistentClient(path=os.path.join(LOCAL_DB_PATH, "chroma"))
124
+ self.collection = self.client.get_or_create_collection(name=collection_name)
125
 
126
+ # 2. Load Models
127
+ self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
128
  self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
129
 
130
+ # 3. Initialize/Load BM25 (Sparse)
 
131
  self.bm25 = None
132
+ self.doc_store = [] # We need a shadow copy for BM25
133
+ self.load_bm25()
134
+
135
+ def load_bm25(self):
136
+ """Loads BM25 index from disk if it exists"""
137
+ bm25_path = os.path.join(LOCAL_DB_PATH, "bm25.pkl")
138
+ if os.path.exists(bm25_path):
139
+ with open(bm25_path, "rb") as f:
140
+ data = pickle.load(f)
141
+ self.bm25 = data['model']
142
+ self.doc_store = data['docs']
143
+
144
+ def save_bm25(self):
145
+ """Saves BM25 index to disk"""
146
+ bm25_path = os.path.join(LOCAL_DB_PATH, "bm25.pkl")
147
+ with open(bm25_path, "wb") as f:
148
+ pickle.dump({'model': self.bm25, 'docs': self.doc_store}, f)
149
+
150
+ def add_documents(self, parsed_chunks):
151
+ # 1. Add to Chroma (Dense)
152
+ ids = [f"{c['metadata']['source']}_{i}_{time.time()}" for i, c in enumerate(parsed_chunks)]
153
+ texts = [c['text'] for c in parsed_chunks]
154
+ metadatas = [c['metadata'] for c in parsed_chunks]
155
 
156
+ embeddings = self.bi_encoder.encode(texts).tolist()
 
 
 
 
 
 
 
157
 
158
+ self.collection.add(
159
+ documents=texts,
160
+ embeddings=embeddings,
161
+ metadatas=metadatas,
162
+ ids=ids
163
+ )
164
 
165
+ # 2. Update BM25 (Sparse)
166
+ # Note: BM25 is not incremental by default, we rebuild it.
167
+ # For huge datasets, we would implement incremental updates, but for <10k docs, rebuilding is fast.
168
+ current_docs = self.doc_store + texts
169
+ tokenized_corpus = [doc.lower().split() for doc in current_docs]
170
  self.bm25 = BM25Okapi(tokenized_corpus)
171
+ self.doc_store = current_docs
172
 
173
+ # 3. Save Aux Data
174
+ self.save_bm25()
175
+
176
+ return len(texts)
177
+
178
  def search(self, query, top_k=5, alpha=0.5):
179
+ # --- DENSE SEARCH (Chroma) ---
180
+ # Get more candidates for re-ranking
181
+ candidate_k = top_k * 3
182
+
183
+ query_embedding = self.bi_encoder.encode([query]).tolist()
184
+
185
+ chroma_results = self.collection.query(
186
+ query_embeddings=query_embedding,
187
+ n_results=candidate_k
188
+ )
189
+
190
+ # If DB is empty
191
+ if not chroma_results['documents']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  return []
193
+
194
+ # Process Chroma Results
195
+ # Chroma structure: {'ids': [[]], 'documents': [[]], 'metadatas': [[]], 'distances': [[]]}
196
+ dense_hits = {}
197
+ retrieved_docs_map = {} # ID -> Text/Meta mapping
198
+
199
+ for i, doc_id in enumerate(chroma_results['ids'][0]):
200
+ score = 1 - chroma_results['distances'][0][i] # Convert distance to similarity
201
+ dense_hits[doc_id] = score
202
+ retrieved_docs_map[doc_id] = {
203
+ 'text': chroma_results['documents'][0][i],
204
+ 'metadata': chroma_results['metadatas'][0][i]
205
+ }
206
+
207
+ # --- SPARSE SEARCH (BM25) ---
208
+ # Note: Mapping BM25 indices back to Chroma IDs is complex if lists aren't perfectly synced.
209
+ # For this Hybrid implementation, we will rely heavily on Chroma for the *candidates* # and use BM25 to score the *Query vs The Candidates* specifically.
210
+
211
+ hybrid_candidates = []
212
+
213
+ q_tokens = query.lower().split()
214
+
215
+ for doc_id, dense_score in dense_hits.items():
216
+ doc_text = retrieved_docs_map[doc_id]['text']
217
+
218
+ # Score this specific candidate with BM25 logic (on the fly)
219
+ # This is "Re-scoring" rather than "Retrieving" with BM25, which is safer for sync
220
+ doc_tokens = doc_text.lower().split()
221
+ # Simple term frequency for the candidate
222
+ bm25_score = 0
223
+ for token in q_tokens:
224
+ bm25_score += doc_tokens.count(token)
225
 
226
+ # Normalize BM25 score roughly (0-10 range usually, squeeze to 0-1)
227
+ bm25_score = min(bm25_score / 5.0, 1.0)
228
+
229
+ final_hybrid_score = (alpha * dense_score) + ((1-alpha) * bm25_score)
230
+
231
+ hybrid_candidates.append({
232
+ "id": doc_id,
233
+ "text": doc_text,
234
+ "metadata": retrieved_docs_map[doc_id]['metadata'],
235
+ "hybrid_score": final_hybrid_score
236
+ })
237
+
238
+ # Sort by Hybrid Score
239
+ hybrid_candidates.sort(key=lambda x: x['hybrid_score'], reverse=True)
240
+
241
+ # --- RE-RANKING (Cross-Encoder) ---
242
+ top_candidates = hybrid_candidates[:candidate_k]
243
+
244
+ pairs = [[query, c['text']] for c in top_candidates]
245
  cross_scores = self.cross_encoder.predict(pairs)
246
 
 
247
  final_results = []
248
+ for i, cand in enumerate(top_candidates):
249
  final_results.append({
250
+ "chunk": cand['text'],
251
+ "metadata": cand['metadata'],
252
+ "score": cross_scores[i]
253
  })
254
 
255
+ final_results.sort(key=lambda x: x['score'], reverse=True)
 
 
256
  return final_results[:top_k]
257
 
258
+ # --- UI LOGIC ---
259
+
260
+ # 1. Sync on Startup
261
+ if 'synced' not in st.session_state:
262
+ DataManager.sync_from_hub()
263
+ st.session_state.synced = True
264
+
265
+ # 2. Init Engine
266
+ if 'engine' not in st.session_state:
267
+ with st.spinner("Initializing Vector Database..."):
268
+ st.session_state.engine = PersistentSearchEngine()
269
 
270
  with st.sidebar:
271
+ st.header("🗄️ Knowledge Base")
272
+ uploaded_files = st.file_uploader("Ingest Documents", accept_multiple_files=True)
 
 
 
 
273
 
274
+ if uploaded_files and st.button("Add to Database"):
275
+ with st.spinner("Parsing & Indexing..."):
276
+ new_chunks = []
277
+ for f in uploaded_files:
278
+ txt, fname = parse_file(f)
279
+ chunks = recursive_chunking(txt, fname)
280
+ new_chunks.extend(chunks)
281
+
282
+ if new_chunks:
283
+ count = st.session_state.engine.add_documents(new_chunks)
284
+ DataManager.sync_to_hub() # Auto-save to cloud
285
+ st.success(f"Added {count} chunks and synced to Cloud!")
286
+
287
  st.divider()
288
+ st.info(f"Connected to: {DATASET_REPO_ID}")
289
+
290
+ # --- MAIN SEARCH UI ---
291
+ st.title("⚓ Navy Intelligent Search (RAG)")
292
+
293
+ query = st.text_input("Enter Query (e.g. 'Leave policy for O-3 and below'):")
294
+ col1, col2 = st.columns([1, 1])
295
+ with col1:
296
+ top_k = st.number_input("Documents", 1, 10, 3)
297
+ with col2:
298
+ alpha = st.slider("Hybrid Weight", 0.0, 1.0, 0.6, help="Higher = More Semantic")
299
+
300
+ if query:
301
+ results = st.session_state.engine.search(query, top_k=top_k, alpha=alpha)
302
 
303
+ # Store results for RAG
304
+ context_text = ""
 
 
 
 
 
 
 
 
 
305
 
306
+ st.markdown("### 🔍 Search Results")
307
+ for res in results:
308
+ meta = res['metadata']
309
+ score = res['score']
310
+ text = res['chunk']
311
+ context_text += f"Source: {meta['source']} (Page {meta['page']})\nContent: {text}\n\n"
312
+
313
+ with st.expander(f"{meta['source']} | Pg {meta['page']} (Score: {score:.2f})", expanded=True):
314
+ st.markdown(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
+ # --- RAG: SUMMARIZATION ---
317
+ st.divider()
318
+ st.markdown("### 🤖 AI Intelligence")
319
+ if st.button("Generate Summary / Answer"):
320
+ from huggingface_hub import InferenceClient
321
+
322
+ # Use a free, powerful model via HF Inference API
323
+ repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
324
+ llm_client = InferenceClient(model=repo_id, token=HF_TOKEN)
325
+
326
+ prompt = f"""
327
+ You are a Navy Administrative Aide. Answer the user's question based ONLY on the context provided below.
328
+ If the answer is not in the context, say "I cannot find the answer in the provided documents."
329
+
330
+ CONTEXT:
331
+ {context_text}
332
+
333
+ USER QUESTION:
334
+ {query}
335
+
336
+ ANSWER:
337
+ """
338
+
339
+ with st.spinner("Consulting LLM..."):
340
+ try:
341
+ response = llm_client.text_generation(prompt, max_new_tokens=500)
342
+ st.success(response)
343
+ except Exception as e:
344
+ st.error(f"LLM Error: {e}")