Update src/streamlit_app.py
Browse files- src/streamlit_app.py +98 -97
src/streamlit_app.py
CHANGED
|
@@ -1,25 +1,39 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
from torch.nn.functional import softmax
|
| 5 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| 6 |
import pandas as pd
|
| 7 |
import trafilatura
|
| 8 |
|
| 9 |
-
#
|
| 10 |
st.set_page_config(layout="wide", page_title="LinkBERT")
|
| 11 |
|
| 12 |
-
# Load
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
model.eval()
|
| 17 |
|
| 18 |
# Functions
|
| 19 |
-
|
| 20 |
def tokenize_with_indices(text: str):
|
| 21 |
-
encoded = tokenizer.encode_plus(
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def fetch_and_extract_content(url: str):
|
| 25 |
downloaded = trafilatura.fetch_url(url)
|
|
@@ -29,109 +43,102 @@ def fetch_and_extract_content(url: str):
|
|
| 29 |
return None
|
| 30 |
|
| 31 |
def process_text(inputs: str, confidence_threshold: float):
|
| 32 |
-
max_chunk_length = 512 - 2
|
| 33 |
words = inputs.split()
|
| 34 |
chunk_texts = []
|
| 35 |
-
current_chunk = []
|
| 36 |
-
current_length = 0
|
| 37 |
for word in words:
|
| 38 |
-
|
|
|
|
| 39 |
chunk_texts.append(" ".join(current_chunk))
|
| 40 |
current_chunk = [word]
|
| 41 |
-
current_length =
|
| 42 |
-
|
| 43 |
else:
|
| 44 |
current_chunk.append(word)
|
| 45 |
-
current_length +=
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
'Prediction': [],
|
| 51 |
-
'Confidence': [],
|
| 52 |
-
'Start': [],
|
| 53 |
-
'End': []
|
| 54 |
-
}
|
| 55 |
reconstructed_text = ""
|
| 56 |
original_position_offset = 0
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
outputs = model(input_ids_tensor)
|
| 63 |
-
logits = outputs.logits
|
| 64 |
-
predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
|
| 65 |
-
softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
|
| 66 |
-
|
| 67 |
-
word_info = {}
|
| 68 |
-
|
| 69 |
-
for idx, (start, end) in enumerate(token_offsets):
|
| 70 |
-
if idx == 0 or idx == len(token_offsets) - 1:
|
| 71 |
-
continue
|
| 72 |
-
|
| 73 |
-
word_start = start
|
| 74 |
-
while word_start > 0 and chunk[word_start-1] != ' ':
|
| 75 |
-
word_start -= 1
|
| 76 |
-
|
| 77 |
-
if word_start not in word_info:
|
| 78 |
-
word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
|
| 79 |
-
|
| 80 |
-
confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
|
| 81 |
-
|
| 82 |
-
if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
|
| 83 |
-
word_info[word_start]['prediction'] = 1
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
for
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
df_tokens = pd.DataFrame(df_data)
|
| 112 |
return reconstructed_text, df_tokens
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
st.title('LinkBERT')
|
| 117 |
st.markdown("""
|
| 118 |
-
LinkBERT
|
| 119 |
""")
|
| 120 |
|
| 121 |
-
confidence_threshold = st.slider(
|
| 122 |
|
| 123 |
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
|
| 124 |
|
| 125 |
with tab1:
|
| 126 |
user_input = st.text_area("Enter text to process:")
|
| 127 |
-
if st.button(
|
| 128 |
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
|
| 129 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
| 130 |
st.dataframe(df_tokens)
|
| 131 |
|
| 132 |
with tab2:
|
| 133 |
url_input = st.text_input("Enter URL to process:")
|
| 134 |
-
if st.button(
|
| 135 |
content = fetch_and_extract_content(url_input)
|
| 136 |
if content:
|
| 137 |
highlighted_text, df_tokens = process_text(content, confidence_threshold)
|
|
@@ -140,28 +147,22 @@ with tab2:
|
|
| 140 |
else:
|
| 141 |
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|
| 142 |
|
| 143 |
-
# Additional information at the end
|
| 144 |
st.divider()
|
| 145 |
st.markdown("""
|
| 146 |
-
|
| 147 |
## Applications of LinkBERT
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
- **Anchor Text
|
| 152 |
-
- **
|
| 153 |
-
- **Link Placement Guide:** Offers guidance to link builders by suggesting optimal placement for links within content.
|
| 154 |
-
- **Anchor Text Idea Generator:** Provides creative anchor text suggestions to enrich content and improve SEO strategies.
|
| 155 |
-
- **Spam and Inorganic SEO Detection:** Helps identify unnatural link patterns, contributing to the detection of spam and inorganic SEO tactics.
|
| 156 |
|
| 157 |
## Training and Performance
|
| 158 |
-
|
| 159 |
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
|
| 160 |
|
| 161 |
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
|
| 162 |
-
|
| 163 |
# Engage Our Team
|
| 164 |
Interested in using this in an automated pipeline for bulk link prediction?
|
| 165 |
|
| 166 |
-
Please [book an appointment](https://dejanmarketing.com/conference/)
|
| 167 |
-
""")
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 4 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| 5 |
import pandas as pd
|
| 6 |
import trafilatura
|
| 7 |
|
| 8 |
+
# Streamlit config
|
| 9 |
st.set_page_config(layout="wide", page_title="LinkBERT")
|
| 10 |
|
| 11 |
+
# Load tokenizer & model (avoid meta-tensor .to() issue)
|
| 12 |
+
MODEL_ID = "dejanseo/LinkBERT-XL"
|
| 13 |
+
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
|
| 15 |
+
|
| 16 |
+
load_kwargs = {}
|
| 17 |
+
if torch.cuda.is_available():
|
| 18 |
+
# Load directly onto GPU(s); do NOT call .to(...) afterward
|
| 19 |
+
load_kwargs.update(dict(device_map="auto", torch_dtype=torch.float16))
|
| 20 |
+
else:
|
| 21 |
+
# CPU load without meta tensors
|
| 22 |
+
load_kwargs.update(dict(device_map=None))
|
| 23 |
+
|
| 24 |
+
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, **load_kwargs)
|
| 25 |
model.eval()
|
| 26 |
|
| 27 |
# Functions
|
|
|
|
| 28 |
def tokenize_with_indices(text: str):
|
| 29 |
+
encoded = tokenizer.encode_plus(
|
| 30 |
+
text,
|
| 31 |
+
return_offsets_mapping=True,
|
| 32 |
+
add_special_tokens=True,
|
| 33 |
+
truncation=True,
|
| 34 |
+
max_length=512
|
| 35 |
+
)
|
| 36 |
+
return encoded["input_ids"], encoded["offset_mapping"]
|
| 37 |
|
| 38 |
def fetch_and_extract_content(url: str):
|
| 39 |
downloaded = trafilatura.fetch_url(url)
|
|
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
def process_text(inputs: str, confidence_threshold: float):
|
| 46 |
+
max_chunk_length = 512 - 2 # safe window for special tokens
|
| 47 |
words = inputs.split()
|
| 48 |
chunk_texts = []
|
| 49 |
+
current_chunk, current_length = [], 0
|
|
|
|
| 50 |
for word in words:
|
| 51 |
+
tok_len = len(tokenizer.tokenize(word))
|
| 52 |
+
if tok_len + current_length > max_chunk_length:
|
| 53 |
chunk_texts.append(" ".join(current_chunk))
|
| 54 |
current_chunk = [word]
|
| 55 |
+
current_length = tok_len
|
|
|
|
| 56 |
else:
|
| 57 |
current_chunk.append(word)
|
| 58 |
+
current_length += tok_len
|
| 59 |
+
if current_chunk:
|
| 60 |
+
chunk_texts.append(" ".join(current_chunk))
|
| 61 |
+
|
| 62 |
+
df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
reconstructed_text = ""
|
| 64 |
original_position_offset = 0
|
| 65 |
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
for chunk in chunk_texts:
|
| 68 |
+
input_ids, token_offsets = tokenize_with_indices(chunk)
|
| 69 |
+
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
outputs = model(input_ids_tensor)
|
| 72 |
+
logits = outputs.logits # [1, seq_len, num_labels]
|
| 73 |
+
predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
| 74 |
+
softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
|
| 75 |
+
|
| 76 |
+
word_info = {}
|
| 77 |
+
for idx, (start, end) in enumerate(token_offsets):
|
| 78 |
+
if idx == 0 or idx == len(token_offsets) - 1:
|
| 79 |
+
continue # skip specials
|
| 80 |
+
|
| 81 |
+
word_start = start
|
| 82 |
+
while word_start > 0 and chunk[word_start - 1] != ' ':
|
| 83 |
+
word_start -= 1
|
| 84 |
+
|
| 85 |
+
if word_start not in word_info:
|
| 86 |
+
word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
|
| 87 |
+
|
| 88 |
+
conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
|
| 89 |
+
if predictions[idx] == 1 and conf_pct >= confidence_threshold:
|
| 90 |
+
word_info[word_start]["prediction"] = 1
|
| 91 |
+
word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
|
| 92 |
+
word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
|
| 93 |
+
|
| 94 |
+
last_end = 0
|
| 95 |
+
for word_start in sorted(word_info.keys()):
|
| 96 |
+
word_data = word_info[word_start]
|
| 97 |
+
for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
|
| 98 |
+
escaped = subtoken_text.replace("$", "\\$")
|
| 99 |
+
if last_end < subtoken_start:
|
| 100 |
+
reconstructed_text += chunk[last_end:subtoken_start]
|
| 101 |
+
if word_data["prediction"] == 1:
|
| 102 |
+
reconstructed_text += (
|
| 103 |
+
f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
reconstructed_text += escaped
|
| 107 |
+
last_end = subtoken_end
|
| 108 |
+
|
| 109 |
+
df_data["Word"].append(escaped)
|
| 110 |
+
df_data["Prediction"].append(word_data["prediction"])
|
| 111 |
+
df_data["Confidence"].append(word_info[word_start]["confidence"])
|
| 112 |
+
df_data["Start"].append(subtoken_start + original_position_offset)
|
| 113 |
+
df_data["End"].append(subtoken_end + original_position_offset)
|
| 114 |
+
|
| 115 |
+
original_position_offset += len(chunk) + 1
|
| 116 |
+
|
| 117 |
+
reconstructed_text += chunk[last_end:].replace("$", "\\$")
|
| 118 |
|
| 119 |
df_tokens = pd.DataFrame(df_data)
|
| 120 |
return reconstructed_text, df_tokens
|
| 121 |
|
| 122 |
+
# UI
|
| 123 |
+
st.title("LinkBERT")
|
|
|
|
| 124 |
st.markdown("""
|
| 125 |
+
LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
|
| 126 |
""")
|
| 127 |
|
| 128 |
+
confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
|
| 129 |
|
| 130 |
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
|
| 131 |
|
| 132 |
with tab1:
|
| 133 |
user_input = st.text_area("Enter text to process:")
|
| 134 |
+
if st.button("Process Text"):
|
| 135 |
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
|
| 136 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
| 137 |
st.dataframe(df_tokens)
|
| 138 |
|
| 139 |
with tab2:
|
| 140 |
url_input = st.text_input("Enter URL to process:")
|
| 141 |
+
if st.button("Fetch and Process"):
|
| 142 |
content = fetch_and_extract_content(url_input)
|
| 143 |
if content:
|
| 144 |
highlighted_text, df_tokens = process_text(content, confidence_threshold)
|
|
|
|
| 147 |
else:
|
| 148 |
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|
| 149 |
|
|
|
|
| 150 |
st.divider()
|
| 151 |
st.markdown("""
|
|
|
|
| 152 |
## Applications of LinkBERT
|
| 153 |
+
- **Anchor Text Suggestion**
|
| 154 |
+
- **Evaluation of Existing Links**
|
| 155 |
+
- **Link Placement Guide**
|
| 156 |
+
- **Anchor Text Idea Generator**
|
| 157 |
+
- **Spam and Inorganic SEO Detection**
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
## Training and Performance
|
|
|
|
| 160 |
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
|
| 161 |
|
| 162 |
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
|
| 163 |
+
|
| 164 |
# Engage Our Team
|
| 165 |
Interested in using this in an automated pipeline for bulk link prediction?
|
| 166 |
|
| 167 |
+
Please [book an appointment](https://dejanmarketing.com/conference/).
|
| 168 |
+
""")
|