Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,211 +1,344 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
-
import
|
|
|
|
| 6 |
from rank_bm25 import BM25Okapi
|
|
|
|
|
|
|
| 7 |
import pypdf
|
| 8 |
import docx
|
| 9 |
-
import
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# --- CONFIGURATION ---
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# --- HELPER FUNCTIONS ---
|
| 15 |
def parse_file(uploaded_file):
|
| 16 |
text = ""
|
|
|
|
| 17 |
try:
|
| 18 |
-
if
|
| 19 |
reader = pypdf.PdfReader(uploaded_file)
|
| 20 |
-
for page in reader.pages:
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
doc = docx.Document(uploaded_file)
|
| 24 |
text = "\n".join([para.text for para in doc.paragraphs])
|
| 25 |
-
elif
|
| 26 |
text = uploaded_file.read().decode("utf-8")
|
| 27 |
-
elif uploaded_file.name.endswith(".csv"):
|
| 28 |
-
df = pd.read_csv(uploaded_file)
|
| 29 |
-
text = df.to_string()
|
| 30 |
except Exception as e:
|
| 31 |
-
st.error(f"Error
|
| 32 |
-
return text
|
| 33 |
|
| 34 |
-
def
|
|
|
|
|
|
|
|
|
|
| 35 |
words = text.split()
|
| 36 |
chunks = []
|
|
|
|
| 37 |
for i in range(0, len(words), chunk_size - overlap):
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return chunks
|
| 42 |
|
| 43 |
-
# --- CORE
|
| 44 |
-
class
|
| 45 |
-
def __init__(self,
|
| 46 |
-
# 1.
|
| 47 |
-
self.
|
|
|
|
| 48 |
|
| 49 |
-
# 2.
|
| 50 |
-
|
| 51 |
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 52 |
|
| 53 |
-
|
| 54 |
-
self.faiss_index = None
|
| 55 |
self.bm25 = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
self.documents = documents
|
| 59 |
-
|
| 60 |
-
# Build Dense Index
|
| 61 |
-
embeddings = self.bi_encoder.encode(documents, convert_to_tensor=True)
|
| 62 |
-
# Convert to numpy for FAISS
|
| 63 |
-
embeddings_np = embeddings.cpu().numpy()
|
| 64 |
-
faiss.normalize_L2(embeddings_np)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
|
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def search(self, query, top_k=5, alpha=0.5):
|
| 75 |
-
#
|
| 76 |
-
#
|
| 77 |
-
candidate_k = top_k * 3
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Normalize BM25
|
| 89 |
-
if len(bm25_scores) > 0 and max(bm25_scores) > 0:
|
| 90 |
-
bm25_scores = (bm25_scores - min(bm25_scores)) / (max(bm25_scores) - min(bm25_scores))
|
| 91 |
-
|
| 92 |
-
# Combine Scores to get candidates
|
| 93 |
-
candidates = {} # {doc_idx: hybrid_score}
|
| 94 |
-
|
| 95 |
-
# Map vector results
|
| 96 |
-
for i, idx in enumerate(v_indices[0]):
|
| 97 |
-
if idx != -1:
|
| 98 |
-
v_score = v_scores[0][i]
|
| 99 |
-
candidates[idx] = alpha * v_score
|
| 100 |
-
|
| 101 |
-
# Add BM25 results (for all docs, efficient enough for small corpora)
|
| 102 |
-
# In production, you'd only check top BM25 results
|
| 103 |
-
top_bm25_indices = np.argsort(bm25_scores)[-candidate_k:]
|
| 104 |
-
for idx in top_bm25_indices:
|
| 105 |
-
score = (1 - alpha) * bm25_scores[idx]
|
| 106 |
-
if idx in candidates:
|
| 107 |
-
candidates[idx] += score
|
| 108 |
-
else:
|
| 109 |
-
candidates[idx] = score
|
| 110 |
-
|
| 111 |
-
# Sort candidates by Hybrid Score
|
| 112 |
-
sorted_candidates = sorted(candidates.items(), key=lambda x: x[1], reverse=True)[:candidate_k]
|
| 113 |
-
|
| 114 |
-
# STAGE 2: RE-RANKING (Cross-Encoder)
|
| 115 |
-
# Prepare pairs for the Cross-Encoder: [[query, doc1], [query, doc2]...]
|
| 116 |
-
candidate_indices = [idx for idx, score in sorted_candidates]
|
| 117 |
-
candidate_docs = [self.documents[idx] for idx in candidate_indices]
|
| 118 |
-
|
| 119 |
-
pairs = [[query, doc] for doc in candidate_docs]
|
| 120 |
-
|
| 121 |
-
if not pairs:
|
| 122 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
cross_scores = self.cross_encoder.predict(pairs)
|
| 126 |
|
| 127 |
-
# Combine everything into final results
|
| 128 |
final_results = []
|
| 129 |
-
for i,
|
| 130 |
final_results.append({
|
| 131 |
-
"chunk":
|
| 132 |
-
"
|
| 133 |
-
"
|
| 134 |
})
|
| 135 |
|
| 136 |
-
|
| 137 |
-
final_results = sorted(final_results, key=lambda x: x["score"], reverse=True)
|
| 138 |
-
|
| 139 |
return final_results[:top_k]
|
| 140 |
|
| 141 |
-
# --- UI
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
with st.sidebar:
|
| 150 |
-
st.header("
|
| 151 |
-
uploaded_files = st.file_uploader(
|
| 152 |
-
"Upload Documents",
|
| 153 |
-
type=['txt', 'pdf', 'docx', 'csv'],
|
| 154 |
-
accept_multiple_files=True
|
| 155 |
-
)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
"Base Embedding Model",
|
| 162 |
-
["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
|
| 163 |
-
help="Used for the initial fast retrieval."
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
alpha = st.slider("Hybrid Alpha", 0.0, 1.0, 0.4,
|
| 167 |
-
help="0.0 = Keywords, 1.0 = Vectors. 0.4 is often best for Hybrid.")
|
| 168 |
-
|
| 169 |
-
top_k = st.number_input("Final Results", 1, 20, 5)
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
for file in uploaded_files:
|
| 181 |
-
raw = parse_file(file)
|
| 182 |
-
chunks = chunk_text(raw)
|
| 183 |
-
all_chunks.extend(chunks)
|
| 184 |
-
|
| 185 |
-
if all_chunks:
|
| 186 |
-
# Initialize Engine
|
| 187 |
-
st.session_state.engine = SearchEngine(model_choice)
|
| 188 |
-
st.session_state.engine.fit(all_chunks)
|
| 189 |
-
st.success(f"Indexed {len(all_chunks)} chunks!")
|
| 190 |
-
else:
|
| 191 |
-
st.warning("No text extracted.")
|
| 192 |
-
|
| 193 |
-
# --- SEARCH ---
|
| 194 |
-
if st.session_state.engine:
|
| 195 |
-
query = st.text_input("Ask a question:")
|
| 196 |
-
if query:
|
| 197 |
-
with st.spinner("Retrieving & Re-Ranking..."):
|
| 198 |
-
results = st.session_state.engine.search(query, top_k=top_k, alpha=alpha)
|
| 199 |
-
|
| 200 |
-
for i, res in enumerate(results):
|
| 201 |
-
score = res['score']
|
| 202 |
-
# Color code high relevance
|
| 203 |
-
color = "green" if score > 0 else "blue"
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import numpy as np
|
| 4 |
+
import chromadb
|
| 5 |
+
from chromadb.config import Settings
|
| 6 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 7 |
from rank_bm25 import BM25Okapi
|
| 8 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 9 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 10 |
import pypdf
|
| 11 |
import docx
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
import pickle
|
| 15 |
+
import time
|
| 16 |
|
| 17 |
# --- CONFIGURATION ---
|
| 18 |
+
# REPLACE THIS WITH YOUR NEW DATASET NAME!
|
| 19 |
+
DATASET_REPO_ID = "NavyDevilDoc/navy-policy-index"
|
| 20 |
+
LOCAL_DB_PATH = "./data_store"
|
| 21 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 22 |
+
|
| 23 |
+
st.set_page_config(page_title="Navy Search & Intel", layout="wide")
|
| 24 |
+
|
| 25 |
+
# --- PERSISTENCE MANAGER ---
|
| 26 |
+
class DataManager:
|
| 27 |
+
"""Handles syncing the ChromaDB and BM25 index with the Hugging Face Hub"""
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def sync_from_hub():
|
| 31 |
+
"""Downloads the latest DB from the HF Dataset"""
|
| 32 |
+
if not HF_TOKEN:
|
| 33 |
+
st.warning("HF_TOKEN not found in Secrets. Persistence will not work.")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
st.toast("Syncing database from Cloud...", icon="☁️")
|
| 38 |
+
snapshot_download(
|
| 39 |
+
repo_id=DATASET_REPO_ID,
|
| 40 |
+
repo_type="dataset",
|
| 41 |
+
local_dir=LOCAL_DB_PATH,
|
| 42 |
+
token=HF_TOKEN
|
| 43 |
+
)
|
| 44 |
+
return True
|
| 45 |
+
except (RepositoryNotFoundError, Exception) as e:
|
| 46 |
+
# If dataset is empty or doesn't exist yet, that's fine for a fresh start
|
| 47 |
+
print(f"Cloud sync note: {e}")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def sync_to_hub():
|
| 52 |
+
"""Uploads the local DB to the HF Dataset"""
|
| 53 |
+
if not HF_TOKEN:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
api = HfApi(token=HF_TOKEN)
|
| 57 |
+
try:
|
| 58 |
+
st.toast("Uploading new index to Cloud...", icon="🚀")
|
| 59 |
+
api.upload_folder(
|
| 60 |
+
folder_path=LOCAL_DB_PATH,
|
| 61 |
+
repo_id=DATASET_REPO_ID,
|
| 62 |
+
repo_type="dataset",
|
| 63 |
+
commit_message="Auto-save: Update Index"
|
| 64 |
+
)
|
| 65 |
+
st.success("Database saved to Cloud!")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
st.error(f"Failed to sync to cloud: {e}")
|
| 68 |
|
| 69 |
# --- HELPER FUNCTIONS ---
|
| 70 |
def parse_file(uploaded_file):
|
| 71 |
text = ""
|
| 72 |
+
filename = uploaded_file.name
|
| 73 |
try:
|
| 74 |
+
if filename.endswith(".pdf"):
|
| 75 |
reader = pypdf.PdfReader(uploaded_file)
|
| 76 |
+
for i, page in enumerate(reader.pages):
|
| 77 |
+
page_text = page.extract_text()
|
| 78 |
+
if page_text:
|
| 79 |
+
# We inject Page markers into the text for the LLM to see later
|
| 80 |
+
text += f"\n[PAGE {i+1}] {page_text}"
|
| 81 |
+
elif filename.endswith(".docx"):
|
| 82 |
doc = docx.Document(uploaded_file)
|
| 83 |
text = "\n".join([para.text for para in doc.paragraphs])
|
| 84 |
+
elif filename.endswith(".txt"):
|
| 85 |
text = uploaded_file.read().decode("utf-8")
|
|
|
|
|
|
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
+
st.error(f"Error parsing {filename}: {e}")
|
| 88 |
+
return text, filename
|
| 89 |
|
| 90 |
+
def recursive_chunking(text, source, chunk_size=500, overlap=100):
|
| 91 |
+
"""
|
| 92 |
+
Splits text into chunks, trying to respect page boundaries if possible.
|
| 93 |
+
"""
|
| 94 |
words = text.split()
|
| 95 |
chunks = []
|
| 96 |
+
|
| 97 |
for i in range(0, len(words), chunk_size - overlap):
|
| 98 |
+
chunk_words = words[i:i + chunk_size]
|
| 99 |
+
chunk_text = " ".join(chunk_words)
|
| 100 |
+
|
| 101 |
+
# Metadata extraction (simple heuristic for page numbers we injected)
|
| 102 |
+
page_num = "Unknown"
|
| 103 |
+
if "[PAGE" in chunk_text:
|
| 104 |
+
try:
|
| 105 |
+
# Find the last page marker in this chunk
|
| 106 |
+
start = chunk_text.rfind("[PAGE") + 6
|
| 107 |
+
end = chunk_text.find("]", start)
|
| 108 |
+
page_num = chunk_text[start:end]
|
| 109 |
+
except:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
if len(chunk_text) > 50:
|
| 113 |
+
chunks.append({
|
| 114 |
+
"text": chunk_text,
|
| 115 |
+
"metadata": {"source": source, "page": page_num}
|
| 116 |
+
})
|
| 117 |
return chunks
|
| 118 |
|
| 119 |
+
# --- CORE SEARCH ENGINE ---
|
| 120 |
+
class PersistentSearchEngine:
|
| 121 |
+
def __init__(self, collection_name="navy_docs"):
|
| 122 |
+
# 1. Initialize ChromaDB (Persistent)
|
| 123 |
+
self.client = chromadb.PersistentClient(path=os.path.join(LOCAL_DB_PATH, "chroma"))
|
| 124 |
+
self.collection = self.client.get_or_create_collection(name=collection_name)
|
| 125 |
|
| 126 |
+
# 2. Load Models
|
| 127 |
+
self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 128 |
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 129 |
|
| 130 |
+
# 3. Initialize/Load BM25 (Sparse)
|
|
|
|
| 131 |
self.bm25 = None
|
| 132 |
+
self.doc_store = [] # We need a shadow copy for BM25
|
| 133 |
+
self.load_bm25()
|
| 134 |
+
|
| 135 |
+
def load_bm25(self):
|
| 136 |
+
"""Loads BM25 index from disk if it exists"""
|
| 137 |
+
bm25_path = os.path.join(LOCAL_DB_PATH, "bm25.pkl")
|
| 138 |
+
if os.path.exists(bm25_path):
|
| 139 |
+
with open(bm25_path, "rb") as f:
|
| 140 |
+
data = pickle.load(f)
|
| 141 |
+
self.bm25 = data['model']
|
| 142 |
+
self.doc_store = data['docs']
|
| 143 |
+
|
| 144 |
+
def save_bm25(self):
|
| 145 |
+
"""Saves BM25 index to disk"""
|
| 146 |
+
bm25_path = os.path.join(LOCAL_DB_PATH, "bm25.pkl")
|
| 147 |
+
with open(bm25_path, "wb") as f:
|
| 148 |
+
pickle.dump({'model': self.bm25, 'docs': self.doc_store}, f)
|
| 149 |
+
|
| 150 |
+
def add_documents(self, parsed_chunks):
|
| 151 |
+
# 1. Add to Chroma (Dense)
|
| 152 |
+
ids = [f"{c['metadata']['source']}_{i}_{time.time()}" for i, c in enumerate(parsed_chunks)]
|
| 153 |
+
texts = [c['text'] for c in parsed_chunks]
|
| 154 |
+
metadatas = [c['metadata'] for c in parsed_chunks]
|
| 155 |
|
| 156 |
+
embeddings = self.bi_encoder.encode(texts).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
self.collection.add(
|
| 159 |
+
documents=texts,
|
| 160 |
+
embeddings=embeddings,
|
| 161 |
+
metadatas=metadatas,
|
| 162 |
+
ids=ids
|
| 163 |
+
)
|
| 164 |
|
| 165 |
+
# 2. Update BM25 (Sparse)
|
| 166 |
+
# Note: BM25 is not incremental by default, we rebuild it.
|
| 167 |
+
# For huge datasets, we would implement incremental updates, but for <10k docs, rebuilding is fast.
|
| 168 |
+
current_docs = self.doc_store + texts
|
| 169 |
+
tokenized_corpus = [doc.lower().split() for doc in current_docs]
|
| 170 |
self.bm25 = BM25Okapi(tokenized_corpus)
|
| 171 |
+
self.doc_store = current_docs
|
| 172 |
|
| 173 |
+
# 3. Save Aux Data
|
| 174 |
+
self.save_bm25()
|
| 175 |
+
|
| 176 |
+
return len(texts)
|
| 177 |
+
|
| 178 |
def search(self, query, top_k=5, alpha=0.5):
|
| 179 |
+
# --- DENSE SEARCH (Chroma) ---
|
| 180 |
+
# Get more candidates for re-ranking
|
| 181 |
+
candidate_k = top_k * 3
|
| 182 |
+
|
| 183 |
+
query_embedding = self.bi_encoder.encode([query]).tolist()
|
| 184 |
+
|
| 185 |
+
chroma_results = self.collection.query(
|
| 186 |
+
query_embeddings=query_embedding,
|
| 187 |
+
n_results=candidate_k
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# If DB is empty
|
| 191 |
+
if not chroma_results['documents']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
return []
|
| 193 |
+
|
| 194 |
+
# Process Chroma Results
|
| 195 |
+
# Chroma structure: {'ids': [[]], 'documents': [[]], 'metadatas': [[]], 'distances': [[]]}
|
| 196 |
+
dense_hits = {}
|
| 197 |
+
retrieved_docs_map = {} # ID -> Text/Meta mapping
|
| 198 |
+
|
| 199 |
+
for i, doc_id in enumerate(chroma_results['ids'][0]):
|
| 200 |
+
score = 1 - chroma_results['distances'][0][i] # Convert distance to similarity
|
| 201 |
+
dense_hits[doc_id] = score
|
| 202 |
+
retrieved_docs_map[doc_id] = {
|
| 203 |
+
'text': chroma_results['documents'][0][i],
|
| 204 |
+
'metadata': chroma_results['metadatas'][0][i]
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# --- SPARSE SEARCH (BM25) ---
|
| 208 |
+
# Note: Mapping BM25 indices back to Chroma IDs is complex if lists aren't perfectly synced.
|
| 209 |
+
# For this Hybrid implementation, we will rely heavily on Chroma for the *candidates* # and use BM25 to score the *Query vs The Candidates* specifically.
|
| 210 |
+
|
| 211 |
+
hybrid_candidates = []
|
| 212 |
+
|
| 213 |
+
q_tokens = query.lower().split()
|
| 214 |
+
|
| 215 |
+
for doc_id, dense_score in dense_hits.items():
|
| 216 |
+
doc_text = retrieved_docs_map[doc_id]['text']
|
| 217 |
+
|
| 218 |
+
# Score this specific candidate with BM25 logic (on the fly)
|
| 219 |
+
# This is "Re-scoring" rather than "Retrieving" with BM25, which is safer for sync
|
| 220 |
+
doc_tokens = doc_text.lower().split()
|
| 221 |
+
# Simple term frequency for the candidate
|
| 222 |
+
bm25_score = 0
|
| 223 |
+
for token in q_tokens:
|
| 224 |
+
bm25_score += doc_tokens.count(token)
|
| 225 |
|
| 226 |
+
# Normalize BM25 score roughly (0-10 range usually, squeeze to 0-1)
|
| 227 |
+
bm25_score = min(bm25_score / 5.0, 1.0)
|
| 228 |
+
|
| 229 |
+
final_hybrid_score = (alpha * dense_score) + ((1-alpha) * bm25_score)
|
| 230 |
+
|
| 231 |
+
hybrid_candidates.append({
|
| 232 |
+
"id": doc_id,
|
| 233 |
+
"text": doc_text,
|
| 234 |
+
"metadata": retrieved_docs_map[doc_id]['metadata'],
|
| 235 |
+
"hybrid_score": final_hybrid_score
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
# Sort by Hybrid Score
|
| 239 |
+
hybrid_candidates.sort(key=lambda x: x['hybrid_score'], reverse=True)
|
| 240 |
+
|
| 241 |
+
# --- RE-RANKING (Cross-Encoder) ---
|
| 242 |
+
top_candidates = hybrid_candidates[:candidate_k]
|
| 243 |
+
|
| 244 |
+
pairs = [[query, c['text']] for c in top_candidates]
|
| 245 |
cross_scores = self.cross_encoder.predict(pairs)
|
| 246 |
|
|
|
|
| 247 |
final_results = []
|
| 248 |
+
for i, cand in enumerate(top_candidates):
|
| 249 |
final_results.append({
|
| 250 |
+
"chunk": cand['text'],
|
| 251 |
+
"metadata": cand['metadata'],
|
| 252 |
+
"score": cross_scores[i]
|
| 253 |
})
|
| 254 |
|
| 255 |
+
final_results.sort(key=lambda x: x['score'], reverse=True)
|
|
|
|
|
|
|
| 256 |
return final_results[:top_k]
|
| 257 |
|
| 258 |
+
# --- UI LOGIC ---
|
| 259 |
+
|
| 260 |
+
# 1. Sync on Startup
|
| 261 |
+
if 'synced' not in st.session_state:
|
| 262 |
+
DataManager.sync_from_hub()
|
| 263 |
+
st.session_state.synced = True
|
| 264 |
+
|
| 265 |
+
# 2. Init Engine
|
| 266 |
+
if 'engine' not in st.session_state:
|
| 267 |
+
with st.spinner("Initializing Vector Database..."):
|
| 268 |
+
st.session_state.engine = PersistentSearchEngine()
|
| 269 |
|
| 270 |
with st.sidebar:
|
| 271 |
+
st.header("🗄️ Knowledge Base")
|
| 272 |
+
uploaded_files = st.file_uploader("Ingest Documents", accept_multiple_files=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
if uploaded_files and st.button("Add to Database"):
|
| 275 |
+
with st.spinner("Parsing & Indexing..."):
|
| 276 |
+
new_chunks = []
|
| 277 |
+
for f in uploaded_files:
|
| 278 |
+
txt, fname = parse_file(f)
|
| 279 |
+
chunks = recursive_chunking(txt, fname)
|
| 280 |
+
new_chunks.extend(chunks)
|
| 281 |
+
|
| 282 |
+
if new_chunks:
|
| 283 |
+
count = st.session_state.engine.add_documents(new_chunks)
|
| 284 |
+
DataManager.sync_to_hub() # Auto-save to cloud
|
| 285 |
+
st.success(f"Added {count} chunks and synced to Cloud!")
|
| 286 |
+
|
| 287 |
st.divider()
|
| 288 |
+
st.info(f"Connected to: {DATASET_REPO_ID}")
|
| 289 |
+
|
| 290 |
+
# --- MAIN SEARCH UI ---
|
| 291 |
+
st.title("⚓ Navy Intelligent Search (RAG)")
|
| 292 |
+
|
| 293 |
+
query = st.text_input("Enter Query (e.g. 'Leave policy for O-3 and below'):")
|
| 294 |
+
col1, col2 = st.columns([1, 1])
|
| 295 |
+
with col1:
|
| 296 |
+
top_k = st.number_input("Documents", 1, 10, 3)
|
| 297 |
+
with col2:
|
| 298 |
+
alpha = st.slider("Hybrid Weight", 0.0, 1.0, 0.6, help="Higher = More Semantic")
|
| 299 |
+
|
| 300 |
+
if query:
|
| 301 |
+
results = st.session_state.engine.search(query, top_k=top_k, alpha=alpha)
|
| 302 |
|
| 303 |
+
# Store results for RAG
|
| 304 |
+
context_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
+
st.markdown("### 🔍 Search Results")
|
| 307 |
+
for res in results:
|
| 308 |
+
meta = res['metadata']
|
| 309 |
+
score = res['score']
|
| 310 |
+
text = res['chunk']
|
| 311 |
+
context_text += f"Source: {meta['source']} (Page {meta['page']})\nContent: {text}\n\n"
|
| 312 |
+
|
| 313 |
+
with st.expander(f"{meta['source']} | Pg {meta['page']} (Score: {score:.2f})", expanded=True):
|
| 314 |
+
st.markdown(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
+
# --- RAG: SUMMARIZATION ---
|
| 317 |
+
st.divider()
|
| 318 |
+
st.markdown("### 🤖 AI Intelligence")
|
| 319 |
+
if st.button("Generate Summary / Answer"):
|
| 320 |
+
from huggingface_hub import InferenceClient
|
| 321 |
+
|
| 322 |
+
# Use a free, powerful model via HF Inference API
|
| 323 |
+
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 324 |
+
llm_client = InferenceClient(model=repo_id, token=HF_TOKEN)
|
| 325 |
+
|
| 326 |
+
prompt = f"""
|
| 327 |
+
You are a Navy Administrative Aide. Answer the user's question based ONLY on the context provided below.
|
| 328 |
+
If the answer is not in the context, say "I cannot find the answer in the provided documents."
|
| 329 |
+
|
| 330 |
+
CONTEXT:
|
| 331 |
+
{context_text}
|
| 332 |
+
|
| 333 |
+
USER QUESTION:
|
| 334 |
+
{query}
|
| 335 |
+
|
| 336 |
+
ANSWER:
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
with st.spinner("Consulting LLM..."):
|
| 340 |
+
try:
|
| 341 |
+
response = llm_client.text_generation(prompt, max_new_tokens=500)
|
| 342 |
+
st.success(response)
|
| 343 |
+
except Exception as e:
|
| 344 |
+
st.error(f"LLM Error: {e}")
|