File size: 4,942 Bytes
848b652
 
 
 
 
 
9ac4279
 
848b652
 
9ac4279
848b652
 
 
 
9ac4279
848b652
 
 
 
 
9ac4279
 
848b652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
Uses MS MARCO Cross-Encoder for search relevance ranking.
"""

import re
import numpy as np
import streamlit as st
import torch
from sentence_transformers import CrossEncoder

# --- Configuration ---
MODEL_NAME = "cross-encoder/ms-marco-electra-base"
MAX_SNIPPET_CHARS = 450
MAX_SENTENCES = 5

st.set_page_config(
    page_title="Snippet Generator",
    page_icon="βœ‚οΈ",
    layout="centered"
)


@st.cache_resource
def load_model():
    """Load CrossEncoder model."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CrossEncoder(MODEL_NAME, device=device)
    return model


def segment_sentences(text: str) -> list[str]:
    """Sentence segmentation with deduplication and filtering."""
    # Split on sentence boundaries AND newlines
    pattern = r'(?<=[.!?])\s+|\n+'
    raw_sentences = re.split(pattern, text)
    
    seen = set()
    sentences = []
    
    for s in raw_sentences:
        s = s.strip()
        
        if not s or len(s) < 20:
            continue
        
        if s.startswith('http') or s.startswith('URL:'):
            continue
        
        # Skip low-alpha content (metadata, tables, prices)
        alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
        if alpha_ratio < 0.5:
            continue
        
        # Skip questions
        if s.endswith('?'):
            continue
            
        normalized = ' '.join(s.lower().split())
        if normalized in seen:
            continue
        seen.add(normalized)
        
        sentences.append(s)
    
    return sentences


def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
    """Generate snippet using Cross-Encoder scoring."""
    sentences = segment_sentences(document)
    
    if not sentences:
        return "", []
    
    # Cross-encoder: score query-sentence pairs
    pairs = [[query, sent] for sent in sentences]
    scores = model.predict(pairs)
    
    ranked_indices = np.argsort(scores)[::-1]
    
    # Select with budget
    selected = []
    total_length = 0
    
    for idx in ranked_indices:
        sent = sentences[idx]
        if total_length + len(sent) <= max_chars and len(selected) < max_sents:
            selected.append((idx, sent, scores[idx]))
            total_length += len(sent)
    
    if not selected:
        best_idx = ranked_indices[0]
        return sentences[best_idx][:max_chars] + "...", []
    
    # Sort by document order
    selected.sort(key=lambda x: x[0])
    
    # Stitch with ellipsis for gaps
    snippet_parts = []
    prev_idx = -1
    
    for idx, sent, _ in selected:
        if prev_idx >= 0 and idx > prev_idx + 1:
            snippet_parts.append("...")
        snippet_parts.append(sent)
        prev_idx = idx
    
    if prev_idx < len(sentences) - 1:
        snippet_parts.append("...")
    
    # Debug info
    debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]]) 
                  for i in range(min(5, len(ranked_indices)))]
    
    return " ".join(snippet_parts), debug_info


# --- Streamlit UI ---
st.title("βœ‚οΈ Snippet Generator")
st.caption("Recreates Google Vertex AI / Gemini grounding-style snippets")

st.markdown("""
This tool generates extractive snippets from documents using a Cross-Encoder model trained on MS MARCO search relevance data.

**How it works:**
1. Segments document into sentences
2. Scores each sentence against your query using `cross-encoder/ms-marco-electra-base`
3. Selects top-scoring sentences within budget
4. Stitches them in document order with `...` for gaps
""")

st.markdown("---")

query = st.text_input("πŸ” Query", value="best prostate cancer treatment in the world")

document = st.text_area(
    "πŸ“„ Document", 
    height=250,
    placeholder="Paste document content here..."
)

with st.expander("βš™οΈ Settings"):
    max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
    max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
    show_debug = st.checkbox("Show debug info", value=True)

if st.button("Generate Snippet", type="primary"):
    if query and document:
        with st.spinner("Loading model & scoring sentences..."):
            model = load_model()
            snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
        
        st.subheader("Generated Snippet")
        st.code(snippet, language=None)
        
        if show_debug and debug:
            st.markdown("---")
            st.write("**Top sentences by score:**")
            for score, sent in debug:
                st.text(f"{score:.4f}: {sent[:80]}...")
    else:
        st.warning("Please enter both a query and document.")

st.markdown("---")
st.caption("Model: `cross-encoder/ms-marco-electra-base` | [GitHub](https://github.com/UKPLab/sentence-transformers)")