File size: 16,247 Bytes
26536bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 | import streamlit as st
import torch
import torch.nn as nn
from transformers import DebertaV2Model, DebertaV2TokenizerFast, DebertaV2Config, AutoTokenizer
from pathlib import Path
import numpy as np
import json
import logging
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
from tqdm import tqdm
from skimage.filters import threshold_otsu
# ----------------------------------
# Logging
# ----------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------------
# Config / Model
# ----------------------------------
@dataclass
class TrainingConfig:
"""Training configuration for link token classification"""
model_name: str = "microsoft/deberta-v3-large"
num_labels: int = 2 # 0: not link, 1: link token
# Inference windowing
max_length: int = 512
doc_stride: int = 128 # match _prep.py for consistent windowing
# Train-only placeholders
train_file: str = ""
val_file: str = ""
batch_size: int = 1
gradient_accumulation_steps: int = 1
num_epochs: int = 1
learning_rate: float = 1e-5
warmup_ratio: float = 0.1
weight_decay: float = 0.01
max_grad_norm: float = 1.0
label_smoothing: float = 0.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"
num_workers: int = 0
bf16: bool = False
seed: int = 42
logging_steps: int = 1
eval_steps: int = 100
save_steps: int = 100
output_dir: str = "./deberta_link_output" # model is loaded from here
wandb_project: str = ""
wandb_name: str = ""
patience: int = 2
min_delta: float = 0.0001
class DeBERTaForTokenClassification(nn.Module):
"""DeBERTa model for token classification"""
def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
super().__init__()
self.config = DebertaV2Config.from_pretrained(model_name)
self.deberta = DebertaV2Model.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
nn.init.xavier_uniform_(self.classifier.weight)
nn.init.zeros_(self.classifier.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
logits = self.classifier(sequence_output)
return {'loss': None, 'logits': logits}
# ----------------------------------
# Load model/tokenizer (robust)
# ----------------------------------
@st.cache_resource
def load_model():
"""Loads pre-trained model and tokenizer. Handles raw state_dict and wrapped checkpoints."""
config = TrainingConfig()
final_dir = Path(config.output_dir) / "final_model"
model_path = final_dir / "pytorch_model.bin"
if not model_path.exists():
st.error(f"Model checkpoint not found at {model_path}.")
st.stop()
logger.info(f"Loading model from {model_path}...")
model = DeBERTaForTokenClassification(config.model_name, config.num_labels)
# Load checkpoint robustly
try:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
except TypeError:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
# Determine state_dict
state_dict = None
if isinstance(checkpoint, dict):
# Case A: raw state_dict (keys -> tensors)
if checkpoint and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
state_dict = checkpoint
logger.info("Detected raw state_dict checkpoint.")
# Case B: wrapped dicts
elif 'model_state_dict' in checkpoint and isinstance(checkpoint['model_state_dict'], dict):
state_dict = checkpoint['model_state_dict']
logger.info("Detected 'model_state_dict' in checkpoint.")
elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
state_dict = checkpoint['state_dict']
logger.info("Detected 'state_dict' in checkpoint.")
else:
raise KeyError(f"Unrecognized checkpoint format keys: {list(checkpoint.keys())}")
else:
raise TypeError(f"Unexpected checkpoint type: {type(checkpoint)}")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Missing keys: {missing}")
if unexpected:
logger.warning(f"Unexpected keys: {unexpected}")
model.to(config.device)
model.eval()
logger.info(f"Loading tokenizer {config.model_name}...")
tokenizer = DebertaV2TokenizerFast.from_pretrained(config.model_name)
logger.info("Tokenizer loaded.")
return model, tokenizer, config.device, config.max_length, config.doc_stride
model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model()
# ----------------------------------
# Inference helpers
# ----------------------------------
def windowize_inference(
plain_text: str,
tokenizer: AutoTokenizer,
max_length: int,
doc_stride: int
) -> List[Dict]:
"""Slice long text into overlapping windows for inference."""
specials = tokenizer.num_special_tokens_to_add(pair=False)
cap = max_length - specials
if cap <= 0:
raise ValueError(f"max_length too small; specials={specials}")
full_encoding = tokenizer(
plain_text,
add_special_tokens=False,
return_offsets_mapping=True,
return_attention_mask=False,
return_token_type_ids=False,
truncation=False,
)
input_ids_no_special = full_encoding["input_ids"]
offsets_no_special = full_encoding["offset_mapping"]
temp_encoding_for_word_ids = tokenizer(
plain_text, return_offsets_mapping=True, truncation=False, padding=False
)
full_word_ids = temp_encoding_for_word_ids.word_ids(batch_index=0)
windows_data = []
step = max(cap - doc_stride, 1)
start_token_idx = 0
total_tokens_no_special = len(input_ids_no_special)
while start_token_idx < total_tokens_no_special:
end_token_idx = min(start_token_idx + cap, total_tokens_no_special)
ids_slice_no_special = input_ids_no_special[start_token_idx:end_token_idx]
offsets_slice_no_special = offsets_no_special[start_token_idx:end_token_idx]
word_ids_slice = full_word_ids[start_token_idx:end_token_idx]
input_ids_with_special = tokenizer.build_inputs_with_special_tokens(ids_slice_no_special)
attention_mask_with_special = [1] * len(input_ids_with_special)
padding_length = max_length - len(input_ids_with_special)
if padding_length > 0:
input_ids_with_special.extend([tokenizer.pad_token_id] * padding_length)
attention_mask_with_special.extend([0] * padding_length)
window_offset_mapping = offsets_slice_no_special[:]
window_word_ids = word_ids_slice[:]
if tokenizer.cls_token_id is not None:
window_offset_mapping.insert(0, (0, 0))
window_word_ids.insert(0, None)
if tokenizer.sep_token_id is not None and len(window_offset_mapping) < max_length:
window_offset_mapping.append((0, 0))
window_word_ids.append(None)
while len(window_offset_mapping) < max_length:
window_offset_mapping.append((0, 0))
window_word_ids.append(None)
windows_data.append({
"input_ids": torch.tensor(input_ids_with_special, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask_with_special, dtype=torch.long),
"word_ids": window_word_ids,
"offset_mapping": window_offset_mapping,
})
if end_token_idx == total_tokens_no_special:
break
start_token_idx += step
return windows_data
def classify_text(
text: str,
otsu_mode: str,
prediction_threshold_override: Optional[float] = None
) -> Tuple[str, Optional[str], Optional[float]]:
"""Classify link tokens with windowing. Returns (html, warning, threshold%)."""
if not text.strip():
return "", None, None
windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
if not windows:
return "", "Could not generate any windows for processing.", None
char_link_probabilities = np.zeros(len(text), dtype=np.float32)
char_covered = np.zeros(len(text), dtype=bool)
all_content_token_probs = []
with torch.no_grad():
for window in tqdm(windows, desc="Processing windows"):
inputs = {
'input_ids': window['input_ids'].unsqueeze(0).to(device),
'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
}
outputs = model(**inputs)
logits = outputs['logits'].squeeze(0)
probabilities = torch.softmax(logits, dim=-1)
link_probs_for_window_tokens = probabilities[:, 1].cpu().numpy()
for i, (offset_start, offset_end) in enumerate(window['offset_mapping']):
if window['word_ids'][i] is not None and offset_start < offset_end:
char_link_probabilities[offset_start:offset_end] = np.maximum(
char_link_probabilities[offset_start:offset_end],
link_probs_for_window_tokens[i]
)
char_covered[offset_start:offset_end] = True
all_content_token_probs.append(link_probs_for_window_tokens[i])
# Threshold selection (Otsu or manual)
determined_threshold_float = None
determined_threshold_for_display = None # 0-100%
if prediction_threshold_override is not None:
determined_threshold_float = prediction_threshold_override / 100.0
determined_threshold_for_display = prediction_threshold_override
else:
if len(all_content_token_probs) > 1:
try:
otsu_base_threshold = threshold_otsu(np.array(all_content_token_probs))
conservative_delta = 0.1 # stricter
generous_delta = 0.1 # more lenient
if otsu_mode == 'conservative':
determined_threshold_float = otsu_base_threshold + conservative_delta
elif otsu_mode == 'generous':
determined_threshold_float = otsu_base_threshold - generous_delta
else:
determined_threshold_float = otsu_base_threshold
determined_threshold_float = max(0.0, min(1.0, determined_threshold_float))
determined_threshold_for_display = determined_threshold_float * 100
except ValueError:
logger.warning("Otsu failed; defaulting to 0.5.")
determined_threshold_float = 0.5
determined_threshold_for_display = 50.0
else:
logger.warning("Insufficient tokens for Otsu; defaulting to 0.5.")
determined_threshold_float = 0.5
determined_threshold_for_display = 50.0
final_threshold = determined_threshold_float
# Word-level aggregation
full_text_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False, padding=False)
full_word_ids = full_text_encoding.word_ids(batch_index=0)
full_offset_mapping = full_text_encoding['offset_mapping']
word_prob_map: Dict[int, List[float]] = {}
word_char_spans: Dict[int, List[int]] = {}
for i, word_id in enumerate(full_word_ids):
if word_id is not None:
start_char, end_char = full_offset_mapping[i]
if start_char < end_char and np.any(char_covered[start_char:end_char]):
if word_id not in word_prob_map:
word_prob_map[word_id] = []
word_char_spans[word_id] = [start_char, end_char]
else:
word_char_spans[word_id][0] = min(word_char_spans[word_id][0], start_char)
word_char_spans[word_id][1] = max(word_char_spans[word_id][1], end_char)
token_span_probs = char_link_probabilities[start_char:end_char]
word_prob_map[word_id].append(np.max(token_span_probs) if token_span_probs.size > 0 else 0.0)
elif word_id not in word_prob_map:
word_prob_map[word_id] = [0.0]
word_char_spans[word_id] = list(full_offset_mapping[i])
words_to_highlight_status: Dict[int, bool] = {}
for word_id, probs in word_prob_map.items():
max_word_prob = np.max(probs) if probs else 0.0
words_to_highlight_status[word_id] = (max_word_prob >= final_threshold)
# Reconstruct HTML with highlights
html_output_parts: List[str] = []
current_char_idx = 0
sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
for word_id in sorted_word_ids:
start_char, end_char = word_char_spans[word_id]
if start_char > current_char_idx:
html_output_parts.append(text[current_char_idx:start_char])
word_text = text[start_char:end_char]
if words_to_highlight_status.get(word_id, False):
html_output_parts.append(
"<span style='background-color: #D4EDDA; color: #155724; padding: 0.1em 0.2em; border-radius: 0.2em;'>"
+ word_text +
"</span>"
)
else:
html_output_parts.append(word_text)
current_char_idx = end_char
if current_char_idx < len(text):
html_output_parts.append(text[current_char_idx:])
return "".join(html_output_parts), None, determined_threshold_for_display
# ----------------------------------
# Streamlit UI
# ----------------------------------
st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
st.title("LinkBERT")
user_input = st.text_area(
"Paste your text here:",
"DEJAN AI is the world's leading AI SEO agency.",
height=200
)
with st.expander('Settings'):
auto_threshold_enabled = st.checkbox(
"Automagic",
value=True,
help="Uncheck to set manual threshold value for link prediction."
)
otsu_mode_options = ['Conservative', 'Standard', 'Generous']
selected_otsu_mode = 'Standard'
if auto_threshold_enabled:
selected_otsu_mode = st.radio(
"Generosity:",
otsu_mode_options,
index=1,
help="Generous suggests more links; conservative suggests fewer."
)
prediction_threshold_manual = 50.0
if not auto_threshold_enabled:
prediction_threshold_manual = st.slider(
"Manual Link Probability Threshold (%)",
min_value=0,
max_value=100,
value=50,
step=1,
help="Minimum probability to classify a token as a link when Automagic is off."
)
if st.button("Classify Text"):
if not user_input.strip():
st.warning("Please enter some text to classify.")
else:
threshold_to_pass = None if auto_threshold_enabled else prediction_threshold_manual
highlighted_html, warning_message, determined_threshold_for_display = classify_text(
user_input,
selected_otsu_mode.lower(),
threshold_to_pass
)
if warning_message:
st.warning(warning_message)
if determined_threshold_for_display is not None and auto_threshold_enabled:
st.info(f"Auto threshold: {determined_threshold_for_display:.1f}% ({selected_otsu_mode})")
st.markdown(highlighted_html, unsafe_allow_html=True)
|