File size: 16,247 Bytes
26536bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import streamlit as st
import torch
import torch.nn as nn
from transformers import DebertaV2Model, DebertaV2TokenizerFast, DebertaV2Config, AutoTokenizer
from pathlib import Path
import numpy as np
import json
import logging
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
from tqdm import tqdm
from skimage.filters import threshold_otsu

# ----------------------------------
# Logging
# ----------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ----------------------------------
# Config / Model
# ----------------------------------

@dataclass
class TrainingConfig:
    """Training configuration for link token classification"""
    model_name: str = "microsoft/deberta-v3-large"
    num_labels: int = 2  # 0: not link, 1: link token

    # Inference windowing
    max_length: int = 512
    doc_stride: int = 128  # match _prep.py for consistent windowing

    # Train-only placeholders
    train_file: str = ""
    val_file: str = ""
    batch_size: int = 1
    gradient_accumulation_steps: int = 1
    num_epochs: int = 1
    learning_rate: float = 1e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    label_smoothing: float = 0.0

    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 0
    bf16: bool = False
    seed: int = 42

    logging_steps: int = 1
    eval_steps: int = 100
    save_steps: int = 100
    output_dir: str = "./deberta_link_output"  # model is loaded from here

    wandb_project: str = ""
    wandb_name: str = ""

    patience: int = 2
    min_delta: float = 0.0001


class DeBERTaForTokenClassification(nn.Module):
    """DeBERTa model for token classification"""

    def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
        super().__init__()
        self.config = DebertaV2Config.from_pretrained(model_name)
        self.deberta = DebertaV2Model.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)

    def forward(

        self,

        input_ids: torch.Tensor,

        attention_mask: torch.Tensor,

        labels: Optional[torch.Tensor] = None

    ) -> Dict[str, torch.Tensor]:
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)
        return {'loss': None, 'logits': logits}

# ----------------------------------
# Load model/tokenizer (robust)
# ----------------------------------

@st.cache_resource
def load_model():
    """Loads pre-trained model and tokenizer. Handles raw state_dict and wrapped checkpoints."""
    config = TrainingConfig()
    final_dir = Path(config.output_dir) / "final_model"
    model_path = final_dir / "pytorch_model.bin"

    if not model_path.exists():
        st.error(f"Model checkpoint not found at {model_path}.")
        st.stop()

    logger.info(f"Loading model from {model_path}...")
    model = DeBERTaForTokenClassification(config.model_name, config.num_labels)

    # Load checkpoint robustly
    try:
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
    except TypeError:
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

    # Determine state_dict
    state_dict = None
    if isinstance(checkpoint, dict):
        # Case A: raw state_dict (keys -> tensors)
        if checkpoint and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
            state_dict = checkpoint
            logger.info("Detected raw state_dict checkpoint.")
        # Case B: wrapped dicts
        elif 'model_state_dict' in checkpoint and isinstance(checkpoint['model_state_dict'], dict):
            state_dict = checkpoint['model_state_dict']
            logger.info("Detected 'model_state_dict' in checkpoint.")
        elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
            state_dict = checkpoint['state_dict']
            logger.info("Detected 'state_dict' in checkpoint.")
        else:
            raise KeyError(f"Unrecognized checkpoint format keys: {list(checkpoint.keys())}")
    else:
        raise TypeError(f"Unexpected checkpoint type: {type(checkpoint)}")

    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing:
        logger.warning(f"Missing keys: {missing}")
    if unexpected:
        logger.warning(f"Unexpected keys: {unexpected}")

    model.to(config.device)
    model.eval()

    logger.info(f"Loading tokenizer {config.model_name}...")
    tokenizer = DebertaV2TokenizerFast.from_pretrained(config.model_name)
    logger.info("Tokenizer loaded.")

    return model, tokenizer, config.device, config.max_length, config.doc_stride

model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model()

# ----------------------------------
# Inference helpers
# ----------------------------------

def windowize_inference(

    plain_text: str,

    tokenizer: AutoTokenizer,

    max_length: int,

    doc_stride: int

) -> List[Dict]:
    """Slice long text into overlapping windows for inference."""
    specials = tokenizer.num_special_tokens_to_add(pair=False)
    cap = max_length - specials
    if cap <= 0:
        raise ValueError(f"max_length too small; specials={specials}")

    full_encoding = tokenizer(
        plain_text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        return_attention_mask=False,
        return_token_type_ids=False,
        truncation=False,
    )
    input_ids_no_special = full_encoding["input_ids"]
    offsets_no_special = full_encoding["offset_mapping"]

    temp_encoding_for_word_ids = tokenizer(
        plain_text, return_offsets_mapping=True, truncation=False, padding=False
    )
    full_word_ids = temp_encoding_for_word_ids.word_ids(batch_index=0)

    windows_data = []
    step = max(cap - doc_stride, 1)
    start_token_idx = 0
    total_tokens_no_special = len(input_ids_no_special)

    while start_token_idx < total_tokens_no_special:
        end_token_idx = min(start_token_idx + cap, total_tokens_no_special)

        ids_slice_no_special = input_ids_no_special[start_token_idx:end_token_idx]
        offsets_slice_no_special = offsets_no_special[start_token_idx:end_token_idx]
        word_ids_slice = full_word_ids[start_token_idx:end_token_idx]

        input_ids_with_special = tokenizer.build_inputs_with_special_tokens(ids_slice_no_special)
        attention_mask_with_special = [1] * len(input_ids_with_special)

        padding_length = max_length - len(input_ids_with_special)
        if padding_length > 0:
            input_ids_with_special.extend([tokenizer.pad_token_id] * padding_length)
            attention_mask_with_special.extend([0] * padding_length)

        window_offset_mapping = offsets_slice_no_special[:]
        window_word_ids = word_ids_slice[:]

        if tokenizer.cls_token_id is not None:
            window_offset_mapping.insert(0, (0, 0))
            window_word_ids.insert(0, None)
        if tokenizer.sep_token_id is not None and len(window_offset_mapping) < max_length:
            window_offset_mapping.append((0, 0))
            window_word_ids.append(None)

        while len(window_offset_mapping) < max_length:
            window_offset_mapping.append((0, 0))
            window_word_ids.append(None)

        windows_data.append({
            "input_ids": torch.tensor(input_ids_with_special, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask_with_special, dtype=torch.long),
            "word_ids": window_word_ids,
            "offset_mapping": window_offset_mapping,
        })

        if end_token_idx == total_tokens_no_special:
            break
        start_token_idx += step

    return windows_data


def classify_text(

    text: str,

    otsu_mode: str,

    prediction_threshold_override: Optional[float] = None

) -> Tuple[str, Optional[str], Optional[float]]:
    """Classify link tokens with windowing. Returns (html, warning, threshold%)."""
    if not text.strip():
        return "", None, None

    windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
    if not windows:
        return "", "Could not generate any windows for processing.", None

    char_link_probabilities = np.zeros(len(text), dtype=np.float32)
    char_covered = np.zeros(len(text), dtype=bool)
    all_content_token_probs = []

    with torch.no_grad():
        for window in tqdm(windows, desc="Processing windows"):
            inputs = {
                'input_ids': window['input_ids'].unsqueeze(0).to(device),
                'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
            }
            outputs = model(**inputs)
            logits = outputs['logits'].squeeze(0)
            probabilities = torch.softmax(logits, dim=-1)
            link_probs_for_window_tokens = probabilities[:, 1].cpu().numpy()

            for i, (offset_start, offset_end) in enumerate(window['offset_mapping']):
                if window['word_ids'][i] is not None and offset_start < offset_end:
                    char_link_probabilities[offset_start:offset_end] = np.maximum(
                        char_link_probabilities[offset_start:offset_end],
                        link_probs_for_window_tokens[i]
                    )
                    char_covered[offset_start:offset_end] = True
                    all_content_token_probs.append(link_probs_for_window_tokens[i])

    # Threshold selection (Otsu or manual)
    determined_threshold_float = None
    determined_threshold_for_display = None  # 0-100%

    if prediction_threshold_override is not None:
        determined_threshold_float = prediction_threshold_override / 100.0
        determined_threshold_for_display = prediction_threshold_override
    else:
        if len(all_content_token_probs) > 1:
            try:
                otsu_base_threshold = threshold_otsu(np.array(all_content_token_probs))
                conservative_delta = 0.1  # stricter
                generous_delta = 0.1      # more lenient
                if otsu_mode == 'conservative':
                    determined_threshold_float = otsu_base_threshold + conservative_delta
                elif otsu_mode == 'generous':
                    determined_threshold_float = otsu_base_threshold - generous_delta
                else:
                    determined_threshold_float = otsu_base_threshold
                determined_threshold_float = max(0.0, min(1.0, determined_threshold_float))
                determined_threshold_for_display = determined_threshold_float * 100
            except ValueError:
                logger.warning("Otsu failed; defaulting to 0.5.")
                determined_threshold_float = 0.5
                determined_threshold_for_display = 50.0
        else:
            logger.warning("Insufficient tokens for Otsu; defaulting to 0.5.")
            determined_threshold_float = 0.5
            determined_threshold_for_display = 50.0

    final_threshold = determined_threshold_float

    # Word-level aggregation
    full_text_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False, padding=False)
    full_word_ids = full_text_encoding.word_ids(batch_index=0)
    full_offset_mapping = full_text_encoding['offset_mapping']

    word_prob_map: Dict[int, List[float]] = {}
    word_char_spans: Dict[int, List[int]] = {}

    for i, word_id in enumerate(full_word_ids):
        if word_id is not None:
            start_char, end_char = full_offset_mapping[i]
            if start_char < end_char and np.any(char_covered[start_char:end_char]):
                if word_id not in word_prob_map:
                    word_prob_map[word_id] = []
                    word_char_spans[word_id] = [start_char, end_char]
                else:
                    word_char_spans[word_id][0] = min(word_char_spans[word_id][0], start_char)
                    word_char_spans[word_id][1] = max(word_char_spans[word_id][1], end_char)

                token_span_probs = char_link_probabilities[start_char:end_char]
                word_prob_map[word_id].append(np.max(token_span_probs) if token_span_probs.size > 0 else 0.0)
            elif word_id not in word_prob_map:
                word_prob_map[word_id] = [0.0]
                word_char_spans[word_id] = list(full_offset_mapping[i])

    words_to_highlight_status: Dict[int, bool] = {}
    for word_id, probs in word_prob_map.items():
        max_word_prob = np.max(probs) if probs else 0.0
        words_to_highlight_status[word_id] = (max_word_prob >= final_threshold)

    # Reconstruct HTML with highlights
    html_output_parts: List[str] = []
    current_char_idx = 0
    sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])

    for word_id in sorted_word_ids:
        start_char, end_char = word_char_spans[word_id]
        if start_char > current_char_idx:
            html_output_parts.append(text[current_char_idx:start_char])

        word_text = text[start_char:end_char]
        if words_to_highlight_status.get(word_id, False):
            html_output_parts.append(
                "<span style='background-color: #D4EDDA; color: #155724; padding: 0.1em 0.2em; border-radius: 0.2em;'>"
                + word_text +
                "</span>"
            )
        else:
            html_output_parts.append(word_text)
        current_char_idx = end_char

    if current_char_idx < len(text):
        html_output_parts.append(text[current_char_idx:])

    return "".join(html_output_parts), None, determined_threshold_for_display

# ----------------------------------
# Streamlit UI
# ----------------------------------

st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
st.title("LinkBERT")

user_input = st.text_area(
    "Paste your text here:",
    "DEJAN AI is the world's leading AI SEO agency.",
    height=200
)

with st.expander('Settings'):
    auto_threshold_enabled = st.checkbox(
        "Automagic",
        value=True,
        help="Uncheck to set manual threshold value for link prediction."
    )

    otsu_mode_options = ['Conservative', 'Standard', 'Generous']
    selected_otsu_mode = 'Standard'
    if auto_threshold_enabled:
        selected_otsu_mode = st.radio(
            "Generosity:",
            otsu_mode_options,
            index=1,
            help="Generous suggests more links; conservative suggests fewer."
        )

    prediction_threshold_manual = 50.0
    if not auto_threshold_enabled:
        prediction_threshold_manual = st.slider(
            "Manual Link Probability Threshold (%)",
            min_value=0,
            max_value=100,
            value=50,
            step=1,
            help="Minimum probability to classify a token as a link when Automagic is off."
        )

if st.button("Classify Text"):
    if not user_input.strip():
        st.warning("Please enter some text to classify.")
    else:
        threshold_to_pass = None if auto_threshold_enabled else prediction_threshold_manual
        highlighted_html, warning_message, determined_threshold_for_display = classify_text(
            user_input,
            selected_otsu_mode.lower(),
            threshold_to_pass
        )
        if warning_message:
            st.warning(warning_message)
        if determined_threshold_for_display is not None and auto_threshold_enabled:
            st.info(f"Auto threshold: {determined_threshold_for_display:.1f}% ({selected_otsu_mode})")
        st.markdown(highlighted_html, unsafe_allow_html=True)