kirudang commited on
Commit
fc6dcab
·
1 Parent(s): c20b0f9

Sync SafeSeal app

Browse files
Files changed (5) hide show
  1. README.md +66 -7
  2. SynthID_randomization.py +294 -0
  3. app.py +364 -0
  4. requirements.txt +10 -0
  5. utils_final.py +1213 -0
README.md CHANGED
@@ -1,13 +1,72 @@
1
  ---
2
- title: SafeSeal
3
- emoji: 🏆
4
- colorFrom: purple
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Demo for SafeSeal watermarking method.
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SafeSeal Watermark
3
+ emoji: 🔒
4
+ colorFrom: blue
5
  colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.50.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # SafeSeal Watermark
14
+
15
+ **Content-Preserving Watermarking for Large Language Model Deployments.**
16
+
17
+ Generate watermarked text by key-conditioned sampling words with context-aware synonyms.
18
+
19
+ ## Features
20
+
21
+ - 🔑 **Secret Key**: Deterministic watermarking with user-controlled key
22
+ - 📊 **BERTScore Filtering**: Adjustable similarity threshold (0.0 - 1.0)
23
+ - 🏆 **Tournament Sampling**: Select synonyms using tournament-based randomization
24
+ - ✨ **Visual Highlighting**: See exactly which words were changed
25
+ - 🚀 **GPU Support**: Fast inference with automatic GPU detection
26
+ - 🛡️ **Smart Filtering**: Excludes antonyms, specific nouns in same category, and preserves entity names
27
+
28
+ ## How It Works
29
+
30
+ 1. **Entity Detection**: Extracts eligible words (nouns, verbs, adjectives, adverbs) while skipping named entities
31
+ 2. **Candidate Generation**: Uses RoBERTa-base to generate semantically similar alternatives
32
+ 3. **BERTScore Filtering**: Evaluates candidates against a similarity threshold
33
+ 4. **Tournament Selection**: Deterministically selects replacements based on secret key
34
+ 5. **Visualization**: Highlights changed words in the output
35
+
36
+ ## Usage
37
+
38
+ 1. Enter your text in the left panel
39
+ 2. Adjust hyperparameters in the sidebar:
40
+ - **Secret Key**: Used for deterministic randomization
41
+ - **Threshold**: Similarity threshold (default: 0.98)
42
+ - **Tournament parameters**: Fine-tune the selection process
43
+ 3. Click "🚀 Generate Watermark"
44
+ 4. View the watermarked text with highlighted changes
45
+
46
+ ## Parameters
47
+
48
+ - **Secret Key**: Used for deterministic randomization
49
+ - **Threshold (0.98)**: BERTScore similarity threshold - higher = more conservative changes
50
+ - **m (10)**: Number of tournament rounds
51
+ - **c (2)**: Competitors per tournament match
52
+ - **h (6)**: Context size (left tokens to consider)
53
+ - **Alpha (1.1)**: Temperature scaling factor
54
+
55
+ ## Technical Details
56
+
57
+ - **Model**: RoBERTa-base for masked language modeling
58
+ - **Similarity Scoring**: BERTScore F1 scores
59
+ - **Selection**: Tournament-based deterministic sampling
60
+ - **Filtering**: POS tag matching, antonym exclusion, semantic compatibility checks
61
+
62
+ ## Example
63
+
64
+ **Input:**
65
+ > "The quick brown fox jumps over the lazy dog."
66
+
67
+ **Watermarked Output:**
68
+ > "The swift brown fox leaps over the idle dog."
69
+
70
+ Changed words highlighted: swift (was quick), leaps (was jumps), idle (was lazy)
71
+
72
+ ⚠️ **Demo Version**: This is a demonstration using a light model to showcase the watermarking pipeline. Results may not be perfect and are intended for testing purposes only.
SynthID_randomization.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sampling with Tournament
2
+ # input:
3
+ # original: original word
4
+ # alternatives: substitute candidates
5
+ # similarity: similarity scores
6
+ # context: 1-token left context
7
+ # key: secret key for hashing
8
+ # m: number of tournament rounds. Set to 2
9
+ # c: number of candidates per round (not including original): c is set to the number of alternatives = K
10
+ # alpha: scaling factor like temperature. Set to 1
11
+
12
+ # Process:
13
+ # 1. Normalize similarity scores to [0,1] though softmax with scaling factor alpha
14
+ # 2. Create random score for each candidate with m rounds: so we have a matrix of shape (K, m)
15
+ # 3. Each score is generated by hashing key, context, candidate, round index
16
+ # 4. For the first round, sampling from normalized similarity scores M = K^m candidates (each candidate can appear multiple times) and divide into groups of size K
17
+ # 5. For each group, select the candidate with the highest random score to go to the next round. If they are tied, just randomly select one of the tied candidates. For example, K = 3, we have 1 1 0 score, so break the 1 1 tie randomly.
18
+ # 6: For next round, use the winners from the previous round, but the scores from the matrix for that round. Repeat until we have one winner.
19
+
20
+ # Output: return winner candidate index
21
+
22
+
23
+ import hmac
24
+ import hashlib
25
+ import math
26
+ from typing import List, Tuple
27
+
28
+ def _softmax(xs: List[float], alpha: float = 1.0) -> List[float]:
29
+ # Temperature-scaled softmax normalized to sum 1
30
+ if not xs:
31
+ return []
32
+ # subtract max for numerical stability
33
+ m = max(xs)
34
+ zs = [math.exp(alpha * (x - m)) for x in xs]
35
+ s = sum(zs)
36
+ return [z / s for z in zs]
37
+
38
+ def _hash_bytes(key: str, payload: str) -> bytes:
39
+ return hmac.new(key.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256).digest()
40
+
41
+ def _hash_to_uniform01(key: str, payload: str) -> float:
42
+ # Map 8 bytes to [0,1) as a 64-bit integer / 2**64
43
+ b = _hash_bytes(key, payload)[:8]
44
+ n = int.from_bytes(b, "big", signed=False)
45
+ return n / 2**64
46
+
47
+ def _ctx_str(context: List[str]) -> str:
48
+ # Use last token (you can widen to last-4 if you like)
49
+ toks = (context or [])
50
+ return " ".join(toks[-1:]).lower()
51
+
52
+ def _per_round_score(key: str, ctx: str, candidate: str, round_idx: int) -> int:
53
+ # Deterministic "random" score following Bernoulli(0.5): 0 or 1
54
+ payload = f"score::{ctx}::{candidate}::r{round_idx}"
55
+ uniform_score = _hash_to_uniform01(key, payload)
56
+ # Convert uniform [0,1) to Bernoulli(0.5): 0 if < 0.5, 1 if >= 0.5
57
+ return 1 if uniform_score >= 0.5 else 0
58
+
59
+ def _sample_categorical_from_uniform(u: float, probs: List[float]) -> int:
60
+ # Inverse-CDF sampling from probs using a uniform u in [0,1)
61
+ c = 0.0
62
+ for i, p in enumerate(probs):
63
+ c += p
64
+ if u < c:
65
+ return i
66
+ return len(probs) - 1 # fallback on last due to float sums
67
+
68
+ def _draw_M_candidates(key: str, ctx: str, probs: List[float], K: int, c: int, m: int) -> List[int]:
69
+ """
70
+ Draw M = c^m candidates (indices 0..K-1) with replacement from probs,
71
+ using a deterministic stream of uniforms keyed by (key, ctx).
72
+ """
73
+ M = c ** m
74
+ picks = []
75
+ for draw_idx in range(M):
76
+ u = _hash_to_uniform01(key, f"draw::{ctx}::m{m}::c{c}::i{draw_idx}")
77
+ j = _sample_categorical_from_uniform(u, probs)
78
+ picks.append(j)
79
+ return picks
80
+
81
+ def _run_tournament_round(
82
+ key: str,
83
+ ctx: str,
84
+ picks: List[int],
85
+ candidates: List[str],
86
+ round_idx: int,
87
+ group_size: int,
88
+ ) -> List[int]:
89
+ """
90
+ Split 'picks' into consecutive groups of size 'group_size'.
91
+ For each group, select the winner = argmax per-round-score;
92
+ tie-break deterministically.
93
+ Returns the list of winning indices (into candidates list).
94
+ """
95
+ assert len(picks) % group_size == 0, "Group partition must divide evenly."
96
+ winners = []
97
+
98
+ # Add detailed logging for all rounds
99
+ # print(f"\n=== ROUND {round_idx} DETAILS ===")
100
+ # print(f"Total picks: {len(picks)}")
101
+ # print(f"Group size: {group_size}")
102
+ # print(f"Number of groups: {len(picks) // group_size}")
103
+ # print(f"Picks array: {picks}")
104
+ # print(f"Candidates: {candidates}")
105
+
106
+ for g in range(0, len(picks), group_size):
107
+ group = picks[g:g+group_size] # indices into candidates list
108
+ group_num = g // group_size + 1
109
+
110
+ # Compute per-round scores
111
+ scored: List[Tuple[int, int]] = []
112
+ for idx in group:
113
+ cand = candidates[idx]
114
+ s = _per_round_score(key, ctx, cand, round_idx)
115
+ scored.append((s, idx))
116
+
117
+ # Add detailed logging for all rounds
118
+ # print(f"\n--- Group {group_num} ---")
119
+ # print(f"Group indices: {group}")
120
+ # print(f"Group candidates: {[candidates[idx] for idx in group]}")
121
+ # print("Scores:")
122
+ # for score, idx in scored:
123
+ # print(f" {candidates[idx]}: {score}")
124
+
125
+ # Find max score; tie-break with a deterministic secondary key
126
+ max_s = max(s for s, _ in scored)
127
+ tied = [idx for (s, idx) in scored if s == max_s]
128
+
129
+ if len(tied) == 1:
130
+ winner_idx = tied[0]
131
+ winners.append(winner_idx)
132
+ # print(f"Winner: {candidates[winner_idx]} (score: {max_s})")
133
+ else:
134
+ # Random tie-breaker: pick one randomly from tied candidates
135
+ import random
136
+ # Create a more robust seed by combining key, context, round, and group info
137
+ seed_string = f"{key}:{ctx}:r{round_idx}:g{group_num}"
138
+ # Convert string to integer seed using a more reliable method
139
+ seed_value = sum(ord(c) * (i + 1) for i, c in enumerate(seed_string))
140
+ random.seed(seed_value)
141
+ winner_idx = random.choice(tied)
142
+ random.seed() # Reset seed to avoid affecting other random operations
143
+ winners.append(winner_idx)
144
+ # print(f"TIE! Winners: {[candidates[idx] for idx in tied]}")
145
+ # print(f"Random tie-breaker winner: {candidates[winner_idx]} (score: {max_s})")
146
+ # print(f" (Seed used: {seed_string} -> {seed_value})")
147
+
148
+ # Add summary logging for all rounds
149
+ # print(f"\nRound {round_idx} winners: {[candidates[idx] for idx in winners]}")
150
+ # print(f"Winner indices: {winners}")
151
+ # print("=" * 50)
152
+
153
+ return winners
154
+
155
+ def _show_score_matrix(key: str, ctx: str, candidates: List[str], m: int):
156
+ """
157
+ Display the complete score matrix for all candidates across all rounds.
158
+ Each candidate gets a Bernoulli(0.5) score for each round.
159
+ """
160
+ # print(f"\n=== SCORE MATRIX (Bernoulli 0.5) ===")
161
+ # print(f"Format: candidate -> [round1_score, round2_score, ...]")
162
+ # for i, candidate in enumerate(candidates):
163
+ # scores = []
164
+ # for round_idx in range(1, m + 1):
165
+ # score = _per_round_score(key, ctx, candidate, round_idx)
166
+ # scores.append(score)
167
+ # print(f"{candidate:8} -> {scores}")
168
+ # print("=" * 50)
169
+ pass
170
+
171
+ def tournament_randomize(
172
+ original: str,
173
+ alternatives: List[str],
174
+ similarity: List[float],
175
+ context: List[str],
176
+ key: str,
177
+ m: int = 2,
178
+ c: int = 2,
179
+ alpha: float = 1.0,
180
+ ) -> int:
181
+ """
182
+ Tournament sampling among alternatives only (K = len(alternatives)).
183
+ Returns the WINNER INDEX (0..K-1) into 'alternatives'.
184
+
185
+ Steps:
186
+ 1) softmax(similarity, alpha) -> probs over alternatives
187
+ 2) build per-round scores for each candidate via HMAC (on the fly)
188
+ 3) Round 1: draw M = c^m picks from probs; group into size c; pick per-group winners by scores of round 1
189
+ 4) Next rounds: group winners into size c; pick winners using that round's scores
190
+ 5) Repeat until one winner remains
191
+ """
192
+ assert len(alternatives) == len(similarity) > 0, "Need at least one alternative with a similarity score."
193
+ assert m >= 1, "Tournament rounds m must be >= 1"
194
+ assert c >= 2, "Number of competitors per match c must be >= 2"
195
+
196
+ K = len(alternatives)
197
+ ctx = _ctx_str(context)
198
+ probs = _softmax(similarity, alpha=alpha)
199
+
200
+ # Add detailed logging for setup
201
+ # print(f"\n=== TOURNAMENT SETUP ===")
202
+ # print(f"Original word: '{original}'")
203
+ # print(f"Alternatives: {alternatives}")
204
+ # print(f"Similarity scores: {similarity}")
205
+ # print(f"Context: {context}")
206
+ # print(f"Key: {key}")
207
+ # print(f"Tournament rounds (m): {m}")
208
+ # print(f"Competitors per match (c): {c}")
209
+ # print(f"Alpha (temperature): {alpha}")
210
+ # print(f"K (number of alternatives): {K}")
211
+ # print(f"Context string: '{ctx}'")
212
+ # print(f"Normalized probabilities: {[f'{p:.6f}' for p in probs]}")
213
+ # print(f"Expected picks (M = c^m): {c**m}")
214
+
215
+ # First round: M = c^m draws from probs
216
+ picks = _draw_M_candidates(key, ctx, probs, K, c, m) # indices 0..K-1 into alternatives
217
+
218
+ # Show the picks array for the first round
219
+ # print(f"Generated picks array: {picks}")
220
+ # print(f"Picks correspond to words: {[alternatives[idx] for idx in picks]}")
221
+
222
+ # We treat 'candidates' as the alternatives list; indices map directly
223
+ candidates = alternatives
224
+
225
+ # Show the complete score matrix for all candidates across all rounds
226
+ _show_score_matrix(key, ctx, candidates, m)
227
+
228
+ # Run m rounds. Each round groups current list into blocks of size c,
229
+ # picks one per group using per-round scores.
230
+ current = picks
231
+ for r in range(1, m + 1):
232
+ # group_size is now c instead of K
233
+ current = _run_tournament_round(
234
+ key=key,
235
+ ctx=ctx,
236
+ picks=current,
237
+ candidates=candidates,
238
+ round_idx=r,
239
+ group_size=c,
240
+ )
241
+ # After each round, the number of survivors shrinks by factor c.
242
+ # After m rounds, we must have exactly one winner.
243
+ assert len(current) == 1, f"Expected a single winner after {m} rounds, got {len(current)}"
244
+
245
+ final_winner_idx = current[0]
246
+ # print(f"\n=== FINAL RESULT ===")
247
+ # print(f"Winner index: {final_winner_idx}")
248
+ # print(f"Winner word: '{candidates[final_winner_idx]}'")
249
+
250
+ return final_winner_idx
251
+
252
+ # Convenience wrapper that returns the selected word
253
+ def tournament_select_word(
254
+ original: str,
255
+ alternatives: List[str],
256
+ similarity: List[float],
257
+ context: List[str],
258
+ key: str,
259
+ m: int = 2,
260
+ c: int = 2,
261
+ alpha: float = 1.0,
262
+ ) -> str:
263
+ idx = tournament_randomize(original, alternatives, similarity, context, key, m, c, alpha)
264
+ return alternatives[idx]
265
+
266
+
267
+ # # Sample running
268
+ # original = "big"
269
+ # alts = ["large", "huge", "massive"] # K = 3
270
+ # sims = [0.82, 0.55, 0.60] # any similarity scores
271
+ # ctx = ["a"] # 1-token left context (e.g., "... a ___")
272
+ # key = "super-secret-key"
273
+ # m = 2
274
+ # c = 2
275
+ # alpha = 1.0
276
+
277
+ # winner_idx = tournament_randomize(original, alts, sims, ctx, key, m, c, alpha)
278
+ # print("Winner index:", winner_idx, "->", alts[winner_idx])
279
+
280
+ # # Or directly:
281
+ # print("Winner word:", tournament_select_word(original, alts, sims, ctx, key, m, c, alpha))
282
+
283
+ # # Test with a different key
284
+ # key2 = "different-secret-key"
285
+ # winner_idx2 = tournament_randomize(original, alts, sims, ctx, key2, m, c, alpha)
286
+ # print("Winner index with different key:", winner_idx2, "->", alts[winner_idx2])
287
+
288
+ # # Test with a different key
289
+ # key3 = "key224242"
290
+ # winner_idx3 = tournament_randomize(original, alts, sims, ctx, key3, m, c, alpha)
291
+ # print("Winner index with different key:", winner_idx3, "->", alts[winner_idx3])
292
+
293
+
294
+
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import torch
4
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
5
+ import spacy
6
+ import subprocess
7
+ import sys
8
+ import nltk
9
+ from nltk.tokenize import word_tokenize
10
+ from utils_final import extract_entities_and_pos, whole_context_process_sentence
11
+
12
+ # Download NLTK data if not available
13
+ def setup_nltk():
14
+ """Setup NLTK data with error handling."""
15
+ try:
16
+ nltk.download('punkt_tab', quiet=True)
17
+ except:
18
+ pass
19
+ try:
20
+ nltk.download('averaged_perceptron_tagger_eng', quiet=True)
21
+ except:
22
+ pass
23
+ try:
24
+ nltk.download('wordnet', quiet=True)
25
+ except:
26
+ pass
27
+ try:
28
+ nltk.download('omw-1.4', quiet=True)
29
+ except:
30
+ pass
31
+
32
+ setup_nltk()
33
+
34
+ # Set environment
35
+ cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
36
+
37
+ # Load spaCy model - download if not available
38
+ try:
39
+ nlp = spacy.load("en_core_web_sm")
40
+ except OSError:
41
+ print("Downloading spaCy model...")
42
+ subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
43
+ nlp = spacy.load("en_core_web_sm")
44
+
45
+ # Define apply_replacements function (from Safeseal_gen_final.py)
46
+ def apply_replacements(sentence, replacements):
47
+ """
48
+ Apply replacements to the sentence while preserving original formatting, spacing, and punctuation.
49
+ """
50
+ doc = nlp(sentence) # Tokenize the sentence
51
+ tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens
52
+
53
+ # Apply replacements based on token positions
54
+ for position, target, replacement in replacements:
55
+ if position < len(tokens) and tokens[position].strip() == target:
56
+ tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "")
57
+
58
+ # Reassemble the sentence
59
+ return "".join(tokens)
60
+
61
+ # Initialize session state for model caching
62
+ @st.cache_resource
63
+ def load_model():
64
+ """Load the model and tokenizer (cached to avoid reloading on every run)"""
65
+ print("Loading model...")
66
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
67
+ lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager")
68
+
69
+ tokenizer.model_max_length = 512
70
+ tokenizer.max_len = 512
71
+
72
+ if hasattr(lm_model.config, 'max_position_embeddings'):
73
+ lm_model.config.max_position_embeddings = 512
74
+
75
+ lm_model.eval()
76
+
77
+ if torch.cuda.is_available():
78
+ lm_model = lm_model.cuda()
79
+ print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
80
+ else:
81
+ print("Model loaded on CPU")
82
+
83
+ return tokenizer, lm_model
84
+
85
+ sampling_results = []
86
+
87
+ def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'):
88
+ """
89
+ Wrapper function to process text and return watermarked output with tracking of changes.
90
+ """
91
+ global sampling_results
92
+ sampling_results = []
93
+
94
+ lines = text.splitlines(keepends=True)
95
+ final_text = []
96
+ total_randomized_words = 0
97
+ total_words = len(word_tokenize(text))
98
+
99
+ # Track changed words and their positions
100
+ changed_words = [] # List of (original, replacement, position)
101
+
102
+ for line in lines:
103
+ if line.strip():
104
+ replacements, sampling_results_line = whole_context_process_sentence(
105
+ text,
106
+ line.strip(),
107
+ tokenizer, lm_model, Top_K, threshold,
108
+ secret_key, m, c, h, alpha, "output",
109
+ batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode
110
+ )
111
+
112
+ sampling_results.extend(sampling_results_line)
113
+
114
+ if replacements:
115
+ randomized_line = apply_replacements(line, replacements)
116
+ final_text.append(randomized_line)
117
+
118
+ # Track ONLY actual changes (where original != replacement)
119
+ for position, original, replacement in replacements:
120
+ if original != replacement:
121
+ changed_words.append((original, replacement, position))
122
+ total_randomized_words += 1
123
+ else:
124
+ final_text.append(line)
125
+ else:
126
+ final_text.append(line)
127
+
128
+ return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results
129
+
130
+ def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results):
131
+ """
132
+ Create HTML with highlighted changed words using spaCy tokenization.
133
+ """
134
+ # Create a set of replacement words that were actually changed (not same as original)
135
+ actual_replacements = set()
136
+ replacement_to_original = {}
137
+
138
+ for original, replacement, _ in changed_words_info:
139
+ if original.lower() != replacement.lower(): # Only map actual changes
140
+ actual_replacements.add(replacement.lower())
141
+ replacement_to_original[replacement.lower()] = original
142
+
143
+ # Parse watermarked text with spaCy
144
+ doc_watermarked = nlp(watermarked_text)
145
+
146
+ # Build HTML by processing the watermarked text
147
+ result_html = []
148
+ words_highlighted = set() # Track which words we've highlighted (to avoid duplicates)
149
+
150
+ for token in doc_watermarked:
151
+ text = token.text_with_ws
152
+ text_clean = token.text.strip('.,!?;:')
153
+ text_lower = text_clean.lower()
154
+
155
+ # Only highlight if this word is in our actual replacements set
156
+ # and we haven't already highlighted this exact word
157
+ if text_lower in actual_replacements and text_lower not in words_highlighted:
158
+ original_word = replacement_to_original.get(text_lower, text_clean)
159
+
160
+ # Only highlight if actually different from original
161
+ if original_word.lower() != text_lower:
162
+ tooltip = f"Original: {original_word} → New: {text_clean}"
163
+ # Enhanced highlighting with better colors
164
+ highlighted_text = f"<mark style='background: linear-gradient(120deg, #84fab0 0%, #8fd3f4 100%); padding: 2px 6px; border-radius: 4px; font-weight: 500; box-shadow: 0 1px 2px rgba(0,0,0,0.1);' title='{tooltip}'>{text_clean}</mark>"
165
+
166
+ # Preserve trailing whitespace and punctuation
167
+ if text != text_clean:
168
+ highlighted_text += text[len(text_clean):]
169
+
170
+ result_html.append(highlighted_text)
171
+ words_highlighted.add(text_lower) # Mark as highlighted
172
+ else:
173
+ result_html.append(text)
174
+ else:
175
+ result_html.append(text)
176
+
177
+ # Return just the inner content without the outer div (added by caller)
178
+ return "".join(result_html)
179
+
180
+ # Streamlit UI
181
+ def main():
182
+ st.set_page_config(
183
+ page_title="Watermarked Text Generator",
184
+ page_icon="🔒",
185
+ layout="wide"
186
+ )
187
+
188
+ # Centered and styled title
189
+ st.markdown(
190
+ """
191
+ <div style="text-align: center; margin-bottom: 10px;">
192
+ <h1 style="color: #4A90E2; font-size: 2.5rem; font-weight: bold; margin: 0;">
193
+ 🔒 SafeSeal Watermark
194
+ </h1>
195
+ </div>
196
+ <div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1rem;">
197
+ Content-Preserving Watermarking for Large Language Model Deployments.
198
+ </div>
199
+ """,
200
+ unsafe_allow_html=True
201
+ )
202
+
203
+ # Add a nice separator
204
+ st.markdown("---")
205
+
206
+ # Sidebar for hyperparameters
207
+ with st.sidebar:
208
+ st.markdown("### ⚙️ Hyperparameters")
209
+ st.caption("Configure the watermarking algorithm")
210
+
211
+ # Main inputs
212
+ secret_key = st.text_input(
213
+ "🔑 Secret Key",
214
+ value="My_Secret_Key",
215
+ help="Secret key for deterministic randomization"
216
+ )
217
+
218
+ threshold = st.slider(
219
+ "📊 Similarity Threshold",
220
+ min_value=0.0,
221
+ max_value=1.0,
222
+ value=0.98,
223
+ step=0.01,
224
+ help="BERTScore similarity threshold (higher = more similar replacements)"
225
+ )
226
+
227
+ st.divider()
228
+
229
+ # Tournament Sampling parameters
230
+ st.markdown("### 🏆 Tournament Sampling")
231
+ st.caption("Control the randomization process")
232
+
233
+ # Hidden Top_K parameter (default 6)
234
+ Top_K = 6
235
+
236
+ m = st.number_input(
237
+ "m (Tournament Rounds)",
238
+ min_value=1,
239
+ max_value=20,
240
+ value=10,
241
+ help="Number of tournament rounds"
242
+ )
243
+
244
+ c = st.number_input(
245
+ "c (Competitors per Round)",
246
+ min_value=2,
247
+ max_value=10,
248
+ value=2,
249
+ help="Number of competitors per tournament match"
250
+ )
251
+
252
+ h = st.number_input(
253
+ "h (Context Size)",
254
+ min_value=1,
255
+ max_value=20,
256
+ value=6,
257
+ help="Number of left context tokens to consider"
258
+ )
259
+
260
+ alpha = st.slider(
261
+ "Alpha (Temperature)",
262
+ min_value=0.1,
263
+ max_value=5.0,
264
+ value=1.1,
265
+ step=0.1,
266
+ help="Temperature scaling factor for softmax"
267
+ )
268
+
269
+ # Main content area
270
+ col1, col2 = st.columns(2)
271
+
272
+ # Check if model is loaded
273
+ if 'tokenizer' not in st.session_state:
274
+ with st.spinner("Loading model... This may take a minute"):
275
+ tokenizer, lm_model = load_model()
276
+ st.session_state.tokenizer = tokenizer
277
+ st.session_state.lm_model = lm_model
278
+
279
+ with col1:
280
+ st.markdown("### 📝 Input Text")
281
+ input_text = st.text_area(
282
+ "Enter text to watermark",
283
+ height=400,
284
+ placeholder="Paste your text here to generate a watermarked version...",
285
+ label_visibility="collapsed"
286
+ )
287
+
288
+ # Process button at the bottom of input column
289
+ if st.button("🚀 Generate Watermark", type="primary", use_container_width=True):
290
+ if not input_text or len(input_text.strip()) == 0:
291
+ st.warning("Please enter some text to watermark.")
292
+ else:
293
+ with st.spinner("Generating watermarked text... This may take a few moments"):
294
+ try:
295
+ # Process the text
296
+ watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper(
297
+ input_text,
298
+ st.session_state.tokenizer,
299
+ st.session_state.lm_model,
300
+ Top_K=int(Top_K),
301
+ threshold=float(threshold),
302
+ secret_key=secret_key,
303
+ m=int(m),
304
+ c=int(c),
305
+ h=int(h),
306
+ alpha=float(alpha),
307
+ batch_size=32,
308
+ max_length=512,
309
+ similarity_context_mode='whole'
310
+ )
311
+
312
+ # Store results in session state
313
+ st.session_state.watermarked_text = watermarked_text
314
+ st.session_state.changed_words = changed_words
315
+ st.session_state.sampling_results = sampling_results
316
+ st.session_state.total_randomized = total_randomized_words
317
+ st.session_state.total_words = total_words
318
+
319
+ st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)")
320
+ except Exception as e:
321
+ st.error(f"Error generating watermark: {str(e)}")
322
+ import traceback
323
+ st.code(traceback.format_exc())
324
+
325
+ with col2:
326
+ st.markdown("### 🔒 Watermarked Text")
327
+
328
+ # Display watermarked text with highlights
329
+ if 'watermarked_text' in st.session_state:
330
+ highlight_html = create_html_with_highlights(
331
+ input_text,
332
+ st.session_state.watermarked_text,
333
+ st.session_state.changed_words,
334
+ st.session_state.sampling_results
335
+ )
336
+ # Show highlighted version with border - wrap the complete HTML
337
+ full_html = f"""
338
+ <div style='padding: 15px; background-color: #f8f9fa; border-radius: 8px; border: 1px solid #e0e0e0; min-height: 400px; max-height: 400px; overflow-y: auto; line-height: 1.8; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 15px; white-space: pre-wrap; word-wrap: break-word;'>
339
+ {highlight_html}
340
+ </div>
341
+ """
342
+ st.markdown(full_html, unsafe_allow_html=True)
343
+ else:
344
+ st.info("👈 Enter text in the left panel and click 'Generate Watermark' to start")
345
+
346
+ # Footer
347
+ st.divider()
348
+ st.caption("🔒 Secure AI Watermarking Tool | Built with SafeSeal")
349
+
350
+ # Demo warning at the bottom
351
+ st.markdown(
352
+ """
353
+ <div style="text-align: center; margin-top: 20px; padding: 10px; font-size: 0.85rem; color: #666;">
354
+ ⚠️ <strong>Demo Version</strong>: This is a demonstration using a light model to showcase the watermarking pipeline.
355
+ Results may not be perfect and are intended for testing purposes only.
356
+ </div>
357
+ """,
358
+ unsafe_allow_html=True
359
+ )
360
+
361
+ if __name__ == "__main__":
362
+ # Run the app
363
+ main()
364
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ spacy>=3.4.0
5
+ nltk>=3.8.0
6
+ bert-score>=0.3.13
7
+ pandas>=2.0.0
8
+ numpy>=1.24.0
9
+ scikit-learn>=1.3.0
10
+
utils_final.py ADDED
@@ -0,0 +1,1213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import math
5
+ import torch
6
+ import string
7
+ import spacy
8
+ import pandas as pd
9
+ import numpy as np
10
+ import nltk
11
+ import sys
12
+ import subprocess
13
+ from nltk.tokenize import word_tokenize
14
+ from nltk.stem.wordnet import WordNetLemmatizer
15
+ from nltk.corpus import wordnet as wn
16
+ import json
17
+ from filelock import FileLock
18
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
19
+ from functools import lru_cache
20
+ from typing import List, Tuple, Dict, Any
21
+ import multiprocessing as mp
22
+
23
+ # Ensure the HF_HOME environment variable points to your desired cache location
24
+ # Token removed for security
25
+ cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
26
+
27
+ # Handle potential import conflicts with sentence_transformers
28
+ try:
29
+ # Try to import bert_score directly to avoid sentence_transformers conflicts
30
+ from bert_score import score as bert_score
31
+ SIMILARITY_AVAILABLE = True
32
+
33
+ def calc_scores_bert(original_sentence, substitute_sentences):
34
+ """BERTScore function using direct bert_score import."""
35
+ try:
36
+ # Safety check: truncate inputs if they're too long
37
+ max_chars = 2000 # Roughly 500 tokens
38
+ if len(original_sentence) > max_chars:
39
+ original_sentence = original_sentence[:max_chars]
40
+
41
+ truncated_substitutes = []
42
+ for sub in substitute_sentences:
43
+ if len(sub) > max_chars:
44
+ sub = sub[:max_chars]
45
+ truncated_substitutes.append(sub)
46
+
47
+ references = [original_sentence] * len(truncated_substitutes)
48
+ P, R, F1 = bert_score(
49
+ cands=truncated_substitutes,
50
+ refs=references,
51
+ model_type="bert-base-uncased",
52
+ verbose=False
53
+ )
54
+ return F1.tolist()
55
+ except Exception as e:
56
+ return [0.5] * len(substitute_sentences)
57
+
58
+ def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
59
+ """Similarity function using direct bert_score import."""
60
+ if method == 'bert':
61
+ return calc_scores_bert(original_sentence, substitute_sentences)
62
+ else:
63
+ return [0.5] * len(substitute_sentences)
64
+
65
+ except ImportError as e:
66
+ print(f"Warning: bert_score import failed: {e}")
67
+ print("Falling back to neutral similarity scores...")
68
+ SIMILARITY_AVAILABLE = False
69
+
70
+ def calc_scores_bert(original_sentence, substitute_sentences):
71
+ """Fallback BERTScore function with neutral scores."""
72
+ return [0.5] * len(substitute_sentences)
73
+
74
+ def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
75
+ """Fallback similarity function with neutral scores."""
76
+ return [0.5] * len(substitute_sentences)
77
+
78
+ # Setup NLTK data
79
+ def setup_nltk_data():
80
+ """Setup NLTK data with error handling."""
81
+ try:
82
+ nltk.download('punkt_tab', quiet=True)
83
+ except:
84
+ pass
85
+ try:
86
+ nltk.download('averaged_perceptron_tagger_eng', quiet=True)
87
+ except:
88
+ pass
89
+ try:
90
+ nltk.download('wordnet', quiet=True)
91
+ except:
92
+ pass
93
+ try:
94
+ nltk.download('omw-1.4', quiet=True)
95
+ except:
96
+ pass
97
+
98
+ setup_nltk_data()
99
+
100
+ lemmatizer = WordNetLemmatizer()
101
+
102
+ # Load spaCy model - download if not available
103
+ try:
104
+ nlp = spacy.load("en_core_web_sm")
105
+ except OSError:
106
+ print("Downloading spaCy model...")
107
+ subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
108
+ nlp = spacy.load("en_core_web_sm")
109
+
110
+ # Define the detailed whitelist of POS tags (excluding adverbs)
111
+ DETAILED_POS_WHITELIST = {
112
+ 'NN', # Noun, singular or mass (e.g., dog, car)
113
+ 'NNS', # Noun, plural (e.g., dogs, cars)
114
+ 'VB', # Verb, base form (e.g., run, eat)
115
+ 'VBD', # Verb, past tense (e.g., ran, ate)
116
+ 'VBG', # Verb, gerund or present participle (e.g., running, eating)
117
+ 'VBN', # Verb, past participle (e.g., run, eaten)
118
+ 'VBP', # Verb, non-3rd person singular present (e.g., run, eat)
119
+ 'VBZ', # Verb, 3rd person singular present (e.g., runs, eats)
120
+ 'JJ', # Adjective (e.g., big, blue)
121
+ 'JJR', # Adjective, comparative (e.g., bigger, bluer)
122
+ 'JJS', # Adjective, superlative (e.g., biggest, bluest)
123
+ 'RB', # Adverb (e.g., very, silently)
124
+ 'RBR', # Adverb, comparative (e.g., better)
125
+ 'RBS' # Adverb, superlative (e.g., best)
126
+ }
127
+
128
+ # Global caches for better performance
129
+ _pos_cache = {}
130
+ _antonym_cache = {}
131
+ _word_validity_cache = {}
132
+
133
+ def extract_entities_and_pos(text):
134
+ """
135
+ Detect eligible tokens for replacement while skipping:
136
+ - Named entities (e.g., names, locations, organizations).
137
+ - Compound words (e.g., "Opteron-based").
138
+ - Phrasal verbs (e.g., "make up", "focus on").
139
+ - Punctuation and non-POS-whitelisted tokens.
140
+ """
141
+ doc = nlp(text)
142
+ sentence_target_pairs = [] # List to hold (sentence, target word, token index)
143
+
144
+ for sent in doc.sents:
145
+ for token in sent:
146
+ # Skip named entities using token.ent_type_ (more reliable than a text match)
147
+ if token.ent_type_:
148
+ continue
149
+
150
+ # Skip standalone punctuation
151
+ if token.is_punct:
152
+ continue
153
+
154
+ # Skip compound words (e.g., "Opteron-based")
155
+ if "-" in token.text or token.dep_ in {"compound", "amod"}:
156
+ continue
157
+
158
+ # Skip phrasal verbs (e.g., "make up", "focus on")
159
+ if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children):
160
+ continue
161
+
162
+ # Include regular tokens matching the POS whitelist
163
+ if token.tag_ in DETAILED_POS_WHITELIST:
164
+ sentence_target_pairs.append((sent.text, token.text, token.i))
165
+
166
+ return sentence_target_pairs
167
+
168
+ def preprocess_text(text):
169
+ """
170
+ Preprocesses the text to handle abbreviations, titles, and edge cases
171
+ where a period or other punctuation does not signify a sentence end.
172
+ Ensures figures, acronyms, and short names are left untouched.
173
+ """
174
+ # Protect common abbreviations like "U.S." and "Corp."
175
+ text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1<PERIOD>', text)
176
+
177
+ # Protect floating-point numbers or ranges like "3.57" or "1.48–2.10"
178
+ text = re.sub(r'(\b\d+)\.(\d+)', r'\1<PERIOD>\2', text)
179
+
180
+ # Avoid modifying standalone single-letter initials in names (e.g., "J. Smith")
181
+ text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1<PERIOD>', text)
182
+
183
+ # Protect acronym-like patterns with dots, such as "F.B.I."
184
+ text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<PERIOD>'), text)
185
+
186
+ return text
187
+
188
+ def split_sentences(text):
189
+ """
190
+ Splits text into sentences while preserving original newlines exactly.
191
+ - Protects abbreviations, acronyms, and floating-point numbers.
192
+ - Only adds newlines where necessary without duplicating them.
193
+ """
194
+ # Step 1: Protect abbreviations, floating numbers, acronyms
195
+ text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1<ABBR>', text)
196
+ text = re.sub(r'(\b\d+)\.(\d+)', r'\1<FLOAT>\2', text)
197
+ text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<ABBR>'), text)
198
+
199
+ # Step 2: Identify sentence boundaries without duplicating newlines
200
+ sentences = []
201
+ for line in text.splitlines(keepends=True): # Retain original newlines
202
+ # Split only if punctuation marks end a sentence
203
+ split_line = re.split(r'(?<=[.!?])\s+', line.strip())
204
+ sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line])
205
+
206
+ # Step 3: Restore protected patterns
207
+ return [sent.replace('<ABBR>', '.').replace('<FLOAT>', '.') for sent in sentences]
208
+
209
+ @lru_cache(maxsize=10000)
210
+ def is_valid_word(word):
211
+ """Check if a word is valid using WordNet (cached)."""
212
+ return bool(wn.synsets(word))
213
+
214
+ @lru_cache(maxsize=5000)
215
+ def get_word_pos_tags(word):
216
+ """Get POS tags for a word using both NLTK and spaCy (cached)."""
217
+ nltk_pos = nltk.pos_tag([word])[0][1]
218
+ spacy_pos = nlp(word)[0].pos_
219
+ return nltk_pos, spacy_pos
220
+
221
+ @lru_cache(maxsize=5000)
222
+ def get_word_lemma(word):
223
+ """Get lemmatized form of a word (cached)."""
224
+ return lemmatizer.lemmatize(word)
225
+
226
+ @lru_cache(maxsize=2000)
227
+ def get_word_antonyms(word):
228
+ """Get antonyms for a word (cached). Includes all lemmas from all synsets."""
229
+ target_synsets = wn.synsets(word)
230
+ antonyms = set()
231
+
232
+ # Get antonyms from all synsets and all lemmas
233
+ for syn in target_synsets:
234
+ for lem in syn.lemmas():
235
+ for ant in lem.antonyms():
236
+ # Add the antonym word (first part before the dot)
237
+ antonyms.add(ant.name().split('.')[0])
238
+ # Also add other lemmas of the antonym for completeness
239
+ for alt_lem in wn.synsets(ant.name().split('.')[0]):
240
+ for alt_ant_lem in alt_lem.lemmas():
241
+ antonyms.add(alt_ant_lem.name().split('.')[0])
242
+
243
+ return antonyms
244
+
245
+ def _are_semantically_compatible(target, candidate):
246
+ """
247
+ Check if target and candidate are semantically compatible for replacement.
248
+ Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals).
249
+ """
250
+ try:
251
+ # Direct check: if target and candidate are both specific terms for crops, animals, etc.
252
+ # check if they're NOT near-synonyms
253
+
254
+ # Agricultural/crop terms that shouldn't be swapped
255
+ agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum',
256
+ 'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume']
257
+
258
+ # If both are agricultural terms and different, block
259
+ if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and
260
+ target.lower() != candidate.lower()):
261
+ return False
262
+
263
+ target_synsets = wn.synsets(target)
264
+ cand_synsets = wn.synsets(candidate)
265
+
266
+ if not target_synsets or not cand_synsets:
267
+ return True # If no synsets, allow through
268
+
269
+ # Check if they're near-synonyms (very similar) - if so, allow
270
+ # We can use path similarity to check if they're similar enough
271
+ max_similarity = 0.0
272
+ for t_syn in target_synsets:
273
+ for c_syn in cand_synsets:
274
+ try:
275
+ similarity = t_syn.path_similarity(c_syn) or 0.0
276
+ max_similarity = max(max_similarity, similarity)
277
+ except:
278
+ pass
279
+
280
+ # If they have high path similarity (>0.5), they're similar enough to allow
281
+ if max_similarity > 0.5:
282
+ return True
283
+
284
+ # Otherwise, check if they share common direct hypernyms
285
+ target_hypernyms = set()
286
+ for syn in target_synsets:
287
+ # Get immediate hypernyms (parent concepts)
288
+ for hypernym in syn.hypernyms():
289
+ target_hypernyms.add(hypernym)
290
+
291
+ cand_hypernyms = set()
292
+ for syn in cand_synsets:
293
+ for hypernym in syn.hypernyms():
294
+ cand_hypernyms.add(hypernym)
295
+
296
+ # If they share hypernyms, check if they're both specific instances (not general terms)
297
+ common_hypernyms = target_hypernyms & cand_hypernyms
298
+
299
+ if common_hypernyms:
300
+ # Check if both words are specific instances of the same category
301
+ # If so, they shouldn't be replaced with each other
302
+ # We identify this by checking if their hypernym has many siblings
303
+ for hypernym in common_hypernyms:
304
+ siblings = hypernym.hyponyms()
305
+ # If there are many specific instances (e.g., many crops, many fruits)
306
+ # it's likely a category with specific instances that shouldn't be interchanged
307
+ if len(siblings) > 3:
308
+ # Check if hypernym name suggests a specific category
309
+ hypernym_name = hypernym.name().split('.')[0]
310
+ category_keywords = [
311
+ 'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company',
312
+ 'country', 'city', 'brand', 'product', 'food', 'vehicle'
313
+ ]
314
+
315
+ # If the hypernym contains category keywords, these are likely
316
+ # specific instances that shouldn't be swapped
317
+ if any(keyword in hypernym_name for keyword in category_keywords):
318
+ return False
319
+
320
+ return True
321
+
322
+ except Exception as e:
323
+ # On any error, allow the candidate through (conservative approach)
324
+ return True
325
+
326
+ def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400):
327
+ """
328
+ Create context windows around the target sentence for better MLM generation.
329
+ Intelligently handles tokenizer length limits by preserving the most relevant context.
330
+
331
+ Args:
332
+ full_text: The complete document text
333
+ target_sentence: The sentence containing the target word
334
+ target_word: The word to be replaced
335
+ tokenizer: The tokenizer to check length limits
336
+ max_tokens: Maximum tokens to use for context (leave room for instruction + mask)
337
+
338
+ Returns:
339
+ List of context windows with different levels of context
340
+ """
341
+ # Split full text into sentences
342
+ sentences = split_sentences(full_text)
343
+
344
+ # Find the target sentence index
345
+ target_sentence_idx = None
346
+ for i, sent in enumerate(sentences):
347
+ if target_sentence.strip() in sent.strip():
348
+ target_sentence_idx = i
349
+ break
350
+
351
+ if target_sentence_idx is None:
352
+ return [target_sentence] # Fallback to original sentence
353
+
354
+ # Create context windows with sentence-prioritized approach
355
+ context_windows = []
356
+
357
+ # Window 1: Just the target sentence (always include)
358
+ context_windows.append(target_sentence)
359
+
360
+ # Window 2: Target sentence + 1 sentence before and after (if fits)
361
+ start_idx = max(0, target_sentence_idx - 1)
362
+ end_idx = min(len(sentences), target_sentence_idx + 2)
363
+ context_window = " ".join(sentences[start_idx:end_idx])
364
+
365
+ try:
366
+ encoded_len = len(tokenizer.encode(context_window))
367
+ if encoded_len <= max_tokens:
368
+ context_windows.append(context_window)
369
+ except Exception as e:
370
+ pass
371
+
372
+ # Window 3: Target sentence + 2 sentences before and after (if fits)
373
+ start_idx = max(0, target_sentence_idx - 2)
374
+ end_idx = min(len(sentences), target_sentence_idx + 3)
375
+ context_window = " ".join(sentences[start_idx:end_idx])
376
+
377
+ try:
378
+ encoded_len = len(tokenizer.encode(context_window))
379
+ if encoded_len <= max_tokens:
380
+ context_windows.append(context_window)
381
+ except Exception as e:
382
+ pass
383
+
384
+ # Window 4: Target sentence + 3 sentences before and after (if fits)
385
+ start_idx = max(0, target_sentence_idx - 3)
386
+ end_idx = min(len(sentences), target_sentence_idx + 4)
387
+ context_window = " ".join(sentences[start_idx:end_idx])
388
+
389
+ try:
390
+ encoded_len = len(tokenizer.encode(context_window))
391
+ if encoded_len <= max_tokens:
392
+ context_windows.append(context_window)
393
+ except Exception as e:
394
+ pass
395
+
396
+ # Window 5: Intelligent context with sentence prioritization + word expansion
397
+ intelligent_context = _create_intelligent_context(
398
+ full_text, target_word, target_sentence_idx, tokenizer, max_tokens
399
+ )
400
+ context_windows.append(intelligent_context)
401
+
402
+ return context_windows
403
+
404
+ def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens):
405
+ """
406
+ Create intelligent context that prioritizes sentence boundaries while respecting token limits.
407
+ Strategy: Target sentence → Nearby sentences → Word-level expansion
408
+ """
409
+ sentences = split_sentences(full_text)
410
+
411
+ # Strategy 1: Always start with the target sentence
412
+ target_sentence = sentences[target_sentence_idx]
413
+ try:
414
+ target_sentence_tokens = len(tokenizer.encode(target_sentence))
415
+ except Exception as e:
416
+ target_sentence_tokens = 1000 # Fallback to assume it's too long
417
+
418
+ if target_sentence_tokens > max_tokens:
419
+ # If even target sentence is too long, truncate intelligently
420
+ return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens)
421
+
422
+ # Strategy 2: Expand sentence-by-sentence around target sentence
423
+ best_context = target_sentence
424
+ best_token_count = target_sentence_tokens
425
+
426
+ # Try adding sentences before and after the target sentence
427
+ for sentence_radius in range(1, min(len(sentences), 20)): # Max 20 sentences radius
428
+ start_idx = max(0, target_sentence_idx - sentence_radius)
429
+ end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1)
430
+
431
+ # Create context with complete sentences
432
+ context_sentences = sentences[start_idx:end_idx]
433
+ context_window = " ".join(context_sentences)
434
+ try:
435
+ token_count = len(tokenizer.encode(context_window))
436
+ except Exception as e:
437
+ token_count = 1000 # Fallback to assume it's too long
438
+
439
+ if token_count <= max_tokens:
440
+ # This sentence expansion fits, keep it as our best option
441
+ best_context = context_window
442
+ best_token_count = token_count
443
+ else:
444
+ # This expansion is too big, stop here
445
+ break
446
+
447
+ # Strategy 3: If we have room left, try word-level expansion within the best sentence context
448
+ remaining_tokens = max_tokens - best_token_count
449
+ if remaining_tokens > 50: # If we have significant room left
450
+ enhanced_context = _enhance_with_word_expansion(
451
+ full_text, target_word, best_context, tokenizer, remaining_tokens
452
+ )
453
+ if enhanced_context:
454
+ return enhanced_context
455
+
456
+ return best_context
457
+
458
+ def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens):
459
+ """
460
+ Enhance the current sentence-based context with word-level expansion if there's room.
461
+ """
462
+ words = full_text.split()
463
+ target_word_idx = None
464
+
465
+ # Find target word position in full text
466
+ for i, word in enumerate(words):
467
+ if word.lower() == target_word.lower():
468
+ target_word_idx = i
469
+ break
470
+
471
+ if target_word_idx is None:
472
+ return current_context
473
+
474
+ # Try to expand word-by-word around the target word
475
+ try:
476
+ current_tokens = len(tokenizer.encode(current_context))
477
+ except Exception as e:
478
+ print(f"WARNING: Error encoding current context: {e}")
479
+ current_tokens = 1000 # Fallback to assume it's too long
480
+
481
+ for expansion_size in range(1, min(len(words), 100)): # Max 100 words expansion
482
+ start_word = max(0, target_word_idx - expansion_size)
483
+ end_word = min(len(words), target_word_idx + expansion_size + 1)
484
+
485
+ expanded_context = " ".join(words[start_word:end_word])
486
+ try:
487
+ expanded_tokens = len(tokenizer.encode(expanded_context))
488
+ except Exception as e:
489
+ expanded_tokens = 1000 # Fallback to assume it's too long
490
+
491
+ if expanded_tokens <= current_tokens + remaining_tokens:
492
+ # This expansion fits within our remaining token budget
493
+ return expanded_context
494
+ else:
495
+ # This expansion is too big, stop here
496
+ break
497
+
498
+ return current_context
499
+
500
+ def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens):
501
+ """
502
+ Intelligently truncate a sentence while preserving context around the target word.
503
+ """
504
+ words = sentence.split()
505
+ target_word_idx = None
506
+
507
+ # Find target word position
508
+ for i, word in enumerate(words):
509
+ if word.lower() == target_word.lower():
510
+ target_word_idx = i
511
+ break
512
+
513
+ if target_word_idx is None:
514
+ # If target word not found, truncate from the end
515
+ truncated = " ".join(words)
516
+ try:
517
+ while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1:
518
+ words = words[:-1]
519
+ truncated = " ".join(words)
520
+ except Exception as e:
521
+ # Fallback: return first few words
522
+ truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
523
+ return truncated
524
+
525
+ # Truncate symmetrically around target word
526
+ context_words = 10 # Start with 10 words before/after
527
+ while context_words > 0:
528
+ start_word = max(0, target_word_idx - context_words)
529
+ end_word = min(len(words), target_word_idx + context_words + 1)
530
+ truncated_sentence = " ".join(words[start_word:end_word])
531
+
532
+ try:
533
+ if len(tokenizer.encode(truncated_sentence)) <= max_tokens:
534
+ return truncated_sentence
535
+ except Exception as e:
536
+ # Continue to next iteration
537
+ pass
538
+
539
+ context_words -= 1
540
+
541
+ # Fallback: just the target word with minimal context
542
+ return f"... {target_word} ..."
543
+
544
+ def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None):
545
+ """
546
+ Intelligently slice input text to fit within max_length tokens while preserving the mask token.
547
+ Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas.
548
+
549
+ Args:
550
+ input_text: The full input text to be tokenized
551
+ tokenizer: The tokenizer to use
552
+ max_length: Maximum allowed sequence length (default 512)
553
+ mask_token_id: The mask token ID to preserve
554
+
555
+ Returns:
556
+ Tuple of (sliced_input_ids, mask_position_in_sliced)
557
+ """
558
+ # First, tokenize the full input
559
+ input_ids = tokenizer.encode(input_text, add_special_tokens=True)
560
+
561
+ # If already within limits, return as is
562
+ if len(input_ids) <= max_length:
563
+ mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None
564
+ return input_ids, mask_pos
565
+
566
+ # Find mask token position
567
+ mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id]
568
+
569
+ if not mask_positions:
570
+ # No mask token found, truncate from the end
571
+ return input_ids[:max_length], None
572
+
573
+ mask_pos = mask_positions[0] # Use first mask token
574
+
575
+ # Calculate how many tokens we need to remove
576
+ excess_tokens = len(input_ids) - max_length
577
+
578
+ # Strategy: Remove tokens from both ends while preserving mask context
579
+ # Reserve some context around the mask token
580
+ mask_context_size = min(50, max_length // 4) # Reserve 25% of max_length or 50 tokens, whichever is smaller
581
+
582
+ # Calculate available space for context around mask
583
+ available_before = min(mask_pos, mask_context_size)
584
+ available_after = min(len(input_ids) - mask_pos - 1, mask_context_size)
585
+
586
+ # Calculate how much to remove from each end
587
+ tokens_to_remove_before = max(0, mask_pos - available_before)
588
+ tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after)
589
+
590
+ # Initialize removal variables
591
+ remove_before = 0
592
+ remove_after = 0
593
+
594
+ # Distribute excess tokens proportionally
595
+ if excess_tokens > 0:
596
+ if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens:
597
+ # We can remove enough from the ends
598
+ if tokens_to_remove_before >= excess_tokens // 2:
599
+ remove_before = excess_tokens // 2
600
+ remove_after = excess_tokens - remove_before
601
+ else:
602
+ remove_before = tokens_to_remove_before
603
+ remove_after = min(tokens_to_remove_after, excess_tokens - remove_before)
604
+ else:
605
+ # Need to remove more aggressively
606
+ remove_before = tokens_to_remove_before
607
+ remove_after = tokens_to_remove_after
608
+ remaining_excess = excess_tokens - remove_before - remove_after
609
+
610
+ # Remove remaining excess from the end
611
+ if remaining_excess > 0:
612
+ remove_after += remaining_excess
613
+
614
+ # Calculate final indices
615
+ start_idx = remove_before
616
+ end_idx = len(input_ids) - remove_after
617
+
618
+ # Ensure we don't exceed max_length
619
+ if end_idx - start_idx > max_length:
620
+ # Center around mask token
621
+ half_length = max_length // 2
622
+ start_idx = max(0, mask_pos - half_length)
623
+ end_idx = min(len(input_ids), start_idx + max_length)
624
+
625
+ # Slice the input_ids
626
+ sliced_input_ids = input_ids[start_idx:end_idx]
627
+
628
+ # Debug information
629
+ if len(sliced_input_ids) > max_length:
630
+ # Force truncation as final fallback
631
+ sliced_input_ids = sliced_input_ids[:max_length]
632
+
633
+ # Adjust mask position for the sliced sequence
634
+ adjusted_mask_pos = mask_pos - start_idx
635
+
636
+ return sliced_input_ids, adjusted_mask_pos
637
+
638
+ def _create_word_level_context(full_text, target_word, tokenizer, max_tokens):
639
+ """
640
+ Create context by expanding word-by-word around the target word until reaching token limit.
641
+ This maximizes context while respecting tokenizer limits.
642
+ """
643
+ words = full_text.split()
644
+ target_word_idx = None
645
+
646
+ # Find target word position in full text
647
+ for i, word in enumerate(words):
648
+ if word.lower() == target_word.lower():
649
+ target_word_idx = i
650
+ break
651
+
652
+ if target_word_idx is None:
653
+ # Fallback: expand from beginning until token limit
654
+ return _expand_from_start(words, tokenizer, max_tokens)
655
+
656
+ # Word-by-word expansion around target word
657
+ return _expand_around_target(words, target_word_idx, tokenizer, max_tokens)
658
+
659
+ def _expand_around_target(words, target_idx, tokenizer, max_tokens):
660
+ """
661
+ Expand word-by-word around target word until reaching token limit.
662
+ """
663
+ best_context = ""
664
+ best_token_count = 0
665
+
666
+ # Try different expansion sizes
667
+ for expansion_size in range(1, min(len(words), 200)): # Max 200 words expansion
668
+ start_word = max(0, target_idx - expansion_size)
669
+ end_word = min(len(words), target_idx + expansion_size + 1)
670
+
671
+ context_window = " ".join(words[start_word:end_word])
672
+ try:
673
+ token_count = len(tokenizer.encode(context_window))
674
+ except Exception as e:
675
+ token_count = 1000 # Fallback to assume it's too long
676
+
677
+ if token_count <= max_tokens:
678
+ # This expansion fits, keep it as our best option
679
+ best_context = context_window
680
+ best_token_count = token_count
681
+ else:
682
+ # This expansion is too big, stop here
683
+ break
684
+
685
+ # If we found a good context, return it
686
+ if best_context:
687
+ return best_context
688
+
689
+ # Fallback: minimal context around target word
690
+ start_word = max(0, target_idx - 5)
691
+ end_word = min(len(words), target_idx + 6)
692
+ return " ".join(words[start_word:end_word])
693
+
694
+ def _expand_from_start(words, tokenizer, max_tokens):
695
+ """
696
+ Expand from the start of the text until reaching token limit.
697
+ """
698
+ for end_idx in range(len(words), 0, -1):
699
+ context_window = " ".join(words[:end_idx])
700
+ try:
701
+ if len(tokenizer.encode(context_window)) <= max_tokens:
702
+ return context_window
703
+ except Exception as e:
704
+ # Continue to next iteration
705
+ pass
706
+
707
+ # Fallback: first few words
708
+ return " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
709
+
710
+ def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
711
+ """
712
+ Enhanced MLM inference using whole document context for better candidate generation.
713
+ """
714
+ results = {}
715
+
716
+ # Group targets by sentence for batch processing
717
+ sentence_groups = {}
718
+ for sent, target, index in sentence_target_pairs:
719
+ if sent not in sentence_groups:
720
+ sentence_groups[sent] = []
721
+ sentence_groups[sent].append((target, index))
722
+
723
+ for sentence, targets in sentence_groups.items():
724
+ # Process targets in batches
725
+ for i in range(0, len(targets), batch_size):
726
+ batch_targets = targets[i:i+batch_size]
727
+ batch_results = _process_whole_context_mlm_batch(
728
+ full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode
729
+ )
730
+ results.update(batch_results)
731
+
732
+ return results
733
+
734
+ def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
735
+ """
736
+ Process a batch of targets using whole document context for MLM.
737
+ """
738
+ results = {}
739
+
740
+ # Tokenize sentence once
741
+ doc = nlp(sentence)
742
+ tokens = [token.text for token in doc]
743
+
744
+ # Create multiple masked versions for batch processing
745
+ masked_inputs = []
746
+ mask_positions = []
747
+ contexts_for_targets = []
748
+
749
+ for target, index in targets:
750
+ if index < len(tokens):
751
+ # Create context windows with tokenizer length awareness
752
+ context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens)
753
+
754
+ # Use the most comprehensive context window that fits within token limits
755
+ full_context = context_windows[-1] # Built around the target sentence
756
+ # Select context for similarity according to mode
757
+ context = sentence if similarity_context_mode == 'sentence' else full_context
758
+
759
+ # Create masked version of the FULL context (not just the sentence)
760
+ masked_full_context = context.replace(target, tokenizer.mask_token, 1)
761
+
762
+ instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:"
763
+ input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}"
764
+
765
+ # AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors
766
+ # Estimate token count (roughly 1 token per 4 characters for English)
767
+ estimated_tokens = len(input_text) // 4
768
+ if estimated_tokens > 500: # Leave some buffer
769
+ # Truncate to roughly 2000 characters (500 tokens)
770
+ input_text = input_text[:2000]
771
+
772
+ # SIMPLE FIX: Truncate input text if it's too long
773
+ try:
774
+ temp_tokens = tokenizer.encode(input_text, add_special_tokens=True)
775
+ if len(temp_tokens) > 512:
776
+ # Truncate the input text by removing words from the end
777
+ words = input_text.split()
778
+ while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10:
779
+ words = words[:-1]
780
+ input_text = " ".join(words)
781
+ except Exception as e:
782
+ # Emergency truncation - just take first 200 words
783
+ words = input_text.split()
784
+ input_text = " ".join(words[:200])
785
+
786
+ masked_inputs.append(input_text)
787
+ # Store the original sentence-level index for reference, but mask position will be calculated during tokenization
788
+ mask_positions.append(index)
789
+ contexts_for_targets.append(context)
790
+
791
+ if not masked_inputs:
792
+ return results
793
+
794
+ # Batch tokenize
795
+ MAX_LENGTH = max_length # Use parameter for A100 optimization
796
+ batch_inputs = []
797
+ batch_mask_positions = []
798
+ batch_contexts = []
799
+
800
+ for input_text, mask_pos in zip(masked_inputs, mask_positions):
801
+ # Use intelligent token slicing to ensure we stay within MAX_LENGTH
802
+ try:
803
+ input_ids, adjusted_mask_pos = _intelligent_token_slicing(
804
+ input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id
805
+ )
806
+
807
+ if adjusted_mask_pos is not None:
808
+ batch_inputs.append(input_ids)
809
+ batch_mask_positions.append(adjusted_mask_pos)
810
+ else:
811
+ # Mask token not found in sliced sequence, skip this input
812
+ continue
813
+
814
+ except Exception as e:
815
+ # Fallback: simple truncation
816
+ try:
817
+ input_ids = tokenizer.encode(input_text, add_special_tokens=True)
818
+ if len(input_ids) > MAX_LENGTH:
819
+ input_ids = input_ids[:MAX_LENGTH]
820
+
821
+ masked_position = input_ids.index(tokenizer.mask_token_id)
822
+ batch_inputs.append(input_ids)
823
+ batch_mask_positions.append(masked_position)
824
+ except ValueError:
825
+ # Mask token not found, skip this input
826
+ continue
827
+
828
+ if not batch_inputs:
829
+ return results
830
+
831
+ # Pad sequences to same length, but ensure we don't exceed MAX_LENGTH
832
+ max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH)
833
+
834
+ # Additional safety check: truncate any sequences that are still too long
835
+ truncated_batch_inputs = []
836
+ for input_ids in batch_inputs:
837
+ if len(input_ids) > MAX_LENGTH:
838
+ input_ids = input_ids[:MAX_LENGTH]
839
+ truncated_batch_inputs.append(input_ids)
840
+
841
+ padded_inputs = []
842
+ attention_masks = []
843
+
844
+ for input_ids in truncated_batch_inputs:
845
+ attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids))
846
+ padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
847
+ padded_inputs.append(padded_ids)
848
+ attention_masks.append(attention_mask)
849
+
850
+ # Final safety check: ensure no sequence exceeds MAX_LENGTH
851
+ for i, padded_ids in enumerate(padded_inputs):
852
+ if len(padded_ids) > MAX_LENGTH:
853
+ padded_inputs[i] = padded_ids[:MAX_LENGTH]
854
+ attention_masks[i] = attention_masks[i][:MAX_LENGTH]
855
+
856
+ # Batch inference - optimized for A100 with mixed precision
857
+ with torch.no_grad():
858
+ input_tensor = torch.tensor(padded_inputs, dtype=torch.long)
859
+ attention_tensor = torch.tensor(attention_masks, dtype=torch.long)
860
+
861
+ if torch.cuda.is_available():
862
+ input_tensor = input_tensor.cuda()
863
+ attention_tensor = attention_tensor.cuda()
864
+
865
+ # Use mixed precision for A100 optimization
866
+ with torch.amp.autocast('cuda'):
867
+ outputs = lm_model(input_tensor, attention_mask=attention_tensor)
868
+ logits = outputs.logits
869
+ else:
870
+ outputs = lm_model(input_tensor, attention_mask=attention_tensor)
871
+ logits = outputs.logits
872
+
873
+ # Process results - collect filtered candidates first
874
+ batch_filtered_results = {}
875
+ for i, (target, index) in enumerate(targets):
876
+ if i < len(batch_mask_positions):
877
+ mask_pos = batch_mask_positions[i]
878
+ mask_logits = logits[i, mask_pos].squeeze()
879
+
880
+ # Get top predictions
881
+ top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1]
882
+ scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist()
883
+ words = [tokenizer.decode(token.item()).strip() for token in top_tokens]
884
+
885
+ # Filter candidates (without similarity scoring)
886
+ filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index)
887
+ if filtered_candidates:
888
+ # Attach the exact context window used for this target
889
+ batch_filtered_results[(sentence, target, index)] = {
890
+ 'filtered_words': filtered_candidates,
891
+ 'context': contexts_for_targets[i]
892
+ }
893
+
894
+ # Batch similarity scoring for all candidates
895
+ if batch_filtered_results:
896
+ similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer)
897
+ results.update(similarity_results)
898
+
899
+ return results
900
+
901
+ def _filter_candidates_batch(target, words, scores, tokens, target_index):
902
+ """
903
+ Optimized batch filtering of candidates (no similarity scoring here - moved to batch level).
904
+ """
905
+ # Basic filtering
906
+ filtered_words = []
907
+ filtered_scores = []
908
+ seen_words = set()
909
+
910
+ for word, score in zip(words, scores):
911
+ word_lower = word.lower()
912
+ if word_lower in seen_words or word_lower == target.lower():
913
+ continue
914
+ seen_words.add(word_lower)
915
+
916
+ if not is_valid_word(word):
917
+ continue
918
+
919
+ # Quick POS check
920
+ target_nltk_pos, target_spacy_pos = get_word_pos_tags(target)
921
+ cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word)
922
+
923
+ if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos:
924
+ continue
925
+
926
+ # Check antonyms (bidirectional and case-insensitive)
927
+ antonyms = get_word_antonyms(target)
928
+ if word.lower() in [ant.lower() for ant in antonyms]:
929
+ continue
930
+
931
+ # Also check if the candidate has the target as an antonym (reverse check)
932
+ candidate_antonyms = get_word_antonyms(word)
933
+ if target.lower() in [ant.lower() for ant in candidate_antonyms]:
934
+ continue
935
+
936
+ # Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard)
937
+ common_antonyms = {
938
+ 'big': ['small', 'tiny', 'little'],
939
+ 'small': ['big', 'large', 'huge'],
940
+ 'large': ['small', 'tiny', 'little'],
941
+ 'good': ['bad', 'evil', 'wrong'],
942
+ 'bad': ['good', 'great', 'excellent'],
943
+ 'high': ['low'],
944
+ 'low': ['high'],
945
+ 'new': ['old'],
946
+ 'old': ['new'],
947
+ 'fast': ['slow'],
948
+ 'slow': ['fast'],
949
+ 'rich': ['poor'],
950
+ 'poor': ['rich'],
951
+ 'hot': ['cold'],
952
+ 'cold': ['hot'],
953
+ 'happy': ['sad', 'unhappy'],
954
+ 'sad': ['happy', 'joyful'],
955
+ 'true': ['false', 'untrue'],
956
+ 'false': ['true'],
957
+ 'real': ['fake', 'unreal'],
958
+ 'fake': ['real'],
959
+ 'up': ['down'],
960
+ 'down': ['up'],
961
+ 'yes': ['no'],
962
+ 'no': ['yes'],
963
+ 'alive': ['dead'],
964
+ 'dead': ['alive'],
965
+ 'safe': ['unsafe', 'dangerous'],
966
+ 'dangerous': ['safe'],
967
+ 'clean': ['dirty'],
968
+ 'dirty': ['clean'],
969
+ 'full': ['empty'],
970
+ 'empty': ['full'],
971
+ 'open': ['closed', 'shut'],
972
+ 'closed': ['open'],
973
+ 'begin': ['end', 'finish'],
974
+ 'end': ['begin', 'start'],
975
+ 'start': ['end', 'finish'],
976
+ 'finish': ['start', 'begin'],
977
+ 'first': ['last'],
978
+ 'last': ['first']
979
+ }
980
+
981
+ # Check if word is a known antonym of target (case-insensitive)
982
+ target_lower = target.lower()
983
+ if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]:
984
+ continue
985
+
986
+ # Check if word and target are in the same specific noun category (e.g., crops, animals, companies)
987
+ # If they are different specific terms in the same category, exclude the candidate
988
+ if not _are_semantically_compatible(target, word):
989
+ continue
990
+
991
+ filtered_words.append(word)
992
+ filtered_scores.append(score)
993
+
994
+ if len(filtered_words) < 2:
995
+ return None
996
+
997
+ # Return filtered words without similarity scoring (done at batch level)
998
+ return filtered_words
999
+
1000
+ def _batch_similarity_scoring(batch_results, tokenizer):
1001
+ """
1002
+ Optimized batched similarity scoring across multiple sentences for full context.
1003
+ Processes all candidates from multiple sentences together for better efficiency.
1004
+ """
1005
+ # Collect all similarity scoring tasks
1006
+ similarity_tasks = []
1007
+ sentence_contexts = {}
1008
+
1009
+ for (sentence, target, index), value in batch_results.items():
1010
+ if value is None:
1011
+ continue
1012
+ # Support both legacy list and new dict with context
1013
+ if isinstance(value, dict):
1014
+ filtered_words = value.get('filtered_words')
1015
+ context = value.get('context', sentence)
1016
+ else:
1017
+ filtered_words = value
1018
+ context = sentence
1019
+
1020
+ # Tokenize the sentence once
1021
+ tokens = tokenizer.tokenize(sentence)
1022
+ if index >= len(tokens):
1023
+ continue
1024
+
1025
+ # Store sentence context for later use
1026
+ sentence_contexts[(sentence, target, index)] = {
1027
+ 'tokens': tokens,
1028
+ 'target_index': index,
1029
+ 'filtered_words': filtered_words
1030
+ }
1031
+
1032
+ # Create candidate sentences for this target
1033
+ for word in filtered_words:
1034
+ candidate_tokens = tokens.copy()
1035
+ candidate_tokens[index] = word
1036
+ candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens)
1037
+
1038
+ # Build full-context candidate by replacing the sentence inside the chosen context once
1039
+ candidate_full_context = context.replace(sentence, candidate_sentence, 1)
1040
+ similarity_tasks.append({
1041
+ 'original_context': context,
1042
+ 'candidate_full_context': candidate_full_context,
1043
+ 'target_word': word,
1044
+ 'context_key': (sentence, target, index)
1045
+ })
1046
+
1047
+ if not similarity_tasks:
1048
+ return {}
1049
+
1050
+ # Batch process all similarity scoring
1051
+ try:
1052
+ # Group by original full context for efficient BERTScore computation
1053
+ context_groups = {}
1054
+ for task in similarity_tasks:
1055
+ orig_ctx = task['original_context']
1056
+ if orig_ctx not in context_groups:
1057
+ context_groups[orig_ctx] = []
1058
+ context_groups[orig_ctx].append(task)
1059
+
1060
+ # Process each context group
1061
+ final_results = {}
1062
+ for orig_context, tasks in context_groups.items():
1063
+ # Extract candidate full-contexts
1064
+ candidate_contexts = [task['candidate_full_context'] for task in tasks]
1065
+
1066
+ # Batch BERTScore computation against the same full context
1067
+ try:
1068
+ similarity_scores = calc_scores_bert(orig_context, candidate_contexts)
1069
+ except Exception as e:
1070
+ # Fallback to neutral scores
1071
+ similarity_scores = [0.5] * len(candidate_contexts)
1072
+
1073
+ if similarity_scores and not all(score == 0.5 for score in similarity_scores):
1074
+ # Group results by context key
1075
+ for task, score in zip(tasks, similarity_scores):
1076
+ context_key = task['context_key']
1077
+ if context_key not in final_results:
1078
+ final_results[context_key] = []
1079
+ final_results[context_key].append((task['target_word'], score))
1080
+
1081
+ # Sort results by similarity score
1082
+ for context_key in final_results:
1083
+ final_results[context_key].sort(key=lambda x: x[1], reverse=True)
1084
+
1085
+ return final_results
1086
+
1087
+ except Exception as e:
1088
+ return {}
1089
+
1090
+ def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha):
1091
+ """
1092
+ Parallel tournament sampling for multiple targets.
1093
+ """
1094
+ results = {}
1095
+
1096
+ if not target_results:
1097
+ return results
1098
+
1099
+ def process_single_tournament(item):
1100
+ (sentence, target, index), candidates = item
1101
+ if not candidates:
1102
+ return (sentence, target, index), None
1103
+
1104
+ alternatives = [alt[0] for alt in candidates]
1105
+ similarity = [alt[1] for alt in candidates]
1106
+
1107
+ if not alternatives or not similarity:
1108
+ return (sentence, target, index), None
1109
+
1110
+ # Get context
1111
+ context_tokens = word_tokenize(sentence)
1112
+ left_context = context_tokens[max(0, index - h):index]
1113
+
1114
+ # Tournament selection
1115
+ from SynthID_randomization import tournament_select_word
1116
+ randomized_word = tournament_select_word(
1117
+ target, alternatives, similarity,
1118
+ context=left_context, key=secret_key, m=m, c=c, alpha=alpha
1119
+ )
1120
+
1121
+ return (sentence, target, index), randomized_word
1122
+
1123
+ # Process in parallel
1124
+ max_workers = max(1, min(8, len(target_results)))
1125
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1126
+ future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()}
1127
+
1128
+ for future in as_completed(future_to_item):
1129
+ key, result = future.result()
1130
+ results[key] = result
1131
+
1132
+ return results
1133
+
1134
+ def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'):
1135
+ """
1136
+ Enhanced sentence processing using whole document context for better candidate generation.
1137
+ """
1138
+ replacements = []
1139
+ sampling_results = []
1140
+ doc = nlp(sentence)
1141
+ sentence_target_pairs = extract_entities_and_pos(sentence)
1142
+
1143
+ if not sentence_target_pairs:
1144
+ return replacements, sampling_results
1145
+
1146
+ # Filter valid target pairs
1147
+ valid_pairs = []
1148
+ spacy_tokens = [token.text for token in doc]
1149
+
1150
+ for sent, target, position in sentence_target_pairs:
1151
+ if position < len(spacy_tokens) and spacy_tokens[position] == target:
1152
+ valid_pairs.append((sent, target, position))
1153
+
1154
+ if not valid_pairs:
1155
+ return replacements, sampling_results
1156
+
1157
+ # Enhanced MLM inference with whole document context
1158
+ batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode)
1159
+
1160
+ # Filter by threshold (matching original logic)
1161
+ filtered_results = {}
1162
+ for key, candidates in batch_results.items():
1163
+ if candidates:
1164
+ # Apply threshold filtering (matching original logic)
1165
+ threshold_candidates = [(word, score) for word, score in candidates if score >= threshold]
1166
+ if len(threshold_candidates) >= 2:
1167
+ filtered_results[key] = threshold_candidates
1168
+
1169
+ # Parallel tournament sampling
1170
+ tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha)
1171
+
1172
+ # Collect replacements and sampling results
1173
+ for (sent, target, position), randomized_word in tournament_results.items():
1174
+ if randomized_word:
1175
+ # Get the alternatives for this target from the filtered results
1176
+ alternatives = filtered_results.get((sent, target, position), [])
1177
+ alternatives_list = [alt[0] for alt in alternatives]
1178
+ # Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility)
1179
+ alternatives_with_similarity = [
1180
+ {"word": alt[0], "similarity": float(alt[1])} for alt in alternatives
1181
+ ]
1182
+
1183
+ # Track sampling results
1184
+ sampling_results.append({
1185
+ "word": target,
1186
+ "alternatives": alternatives_list,
1187
+ "alternatives_with_similarity": alternatives_with_similarity,
1188
+ "randomized_word": randomized_word
1189
+ })
1190
+
1191
+ replacements.append((position, target, randomized_word))
1192
+
1193
+ return replacements, sampling_results
1194
+
1195
+ # Legacy function for compatibility
1196
+ def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75):
1197
+ """
1198
+ Legacy single-target lookup function for compatibility.
1199
+ """
1200
+ results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K)
1201
+ return results.get((sentence, target, index), None)
1202
+
1203
+ def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20):
1204
+ """
1205
+ Legacy batch MLM inference function for compatibility.
1206
+ """
1207
+ return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K)
1208
+
1209
+ def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4):
1210
+ """
1211
+ Optimized batch lookup using the new batch MLM inference.
1212
+ """
1213
+ return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K)