dejanseo commited on
Commit
df3962f
·
verified ·
1 Parent(s): 88da8fc

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +69 -32
src/streamlit_app.py CHANGED
@@ -8,23 +8,28 @@ import trafilatura
8
  # Streamlit config
9
  st.set_page_config(layout="wide", page_title="LinkBERT")
10
 
11
- # Load tokenizer & model (avoid meta-tensor .to() issue)
12
  MODEL_ID = "dejanseo/LinkBERT-XL"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
15
 
16
- load_kwargs = {}
17
- if torch.cuda.is_available():
18
- # Load directly onto GPU(s); do NOT call .to(...) afterward
19
- load_kwargs.update(dict(device_map="auto", torch_dtype=torch.float16))
 
 
 
 
 
 
 
20
  else:
21
- # CPU load without meta tensors
22
- load_kwargs.update(dict(device_map=None))
23
 
24
- model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, **load_kwargs)
25
  model.eval()
26
 
27
- # Functions
28
  def tokenize_with_indices(text: str):
29
  encoded = tokenizer.encode_plus(
30
  text,
@@ -66,6 +71,7 @@ def process_text(inputs: str, confidence_threshold: float):
66
  with torch.no_grad():
67
  for chunk in chunk_texts:
68
  input_ids, token_offsets = tokenize_with_indices(chunk)
 
69
  input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
70
 
71
  outputs = model(input_ids_tensor)
@@ -73,53 +79,77 @@ def process_text(inputs: str, confidence_threshold: float):
73
  predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
74
  softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
75
 
 
76
  word_info = {}
77
  for idx, (start, end) in enumerate(token_offsets):
78
  if idx == 0 or idx == len(token_offsets) - 1:
79
  continue # skip specials
80
 
81
  word_start = start
82
- while word_start > 0 and chunk[word_start - 1] != ' ':
 
 
83
  word_start -= 1
 
 
 
 
84
 
 
 
85
  if word_start not in word_info:
 
86
  word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
87
 
88
  conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
 
 
89
  if predictions[idx] == 1 and conf_pct >= confidence_threshold:
90
- word_info[word_start]["prediction"] = 1
 
 
91
  word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
92
  word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
93
 
94
  last_end = 0
 
95
  for word_start in sorted(word_info.keys()):
96
  word_data = word_info[word_start]
97
- for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
 
98
  escaped = subtoken_text.replace("$", "\\$")
 
99
  if last_end < subtoken_start:
100
  reconstructed_text += chunk[last_end:subtoken_start]
 
101
  if word_data["prediction"] == 1:
 
102
  reconstructed_text += (
103
- f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
104
  )
105
  else:
106
- reconstructed_text += escaped
 
107
  last_end = subtoken_end
108
 
 
109
  df_data["Word"].append(escaped)
110
- df_data["Prediction"].append(word_data["prediction"])
111
- df_data["Confidence"].append(word_info[word_start]["confidence"])
112
  df_data["Start"].append(subtoken_start + original_position_offset)
113
  df_data["End"].append(subtoken_end + original_position_offset)
114
 
115
- original_position_offset += len(chunk) + 1
116
-
117
- reconstructed_text += chunk[last_end:].replace("$", "\\$")
 
 
 
118
 
119
  df_tokens = pd.DataFrame(df_data)
120
  return reconstructed_text, df_tokens
121
 
122
- # UI
123
  st.title("LinkBERT")
124
  st.markdown("""
125
  LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
@@ -130,22 +160,29 @@ confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
130
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
131
 
132
  with tab1:
133
- user_input = st.text_area("Enter text to process:")
134
  if st.button("Process Text"):
135
- highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
136
- st.markdown(highlighted_text, unsafe_allow_html=True)
137
- st.dataframe(df_tokens)
 
 
 
138
 
139
  with tab2:
140
- url_input = st.text_input("Enter URL to process:")
141
  if st.button("Fetch and Process"):
142
- content = fetch_and_extract_content(url_input)
143
- if content:
144
- highlighted_text, df_tokens = process_text(content, confidence_threshold)
145
- st.markdown(highlighted_text, unsafe_allow_html=True)
146
- st.dataframe(df_tokens)
 
 
 
 
147
  else:
148
- st.error("Could not fetch content from the URL. Please check the URL and try again.")
149
 
150
  st.divider()
151
  st.markdown("""
@@ -165,4 +202,4 @@ LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
165
  Interested in using this in an automated pipeline for bulk link prediction?
166
 
167
  Please [book an appointment](https://dejanmarketing.com/conference/).
168
- """)
 
8
  # Streamlit config
9
  st.set_page_config(layout="wide", page_title="LinkBERT")
10
 
11
+ # Load tokenizer & model
12
  MODEL_ID = "dejanseo/LinkBERT-XL"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
15
 
16
+ # Determine the device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Load the model directly to the determined device
20
+ # Avoid device_map="auto" if it's causing meta tensor issues with certain torch/transformers versions.
21
+ # Load to CPU first, then move to GPU if available.
22
+ model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
23
+
24
+ # Explicitly move model to the determined device and dtype
25
+ if device == "cuda":
26
+ model.half().to(device) # Use .half() for float16 on GPU
27
  else:
28
+ model.to(device) # For CPU, typically stick to float32 unless model was specifically trained with bfloat16 for CPU
 
29
 
 
30
  model.eval()
31
 
32
+ # Functions (rest of your functions remain mostly the same)
33
  def tokenize_with_indices(text: str):
34
  encoded = tokenizer.encode_plus(
35
  text,
 
71
  with torch.no_grad():
72
  for chunk in chunk_texts:
73
  input_ids, token_offsets = tokenize_with_indices(chunk)
74
+ # Ensure input_ids_tensor is on the same device as the model
75
  input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
76
 
77
  outputs = model(input_ids_tensor)
 
79
  predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
80
  softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
81
 
82
+ # The rest of your processing logic
83
  word_info = {}
84
  for idx, (start, end) in enumerate(token_offsets):
85
  if idx == 0 or idx == len(token_offsets) - 1:
86
  continue # skip specials
87
 
88
  word_start = start
89
+ # Find the actual start of the word corresponding to this token
90
+ # This logic assumes space-separated words for the most part
91
+ while word_start > 0 and chunk[word_start - 1] not in [' ', '\n', '\t']:
92
  word_start -= 1
93
+ # If a word_start maps to multiple tokens (e.g., "don't" -> ["don", "'", "t"])
94
+ # ensure we pick the earliest start for that conceptual word
95
+ while word_start > 0 and (chunk[word_start-1:word_start] == ' ' or tokenizer.decode(tokenizer.encode(chunk[word_start-1:end], add_special_tokens=False))[0] == chunk[word_start-1]):
96
+ word_start -= 1
97
 
98
+ # Use a tuple (word_start, actual_word_text_from_chunk) as key for more robust aggregation
99
+ # For simplicity here, we stick to word_start
100
  if word_start not in word_info:
101
+ # Initialize with default for "not link"
102
  word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
103
 
104
  conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
105
+
106
+ # Only mark as 1 if the current token's prediction is 1 AND confidence meets threshold
107
  if predictions[idx] == 1 and conf_pct >= confidence_threshold:
108
+ word_info[word_start]["prediction"] = 1 # Mark the whole 'word' as a link
109
+
110
+ # Keep the max confidence for any token within the 'word'
111
  word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
112
  word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
113
 
114
  last_end = 0
115
+ # Sort by word_start to maintain order
116
  for word_start in sorted(word_info.keys()):
117
  word_data = word_info[word_start]
118
+ # Sort subtokens to ensure they are processed in order within a word
119
+ for subtoken_start, subtoken_end, subtoken_text in sorted(word_data["subtokens"], key=lambda x: x[0]):
120
  escaped = subtoken_text.replace("$", "\\$")
121
+ # Add any text between the last processed token and the current one
122
  if last_end < subtoken_start:
123
  reconstructed_text += chunk[last_end:subtoken_start]
124
+
125
  if word_data["prediction"] == 1:
126
+ # Apply highlight to the subtoken
127
  reconstructed_text += (
128
+ f"<span style='background-color: rgba(0, 255, 0, 0.5); display: inline;'>{escaped}</span>" # Added alpha for better readability
129
  )
130
  else:
131
+ reconstructed_text += escaped # No highlight
132
+
133
  last_end = subtoken_end
134
 
135
+ # For DataFrame, append the info for each *subtoken*
136
  df_data["Word"].append(escaped)
137
+ df_data["Prediction"].append(word_data["prediction"]) # Prediction applies to the whole conceptual word
138
+ df_data["Confidence"].append(word_data["confidence"]) # Confidence applies to the whole conceptual word
139
  df_data["Start"].append(subtoken_start + original_position_offset)
140
  df_data["End"].append(subtoken_end + original_position_offset)
141
 
142
+ # Add any remaining text from the current chunk after the last token
143
+ if last_end < len(chunk):
144
+ reconstructed_text += chunk[last_end:].replace("$", "\\$")
145
+
146
+ # Update offset for the next chunk. Add 1 for the space that was implicitly there.
147
+ original_position_offset += len(chunk) + 1
148
 
149
  df_tokens = pd.DataFrame(df_data)
150
  return reconstructed_text, df_tokens
151
 
152
+ # UI (remains the same)
153
  st.title("LinkBERT")
154
  st.markdown("""
155
  LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
 
160
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
161
 
162
  with tab1:
163
+ user_input = st.text_area("Enter text to process:", height=200) # Added height for better UX
164
  if st.button("Process Text"):
165
+ if user_input: # Ensure input is not empty
166
+ highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
167
+ st.markdown(highlighted_text, unsafe_allow_html=True)
168
+ st.dataframe(df_tokens)
169
+ else:
170
+ st.warning("Please enter some text to process.")
171
 
172
  with tab2:
173
+ url_input = st.text_input("Enter URL to process:", value="https://dejan.ai/blog/gpt-5-made-seo-irreplaceable/") # Pre-fill with example
174
  if st.button("Fetch and Process"):
175
+ if url_input: # Ensure URL input is not empty
176
+ with st.spinner("Fetching and processing content..."):
177
+ content = fetch_and_extract_content(url_input)
178
+ if content:
179
+ highlighted_text, df_tokens = process_text(content, confidence_threshold)
180
+ st.markdown(highlighted_text, unsafe_allow_html=True)
181
+ st.dataframe(df_tokens)
182
+ else:
183
+ st.error("Could not fetch content from the URL. Please check the URL and try again.")
184
  else:
185
+ st.warning("Please enter a URL to process.")
186
 
187
  st.divider()
188
  st.markdown("""
 
202
  Interested in using this in an automated pipeline for bulk link prediction?
203
 
204
  Please [book an appointment](https://dejanmarketing.com/conference/).
205
+ """)