dejanseo commited on
Commit
fd3e951
·
verified ·
1 Parent(s): fcc0e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -6,9 +6,11 @@ import logging
6
  from dataclasses import dataclass
7
  from typing import Optional, Dict, List, Tuple
8
 
9
- # --- HIDE STREAMLIT MENU ---
10
  st.set_page_config(
11
- initial_sidebar_state="collapsed"
 
 
12
  )
13
 
14
  hide_streamlit_style = """
@@ -86,8 +88,8 @@ def windowize_inference(
86
  total_tokens = len(full_encoding["input_ids"])
87
 
88
  if total_tokens == 0 and len(plain_text) > 0:
89
- logger.warning("Tokenizer produced 0 tokens for a non-empty string.")
90
- return []
91
 
92
  while start_token_idx < total_tokens:
93
  end_token_idx = min(start_token_idx + cap, total_tokens)
@@ -168,8 +170,8 @@ def classify_text(
168
 
169
  for i, word_id in enumerate(word_ids):
170
  if word_id is not None and i < len(offsets):
171
- start_char, end_char = offsets[i]
172
- if start_char < end_char:
173
  current_token_max_prob = np.max(char_link_probabilities[start_char:end_char]) if start_char < len(char_link_probabilities) else 0.0
174
 
175
  if word_id not in word_max_prob_map:
@@ -209,7 +211,8 @@ def classify_text(
209
  base_text_color = "#155724"
210
 
211
  html_parts.append(f"<span style='background-color: {base_bg_color}; color: {base_text_color}; "
212
- f"padding: 0.1em 0.2em; border-radius: 0.2em; opacity: {normalized_opacity:.2f};'>"
 
213
  f"{word_text}</span>")
214
  else:
215
  html_parts.append(word_text)
@@ -223,12 +226,11 @@ def classify_text(
223
  # ----------------------------------
224
  # Streamlit UI
225
  # ----------------------------------
226
- st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
227
  st.title("LinkBERT")
228
 
229
  DEFAULT_THRESHOLD = 70.0
230
  THRESHOLD_STEP = 10.0
231
- THRESHOLD_BOUNDARY_PERCENT = 10.0
232
 
233
  if 'current_threshold' not in st.session_state:
234
  st.session_state.current_threshold = DEFAULT_THRESHOLD
@@ -258,24 +260,34 @@ def run_classification(new_threshold: float):
258
  st.warning("Please enter some text to classify.")
259
  st.session_state.output_html = ""
260
  else:
261
- with st.spinner("Processing..."):
262
  html, warning = classify_text(st.session_state.user_input, st.session_state.current_threshold)
263
  if warning: st.warning(warning)
264
  st.session_state.output_html = html
265
  st.rerun()
266
 
267
- if st.button("Classify Text", type="primary"):
268
  run_classification(slider_threshold)
269
 
270
  if st.session_state.output_html:
271
  st.markdown("---")
272
- st.subheader(f"Results (Threshold: {st.session_state.current_threshold:.1f}%)")
273
  st.markdown(st.session_state.output_html, unsafe_allow_html=True)
 
 
 
 
 
 
274
 
275
  col1, col2, col3 = st.columns(3)
276
 
277
  with col1:
278
- if st.button("Less", icon=":material/playlist_remove:", use_container_width=True, disabled=not st.session_state.output_html):
 
 
 
 
 
279
  current_thr = st.session_state.current_threshold
280
  if current_thr >= (100.0 - THRESHOLD_BOUNDARY_PERCENT):
281
  new_threshold = current_thr + (100.0 - current_thr) / 2.0
@@ -284,14 +296,24 @@ if st.session_state.output_html:
284
  run_classification(min(100.0, new_threshold))
285
 
286
  with col2:
287
- if st.button("Default", icon=":material/notes:", use_container_width=True, disabled=not st.session_state.output_html):
 
 
 
 
 
288
  run_classification(DEFAULT_THRESHOLD)
289
 
290
  with col3:
291
- if st.button("More", icon=":material/docs_add_on:", use_container_width=True, disabled=not st.session_state.output_html):
 
 
 
 
 
292
  current_thr = st.session_state.current_threshold
293
  if current_thr <= THRESHOLD_BOUNDARY_PERCENT:
294
  new_threshold = current_thr / 2.0
295
  else:
296
  new_threshold = current_thr - THRESHOLD_STEP
297
- run_classification(max(0.0, new_threshold))
 
6
  from dataclasses import dataclass
7
  from typing import Optional, Dict, List, Tuple
8
 
9
+ # --- HIDE STREAMLIT MENU / PAGE CONFIG ---
10
  st.set_page_config(
11
+ initial_sidebar_state="collapsed",
12
+ layout="wide",
13
+ page_title="LinkBERT by DEJAN AI"
14
  )
15
 
16
  hide_streamlit_style = """
 
88
  total_tokens = len(full_encoding["input_ids"])
89
 
90
  if total_tokens == 0 and len(plain_text) > 0:
91
+ logger.warning("Tokenizer produced 0 tokens for a non-empty string.")
92
+ return []
93
 
94
  while start_token_idx < total_tokens:
95
  end_token_idx = min(start_token_idx + cap, total_tokens)
 
170
 
171
  for i, word_id in enumerate(word_ids):
172
  if word_id is not None and i < len(offsets):
173
+ start_char, end_char = offsets[i]
174
+ if start_char < end_char:
175
  current_token_max_prob = np.max(char_link_probabilities[start_char:end_char]) if start_char < len(char_link_probabilities) else 0.0
176
 
177
  if word_id not in word_max_prob_map:
 
211
  base_text_color = "#155724"
212
 
213
  html_parts.append(f"<span style='background-color: {base_bg_color}; color: {base_text_color}; "
214
+ f"padding: 0.1em 0.2em; border-radius: 0.2em; opacity: {normalized_opacity:.2f};' "
215
+ f"title='Link Probability: {word_prob:.1%}'>"
216
  f"{word_text}</span>")
217
  else:
218
  html_parts.append(word_text)
 
226
  # ----------------------------------
227
  # Streamlit UI
228
  # ----------------------------------
 
229
  st.title("LinkBERT")
230
 
231
  DEFAULT_THRESHOLD = 70.0
232
  THRESHOLD_STEP = 10.0
233
+ THRESHOLD_BOUNDARY_PERCENT = 10.0 # Top/Bottom 10% for finer control
234
 
235
  if 'current_threshold' not in st.session_state:
236
  st.session_state.current_threshold = DEFAULT_THRESHOLD
 
260
  st.warning("Please enter some text to classify.")
261
  st.session_state.output_html = ""
262
  else:
263
+ with st.spinner("Analyzing text..."):
264
  html, warning = classify_text(st.session_state.user_input, st.session_state.current_threshold)
265
  if warning: st.warning(warning)
266
  st.session_state.output_html = html
267
  st.rerun()
268
 
269
+ if st.button("Classify Text", type="primary", use_container_width=True):
270
  run_classification(slider_threshold)
271
 
272
  if st.session_state.output_html:
273
  st.markdown("---")
 
274
  st.markdown(st.session_state.output_html, unsafe_allow_html=True)
275
+ st.markdown("---")
276
+
277
+ st.markdown(
278
+ f"<p style='text-align: center;'>Confidence Threshold: {st.session_state.current_threshold:.1f}%</p>",
279
+ unsafe_allow_html=True
280
+ )
281
 
282
  col1, col2, col3 = st.columns(3)
283
 
284
  with col1:
285
+ if st.button(
286
+ "Less",
287
+ icon=":material/playlist_remove:",
288
+ use_container_width=True,
289
+ help="Show fewer, more probable links"
290
+ ):
291
  current_thr = st.session_state.current_threshold
292
  if current_thr >= (100.0 - THRESHOLD_BOUNDARY_PERCENT):
293
  new_threshold = current_thr + (100.0 - current_thr) / 2.0
 
296
  run_classification(min(100.0, new_threshold))
297
 
298
  with col2:
299
+ if st.button(
300
+ "Default",
301
+ icon=":material/notes:",
302
+ use_container_width=True,
303
+ help="Reset to default threshold (70%)"
304
+ ):
305
  run_classification(DEFAULT_THRESHOLD)
306
 
307
  with col3:
308
+ if st.button(
309
+ "More",
310
+ icon=":material/docs_add_on:",
311
+ use_container_width=True,
312
+ help="Show more potential links"
313
+ ):
314
  current_thr = st.session_state.current_threshold
315
  if current_thr <= THRESHOLD_BOUNDARY_PERCENT:
316
  new_threshold = current_thr / 2.0
317
  else:
318
  new_threshold = current_thr - THRESHOLD_STEP
319
+ run_classification(max(0.0, new_threshold))