Spaces:
Sleeping
Sleeping
Sync SafeSeal app
Browse files- README.md +66 -7
- SynthID_randomization.py +294 -0
- app.py +364 -0
- requirements.txt +10 -0
- utils_final.py +1213 -0
README.md
CHANGED
|
@@ -1,13 +1,72 @@
|
|
| 1 |
---
|
| 2 |
-
title: SafeSeal
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: purple
|
| 6 |
-
sdk:
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SafeSeal Watermark
|
| 3 |
+
emoji: 🔒
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.50.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SafeSeal Watermark
|
| 14 |
+
|
| 15 |
+
**Content-Preserving Watermarking for Large Language Model Deployments.**
|
| 16 |
+
|
| 17 |
+
Generate watermarked text by key-conditioned sampling words with context-aware synonyms.
|
| 18 |
+
|
| 19 |
+
## Features
|
| 20 |
+
|
| 21 |
+
- 🔑 **Secret Key**: Deterministic watermarking with user-controlled key
|
| 22 |
+
- 📊 **BERTScore Filtering**: Adjustable similarity threshold (0.0 - 1.0)
|
| 23 |
+
- 🏆 **Tournament Sampling**: Select synonyms using tournament-based randomization
|
| 24 |
+
- ✨ **Visual Highlighting**: See exactly which words were changed
|
| 25 |
+
- 🚀 **GPU Support**: Fast inference with automatic GPU detection
|
| 26 |
+
- 🛡️ **Smart Filtering**: Excludes antonyms, specific nouns in same category, and preserves entity names
|
| 27 |
+
|
| 28 |
+
## How It Works
|
| 29 |
+
|
| 30 |
+
1. **Entity Detection**: Extracts eligible words (nouns, verbs, adjectives, adverbs) while skipping named entities
|
| 31 |
+
2. **Candidate Generation**: Uses RoBERTa-base to generate semantically similar alternatives
|
| 32 |
+
3. **BERTScore Filtering**: Evaluates candidates against a similarity threshold
|
| 33 |
+
4. **Tournament Selection**: Deterministically selects replacements based on secret key
|
| 34 |
+
5. **Visualization**: Highlights changed words in the output
|
| 35 |
+
|
| 36 |
+
## Usage
|
| 37 |
+
|
| 38 |
+
1. Enter your text in the left panel
|
| 39 |
+
2. Adjust hyperparameters in the sidebar:
|
| 40 |
+
- **Secret Key**: Used for deterministic randomization
|
| 41 |
+
- **Threshold**: Similarity threshold (default: 0.98)
|
| 42 |
+
- **Tournament parameters**: Fine-tune the selection process
|
| 43 |
+
3. Click "🚀 Generate Watermark"
|
| 44 |
+
4. View the watermarked text with highlighted changes
|
| 45 |
+
|
| 46 |
+
## Parameters
|
| 47 |
+
|
| 48 |
+
- **Secret Key**: Used for deterministic randomization
|
| 49 |
+
- **Threshold (0.98)**: BERTScore similarity threshold - higher = more conservative changes
|
| 50 |
+
- **m (10)**: Number of tournament rounds
|
| 51 |
+
- **c (2)**: Competitors per tournament match
|
| 52 |
+
- **h (6)**: Context size (left tokens to consider)
|
| 53 |
+
- **Alpha (1.1)**: Temperature scaling factor
|
| 54 |
+
|
| 55 |
+
## Technical Details
|
| 56 |
+
|
| 57 |
+
- **Model**: RoBERTa-base for masked language modeling
|
| 58 |
+
- **Similarity Scoring**: BERTScore F1 scores
|
| 59 |
+
- **Selection**: Tournament-based deterministic sampling
|
| 60 |
+
- **Filtering**: POS tag matching, antonym exclusion, semantic compatibility checks
|
| 61 |
+
|
| 62 |
+
## Example
|
| 63 |
+
|
| 64 |
+
**Input:**
|
| 65 |
+
> "The quick brown fox jumps over the lazy dog."
|
| 66 |
+
|
| 67 |
+
**Watermarked Output:**
|
| 68 |
+
> "The swift brown fox leaps over the idle dog."
|
| 69 |
+
|
| 70 |
+
Changed words highlighted: swift (was quick), leaps (was jumps), idle (was lazy)
|
| 71 |
+
|
| 72 |
+
⚠️ **Demo Version**: This is a demonstration using a light model to showcase the watermarking pipeline. Results may not be perfect and are intended for testing purposes only.
|
SynthID_randomization.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sampling with Tournament
|
| 2 |
+
# input:
|
| 3 |
+
# original: original word
|
| 4 |
+
# alternatives: substitute candidates
|
| 5 |
+
# similarity: similarity scores
|
| 6 |
+
# context: 1-token left context
|
| 7 |
+
# key: secret key for hashing
|
| 8 |
+
# m: number of tournament rounds. Set to 2
|
| 9 |
+
# c: number of candidates per round (not including original): c is set to the number of alternatives = K
|
| 10 |
+
# alpha: scaling factor like temperature. Set to 1
|
| 11 |
+
|
| 12 |
+
# Process:
|
| 13 |
+
# 1. Normalize similarity scores to [0,1] though softmax with scaling factor alpha
|
| 14 |
+
# 2. Create random score for each candidate with m rounds: so we have a matrix of shape (K, m)
|
| 15 |
+
# 3. Each score is generated by hashing key, context, candidate, round index
|
| 16 |
+
# 4. For the first round, sampling from normalized similarity scores M = K^m candidates (each candidate can appear multiple times) and divide into groups of size K
|
| 17 |
+
# 5. For each group, select the candidate with the highest random score to go to the next round. If they are tied, just randomly select one of the tied candidates. For example, K = 3, we have 1 1 0 score, so break the 1 1 tie randomly.
|
| 18 |
+
# 6: For next round, use the winners from the previous round, but the scores from the matrix for that round. Repeat until we have one winner.
|
| 19 |
+
|
| 20 |
+
# Output: return winner candidate index
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import hmac
|
| 24 |
+
import hashlib
|
| 25 |
+
import math
|
| 26 |
+
from typing import List, Tuple
|
| 27 |
+
|
| 28 |
+
def _softmax(xs: List[float], alpha: float = 1.0) -> List[float]:
|
| 29 |
+
# Temperature-scaled softmax normalized to sum 1
|
| 30 |
+
if not xs:
|
| 31 |
+
return []
|
| 32 |
+
# subtract max for numerical stability
|
| 33 |
+
m = max(xs)
|
| 34 |
+
zs = [math.exp(alpha * (x - m)) for x in xs]
|
| 35 |
+
s = sum(zs)
|
| 36 |
+
return [z / s for z in zs]
|
| 37 |
+
|
| 38 |
+
def _hash_bytes(key: str, payload: str) -> bytes:
|
| 39 |
+
return hmac.new(key.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256).digest()
|
| 40 |
+
|
| 41 |
+
def _hash_to_uniform01(key: str, payload: str) -> float:
|
| 42 |
+
# Map 8 bytes to [0,1) as a 64-bit integer / 2**64
|
| 43 |
+
b = _hash_bytes(key, payload)[:8]
|
| 44 |
+
n = int.from_bytes(b, "big", signed=False)
|
| 45 |
+
return n / 2**64
|
| 46 |
+
|
| 47 |
+
def _ctx_str(context: List[str]) -> str:
|
| 48 |
+
# Use last token (you can widen to last-4 if you like)
|
| 49 |
+
toks = (context or [])
|
| 50 |
+
return " ".join(toks[-1:]).lower()
|
| 51 |
+
|
| 52 |
+
def _per_round_score(key: str, ctx: str, candidate: str, round_idx: int) -> int:
|
| 53 |
+
# Deterministic "random" score following Bernoulli(0.5): 0 or 1
|
| 54 |
+
payload = f"score::{ctx}::{candidate}::r{round_idx}"
|
| 55 |
+
uniform_score = _hash_to_uniform01(key, payload)
|
| 56 |
+
# Convert uniform [0,1) to Bernoulli(0.5): 0 if < 0.5, 1 if >= 0.5
|
| 57 |
+
return 1 if uniform_score >= 0.5 else 0
|
| 58 |
+
|
| 59 |
+
def _sample_categorical_from_uniform(u: float, probs: List[float]) -> int:
|
| 60 |
+
# Inverse-CDF sampling from probs using a uniform u in [0,1)
|
| 61 |
+
c = 0.0
|
| 62 |
+
for i, p in enumerate(probs):
|
| 63 |
+
c += p
|
| 64 |
+
if u < c:
|
| 65 |
+
return i
|
| 66 |
+
return len(probs) - 1 # fallback on last due to float sums
|
| 67 |
+
|
| 68 |
+
def _draw_M_candidates(key: str, ctx: str, probs: List[float], K: int, c: int, m: int) -> List[int]:
|
| 69 |
+
"""
|
| 70 |
+
Draw M = c^m candidates (indices 0..K-1) with replacement from probs,
|
| 71 |
+
using a deterministic stream of uniforms keyed by (key, ctx).
|
| 72 |
+
"""
|
| 73 |
+
M = c ** m
|
| 74 |
+
picks = []
|
| 75 |
+
for draw_idx in range(M):
|
| 76 |
+
u = _hash_to_uniform01(key, f"draw::{ctx}::m{m}::c{c}::i{draw_idx}")
|
| 77 |
+
j = _sample_categorical_from_uniform(u, probs)
|
| 78 |
+
picks.append(j)
|
| 79 |
+
return picks
|
| 80 |
+
|
| 81 |
+
def _run_tournament_round(
|
| 82 |
+
key: str,
|
| 83 |
+
ctx: str,
|
| 84 |
+
picks: List[int],
|
| 85 |
+
candidates: List[str],
|
| 86 |
+
round_idx: int,
|
| 87 |
+
group_size: int,
|
| 88 |
+
) -> List[int]:
|
| 89 |
+
"""
|
| 90 |
+
Split 'picks' into consecutive groups of size 'group_size'.
|
| 91 |
+
For each group, select the winner = argmax per-round-score;
|
| 92 |
+
tie-break deterministically.
|
| 93 |
+
Returns the list of winning indices (into candidates list).
|
| 94 |
+
"""
|
| 95 |
+
assert len(picks) % group_size == 0, "Group partition must divide evenly."
|
| 96 |
+
winners = []
|
| 97 |
+
|
| 98 |
+
# Add detailed logging for all rounds
|
| 99 |
+
# print(f"\n=== ROUND {round_idx} DETAILS ===")
|
| 100 |
+
# print(f"Total picks: {len(picks)}")
|
| 101 |
+
# print(f"Group size: {group_size}")
|
| 102 |
+
# print(f"Number of groups: {len(picks) // group_size}")
|
| 103 |
+
# print(f"Picks array: {picks}")
|
| 104 |
+
# print(f"Candidates: {candidates}")
|
| 105 |
+
|
| 106 |
+
for g in range(0, len(picks), group_size):
|
| 107 |
+
group = picks[g:g+group_size] # indices into candidates list
|
| 108 |
+
group_num = g // group_size + 1
|
| 109 |
+
|
| 110 |
+
# Compute per-round scores
|
| 111 |
+
scored: List[Tuple[int, int]] = []
|
| 112 |
+
for idx in group:
|
| 113 |
+
cand = candidates[idx]
|
| 114 |
+
s = _per_round_score(key, ctx, cand, round_idx)
|
| 115 |
+
scored.append((s, idx))
|
| 116 |
+
|
| 117 |
+
# Add detailed logging for all rounds
|
| 118 |
+
# print(f"\n--- Group {group_num} ---")
|
| 119 |
+
# print(f"Group indices: {group}")
|
| 120 |
+
# print(f"Group candidates: {[candidates[idx] for idx in group]}")
|
| 121 |
+
# print("Scores:")
|
| 122 |
+
# for score, idx in scored:
|
| 123 |
+
# print(f" {candidates[idx]}: {score}")
|
| 124 |
+
|
| 125 |
+
# Find max score; tie-break with a deterministic secondary key
|
| 126 |
+
max_s = max(s for s, _ in scored)
|
| 127 |
+
tied = [idx for (s, idx) in scored if s == max_s]
|
| 128 |
+
|
| 129 |
+
if len(tied) == 1:
|
| 130 |
+
winner_idx = tied[0]
|
| 131 |
+
winners.append(winner_idx)
|
| 132 |
+
# print(f"Winner: {candidates[winner_idx]} (score: {max_s})")
|
| 133 |
+
else:
|
| 134 |
+
# Random tie-breaker: pick one randomly from tied candidates
|
| 135 |
+
import random
|
| 136 |
+
# Create a more robust seed by combining key, context, round, and group info
|
| 137 |
+
seed_string = f"{key}:{ctx}:r{round_idx}:g{group_num}"
|
| 138 |
+
# Convert string to integer seed using a more reliable method
|
| 139 |
+
seed_value = sum(ord(c) * (i + 1) for i, c in enumerate(seed_string))
|
| 140 |
+
random.seed(seed_value)
|
| 141 |
+
winner_idx = random.choice(tied)
|
| 142 |
+
random.seed() # Reset seed to avoid affecting other random operations
|
| 143 |
+
winners.append(winner_idx)
|
| 144 |
+
# print(f"TIE! Winners: {[candidates[idx] for idx in tied]}")
|
| 145 |
+
# print(f"Random tie-breaker winner: {candidates[winner_idx]} (score: {max_s})")
|
| 146 |
+
# print(f" (Seed used: {seed_string} -> {seed_value})")
|
| 147 |
+
|
| 148 |
+
# Add summary logging for all rounds
|
| 149 |
+
# print(f"\nRound {round_idx} winners: {[candidates[idx] for idx in winners]}")
|
| 150 |
+
# print(f"Winner indices: {winners}")
|
| 151 |
+
# print("=" * 50)
|
| 152 |
+
|
| 153 |
+
return winners
|
| 154 |
+
|
| 155 |
+
def _show_score_matrix(key: str, ctx: str, candidates: List[str], m: int):
|
| 156 |
+
"""
|
| 157 |
+
Display the complete score matrix for all candidates across all rounds.
|
| 158 |
+
Each candidate gets a Bernoulli(0.5) score for each round.
|
| 159 |
+
"""
|
| 160 |
+
# print(f"\n=== SCORE MATRIX (Bernoulli 0.5) ===")
|
| 161 |
+
# print(f"Format: candidate -> [round1_score, round2_score, ...]")
|
| 162 |
+
# for i, candidate in enumerate(candidates):
|
| 163 |
+
# scores = []
|
| 164 |
+
# for round_idx in range(1, m + 1):
|
| 165 |
+
# score = _per_round_score(key, ctx, candidate, round_idx)
|
| 166 |
+
# scores.append(score)
|
| 167 |
+
# print(f"{candidate:8} -> {scores}")
|
| 168 |
+
# print("=" * 50)
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
def tournament_randomize(
|
| 172 |
+
original: str,
|
| 173 |
+
alternatives: List[str],
|
| 174 |
+
similarity: List[float],
|
| 175 |
+
context: List[str],
|
| 176 |
+
key: str,
|
| 177 |
+
m: int = 2,
|
| 178 |
+
c: int = 2,
|
| 179 |
+
alpha: float = 1.0,
|
| 180 |
+
) -> int:
|
| 181 |
+
"""
|
| 182 |
+
Tournament sampling among alternatives only (K = len(alternatives)).
|
| 183 |
+
Returns the WINNER INDEX (0..K-1) into 'alternatives'.
|
| 184 |
+
|
| 185 |
+
Steps:
|
| 186 |
+
1) softmax(similarity, alpha) -> probs over alternatives
|
| 187 |
+
2) build per-round scores for each candidate via HMAC (on the fly)
|
| 188 |
+
3) Round 1: draw M = c^m picks from probs; group into size c; pick per-group winners by scores of round 1
|
| 189 |
+
4) Next rounds: group winners into size c; pick winners using that round's scores
|
| 190 |
+
5) Repeat until one winner remains
|
| 191 |
+
"""
|
| 192 |
+
assert len(alternatives) == len(similarity) > 0, "Need at least one alternative with a similarity score."
|
| 193 |
+
assert m >= 1, "Tournament rounds m must be >= 1"
|
| 194 |
+
assert c >= 2, "Number of competitors per match c must be >= 2"
|
| 195 |
+
|
| 196 |
+
K = len(alternatives)
|
| 197 |
+
ctx = _ctx_str(context)
|
| 198 |
+
probs = _softmax(similarity, alpha=alpha)
|
| 199 |
+
|
| 200 |
+
# Add detailed logging for setup
|
| 201 |
+
# print(f"\n=== TOURNAMENT SETUP ===")
|
| 202 |
+
# print(f"Original word: '{original}'")
|
| 203 |
+
# print(f"Alternatives: {alternatives}")
|
| 204 |
+
# print(f"Similarity scores: {similarity}")
|
| 205 |
+
# print(f"Context: {context}")
|
| 206 |
+
# print(f"Key: {key}")
|
| 207 |
+
# print(f"Tournament rounds (m): {m}")
|
| 208 |
+
# print(f"Competitors per match (c): {c}")
|
| 209 |
+
# print(f"Alpha (temperature): {alpha}")
|
| 210 |
+
# print(f"K (number of alternatives): {K}")
|
| 211 |
+
# print(f"Context string: '{ctx}'")
|
| 212 |
+
# print(f"Normalized probabilities: {[f'{p:.6f}' for p in probs]}")
|
| 213 |
+
# print(f"Expected picks (M = c^m): {c**m}")
|
| 214 |
+
|
| 215 |
+
# First round: M = c^m draws from probs
|
| 216 |
+
picks = _draw_M_candidates(key, ctx, probs, K, c, m) # indices 0..K-1 into alternatives
|
| 217 |
+
|
| 218 |
+
# Show the picks array for the first round
|
| 219 |
+
# print(f"Generated picks array: {picks}")
|
| 220 |
+
# print(f"Picks correspond to words: {[alternatives[idx] for idx in picks]}")
|
| 221 |
+
|
| 222 |
+
# We treat 'candidates' as the alternatives list; indices map directly
|
| 223 |
+
candidates = alternatives
|
| 224 |
+
|
| 225 |
+
# Show the complete score matrix for all candidates across all rounds
|
| 226 |
+
_show_score_matrix(key, ctx, candidates, m)
|
| 227 |
+
|
| 228 |
+
# Run m rounds. Each round groups current list into blocks of size c,
|
| 229 |
+
# picks one per group using per-round scores.
|
| 230 |
+
current = picks
|
| 231 |
+
for r in range(1, m + 1):
|
| 232 |
+
# group_size is now c instead of K
|
| 233 |
+
current = _run_tournament_round(
|
| 234 |
+
key=key,
|
| 235 |
+
ctx=ctx,
|
| 236 |
+
picks=current,
|
| 237 |
+
candidates=candidates,
|
| 238 |
+
round_idx=r,
|
| 239 |
+
group_size=c,
|
| 240 |
+
)
|
| 241 |
+
# After each round, the number of survivors shrinks by factor c.
|
| 242 |
+
# After m rounds, we must have exactly one winner.
|
| 243 |
+
assert len(current) == 1, f"Expected a single winner after {m} rounds, got {len(current)}"
|
| 244 |
+
|
| 245 |
+
final_winner_idx = current[0]
|
| 246 |
+
# print(f"\n=== FINAL RESULT ===")
|
| 247 |
+
# print(f"Winner index: {final_winner_idx}")
|
| 248 |
+
# print(f"Winner word: '{candidates[final_winner_idx]}'")
|
| 249 |
+
|
| 250 |
+
return final_winner_idx
|
| 251 |
+
|
| 252 |
+
# Convenience wrapper that returns the selected word
|
| 253 |
+
def tournament_select_word(
|
| 254 |
+
original: str,
|
| 255 |
+
alternatives: List[str],
|
| 256 |
+
similarity: List[float],
|
| 257 |
+
context: List[str],
|
| 258 |
+
key: str,
|
| 259 |
+
m: int = 2,
|
| 260 |
+
c: int = 2,
|
| 261 |
+
alpha: float = 1.0,
|
| 262 |
+
) -> str:
|
| 263 |
+
idx = tournament_randomize(original, alternatives, similarity, context, key, m, c, alpha)
|
| 264 |
+
return alternatives[idx]
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# # Sample running
|
| 268 |
+
# original = "big"
|
| 269 |
+
# alts = ["large", "huge", "massive"] # K = 3
|
| 270 |
+
# sims = [0.82, 0.55, 0.60] # any similarity scores
|
| 271 |
+
# ctx = ["a"] # 1-token left context (e.g., "... a ___")
|
| 272 |
+
# key = "super-secret-key"
|
| 273 |
+
# m = 2
|
| 274 |
+
# c = 2
|
| 275 |
+
# alpha = 1.0
|
| 276 |
+
|
| 277 |
+
# winner_idx = tournament_randomize(original, alts, sims, ctx, key, m, c, alpha)
|
| 278 |
+
# print("Winner index:", winner_idx, "->", alts[winner_idx])
|
| 279 |
+
|
| 280 |
+
# # Or directly:
|
| 281 |
+
# print("Winner word:", tournament_select_word(original, alts, sims, ctx, key, m, c, alpha))
|
| 282 |
+
|
| 283 |
+
# # Test with a different key
|
| 284 |
+
# key2 = "different-secret-key"
|
| 285 |
+
# winner_idx2 = tournament_randomize(original, alts, sims, ctx, key2, m, c, alpha)
|
| 286 |
+
# print("Winner index with different key:", winner_idx2, "->", alts[winner_idx2])
|
| 287 |
+
|
| 288 |
+
# # Test with a different key
|
| 289 |
+
# key3 = "key224242"
|
| 290 |
+
# winner_idx3 = tournament_randomize(original, alts, sims, ctx, key3, m, c, alpha)
|
| 291 |
+
# print("Winner index with different key:", winner_idx3, "->", alts[winner_idx3])
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
app.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
| 5 |
+
import spacy
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import nltk
|
| 9 |
+
from nltk.tokenize import word_tokenize
|
| 10 |
+
from utils_final import extract_entities_and_pos, whole_context_process_sentence
|
| 11 |
+
|
| 12 |
+
# Download NLTK data if not available
|
| 13 |
+
def setup_nltk():
|
| 14 |
+
"""Setup NLTK data with error handling."""
|
| 15 |
+
try:
|
| 16 |
+
nltk.download('punkt_tab', quiet=True)
|
| 17 |
+
except:
|
| 18 |
+
pass
|
| 19 |
+
try:
|
| 20 |
+
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
|
| 21 |
+
except:
|
| 22 |
+
pass
|
| 23 |
+
try:
|
| 24 |
+
nltk.download('wordnet', quiet=True)
|
| 25 |
+
except:
|
| 26 |
+
pass
|
| 27 |
+
try:
|
| 28 |
+
nltk.download('omw-1.4', quiet=True)
|
| 29 |
+
except:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
setup_nltk()
|
| 33 |
+
|
| 34 |
+
# Set environment
|
| 35 |
+
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
|
| 36 |
+
|
| 37 |
+
# Load spaCy model - download if not available
|
| 38 |
+
try:
|
| 39 |
+
nlp = spacy.load("en_core_web_sm")
|
| 40 |
+
except OSError:
|
| 41 |
+
print("Downloading spaCy model...")
|
| 42 |
+
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
|
| 43 |
+
nlp = spacy.load("en_core_web_sm")
|
| 44 |
+
|
| 45 |
+
# Define apply_replacements function (from Safeseal_gen_final.py)
|
| 46 |
+
def apply_replacements(sentence, replacements):
|
| 47 |
+
"""
|
| 48 |
+
Apply replacements to the sentence while preserving original formatting, spacing, and punctuation.
|
| 49 |
+
"""
|
| 50 |
+
doc = nlp(sentence) # Tokenize the sentence
|
| 51 |
+
tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens
|
| 52 |
+
|
| 53 |
+
# Apply replacements based on token positions
|
| 54 |
+
for position, target, replacement in replacements:
|
| 55 |
+
if position < len(tokens) and tokens[position].strip() == target:
|
| 56 |
+
tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "")
|
| 57 |
+
|
| 58 |
+
# Reassemble the sentence
|
| 59 |
+
return "".join(tokens)
|
| 60 |
+
|
| 61 |
+
# Initialize session state for model caching
|
| 62 |
+
@st.cache_resource
|
| 63 |
+
def load_model():
|
| 64 |
+
"""Load the model and tokenizer (cached to avoid reloading on every run)"""
|
| 65 |
+
print("Loading model...")
|
| 66 |
+
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
| 67 |
+
lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager")
|
| 68 |
+
|
| 69 |
+
tokenizer.model_max_length = 512
|
| 70 |
+
tokenizer.max_len = 512
|
| 71 |
+
|
| 72 |
+
if hasattr(lm_model.config, 'max_position_embeddings'):
|
| 73 |
+
lm_model.config.max_position_embeddings = 512
|
| 74 |
+
|
| 75 |
+
lm_model.eval()
|
| 76 |
+
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
lm_model = lm_model.cuda()
|
| 79 |
+
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
|
| 80 |
+
else:
|
| 81 |
+
print("Model loaded on CPU")
|
| 82 |
+
|
| 83 |
+
return tokenizer, lm_model
|
| 84 |
+
|
| 85 |
+
sampling_results = []
|
| 86 |
+
|
| 87 |
+
def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'):
|
| 88 |
+
"""
|
| 89 |
+
Wrapper function to process text and return watermarked output with tracking of changes.
|
| 90 |
+
"""
|
| 91 |
+
global sampling_results
|
| 92 |
+
sampling_results = []
|
| 93 |
+
|
| 94 |
+
lines = text.splitlines(keepends=True)
|
| 95 |
+
final_text = []
|
| 96 |
+
total_randomized_words = 0
|
| 97 |
+
total_words = len(word_tokenize(text))
|
| 98 |
+
|
| 99 |
+
# Track changed words and their positions
|
| 100 |
+
changed_words = [] # List of (original, replacement, position)
|
| 101 |
+
|
| 102 |
+
for line in lines:
|
| 103 |
+
if line.strip():
|
| 104 |
+
replacements, sampling_results_line = whole_context_process_sentence(
|
| 105 |
+
text,
|
| 106 |
+
line.strip(),
|
| 107 |
+
tokenizer, lm_model, Top_K, threshold,
|
| 108 |
+
secret_key, m, c, h, alpha, "output",
|
| 109 |
+
batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
sampling_results.extend(sampling_results_line)
|
| 113 |
+
|
| 114 |
+
if replacements:
|
| 115 |
+
randomized_line = apply_replacements(line, replacements)
|
| 116 |
+
final_text.append(randomized_line)
|
| 117 |
+
|
| 118 |
+
# Track ONLY actual changes (where original != replacement)
|
| 119 |
+
for position, original, replacement in replacements:
|
| 120 |
+
if original != replacement:
|
| 121 |
+
changed_words.append((original, replacement, position))
|
| 122 |
+
total_randomized_words += 1
|
| 123 |
+
else:
|
| 124 |
+
final_text.append(line)
|
| 125 |
+
else:
|
| 126 |
+
final_text.append(line)
|
| 127 |
+
|
| 128 |
+
return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results
|
| 129 |
+
|
| 130 |
+
def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results):
|
| 131 |
+
"""
|
| 132 |
+
Create HTML with highlighted changed words using spaCy tokenization.
|
| 133 |
+
"""
|
| 134 |
+
# Create a set of replacement words that were actually changed (not same as original)
|
| 135 |
+
actual_replacements = set()
|
| 136 |
+
replacement_to_original = {}
|
| 137 |
+
|
| 138 |
+
for original, replacement, _ in changed_words_info:
|
| 139 |
+
if original.lower() != replacement.lower(): # Only map actual changes
|
| 140 |
+
actual_replacements.add(replacement.lower())
|
| 141 |
+
replacement_to_original[replacement.lower()] = original
|
| 142 |
+
|
| 143 |
+
# Parse watermarked text with spaCy
|
| 144 |
+
doc_watermarked = nlp(watermarked_text)
|
| 145 |
+
|
| 146 |
+
# Build HTML by processing the watermarked text
|
| 147 |
+
result_html = []
|
| 148 |
+
words_highlighted = set() # Track which words we've highlighted (to avoid duplicates)
|
| 149 |
+
|
| 150 |
+
for token in doc_watermarked:
|
| 151 |
+
text = token.text_with_ws
|
| 152 |
+
text_clean = token.text.strip('.,!?;:')
|
| 153 |
+
text_lower = text_clean.lower()
|
| 154 |
+
|
| 155 |
+
# Only highlight if this word is in our actual replacements set
|
| 156 |
+
# and we haven't already highlighted this exact word
|
| 157 |
+
if text_lower in actual_replacements and text_lower not in words_highlighted:
|
| 158 |
+
original_word = replacement_to_original.get(text_lower, text_clean)
|
| 159 |
+
|
| 160 |
+
# Only highlight if actually different from original
|
| 161 |
+
if original_word.lower() != text_lower:
|
| 162 |
+
tooltip = f"Original: {original_word} → New: {text_clean}"
|
| 163 |
+
# Enhanced highlighting with better colors
|
| 164 |
+
highlighted_text = f"<mark style='background: linear-gradient(120deg, #84fab0 0%, #8fd3f4 100%); padding: 2px 6px; border-radius: 4px; font-weight: 500; box-shadow: 0 1px 2px rgba(0,0,0,0.1);' title='{tooltip}'>{text_clean}</mark>"
|
| 165 |
+
|
| 166 |
+
# Preserve trailing whitespace and punctuation
|
| 167 |
+
if text != text_clean:
|
| 168 |
+
highlighted_text += text[len(text_clean):]
|
| 169 |
+
|
| 170 |
+
result_html.append(highlighted_text)
|
| 171 |
+
words_highlighted.add(text_lower) # Mark as highlighted
|
| 172 |
+
else:
|
| 173 |
+
result_html.append(text)
|
| 174 |
+
else:
|
| 175 |
+
result_html.append(text)
|
| 176 |
+
|
| 177 |
+
# Return just the inner content without the outer div (added by caller)
|
| 178 |
+
return "".join(result_html)
|
| 179 |
+
|
| 180 |
+
# Streamlit UI
|
| 181 |
+
def main():
|
| 182 |
+
st.set_page_config(
|
| 183 |
+
page_title="Watermarked Text Generator",
|
| 184 |
+
page_icon="🔒",
|
| 185 |
+
layout="wide"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Centered and styled title
|
| 189 |
+
st.markdown(
|
| 190 |
+
"""
|
| 191 |
+
<div style="text-align: center; margin-bottom: 10px;">
|
| 192 |
+
<h1 style="color: #4A90E2; font-size: 2.5rem; font-weight: bold; margin: 0;">
|
| 193 |
+
🔒 SafeSeal Watermark
|
| 194 |
+
</h1>
|
| 195 |
+
</div>
|
| 196 |
+
<div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1rem;">
|
| 197 |
+
Content-Preserving Watermarking for Large Language Model Deployments.
|
| 198 |
+
</div>
|
| 199 |
+
""",
|
| 200 |
+
unsafe_allow_html=True
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Add a nice separator
|
| 204 |
+
st.markdown("---")
|
| 205 |
+
|
| 206 |
+
# Sidebar for hyperparameters
|
| 207 |
+
with st.sidebar:
|
| 208 |
+
st.markdown("### ⚙️ Hyperparameters")
|
| 209 |
+
st.caption("Configure the watermarking algorithm")
|
| 210 |
+
|
| 211 |
+
# Main inputs
|
| 212 |
+
secret_key = st.text_input(
|
| 213 |
+
"🔑 Secret Key",
|
| 214 |
+
value="My_Secret_Key",
|
| 215 |
+
help="Secret key for deterministic randomization"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
threshold = st.slider(
|
| 219 |
+
"📊 Similarity Threshold",
|
| 220 |
+
min_value=0.0,
|
| 221 |
+
max_value=1.0,
|
| 222 |
+
value=0.98,
|
| 223 |
+
step=0.01,
|
| 224 |
+
help="BERTScore similarity threshold (higher = more similar replacements)"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
st.divider()
|
| 228 |
+
|
| 229 |
+
# Tournament Sampling parameters
|
| 230 |
+
st.markdown("### 🏆 Tournament Sampling")
|
| 231 |
+
st.caption("Control the randomization process")
|
| 232 |
+
|
| 233 |
+
# Hidden Top_K parameter (default 6)
|
| 234 |
+
Top_K = 6
|
| 235 |
+
|
| 236 |
+
m = st.number_input(
|
| 237 |
+
"m (Tournament Rounds)",
|
| 238 |
+
min_value=1,
|
| 239 |
+
max_value=20,
|
| 240 |
+
value=10,
|
| 241 |
+
help="Number of tournament rounds"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
c = st.number_input(
|
| 245 |
+
"c (Competitors per Round)",
|
| 246 |
+
min_value=2,
|
| 247 |
+
max_value=10,
|
| 248 |
+
value=2,
|
| 249 |
+
help="Number of competitors per tournament match"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
h = st.number_input(
|
| 253 |
+
"h (Context Size)",
|
| 254 |
+
min_value=1,
|
| 255 |
+
max_value=20,
|
| 256 |
+
value=6,
|
| 257 |
+
help="Number of left context tokens to consider"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
alpha = st.slider(
|
| 261 |
+
"Alpha (Temperature)",
|
| 262 |
+
min_value=0.1,
|
| 263 |
+
max_value=5.0,
|
| 264 |
+
value=1.1,
|
| 265 |
+
step=0.1,
|
| 266 |
+
help="Temperature scaling factor for softmax"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Main content area
|
| 270 |
+
col1, col2 = st.columns(2)
|
| 271 |
+
|
| 272 |
+
# Check if model is loaded
|
| 273 |
+
if 'tokenizer' not in st.session_state:
|
| 274 |
+
with st.spinner("Loading model... This may take a minute"):
|
| 275 |
+
tokenizer, lm_model = load_model()
|
| 276 |
+
st.session_state.tokenizer = tokenizer
|
| 277 |
+
st.session_state.lm_model = lm_model
|
| 278 |
+
|
| 279 |
+
with col1:
|
| 280 |
+
st.markdown("### 📝 Input Text")
|
| 281 |
+
input_text = st.text_area(
|
| 282 |
+
"Enter text to watermark",
|
| 283 |
+
height=400,
|
| 284 |
+
placeholder="Paste your text here to generate a watermarked version...",
|
| 285 |
+
label_visibility="collapsed"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Process button at the bottom of input column
|
| 289 |
+
if st.button("🚀 Generate Watermark", type="primary", use_container_width=True):
|
| 290 |
+
if not input_text or len(input_text.strip()) == 0:
|
| 291 |
+
st.warning("Please enter some text to watermark.")
|
| 292 |
+
else:
|
| 293 |
+
with st.spinner("Generating watermarked text... This may take a few moments"):
|
| 294 |
+
try:
|
| 295 |
+
# Process the text
|
| 296 |
+
watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper(
|
| 297 |
+
input_text,
|
| 298 |
+
st.session_state.tokenizer,
|
| 299 |
+
st.session_state.lm_model,
|
| 300 |
+
Top_K=int(Top_K),
|
| 301 |
+
threshold=float(threshold),
|
| 302 |
+
secret_key=secret_key,
|
| 303 |
+
m=int(m),
|
| 304 |
+
c=int(c),
|
| 305 |
+
h=int(h),
|
| 306 |
+
alpha=float(alpha),
|
| 307 |
+
batch_size=32,
|
| 308 |
+
max_length=512,
|
| 309 |
+
similarity_context_mode='whole'
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Store results in session state
|
| 313 |
+
st.session_state.watermarked_text = watermarked_text
|
| 314 |
+
st.session_state.changed_words = changed_words
|
| 315 |
+
st.session_state.sampling_results = sampling_results
|
| 316 |
+
st.session_state.total_randomized = total_randomized_words
|
| 317 |
+
st.session_state.total_words = total_words
|
| 318 |
+
|
| 319 |
+
st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)")
|
| 320 |
+
except Exception as e:
|
| 321 |
+
st.error(f"Error generating watermark: {str(e)}")
|
| 322 |
+
import traceback
|
| 323 |
+
st.code(traceback.format_exc())
|
| 324 |
+
|
| 325 |
+
with col2:
|
| 326 |
+
st.markdown("### 🔒 Watermarked Text")
|
| 327 |
+
|
| 328 |
+
# Display watermarked text with highlights
|
| 329 |
+
if 'watermarked_text' in st.session_state:
|
| 330 |
+
highlight_html = create_html_with_highlights(
|
| 331 |
+
input_text,
|
| 332 |
+
st.session_state.watermarked_text,
|
| 333 |
+
st.session_state.changed_words,
|
| 334 |
+
st.session_state.sampling_results
|
| 335 |
+
)
|
| 336 |
+
# Show highlighted version with border - wrap the complete HTML
|
| 337 |
+
full_html = f"""
|
| 338 |
+
<div style='padding: 15px; background-color: #f8f9fa; border-radius: 8px; border: 1px solid #e0e0e0; min-height: 400px; max-height: 400px; overflow-y: auto; line-height: 1.8; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 15px; white-space: pre-wrap; word-wrap: break-word;'>
|
| 339 |
+
{highlight_html}
|
| 340 |
+
</div>
|
| 341 |
+
"""
|
| 342 |
+
st.markdown(full_html, unsafe_allow_html=True)
|
| 343 |
+
else:
|
| 344 |
+
st.info("👈 Enter text in the left panel and click 'Generate Watermark' to start")
|
| 345 |
+
|
| 346 |
+
# Footer
|
| 347 |
+
st.divider()
|
| 348 |
+
st.caption("🔒 Secure AI Watermarking Tool | Built with SafeSeal")
|
| 349 |
+
|
| 350 |
+
# Demo warning at the bottom
|
| 351 |
+
st.markdown(
|
| 352 |
+
"""
|
| 353 |
+
<div style="text-align: center; margin-top: 20px; padding: 10px; font-size: 0.85rem; color: #666;">
|
| 354 |
+
⚠️ <strong>Demo Version</strong>: This is a demonstration using a light model to showcase the watermarking pipeline.
|
| 355 |
+
Results may not be perfect and are intended for testing purposes only.
|
| 356 |
+
</div>
|
| 357 |
+
""",
|
| 358 |
+
unsafe_allow_html=True
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
# Run the app
|
| 363 |
+
main()
|
| 364 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
spacy>=3.4.0
|
| 5 |
+
nltk>=3.8.0
|
| 6 |
+
bert-score>=0.3.13
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
scikit-learn>=1.3.0
|
| 10 |
+
|
utils_final.py
ADDED
|
@@ -0,0 +1,1213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import string
|
| 7 |
+
import spacy
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import nltk
|
| 11 |
+
import sys
|
| 12 |
+
import subprocess
|
| 13 |
+
from nltk.tokenize import word_tokenize
|
| 14 |
+
from nltk.stem.wordnet import WordNetLemmatizer
|
| 15 |
+
from nltk.corpus import wordnet as wn
|
| 16 |
+
import json
|
| 17 |
+
from filelock import FileLock
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import List, Tuple, Dict, Any
|
| 21 |
+
import multiprocessing as mp
|
| 22 |
+
|
| 23 |
+
# Ensure the HF_HOME environment variable points to your desired cache location
|
| 24 |
+
# Token removed for security
|
| 25 |
+
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
|
| 26 |
+
|
| 27 |
+
# Handle potential import conflicts with sentence_transformers
|
| 28 |
+
try:
|
| 29 |
+
# Try to import bert_score directly to avoid sentence_transformers conflicts
|
| 30 |
+
from bert_score import score as bert_score
|
| 31 |
+
SIMILARITY_AVAILABLE = True
|
| 32 |
+
|
| 33 |
+
def calc_scores_bert(original_sentence, substitute_sentences):
|
| 34 |
+
"""BERTScore function using direct bert_score import."""
|
| 35 |
+
try:
|
| 36 |
+
# Safety check: truncate inputs if they're too long
|
| 37 |
+
max_chars = 2000 # Roughly 500 tokens
|
| 38 |
+
if len(original_sentence) > max_chars:
|
| 39 |
+
original_sentence = original_sentence[:max_chars]
|
| 40 |
+
|
| 41 |
+
truncated_substitutes = []
|
| 42 |
+
for sub in substitute_sentences:
|
| 43 |
+
if len(sub) > max_chars:
|
| 44 |
+
sub = sub[:max_chars]
|
| 45 |
+
truncated_substitutes.append(sub)
|
| 46 |
+
|
| 47 |
+
references = [original_sentence] * len(truncated_substitutes)
|
| 48 |
+
P, R, F1 = bert_score(
|
| 49 |
+
cands=truncated_substitutes,
|
| 50 |
+
refs=references,
|
| 51 |
+
model_type="bert-base-uncased",
|
| 52 |
+
verbose=False
|
| 53 |
+
)
|
| 54 |
+
return F1.tolist()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return [0.5] * len(substitute_sentences)
|
| 57 |
+
|
| 58 |
+
def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
|
| 59 |
+
"""Similarity function using direct bert_score import."""
|
| 60 |
+
if method == 'bert':
|
| 61 |
+
return calc_scores_bert(original_sentence, substitute_sentences)
|
| 62 |
+
else:
|
| 63 |
+
return [0.5] * len(substitute_sentences)
|
| 64 |
+
|
| 65 |
+
except ImportError as e:
|
| 66 |
+
print(f"Warning: bert_score import failed: {e}")
|
| 67 |
+
print("Falling back to neutral similarity scores...")
|
| 68 |
+
SIMILARITY_AVAILABLE = False
|
| 69 |
+
|
| 70 |
+
def calc_scores_bert(original_sentence, substitute_sentences):
|
| 71 |
+
"""Fallback BERTScore function with neutral scores."""
|
| 72 |
+
return [0.5] * len(substitute_sentences)
|
| 73 |
+
|
| 74 |
+
def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
|
| 75 |
+
"""Fallback similarity function with neutral scores."""
|
| 76 |
+
return [0.5] * len(substitute_sentences)
|
| 77 |
+
|
| 78 |
+
# Setup NLTK data
|
| 79 |
+
def setup_nltk_data():
|
| 80 |
+
"""Setup NLTK data with error handling."""
|
| 81 |
+
try:
|
| 82 |
+
nltk.download('punkt_tab', quiet=True)
|
| 83 |
+
except:
|
| 84 |
+
pass
|
| 85 |
+
try:
|
| 86 |
+
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
|
| 87 |
+
except:
|
| 88 |
+
pass
|
| 89 |
+
try:
|
| 90 |
+
nltk.download('wordnet', quiet=True)
|
| 91 |
+
except:
|
| 92 |
+
pass
|
| 93 |
+
try:
|
| 94 |
+
nltk.download('omw-1.4', quiet=True)
|
| 95 |
+
except:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
setup_nltk_data()
|
| 99 |
+
|
| 100 |
+
lemmatizer = WordNetLemmatizer()
|
| 101 |
+
|
| 102 |
+
# Load spaCy model - download if not available
|
| 103 |
+
try:
|
| 104 |
+
nlp = spacy.load("en_core_web_sm")
|
| 105 |
+
except OSError:
|
| 106 |
+
print("Downloading spaCy model...")
|
| 107 |
+
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
|
| 108 |
+
nlp = spacy.load("en_core_web_sm")
|
| 109 |
+
|
| 110 |
+
# Define the detailed whitelist of POS tags (excluding adverbs)
|
| 111 |
+
DETAILED_POS_WHITELIST = {
|
| 112 |
+
'NN', # Noun, singular or mass (e.g., dog, car)
|
| 113 |
+
'NNS', # Noun, plural (e.g., dogs, cars)
|
| 114 |
+
'VB', # Verb, base form (e.g., run, eat)
|
| 115 |
+
'VBD', # Verb, past tense (e.g., ran, ate)
|
| 116 |
+
'VBG', # Verb, gerund or present participle (e.g., running, eating)
|
| 117 |
+
'VBN', # Verb, past participle (e.g., run, eaten)
|
| 118 |
+
'VBP', # Verb, non-3rd person singular present (e.g., run, eat)
|
| 119 |
+
'VBZ', # Verb, 3rd person singular present (e.g., runs, eats)
|
| 120 |
+
'JJ', # Adjective (e.g., big, blue)
|
| 121 |
+
'JJR', # Adjective, comparative (e.g., bigger, bluer)
|
| 122 |
+
'JJS', # Adjective, superlative (e.g., biggest, bluest)
|
| 123 |
+
'RB', # Adverb (e.g., very, silently)
|
| 124 |
+
'RBR', # Adverb, comparative (e.g., better)
|
| 125 |
+
'RBS' # Adverb, superlative (e.g., best)
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# Global caches for better performance
|
| 129 |
+
_pos_cache = {}
|
| 130 |
+
_antonym_cache = {}
|
| 131 |
+
_word_validity_cache = {}
|
| 132 |
+
|
| 133 |
+
def extract_entities_and_pos(text):
|
| 134 |
+
"""
|
| 135 |
+
Detect eligible tokens for replacement while skipping:
|
| 136 |
+
- Named entities (e.g., names, locations, organizations).
|
| 137 |
+
- Compound words (e.g., "Opteron-based").
|
| 138 |
+
- Phrasal verbs (e.g., "make up", "focus on").
|
| 139 |
+
- Punctuation and non-POS-whitelisted tokens.
|
| 140 |
+
"""
|
| 141 |
+
doc = nlp(text)
|
| 142 |
+
sentence_target_pairs = [] # List to hold (sentence, target word, token index)
|
| 143 |
+
|
| 144 |
+
for sent in doc.sents:
|
| 145 |
+
for token in sent:
|
| 146 |
+
# Skip named entities using token.ent_type_ (more reliable than a text match)
|
| 147 |
+
if token.ent_type_:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Skip standalone punctuation
|
| 151 |
+
if token.is_punct:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
# Skip compound words (e.g., "Opteron-based")
|
| 155 |
+
if "-" in token.text or token.dep_ in {"compound", "amod"}:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
# Skip phrasal verbs (e.g., "make up", "focus on")
|
| 159 |
+
if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children):
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# Include regular tokens matching the POS whitelist
|
| 163 |
+
if token.tag_ in DETAILED_POS_WHITELIST:
|
| 164 |
+
sentence_target_pairs.append((sent.text, token.text, token.i))
|
| 165 |
+
|
| 166 |
+
return sentence_target_pairs
|
| 167 |
+
|
| 168 |
+
def preprocess_text(text):
|
| 169 |
+
"""
|
| 170 |
+
Preprocesses the text to handle abbreviations, titles, and edge cases
|
| 171 |
+
where a period or other punctuation does not signify a sentence end.
|
| 172 |
+
Ensures figures, acronyms, and short names are left untouched.
|
| 173 |
+
"""
|
| 174 |
+
# Protect common abbreviations like "U.S." and "Corp."
|
| 175 |
+
text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1<PERIOD>', text)
|
| 176 |
+
|
| 177 |
+
# Protect floating-point numbers or ranges like "3.57" or "1.48–2.10"
|
| 178 |
+
text = re.sub(r'(\b\d+)\.(\d+)', r'\1<PERIOD>\2', text)
|
| 179 |
+
|
| 180 |
+
# Avoid modifying standalone single-letter initials in names (e.g., "J. Smith")
|
| 181 |
+
text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1<PERIOD>', text)
|
| 182 |
+
|
| 183 |
+
# Protect acronym-like patterns with dots, such as "F.B.I."
|
| 184 |
+
text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<PERIOD>'), text)
|
| 185 |
+
|
| 186 |
+
return text
|
| 187 |
+
|
| 188 |
+
def split_sentences(text):
|
| 189 |
+
"""
|
| 190 |
+
Splits text into sentences while preserving original newlines exactly.
|
| 191 |
+
- Protects abbreviations, acronyms, and floating-point numbers.
|
| 192 |
+
- Only adds newlines where necessary without duplicating them.
|
| 193 |
+
"""
|
| 194 |
+
# Step 1: Protect abbreviations, floating numbers, acronyms
|
| 195 |
+
text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1<ABBR>', text)
|
| 196 |
+
text = re.sub(r'(\b\d+)\.(\d+)', r'\1<FLOAT>\2', text)
|
| 197 |
+
text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<ABBR>'), text)
|
| 198 |
+
|
| 199 |
+
# Step 2: Identify sentence boundaries without duplicating newlines
|
| 200 |
+
sentences = []
|
| 201 |
+
for line in text.splitlines(keepends=True): # Retain original newlines
|
| 202 |
+
# Split only if punctuation marks end a sentence
|
| 203 |
+
split_line = re.split(r'(?<=[.!?])\s+', line.strip())
|
| 204 |
+
sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line])
|
| 205 |
+
|
| 206 |
+
# Step 3: Restore protected patterns
|
| 207 |
+
return [sent.replace('<ABBR>', '.').replace('<FLOAT>', '.') for sent in sentences]
|
| 208 |
+
|
| 209 |
+
@lru_cache(maxsize=10000)
|
| 210 |
+
def is_valid_word(word):
|
| 211 |
+
"""Check if a word is valid using WordNet (cached)."""
|
| 212 |
+
return bool(wn.synsets(word))
|
| 213 |
+
|
| 214 |
+
@lru_cache(maxsize=5000)
|
| 215 |
+
def get_word_pos_tags(word):
|
| 216 |
+
"""Get POS tags for a word using both NLTK and spaCy (cached)."""
|
| 217 |
+
nltk_pos = nltk.pos_tag([word])[0][1]
|
| 218 |
+
spacy_pos = nlp(word)[0].pos_
|
| 219 |
+
return nltk_pos, spacy_pos
|
| 220 |
+
|
| 221 |
+
@lru_cache(maxsize=5000)
|
| 222 |
+
def get_word_lemma(word):
|
| 223 |
+
"""Get lemmatized form of a word (cached)."""
|
| 224 |
+
return lemmatizer.lemmatize(word)
|
| 225 |
+
|
| 226 |
+
@lru_cache(maxsize=2000)
|
| 227 |
+
def get_word_antonyms(word):
|
| 228 |
+
"""Get antonyms for a word (cached). Includes all lemmas from all synsets."""
|
| 229 |
+
target_synsets = wn.synsets(word)
|
| 230 |
+
antonyms = set()
|
| 231 |
+
|
| 232 |
+
# Get antonyms from all synsets and all lemmas
|
| 233 |
+
for syn in target_synsets:
|
| 234 |
+
for lem in syn.lemmas():
|
| 235 |
+
for ant in lem.antonyms():
|
| 236 |
+
# Add the antonym word (first part before the dot)
|
| 237 |
+
antonyms.add(ant.name().split('.')[0])
|
| 238 |
+
# Also add other lemmas of the antonym for completeness
|
| 239 |
+
for alt_lem in wn.synsets(ant.name().split('.')[0]):
|
| 240 |
+
for alt_ant_lem in alt_lem.lemmas():
|
| 241 |
+
antonyms.add(alt_ant_lem.name().split('.')[0])
|
| 242 |
+
|
| 243 |
+
return antonyms
|
| 244 |
+
|
| 245 |
+
def _are_semantically_compatible(target, candidate):
|
| 246 |
+
"""
|
| 247 |
+
Check if target and candidate are semantically compatible for replacement.
|
| 248 |
+
Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals).
|
| 249 |
+
"""
|
| 250 |
+
try:
|
| 251 |
+
# Direct check: if target and candidate are both specific terms for crops, animals, etc.
|
| 252 |
+
# check if they're NOT near-synonyms
|
| 253 |
+
|
| 254 |
+
# Agricultural/crop terms that shouldn't be swapped
|
| 255 |
+
agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum',
|
| 256 |
+
'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume']
|
| 257 |
+
|
| 258 |
+
# If both are agricultural terms and different, block
|
| 259 |
+
if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and
|
| 260 |
+
target.lower() != candidate.lower()):
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
target_synsets = wn.synsets(target)
|
| 264 |
+
cand_synsets = wn.synsets(candidate)
|
| 265 |
+
|
| 266 |
+
if not target_synsets or not cand_synsets:
|
| 267 |
+
return True # If no synsets, allow through
|
| 268 |
+
|
| 269 |
+
# Check if they're near-synonyms (very similar) - if so, allow
|
| 270 |
+
# We can use path similarity to check if they're similar enough
|
| 271 |
+
max_similarity = 0.0
|
| 272 |
+
for t_syn in target_synsets:
|
| 273 |
+
for c_syn in cand_synsets:
|
| 274 |
+
try:
|
| 275 |
+
similarity = t_syn.path_similarity(c_syn) or 0.0
|
| 276 |
+
max_similarity = max(max_similarity, similarity)
|
| 277 |
+
except:
|
| 278 |
+
pass
|
| 279 |
+
|
| 280 |
+
# If they have high path similarity (>0.5), they're similar enough to allow
|
| 281 |
+
if max_similarity > 0.5:
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
# Otherwise, check if they share common direct hypernyms
|
| 285 |
+
target_hypernyms = set()
|
| 286 |
+
for syn in target_synsets:
|
| 287 |
+
# Get immediate hypernyms (parent concepts)
|
| 288 |
+
for hypernym in syn.hypernyms():
|
| 289 |
+
target_hypernyms.add(hypernym)
|
| 290 |
+
|
| 291 |
+
cand_hypernyms = set()
|
| 292 |
+
for syn in cand_synsets:
|
| 293 |
+
for hypernym in syn.hypernyms():
|
| 294 |
+
cand_hypernyms.add(hypernym)
|
| 295 |
+
|
| 296 |
+
# If they share hypernyms, check if they're both specific instances (not general terms)
|
| 297 |
+
common_hypernyms = target_hypernyms & cand_hypernyms
|
| 298 |
+
|
| 299 |
+
if common_hypernyms:
|
| 300 |
+
# Check if both words are specific instances of the same category
|
| 301 |
+
# If so, they shouldn't be replaced with each other
|
| 302 |
+
# We identify this by checking if their hypernym has many siblings
|
| 303 |
+
for hypernym in common_hypernyms:
|
| 304 |
+
siblings = hypernym.hyponyms()
|
| 305 |
+
# If there are many specific instances (e.g., many crops, many fruits)
|
| 306 |
+
# it's likely a category with specific instances that shouldn't be interchanged
|
| 307 |
+
if len(siblings) > 3:
|
| 308 |
+
# Check if hypernym name suggests a specific category
|
| 309 |
+
hypernym_name = hypernym.name().split('.')[0]
|
| 310 |
+
category_keywords = [
|
| 311 |
+
'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company',
|
| 312 |
+
'country', 'city', 'brand', 'product', 'food', 'vehicle'
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
# If the hypernym contains category keywords, these are likely
|
| 316 |
+
# specific instances that shouldn't be swapped
|
| 317 |
+
if any(keyword in hypernym_name for keyword in category_keywords):
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
return True
|
| 321 |
+
|
| 322 |
+
except Exception as e:
|
| 323 |
+
# On any error, allow the candidate through (conservative approach)
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
+
def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400):
|
| 327 |
+
"""
|
| 328 |
+
Create context windows around the target sentence for better MLM generation.
|
| 329 |
+
Intelligently handles tokenizer length limits by preserving the most relevant context.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
full_text: The complete document text
|
| 333 |
+
target_sentence: The sentence containing the target word
|
| 334 |
+
target_word: The word to be replaced
|
| 335 |
+
tokenizer: The tokenizer to check length limits
|
| 336 |
+
max_tokens: Maximum tokens to use for context (leave room for instruction + mask)
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
List of context windows with different levels of context
|
| 340 |
+
"""
|
| 341 |
+
# Split full text into sentences
|
| 342 |
+
sentences = split_sentences(full_text)
|
| 343 |
+
|
| 344 |
+
# Find the target sentence index
|
| 345 |
+
target_sentence_idx = None
|
| 346 |
+
for i, sent in enumerate(sentences):
|
| 347 |
+
if target_sentence.strip() in sent.strip():
|
| 348 |
+
target_sentence_idx = i
|
| 349 |
+
break
|
| 350 |
+
|
| 351 |
+
if target_sentence_idx is None:
|
| 352 |
+
return [target_sentence] # Fallback to original sentence
|
| 353 |
+
|
| 354 |
+
# Create context windows with sentence-prioritized approach
|
| 355 |
+
context_windows = []
|
| 356 |
+
|
| 357 |
+
# Window 1: Just the target sentence (always include)
|
| 358 |
+
context_windows.append(target_sentence)
|
| 359 |
+
|
| 360 |
+
# Window 2: Target sentence + 1 sentence before and after (if fits)
|
| 361 |
+
start_idx = max(0, target_sentence_idx - 1)
|
| 362 |
+
end_idx = min(len(sentences), target_sentence_idx + 2)
|
| 363 |
+
context_window = " ".join(sentences[start_idx:end_idx])
|
| 364 |
+
|
| 365 |
+
try:
|
| 366 |
+
encoded_len = len(tokenizer.encode(context_window))
|
| 367 |
+
if encoded_len <= max_tokens:
|
| 368 |
+
context_windows.append(context_window)
|
| 369 |
+
except Exception as e:
|
| 370 |
+
pass
|
| 371 |
+
|
| 372 |
+
# Window 3: Target sentence + 2 sentences before and after (if fits)
|
| 373 |
+
start_idx = max(0, target_sentence_idx - 2)
|
| 374 |
+
end_idx = min(len(sentences), target_sentence_idx + 3)
|
| 375 |
+
context_window = " ".join(sentences[start_idx:end_idx])
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
encoded_len = len(tokenizer.encode(context_window))
|
| 379 |
+
if encoded_len <= max_tokens:
|
| 380 |
+
context_windows.append(context_window)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
+
# Window 4: Target sentence + 3 sentences before and after (if fits)
|
| 385 |
+
start_idx = max(0, target_sentence_idx - 3)
|
| 386 |
+
end_idx = min(len(sentences), target_sentence_idx + 4)
|
| 387 |
+
context_window = " ".join(sentences[start_idx:end_idx])
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
encoded_len = len(tokenizer.encode(context_window))
|
| 391 |
+
if encoded_len <= max_tokens:
|
| 392 |
+
context_windows.append(context_window)
|
| 393 |
+
except Exception as e:
|
| 394 |
+
pass
|
| 395 |
+
|
| 396 |
+
# Window 5: Intelligent context with sentence prioritization + word expansion
|
| 397 |
+
intelligent_context = _create_intelligent_context(
|
| 398 |
+
full_text, target_word, target_sentence_idx, tokenizer, max_tokens
|
| 399 |
+
)
|
| 400 |
+
context_windows.append(intelligent_context)
|
| 401 |
+
|
| 402 |
+
return context_windows
|
| 403 |
+
|
| 404 |
+
def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens):
|
| 405 |
+
"""
|
| 406 |
+
Create intelligent context that prioritizes sentence boundaries while respecting token limits.
|
| 407 |
+
Strategy: Target sentence → Nearby sentences → Word-level expansion
|
| 408 |
+
"""
|
| 409 |
+
sentences = split_sentences(full_text)
|
| 410 |
+
|
| 411 |
+
# Strategy 1: Always start with the target sentence
|
| 412 |
+
target_sentence = sentences[target_sentence_idx]
|
| 413 |
+
try:
|
| 414 |
+
target_sentence_tokens = len(tokenizer.encode(target_sentence))
|
| 415 |
+
except Exception as e:
|
| 416 |
+
target_sentence_tokens = 1000 # Fallback to assume it's too long
|
| 417 |
+
|
| 418 |
+
if target_sentence_tokens > max_tokens:
|
| 419 |
+
# If even target sentence is too long, truncate intelligently
|
| 420 |
+
return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens)
|
| 421 |
+
|
| 422 |
+
# Strategy 2: Expand sentence-by-sentence around target sentence
|
| 423 |
+
best_context = target_sentence
|
| 424 |
+
best_token_count = target_sentence_tokens
|
| 425 |
+
|
| 426 |
+
# Try adding sentences before and after the target sentence
|
| 427 |
+
for sentence_radius in range(1, min(len(sentences), 20)): # Max 20 sentences radius
|
| 428 |
+
start_idx = max(0, target_sentence_idx - sentence_radius)
|
| 429 |
+
end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1)
|
| 430 |
+
|
| 431 |
+
# Create context with complete sentences
|
| 432 |
+
context_sentences = sentences[start_idx:end_idx]
|
| 433 |
+
context_window = " ".join(context_sentences)
|
| 434 |
+
try:
|
| 435 |
+
token_count = len(tokenizer.encode(context_window))
|
| 436 |
+
except Exception as e:
|
| 437 |
+
token_count = 1000 # Fallback to assume it's too long
|
| 438 |
+
|
| 439 |
+
if token_count <= max_tokens:
|
| 440 |
+
# This sentence expansion fits, keep it as our best option
|
| 441 |
+
best_context = context_window
|
| 442 |
+
best_token_count = token_count
|
| 443 |
+
else:
|
| 444 |
+
# This expansion is too big, stop here
|
| 445 |
+
break
|
| 446 |
+
|
| 447 |
+
# Strategy 3: If we have room left, try word-level expansion within the best sentence context
|
| 448 |
+
remaining_tokens = max_tokens - best_token_count
|
| 449 |
+
if remaining_tokens > 50: # If we have significant room left
|
| 450 |
+
enhanced_context = _enhance_with_word_expansion(
|
| 451 |
+
full_text, target_word, best_context, tokenizer, remaining_tokens
|
| 452 |
+
)
|
| 453 |
+
if enhanced_context:
|
| 454 |
+
return enhanced_context
|
| 455 |
+
|
| 456 |
+
return best_context
|
| 457 |
+
|
| 458 |
+
def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens):
|
| 459 |
+
"""
|
| 460 |
+
Enhance the current sentence-based context with word-level expansion if there's room.
|
| 461 |
+
"""
|
| 462 |
+
words = full_text.split()
|
| 463 |
+
target_word_idx = None
|
| 464 |
+
|
| 465 |
+
# Find target word position in full text
|
| 466 |
+
for i, word in enumerate(words):
|
| 467 |
+
if word.lower() == target_word.lower():
|
| 468 |
+
target_word_idx = i
|
| 469 |
+
break
|
| 470 |
+
|
| 471 |
+
if target_word_idx is None:
|
| 472 |
+
return current_context
|
| 473 |
+
|
| 474 |
+
# Try to expand word-by-word around the target word
|
| 475 |
+
try:
|
| 476 |
+
current_tokens = len(tokenizer.encode(current_context))
|
| 477 |
+
except Exception as e:
|
| 478 |
+
print(f"WARNING: Error encoding current context: {e}")
|
| 479 |
+
current_tokens = 1000 # Fallback to assume it's too long
|
| 480 |
+
|
| 481 |
+
for expansion_size in range(1, min(len(words), 100)): # Max 100 words expansion
|
| 482 |
+
start_word = max(0, target_word_idx - expansion_size)
|
| 483 |
+
end_word = min(len(words), target_word_idx + expansion_size + 1)
|
| 484 |
+
|
| 485 |
+
expanded_context = " ".join(words[start_word:end_word])
|
| 486 |
+
try:
|
| 487 |
+
expanded_tokens = len(tokenizer.encode(expanded_context))
|
| 488 |
+
except Exception as e:
|
| 489 |
+
expanded_tokens = 1000 # Fallback to assume it's too long
|
| 490 |
+
|
| 491 |
+
if expanded_tokens <= current_tokens + remaining_tokens:
|
| 492 |
+
# This expansion fits within our remaining token budget
|
| 493 |
+
return expanded_context
|
| 494 |
+
else:
|
| 495 |
+
# This expansion is too big, stop here
|
| 496 |
+
break
|
| 497 |
+
|
| 498 |
+
return current_context
|
| 499 |
+
|
| 500 |
+
def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens):
|
| 501 |
+
"""
|
| 502 |
+
Intelligently truncate a sentence while preserving context around the target word.
|
| 503 |
+
"""
|
| 504 |
+
words = sentence.split()
|
| 505 |
+
target_word_idx = None
|
| 506 |
+
|
| 507 |
+
# Find target word position
|
| 508 |
+
for i, word in enumerate(words):
|
| 509 |
+
if word.lower() == target_word.lower():
|
| 510 |
+
target_word_idx = i
|
| 511 |
+
break
|
| 512 |
+
|
| 513 |
+
if target_word_idx is None:
|
| 514 |
+
# If target word not found, truncate from the end
|
| 515 |
+
truncated = " ".join(words)
|
| 516 |
+
try:
|
| 517 |
+
while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1:
|
| 518 |
+
words = words[:-1]
|
| 519 |
+
truncated = " ".join(words)
|
| 520 |
+
except Exception as e:
|
| 521 |
+
# Fallback: return first few words
|
| 522 |
+
truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
|
| 523 |
+
return truncated
|
| 524 |
+
|
| 525 |
+
# Truncate symmetrically around target word
|
| 526 |
+
context_words = 10 # Start with 10 words before/after
|
| 527 |
+
while context_words > 0:
|
| 528 |
+
start_word = max(0, target_word_idx - context_words)
|
| 529 |
+
end_word = min(len(words), target_word_idx + context_words + 1)
|
| 530 |
+
truncated_sentence = " ".join(words[start_word:end_word])
|
| 531 |
+
|
| 532 |
+
try:
|
| 533 |
+
if len(tokenizer.encode(truncated_sentence)) <= max_tokens:
|
| 534 |
+
return truncated_sentence
|
| 535 |
+
except Exception as e:
|
| 536 |
+
# Continue to next iteration
|
| 537 |
+
pass
|
| 538 |
+
|
| 539 |
+
context_words -= 1
|
| 540 |
+
|
| 541 |
+
# Fallback: just the target word with minimal context
|
| 542 |
+
return f"... {target_word} ..."
|
| 543 |
+
|
| 544 |
+
def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None):
|
| 545 |
+
"""
|
| 546 |
+
Intelligently slice input text to fit within max_length tokens while preserving the mask token.
|
| 547 |
+
Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
input_text: The full input text to be tokenized
|
| 551 |
+
tokenizer: The tokenizer to use
|
| 552 |
+
max_length: Maximum allowed sequence length (default 512)
|
| 553 |
+
mask_token_id: The mask token ID to preserve
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
Tuple of (sliced_input_ids, mask_position_in_sliced)
|
| 557 |
+
"""
|
| 558 |
+
# First, tokenize the full input
|
| 559 |
+
input_ids = tokenizer.encode(input_text, add_special_tokens=True)
|
| 560 |
+
|
| 561 |
+
# If already within limits, return as is
|
| 562 |
+
if len(input_ids) <= max_length:
|
| 563 |
+
mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None
|
| 564 |
+
return input_ids, mask_pos
|
| 565 |
+
|
| 566 |
+
# Find mask token position
|
| 567 |
+
mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id]
|
| 568 |
+
|
| 569 |
+
if not mask_positions:
|
| 570 |
+
# No mask token found, truncate from the end
|
| 571 |
+
return input_ids[:max_length], None
|
| 572 |
+
|
| 573 |
+
mask_pos = mask_positions[0] # Use first mask token
|
| 574 |
+
|
| 575 |
+
# Calculate how many tokens we need to remove
|
| 576 |
+
excess_tokens = len(input_ids) - max_length
|
| 577 |
+
|
| 578 |
+
# Strategy: Remove tokens from both ends while preserving mask context
|
| 579 |
+
# Reserve some context around the mask token
|
| 580 |
+
mask_context_size = min(50, max_length // 4) # Reserve 25% of max_length or 50 tokens, whichever is smaller
|
| 581 |
+
|
| 582 |
+
# Calculate available space for context around mask
|
| 583 |
+
available_before = min(mask_pos, mask_context_size)
|
| 584 |
+
available_after = min(len(input_ids) - mask_pos - 1, mask_context_size)
|
| 585 |
+
|
| 586 |
+
# Calculate how much to remove from each end
|
| 587 |
+
tokens_to_remove_before = max(0, mask_pos - available_before)
|
| 588 |
+
tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after)
|
| 589 |
+
|
| 590 |
+
# Initialize removal variables
|
| 591 |
+
remove_before = 0
|
| 592 |
+
remove_after = 0
|
| 593 |
+
|
| 594 |
+
# Distribute excess tokens proportionally
|
| 595 |
+
if excess_tokens > 0:
|
| 596 |
+
if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens:
|
| 597 |
+
# We can remove enough from the ends
|
| 598 |
+
if tokens_to_remove_before >= excess_tokens // 2:
|
| 599 |
+
remove_before = excess_tokens // 2
|
| 600 |
+
remove_after = excess_tokens - remove_before
|
| 601 |
+
else:
|
| 602 |
+
remove_before = tokens_to_remove_before
|
| 603 |
+
remove_after = min(tokens_to_remove_after, excess_tokens - remove_before)
|
| 604 |
+
else:
|
| 605 |
+
# Need to remove more aggressively
|
| 606 |
+
remove_before = tokens_to_remove_before
|
| 607 |
+
remove_after = tokens_to_remove_after
|
| 608 |
+
remaining_excess = excess_tokens - remove_before - remove_after
|
| 609 |
+
|
| 610 |
+
# Remove remaining excess from the end
|
| 611 |
+
if remaining_excess > 0:
|
| 612 |
+
remove_after += remaining_excess
|
| 613 |
+
|
| 614 |
+
# Calculate final indices
|
| 615 |
+
start_idx = remove_before
|
| 616 |
+
end_idx = len(input_ids) - remove_after
|
| 617 |
+
|
| 618 |
+
# Ensure we don't exceed max_length
|
| 619 |
+
if end_idx - start_idx > max_length:
|
| 620 |
+
# Center around mask token
|
| 621 |
+
half_length = max_length // 2
|
| 622 |
+
start_idx = max(0, mask_pos - half_length)
|
| 623 |
+
end_idx = min(len(input_ids), start_idx + max_length)
|
| 624 |
+
|
| 625 |
+
# Slice the input_ids
|
| 626 |
+
sliced_input_ids = input_ids[start_idx:end_idx]
|
| 627 |
+
|
| 628 |
+
# Debug information
|
| 629 |
+
if len(sliced_input_ids) > max_length:
|
| 630 |
+
# Force truncation as final fallback
|
| 631 |
+
sliced_input_ids = sliced_input_ids[:max_length]
|
| 632 |
+
|
| 633 |
+
# Adjust mask position for the sliced sequence
|
| 634 |
+
adjusted_mask_pos = mask_pos - start_idx
|
| 635 |
+
|
| 636 |
+
return sliced_input_ids, adjusted_mask_pos
|
| 637 |
+
|
| 638 |
+
def _create_word_level_context(full_text, target_word, tokenizer, max_tokens):
|
| 639 |
+
"""
|
| 640 |
+
Create context by expanding word-by-word around the target word until reaching token limit.
|
| 641 |
+
This maximizes context while respecting tokenizer limits.
|
| 642 |
+
"""
|
| 643 |
+
words = full_text.split()
|
| 644 |
+
target_word_idx = None
|
| 645 |
+
|
| 646 |
+
# Find target word position in full text
|
| 647 |
+
for i, word in enumerate(words):
|
| 648 |
+
if word.lower() == target_word.lower():
|
| 649 |
+
target_word_idx = i
|
| 650 |
+
break
|
| 651 |
+
|
| 652 |
+
if target_word_idx is None:
|
| 653 |
+
# Fallback: expand from beginning until token limit
|
| 654 |
+
return _expand_from_start(words, tokenizer, max_tokens)
|
| 655 |
+
|
| 656 |
+
# Word-by-word expansion around target word
|
| 657 |
+
return _expand_around_target(words, target_word_idx, tokenizer, max_tokens)
|
| 658 |
+
|
| 659 |
+
def _expand_around_target(words, target_idx, tokenizer, max_tokens):
|
| 660 |
+
"""
|
| 661 |
+
Expand word-by-word around target word until reaching token limit.
|
| 662 |
+
"""
|
| 663 |
+
best_context = ""
|
| 664 |
+
best_token_count = 0
|
| 665 |
+
|
| 666 |
+
# Try different expansion sizes
|
| 667 |
+
for expansion_size in range(1, min(len(words), 200)): # Max 200 words expansion
|
| 668 |
+
start_word = max(0, target_idx - expansion_size)
|
| 669 |
+
end_word = min(len(words), target_idx + expansion_size + 1)
|
| 670 |
+
|
| 671 |
+
context_window = " ".join(words[start_word:end_word])
|
| 672 |
+
try:
|
| 673 |
+
token_count = len(tokenizer.encode(context_window))
|
| 674 |
+
except Exception as e:
|
| 675 |
+
token_count = 1000 # Fallback to assume it's too long
|
| 676 |
+
|
| 677 |
+
if token_count <= max_tokens:
|
| 678 |
+
# This expansion fits, keep it as our best option
|
| 679 |
+
best_context = context_window
|
| 680 |
+
best_token_count = token_count
|
| 681 |
+
else:
|
| 682 |
+
# This expansion is too big, stop here
|
| 683 |
+
break
|
| 684 |
+
|
| 685 |
+
# If we found a good context, return it
|
| 686 |
+
if best_context:
|
| 687 |
+
return best_context
|
| 688 |
+
|
| 689 |
+
# Fallback: minimal context around target word
|
| 690 |
+
start_word = max(0, target_idx - 5)
|
| 691 |
+
end_word = min(len(words), target_idx + 6)
|
| 692 |
+
return " ".join(words[start_word:end_word])
|
| 693 |
+
|
| 694 |
+
def _expand_from_start(words, tokenizer, max_tokens):
|
| 695 |
+
"""
|
| 696 |
+
Expand from the start of the text until reaching token limit.
|
| 697 |
+
"""
|
| 698 |
+
for end_idx in range(len(words), 0, -1):
|
| 699 |
+
context_window = " ".join(words[:end_idx])
|
| 700 |
+
try:
|
| 701 |
+
if len(tokenizer.encode(context_window)) <= max_tokens:
|
| 702 |
+
return context_window
|
| 703 |
+
except Exception as e:
|
| 704 |
+
# Continue to next iteration
|
| 705 |
+
pass
|
| 706 |
+
|
| 707 |
+
# Fallback: first few words
|
| 708 |
+
return " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
|
| 709 |
+
|
| 710 |
+
def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
|
| 711 |
+
"""
|
| 712 |
+
Enhanced MLM inference using whole document context for better candidate generation.
|
| 713 |
+
"""
|
| 714 |
+
results = {}
|
| 715 |
+
|
| 716 |
+
# Group targets by sentence for batch processing
|
| 717 |
+
sentence_groups = {}
|
| 718 |
+
for sent, target, index in sentence_target_pairs:
|
| 719 |
+
if sent not in sentence_groups:
|
| 720 |
+
sentence_groups[sent] = []
|
| 721 |
+
sentence_groups[sent].append((target, index))
|
| 722 |
+
|
| 723 |
+
for sentence, targets in sentence_groups.items():
|
| 724 |
+
# Process targets in batches
|
| 725 |
+
for i in range(0, len(targets), batch_size):
|
| 726 |
+
batch_targets = targets[i:i+batch_size]
|
| 727 |
+
batch_results = _process_whole_context_mlm_batch(
|
| 728 |
+
full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode
|
| 729 |
+
)
|
| 730 |
+
results.update(batch_results)
|
| 731 |
+
|
| 732 |
+
return results
|
| 733 |
+
|
| 734 |
+
def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
|
| 735 |
+
"""
|
| 736 |
+
Process a batch of targets using whole document context for MLM.
|
| 737 |
+
"""
|
| 738 |
+
results = {}
|
| 739 |
+
|
| 740 |
+
# Tokenize sentence once
|
| 741 |
+
doc = nlp(sentence)
|
| 742 |
+
tokens = [token.text for token in doc]
|
| 743 |
+
|
| 744 |
+
# Create multiple masked versions for batch processing
|
| 745 |
+
masked_inputs = []
|
| 746 |
+
mask_positions = []
|
| 747 |
+
contexts_for_targets = []
|
| 748 |
+
|
| 749 |
+
for target, index in targets:
|
| 750 |
+
if index < len(tokens):
|
| 751 |
+
# Create context windows with tokenizer length awareness
|
| 752 |
+
context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens)
|
| 753 |
+
|
| 754 |
+
# Use the most comprehensive context window that fits within token limits
|
| 755 |
+
full_context = context_windows[-1] # Built around the target sentence
|
| 756 |
+
# Select context for similarity according to mode
|
| 757 |
+
context = sentence if similarity_context_mode == 'sentence' else full_context
|
| 758 |
+
|
| 759 |
+
# Create masked version of the FULL context (not just the sentence)
|
| 760 |
+
masked_full_context = context.replace(target, tokenizer.mask_token, 1)
|
| 761 |
+
|
| 762 |
+
instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:"
|
| 763 |
+
input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}"
|
| 764 |
+
|
| 765 |
+
# AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors
|
| 766 |
+
# Estimate token count (roughly 1 token per 4 characters for English)
|
| 767 |
+
estimated_tokens = len(input_text) // 4
|
| 768 |
+
if estimated_tokens > 500: # Leave some buffer
|
| 769 |
+
# Truncate to roughly 2000 characters (500 tokens)
|
| 770 |
+
input_text = input_text[:2000]
|
| 771 |
+
|
| 772 |
+
# SIMPLE FIX: Truncate input text if it's too long
|
| 773 |
+
try:
|
| 774 |
+
temp_tokens = tokenizer.encode(input_text, add_special_tokens=True)
|
| 775 |
+
if len(temp_tokens) > 512:
|
| 776 |
+
# Truncate the input text by removing words from the end
|
| 777 |
+
words = input_text.split()
|
| 778 |
+
while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10:
|
| 779 |
+
words = words[:-1]
|
| 780 |
+
input_text = " ".join(words)
|
| 781 |
+
except Exception as e:
|
| 782 |
+
# Emergency truncation - just take first 200 words
|
| 783 |
+
words = input_text.split()
|
| 784 |
+
input_text = " ".join(words[:200])
|
| 785 |
+
|
| 786 |
+
masked_inputs.append(input_text)
|
| 787 |
+
# Store the original sentence-level index for reference, but mask position will be calculated during tokenization
|
| 788 |
+
mask_positions.append(index)
|
| 789 |
+
contexts_for_targets.append(context)
|
| 790 |
+
|
| 791 |
+
if not masked_inputs:
|
| 792 |
+
return results
|
| 793 |
+
|
| 794 |
+
# Batch tokenize
|
| 795 |
+
MAX_LENGTH = max_length # Use parameter for A100 optimization
|
| 796 |
+
batch_inputs = []
|
| 797 |
+
batch_mask_positions = []
|
| 798 |
+
batch_contexts = []
|
| 799 |
+
|
| 800 |
+
for input_text, mask_pos in zip(masked_inputs, mask_positions):
|
| 801 |
+
# Use intelligent token slicing to ensure we stay within MAX_LENGTH
|
| 802 |
+
try:
|
| 803 |
+
input_ids, adjusted_mask_pos = _intelligent_token_slicing(
|
| 804 |
+
input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
if adjusted_mask_pos is not None:
|
| 808 |
+
batch_inputs.append(input_ids)
|
| 809 |
+
batch_mask_positions.append(adjusted_mask_pos)
|
| 810 |
+
else:
|
| 811 |
+
# Mask token not found in sliced sequence, skip this input
|
| 812 |
+
continue
|
| 813 |
+
|
| 814 |
+
except Exception as e:
|
| 815 |
+
# Fallback: simple truncation
|
| 816 |
+
try:
|
| 817 |
+
input_ids = tokenizer.encode(input_text, add_special_tokens=True)
|
| 818 |
+
if len(input_ids) > MAX_LENGTH:
|
| 819 |
+
input_ids = input_ids[:MAX_LENGTH]
|
| 820 |
+
|
| 821 |
+
masked_position = input_ids.index(tokenizer.mask_token_id)
|
| 822 |
+
batch_inputs.append(input_ids)
|
| 823 |
+
batch_mask_positions.append(masked_position)
|
| 824 |
+
except ValueError:
|
| 825 |
+
# Mask token not found, skip this input
|
| 826 |
+
continue
|
| 827 |
+
|
| 828 |
+
if not batch_inputs:
|
| 829 |
+
return results
|
| 830 |
+
|
| 831 |
+
# Pad sequences to same length, but ensure we don't exceed MAX_LENGTH
|
| 832 |
+
max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH)
|
| 833 |
+
|
| 834 |
+
# Additional safety check: truncate any sequences that are still too long
|
| 835 |
+
truncated_batch_inputs = []
|
| 836 |
+
for input_ids in batch_inputs:
|
| 837 |
+
if len(input_ids) > MAX_LENGTH:
|
| 838 |
+
input_ids = input_ids[:MAX_LENGTH]
|
| 839 |
+
truncated_batch_inputs.append(input_ids)
|
| 840 |
+
|
| 841 |
+
padded_inputs = []
|
| 842 |
+
attention_masks = []
|
| 843 |
+
|
| 844 |
+
for input_ids in truncated_batch_inputs:
|
| 845 |
+
attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids))
|
| 846 |
+
padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
|
| 847 |
+
padded_inputs.append(padded_ids)
|
| 848 |
+
attention_masks.append(attention_mask)
|
| 849 |
+
|
| 850 |
+
# Final safety check: ensure no sequence exceeds MAX_LENGTH
|
| 851 |
+
for i, padded_ids in enumerate(padded_inputs):
|
| 852 |
+
if len(padded_ids) > MAX_LENGTH:
|
| 853 |
+
padded_inputs[i] = padded_ids[:MAX_LENGTH]
|
| 854 |
+
attention_masks[i] = attention_masks[i][:MAX_LENGTH]
|
| 855 |
+
|
| 856 |
+
# Batch inference - optimized for A100 with mixed precision
|
| 857 |
+
with torch.no_grad():
|
| 858 |
+
input_tensor = torch.tensor(padded_inputs, dtype=torch.long)
|
| 859 |
+
attention_tensor = torch.tensor(attention_masks, dtype=torch.long)
|
| 860 |
+
|
| 861 |
+
if torch.cuda.is_available():
|
| 862 |
+
input_tensor = input_tensor.cuda()
|
| 863 |
+
attention_tensor = attention_tensor.cuda()
|
| 864 |
+
|
| 865 |
+
# Use mixed precision for A100 optimization
|
| 866 |
+
with torch.amp.autocast('cuda'):
|
| 867 |
+
outputs = lm_model(input_tensor, attention_mask=attention_tensor)
|
| 868 |
+
logits = outputs.logits
|
| 869 |
+
else:
|
| 870 |
+
outputs = lm_model(input_tensor, attention_mask=attention_tensor)
|
| 871 |
+
logits = outputs.logits
|
| 872 |
+
|
| 873 |
+
# Process results - collect filtered candidates first
|
| 874 |
+
batch_filtered_results = {}
|
| 875 |
+
for i, (target, index) in enumerate(targets):
|
| 876 |
+
if i < len(batch_mask_positions):
|
| 877 |
+
mask_pos = batch_mask_positions[i]
|
| 878 |
+
mask_logits = logits[i, mask_pos].squeeze()
|
| 879 |
+
|
| 880 |
+
# Get top predictions
|
| 881 |
+
top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1]
|
| 882 |
+
scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist()
|
| 883 |
+
words = [tokenizer.decode(token.item()).strip() for token in top_tokens]
|
| 884 |
+
|
| 885 |
+
# Filter candidates (without similarity scoring)
|
| 886 |
+
filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index)
|
| 887 |
+
if filtered_candidates:
|
| 888 |
+
# Attach the exact context window used for this target
|
| 889 |
+
batch_filtered_results[(sentence, target, index)] = {
|
| 890 |
+
'filtered_words': filtered_candidates,
|
| 891 |
+
'context': contexts_for_targets[i]
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
# Batch similarity scoring for all candidates
|
| 895 |
+
if batch_filtered_results:
|
| 896 |
+
similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer)
|
| 897 |
+
results.update(similarity_results)
|
| 898 |
+
|
| 899 |
+
return results
|
| 900 |
+
|
| 901 |
+
def _filter_candidates_batch(target, words, scores, tokens, target_index):
|
| 902 |
+
"""
|
| 903 |
+
Optimized batch filtering of candidates (no similarity scoring here - moved to batch level).
|
| 904 |
+
"""
|
| 905 |
+
# Basic filtering
|
| 906 |
+
filtered_words = []
|
| 907 |
+
filtered_scores = []
|
| 908 |
+
seen_words = set()
|
| 909 |
+
|
| 910 |
+
for word, score in zip(words, scores):
|
| 911 |
+
word_lower = word.lower()
|
| 912 |
+
if word_lower in seen_words or word_lower == target.lower():
|
| 913 |
+
continue
|
| 914 |
+
seen_words.add(word_lower)
|
| 915 |
+
|
| 916 |
+
if not is_valid_word(word):
|
| 917 |
+
continue
|
| 918 |
+
|
| 919 |
+
# Quick POS check
|
| 920 |
+
target_nltk_pos, target_spacy_pos = get_word_pos_tags(target)
|
| 921 |
+
cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word)
|
| 922 |
+
|
| 923 |
+
if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos:
|
| 924 |
+
continue
|
| 925 |
+
|
| 926 |
+
# Check antonyms (bidirectional and case-insensitive)
|
| 927 |
+
antonyms = get_word_antonyms(target)
|
| 928 |
+
if word.lower() in [ant.lower() for ant in antonyms]:
|
| 929 |
+
continue
|
| 930 |
+
|
| 931 |
+
# Also check if the candidate has the target as an antonym (reverse check)
|
| 932 |
+
candidate_antonyms = get_word_antonyms(word)
|
| 933 |
+
if target.lower() in [ant.lower() for ant in candidate_antonyms]:
|
| 934 |
+
continue
|
| 935 |
+
|
| 936 |
+
# Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard)
|
| 937 |
+
common_antonyms = {
|
| 938 |
+
'big': ['small', 'tiny', 'little'],
|
| 939 |
+
'small': ['big', 'large', 'huge'],
|
| 940 |
+
'large': ['small', 'tiny', 'little'],
|
| 941 |
+
'good': ['bad', 'evil', 'wrong'],
|
| 942 |
+
'bad': ['good', 'great', 'excellent'],
|
| 943 |
+
'high': ['low'],
|
| 944 |
+
'low': ['high'],
|
| 945 |
+
'new': ['old'],
|
| 946 |
+
'old': ['new'],
|
| 947 |
+
'fast': ['slow'],
|
| 948 |
+
'slow': ['fast'],
|
| 949 |
+
'rich': ['poor'],
|
| 950 |
+
'poor': ['rich'],
|
| 951 |
+
'hot': ['cold'],
|
| 952 |
+
'cold': ['hot'],
|
| 953 |
+
'happy': ['sad', 'unhappy'],
|
| 954 |
+
'sad': ['happy', 'joyful'],
|
| 955 |
+
'true': ['false', 'untrue'],
|
| 956 |
+
'false': ['true'],
|
| 957 |
+
'real': ['fake', 'unreal'],
|
| 958 |
+
'fake': ['real'],
|
| 959 |
+
'up': ['down'],
|
| 960 |
+
'down': ['up'],
|
| 961 |
+
'yes': ['no'],
|
| 962 |
+
'no': ['yes'],
|
| 963 |
+
'alive': ['dead'],
|
| 964 |
+
'dead': ['alive'],
|
| 965 |
+
'safe': ['unsafe', 'dangerous'],
|
| 966 |
+
'dangerous': ['safe'],
|
| 967 |
+
'clean': ['dirty'],
|
| 968 |
+
'dirty': ['clean'],
|
| 969 |
+
'full': ['empty'],
|
| 970 |
+
'empty': ['full'],
|
| 971 |
+
'open': ['closed', 'shut'],
|
| 972 |
+
'closed': ['open'],
|
| 973 |
+
'begin': ['end', 'finish'],
|
| 974 |
+
'end': ['begin', 'start'],
|
| 975 |
+
'start': ['end', 'finish'],
|
| 976 |
+
'finish': ['start', 'begin'],
|
| 977 |
+
'first': ['last'],
|
| 978 |
+
'last': ['first']
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
# Check if word is a known antonym of target (case-insensitive)
|
| 982 |
+
target_lower = target.lower()
|
| 983 |
+
if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]:
|
| 984 |
+
continue
|
| 985 |
+
|
| 986 |
+
# Check if word and target are in the same specific noun category (e.g., crops, animals, companies)
|
| 987 |
+
# If they are different specific terms in the same category, exclude the candidate
|
| 988 |
+
if not _are_semantically_compatible(target, word):
|
| 989 |
+
continue
|
| 990 |
+
|
| 991 |
+
filtered_words.append(word)
|
| 992 |
+
filtered_scores.append(score)
|
| 993 |
+
|
| 994 |
+
if len(filtered_words) < 2:
|
| 995 |
+
return None
|
| 996 |
+
|
| 997 |
+
# Return filtered words without similarity scoring (done at batch level)
|
| 998 |
+
return filtered_words
|
| 999 |
+
|
| 1000 |
+
def _batch_similarity_scoring(batch_results, tokenizer):
|
| 1001 |
+
"""
|
| 1002 |
+
Optimized batched similarity scoring across multiple sentences for full context.
|
| 1003 |
+
Processes all candidates from multiple sentences together for better efficiency.
|
| 1004 |
+
"""
|
| 1005 |
+
# Collect all similarity scoring tasks
|
| 1006 |
+
similarity_tasks = []
|
| 1007 |
+
sentence_contexts = {}
|
| 1008 |
+
|
| 1009 |
+
for (sentence, target, index), value in batch_results.items():
|
| 1010 |
+
if value is None:
|
| 1011 |
+
continue
|
| 1012 |
+
# Support both legacy list and new dict with context
|
| 1013 |
+
if isinstance(value, dict):
|
| 1014 |
+
filtered_words = value.get('filtered_words')
|
| 1015 |
+
context = value.get('context', sentence)
|
| 1016 |
+
else:
|
| 1017 |
+
filtered_words = value
|
| 1018 |
+
context = sentence
|
| 1019 |
+
|
| 1020 |
+
# Tokenize the sentence once
|
| 1021 |
+
tokens = tokenizer.tokenize(sentence)
|
| 1022 |
+
if index >= len(tokens):
|
| 1023 |
+
continue
|
| 1024 |
+
|
| 1025 |
+
# Store sentence context for later use
|
| 1026 |
+
sentence_contexts[(sentence, target, index)] = {
|
| 1027 |
+
'tokens': tokens,
|
| 1028 |
+
'target_index': index,
|
| 1029 |
+
'filtered_words': filtered_words
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
# Create candidate sentences for this target
|
| 1033 |
+
for word in filtered_words:
|
| 1034 |
+
candidate_tokens = tokens.copy()
|
| 1035 |
+
candidate_tokens[index] = word
|
| 1036 |
+
candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens)
|
| 1037 |
+
|
| 1038 |
+
# Build full-context candidate by replacing the sentence inside the chosen context once
|
| 1039 |
+
candidate_full_context = context.replace(sentence, candidate_sentence, 1)
|
| 1040 |
+
similarity_tasks.append({
|
| 1041 |
+
'original_context': context,
|
| 1042 |
+
'candidate_full_context': candidate_full_context,
|
| 1043 |
+
'target_word': word,
|
| 1044 |
+
'context_key': (sentence, target, index)
|
| 1045 |
+
})
|
| 1046 |
+
|
| 1047 |
+
if not similarity_tasks:
|
| 1048 |
+
return {}
|
| 1049 |
+
|
| 1050 |
+
# Batch process all similarity scoring
|
| 1051 |
+
try:
|
| 1052 |
+
# Group by original full context for efficient BERTScore computation
|
| 1053 |
+
context_groups = {}
|
| 1054 |
+
for task in similarity_tasks:
|
| 1055 |
+
orig_ctx = task['original_context']
|
| 1056 |
+
if orig_ctx not in context_groups:
|
| 1057 |
+
context_groups[orig_ctx] = []
|
| 1058 |
+
context_groups[orig_ctx].append(task)
|
| 1059 |
+
|
| 1060 |
+
# Process each context group
|
| 1061 |
+
final_results = {}
|
| 1062 |
+
for orig_context, tasks in context_groups.items():
|
| 1063 |
+
# Extract candidate full-contexts
|
| 1064 |
+
candidate_contexts = [task['candidate_full_context'] for task in tasks]
|
| 1065 |
+
|
| 1066 |
+
# Batch BERTScore computation against the same full context
|
| 1067 |
+
try:
|
| 1068 |
+
similarity_scores = calc_scores_bert(orig_context, candidate_contexts)
|
| 1069 |
+
except Exception as e:
|
| 1070 |
+
# Fallback to neutral scores
|
| 1071 |
+
similarity_scores = [0.5] * len(candidate_contexts)
|
| 1072 |
+
|
| 1073 |
+
if similarity_scores and not all(score == 0.5 for score in similarity_scores):
|
| 1074 |
+
# Group results by context key
|
| 1075 |
+
for task, score in zip(tasks, similarity_scores):
|
| 1076 |
+
context_key = task['context_key']
|
| 1077 |
+
if context_key not in final_results:
|
| 1078 |
+
final_results[context_key] = []
|
| 1079 |
+
final_results[context_key].append((task['target_word'], score))
|
| 1080 |
+
|
| 1081 |
+
# Sort results by similarity score
|
| 1082 |
+
for context_key in final_results:
|
| 1083 |
+
final_results[context_key].sort(key=lambda x: x[1], reverse=True)
|
| 1084 |
+
|
| 1085 |
+
return final_results
|
| 1086 |
+
|
| 1087 |
+
except Exception as e:
|
| 1088 |
+
return {}
|
| 1089 |
+
|
| 1090 |
+
def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha):
|
| 1091 |
+
"""
|
| 1092 |
+
Parallel tournament sampling for multiple targets.
|
| 1093 |
+
"""
|
| 1094 |
+
results = {}
|
| 1095 |
+
|
| 1096 |
+
if not target_results:
|
| 1097 |
+
return results
|
| 1098 |
+
|
| 1099 |
+
def process_single_tournament(item):
|
| 1100 |
+
(sentence, target, index), candidates = item
|
| 1101 |
+
if not candidates:
|
| 1102 |
+
return (sentence, target, index), None
|
| 1103 |
+
|
| 1104 |
+
alternatives = [alt[0] for alt in candidates]
|
| 1105 |
+
similarity = [alt[1] for alt in candidates]
|
| 1106 |
+
|
| 1107 |
+
if not alternatives or not similarity:
|
| 1108 |
+
return (sentence, target, index), None
|
| 1109 |
+
|
| 1110 |
+
# Get context
|
| 1111 |
+
context_tokens = word_tokenize(sentence)
|
| 1112 |
+
left_context = context_tokens[max(0, index - h):index]
|
| 1113 |
+
|
| 1114 |
+
# Tournament selection
|
| 1115 |
+
from SynthID_randomization import tournament_select_word
|
| 1116 |
+
randomized_word = tournament_select_word(
|
| 1117 |
+
target, alternatives, similarity,
|
| 1118 |
+
context=left_context, key=secret_key, m=m, c=c, alpha=alpha
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
return (sentence, target, index), randomized_word
|
| 1122 |
+
|
| 1123 |
+
# Process in parallel
|
| 1124 |
+
max_workers = max(1, min(8, len(target_results)))
|
| 1125 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 1126 |
+
future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()}
|
| 1127 |
+
|
| 1128 |
+
for future in as_completed(future_to_item):
|
| 1129 |
+
key, result = future.result()
|
| 1130 |
+
results[key] = result
|
| 1131 |
+
|
| 1132 |
+
return results
|
| 1133 |
+
|
| 1134 |
+
def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'):
|
| 1135 |
+
"""
|
| 1136 |
+
Enhanced sentence processing using whole document context for better candidate generation.
|
| 1137 |
+
"""
|
| 1138 |
+
replacements = []
|
| 1139 |
+
sampling_results = []
|
| 1140 |
+
doc = nlp(sentence)
|
| 1141 |
+
sentence_target_pairs = extract_entities_and_pos(sentence)
|
| 1142 |
+
|
| 1143 |
+
if not sentence_target_pairs:
|
| 1144 |
+
return replacements, sampling_results
|
| 1145 |
+
|
| 1146 |
+
# Filter valid target pairs
|
| 1147 |
+
valid_pairs = []
|
| 1148 |
+
spacy_tokens = [token.text for token in doc]
|
| 1149 |
+
|
| 1150 |
+
for sent, target, position in sentence_target_pairs:
|
| 1151 |
+
if position < len(spacy_tokens) and spacy_tokens[position] == target:
|
| 1152 |
+
valid_pairs.append((sent, target, position))
|
| 1153 |
+
|
| 1154 |
+
if not valid_pairs:
|
| 1155 |
+
return replacements, sampling_results
|
| 1156 |
+
|
| 1157 |
+
# Enhanced MLM inference with whole document context
|
| 1158 |
+
batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode)
|
| 1159 |
+
|
| 1160 |
+
# Filter by threshold (matching original logic)
|
| 1161 |
+
filtered_results = {}
|
| 1162 |
+
for key, candidates in batch_results.items():
|
| 1163 |
+
if candidates:
|
| 1164 |
+
# Apply threshold filtering (matching original logic)
|
| 1165 |
+
threshold_candidates = [(word, score) for word, score in candidates if score >= threshold]
|
| 1166 |
+
if len(threshold_candidates) >= 2:
|
| 1167 |
+
filtered_results[key] = threshold_candidates
|
| 1168 |
+
|
| 1169 |
+
# Parallel tournament sampling
|
| 1170 |
+
tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha)
|
| 1171 |
+
|
| 1172 |
+
# Collect replacements and sampling results
|
| 1173 |
+
for (sent, target, position), randomized_word in tournament_results.items():
|
| 1174 |
+
if randomized_word:
|
| 1175 |
+
# Get the alternatives for this target from the filtered results
|
| 1176 |
+
alternatives = filtered_results.get((sent, target, position), [])
|
| 1177 |
+
alternatives_list = [alt[0] for alt in alternatives]
|
| 1178 |
+
# Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility)
|
| 1179 |
+
alternatives_with_similarity = [
|
| 1180 |
+
{"word": alt[0], "similarity": float(alt[1])} for alt in alternatives
|
| 1181 |
+
]
|
| 1182 |
+
|
| 1183 |
+
# Track sampling results
|
| 1184 |
+
sampling_results.append({
|
| 1185 |
+
"word": target,
|
| 1186 |
+
"alternatives": alternatives_list,
|
| 1187 |
+
"alternatives_with_similarity": alternatives_with_similarity,
|
| 1188 |
+
"randomized_word": randomized_word
|
| 1189 |
+
})
|
| 1190 |
+
|
| 1191 |
+
replacements.append((position, target, randomized_word))
|
| 1192 |
+
|
| 1193 |
+
return replacements, sampling_results
|
| 1194 |
+
|
| 1195 |
+
# Legacy function for compatibility
|
| 1196 |
+
def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75):
|
| 1197 |
+
"""
|
| 1198 |
+
Legacy single-target lookup function for compatibility.
|
| 1199 |
+
"""
|
| 1200 |
+
results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K)
|
| 1201 |
+
return results.get((sentence, target, index), None)
|
| 1202 |
+
|
| 1203 |
+
def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20):
|
| 1204 |
+
"""
|
| 1205 |
+
Legacy batch MLM inference function for compatibility.
|
| 1206 |
+
"""
|
| 1207 |
+
return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K)
|
| 1208 |
+
|
| 1209 |
+
def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4):
|
| 1210 |
+
"""
|
| 1211 |
+
Optimized batch lookup using the new batch MLM inference.
|
| 1212 |
+
"""
|
| 1213 |
+
return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K)
|