File size: 4,681 Bytes
848b652
 
 
 
 
 
9ac4279
 
848b652
 
9ac4279
848b652
 
 
 
9ac4279
aef81ea
 
 
 
 
 
848b652
aef81ea
848b652
 
 
9ac4279
848b652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aef81ea
848b652
a6002dd
 
848b652
a6002dd
848b652
 
aef81ea
848b652
a6002dd
848b652
 
aef81ea
848b652
 
 
 
aef81ea
848b652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""

import re
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder

# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5

st.logo(
    image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
    link="https://dejan.ai/",
    size="large"
)

st.set_page_config(
    page_title="Snippet Generator by DEJAN AI",
    page_icon="✂️",
    layout="centered"
)

@st.cache_resource
def load_model():
    """Load CrossEncoder model."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CrossEncoder(MODEL_NAME, device=device)
    return model


def segment_sentences(text: str) -> list[str]:
    """Sentence segmentation with deduplication and filtering."""
    # Split on sentence boundaries AND newlines
    pattern = r'(?<=[.!?])\s+|\n+'
    raw_sentences = re.split(pattern, text)
    
    seen = set()
    sentences = []
    
    for s in raw_sentences:
        s = s.strip()
        
        if not s or len(s) < 20:
            continue
        
        if s.startswith('http') or s.startswith('URL:'):
            continue
        
        # Skip low-alpha content (metadata, tables, prices)
        alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
        if alpha_ratio < 0.5:
            continue
        
        # Skip questions
        if s.endswith('?'):
            continue
            
        normalized = ' '.join(s.lower().split())
        if normalized in seen:
            continue
        seen.add(normalized)
        
        sentences.append(s)
    
    return sentences


def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
    """Generate snippet using Cross-Encoder scoring."""
    sentences = segment_sentences(document)
    
    if not sentences:
        return "", []
    
    # Cross-encoder: score query-sentence pairs
    pairs = [[query, sent] for sent in sentences]
    scores = model.predict(pairs)
    
    ranked_indices = np.argsort(scores)[::-1]
    
    # Select with budget
    selected = []
    total_length = 0
    
    for idx in ranked_indices:
        sent = sentences[idx]
        if total_length + len(sent) <= max_chars and len(selected) < max_sents:
            selected.append((idx, sent, scores[idx]))
            total_length += len(sent)
    
    if not selected:
        best_idx = ranked_indices[0]
        return sentences[best_idx][:max_chars] + "...", []
    
    # Sort by document order
    selected.sort(key=lambda x: x[0])
    
    # Stitch with ellipsis for gaps
    snippet_parts = []
    prev_idx = -1
    
    for idx, sent, _ in selected:
        if prev_idx >= 0 and idx > prev_idx + 1:
            snippet_parts.append("...")
        snippet_parts.append(sent)
        prev_idx = idx
    
    if prev_idx < len(sentences) - 1:
        snippet_parts.append("...")
    
    # Debug info
    debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]]) 
                  for i in range(min(5, len(ranked_indices)))]
    
    return " ".join(snippet_parts), debug_info


# --- Streamlit UI ---
st.title("Grounding Snippet Generator", help="cross-encoder/ms-marco-electra-base")

st.write("How much of your page will be used to ground the model for a particular fanout query?")
st.write("Full Context: https://dejan.ai/blog/ai-search-filter/")

query = st.text_input("Query", placeholder="enter a search query...")

document = st.text_area(
    "Web Page Text", 
    height=250,
    placeholder="Paste the full page content here..."
)

with st.expander("Settings"):
    max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
    max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
    show_debug = st.checkbox("Show debug info", value=True)

if st.button("Generate Snippet", help="cross-encoder/ms-marco-electra-base"):
    if query and document:
        with st.spinner("Loading model & scoring sentences..."):
            model = load_model()
            snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
        
        st.subheader("Generated Snippet")
        st.code(snippet, language=None)
        
        if show_debug and debug:
            st.markdown("---")
            st.write("**Top sentences by score:**")
            for score, sent in debug:
                st.text(f"{score:.4f}: {sent[:80]}...")
    else:
        st.warning("Please enter both a query and document.")