""" Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets Uses MS MARCO Cross-Encoder for search relevance ranking. """ import re import numpy as np import streamlit as st import torch from sentence_transformers import CrossEncoder # --- Configuration --- MODEL_NAME = "cross-encoder/ms-marco-electra-base" MAX_SNIPPET_CHARS = 450 MAX_SENTENCES = 5 st.logo( image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", link="https://dejan.ai/", size="large" ) st.set_page_config( page_title="Snippet Generator by DEJAN AI", page_icon="✂️", layout="centered" ) @st.cache_resource def load_model(): """Load CrossEncoder model.""" device = "cuda" if torch.cuda.is_available() else "cpu" model = CrossEncoder(MODEL_NAME, device=device) return model def segment_sentences(text: str) -> list[str]: """Sentence segmentation with deduplication and filtering.""" # Split on sentence boundaries AND newlines pattern = r'(?<=[.!?])\s+|\n+' raw_sentences = re.split(pattern, text) seen = set() sentences = [] for s in raw_sentences: s = s.strip() if not s or len(s) < 20: continue if s.startswith('http') or s.startswith('URL:'): continue # Skip low-alpha content (metadata, tables, prices) alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1) if alpha_ratio < 0.5: continue # Skip questions if s.endswith('?'): continue normalized = ' '.join(s.lower().split()) if normalized in seen: continue seen.add(normalized) sentences.append(s) return sentences def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]: """Generate snippet using Cross-Encoder scoring.""" sentences = segment_sentences(document) if not sentences: return "", [] # Cross-encoder: score query-sentence pairs pairs = [[query, sent] for sent in sentences] scores = model.predict(pairs) ranked_indices = np.argsort(scores)[::-1] # Select with budget selected = [] total_length = 0 for idx in ranked_indices: sent = sentences[idx] if total_length + len(sent) <= max_chars and len(selected) < max_sents: selected.append((idx, sent, scores[idx])) total_length += len(sent) if not selected: best_idx = ranked_indices[0] return sentences[best_idx][:max_chars] + "...", [] # Sort by document order selected.sort(key=lambda x: x[0]) # Stitch with ellipsis for gaps snippet_parts = [] prev_idx = -1 for idx, sent, _ in selected: if prev_idx >= 0 and idx > prev_idx + 1: snippet_parts.append("...") snippet_parts.append(sent) prev_idx = idx if prev_idx < len(sentences) - 1: snippet_parts.append("...") # Debug info debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]]) for i in range(min(5, len(ranked_indices)))] return " ".join(snippet_parts), debug_info # --- Streamlit UI --- st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base") st.markdown(""" How much of your page will be used to ground the model for a particular fanout query? """) query = st.text_input("Query", value="best prostate cancer treatment in the world") document = st.text_area( "Web Page Text", height=250, placeholder="Paste document content here..." ) with st.expander("Settings"): max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50) max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES) show_debug = st.checkbox("Show debug info", value=True) if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"): if query and document: with st.spinner("Loading model & scoring sentences..."): model = load_model() snippet, debug = generate_snippet(query, document, model, max_chars, max_sents) st.subheader("Generated Snippet") st.code(snippet, language=None) if show_debug and debug: st.markdown("---") st.write("**Top sentences by score:**") for score, sent in debug: st.text(f"{score:.4f}: {sent[:80]}...") else: st.warning("Please enter both a query and document.")