File size: 4,681 Bytes
848b652 9ac4279 848b652 9ac4279 848b652 9ac4279 aef81ea 848b652 aef81ea 848b652 9ac4279 848b652 aef81ea 848b652 a6002dd 848b652 a6002dd 848b652 aef81ea 848b652 a6002dd 848b652 aef81ea 848b652 aef81ea 848b652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""
import re
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder
# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large"
)
st.set_page_config(
page_title="Snippet Generator by DEJAN AI",
page_icon="✂️",
layout="centered"
)
@st.cache_resource
def load_model():
"""Load CrossEncoder model."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(MODEL_NAME, device=device)
return model
def segment_sentences(text: str) -> list[str]:
"""Sentence segmentation with deduplication and filtering."""
# Split on sentence boundaries AND newlines
pattern = r'(?<=[.!?])\s+|\n+'
raw_sentences = re.split(pattern, text)
seen = set()
sentences = []
for s in raw_sentences:
s = s.strip()
if not s or len(s) < 20:
continue
if s.startswith('http') or s.startswith('URL:'):
continue
# Skip low-alpha content (metadata, tables, prices)
alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
if alpha_ratio < 0.5:
continue
# Skip questions
if s.endswith('?'):
continue
normalized = ' '.join(s.lower().split())
if normalized in seen:
continue
seen.add(normalized)
sentences.append(s)
return sentences
def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
"""Generate snippet using Cross-Encoder scoring."""
sentences = segment_sentences(document)
if not sentences:
return "", []
# Cross-encoder: score query-sentence pairs
pairs = [[query, sent] for sent in sentences]
scores = model.predict(pairs)
ranked_indices = np.argsort(scores)[::-1]
# Select with budget
selected = []
total_length = 0
for idx in ranked_indices:
sent = sentences[idx]
if total_length + len(sent) <= max_chars and len(selected) < max_sents:
selected.append((idx, sent, scores[idx]))
total_length += len(sent)
if not selected:
best_idx = ranked_indices[0]
return sentences[best_idx][:max_chars] + "...", []
# Sort by document order
selected.sort(key=lambda x: x[0])
# Stitch with ellipsis for gaps
snippet_parts = []
prev_idx = -1
for idx, sent, _ in selected:
if prev_idx >= 0 and idx > prev_idx + 1:
snippet_parts.append("...")
snippet_parts.append(sent)
prev_idx = idx
if prev_idx < len(sentences) - 1:
snippet_parts.append("...")
# Debug info
debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]])
for i in range(min(5, len(ranked_indices)))]
return " ".join(snippet_parts), debug_info
# --- Streamlit UI ---
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")
st.write("How much of your page will be used to ground the model for a particular fanout query?")
st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")
query = st.text_input("Query", placeholder="enter a search query...")
document = st.text_area(
"Web Page Text",
height=250,
placeholder="Paste the full page content here..."
)
with st.expander("Settings"):
max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
show_debug = st.checkbox("Show debug info", value=True)
if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"):
if query and document:
with st.spinner("Loading model & scoring sentences..."):
model = load_model()
snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
st.subheader("Generated Snippet")
st.code(snippet, language=None)
if show_debug and debug:
st.markdown("---")
st.write("**Top sentences by score:**")
for score, sent in debug:
st.text(f"{score:.4f}: {sent[:80]}...")
else:
st.warning("Please enter both a query and document.")
|