dejanseo commited on
Commit
2c44b4b
Β·
verified Β·
1 Parent(s): 4f5851a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +181 -57
src/streamlit_app.py CHANGED
@@ -4,6 +4,7 @@ Uses MS MARCO Cross-Encoder for search relevance ranking.
4
  """
5
 
6
  import re
 
7
  import numpy as np
8
  import streamlit as st
9
  import torch
@@ -14,17 +15,24 @@ MODEL_NAME = "cross-encoder/ms-marco-electra-base"
14
  MAX_SNIPPET_CHARS = 450
15
  MAX_SENTENCES = 5
16
 
 
 
 
 
 
 
17
  st.logo(
18
  image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
19
  link="https://dejan.ai/",
20
  size="large"
21
  )
22
 
23
- st.set_page_config(
24
- page_title="Snippet Generator by DEJAN AI",
25
- page_icon="βœ‚οΈ",
26
- layout="centered"
27
- )
 
28
 
29
  @st.cache_resource
30
  def load_model():
@@ -36,7 +44,6 @@ def load_model():
36
 
37
  def segment_sentences(text: str) -> list[str]:
38
  """Sentence segmentation with deduplication and filtering."""
39
- # Split on sentence boundaries AND newlines
40
  pattern = r'(?<=[.!?])\s+|\n+'
41
  raw_sentences = re.split(pattern, text)
42
 
@@ -52,12 +59,10 @@ def segment_sentences(text: str) -> list[str]:
52
  if s.startswith('http') or s.startswith('URL:'):
53
  continue
54
 
55
- # Skip low-alpha content (metadata, tables, prices)
56
  alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
57
  if alpha_ratio < 0.5:
58
  continue
59
 
60
- # Skip questions
61
  if s.endswith('?'):
62
  continue
63
 
@@ -71,88 +76,207 @@ def segment_sentences(text: str) -> list[str]:
71
  return sentences
72
 
73
 
74
- def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
75
- """Generate snippet using Cross-Encoder scoring."""
76
  sentences = segment_sentences(document)
77
 
78
  if not sentences:
79
- return "", []
80
 
81
- # Cross-encoder: score query-sentence pairs
82
  pairs = [[query, sent] for sent in sentences]
83
  scores = model.predict(pairs)
84
 
85
  ranked_indices = np.argsort(scores)[::-1]
86
 
87
- # Select with budget
88
- selected = []
89
  total_length = 0
90
 
91
  for idx in ranked_indices:
92
  sent = sentences[idx]
93
- if total_length + len(sent) <= max_chars and len(selected) < max_sents:
94
- selected.append((idx, sent, scores[idx]))
95
  total_length += len(sent)
96
 
97
- if not selected:
98
  best_idx = ranked_indices[0]
99
- return sentences[best_idx][:max_chars] + "...", []
 
 
 
 
 
 
100
 
101
- # Sort by document order
102
- selected.sort(key=lambda x: x[0])
103
 
104
- # Stitch with ellipsis for gaps
105
  snippet_parts = []
106
  prev_idx = -1
107
 
108
- for idx, sent, _ in selected:
109
  if prev_idx >= 0 and idx > prev_idx + 1:
110
  snippet_parts.append("...")
111
- snippet_parts.append(sent)
112
  prev_idx = idx
113
 
114
  if prev_idx < len(sentences) - 1:
115
  snippet_parts.append("...")
116
 
117
- # Debug info
118
- debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]])
119
- for i in range(min(5, len(ranked_indices)))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- return " ".join(snippet_parts), debug_info
122
 
123
 
124
- # --- Streamlit UI ---
125
- st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")
 
 
 
 
126
 
127
- st.write("How much of your page will be used to ground the model for a particular fanout query?")
128
- st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")
129
 
130
- query = st.text_input("Query", placeholder="enter a search query...")
 
 
131
 
132
- document = st.text_area(
133
- "Web Page Text",
134
- height=250,
135
- placeholder="Paste the full page content here..."
136
- )
137
 
138
- with st.expander("Settings"):
139
- max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
140
- max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
141
- show_debug = st.checkbox("Show debug info", value=True)
142
 
143
- if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"):
144
- if query and document:
145
- with st.spinner("Loading model & scoring sentences..."):
146
- model = load_model()
147
- snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
148
-
149
- st.subheader("Generated Snippet")
150
- st.code(snippet, language=None)
151
-
152
- if show_debug and debug:
153
- st.markdown("---")
154
- st.write("**Top sentences by score:**")
155
- for score, sent in debug:
156
- st.text(f"{score:.4f}: {sent[:80]}...")
157
- else:
158
- st.warning("Please enter both a query and document.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import re
7
+ import html
8
  import numpy as np
9
  import streamlit as st
10
  import torch
 
15
  MAX_SNIPPET_CHARS = 450
16
  MAX_SENTENCES = 5
17
 
18
+ st.set_page_config(
19
+ page_title="Snippet Generator by DEJAN AI",
20
+ page_icon="βœ‚οΈ",
21
+ layout="centered"
22
+ )
23
+
24
  st.logo(
25
  image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
26
  link="https://dejan.ai/",
27
  size="large"
28
  )
29
 
30
+ # --- Session State ---
31
+ if "results_mode" not in st.session_state:
32
+ st.session_state.results_mode = False
33
+ if "snippet_data" not in st.session_state:
34
+ st.session_state.snippet_data = None
35
+
36
 
37
  @st.cache_resource
38
  def load_model():
 
44
 
45
  def segment_sentences(text: str) -> list[str]:
46
  """Sentence segmentation with deduplication and filtering."""
 
47
  pattern = r'(?<=[.!?])\s+|\n+'
48
  raw_sentences = re.split(pattern, text)
49
 
 
59
  if s.startswith('http') or s.startswith('URL:'):
60
  continue
61
 
 
62
  alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
63
  if alpha_ratio < 0.5:
64
  continue
65
 
 
66
  if s.endswith('?'):
67
  continue
68
 
 
76
  return sentences
77
 
78
 
79
+ def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> dict:
80
+ """Generate snippet using Cross-Encoder scoring. Returns full analysis."""
81
  sentences = segment_sentences(document)
82
 
83
  if not sentences:
84
+ return {"snippet": "", "selected_sentences": [], "all_sentences": [], "scores": []}
85
 
 
86
  pairs = [[query, sent] for sent in sentences]
87
  scores = model.predict(pairs)
88
 
89
  ranked_indices = np.argsort(scores)[::-1]
90
 
91
+ selected_indices = []
 
92
  total_length = 0
93
 
94
  for idx in ranked_indices:
95
  sent = sentences[idx]
96
+ if total_length + len(sent) <= max_chars and len(selected_indices) < max_sents:
97
+ selected_indices.append(idx)
98
  total_length += len(sent)
99
 
100
+ if not selected_indices:
101
  best_idx = ranked_indices[0]
102
+ return {
103
+ "snippet": sentences[best_idx][:max_chars] + "...",
104
+ "selected_sentences": [sentences[best_idx][:max_chars]],
105
+ "all_sentences": sentences,
106
+ "scores": scores.tolist(),
107
+ "selected_indices": [best_idx]
108
+ }
109
 
110
+ selected_indices.sort()
111
+ selected_sentences = [sentences[i] for i in selected_indices]
112
 
 
113
  snippet_parts = []
114
  prev_idx = -1
115
 
116
+ for idx in selected_indices:
117
  if prev_idx >= 0 and idx > prev_idx + 1:
118
  snippet_parts.append("...")
119
+ snippet_parts.append(sentences[idx])
120
  prev_idx = idx
121
 
122
  if prev_idx < len(sentences) - 1:
123
  snippet_parts.append("...")
124
 
125
+ return {
126
+ "snippet": " ".join(snippet_parts),
127
+ "selected_sentences": selected_sentences,
128
+ "all_sentences": sentences,
129
+ "scores": scores.tolist(),
130
+ "selected_indices": selected_indices
131
+ }
132
+
133
+
134
+ def render_highlighted_html(document: str, selected_sentences: list[str]) -> str:
135
+ """Render document as HTML with highlighted sentences. Uses html.escape() for safety."""
136
+
137
+ # Find positions of selected sentences
138
+ highlights = []
139
+ for sent in selected_sentences:
140
+ start = document.find(sent)
141
+ if start != -1:
142
+ highlights.append((start, start + len(sent)))
143
+ continue
144
+ sent_pattern = r'\s+'.join(re.escape(word) for word in sent.split())
145
+ match = re.search(sent_pattern, document)
146
+ if match:
147
+ highlights.append((match.start(), match.end()))
148
+
149
+ highlights.sort(key=lambda x: x[0])
150
+
151
+ # Merge overlapping
152
+ merged = []
153
+ for start, end in highlights:
154
+ if merged and start <= merged[-1][1]:
155
+ merged[-1] = (merged[-1][0], max(merged[-1][1], end))
156
+ else:
157
+ merged.append((start, end))
158
+
159
+ # Build HTML with proper escaping
160
+ parts = []
161
+ pos = 0
162
+
163
+ for start, end in merged:
164
+ # Non-selected: gray text
165
+ if pos < start:
166
+ text = html.escape(document[pos:start])
167
+ parts.append(f'<span style="color:#888">{text}</span>')
168
+
169
+ # Selected: green highlight
170
+ text = html.escape(document[start:end])
171
+ parts.append(f'<span style="background:#c6f6d5;color:#166534;padding:1px 3px;border-radius:3px">{text}</span>')
172
+ pos = end
173
+
174
+ # Remaining non-selected
175
+ if pos < len(document):
176
+ text = html.escape(document[pos:])
177
+ parts.append(f'<span style="color:#888">{text}</span>')
178
 
179
+ return "".join(parts)
180
 
181
 
182
+ def generate_regex_pattern(selected_sentences: list[str]) -> str:
183
+ """Generate regex pattern for matching selected snippets."""
184
+ if not selected_sentences:
185
+ return ""
186
+ escaped = [re.escape(sent) for sent in selected_sentences]
187
+ return r'[\s\S]*?'.join(escaped)
188
 
 
 
189
 
190
+ def reset_to_input():
191
+ st.session_state.results_mode = False
192
+ st.session_state.snippet_data = None
193
 
 
 
 
 
 
194
 
195
+ # --- Main UI ---
196
+ st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")
 
 
197
 
198
+ if not st.session_state.results_mode:
199
+ # === INPUT MODE ===
200
+ st.write("How much of your page will be used to ground the model for a particular fanout query?")
201
+ st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")
202
+
203
+ query = st.text_input("Query", placeholder="enter a search query...")
204
+
205
+ document = st.text_area(
206
+ "Web Page Text",
207
+ height=250,
208
+ placeholder="Paste the full page content here..."
209
+ )
210
+
211
+ with st.expander("Settings"):
212
+ max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
213
+ max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
214
+
215
+ if st.button("Generate Snippet", type="primary"):
216
+ if query and document:
217
+ with st.spinner("Loading model & scoring sentences..."):
218
+ model = load_model()
219
+ result = generate_snippet(query, document, model, max_chars, max_sents)
220
+
221
+ st.session_state.snippet_data = {
222
+ "query": query,
223
+ "document": document,
224
+ "result": result,
225
+ "max_chars": max_chars,
226
+ "max_sents": max_sents
227
+ }
228
+ st.session_state.results_mode = True
229
+ st.rerun()
230
+ else:
231
+ st.warning("Please enter both a query and document.")
232
+
233
+ else:
234
+ # === RESULTS MODE ===
235
+ data = st.session_state.snippet_data
236
+ query = data["query"]
237
+ document = data["document"]
238
+ result = data["result"]
239
+
240
+ if st.button("← New Analysis"):
241
+ reset_to_input()
242
+ st.rerun()
243
+
244
+ # Stats
245
+ snippet_chars = sum(len(s) for s in result["selected_sentences"])
246
+ doc_chars = len(document)
247
+ pct = (snippet_chars / doc_chars * 100) if doc_chars > 0 else 0
248
+
249
+ st.caption(f"{snippet_chars:,} / {doc_chars:,} chars ({pct:.1f}%) β€’ {len(result['selected_sentences'])} sentences")
250
+
251
+ # Query - use st.html to prevent any rendering issues
252
+ st.html(f'<p style="font-size:1.3em;font-weight:600;margin:1em 0">{html.escape(query)}</p>')
253
+
254
+ # Generated snippet
255
+ st.subheader("Generated Snippet")
256
+ st.code(result["snippet"], wrap_lines=True, language=None)
257
+
258
+ # Highlighted document - st.html() does NO markdown parsing
259
+ highlighted = render_highlighted_html(document, result["selected_sentences"])
260
+
261
+ st.html(f'''
262
+ <div style="
263
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
264
+ font-size: 14px;
265
+ line-height: 1.7;
266
+ white-space: pre-wrap;
267
+ word-wrap: break-word;
268
+ padding: 16px;
269
+ border: 1px solid #e0e0e0;
270
+ border-radius: 8px;
271
+ background: #fafafa;
272
+ overflow-y: auto;
273
+ ">{highlighted}</div>
274
+ ''')
275
+
276
+ st.caption("🟒 Green = included in snippet")
277
+
278
+ # Regex pattern
279
+ with st.expander("πŸ“‹ Regex Pattern"):
280
+ regex = generate_regex_pattern(result["selected_sentences"])
281
+ st.code(regex, language=None, wrap_lines=True)
282
+ st.caption("Match selected snippets in other tools.")