NavyDevilDoc commited on
Commit
b045f9c
·
verified ·
1 Parent(s): a239196

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -107
app.py CHANGED
@@ -1,19 +1,18 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- from sentence_transformers import SentenceTransformer
5
  import faiss
6
  from rank_bm25 import BM25Okapi
7
  import pypdf
8
  import docx
9
- from io import BytesIO
10
 
11
  # --- CONFIGURATION ---
12
- st.set_page_config(page_title="Hybrid Semantic Search", layout="wide")
13
 
14
- # --- HELPER FUNCTIONS: FILE PARSING ---
15
  def parse_file(uploaded_file):
16
- """Extracts text from various file formats."""
17
  text = ""
18
  try:
19
  if uploaded_file.name.endswith(".pdf"):
@@ -27,26 +26,30 @@ def parse_file(uploaded_file):
27
  text = uploaded_file.read().decode("utf-8")
28
  elif uploaded_file.name.endswith(".csv"):
29
  df = pd.read_csv(uploaded_file)
30
- # Assuming a generic CSV, we just flatten it to text for now
31
  text = df.to_string()
32
  except Exception as e:
33
  st.error(f"Error reading file: {e}")
34
  return text
35
 
36
  def chunk_text(text, chunk_size=300, overlap=50):
37
- """Splits text into overlapping chunks for better context."""
38
  words = text.split()
39
  chunks = []
40
  for i in range(0, len(words), chunk_size - overlap):
41
  chunk = " ".join(words[i:i + chunk_size])
42
- if len(chunk) > 50: # Filter out tiny chunks
43
  chunks.append(chunk)
44
  return chunks
45
 
46
- # --- CORE LOGIC: HYBRID SEARCH ENGINE ---
47
- class HybridSearchEngine:
48
- def __init__(self, model_name):
49
- self.model = SentenceTransformer(model_name)
 
 
 
 
 
 
50
  self.documents = []
51
  self.faiss_index = None
52
  self.bm25 = None
@@ -54,136 +57,155 @@ class HybridSearchEngine:
54
  def fit(self, documents):
55
  self.documents = documents
56
 
57
- # 1. Build Dense Index (FAISS)
58
- embeddings = self.model.encode(documents)
59
- # Normalize for Cosine Similarity (Inner Product)
60
- faiss.normalize_L2(embeddings)
61
- dimension = embeddings.shape[1]
62
- self.faiss_index = faiss.IndexFlatIP(dimension) # Inner Product = Cosine Sim
63
- self.faiss_index.add(embeddings)
 
 
64
 
65
- # 2. Build Sparse Index (BM25)
66
  tokenized_corpus = [doc.lower().split() for doc in documents]
67
  self.bm25 = BM25Okapi(tokenized_corpus)
68
 
69
  def search(self, query, top_k=5, alpha=0.5):
70
- """
71
- Alpha: Weighting factor.
72
- 1.0 = Pure Vector Search
73
- 0.0 = Pure Keyword Search
74
- 0.5 = Equal Hybrid
75
- """
76
- # --- Vector Search ---
77
- query_vector = self.model.encode([query])
78
  faiss.normalize_L2(query_vector)
79
- # Search more than we need to allow for re-ranking
80
- v_scores, v_indices = self.faiss_index.search(query_vector, len(self.documents))
81
 
82
- # Create a map of {doc_index: vector_score}
83
- # Normalize vector scores to 0-1 range (approx)
84
- v_results = {}
85
- for i, idx in enumerate(v_indices[0]):
86
- if idx != -1:
87
- v_results[idx] = v_scores[0][i]
88
-
89
- # --- Keyword Search (BM25) ---
90
  tokenized_query = query.lower().split()
91
  bm25_scores = self.bm25.get_scores(tokenized_query)
92
 
93
- # Normalize BM25 scores (Min-Max Scaling) to match Vector scale
94
- if max(bm25_scores) > 0:
95
  bm25_scores = (bm25_scores - min(bm25_scores)) / (max(bm25_scores) - min(bm25_scores))
 
 
 
96
 
97
- # --- Hybrid Combination ---
98
- final_results = []
99
- for idx, doc in enumerate(self.documents):
100
- v_score = v_results.get(idx, 0.0)
101
- k_score = bm25_scores[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Weighted Score
104
- final_score = (alpha * v_score) + ((1 - alpha) * k_score)
 
 
 
 
105
  final_results.append({
106
- "chunk": doc,
107
- "score": final_score,
108
- "vector_score": v_score,
109
- "keyword_score": k_score
110
  })
111
 
112
- # Sort by final score
113
  final_results = sorted(final_results, key=lambda x: x["score"], reverse=True)
 
114
  return final_results[:top_k]
115
 
116
- # --- STREAMLIT UI ---
117
-
118
- st.title("⚡ Hybrid Search: Vector + Keywords")
119
- st.caption("Robust semantic search powered by FAISS (Dense) and BM25 (Sparse).")
 
 
 
120
 
121
  with st.sidebar:
122
- st.header("⚙️ Configuration")
 
 
 
 
 
 
 
123
 
124
- # 3. Select Embedding Model
125
  model_choice = st.selectbox(
126
- "Embedding Model",
127
- options=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-mpnet-base-dot-v1"],
128
- index=0,
129
- help="MiniLM is fast; MPNet is more accurate but slower."
130
  )
131
 
132
- # 2. Results Count
133
- top_k = st.number_input("Results to Retrieve", min_value=1, max_value=50, value=5, step=1)
134
 
135
- # Hybrid Weight Slider
136
- alpha = st.slider("Hybrid Balance (Alpha)", 0.0, 1.0, 0.5,
137
- help="0.0 = Keywords Only, 1.0 = Vectors Only")
138
-
139
- st.divider()
140
 
141
- # 1. File Upload
142
- uploaded_files = st.file_uploader(
143
- "Upload Knowledge Base",
144
- type=['txt', 'pdf', 'docx', 'csv'],
145
- accept_multiple_files=True
146
- )
147
-
148
- process_btn = st.button("Build Database")
149
 
150
- # --- APP STATE MANAGEMENT ---
151
- if 'search_engine' not in st.session_state:
152
- st.session_state.search_engine = None
153
 
154
- if process_btn and uploaded_files:
155
- with st.spinner(f"Parsing files and initializing {model_choice}..."):
156
  all_chunks = []
157
  for file in uploaded_files:
158
- raw_text = parse_file(file)
159
- file_chunks = chunk_text(raw_text)
160
- all_chunks.extend(file_chunks)
161
 
162
  if all_chunks:
163
- engine = HybridSearchEngine(model_choice)
164
- engine.fit(all_chunks)
165
- st.session_state.search_engine = engine
166
- st.success(f"Indexed {len(all_chunks)} chunks from {len(uploaded_files)} files!")
167
  else:
168
- st.warning("No text found in uploaded files.")
169
 
170
- # --- SEARCH INTERFACE ---
171
- if st.session_state.search_engine:
172
- query = st.text_input("Enter your query:", placeholder="e.g., 'What are the safety protocols for the engine room?'")
173
-
174
  if query:
175
- results = st.session_state.search_engine.search(query, top_k=top_k, alpha=alpha)
176
-
177
- st.subheader(f"Top {top_k} Matches")
178
-
179
  for i, res in enumerate(results):
180
- with st.expander(f"Rank {i+1} (Score: {res['score']:.4f})", expanded=(i==0)):
181
- st.markdown(f"**{res['chunk']}**")
182
-
183
- # Metadata columns
184
- c1, c2, c3 = st.columns(3)
185
- c1.metric("Hybrid Score", f"{res['score']:.4f}")
186
- c2.metric("Vector Match", f"{res['vector_score']:.4f}")
187
- c3.metric("Keyword Match", f"{res['keyword_score']:.4f}")
 
188
  else:
189
- st.info("👈 Please upload documents in the sidebar to begin.")
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
5
  import faiss
6
  from rank_bm25 import BM25Okapi
7
  import pypdf
8
  import docx
9
+ import torch
10
 
11
  # --- CONFIGURATION ---
12
+ st.set_page_config(page_title="Advanced Semantic Search", layout="wide")
13
 
14
+ # --- HELPER FUNCTIONS ---
15
  def parse_file(uploaded_file):
 
16
  text = ""
17
  try:
18
  if uploaded_file.name.endswith(".pdf"):
 
26
  text = uploaded_file.read().decode("utf-8")
27
  elif uploaded_file.name.endswith(".csv"):
28
  df = pd.read_csv(uploaded_file)
 
29
  text = df.to_string()
30
  except Exception as e:
31
  st.error(f"Error reading file: {e}")
32
  return text
33
 
34
  def chunk_text(text, chunk_size=300, overlap=50):
 
35
  words = text.split()
36
  chunks = []
37
  for i in range(0, len(words), chunk_size - overlap):
38
  chunk = " ".join(words[i:i + chunk_size])
39
+ if len(chunk) > 50:
40
  chunks.append(chunk)
41
  return chunks
42
 
43
+ # --- CORE LOGIC: RETRIEVER + RE-RANKER ---
44
+ class SearchEngine:
45
+ def __init__(self, bi_encoder_name):
46
+ # 1. Bi-Encoder (Fast Retrieval)
47
+ self.bi_encoder = SentenceTransformer(bi_encoder_name)
48
+
49
+ # 2. Cross-Encoder (Accurate Re-Ranking)
50
+ # We use a standard MS MARCO model designed for this exact task
51
+ self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
52
+
53
  self.documents = []
54
  self.faiss_index = None
55
  self.bm25 = None
 
57
  def fit(self, documents):
58
  self.documents = documents
59
 
60
+ # Build Dense Index
61
+ embeddings = self.bi_encoder.encode(documents, convert_to_tensor=True)
62
+ # Convert to numpy for FAISS
63
+ embeddings_np = embeddings.cpu().numpy()
64
+ faiss.normalize_L2(embeddings_np)
65
+
66
+ dimension = embeddings_np.shape[1]
67
+ self.faiss_index = faiss.IndexFlatIP(dimension)
68
+ self.faiss_index.add(embeddings_np)
69
 
70
+ # Build Sparse Index
71
  tokenized_corpus = [doc.lower().split() for doc in documents]
72
  self.bm25 = BM25Okapi(tokenized_corpus)
73
 
74
  def search(self, query, top_k=5, alpha=0.5):
75
+ # STAGE 1: RETRIEVAL (Get a candidate pool)
76
+ # We retrieve 3x the requested amount to give the re-ranker options
77
+ candidate_k = top_k * 3
78
+
79
+ # Vector Search
80
+ query_vector = self.bi_encoder.encode([query])
 
 
81
  faiss.normalize_L2(query_vector)
82
+ v_scores, v_indices = self.faiss_index.search(query_vector, min(len(self.documents), candidate_k))
 
83
 
84
+ # BM25 Search
 
 
 
 
 
 
 
85
  tokenized_query = query.lower().split()
86
  bm25_scores = self.bm25.get_scores(tokenized_query)
87
 
88
+ # Normalize BM25
89
+ if len(bm25_scores) > 0 and max(bm25_scores) > 0:
90
  bm25_scores = (bm25_scores - min(bm25_scores)) / (max(bm25_scores) - min(bm25_scores))
91
+
92
+ # Combine Scores to get candidates
93
+ candidates = {} # {doc_idx: hybrid_score}
94
 
95
+ # Map vector results
96
+ for i, idx in enumerate(v_indices[0]):
97
+ if idx != -1:
98
+ v_score = v_scores[0][i]
99
+ candidates[idx] = alpha * v_score
100
+
101
+ # Add BM25 results (for all docs, efficient enough for small corpora)
102
+ # In production, you'd only check top BM25 results
103
+ top_bm25_indices = np.argsort(bm25_scores)[-candidate_k:]
104
+ for idx in top_bm25_indices:
105
+ score = (1 - alpha) * bm25_scores[idx]
106
+ if idx in candidates:
107
+ candidates[idx] += score
108
+ else:
109
+ candidates[idx] = score
110
+
111
+ # Sort candidates by Hybrid Score
112
+ sorted_candidates = sorted(candidates.items(), key=lambda x: x[1], reverse=True)[:candidate_k]
113
+
114
+ # STAGE 2: RE-RANKING (Cross-Encoder)
115
+ # Prepare pairs for the Cross-Encoder: [[query, doc1], [query, doc2]...]
116
+ candidate_indices = [idx for idx, score in sorted_candidates]
117
+ candidate_docs = [self.documents[idx] for idx in candidate_indices]
118
+
119
+ pairs = [[query, doc] for doc in candidate_docs]
120
+
121
+ if not pairs:
122
+ return []
123
 
124
+ # Predict scores (logits)
125
+ cross_scores = self.cross_encoder.predict(pairs)
126
+
127
+ # Combine everything into final results
128
+ final_results = []
129
+ for i, idx in enumerate(candidate_indices):
130
  final_results.append({
131
+ "chunk": self.documents[idx],
132
+ "score": cross_scores[i], # This is the high-accuracy score
133
+ "original_hybrid_score": sorted_candidates[i][1]
 
134
  })
135
 
136
+ # Sort by Cross-Encoder score
137
  final_results = sorted(final_results, key=lambda x: x["score"], reverse=True)
138
+
139
  return final_results[:top_k]
140
 
141
+ # --- UI LAYOUT ---
142
+ st.title("🧠 Semantic Search: Hybrid + Cross-Encoder")
143
+ st.markdown("""
144
+ This system uses a **Two-Stage Retrieval Process**:
145
+ 1. **Retrieval:** Finds top candidates using Vector (semantic) and BM25 (keyword) search.
146
+ 2. **Re-Ranking:** A Cross-Encoder model reads the query and candidates to score true relevance.
147
+ """)
148
 
149
  with st.sidebar:
150
+ st.header("1. Setup Knowledge Base")
151
+ uploaded_files = st.file_uploader(
152
+ "Upload Documents",
153
+ type=['txt', 'pdf', 'docx', 'csv'],
154
+ accept_multiple_files=True
155
+ )
156
+
157
+ st.divider()
158
 
159
+ st.header("2. Tuning")
160
  model_choice = st.selectbox(
161
+ "Base Embedding Model",
162
+ ["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
163
+ help="Used for the initial fast retrieval."
 
164
  )
165
 
166
+ alpha = st.slider("Hybrid Alpha", 0.0, 1.0, 0.4,
167
+ help="0.0 = Keywords, 1.0 = Vectors. 0.4 is often best for Hybrid.")
168
 
169
+ top_k = st.number_input("Final Results", 1, 20, 5)
 
 
 
 
170
 
171
+ build_btn = st.button("Build Database")
 
 
 
 
 
 
 
172
 
173
+ # --- APP STATE ---
174
+ if 'engine' not in st.session_state:
175
+ st.session_state.engine = None
176
 
177
+ if build_btn and uploaded_files:
178
+ with st.spinner("Processing files..."):
179
  all_chunks = []
180
  for file in uploaded_files:
181
+ raw = parse_file(file)
182
+ chunks = chunk_text(raw)
183
+ all_chunks.extend(chunks)
184
 
185
  if all_chunks:
186
+ # Initialize Engine
187
+ st.session_state.engine = SearchEngine(model_choice)
188
+ st.session_state.engine.fit(all_chunks)
189
+ st.success(f"Indexed {len(all_chunks)} chunks!")
190
  else:
191
+ st.warning("No text extracted.")
192
 
193
+ # --- SEARCH ---
194
+ if st.session_state.engine:
195
+ query = st.text_input("Ask a question:")
 
196
  if query:
197
+ with st.spinner("Retrieving & Re-Ranking..."):
198
+ results = st.session_state.engine.search(query, top_k=top_k, alpha=alpha)
199
+
 
200
  for i, res in enumerate(results):
201
+ score = res['score']
202
+ # Color code high relevance
203
+ color = "green" if score > 0 else "blue"
204
+
205
+ with st.container():
206
+ st.markdown(f"### Rank {i+1}")
207
+ st.caption(f"Relevance Score: :{color}[{score:.3f}]")
208
+ st.info(res['chunk'])
209
+ st.divider()
210
  else:
211
+ st.info("Upload documents to start.")