grounding-snippet-generator / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
2c44b4b verified
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""
import re
import html
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder
# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5
st.set_page_config(
page_title="Snippet Generator by DEJAN AI",
page_icon="βœ‚οΈ",
layout="centered"
)
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large"
)
# --- Session State ---
if "results_mode" not in st.session_state:
st.session_state.results_mode = False
if "snippet_data" not in st.session_state:
st.session_state.snippet_data = None
@st.cache_resource
def load_model():
"""Load CrossEncoder model."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(MODEL_NAME, device=device)
return model
def segment_sentences(text: str) -> list[str]:
"""Sentence segmentation with deduplication and filtering."""
pattern = r'(?<=[.!?])\s+|\n+'
raw_sentences = re.split(pattern, text)
seen = set()
sentences = []
for s in raw_sentences:
s = s.strip()
if not s or len(s) < 20:
continue
if s.startswith('http') or s.startswith('URL:'):
continue
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
if alpha_ratio < 0.5:
continue
if s.endswith('?'):
continue
normalized = ' '.join(s.lower().split())
if normalized in seen:
continue
seen.add(normalized)
sentences.append(s)
return sentences
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> dict:
"""Generate snippet using Cross-Encoder scoring. Returns full analysis."""
sentences = segment_sentences(document)
if not sentences:
return {"snippet": "", "selected_sentences": [], "all_sentences": [], "scores": []}
pairs = [[query, sent] for sent in sentences]
scores = model.predict(pairs)
ranked_indices = np.argsort(scores)[::-1]
selected_indices = []
total_length = 0
for idx in ranked_indices:
sent = sentences[idx]
if total_length + len(sent) <= max_chars and len(selected_indices) < max_sents:
selected_indices.append(idx)
total_length += len(sent)
if not selected_indices:
best_idx = ranked_indices[0]
return {
"snippet": sentences[best_idx][:max_chars] + "...",
"selected_sentences": [sentences[best_idx][:max_chars]],
"all_sentences": sentences,
"scores": scores.tolist(),
"selected_indices": [best_idx]
}
selected_indices.sort()
selected_sentences = [sentences[i] for i in selected_indices]
snippet_parts = []
prev_idx = -1
for idx in selected_indices:
if prev_idx >= 0 and idx > prev_idx + 1:
snippet_parts.append("...")
snippet_parts.append(sentences[idx])
prev_idx = idx
if prev_idx < len(sentences) - 1:
snippet_parts.append("...")
return {
"snippet": " ".join(snippet_parts),
"selected_sentences": selected_sentences,
"all_sentences": sentences,
"scores": scores.tolist(),
"selected_indices": selected_indices
}
def render_highlighted_html(document: str, selected_sentences: list[str]) -> str:
"""Render document as HTML with highlighted sentences. Uses html.escape() for safety."""
# Find positions of selected sentences
highlights = []
for sent in selected_sentences:
start = document.find(sent)
if start != -1:
highlights.append((start, start + len(sent)))
continue
sent_pattern = r'\s+'.join(re.escape(word) for word in sent.split())
match = re.search(sent_pattern, document)
if match:
highlights.append((match.start(), match.end()))
highlights.sort(key=lambda x: x[0])
# Merge overlapping
merged = []
for start, end in highlights:
if merged and start <= merged[-1][1]:
merged[-1] = (merged[-1][0], max(merged[-1][1], end))
else:
merged.append((start, end))
# Build HTML with proper escaping
parts = []
pos = 0
for start, end in merged:
# Non-selected: gray text
if pos < start:
text = html.escape(document[pos:start])
parts.append(f'<span style="color:#888">{text}</span>')
# Selected: green highlight
text = html.escape(document[start:end])
parts.append(f'<span style="background:#c6f6d5;color:#166534;padding:1px 3px;border-radius:3px">{text}</span>')
pos = end
# Remaining non-selected
if pos < len(document):
text = html.escape(document[pos:])
parts.append(f'<span style="color:#888">{text}</span>')
return "".join(parts)
def generate_regex_pattern(selected_sentences: list[str]) -> str:
"""Generate regex pattern for matching selected snippets."""
if not selected_sentences:
return ""
escaped = [re.escape(sent) for sent in selected_sentences]
return r'[\s\S]*?'.join(escaped)
def reset_to_input():
st.session_state.results_mode = False
st.session_state.snippet_data = None
# --- Main UI ---
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")
if not st.session_state.results_mode:
# === INPUT MODE ===
st.write("How much of your page will be used to ground the model for a particular fanout query?")
st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")
query = st.text_input("Query", placeholder="enter a search query...")
document = st.text_area(
"Web Page Text",
height=250,
placeholder="Paste the full page content here..."
)
with st.expander("Settings"):
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
if st.button("Generate Snippet", type="primary"):
if query and document:
with st.spinner("Loading model & scoring sentences..."):
model = load_model()
result = generate_snippet(query, document, model, max_chars, max_sents)
st.session_state.snippet_data = {
"query": query,
"document": document,
"result": result,
"max_chars": max_chars,
"max_sents": max_sents
}
st.session_state.results_mode = True
st.rerun()
else:
st.warning("Please enter both a query and document.")
else:
# === RESULTS MODE ===
data = st.session_state.snippet_data
query = data["query"]
document = data["document"]
result = data["result"]
if st.button("← New Analysis"):
reset_to_input()
st.rerun()
# Stats
snippet_chars = sum(len(s) for s in result["selected_sentences"])
doc_chars = len(document)
pct = (snippet_chars / doc_chars * 100) if doc_chars > 0 else 0
st.caption(f"{snippet_chars:,} / {doc_chars:,} chars ({pct:.1f}%) β€’ {len(result['selected_sentences'])} sentences")
# Query - use st.html to prevent any rendering issues
st.html(f'<p style="font-size:1.3em;font-weight:600;margin:1em 0">{html.escape(query)}</p>')
# Generated snippet
st.subheader("Generated Snippet")
st.code(result["snippet"], wrap_lines=True, language=None)
# Highlighted document - st.html() does NO markdown parsing
highlighted = render_highlighted_html(document, result["selected_sentences"])
st.html(f'''
<div style="
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
font-size: 14px;
line-height: 1.7;
white-space: pre-wrap;
word-wrap: break-word;
padding: 16px;
border: 1px solid #e0e0e0;
border-radius: 8px;
background: #fafafa;
overflow-y: auto;
">{highlighted}</div>
''')
st.caption("🟒 Green = included in snippet")
# Regex pattern
with st.expander("πŸ“‹ Regex Pattern"):
regex = generate_regex_pattern(result["selected_sentences"])
st.code(regex, language=None, wrap_lines=True)
st.caption("Match selected snippets in other tools.")