dejanseo commited on
Commit
848b652
Β·
verified Β·
1 Parent(s): e8fb8fd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +160 -34
src/streamlit_app.py CHANGED
@@ -1,40 +1,166 @@
1
- import altair as alt
 
 
 
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snippet Generator - Recreates Google Vertex AI/Gemini grounding snippets
3
+ Uses MS MARCO Cross-Encoder for search relevance ranking.
4
+ """
5
+
6
+ import re
7
  import numpy as np
 
8
  import streamlit as st
9
+ import torch
10
+ from sentence_transformers import CrossEncoder
11
 
12
+ # --- Configuration ---
13
+ MODEL_NAME = "cross-encoder/ms-marco-electra-base"
14
+ MAX_SNIPPET_CHARS = 450
15
+ MAX_SENTENCES = 5
16
 
17
+ st.set_page_config(
18
+ page_title="Snippet Generator",
19
+ page_icon="βœ‚οΈ",
20
+ layout="centered"
21
+ )
22
 
 
 
23
 
24
+ @st.cache_resource
25
+ def load_model():
26
+ """Load CrossEncoder model."""
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model = CrossEncoder(MODEL_NAME, device=device)
29
+ return model
30
+
31
+
32
+ def segment_sentences(text: str) -> list[str]:
33
+ """Sentence segmentation with deduplication and filtering."""
34
+ # Split on sentence boundaries AND newlines
35
+ pattern = r'(?<=[.!?])\s+|\n+'
36
+ raw_sentences = re.split(pattern, text)
37
+
38
+ seen = set()
39
+ sentences = []
40
+
41
+ for s in raw_sentences:
42
+ s = s.strip()
43
+
44
+ if not s or len(s) < 20:
45
+ continue
46
+
47
+ if s.startswith('http') or s.startswith('URL:'):
48
+ continue
49
+
50
+ # Skip low-alpha content (metadata, tables, prices)
51
+ alpha_ratio = sum(c.isalpha() for c in s) / max(len(s), 1)
52
+ if alpha_ratio < 0.5:
53
+ continue
54
+
55
+ # Skip questions
56
+ if s.endswith('?'):
57
+ continue
58
+
59
+ normalized = ' '.join(s.lower().split())
60
+ if normalized in seen:
61
+ continue
62
+ seen.add(normalized)
63
+
64
+ sentences.append(s)
65
+
66
+ return sentences
67
+
68
+
69
+ def generate_snippet(query: str, document: str, model, max_chars: int, max_sents: int) -> tuple[str, list]:
70
+ """Generate snippet using Cross-Encoder scoring."""
71
+ sentences = segment_sentences(document)
72
+
73
+ if not sentences:
74
+ return "", []
75
+
76
+ # Cross-encoder: score query-sentence pairs
77
+ pairs = [[query, sent] for sent in sentences]
78
+ scores = model.predict(pairs)
79
+
80
+ ranked_indices = np.argsort(scores)[::-1]
81
+
82
+ # Select with budget
83
+ selected = []
84
+ total_length = 0
85
+
86
+ for idx in ranked_indices:
87
+ sent = sentences[idx]
88
+ if total_length + len(sent) <= max_chars and len(selected) < max_sents:
89
+ selected.append((idx, sent, scores[idx]))
90
+ total_length += len(sent)
91
+
92
+ if not selected:
93
+ best_idx = ranked_indices[0]
94
+ return sentences[best_idx][:max_chars] + "...", []
95
+
96
+ # Sort by document order
97
+ selected.sort(key=lambda x: x[0])
98
+
99
+ # Stitch with ellipsis for gaps
100
+ snippet_parts = []
101
+ prev_idx = -1
102
+
103
+ for idx, sent, _ in selected:
104
+ if prev_idx >= 0 and idx > prev_idx + 1:
105
+ snippet_parts.append("...")
106
+ snippet_parts.append(sent)
107
+ prev_idx = idx
108
+
109
+ if prev_idx < len(sentences) - 1:
110
+ snippet_parts.append("...")
111
+
112
+ # Debug info
113
+ debug_info = [(scores[ranked_indices[i]], sentences[ranked_indices[i]])
114
+ for i in range(min(5, len(ranked_indices)))]
115
+
116
+ return " ".join(snippet_parts), debug_info
117
+
118
+
119
+ # --- Streamlit UI ---
120
+ st.title("βœ‚οΈ Snippet Generator")
121
+ st.caption("Recreates Google Vertex AI / Gemini grounding-style snippets")
122
+
123
+ st.markdown("""
124
+ This tool generates extractive snippets from documents using a Cross-Encoder model trained on MS MARCO search relevance data.
125
+
126
+ **How it works:**
127
+ 1. Segments document into sentences
128
+ 2. Scores each sentence against your query using `cross-encoder/ms-marco-electra-base`
129
+ 3. Selects top-scoring sentences within budget
130
+ 4. Stitches them in document order with `...` for gaps
131
+ """)
132
+
133
+ st.markdown("---")
134
+
135
+ query = st.text_input("πŸ” Query", value="best prostate cancer treatment in the world")
136
+
137
+ document = st.text_area(
138
+ "πŸ“„ Document",
139
+ height=250,
140
+ placeholder="Paste document content here..."
141
+ )
142
+
143
+ with st.expander("βš™οΈ Settings"):
144
+ max_chars = st.slider("Max snippet characters", 200, 1500, MAX_SNIPPET_CHARS, 50)
145
+ max_sents = st.slider("Max sentences", 2, 15, MAX_SENTENCES)
146
+ show_debug = st.checkbox("Show debug info", value=True)
147
+
148
+ if st.button("Generate Snippet", type="primary"):
149
+ if query and document:
150
+ with st.spinner("Loading model & scoring sentences..."):
151
+ model = load_model()
152
+ snippet, debug = generate_snippet(query, document, model, max_chars, max_sents)
153
+
154
+ st.subheader("Generated Snippet")
155
+ st.code(snippet, language=None)
156
+
157
+ if show_debug and debug:
158
+ st.markdown("---")
159
+ st.write("**Top sentences by score:**")
160
+ for score, sent in debug:
161
+ st.text(f"{score:.4f}: {sent[:80]}...")
162
+ else:
163
+ st.warning("Please enter both a query and document.")
164
+
165
+ st.markdown("---")
166
+ st.caption("Model: `cross-encoder/ms-marco-electra-base` | [GitHub](https://github.com/UKPLab/sentence-transformers)")