|
|
|
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
|
import json
|
|
|
|
|
|
st.set_page_config(page_title="Link Detection", page_icon="🔗")
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_model(model_path="model_link_token_cls"):
|
|
|
"""Load model and tokenizer."""
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
|
|
model = AutoModelForTokenClassification.from_pretrained(model_path)
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
return tokenizer, model, device
|
|
|
|
|
|
def group_tokens_into_words(tokens, offset_mapping, link_probs):
|
|
|
"""Group tokens into words based on tokenizer patterns."""
|
|
|
words = []
|
|
|
current_word_tokens = []
|
|
|
current_word_offsets = []
|
|
|
current_word_probs = []
|
|
|
|
|
|
for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)):
|
|
|
|
|
|
if offsets == [0, 0]:
|
|
|
if current_word_tokens:
|
|
|
words.append({
|
|
|
'tokens': current_word_tokens,
|
|
|
'offsets': current_word_offsets,
|
|
|
'probs': current_word_probs
|
|
|
})
|
|
|
current_word_tokens = []
|
|
|
current_word_offsets = []
|
|
|
current_word_probs = []
|
|
|
continue
|
|
|
|
|
|
|
|
|
is_new_word = False
|
|
|
|
|
|
|
|
|
if token.startswith("▁"):
|
|
|
is_new_word = True
|
|
|
|
|
|
elif i == 0 or not token.startswith("##"):
|
|
|
|
|
|
if i == 0 or offset_mapping[i-1] == [0, 0]:
|
|
|
is_new_word = True
|
|
|
|
|
|
elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]:
|
|
|
is_new_word = True
|
|
|
|
|
|
if is_new_word and current_word_tokens:
|
|
|
|
|
|
words.append({
|
|
|
'tokens': current_word_tokens,
|
|
|
'offsets': current_word_offsets,
|
|
|
'probs': current_word_probs
|
|
|
})
|
|
|
current_word_tokens = []
|
|
|
current_word_offsets = []
|
|
|
current_word_probs = []
|
|
|
|
|
|
|
|
|
current_word_tokens.append(token)
|
|
|
current_word_offsets.append(offsets)
|
|
|
current_word_probs.append(prob)
|
|
|
|
|
|
|
|
|
if current_word_tokens:
|
|
|
words.append({
|
|
|
'tokens': current_word_tokens,
|
|
|
'offsets': current_word_offsets,
|
|
|
'probs': current_word_probs
|
|
|
})
|
|
|
|
|
|
return words
|
|
|
|
|
|
def predict_links(text, tokenizer, model, device, threshold=0.5,
|
|
|
max_length=512, doc_stride=128):
|
|
|
"""Predict link tokens with word-level highlighting using sliding windows."""
|
|
|
if not text.strip():
|
|
|
return [], []
|
|
|
|
|
|
|
|
|
full_enc = tokenizer(
|
|
|
text,
|
|
|
add_special_tokens=False,
|
|
|
truncation=False,
|
|
|
return_offsets_mapping=True,
|
|
|
)
|
|
|
all_ids = full_enc["input_ids"]
|
|
|
all_offsets = full_enc["offset_mapping"]
|
|
|
n_tokens = len(all_ids)
|
|
|
|
|
|
|
|
|
prob_sums = [0.0] * n_tokens
|
|
|
prob_counts = [0] * n_tokens
|
|
|
|
|
|
|
|
|
specials = tokenizer.num_special_tokens_to_add(pair=False)
|
|
|
cap = max_length - specials
|
|
|
step = max(cap - doc_stride, 1)
|
|
|
|
|
|
|
|
|
start = 0
|
|
|
while start < n_tokens:
|
|
|
end = min(start + cap, n_tokens)
|
|
|
window_ids = all_ids[start:end]
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor(
|
|
|
[tokenizer.build_inputs_with_special_tokens(window_ids)],
|
|
|
device=device
|
|
|
)
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
|
|
|
probs = F.softmax(logits, dim=-1)[0].cpu()
|
|
|
|
|
|
content_probs = probs[1:-1, 1].tolist()
|
|
|
|
|
|
|
|
|
for i, p in enumerate(content_probs):
|
|
|
orig_idx = start + i
|
|
|
if orig_idx < n_tokens:
|
|
|
prob_sums[orig_idx] += p
|
|
|
prob_counts[orig_idx] += 1
|
|
|
|
|
|
if end == n_tokens:
|
|
|
break
|
|
|
start += step
|
|
|
|
|
|
|
|
|
link_probs = [
|
|
|
prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0
|
|
|
for i in range(n_tokens)
|
|
|
]
|
|
|
|
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(all_ids)
|
|
|
offset_mapping = [list(o) for o in all_offsets]
|
|
|
|
|
|
|
|
|
words = group_tokens_into_words(tokens, offset_mapping, link_probs)
|
|
|
|
|
|
|
|
|
link_spans = []
|
|
|
link_details = []
|
|
|
|
|
|
for word_group in words:
|
|
|
word_offsets = word_group['offsets']
|
|
|
word_probs = word_group['probs']
|
|
|
|
|
|
|
|
|
if any(prob >= threshold for prob in word_probs):
|
|
|
|
|
|
start = word_offsets[0][0]
|
|
|
end = word_offsets[-1][1]
|
|
|
link_spans.append((start, end))
|
|
|
|
|
|
|
|
|
max_confidence = max(word_probs)
|
|
|
avg_confidence = sum(word_probs) / len(word_probs)
|
|
|
|
|
|
link_text = text[start:end]
|
|
|
link_details.append({
|
|
|
"text": link_text,
|
|
|
"start": start,
|
|
|
"end": end,
|
|
|
"max_confidence": round(max_confidence, 4),
|
|
|
"avg_confidence": round(avg_confidence, 4)
|
|
|
})
|
|
|
|
|
|
return link_spans, link_details
|
|
|
|
|
|
def render_highlighted_text(text, link_spans):
|
|
|
"""Render text with highlighted link spans."""
|
|
|
if not text:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
link_spans = sorted(link_spans, key=lambda x: x[0])
|
|
|
|
|
|
|
|
|
html_parts = []
|
|
|
last_end = 0
|
|
|
|
|
|
for start, end in link_spans:
|
|
|
|
|
|
if start > last_end:
|
|
|
html_parts.append(text[last_end:start])
|
|
|
|
|
|
html_parts.append(
|
|
|
f'<span style="background-color: #90EE90; padding: 2px 4px; '
|
|
|
f'border-radius: 3px; font-weight: 500;">{text[start:end]}</span>'
|
|
|
)
|
|
|
last_end = end
|
|
|
|
|
|
|
|
|
if last_end < len(text):
|
|
|
html_parts.append(text[last_end:])
|
|
|
|
|
|
html_content = "".join(html_parts)
|
|
|
|
|
|
|
|
|
full_html = f"""
|
|
|
<div style="
|
|
|
padding: 20px;
|
|
|
background-color: #f8f9fa;
|
|
|
border-radius: 8px;
|
|
|
line-height: 1.8;
|
|
|
font-size: 16px;
|
|
|
white-space: pre-wrap;
|
|
|
word-wrap: break-word;
|
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
|
">
|
|
|
{html_content}
|
|
|
</div>
|
|
|
"""
|
|
|
|
|
|
return full_html
|
|
|
|
|
|
def main():
|
|
|
st.title("Link Detection")
|
|
|
|
|
|
|
|
|
try:
|
|
|
tokenizer, model, device = load_model()
|
|
|
st.success(f"Model loaded on {device}")
|
|
|
except Exception as e:
|
|
|
st.error(f"Failed to load model: {e}")
|
|
|
return
|
|
|
|
|
|
|
|
|
threshold = st.slider(
|
|
|
"Confidence Threshold (%)",
|
|
|
min_value=0,
|
|
|
max_value=100,
|
|
|
value=5,
|
|
|
step=1,
|
|
|
help="Highlights entire word if ANY of its tokens meet this threshold"
|
|
|
) / 100.0
|
|
|
|
|
|
|
|
|
text = st.text_area("Input text:", height=200)
|
|
|
|
|
|
if st.button("Detect Links"):
|
|
|
if text:
|
|
|
link_spans, link_details = predict_links(text, tokenizer, model, device, threshold)
|
|
|
|
|
|
|
|
|
st.subheader("Text with Highlighted Links")
|
|
|
html = render_highlighted_text(text, link_spans)
|
|
|
st.markdown(html, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
st.info(f"Found {len(link_details)} words with link confidence above {threshold:.0%}")
|
|
|
|
|
|
|
|
|
if link_details:
|
|
|
st.subheader("Link Details (JSON)")
|
|
|
st.json(link_details)
|
|
|
else:
|
|
|
st.warning("Please enter text")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |