dejanseo commited on
Commit
d8dd02e
·
verified ·
1 Parent(s): a2bc870

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
4
+ import numpy as np
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Tuple, Optional
8
+
9
+ # ----------------------------------
10
+ # Logging
11
+ # ----------------------------------
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ----------------------------------
16
+ # Config
17
+ # ----------------------------------
18
+ @dataclass
19
+ class AppConfig:
20
+ model_name: str = "dejanseo/link-prediction"
21
+ max_length: int = 512
22
+ doc_stride: int = 128
23
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # ----------------------------------
26
+ # Load model/tokenizer
27
+ # ----------------------------------
28
+ config = AppConfig()
29
+ logger.info(f"Loading model: {config.model_name} on {config.device}")
30
+
31
+ model = AutoModelForTokenClassification.from_pretrained(config.model_name)
32
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
33
+ model.to(config.device)
34
+ model.eval()
35
+
36
+ logger.info("Model loaded successfully.")
37
+
38
+ # ----------------------------------
39
+ # Inference helpers
40
+ # ----------------------------------
41
+ def windowize_inference(
42
+ plain_text: str, tokenizer: AutoTokenizer, max_length: int, doc_stride: int
43
+ ) -> List[Dict]:
44
+ """Slice long text into overlapping windows for inference."""
45
+ specials = tokenizer.num_special_tokens_to_add(pair=False)
46
+ cap = max_length - specials
47
+ full_encoding = tokenizer(
48
+ plain_text, add_special_tokens=False, return_offsets_mapping=True, truncation=False
49
+ )
50
+ temp_tokenization = tokenizer(plain_text, truncation=False)
51
+ full_word_ids = temp_tokenization.word_ids(batch_index=0)
52
+
53
+ windows_data = []
54
+ step = max(cap - doc_stride, 1)
55
+ start_token_idx = 0
56
+ total_tokens = len(full_encoding["input_ids"])
57
+
58
+ if total_tokens == 0 and len(plain_text) > 0:
59
+ logger.warning("Tokenizer produced 0 tokens for a non-empty string.")
60
+ return []
61
+
62
+ while start_token_idx < total_tokens:
63
+ end_token_idx = min(start_token_idx + cap, total_tokens)
64
+ ids_slice = full_encoding["input_ids"][start_token_idx:end_token_idx]
65
+ offsets_slice = full_encoding["offset_mapping"][start_token_idx:end_token_idx]
66
+
67
+ word_ids_slice = []
68
+ current_token = 0
69
+ for i, wid in enumerate(full_word_ids):
70
+ if temp_tokenization.token_to_chars(i) is not None:
71
+ if current_token >= start_token_idx and current_token < end_token_idx:
72
+ word_ids_slice.append(wid)
73
+ current_token += 1
74
+
75
+ input_ids = tokenizer.build_inputs_with_special_tokens(ids_slice)
76
+ attention_mask = [1] * len(input_ids)
77
+ padding_length = max_length - len(input_ids)
78
+ input_ids.extend([tokenizer.pad_token_id] * padding_length)
79
+ attention_mask.extend([0] * padding_length)
80
+
81
+ window_offset_mapping = [(0, 0)] + offsets_slice + [(0, 0)]
82
+ window_offset_mapping += [(0, 0)] * padding_length
83
+
84
+ window_word_ids = [None] + word_ids_slice + [None]
85
+ window_word_ids += [None] * padding_length
86
+
87
+ windows_data.append({
88
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
89
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
90
+ "word_ids": window_word_ids[:max_length],
91
+ "offset_mapping": window_offset_mapping[:max_length],
92
+ })
93
+ if end_token_idx >= total_tokens:
94
+ break
95
+ start_token_idx += step
96
+ return windows_data
97
+
98
+
99
+ def classify_text(text: str, threshold_percent: float) -> Tuple[str, Optional[str]]:
100
+ """Classify link tokens with windowing. Returns (html, warning)."""
101
+ if not text.strip():
102
+ return "", "Input text is empty."
103
+
104
+ windows = windowize_inference(text, tokenizer, config.max_length, config.doc_stride)
105
+ if not windows:
106
+ return "", "Could not generate any windows for processing."
107
+
108
+ char_link_probabilities = np.zeros(len(text), dtype=np.float32)
109
+
110
+ with torch.no_grad():
111
+ for window in windows:
112
+ inputs = {
113
+ 'input_ids': window['input_ids'].unsqueeze(0).to(config.device),
114
+ 'attention_mask': window['attention_mask'].unsqueeze(0).to(config.device)
115
+ }
116
+ outputs = model(**inputs)
117
+ probabilities = torch.softmax(outputs.logits, dim=-1).squeeze(0)
118
+ link_probs = probabilities[:, 1].cpu().numpy()
119
+
120
+ for i, offset in enumerate(window['offset_mapping']):
121
+ if isinstance(offset, (list, tuple)) and len(offset) == 2:
122
+ start, end = offset
123
+ if window['word_ids'][i] is not None and start < end:
124
+ char_link_probabilities[start:end] = np.maximum(
125
+ char_link_probabilities[start:end], link_probs[i]
126
+ )
127
+
128
+ final_threshold = threshold_percent / 100.0
129
+
130
+ full_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False)
131
+ word_ids = full_encoding.word_ids(batch_index=0)
132
+ offsets = full_encoding['offset_mapping']
133
+
134
+ word_max_prob_map: Dict[int, float] = {}
135
+ word_char_spans: Dict[int, List[int]] = {}
136
+
137
+ for i, word_id in enumerate(word_ids):
138
+ if word_id is not None and i < len(offsets):
139
+ start_char, end_char = offsets[i]
140
+ if start_char < end_char:
141
+ current_token_max_prob = np.max(char_link_probabilities[start_char:end_char]) if start_char < len(char_link_probabilities) else 0.0
142
+
143
+ if word_id not in word_max_prob_map:
144
+ word_max_prob_map[word_id] = current_token_max_prob
145
+ word_char_spans[word_id] = [start_char, end_char]
146
+ else:
147
+ word_max_prob_map[word_id] = max(word_max_prob_map[word_id], current_token_max_prob)
148
+ word_char_spans[word_id][1] = end_char
149
+
150
+ highlight_candidates: Dict[int, float] = {}
151
+ for word_id, max_prob in word_max_prob_map.items():
152
+ if max_prob >= final_threshold:
153
+ highlight_candidates[word_id] = max_prob
154
+
155
+ max_highlight_prob = max(highlight_candidates.values()) if highlight_candidates else 0.0
156
+
157
+ html_parts, current_char = [], 0
158
+ sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
159
+
160
+ for word_id in sorted_word_ids:
161
+ start_char, end_char = word_char_spans[word_id]
162
+
163
+ if start_char > current_char:
164
+ html_parts.append(text[current_char:start_char])
165
+
166
+ word_text = text[start_char:end_char]
167
+
168
+ if word_id in highlight_candidates:
169
+ word_prob = highlight_candidates[word_id]
170
+ normalized_opacity = (word_prob / max_highlight_prob) * 0.9 + 0.1 if max_highlight_prob > 0 else 1.0
171
+
172
+ html_parts.append(
173
+ f"<span style='background-color: #D4EDDA; color: #155724; "
174
+ f"padding: 0.1em 0.2em; border-radius: 0.2em; opacity: {normalized_opacity:.2f};' "
175
+ f"title='Link Probability: {word_prob:.1%}'>{word_text}</span>"
176
+ )
177
+ else:
178
+ html_parts.append(word_text)
179
+ current_char = end_char
180
+
181
+ if current_char < len(text):
182
+ html_parts.append(text[current_char:])
183
+
184
+ return "".join(html_parts), None
185
+
186
+
187
+ # ----------------------------------
188
+ # Gradio Interface
189
+ # ----------------------------------
190
+ def predict(text: str, threshold: float) -> str:
191
+ """Main prediction function for Gradio."""
192
+ html, warning = classify_text(text, threshold)
193
+ if warning:
194
+ return f"<p style='color: orange;'>{warning}</p>"
195
+ return html
196
+
197
+
198
+ # Build the interface
199
+ with gr.Blocks(title="LinkBERT by DEJAN AI") as demo:
200
+ gr.Markdown("# LinkBERT")
201
+ gr.Markdown("Predict natural link placement in plain text.")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ text_input = gr.Textbox(
206
+ label="Input Text",
207
+ placeholder="Paste your text here...",
208
+ lines=8,
209
+ value="DEJAN AI is the world's leading AI SEO agency. This tool showcases the capability of our latest link prediction model called LinkBERT."
210
+ )
211
+ threshold_slider = gr.Slider(
212
+ minimum=0,
213
+ maximum=100,
214
+ value=70,
215
+ step=1,
216
+ label="Link Probability Threshold (%)"
217
+ )
218
+ submit_btn = gr.Button("Classify Text", variant="primary")
219
+
220
+ with gr.Column():
221
+ output_html = gr.HTML(label="Results")
222
+
223
+ submit_btn.click(
224
+ fn=predict,
225
+ inputs=[text_input, threshold_slider],
226
+ outputs=output_html,
227
+ api_name="predict" # Exposes as /api/predict
228
+ )
229
+
230
+ # Launch
231
+ if __name__ == "__main__":
232
+ demo.launch(server_name="0.0.0.0", server_port=7860)