|
|
""" |
|
|
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets |
|
|
Uses MS MARCO Cross-Encoder for search relevance ranking. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
import torch |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
MODEL_NAME = "cross-encoder/ms-marco-electra-base" |
|
|
MAX_SNIPPET_CHARS = 450 |
|
|
MAX_SENTENCES = 5 |
|
|
|
|
|
st.logo( |
|
|
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", |
|
|
link="https://dejan.ai/", |
|
|
size="large" |
|
|
) |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Snippet Generator by DEJAN AI", |
|
|
page_icon="✂️", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load CrossEncoder model.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = CrossEncoder(MODEL_NAME, device=device) |
|
|
return model |
|
|
|
|
|
|
|
|
def segment_sentences(text: str) -> list[str]: |
|
|
"""Sentence segmentation with deduplication and filtering.""" |
|
|
|
|
|
pattern = r'(?<=[.!?])\s+|\n+' |
|
|
raw_sentences = re.split(pattern, text) |
|
|
|
|
|
seen = set() |
|
|
sentences = [] |
|
|
|
|
|
for s in raw_sentences: |
|
|
s = s.strip() |
|
|
|
|
|
if not s or len(s) < 20: |
|
|
continue |
|
|
|
|
|
if s.startswith('http') or s.startswith('URL:'): |
|
|
continue |
|
|
|
|
|
|
|
|
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1) |
|
|
if alpha_ratio < 0.5: |
|
|
continue |
|
|
|
|
|
|
|
|
if s.endswith('?'): |
|
|
continue |
|
|
|
|
|
normalized = ' '.join(s.lower().split()) |
|
|
if normalized in seen: |
|
|
continue |
|
|
seen.add(normalized) |
|
|
|
|
|
sentences.append(s) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]: |
|
|
"""Generate snippet using Cross-Encoder scoring.""" |
|
|
sentences = segment_sentences(document) |
|
|
|
|
|
if not sentences: |
|
|
return "", [] |
|
|
|
|
|
|
|
|
pairs = [[query, sent] for sent in sentences] |
|
|
scores = model.predict(pairs) |
|
|
|
|
|
ranked_indices = np.argsort(scores)[::-1] |
|
|
|
|
|
|
|
|
selected = [] |
|
|
total_length = 0 |
|
|
|
|
|
for idx in ranked_indices: |
|
|
sent = sentences[idx] |
|
|
if total_length + len(sent) <= max_chars and len(selected) < max_sents: |
|
|
selected.append((idx, sent, scores[idx])) |
|
|
total_length += len(sent) |
|
|
|
|
|
if not selected: |
|
|
best_idx = ranked_indices[0] |
|
|
return sentences[best_idx][:max_chars] + "...", [] |
|
|
|
|
|
|
|
|
selected.sort(key=lambda x: x[0]) |
|
|
|
|
|
|
|
|
snippet_parts = [] |
|
|
prev_idx = -1 |
|
|
|
|
|
for idx, sent, _ in selected: |
|
|
if prev_idx >= 0 and idx > prev_idx + 1: |
|
|
snippet_parts.append("...") |
|
|
snippet_parts.append(sent) |
|
|
prev_idx = idx |
|
|
|
|
|
if prev_idx < len(sentences) - 1: |
|
|
snippet_parts.append("...") |
|
|
|
|
|
|
|
|
debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]]) |
|
|
for i in range(min(5, len(ranked_indices)))] |
|
|
|
|
|
return " ".join(snippet_parts), debug_info |
|
|
|
|
|
|
|
|
|
|
|
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base") |
|
|
|
|
|
st.markdown(""" |
|
|
How much of your page will be used to ground the model for a particular fanout query? |
|
|
""") |
|
|
|
|
|
query = st.text_input("Query", value="best prostate cancer treatment in the world") |
|
|
|
|
|
document = st.text_area( |
|
|
"Web Page Text", |
|
|
height=250, |
|
|
placeholder="Paste document content here..." |
|
|
) |
|
|
|
|
|
with st.expander("Settings"): |
|
|
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50) |
|
|
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES) |
|
|
show_debug = st.checkbox("Show debug info", value=True) |
|
|
|
|
|
if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"): |
|
|
if query and document: |
|
|
with st.spinner("Loading model & scoring sentences..."): |
|
|
model = load_model() |
|
|
snippet, debug = generate_snippet(query, document, model, max_chars, max_sents) |
|
|
|
|
|
st.subheader("Generated Snippet") |
|
|
st.code(snippet, language=None) |
|
|
|
|
|
if show_debug and debug: |
|
|
st.markdown("---") |
|
|
st.write("**Top sentences by score:**") |
|
|
for score, sent in debug: |
|
|
st.text(f"{score:.4f}: {sent[:80]}...") |
|
|
else: |
|
|
st.warning("Please enter both a query and document.") |
|
|
|