grounding-snippet-generator / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
a6002dd verified
raw
history blame
4.68 kB
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""
import re
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder
# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large"
)
st.set_page_config(
page_title="Snippet Generator by DEJAN AI",
page_icon="✂️",
layout="centered"
)
@st.cache_resource
def load_model():
"""Load CrossEncoder model."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(MODEL_NAME, device=device)
return model
def segment_sentences(text: str) -> list[str]:
"""Sentence segmentation with deduplication and filtering."""
# Split on sentence boundaries AND newlines
pattern = r'(?<=[.!?])\s+|\n+'
raw_sentences = re.split(pattern, text)
seen = set()
sentences = []
for s in raw_sentences:
s = s.strip()
if not s or len(s) < 20:
continue
if s.startswith('http') or s.startswith('URL:'):
continue
# Skip low-alpha content (metadata, tables, prices)
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
if alpha_ratio < 0.5:
continue
# Skip questions
if s.endswith('?'):
continue
normalized = ' '.join(s.lower().split())
if normalized in seen:
continue
seen.add(normalized)
sentences.append(s)
return sentences
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
"""Generate snippet using Cross-Encoder scoring."""
sentences = segment_sentences(document)
if not sentences:
return "", []
# Cross-encoder: score query-sentence pairs
pairs = [[query, sent] for sent in sentences]
scores = model.predict(pairs)
ranked_indices = np.argsort(scores)[::-1]
# Select with budget
selected = []
total_length = 0
for idx in ranked_indices:
sent = sentences[idx]
if total_length + len(sent) <= max_chars and len(selected) < max_sents:
selected.append((idx, sent, scores[idx]))
total_length += len(sent)
if not selected:
best_idx = ranked_indices[0]
return sentences[best_idx][:max_chars] + "...", []
# Sort by document order
selected.sort(key=lambda x: x[0])
# Stitch with ellipsis for gaps
snippet_parts = []
prev_idx = -1
for idx, sent, _ in selected:
if prev_idx >= 0 and idx > prev_idx + 1:
snippet_parts.append("...")
snippet_parts.append(sent)
prev_idx = idx
if prev_idx < len(sentences) - 1:
snippet_parts.append("...")
# Debug info
debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]])
for i in range(min(5, len(ranked_indices)))]
return " ".join(snippet_parts), debug_info
# --- Streamlit UI ---
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")
st.write("How much of your page will be used to ground the model for a particular fanout query?")
st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")
query = st.text_input("Query", placeholder="enter a search query...")
document = st.text_area(
"Web Page Text",
height=250,
placeholder="Paste the full page content here..."
)
with st.expander("Settings"):
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
show_debug = st.checkbox("Show debug info", value=True)
if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"):
if query and document:
with st.spinner("Loading model & scoring sentences..."):
model = load_model()
snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
st.subheader("Generated Snippet")
st.code(snippet, language=None)
if show_debug and debug:
st.markdown("---")
st.write("**Top sentences by score:**")
for score, sent in debug:
st.text(f"{score:.4f}: {sent[:80]}...")
else:
st.warning("Please enter both a query and document.")