|
|
""" |
|
|
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets |
|
|
Uses MS MARCO Cross-Encoder for search relevance ranking. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
import torch |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
MODEL_NAME = "cross-encoder/ms-marco-electra-base" |
|
|
MAX_SNIPPET_CHARS = 450 |
|
|
MAX_SENTENCES = 5 |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Snippet Generator", |
|
|
page_icon="βοΈ", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load CrossEncoder model.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = CrossEncoder(MODEL_NAME, device=device) |
|
|
return model |
|
|
|
|
|
|
|
|
def segment_sentences(text: str) -> list[str]: |
|
|
"""Sentence segmentation with deduplication and filtering.""" |
|
|
|
|
|
pattern = r'(?<=[.!?])\s+|\n+' |
|
|
raw_sentences = re.split(pattern, text) |
|
|
|
|
|
seen = set() |
|
|
sentences = [] |
|
|
|
|
|
for s in raw_sentences: |
|
|
s = s.strip() |
|
|
|
|
|
if not s or len(s) < 20: |
|
|
continue |
|
|
|
|
|
if s.startswith('http') or s.startswith('URL:'): |
|
|
continue |
|
|
|
|
|
|
|
|
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1) |
|
|
if alpha_ratio < 0.5: |
|
|
continue |
|
|
|
|
|
|
|
|
if s.endswith('?'): |
|
|
continue |
|
|
|
|
|
normalized = ' '.join(s.lower().split()) |
|
|
if normalized in seen: |
|
|
continue |
|
|
seen.add(normalized) |
|
|
|
|
|
sentences.append(s) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]: |
|
|
"""Generate snippet using Cross-Encoder scoring.""" |
|
|
sentences = segment_sentences(document) |
|
|
|
|
|
if not sentences: |
|
|
return "", [] |
|
|
|
|
|
|
|
|
pairs = [[query, sent] for sent in sentences] |
|
|
scores = model.predict(pairs) |
|
|
|
|
|
ranked_indices = np.argsort(scores)[::-1] |
|
|
|
|
|
|
|
|
selected = [] |
|
|
total_length = 0 |
|
|
|
|
|
for idx in ranked_indices: |
|
|
sent = sentences[idx] |
|
|
if total_length + len(sent) <= max_chars and len(selected) < max_sents: |
|
|
selected.append((idx, sent, scores[idx])) |
|
|
total_length += len(sent) |
|
|
|
|
|
if not selected: |
|
|
best_idx = ranked_indices[0] |
|
|
return sentences[best_idx][:max_chars] + "...", [] |
|
|
|
|
|
|
|
|
selected.sort(key=lambda x: x[0]) |
|
|
|
|
|
|
|
|
snippet_parts = [] |
|
|
prev_idx = -1 |
|
|
|
|
|
for idx, sent, _ in selected: |
|
|
if prev_idx >= 0 and idx > prev_idx + 1: |
|
|
snippet_parts.append("...") |
|
|
snippet_parts.append(sent) |
|
|
prev_idx = idx |
|
|
|
|
|
if prev_idx < len(sentences) - 1: |
|
|
snippet_parts.append("...") |
|
|
|
|
|
|
|
|
debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]]) |
|
|
for i in range(min(5, len(ranked_indices)))] |
|
|
|
|
|
return " ".join(snippet_parts), debug_info |
|
|
|
|
|
|
|
|
|
|
|
st.title("βοΈ Snippet Generator") |
|
|
st.caption("Recreates Google Vertex AI / Gemini grounding-style snippets") |
|
|
|
|
|
st.markdown(""" |
|
|
This tool generates extractive snippets from documents using a Cross-Encoder model trained on MS MARCO search relevance data. |
|
|
|
|
|
**How it works:** |
|
|
1. Segments document into sentences |
|
|
2. Scores each sentence against your query using `cross-encoder/ms-marco-electra-base` |
|
|
3. Selects top-scoring sentences within budget |
|
|
4. Stitches them in document order with `...` for gaps |
|
|
""") |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
query = st.text_input("π Query", value="best prostate cancer treatment in the world") |
|
|
|
|
|
document = st.text_area( |
|
|
"π Document", |
|
|
height=250, |
|
|
placeholder="Paste document content here..." |
|
|
) |
|
|
|
|
|
with st.expander("βοΈ Settings"): |
|
|
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50) |
|
|
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES) |
|
|
show_debug = st.checkbox("Show debug info", value=True) |
|
|
|
|
|
if st.button("Generate Snippet", type="primary"): |
|
|
if query and document: |
|
|
with st.spinner("Loading model & scoring sentences..."): |
|
|
model = load_model() |
|
|
snippet, debug = generate_snippet(query, document, model, max_chars, max_sents) |
|
|
|
|
|
st.subheader("Generated Snippet") |
|
|
st.code(snippet, language=None) |
|
|
|
|
|
if show_debug and debug: |
|
|
st.markdown("---") |
|
|
st.write("**Top sentences by score:**") |
|
|
for score, sent in debug: |
|
|
st.text(f"{score:.4f}: {sent[:80]}...") |
|
|
else: |
|
|
st.warning("Please enter both a query and document.") |
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("Model: `cross-encoder/ms-marco-electra-base` | [GitHub](https://github.com/UKPLab/sentence-transformers)") |