#!/usr/bin/env python3 # app.py # Streamlit app for link detection with word-level highlighting import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForTokenClassification import json st.set_page_config(page_title="Link Detection", page_icon="🔗") @st.cache_resource def load_model(model_path="model_link_token_cls"): """Load model and tokenizer.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForTokenClassification.from_pretrained(model_path) model = model.to(device) model.eval() return tokenizer, model, device def group_tokens_into_words(tokens, offset_mapping, link_probs): """Group tokens into words based on tokenizer patterns.""" words = [] current_word_tokens = [] current_word_offsets = [] current_word_probs = [] for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)): # Skip special tokens if offsets == [0, 0]: if current_word_tokens: words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) current_word_tokens = [] current_word_offsets = [] current_word_probs = [] continue # Check if this is a new word or continuation is_new_word = False # DeBERTa uses ▁ for word boundaries if token.startswith("▁"): is_new_word = True # BERT uses ## for subword continuation elif i == 0 or not token.startswith("##"): # If previous token exists and doesn't indicate continuation if i == 0 or offset_mapping[i-1] == [0, 0]: is_new_word = True # Check if there's a gap between tokens (indicates new word) elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]: is_new_word = True if is_new_word and current_word_tokens: # Save current word words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) current_word_tokens = [] current_word_offsets = [] current_word_probs = [] # Add token to current word current_word_tokens.append(token) current_word_offsets.append(offsets) current_word_probs.append(prob) # Add last word if exists if current_word_tokens: words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) return words def predict_links(text, tokenizer, model, device, threshold=0.5, max_length=512, doc_stride=128): """Predict link tokens with word-level highlighting using sliding windows.""" if not text.strip(): return [], [] # Tokenize full text without truncation or special tokens full_enc = tokenizer( text, add_special_tokens=False, truncation=False, return_offsets_mapping=True, ) all_ids = full_enc["input_ids"] all_offsets = full_enc["offset_mapping"] n_tokens = len(all_ids) # Accumulate probabilities per token position (for averaging overlaps) prob_sums = [0.0] * n_tokens prob_counts = [0] * n_tokens # Sliding window parameters (matching training _prep.py) specials = tokenizer.num_special_tokens_to_add(pair=False) # 2 for DeBERTa cap = max_length - specials # 510 content tokens per window step = max(cap - doc_stride, 1) # 382 # Generate windows and run inference start = 0 while start < n_tokens: end = min(start + cap, n_tokens) window_ids = all_ids[start:end] # Add special tokens (CLS + content + SEP) input_ids = torch.tensor( [tokenizer.build_inputs_with_special_tokens(window_ids)], device=device ) attention_mask = torch.ones_like(input_ids) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask).logits probs = F.softmax(logits, dim=-1)[0].cpu() # Skip special tokens (first and last) to get content probs content_probs = probs[1:-1, 1].tolist() # Map back to original token positions for i, p in enumerate(content_probs): orig_idx = start + i if orig_idx < n_tokens: prob_sums[orig_idx] += p prob_counts[orig_idx] += 1 if end == n_tokens: break start += step # Average probabilities across overlapping windows link_probs = [ prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0 for i in range(n_tokens) ] # Get tokens and offsets for word grouping tokens = tokenizer.convert_ids_to_tokens(all_ids) offset_mapping = [list(o) for o in all_offsets] # Group tokens into words words = group_tokens_into_words(tokens, offset_mapping, link_probs) # Extract link spans - if ANY token in a word meets threshold, highlight entire word link_spans = [] link_details = [] for word_group in words: word_offsets = word_group['offsets'] word_probs = word_group['probs'] # Check if any token in the word meets the threshold if any(prob >= threshold for prob in word_probs): # Get the span of the entire word start = word_offsets[0][0] end = word_offsets[-1][1] link_spans.append((start, end)) # Calculate max confidence for the word max_confidence = max(word_probs) avg_confidence = sum(word_probs) / len(word_probs) link_text = text[start:end] link_details.append({ "text": link_text, "start": start, "end": end, "max_confidence": round(max_confidence, 4), "avg_confidence": round(avg_confidence, 4) }) return link_spans, link_details def render_highlighted_text(text, link_spans): """Render text with highlighted link spans.""" if not text: return "" # Sort spans by start position link_spans = sorted(link_spans, key=lambda x: x[0]) # Build HTML with highlights html_parts = [] last_end = 0 for start, end in link_spans: # Add text before the link if start > last_end: html_parts.append(text[last_end:start]) # Add highlighted link html_parts.append( f'{text[start:end]}' ) last_end = end # Add remaining text if last_end < len(text): html_parts.append(text[last_end:]) html_content = "".join(html_parts) # Wrap in a div full_html = f"""
{html_content}
""" return full_html def main(): st.title("Link Detection") # Load model try: tokenizer, model, device = load_model() st.success(f"Model loaded on {device}") except Exception as e: st.error(f"Failed to load model: {e}") return # Threshold slider threshold = st.slider( "Confidence Threshold (%)", min_value=0, max_value=100, value=5, step=1, help="Highlights entire word if ANY of its tokens meet this threshold" ) / 100.0 # Text input text = st.text_area("Input text:", height=200) if st.button("Detect Links"): if text: link_spans, link_details = predict_links(text, tokenizer, model, device, threshold) # Display highlighted text st.subheader("Text with Highlighted Links") html = render_highlighted_text(text, link_spans) st.markdown(html, unsafe_allow_html=True) # Show statistics st.info(f"Found {len(link_details)} words with link confidence above {threshold:.0%}") # Display JSON details if link_details: st.subheader("Link Details (JSON)") st.json(link_details) else: st.warning("Please enter text") if __name__ == "__main__": main()