google-links / app.py
dejanseo's picture
Upload 22 files
f29b6e6 verified
#!/usr/bin/env python3
# app.py
# Streamlit app for link detection with word-level highlighting
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForTokenClassification
import json
st.set_page_config(page_title="Link Detection", page_icon="🔗")
@st.cache_resource
def load_model(model_path="model_link_token_cls"):
"""Load model and tokenizer."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(model_path)
model = model.to(device)
model.eval()
return tokenizer, model, device
def group_tokens_into_words(tokens, offset_mapping, link_probs):
"""Group tokens into words based on tokenizer patterns."""
words = []
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)):
# Skip special tokens
if offsets == [0, 0]:
if current_word_tokens:
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
continue
# Check if this is a new word or continuation
is_new_word = False
# DeBERTa uses ▁ for word boundaries
if token.startswith("▁"):
is_new_word = True
# BERT uses ## for subword continuation
elif i == 0 or not token.startswith("##"):
# If previous token exists and doesn't indicate continuation
if i == 0 or offset_mapping[i-1] == [0, 0]:
is_new_word = True
# Check if there's a gap between tokens (indicates new word)
elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]:
is_new_word = True
if is_new_word and current_word_tokens:
# Save current word
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
current_word_tokens = []
current_word_offsets = []
current_word_probs = []
# Add token to current word
current_word_tokens.append(token)
current_word_offsets.append(offsets)
current_word_probs.append(prob)
# Add last word if exists
if current_word_tokens:
words.append({
'tokens': current_word_tokens,
'offsets': current_word_offsets,
'probs': current_word_probs
})
return words
def predict_links(text, tokenizer, model, device, threshold=0.5,
max_length=512, doc_stride=128):
"""Predict link tokens with word-level highlighting using sliding windows."""
if not text.strip():
return [], []
# Tokenize full text without truncation or special tokens
full_enc = tokenizer(
text,
add_special_tokens=False,
truncation=False,
return_offsets_mapping=True,
)
all_ids = full_enc["input_ids"]
all_offsets = full_enc["offset_mapping"]
n_tokens = len(all_ids)
# Accumulate probabilities per token position (for averaging overlaps)
prob_sums = [0.0] * n_tokens
prob_counts = [0] * n_tokens
# Sliding window parameters (matching training _prep.py)
specials = tokenizer.num_special_tokens_to_add(pair=False) # 2 for DeBERTa
cap = max_length - specials # 510 content tokens per window
step = max(cap - doc_stride, 1) # 382
# Generate windows and run inference
start = 0
while start < n_tokens:
end = min(start + cap, n_tokens)
window_ids = all_ids[start:end]
# Add special tokens (CLS + content + SEP)
input_ids = torch.tensor(
[tokenizer.build_inputs_with_special_tokens(window_ids)],
device=device
)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
probs = F.softmax(logits, dim=-1)[0].cpu()
# Skip special tokens (first and last) to get content probs
content_probs = probs[1:-1, 1].tolist()
# Map back to original token positions
for i, p in enumerate(content_probs):
orig_idx = start + i
if orig_idx < n_tokens:
prob_sums[orig_idx] += p
prob_counts[orig_idx] += 1
if end == n_tokens:
break
start += step
# Average probabilities across overlapping windows
link_probs = [
prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0
for i in range(n_tokens)
]
# Get tokens and offsets for word grouping
tokens = tokenizer.convert_ids_to_tokens(all_ids)
offset_mapping = [list(o) for o in all_offsets]
# Group tokens into words
words = group_tokens_into_words(tokens, offset_mapping, link_probs)
# Extract link spans - if ANY token in a word meets threshold, highlight entire word
link_spans = []
link_details = []
for word_group in words:
word_offsets = word_group['offsets']
word_probs = word_group['probs']
# Check if any token in the word meets the threshold
if any(prob >= threshold for prob in word_probs):
# Get the span of the entire word
start = word_offsets[0][0]
end = word_offsets[-1][1]
link_spans.append((start, end))
# Calculate max confidence for the word
max_confidence = max(word_probs)
avg_confidence = sum(word_probs) / len(word_probs)
link_text = text[start:end]
link_details.append({
"text": link_text,
"start": start,
"end": end,
"max_confidence": round(max_confidence, 4),
"avg_confidence": round(avg_confidence, 4)
})
return link_spans, link_details
def render_highlighted_text(text, link_spans):
"""Render text with highlighted link spans."""
if not text:
return ""
# Sort spans by start position
link_spans = sorted(link_spans, key=lambda x: x[0])
# Build HTML with highlights
html_parts = []
last_end = 0
for start, end in link_spans:
# Add text before the link
if start > last_end:
html_parts.append(text[last_end:start])
# Add highlighted link
html_parts.append(
f'<span style="background-color: #90EE90; padding: 2px 4px; '
f'border-radius: 3px; font-weight: 500;">{text[start:end]}</span>'
)
last_end = end
# Add remaining text
if last_end < len(text):
html_parts.append(text[last_end:])
html_content = "".join(html_parts)
# Wrap in a div
full_html = f"""
<div style="
padding: 20px;
background-color: #f8f9fa;
border-radius: 8px;
line-height: 1.8;
font-size: 16px;
white-space: pre-wrap;
word-wrap: break-word;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
">
{html_content}
</div>
"""
return full_html
def main():
st.title("Link Detection")
# Load model
try:
tokenizer, model, device = load_model()
st.success(f"Model loaded on {device}")
except Exception as e:
st.error(f"Failed to load model: {e}")
return
# Threshold slider
threshold = st.slider(
"Confidence Threshold (%)",
min_value=0,
max_value=100,
value=5,
step=1,
help="Highlights entire word if ANY of its tokens meet this threshold"
) / 100.0
# Text input
text = st.text_area("Input text:", height=200)
if st.button("Detect Links"):
if text:
link_spans, link_details = predict_links(text, tokenizer, model, device, threshold)
# Display highlighted text
st.subheader("Text with Highlighted Links")
html = render_highlighted_text(text, link_spans)
st.markdown(html, unsafe_allow_html=True)
# Show statistics
st.info(f"Found {len(link_details)} words with link confidence above {threshold:.0%}")
# Display JSON details
if link_details:
st.subheader("Link Details (JSON)")
st.json(link_details)
else:
st.warning("Please enter text")
if __name__ == "__main__":
main()