Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # app.py | |
| # Streamlit app for link detection with word-level highlighting | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import pandas as pd | |
| st.set_page_config(page_title="Link Detection", page_icon="π", layout="centered") | |
| st.logo( | |
| "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", | |
| size="large", | |
| link="https://dejan.ai", | |
| ) | |
| def load_model(model_path="dejanseo/google-links"): | |
| """Load model and tokenizer.""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) | |
| model = AutoModelForTokenClassification.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| return tokenizer, model, device | |
| def group_tokens_into_words(tokens, offset_mapping, link_probs): | |
| """Group tokens into words based on tokenizer patterns.""" | |
| words = [] | |
| current_word_tokens = [] | |
| current_word_offsets = [] | |
| current_word_probs = [] | |
| for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)): | |
| # Skip special tokens | |
| if offsets == [0, 0]: | |
| if current_word_tokens: | |
| words.append({ | |
| 'tokens': current_word_tokens, | |
| 'offsets': current_word_offsets, | |
| 'probs': current_word_probs | |
| }) | |
| current_word_tokens = [] | |
| current_word_offsets = [] | |
| current_word_probs = [] | |
| continue | |
| # Check if this is a new word or continuation | |
| is_new_word = False | |
| # DeBERTa uses β for word boundaries | |
| if token.startswith("β"): | |
| is_new_word = True | |
| # BERT uses ## for subword continuation | |
| elif i == 0 or not token.startswith("##"): | |
| # If previous token exists and doesn't indicate continuation | |
| if i == 0 or offset_mapping[i-1] == [0, 0]: | |
| is_new_word = True | |
| # Check if there's a gap between tokens (indicates new word) | |
| elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]: | |
| is_new_word = True | |
| if is_new_word and current_word_tokens: | |
| # Save current word | |
| words.append({ | |
| 'tokens': current_word_tokens, | |
| 'offsets': current_word_offsets, | |
| 'probs': current_word_probs | |
| }) | |
| current_word_tokens = [] | |
| current_word_offsets = [] | |
| current_word_probs = [] | |
| # Add token to current word | |
| current_word_tokens.append(token) | |
| current_word_offsets.append(offsets) | |
| current_word_probs.append(prob) | |
| # Add last word if exists | |
| if current_word_tokens: | |
| words.append({ | |
| 'tokens': current_word_tokens, | |
| 'offsets': current_word_offsets, | |
| 'probs': current_word_probs | |
| }) | |
| return words | |
| def predict_links(text, tokenizer, model, device, | |
| max_length=512, doc_stride=128): | |
| """Predict link tokens with word-level highlighting using sliding windows.""" | |
| if not text.strip(): | |
| return [] | |
| # Tokenize full text without truncation or special tokens | |
| full_enc = tokenizer( | |
| text, | |
| add_special_tokens=False, | |
| truncation=False, | |
| return_offsets_mapping=True, | |
| ) | |
| all_ids = full_enc["input_ids"] | |
| all_offsets = full_enc["offset_mapping"] | |
| n_tokens = len(all_ids) | |
| # Accumulate probabilities per token position (for averaging overlaps) | |
| prob_sums = [0.0] * n_tokens | |
| prob_counts = [0] * n_tokens | |
| # Sliding window parameters (matching training _prep.py) | |
| specials = 2 # CLS + SEP for DeBERTa | |
| cap = max_length - specials # 510 content tokens per window | |
| step = max(cap - doc_stride, 1) # 382 | |
| # Generate windows and run inference | |
| start = 0 | |
| while start < n_tokens: | |
| end = min(start + cap, n_tokens) | |
| window_ids = all_ids[start:end] | |
| # Add special tokens (CLS + content + SEP) | |
| cls_id = tokenizer.cls_token_id or tokenizer.bos_token_id or 1 | |
| sep_id = tokenizer.sep_token_id or tokenizer.eos_token_id or 2 | |
| input_ids = torch.tensor( | |
| [[cls_id] + window_ids + [sep_id]], | |
| device=device | |
| ) | |
| attention_mask = torch.ones_like(input_ids) | |
| with torch.no_grad(): | |
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits | |
| probs = F.softmax(logits, dim=-1)[0].cpu() | |
| # Skip special tokens (first and last) to get content probs | |
| content_probs = probs[1:-1, 1].tolist() | |
| # Map back to original token positions | |
| for i, p in enumerate(content_probs): | |
| orig_idx = start + i | |
| if orig_idx < n_tokens: | |
| prob_sums[orig_idx] += p | |
| prob_counts[orig_idx] += 1 | |
| if end == n_tokens: | |
| break | |
| start += step | |
| # Average probabilities across overlapping windows | |
| link_probs = [ | |
| prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0 | |
| for i in range(n_tokens) | |
| ] | |
| # Get tokens and offsets for word grouping | |
| tokens = tokenizer.convert_ids_to_tokens(all_ids) | |
| offset_mapping = [list(o) for o in all_offsets] | |
| # Group tokens into words | |
| words = group_tokens_into_words(tokens, offset_mapping, link_probs) | |
| # Build word results with max confidence per word | |
| # Opacity tiers: >=5% β 1.0, >=4% β 0.75, >=3% β 0.5, >=2% β 0.25 | |
| results = [] | |
| for word_group in words: | |
| word_offsets = word_group['offsets'] | |
| word_probs = word_group['probs'] | |
| max_conf = max(word_probs) | |
| if max_conf >= 0.02: | |
| start = word_offsets[0][0] | |
| end = word_offsets[-1][1] | |
| if max_conf >= 0.05: | |
| opacity = 1.0 | |
| elif max_conf >= 0.04: | |
| opacity = 0.75 | |
| elif max_conf >= 0.03: | |
| opacity = 0.5 | |
| else: | |
| opacity = 0.25 | |
| results.append({ | |
| "start": start, | |
| "end": end, | |
| "opacity": opacity, | |
| "confidence": round(max_conf, 4), | |
| }) | |
| return results | |
| def render_highlighted_text(text, word_results): | |
| """Render text with opacity-tiered green highlights.""" | |
| if not text: | |
| return "" | |
| # Sort spans by start position | |
| spans = sorted(word_results, key=lambda x: x["start"]) | |
| html_parts = [] | |
| last_end = 0 | |
| for span in spans: | |
| start, end, opacity = span["start"], span["end"], span["opacity"] | |
| if start > last_end: | |
| html_parts.append(text[last_end:start]) | |
| html_parts.append( | |
| f'<span style="background-color: rgba(46, 125, 50, {opacity}); ' | |
| f'color: {"#fff" if opacity >= 0.75 else "#1A1A1A"}; padding: 2px 4px; ' | |
| f'border-radius: 3px; font-weight: 500;">{text[start:end]}</span>' | |
| ) | |
| last_end = end | |
| if last_end < len(text): | |
| html_parts.append(text[last_end:]) | |
| html_content = "".join(html_parts) | |
| return f""" | |
| <div style=" | |
| padding: 20px; | |
| background-color: #f8f9fa; | |
| border-radius: 8px; | |
| line-height: 1.8; | |
| font-size: 16px; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| "> | |
| {html_content} | |
| </div> | |
| """ | |
| def main(): | |
| st.subheader("Google Link Model") | |
| st.markdown( | |
| "A transformer model trained by [DEJAN AI](https://dejan.ai/) that predicts which words should be hyperlinks. Trained on **10,273 pages from [Google's official blog](https://blog.google/)** β learning link placement directly from Google's own editorial decisions." | |
| ) | |
| # Load model | |
| try: | |
| tokenizer, model, device = load_model() | |
| #st.success(f"Model loaded on {device}") | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return | |
| # Text input | |
| text = st.text_area("Input text:", height=200) | |
| if st.button("Detect Links"): | |
| if text: | |
| word_results = predict_links(text, tokenizer, model, device) | |
| # Display highlighted text | |
| st.subheader("Text with Highlighted Links") | |
| html = render_highlighted_text(text, word_results) | |
| st.markdown(html, unsafe_allow_html=True) | |
| # Show statistics | |
| st.info(f"Found {len(word_results)} link candidates") | |
| # Merge adjacent words into contiguous spans | |
| if word_results: | |
| sorted_results = sorted(word_results, key=lambda x: x["start"]) | |
| merged = [] | |
| cur = sorted_results[0].copy() | |
| for nxt in sorted_results[1:]: | |
| gap = text[cur["end"]:nxt["start"]] | |
| if gap == "" or gap.strip() == "": | |
| # Adjacent or separated only by whitespace β merge | |
| cur["end"] = nxt["end"] | |
| cur["confidence"] = (cur["confidence"] + nxt["confidence"]) / 2 | |
| else: | |
| merged.append(cur) | |
| cur = nxt.copy() | |
| merged.append(cur) | |
| st.subheader("Predicted Link Spans") | |
| df = pd.DataFrame([ | |
| { | |
| "Text": text[r["start"]:r["end"]], | |
| "Confidence": f"{r['confidence']:.2%}", | |
| } | |
| for r in merged | |
| ]) | |
| st.dataframe(df, use_container_width=True, hide_index=True) | |
| else: | |
| st.warning("Please enter text") | |
| if __name__ == "__main__": | |
| main() |