File size: 4,942 Bytes
848b652 9ac4279 848b652 9ac4279 848b652 9ac4279 848b652 9ac4279 848b652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""
import re
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder
# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5
st.set_page_config(
page_title="Snippet Generator",
page_icon="βοΈ",
layout="centered"
)
@st.cache_resource
def load_model():
"""Load CrossEncoder model."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(MODEL_NAME, device=device)
return model
def segment_sentences(text: str) -> list[str]:
"""Sentence segmentation with deduplication and filtering."""
# Split on sentence boundaries AND newlines
pattern = r'(?<=[.!?])\s+|\n+'
raw_sentences = re.split(pattern, text)
seen = set()
sentences = []
for s in raw_sentences:
s = s.strip()
if not s or len(s) < 20:
continue
if s.startswith('http') or s.startswith('URL:'):
continue
# Skip low-alpha content (metadata, tables, prices)
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
if alpha_ratio < 0.5:
continue
# Skip questions
if s.endswith('?'):
continue
normalized = ' '.join(s.lower().split())
if normalized in seen:
continue
seen.add(normalized)
sentences.append(s)
return sentences
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
"""Generate snippet using Cross-Encoder scoring."""
sentences = segment_sentences(document)
if not sentences:
return "", []
# Cross-encoder: score query-sentence pairs
pairs = [[query, sent] for sent in sentences]
scores = model.predict(pairs)
ranked_indices = np.argsort(scores)[::-1]
# Select with budget
selected = []
total_length = 0
for idx in ranked_indices:
sent = sentences[idx]
if total_length + len(sent) <= max_chars and len(selected) < max_sents:
selected.append((idx, sent, scores[idx]))
total_length += len(sent)
if not selected:
best_idx = ranked_indices[0]
return sentences[best_idx][:max_chars] + "...", []
# Sort by document order
selected.sort(key=lambda x: x[0])
# Stitch with ellipsis for gaps
snippet_parts = []
prev_idx = -1
for idx, sent, _ in selected:
if prev_idx >= 0 and idx > prev_idx + 1:
snippet_parts.append("...")
snippet_parts.append(sent)
prev_idx = idx
if prev_idx < len(sentences) - 1:
snippet_parts.append("...")
# Debug info
debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]])
for i in range(min(5, len(ranked_indices)))]
return " ".join(snippet_parts), debug_info
# --- Streamlit UI ---
st.title("βοΈ Snippet Generator")
st.caption("Recreates Google Vertex AI / Gemini grounding-style snippets")
st.markdown("""
This tool generates extractive snippets from documents using a Cross-Encoder model trained on MS MARCO search relevance data.
**How it works:**
1. Segments document into sentences
2. Scores each sentence against your query using `cross-encoder/ms-marco-electra-base`
3. Selects top-scoring sentences within budget
4. Stitches them in document order with `...` for gaps
""")
st.markdown("---")
query = st.text_input("π Query", value="best prostate cancer treatment in the world")
document = st.text_area(
"π Document",
height=250,
placeholder="Paste document content here..."
)
with st.expander("βοΈ Settings"):
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
show_debug = st.checkbox("Show debug info", value=True)
if st.button("Generate Snippet", type="primary"):
if query and document:
with st.spinner("Loading model & scoring sentences..."):
model = load_model()
snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
st.subheader("Generated Snippet")
st.code(snippet, language=None)
if show_debug and debug:
st.markdown("---")
st.write("**Top sentences by score:**")
for score, sent in debug:
st.text(f"{score:.4f}: {sent[:80]}...")
else:
st.warning("Please enter both a query and document.")
st.markdown("---")
st.caption("Model: `cross-encoder/ms-marco-electra-base` | [GitHub](https://github.com/UKPLab/sentence-transformers)") |