""" Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets Uses MS MARCO Cross-Encoder for search relevance ranking. """ import re import html import numpy as np import streamlit as st import torch from sentence_transformers import CrossEncoder # --- Configuration --- MODEL_NAME = "cross-encoder/ms-marco-electra-base" MAX_SNIPPET_CHARS = 450 MAX_SENTENCES = 5 st.set_page_config( page_title="Snippet Generator by DEJAN AI", page_icon="✂️", layout="centered" ) st.logo( image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", link="https://dejan.ai/", size="large" ) # --- Session State --- if "results_mode" not in st.session_state: st.session_state.results_mode = False if "snippet_data" not in st.session_state: st.session_state.snippet_data = None @st.cache_resource def load_model(): """Load CrossEncoder model.""" device = "cuda" if torch.cuda.is_available() else "cpu" model = CrossEncoder(MODEL_NAME, device=device) return model def segment_sentences(text: str) -> list[str]: """Sentence segmentation with deduplication and filtering.""" pattern = r'(?<=[.!?])\s+|\n+' raw_sentences = re.split(pattern, text) seen = set() sentences = [] for s in raw_sentences: s = s.strip() if not s or len(s) < 20: continue if s.startswith('http') or s.startswith('URL:'): continue alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1) if alpha_ratio < 0.5: continue if s.endswith('?'): continue normalized = ' '.join(s.lower().split()) if normalized in seen: continue seen.add(normalized) sentences.append(s) return sentences def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> dict: """Generate snippet using Cross-Encoder scoring. Returns full analysis.""" sentences = segment_sentences(document) if not sentences: return {"snippet": "", "selected_sentences": [], "all_sentences": [], "scores": []} pairs = [[query, sent] for sent in sentences] scores = model.predict(pairs) ranked_indices = np.argsort(scores)[::-1] selected_indices = [] total_length = 0 for idx in ranked_indices: sent = sentences[idx] if total_length + len(sent) <= max_chars and len(selected_indices) < max_sents: selected_indices.append(idx) total_length += len(sent) if not selected_indices: best_idx = ranked_indices[0] return { "snippet": sentences[best_idx][:max_chars] + "...", "selected_sentences": [sentences[best_idx][:max_chars]], "all_sentences": sentences, "scores": scores.tolist(), "selected_indices": [best_idx] } selected_indices.sort() selected_sentences = [sentences[i] for i in selected_indices] snippet_parts = [] prev_idx = -1 for idx in selected_indices: if prev_idx >= 0 and idx > prev_idx + 1: snippet_parts.append("...") snippet_parts.append(sentences[idx]) prev_idx = idx if prev_idx < len(sentences) - 1: snippet_parts.append("...") return { "snippet": " ".join(snippet_parts), "selected_sentences": selected_sentences, "all_sentences": sentences, "scores": scores.tolist(), "selected_indices": selected_indices } def render_highlighted_html(document: str, selected_sentences: list[str]) -> str: """Render document as HTML with highlighted sentences. Uses html.escape() for safety.""" # Find positions of selected sentences highlights = [] for sent in selected_sentences: start = document.find(sent) if start != -1: highlights.append((start, start + len(sent))) continue sent_pattern = r'\s+'.join(re.escape(word) for word in sent.split()) match = re.search(sent_pattern, document) if match: highlights.append((match.start(), match.end())) highlights.sort(key=lambda x: x[0]) # Merge overlapping merged = [] for start, end in highlights: if merged and start <= merged[-1][1]: merged[-1] = (merged[-1][0], max(merged[-1][1], end)) else: merged.append((start, end)) # Build HTML with proper escaping parts = [] pos = 0 for start, end in merged: # Non-selected: gray text if pos < start: text = html.escape(document[pos:start]) parts.append(f'{text}') # Selected: green highlight text = html.escape(document[start:end]) parts.append(f'{text}') pos = end # Remaining non-selected if pos < len(document): text = html.escape(document[pos:]) parts.append(f'{text}') return "".join(parts) def generate_regex_pattern(selected_sentences: list[str]) -> str: """Generate regex pattern for matching selected snippets.""" if not selected_sentences: return "" escaped = [re.escape(sent) for sent in selected_sentences] return r'[\s\S]*?'.join(escaped) def reset_to_input(): st.session_state.results_mode = False st.session_state.snippet_data = None # --- Main UI --- st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base") if not st.session_state.results_mode: # === INPUT MODE === st.write("How much of your page will be used to ground the model for a particular fanout query?") st.write("Full Context: https://dejan.ai/blog/ai-search-filter/") query = st.text_input("Query", placeholder="enter a search query...") document = st.text_area( "Web Page Text", height=250, placeholder="Paste the full page content here..." ) with st.expander("Settings"): max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50) max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES) if st.button("Generate Snippet", type="primary"): if query and document: with st.spinner("Loading model & scoring sentences..."): model = load_model() result = generate_snippet(query, document, model, max_chars, max_sents) st.session_state.snippet_data = { "query": query, "document": document, "result": result, "max_chars": max_chars, "max_sents": max_sents } st.session_state.results_mode = True st.rerun() else: st.warning("Please enter both a query and document.") else: # === RESULTS MODE === data = st.session_state.snippet_data query = data["query"] document = data["document"] result = data["result"] if st.button("← New Analysis"): reset_to_input() st.rerun() # Stats snippet_chars = sum(len(s) for s in result["selected_sentences"]) doc_chars = len(document) pct = (snippet_chars / doc_chars * 100) if doc_chars > 0 else 0 st.caption(f"{snippet_chars:,} / {doc_chars:,} chars ({pct:.1f}%) • {len(result['selected_sentences'])} sentences") # Query - use st.html to prevent any rendering issues st.html(f'

{html.escape(query)}

') # Generated snippet st.subheader("Generated Snippet") st.code(result["snippet"], wrap_lines=True, language=None) # Highlighted document - st.html() does NO markdown parsing highlighted = render_highlighted_html(document, result["selected_sentences"]) st.html(f'''
{highlighted}
''') st.caption("🟢 Green = included in snippet") # Regex pattern with st.expander("📋 Regex Pattern"): regex = generate_regex_pattern(result["selected_sentences"]) st.code(regex, language=None, wrap_lines=True) st.caption("Match selected snippets in other tools.")