link-prediction / app.py
dejanseo's picture
Upload 8 files
26536bf verified
import streamlit as st
import torch
import torch.nn as nn
from transformers import DebertaV2Model, DebertaV2TokenizerFast, DebertaV2Config, AutoTokenizer
from pathlib import Path
import numpy as np
import json
import logging
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
from tqdm import tqdm
from skimage.filters import threshold_otsu
# ----------------------------------
# Logging
# ----------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------------
# Config / Model
# ----------------------------------
@dataclass
class TrainingConfig:
"""Training configuration for link token classification"""
model_name: str = "microsoft/deberta-v3-large"
num_labels: int = 2 # 0: not link, 1: link token
# Inference windowing
max_length: int = 512
doc_stride: int = 128 # match _prep.py for consistent windowing
# Train-only placeholders
train_file: str = ""
val_file: str = ""
batch_size: int = 1
gradient_accumulation_steps: int = 1
num_epochs: int = 1
learning_rate: float = 1e-5
warmup_ratio: float = 0.1
weight_decay: float = 0.01
max_grad_norm: float = 1.0
label_smoothing: float = 0.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"
num_workers: int = 0
bf16: bool = False
seed: int = 42
logging_steps: int = 1
eval_steps: int = 100
save_steps: int = 100
output_dir: str = "./deberta_link_output" # model is loaded from here
wandb_project: str = ""
wandb_name: str = ""
patience: int = 2
min_delta: float = 0.0001
class DeBERTaForTokenClassification(nn.Module):
"""DeBERTa model for token classification"""
def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
super().__init__()
self.config = DebertaV2Config.from_pretrained(model_name)
self.deberta = DebertaV2Model.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
nn.init.xavier_uniform_(self.classifier.weight)
nn.init.zeros_(self.classifier.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
logits = self.classifier(sequence_output)
return {'loss': None, 'logits': logits}
# ----------------------------------
# Load model/tokenizer (robust)
# ----------------------------------
@st.cache_resource
def load_model():
"""Loads pre-trained model and tokenizer. Handles raw state_dict and wrapped checkpoints."""
config = TrainingConfig()
final_dir = Path(config.output_dir) / "final_model"
model_path = final_dir / "pytorch_model.bin"
if not model_path.exists():
st.error(f"Model checkpoint not found at {model_path}.")
st.stop()
logger.info(f"Loading model from {model_path}...")
model = DeBERTaForTokenClassification(config.model_name, config.num_labels)
# Load checkpoint robustly
try:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
except TypeError:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
# Determine state_dict
state_dict = None
if isinstance(checkpoint, dict):
# Case A: raw state_dict (keys -> tensors)
if checkpoint and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
state_dict = checkpoint
logger.info("Detected raw state_dict checkpoint.")
# Case B: wrapped dicts
elif 'model_state_dict' in checkpoint and isinstance(checkpoint['model_state_dict'], dict):
state_dict = checkpoint['model_state_dict']
logger.info("Detected 'model_state_dict' in checkpoint.")
elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
state_dict = checkpoint['state_dict']
logger.info("Detected 'state_dict' in checkpoint.")
else:
raise KeyError(f"Unrecognized checkpoint format keys: {list(checkpoint.keys())}")
else:
raise TypeError(f"Unexpected checkpoint type: {type(checkpoint)}")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Missing keys: {missing}")
if unexpected:
logger.warning(f"Unexpected keys: {unexpected}")
model.to(config.device)
model.eval()
logger.info(f"Loading tokenizer {config.model_name}...")
tokenizer = DebertaV2TokenizerFast.from_pretrained(config.model_name)
logger.info("Tokenizer loaded.")
return model, tokenizer, config.device, config.max_length, config.doc_stride
model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model()
# ----------------------------------
# Inference helpers
# ----------------------------------
def windowize_inference(
plain_text: str,
tokenizer: AutoTokenizer,
max_length: int,
doc_stride: int
) -> List[Dict]:
"""Slice long text into overlapping windows for inference."""
specials = tokenizer.num_special_tokens_to_add(pair=False)
cap = max_length - specials
if cap <= 0:
raise ValueError(f"max_length too small; specials={specials}")
full_encoding = tokenizer(
plain_text,
add_special_tokens=False,
return_offsets_mapping=True,
return_attention_mask=False,
return_token_type_ids=False,
truncation=False,
)
input_ids_no_special = full_encoding["input_ids"]
offsets_no_special = full_encoding["offset_mapping"]
temp_encoding_for_word_ids = tokenizer(
plain_text, return_offsets_mapping=True, truncation=False, padding=False
)
full_word_ids = temp_encoding_for_word_ids.word_ids(batch_index=0)
windows_data = []
step = max(cap - doc_stride, 1)
start_token_idx = 0
total_tokens_no_special = len(input_ids_no_special)
while start_token_idx < total_tokens_no_special:
end_token_idx = min(start_token_idx + cap, total_tokens_no_special)
ids_slice_no_special = input_ids_no_special[start_token_idx:end_token_idx]
offsets_slice_no_special = offsets_no_special[start_token_idx:end_token_idx]
word_ids_slice = full_word_ids[start_token_idx:end_token_idx]
input_ids_with_special = tokenizer.build_inputs_with_special_tokens(ids_slice_no_special)
attention_mask_with_special = [1] * len(input_ids_with_special)
padding_length = max_length - len(input_ids_with_special)
if padding_length > 0:
input_ids_with_special.extend([tokenizer.pad_token_id] * padding_length)
attention_mask_with_special.extend([0] * padding_length)
window_offset_mapping = offsets_slice_no_special[:]
window_word_ids = word_ids_slice[:]
if tokenizer.cls_token_id is not None:
window_offset_mapping.insert(0, (0, 0))
window_word_ids.insert(0, None)
if tokenizer.sep_token_id is not None and len(window_offset_mapping) < max_length:
window_offset_mapping.append((0, 0))
window_word_ids.append(None)
while len(window_offset_mapping) < max_length:
window_offset_mapping.append((0, 0))
window_word_ids.append(None)
windows_data.append({
"input_ids": torch.tensor(input_ids_with_special, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask_with_special, dtype=torch.long),
"word_ids": window_word_ids,
"offset_mapping": window_offset_mapping,
})
if end_token_idx == total_tokens_no_special:
break
start_token_idx += step
return windows_data
def classify_text(
text: str,
otsu_mode: str,
prediction_threshold_override: Optional[float] = None
) -> Tuple[str, Optional[str], Optional[float]]:
"""Classify link tokens with windowing. Returns (html, warning, threshold%)."""
if not text.strip():
return "", None, None
windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
if not windows:
return "", "Could not generate any windows for processing.", None
char_link_probabilities = np.zeros(len(text), dtype=np.float32)
char_covered = np.zeros(len(text), dtype=bool)
all_content_token_probs = []
with torch.no_grad():
for window in tqdm(windows, desc="Processing windows"):
inputs = {
'input_ids': window['input_ids'].unsqueeze(0).to(device),
'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
}
outputs = model(**inputs)
logits = outputs['logits'].squeeze(0)
probabilities = torch.softmax(logits, dim=-1)
link_probs_for_window_tokens = probabilities[:, 1].cpu().numpy()
for i, (offset_start, offset_end) in enumerate(window['offset_mapping']):
if window['word_ids'][i] is not None and offset_start < offset_end:
char_link_probabilities[offset_start:offset_end] = np.maximum(
char_link_probabilities[offset_start:offset_end],
link_probs_for_window_tokens[i]
)
char_covered[offset_start:offset_end] = True
all_content_token_probs.append(link_probs_for_window_tokens[i])
# Threshold selection (Otsu or manual)
determined_threshold_float = None
determined_threshold_for_display = None # 0-100%
if prediction_threshold_override is not None:
determined_threshold_float = prediction_threshold_override / 100.0
determined_threshold_for_display = prediction_threshold_override
else:
if len(all_content_token_probs) > 1:
try:
otsu_base_threshold = threshold_otsu(np.array(all_content_token_probs))
conservative_delta = 0.1 # stricter
generous_delta = 0.1 # more lenient
if otsu_mode == 'conservative':
determined_threshold_float = otsu_base_threshold + conservative_delta
elif otsu_mode == 'generous':
determined_threshold_float = otsu_base_threshold - generous_delta
else:
determined_threshold_float = otsu_base_threshold
determined_threshold_float = max(0.0, min(1.0, determined_threshold_float))
determined_threshold_for_display = determined_threshold_float * 100
except ValueError:
logger.warning("Otsu failed; defaulting to 0.5.")
determined_threshold_float = 0.5
determined_threshold_for_display = 50.0
else:
logger.warning("Insufficient tokens for Otsu; defaulting to 0.5.")
determined_threshold_float = 0.5
determined_threshold_for_display = 50.0
final_threshold = determined_threshold_float
# Word-level aggregation
full_text_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False, padding=False)
full_word_ids = full_text_encoding.word_ids(batch_index=0)
full_offset_mapping = full_text_encoding['offset_mapping']
word_prob_map: Dict[int, List[float]] = {}
word_char_spans: Dict[int, List[int]] = {}
for i, word_id in enumerate(full_word_ids):
if word_id is not None:
start_char, end_char = full_offset_mapping[i]
if start_char < end_char and np.any(char_covered[start_char:end_char]):
if word_id not in word_prob_map:
word_prob_map[word_id] = []
word_char_spans[word_id] = [start_char, end_char]
else:
word_char_spans[word_id][0] = min(word_char_spans[word_id][0], start_char)
word_char_spans[word_id][1] = max(word_char_spans[word_id][1], end_char)
token_span_probs = char_link_probabilities[start_char:end_char]
word_prob_map[word_id].append(np.max(token_span_probs) if token_span_probs.size > 0 else 0.0)
elif word_id not in word_prob_map:
word_prob_map[word_id] = [0.0]
word_char_spans[word_id] = list(full_offset_mapping[i])
words_to_highlight_status: Dict[int, bool] = {}
for word_id, probs in word_prob_map.items():
max_word_prob = np.max(probs) if probs else 0.0
words_to_highlight_status[word_id] = (max_word_prob >= final_threshold)
# Reconstruct HTML with highlights
html_output_parts: List[str] = []
current_char_idx = 0
sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
for word_id in sorted_word_ids:
start_char, end_char = word_char_spans[word_id]
if start_char > current_char_idx:
html_output_parts.append(text[current_char_idx:start_char])
word_text = text[start_char:end_char]
if words_to_highlight_status.get(word_id, False):
html_output_parts.append(
"<span style='background-color: #D4EDDA; color: #155724; padding: 0.1em 0.2em; border-radius: 0.2em;'>"
+ word_text +
"</span>"
)
else:
html_output_parts.append(word_text)
current_char_idx = end_char
if current_char_idx < len(text):
html_output_parts.append(text[current_char_idx:])
return "".join(html_output_parts), None, determined_threshold_for_display
# ----------------------------------
# Streamlit UI
# ----------------------------------
st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
st.title("LinkBERT")
user_input = st.text_area(
"Paste your text here:",
"DEJAN AI is the world's leading AI SEO agency.",
height=200
)
with st.expander('Settings'):
auto_threshold_enabled = st.checkbox(
"Automagic",
value=True,
help="Uncheck to set manual threshold value for link prediction."
)
otsu_mode_options = ['Conservative', 'Standard', 'Generous']
selected_otsu_mode = 'Standard'
if auto_threshold_enabled:
selected_otsu_mode = st.radio(
"Generosity:",
otsu_mode_options,
index=1,
help="Generous suggests more links; conservative suggests fewer."
)
prediction_threshold_manual = 50.0
if not auto_threshold_enabled:
prediction_threshold_manual = st.slider(
"Manual Link Probability Threshold (%)",
min_value=0,
max_value=100,
value=50,
step=1,
help="Minimum probability to classify a token as a link when Automagic is off."
)
if st.button("Classify Text"):
if not user_input.strip():
st.warning("Please enter some text to classify.")
else:
threshold_to_pass = None if auto_threshold_enabled else prediction_threshold_manual
highlighted_html, warning_message, determined_threshold_for_display = classify_text(
user_input,
selected_otsu_mode.lower(),
threshold_to_pass
)
if warning_message:
st.warning(warning_message)
if determined_threshold_for_display is not None and auto_threshold_enabled:
st.info(f"Auto threshold: {determined_threshold_for_display:.1f}% ({selected_otsu_mode})")
st.markdown(highlighted_html, unsafe_allow_html=True)