|
|
""" |
|
|
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets |
|
|
Uses MS MARCO Cross-Encoder for search relevance ranking. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import html |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
import torch |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
MODEL_NAME = "cross-encoder/ms-marco-electra-base" |
|
|
MAX_SNIPPET_CHARS = 450 |
|
|
MAX_SENTENCES = 5 |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Snippet Generator by DEJAN AI", |
|
|
page_icon="βοΈ", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
st.logo( |
|
|
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", |
|
|
link="https://dejan.ai/", |
|
|
size="large" |
|
|
) |
|
|
|
|
|
|
|
|
if "results_mode" not in st.session_state: |
|
|
st.session_state.results_mode = False |
|
|
if "snippet_data" not in st.session_state: |
|
|
st.session_state.snippet_data = None |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load CrossEncoder model.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = CrossEncoder(MODEL_NAME, device=device) |
|
|
return model |
|
|
|
|
|
|
|
|
def segment_sentences(text: str) -> list[str]: |
|
|
"""Sentence segmentation with deduplication and filtering.""" |
|
|
pattern = r'(?<=[.!?])\s+|\n+' |
|
|
raw_sentences = re.split(pattern, text) |
|
|
|
|
|
seen = set() |
|
|
sentences = [] |
|
|
|
|
|
for s in raw_sentences: |
|
|
s = s.strip() |
|
|
|
|
|
if not s or len(s) < 20: |
|
|
continue |
|
|
|
|
|
if s.startswith('http') or s.startswith('URL:'): |
|
|
continue |
|
|
|
|
|
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1) |
|
|
if alpha_ratio < 0.5: |
|
|
continue |
|
|
|
|
|
if s.endswith('?'): |
|
|
continue |
|
|
|
|
|
normalized = ' '.join(s.lower().split()) |
|
|
if normalized in seen: |
|
|
continue |
|
|
seen.add(normalized) |
|
|
|
|
|
sentences.append(s) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> dict: |
|
|
"""Generate snippet using Cross-Encoder scoring. Returns full analysis.""" |
|
|
sentences = segment_sentences(document) |
|
|
|
|
|
if not sentences: |
|
|
return {"snippet": "", "selected_sentences": [], "all_sentences": [], "scores": []} |
|
|
|
|
|
pairs = [[query, sent] for sent in sentences] |
|
|
scores = model.predict(pairs) |
|
|
|
|
|
ranked_indices = np.argsort(scores)[::-1] |
|
|
|
|
|
selected_indices = [] |
|
|
total_length = 0 |
|
|
|
|
|
for idx in ranked_indices: |
|
|
sent = sentences[idx] |
|
|
if total_length + len(sent) <= max_chars and len(selected_indices) < max_sents: |
|
|
selected_indices.append(idx) |
|
|
total_length += len(sent) |
|
|
|
|
|
if not selected_indices: |
|
|
best_idx = ranked_indices[0] |
|
|
return { |
|
|
"snippet": sentences[best_idx][:max_chars] + "...", |
|
|
"selected_sentences": [sentences[best_idx][:max_chars]], |
|
|
"all_sentences": sentences, |
|
|
"scores": scores.tolist(), |
|
|
"selected_indices": [best_idx] |
|
|
} |
|
|
|
|
|
selected_indices.sort() |
|
|
selected_sentences = [sentences[i] for i in selected_indices] |
|
|
|
|
|
snippet_parts = [] |
|
|
prev_idx = -1 |
|
|
|
|
|
for idx in selected_indices: |
|
|
if prev_idx >= 0 and idx > prev_idx + 1: |
|
|
snippet_parts.append("...") |
|
|
snippet_parts.append(sentences[idx]) |
|
|
prev_idx = idx |
|
|
|
|
|
if prev_idx < len(sentences) - 1: |
|
|
snippet_parts.append("...") |
|
|
|
|
|
return { |
|
|
"snippet": " ".join(snippet_parts), |
|
|
"selected_sentences": selected_sentences, |
|
|
"all_sentences": sentences, |
|
|
"scores": scores.tolist(), |
|
|
"selected_indices": selected_indices |
|
|
} |
|
|
|
|
|
|
|
|
def render_highlighted_html(document: str, selected_sentences: list[str]) -> str: |
|
|
"""Render document as HTML with highlighted sentences. Uses html.escape() for safety.""" |
|
|
|
|
|
|
|
|
highlights = [] |
|
|
for sent in selected_sentences: |
|
|
start = document.find(sent) |
|
|
if start != -1: |
|
|
highlights.append((start, start + len(sent))) |
|
|
continue |
|
|
sent_pattern = r'\s+'.join(re.escape(word) for word in sent.split()) |
|
|
match = re.search(sent_pattern, document) |
|
|
if match: |
|
|
highlights.append((match.start(), match.end())) |
|
|
|
|
|
highlights.sort(key=lambda x: x[0]) |
|
|
|
|
|
|
|
|
merged = [] |
|
|
for start, end in highlights: |
|
|
if merged and start <= merged[-1][1]: |
|
|
merged[-1] = (merged[-1][0], max(merged[-1][1], end)) |
|
|
else: |
|
|
merged.append((start, end)) |
|
|
|
|
|
|
|
|
parts = [] |
|
|
pos = 0 |
|
|
|
|
|
for start, end in merged: |
|
|
|
|
|
if pos < start: |
|
|
text = html.escape(document[pos:start]) |
|
|
parts.append(f'<span style="color:#888">{text}</span>') |
|
|
|
|
|
|
|
|
text = html.escape(document[start:end]) |
|
|
parts.append(f'<span style="background:#c6f6d5;color:#166534;padding:1px 3px;border-radius:3px">{text}</span>') |
|
|
pos = end |
|
|
|
|
|
|
|
|
if pos < len(document): |
|
|
text = html.escape(document[pos:]) |
|
|
parts.append(f'<span style="color:#888">{text}</span>') |
|
|
|
|
|
return "".join(parts) |
|
|
|
|
|
|
|
|
def generate_regex_pattern(selected_sentences: list[str]) -> str: |
|
|
"""Generate regex pattern for matching selected snippets.""" |
|
|
if not selected_sentences: |
|
|
return "" |
|
|
escaped = [re.escape(sent) for sent in selected_sentences] |
|
|
return r'[\s\S]*?'.join(escaped) |
|
|
|
|
|
|
|
|
def reset_to_input(): |
|
|
st.session_state.results_mode = False |
|
|
st.session_state.snippet_data = None |
|
|
|
|
|
|
|
|
|
|
|
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base") |
|
|
|
|
|
if not st.session_state.results_mode: |
|
|
|
|
|
st.write("How much of your page will be used to ground the model for a particular fanout query?") |
|
|
st.write("Full Context: https://dejan.ai/blog/ai-search-filter/") |
|
|
|
|
|
query = st.text_input("Query", placeholder="enter a search query...") |
|
|
|
|
|
document = st.text_area( |
|
|
"Web Page Text", |
|
|
height=250, |
|
|
placeholder="Paste the full page content here..." |
|
|
) |
|
|
|
|
|
with st.expander("Settings"): |
|
|
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50) |
|
|
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES) |
|
|
|
|
|
if st.button("Generate Snippet", type="primary"): |
|
|
if query and document: |
|
|
with st.spinner("Loading model & scoring sentences..."): |
|
|
model = load_model() |
|
|
result = generate_snippet(query, document, model, max_chars, max_sents) |
|
|
|
|
|
st.session_state.snippet_data = { |
|
|
"query": query, |
|
|
"document": document, |
|
|
"result": result, |
|
|
"max_chars": max_chars, |
|
|
"max_sents": max_sents |
|
|
} |
|
|
st.session_state.results_mode = True |
|
|
st.rerun() |
|
|
else: |
|
|
st.warning("Please enter both a query and document.") |
|
|
|
|
|
else: |
|
|
|
|
|
data = st.session_state.snippet_data |
|
|
query = data["query"] |
|
|
document = data["document"] |
|
|
result = data["result"] |
|
|
|
|
|
if st.button("β New Analysis"): |
|
|
reset_to_input() |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
snippet_chars = sum(len(s) for s in result["selected_sentences"]) |
|
|
doc_chars = len(document) |
|
|
pct = (snippet_chars / doc_chars * 100) if doc_chars > 0 else 0 |
|
|
|
|
|
st.caption(f"{snippet_chars:,} / {doc_chars:,} chars ({pct:.1f}%) β’ {len(result['selected_sentences'])} sentences") |
|
|
|
|
|
|
|
|
st.html(f'<p style="font-size:1.3em;font-weight:600;margin:1em 0">{html.escape(query)}</p>') |
|
|
|
|
|
|
|
|
st.subheader("Generated Snippet") |
|
|
st.code(result["snippet"], wrap_lines=True, language=None) |
|
|
|
|
|
|
|
|
highlighted = render_highlighted_html(document, result["selected_sentences"]) |
|
|
|
|
|
st.html(f''' |
|
|
<div style=" |
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; |
|
|
font-size: 14px; |
|
|
line-height: 1.7; |
|
|
white-space: pre-wrap; |
|
|
word-wrap: break-word; |
|
|
padding: 16px; |
|
|
border: 1px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
background: #fafafa; |
|
|
overflow-y: auto; |
|
|
">{highlighted}</div> |
|
|
''') |
|
|
|
|
|
st.caption("π’ Green = included in snippet") |
|
|
|
|
|
|
|
|
with st.expander("π Regex Pattern"): |
|
|
regex = generate_regex_pattern(result["selected_sentences"]) |
|
|
st.code(regex, language=None, wrap_lines=True) |
|
|
st.caption("Match selected snippets in other tools.") |