#!/usr/bin/env python3 # app.py # Streamlit app for link detection with word-level highlighting import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForTokenClassification import pandas as pd st.set_page_config(page_title="Link Detection", page_icon="🔗", layout="centered") st.logo( "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", size="large", link="https://dejan.ai", ) @st.cache_resource def load_model(model_path="dejanseo/google-links"): """Load model and tokenizer.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForTokenClassification.from_pretrained(model_path) model = model.to(device) model.eval() return tokenizer, model, device def group_tokens_into_words(tokens, offset_mapping, link_probs): """Group tokens into words based on tokenizer patterns.""" words = [] current_word_tokens = [] current_word_offsets = [] current_word_probs = [] for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)): # Skip special tokens if offsets == [0, 0]: if current_word_tokens: words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) current_word_tokens = [] current_word_offsets = [] current_word_probs = [] continue # Check if this is a new word or continuation is_new_word = False # DeBERTa uses ▁ for word boundaries if token.startswith("▁"): is_new_word = True # BERT uses ## for subword continuation elif i == 0 or not token.startswith("##"): # If previous token exists and doesn't indicate continuation if i == 0 or offset_mapping[i-1] == [0, 0]: is_new_word = True # Check if there's a gap between tokens (indicates new word) elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]: is_new_word = True if is_new_word and current_word_tokens: # Save current word words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) current_word_tokens = [] current_word_offsets = [] current_word_probs = [] # Add token to current word current_word_tokens.append(token) current_word_offsets.append(offsets) current_word_probs.append(prob) # Add last word if exists if current_word_tokens: words.append({ 'tokens': current_word_tokens, 'offsets': current_word_offsets, 'probs': current_word_probs }) return words def predict_links(text, tokenizer, model, device, max_length=512, doc_stride=128): """Predict link tokens with word-level highlighting using sliding windows.""" if not text.strip(): return [] # Tokenize full text without truncation or special tokens full_enc = tokenizer( text, add_special_tokens=False, truncation=False, return_offsets_mapping=True, ) all_ids = full_enc["input_ids"] all_offsets = full_enc["offset_mapping"] n_tokens = len(all_ids) # Accumulate probabilities per token position (for averaging overlaps) prob_sums = [0.0] * n_tokens prob_counts = [0] * n_tokens # Sliding window parameters (matching training _prep.py) specials = 2 # CLS + SEP for DeBERTa cap = max_length - specials # 510 content tokens per window step = max(cap - doc_stride, 1) # 382 # Generate windows and run inference start = 0 while start < n_tokens: end = min(start + cap, n_tokens) window_ids = all_ids[start:end] # Add special tokens (CLS + content + SEP) cls_id = tokenizer.cls_token_id or tokenizer.bos_token_id or 1 sep_id = tokenizer.sep_token_id or tokenizer.eos_token_id or 2 input_ids = torch.tensor( [[cls_id] + window_ids + [sep_id]], device=device ) attention_mask = torch.ones_like(input_ids) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask).logits probs = F.softmax(logits, dim=-1)[0].cpu() # Skip special tokens (first and last) to get content probs content_probs = probs[1:-1, 1].tolist() # Map back to original token positions for i, p in enumerate(content_probs): orig_idx = start + i if orig_idx < n_tokens: prob_sums[orig_idx] += p prob_counts[orig_idx] += 1 if end == n_tokens: break start += step # Average probabilities across overlapping windows link_probs = [ prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0 for i in range(n_tokens) ] # Get tokens and offsets for word grouping tokens = tokenizer.convert_ids_to_tokens(all_ids) offset_mapping = [list(o) for o in all_offsets] # Group tokens into words words = group_tokens_into_words(tokens, offset_mapping, link_probs) # Build word results with max confidence per word # Opacity tiers: >=5% → 1.0, >=4% → 0.75, >=3% → 0.5, >=2% → 0.25 results = [] for word_group in words: word_offsets = word_group['offsets'] word_probs = word_group['probs'] max_conf = max(word_probs) if max_conf >= 0.02: start = word_offsets[0][0] end = word_offsets[-1][1] if max_conf >= 0.05: opacity = 1.0 elif max_conf >= 0.04: opacity = 0.75 elif max_conf >= 0.03: opacity = 0.5 else: opacity = 0.25 results.append({ "start": start, "end": end, "opacity": opacity, "confidence": round(max_conf, 4), }) return results def render_highlighted_text(text, word_results): """Render text with opacity-tiered green highlights.""" if not text: return "" # Sort spans by start position spans = sorted(word_results, key=lambda x: x["start"]) html_parts = [] last_end = 0 for span in spans: start, end, opacity = span["start"], span["end"], span["opacity"] if start > last_end: html_parts.append(text[last_end:start]) html_parts.append( f'= 0.75 else "#1A1A1A"}; padding: 2px 4px; ' f'border-radius: 3px; font-weight: 500;">{text[start:end]}' ) last_end = end if last_end < len(text): html_parts.append(text[last_end:]) html_content = "".join(html_parts) return f"""