google-links / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
c98b628 verified
#!/usr/bin/env python3
# app.py
# Streamlit app for link detection with word-level highlighting
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
st.set_page_config(page_title="Link Detection", page_icon="πŸ”—", layout="centered")
st.logo(
"https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
size="large",
link="https://dejan.ai",
)
@st.cache_resource
def load_model(model_path="dejanseo/google-links"):
"""Load model and tokenizer."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(model_path)
model = model.to(device)
model.eval()
return tokenizer, model, device
def group_tokens_into_words(tokens, offset_mapping, link_probs):
"""Group tokens into words based on tokenizer patterns."""
words = []
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)):
# Skip special tokens
if offsets == [0, 0]:
if current_word_tokens:
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
continue
# Check if this is a new word or continuation
is_new_word = False
# DeBERTa uses ▁ for word boundaries
if token.startswith("▁"):
is_new_word = True
# BERT uses ## for subword continuation
elif i == 0 or not token.startswith("##"):
# If previous token exists and doesn't indicate continuation
if i == 0 or offset_mapping[i-1] == [0, 0]:
is_new_word = True
# Check if there's a gap between tokens (indicates new word)
elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]:
is_new_word = True
if is_new_word and current_word_tokens:
# Save current word
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
# Add token to current word
current_word_tokens.append(token)
current_word_offsets.append(offsets)
current_word_probs.append(prob)
# Add last word if exists
if current_word_tokens:
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
return words
def predict_links(text, tokenizer, model, device,
max_length=512, doc_stride=128):
"""Predict link tokens with word-level highlighting using sliding windows."""
if not text.strip():
return []
# Tokenize full text without truncation or special tokens
full_enc = tokenizer(
text,
add_special_tokens=False,
truncation=False,
return_offsets_mapping=True,
)
all_ids = full_enc["input_ids"]
all_offsets = full_enc["offset_mapping"]
n_tokens = len(all_ids)
# Accumulate probabilities per token position (for averaging overlaps)
prob_sums = [0.0] * n_tokens
prob_counts = [0] * n_tokens
# Sliding window parameters (matching training _prep.py)
specials = 2 # CLS + SEP for DeBERTa
cap = max_length - specials # 510 content tokens per window
step = max(cap - doc_stride, 1) # 382
# Generate windows and run inference
start = 0
while start < n_tokens:
end = min(start + cap, n_tokens)
window_ids = all_ids[start:end]
# Add special tokens (CLS + content + SEP)
cls_id = tokenizer.cls_token_id or tokenizer.bos_token_id or 1
sep_id = tokenizer.sep_token_id or tokenizer.eos_token_id or 2
input_ids = torch.tensor(
[[cls_id] + window_ids + [sep_id]],
device=device
)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
probs = F.softmax(logits, dim=-1)[0].cpu()
# Skip special tokens (first and last) to get content probs
content_probs = probs[1:-1, 1].tolist()
# Map back to original token positions
for i, p in enumerate(content_probs):
orig_idx = start + i
if orig_idx < n_tokens:
prob_sums[orig_idx] += p
prob_counts[orig_idx] += 1
if end == n_tokens:
break
start += step
# Average probabilities across overlapping windows
link_probs = [
prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0
for i in range(n_tokens)
]
# Get tokens and offsets for word grouping
tokens = tokenizer.convert_ids_to_tokens(all_ids)
offset_mapping = [list(o) for o in all_offsets]
# Group tokens into words
words = group_tokens_into_words(tokens, offset_mapping, link_probs)
# Build word results with max confidence per word
# Opacity tiers: >=5% β†’ 1.0, >=4% β†’ 0.75, >=3% β†’ 0.5, >=2% β†’ 0.25
results = []
for word_group in words:
word_offsets = word_group['offsets']
word_probs = word_group['probs']
max_conf = max(word_probs)
if max_conf >= 0.02:
start = word_offsets[0][0]
end = word_offsets[-1][1]
if max_conf >= 0.05:
opacity = 1.0
elif max_conf >= 0.04:
opacity = 0.75
elif max_conf >= 0.03:
opacity = 0.5
else:
opacity = 0.25
results.append({
"start": start,
"end": end,
"opacity": opacity,
"confidence": round(max_conf, 4),
})
return results
def render_highlighted_text(text, word_results):
"""Render text with opacity-tiered green highlights."""
if not text:
return ""
# Sort spans by start position
spans = sorted(word_results, key=lambda x: x["start"])
html_parts = []
last_end = 0
for span in spans:
start, end, opacity = span["start"], span["end"], span["opacity"]
if start > last_end:
html_parts.append(text[last_end:start])
html_parts.append(
f'<span style="background-color: rgba(46, 125, 50, {opacity}); '
f'color: {"#fff" if opacity >= 0.75 else "#1A1A1A"}; padding: 2px 4px; '
f'border-radius: 3px; font-weight: 500;">{text[start:end]}</span>'
)
last_end = end
if last_end < len(text):
html_parts.append(text[last_end:])
html_content = "".join(html_parts)
return f"""
<div style="
padding: 20px;
background-color: #f8f9fa;
border-radius: 8px;
line-height: 1.8;
font-size: 16px;
white-space: pre-wrap;
word-wrap: break-word;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
">
{html_content}
</div>
"""
def main():
st.subheader("Google Link Model")
st.markdown(
"A transformer model trained by [DEJAN AI](https://dejan.ai/) that predicts which words should be hyperlinks. Trained on **10,273 pages from [Google's official blog](https://blog.google/)** β€” learning link placement directly from Google's own editorial decisions."
)
# Load model
try:
tokenizer, model, device = load_model()
#st.success(f"Model loaded on {device}")
except Exception as e:
st.error(f"Failed to load model: {e}")
return
# Text input
text = st.text_area("Input text:", height=200)
if st.button("Detect Links"):
if text:
word_results = predict_links(text, tokenizer, model, device)
# Display highlighted text
st.subheader("Text with Highlighted Links")
html = render_highlighted_text(text, word_results)
st.markdown(html, unsafe_allow_html=True)
# Show statistics
st.info(f"Found {len(word_results)} link candidates")
# Merge adjacent words into contiguous spans
if word_results:
sorted_results = sorted(word_results, key=lambda x: x["start"])
merged = []
cur = sorted_results[0].copy()
for nxt in sorted_results[1:]:
gap = text[cur["end"]:nxt["start"]]
if gap == "" or gap.strip() == "":
# Adjacent or separated only by whitespace β€” merge
cur["end"] = nxt["end"]
cur["confidence"] = (cur["confidence"] + nxt["confidence"]) / 2
else:
merged.append(cur)
cur = nxt.copy()
merged.append(cur)
st.subheader("Predicted Link Spans")
df = pd.DataFrame([
{
"Text": text[r["start"]:r["end"]],
"Confidence": f"{r['confidence']:.2%}",
}
for r in merged
])
st.dataframe(df, use_container_width=True, hide_index=True)
else:
st.warning("Please enter text")
if __name__ == "__main__":
main()