SafeSeal / utils_final.py
kirudang's picture
Sync SafeSeal app
fc6dcab
import os
import re
import time
import math
import torch
import string
import spacy
import pandas as pd
import numpy as np
import nltk
import sys
import subprocess
from nltk.tokenize import word_tokenize
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.corpus import wordnet as wn
import json
from filelock import FileLock
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from functools import lru_cache
from typing import List, Tuple, Dict, Any
import multiprocessing as mp
# Ensure the HF_HOME environment variable points to your desired cache location
# Token removed for security
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
# Handle potential import conflicts with sentence_transformers
try:
# Try to import bert_score directly to avoid sentence_transformers conflicts
from bert_score import score as bert_score
SIMILARITY_AVAILABLE = True
def calc_scores_bert(original_sentence, substitute_sentences):
"""BERTScore function using direct bert_score import."""
try:
# Safety check: truncate inputs if they're too long
max_chars = 2000 # Roughly 500 tokens
if len(original_sentence) > max_chars:
original_sentence = original_sentence[:max_chars]
truncated_substitutes = []
for sub in substitute_sentences:
if len(sub) > max_chars:
sub = sub[:max_chars]
truncated_substitutes.append(sub)
references = [original_sentence] * len(truncated_substitutes)
P, R, F1 = bert_score(
cands=truncated_substitutes,
refs=references,
model_type="bert-base-uncased",
verbose=False
)
return F1.tolist()
except Exception as e:
return [0.5] * len(substitute_sentences)
def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
"""Similarity function using direct bert_score import."""
if method == 'bert':
return calc_scores_bert(original_sentence, substitute_sentences)
else:
return [0.5] * len(substitute_sentences)
except ImportError as e:
print(f"Warning: bert_score import failed: {e}")
print("Falling back to neutral similarity scores...")
SIMILARITY_AVAILABLE = False
def calc_scores_bert(original_sentence, substitute_sentences):
"""Fallback BERTScore function with neutral scores."""
return [0.5] * len(substitute_sentences)
def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
"""Fallback similarity function with neutral scores."""
return [0.5] * len(substitute_sentences)
# Setup NLTK data
def setup_nltk_data():
"""Setup NLTK data with error handling."""
try:
nltk.download('punkt_tab', quiet=True)
except:
pass
try:
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
except:
pass
try:
nltk.download('wordnet', quiet=True)
except:
pass
try:
nltk.download('omw-1.4', quiet=True)
except:
pass
setup_nltk_data()
lemmatizer = WordNetLemmatizer()
# Load spaCy model - download if not available
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Define the detailed whitelist of POS tags (excluding adverbs)
DETAILED_POS_WHITELIST = {
'NN', # Noun, singular or mass (e.g., dog, car)
'NNS', # Noun, plural (e.g., dogs, cars)
'VB', # Verb, base form (e.g., run, eat)
'VBD', # Verb, past tense (e.g., ran, ate)
'VBG', # Verb, gerund or present participle (e.g., running, eating)
'VBN', # Verb, past participle (e.g., run, eaten)
'VBP', # Verb, non-3rd person singular present (e.g., run, eat)
'VBZ', # Verb, 3rd person singular present (e.g., runs, eats)
'JJ', # Adjective (e.g., big, blue)
'JJR', # Adjective, comparative (e.g., bigger, bluer)
'JJS', # Adjective, superlative (e.g., biggest, bluest)
'RB', # Adverb (e.g., very, silently)
'RBR', # Adverb, comparative (e.g., better)
'RBS' # Adverb, superlative (e.g., best)
}
# Global caches for better performance
_pos_cache = {}
_antonym_cache = {}
_word_validity_cache = {}
def extract_entities_and_pos(text):
"""
Detect eligible tokens for replacement while skipping:
- Named entities (e.g., names, locations, organizations).
- Compound words (e.g., "Opteron-based").
- Phrasal verbs (e.g., "make up", "focus on").
- Punctuation and non-POS-whitelisted tokens.
"""
doc = nlp(text)
sentence_target_pairs = [] # List to hold (sentence, target word, token index)
for sent in doc.sents:
for token in sent:
# Skip named entities using token.ent_type_ (more reliable than a text match)
if token.ent_type_:
continue
# Skip standalone punctuation
if token.is_punct:
continue
# Skip compound words (e.g., "Opteron-based")
if "-" in token.text or token.dep_ in {"compound", "amod"}:
continue
# Skip phrasal verbs (e.g., "make up", "focus on")
if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children):
continue
# Include regular tokens matching the POS whitelist
if token.tag_ in DETAILED_POS_WHITELIST:
sentence_target_pairs.append((sent.text, token.text, token.i))
return sentence_target_pairs
def preprocess_text(text):
"""
Preprocesses the text to handle abbreviations, titles, and edge cases
where a period or other punctuation does not signify a sentence end.
Ensures figures, acronyms, and short names are left untouched.
"""
# Protect common abbreviations like "U.S." and "Corp."
text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1<PERIOD>', text)
# Protect floating-point numbers or ranges like "3.57" or "1.48–2.10"
text = re.sub(r'(\b\d+)\.(\d+)', r'\1<PERIOD>\2', text)
# Avoid modifying standalone single-letter initials in names (e.g., "J. Smith")
text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1<PERIOD>', text)
# Protect acronym-like patterns with dots, such as "F.B.I."
text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<PERIOD>'), text)
return text
def split_sentences(text):
"""
Splits text into sentences while preserving original newlines exactly.
- Protects abbreviations, acronyms, and floating-point numbers.
- Only adds newlines where necessary without duplicating them.
"""
# Step 1: Protect abbreviations, floating numbers, acronyms
text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1<ABBR>', text)
text = re.sub(r'(\b\d+)\.(\d+)', r'\1<FLOAT>\2', text)
text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<ABBR>'), text)
# Step 2: Identify sentence boundaries without duplicating newlines
sentences = []
for line in text.splitlines(keepends=True): # Retain original newlines
# Split only if punctuation marks end a sentence
split_line = re.split(r'(?<=[.!?])\s+', line.strip())
sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line])
# Step 3: Restore protected patterns
return [sent.replace('<ABBR>', '.').replace('<FLOAT>', '.') for sent in sentences]
@lru_cache(maxsize=10000)
def is_valid_word(word):
"""Check if a word is valid using WordNet (cached)."""
return bool(wn.synsets(word))
@lru_cache(maxsize=5000)
def get_word_pos_tags(word):
"""Get POS tags for a word using both NLTK and spaCy (cached)."""
nltk_pos = nltk.pos_tag([word])[0][1]
spacy_pos = nlp(word)[0].pos_
return nltk_pos, spacy_pos
@lru_cache(maxsize=5000)
def get_word_lemma(word):
"""Get lemmatized form of a word (cached)."""
return lemmatizer.lemmatize(word)
@lru_cache(maxsize=2000)
def get_word_antonyms(word):
"""Get antonyms for a word (cached). Includes all lemmas from all synsets."""
target_synsets = wn.synsets(word)
antonyms = set()
# Get antonyms from all synsets and all lemmas
for syn in target_synsets:
for lem in syn.lemmas():
for ant in lem.antonyms():
# Add the antonym word (first part before the dot)
antonyms.add(ant.name().split('.')[0])
# Also add other lemmas of the antonym for completeness
for alt_lem in wn.synsets(ant.name().split('.')[0]):
for alt_ant_lem in alt_lem.lemmas():
antonyms.add(alt_ant_lem.name().split('.')[0])
return antonyms
def _are_semantically_compatible(target, candidate):
"""
Check if target and candidate are semantically compatible for replacement.
Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals).
"""
try:
# Direct check: if target and candidate are both specific terms for crops, animals, etc.
# check if they're NOT near-synonyms
# Agricultural/crop terms that shouldn't be swapped
agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum',
'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume']
# If both are agricultural terms and different, block
if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and
target.lower() != candidate.lower()):
return False
target_synsets = wn.synsets(target)
cand_synsets = wn.synsets(candidate)
if not target_synsets or not cand_synsets:
return True # If no synsets, allow through
# Check if they're near-synonyms (very similar) - if so, allow
# We can use path similarity to check if they're similar enough
max_similarity = 0.0
for t_syn in target_synsets:
for c_syn in cand_synsets:
try:
similarity = t_syn.path_similarity(c_syn) or 0.0
max_similarity = max(max_similarity, similarity)
except:
pass
# If they have high path similarity (>0.5), they're similar enough to allow
if max_similarity > 0.5:
return True
# Otherwise, check if they share common direct hypernyms
target_hypernyms = set()
for syn in target_synsets:
# Get immediate hypernyms (parent concepts)
for hypernym in syn.hypernyms():
target_hypernyms.add(hypernym)
cand_hypernyms = set()
for syn in cand_synsets:
for hypernym in syn.hypernyms():
cand_hypernyms.add(hypernym)
# If they share hypernyms, check if they're both specific instances (not general terms)
common_hypernyms = target_hypernyms & cand_hypernyms
if common_hypernyms:
# Check if both words are specific instances of the same category
# If so, they shouldn't be replaced with each other
# We identify this by checking if their hypernym has many siblings
for hypernym in common_hypernyms:
siblings = hypernym.hyponyms()
# If there are many specific instances (e.g., many crops, many fruits)
# it's likely a category with specific instances that shouldn't be interchanged
if len(siblings) > 3:
# Check if hypernym name suggests a specific category
hypernym_name = hypernym.name().split('.')[0]
category_keywords = [
'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company',
'country', 'city', 'brand', 'product', 'food', 'vehicle'
]
# If the hypernym contains category keywords, these are likely
# specific instances that shouldn't be swapped
if any(keyword in hypernym_name for keyword in category_keywords):
return False
return True
except Exception as e:
# On any error, allow the candidate through (conservative approach)
return True
def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400):
"""
Create context windows around the target sentence for better MLM generation.
Intelligently handles tokenizer length limits by preserving the most relevant context.
Args:
full_text: The complete document text
target_sentence: The sentence containing the target word
target_word: The word to be replaced
tokenizer: The tokenizer to check length limits
max_tokens: Maximum tokens to use for context (leave room for instruction + mask)
Returns:
List of context windows with different levels of context
"""
# Split full text into sentences
sentences = split_sentences(full_text)
# Find the target sentence index
target_sentence_idx = None
for i, sent in enumerate(sentences):
if target_sentence.strip() in sent.strip():
target_sentence_idx = i
break
if target_sentence_idx is None:
return [target_sentence] # Fallback to original sentence
# Create context windows with sentence-prioritized approach
context_windows = []
# Window 1: Just the target sentence (always include)
context_windows.append(target_sentence)
# Window 2: Target sentence + 1 sentence before and after (if fits)
start_idx = max(0, target_sentence_idx - 1)
end_idx = min(len(sentences), target_sentence_idx + 2)
context_window = " ".join(sentences[start_idx:end_idx])
try:
encoded_len = len(tokenizer.encode(context_window))
if encoded_len <= max_tokens:
context_windows.append(context_window)
except Exception as e:
pass
# Window 3: Target sentence + 2 sentences before and after (if fits)
start_idx = max(0, target_sentence_idx - 2)
end_idx = min(len(sentences), target_sentence_idx + 3)
context_window = " ".join(sentences[start_idx:end_idx])
try:
encoded_len = len(tokenizer.encode(context_window))
if encoded_len <= max_tokens:
context_windows.append(context_window)
except Exception as e:
pass
# Window 4: Target sentence + 3 sentences before and after (if fits)
start_idx = max(0, target_sentence_idx - 3)
end_idx = min(len(sentences), target_sentence_idx + 4)
context_window = " ".join(sentences[start_idx:end_idx])
try:
encoded_len = len(tokenizer.encode(context_window))
if encoded_len <= max_tokens:
context_windows.append(context_window)
except Exception as e:
pass
# Window 5: Intelligent context with sentence prioritization + word expansion
intelligent_context = _create_intelligent_context(
full_text, target_word, target_sentence_idx, tokenizer, max_tokens
)
context_windows.append(intelligent_context)
return context_windows
def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens):
"""
Create intelligent context that prioritizes sentence boundaries while respecting token limits.
Strategy: Target sentence → Nearby sentences → Word-level expansion
"""
sentences = split_sentences(full_text)
# Strategy 1: Always start with the target sentence
target_sentence = sentences[target_sentence_idx]
try:
target_sentence_tokens = len(tokenizer.encode(target_sentence))
except Exception as e:
target_sentence_tokens = 1000 # Fallback to assume it's too long
if target_sentence_tokens > max_tokens:
# If even target sentence is too long, truncate intelligently
return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens)
# Strategy 2: Expand sentence-by-sentence around target sentence
best_context = target_sentence
best_token_count = target_sentence_tokens
# Try adding sentences before and after the target sentence
for sentence_radius in range(1, min(len(sentences), 20)): # Max 20 sentences radius
start_idx = max(0, target_sentence_idx - sentence_radius)
end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1)
# Create context with complete sentences
context_sentences = sentences[start_idx:end_idx]
context_window = " ".join(context_sentences)
try:
token_count = len(tokenizer.encode(context_window))
except Exception as e:
token_count = 1000 # Fallback to assume it's too long
if token_count <= max_tokens:
# This sentence expansion fits, keep it as our best option
best_context = context_window
best_token_count = token_count
else:
# This expansion is too big, stop here
break
# Strategy 3: If we have room left, try word-level expansion within the best sentence context
remaining_tokens = max_tokens - best_token_count
if remaining_tokens > 50: # If we have significant room left
enhanced_context = _enhance_with_word_expansion(
full_text, target_word, best_context, tokenizer, remaining_tokens
)
if enhanced_context:
return enhanced_context
return best_context
def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens):
"""
Enhance the current sentence-based context with word-level expansion if there's room.
"""
words = full_text.split()
target_word_idx = None
# Find target word position in full text
for i, word in enumerate(words):
if word.lower() == target_word.lower():
target_word_idx = i
break
if target_word_idx is None:
return current_context
# Try to expand word-by-word around the target word
try:
current_tokens = len(tokenizer.encode(current_context))
except Exception as e:
print(f"WARNING: Error encoding current context: {e}")
current_tokens = 1000 # Fallback to assume it's too long
for expansion_size in range(1, min(len(words), 100)): # Max 100 words expansion
start_word = max(0, target_word_idx - expansion_size)
end_word = min(len(words), target_word_idx + expansion_size + 1)
expanded_context = " ".join(words[start_word:end_word])
try:
expanded_tokens = len(tokenizer.encode(expanded_context))
except Exception as e:
expanded_tokens = 1000 # Fallback to assume it's too long
if expanded_tokens <= current_tokens + remaining_tokens:
# This expansion fits within our remaining token budget
return expanded_context
else:
# This expansion is too big, stop here
break
return current_context
def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens):
"""
Intelligently truncate a sentence while preserving context around the target word.
"""
words = sentence.split()
target_word_idx = None
# Find target word position
for i, word in enumerate(words):
if word.lower() == target_word.lower():
target_word_idx = i
break
if target_word_idx is None:
# If target word not found, truncate from the end
truncated = " ".join(words)
try:
while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1:
words = words[:-1]
truncated = " ".join(words)
except Exception as e:
# Fallback: return first few words
truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
return truncated
# Truncate symmetrically around target word
context_words = 10 # Start with 10 words before/after
while context_words > 0:
start_word = max(0, target_word_idx - context_words)
end_word = min(len(words), target_word_idx + context_words + 1)
truncated_sentence = " ".join(words[start_word:end_word])
try:
if len(tokenizer.encode(truncated_sentence)) <= max_tokens:
return truncated_sentence
except Exception as e:
# Continue to next iteration
pass
context_words -= 1
# Fallback: just the target word with minimal context
return f"... {target_word} ..."
def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None):
"""
Intelligently slice input text to fit within max_length tokens while preserving the mask token.
Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas.
Args:
input_text: The full input text to be tokenized
tokenizer: The tokenizer to use
max_length: Maximum allowed sequence length (default 512)
mask_token_id: The mask token ID to preserve
Returns:
Tuple of (sliced_input_ids, mask_position_in_sliced)
"""
# First, tokenize the full input
input_ids = tokenizer.encode(input_text, add_special_tokens=True)
# If already within limits, return as is
if len(input_ids) <= max_length:
mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None
return input_ids, mask_pos
# Find mask token position
mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id]
if not mask_positions:
# No mask token found, truncate from the end
return input_ids[:max_length], None
mask_pos = mask_positions[0] # Use first mask token
# Calculate how many tokens we need to remove
excess_tokens = len(input_ids) - max_length
# Strategy: Remove tokens from both ends while preserving mask context
# Reserve some context around the mask token
mask_context_size = min(50, max_length // 4) # Reserve 25% of max_length or 50 tokens, whichever is smaller
# Calculate available space for context around mask
available_before = min(mask_pos, mask_context_size)
available_after = min(len(input_ids) - mask_pos - 1, mask_context_size)
# Calculate how much to remove from each end
tokens_to_remove_before = max(0, mask_pos - available_before)
tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after)
# Initialize removal variables
remove_before = 0
remove_after = 0
# Distribute excess tokens proportionally
if excess_tokens > 0:
if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens:
# We can remove enough from the ends
if tokens_to_remove_before >= excess_tokens // 2:
remove_before = excess_tokens // 2
remove_after = excess_tokens - remove_before
else:
remove_before = tokens_to_remove_before
remove_after = min(tokens_to_remove_after, excess_tokens - remove_before)
else:
# Need to remove more aggressively
remove_before = tokens_to_remove_before
remove_after = tokens_to_remove_after
remaining_excess = excess_tokens - remove_before - remove_after
# Remove remaining excess from the end
if remaining_excess > 0:
remove_after += remaining_excess
# Calculate final indices
start_idx = remove_before
end_idx = len(input_ids) - remove_after
# Ensure we don't exceed max_length
if end_idx - start_idx > max_length:
# Center around mask token
half_length = max_length // 2
start_idx = max(0, mask_pos - half_length)
end_idx = min(len(input_ids), start_idx + max_length)
# Slice the input_ids
sliced_input_ids = input_ids[start_idx:end_idx]
# Debug information
if len(sliced_input_ids) > max_length:
# Force truncation as final fallback
sliced_input_ids = sliced_input_ids[:max_length]
# Adjust mask position for the sliced sequence
adjusted_mask_pos = mask_pos - start_idx
return sliced_input_ids, adjusted_mask_pos
def _create_word_level_context(full_text, target_word, tokenizer, max_tokens):
"""
Create context by expanding word-by-word around the target word until reaching token limit.
This maximizes context while respecting tokenizer limits.
"""
words = full_text.split()
target_word_idx = None
# Find target word position in full text
for i, word in enumerate(words):
if word.lower() == target_word.lower():
target_word_idx = i
break
if target_word_idx is None:
# Fallback: expand from beginning until token limit
return _expand_from_start(words, tokenizer, max_tokens)
# Word-by-word expansion around target word
return _expand_around_target(words, target_word_idx, tokenizer, max_tokens)
def _expand_around_target(words, target_idx, tokenizer, max_tokens):
"""
Expand word-by-word around target word until reaching token limit.
"""
best_context = ""
best_token_count = 0
# Try different expansion sizes
for expansion_size in range(1, min(len(words), 200)): # Max 200 words expansion
start_word = max(0, target_idx - expansion_size)
end_word = min(len(words), target_idx + expansion_size + 1)
context_window = " ".join(words[start_word:end_word])
try:
token_count = len(tokenizer.encode(context_window))
except Exception as e:
token_count = 1000 # Fallback to assume it's too long
if token_count <= max_tokens:
# This expansion fits, keep it as our best option
best_context = context_window
best_token_count = token_count
else:
# This expansion is too big, stop here
break
# If we found a good context, return it
if best_context:
return best_context
# Fallback: minimal context around target word
start_word = max(0, target_idx - 5)
end_word = min(len(words), target_idx + 6)
return " ".join(words[start_word:end_word])
def _expand_from_start(words, tokenizer, max_tokens):
"""
Expand from the start of the text until reaching token limit.
"""
for end_idx in range(len(words), 0, -1):
context_window = " ".join(words[:end_idx])
try:
if len(tokenizer.encode(context_window)) <= max_tokens:
return context_window
except Exception as e:
# Continue to next iteration
pass
# Fallback: first few words
return " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
"""
Enhanced MLM inference using whole document context for better candidate generation.
"""
results = {}
# Group targets by sentence for batch processing
sentence_groups = {}
for sent, target, index in sentence_target_pairs:
if sent not in sentence_groups:
sentence_groups[sent] = []
sentence_groups[sent].append((target, index))
for sentence, targets in sentence_groups.items():
# Process targets in batches
for i in range(0, len(targets), batch_size):
batch_targets = targets[i:i+batch_size]
batch_results = _process_whole_context_mlm_batch(
full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode
)
results.update(batch_results)
return results
def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
"""
Process a batch of targets using whole document context for MLM.
"""
results = {}
# Tokenize sentence once
doc = nlp(sentence)
tokens = [token.text for token in doc]
# Create multiple masked versions for batch processing
masked_inputs = []
mask_positions = []
contexts_for_targets = []
for target, index in targets:
if index < len(tokens):
# Create context windows with tokenizer length awareness
context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens)
# Use the most comprehensive context window that fits within token limits
full_context = context_windows[-1] # Built around the target sentence
# Select context for similarity according to mode
context = sentence if similarity_context_mode == 'sentence' else full_context
# Create masked version of the FULL context (not just the sentence)
masked_full_context = context.replace(target, tokenizer.mask_token, 1)
instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:"
input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}"
# AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors
# Estimate token count (roughly 1 token per 4 characters for English)
estimated_tokens = len(input_text) // 4
if estimated_tokens > 500: # Leave some buffer
# Truncate to roughly 2000 characters (500 tokens)
input_text = input_text[:2000]
# SIMPLE FIX: Truncate input text if it's too long
try:
temp_tokens = tokenizer.encode(input_text, add_special_tokens=True)
if len(temp_tokens) > 512:
# Truncate the input text by removing words from the end
words = input_text.split()
while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10:
words = words[:-1]
input_text = " ".join(words)
except Exception as e:
# Emergency truncation - just take first 200 words
words = input_text.split()
input_text = " ".join(words[:200])
masked_inputs.append(input_text)
# Store the original sentence-level index for reference, but mask position will be calculated during tokenization
mask_positions.append(index)
contexts_for_targets.append(context)
if not masked_inputs:
return results
# Batch tokenize
MAX_LENGTH = max_length # Use parameter for A100 optimization
batch_inputs = []
batch_mask_positions = []
batch_contexts = []
for input_text, mask_pos in zip(masked_inputs, mask_positions):
# Use intelligent token slicing to ensure we stay within MAX_LENGTH
try:
input_ids, adjusted_mask_pos = _intelligent_token_slicing(
input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id
)
if adjusted_mask_pos is not None:
batch_inputs.append(input_ids)
batch_mask_positions.append(adjusted_mask_pos)
else:
# Mask token not found in sliced sequence, skip this input
continue
except Exception as e:
# Fallback: simple truncation
try:
input_ids = tokenizer.encode(input_text, add_special_tokens=True)
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
masked_position = input_ids.index(tokenizer.mask_token_id)
batch_inputs.append(input_ids)
batch_mask_positions.append(masked_position)
except ValueError:
# Mask token not found, skip this input
continue
if not batch_inputs:
return results
# Pad sequences to same length, but ensure we don't exceed MAX_LENGTH
max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH)
# Additional safety check: truncate any sequences that are still too long
truncated_batch_inputs = []
for input_ids in batch_inputs:
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
truncated_batch_inputs.append(input_ids)
padded_inputs = []
attention_masks = []
for input_ids in truncated_batch_inputs:
attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids))
padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
padded_inputs.append(padded_ids)
attention_masks.append(attention_mask)
# Final safety check: ensure no sequence exceeds MAX_LENGTH
for i, padded_ids in enumerate(padded_inputs):
if len(padded_ids) > MAX_LENGTH:
padded_inputs[i] = padded_ids[:MAX_LENGTH]
attention_masks[i] = attention_masks[i][:MAX_LENGTH]
# Batch inference - optimized for A100 with mixed precision
with torch.no_grad():
input_tensor = torch.tensor(padded_inputs, dtype=torch.long)
attention_tensor = torch.tensor(attention_masks, dtype=torch.long)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
attention_tensor = attention_tensor.cuda()
# Use mixed precision for A100 optimization
with torch.amp.autocast('cuda'):
outputs = lm_model(input_tensor, attention_mask=attention_tensor)
logits = outputs.logits
else:
outputs = lm_model(input_tensor, attention_mask=attention_tensor)
logits = outputs.logits
# Process results - collect filtered candidates first
batch_filtered_results = {}
for i, (target, index) in enumerate(targets):
if i < len(batch_mask_positions):
mask_pos = batch_mask_positions[i]
mask_logits = logits[i, mask_pos].squeeze()
# Get top predictions
top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1]
scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist()
words = [tokenizer.decode(token.item()).strip() for token in top_tokens]
# Filter candidates (without similarity scoring)
filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index)
if filtered_candidates:
# Attach the exact context window used for this target
batch_filtered_results[(sentence, target, index)] = {
'filtered_words': filtered_candidates,
'context': contexts_for_targets[i]
}
# Batch similarity scoring for all candidates
if batch_filtered_results:
similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer)
results.update(similarity_results)
return results
def _filter_candidates_batch(target, words, scores, tokens, target_index):
"""
Optimized batch filtering of candidates (no similarity scoring here - moved to batch level).
"""
# Basic filtering
filtered_words = []
filtered_scores = []
seen_words = set()
for word, score in zip(words, scores):
word_lower = word.lower()
if word_lower in seen_words or word_lower == target.lower():
continue
seen_words.add(word_lower)
if not is_valid_word(word):
continue
# Quick POS check
target_nltk_pos, target_spacy_pos = get_word_pos_tags(target)
cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word)
if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos:
continue
# Check antonyms (bidirectional and case-insensitive)
antonyms = get_word_antonyms(target)
if word.lower() in [ant.lower() for ant in antonyms]:
continue
# Also check if the candidate has the target as an antonym (reverse check)
candidate_antonyms = get_word_antonyms(word)
if target.lower() in [ant.lower() for ant in candidate_antonyms]:
continue
# Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard)
common_antonyms = {
'big': ['small', 'tiny', 'little'],
'small': ['big', 'large', 'huge'],
'large': ['small', 'tiny', 'little'],
'good': ['bad', 'evil', 'wrong'],
'bad': ['good', 'great', 'excellent'],
'high': ['low'],
'low': ['high'],
'new': ['old'],
'old': ['new'],
'fast': ['slow'],
'slow': ['fast'],
'rich': ['poor'],
'poor': ['rich'],
'hot': ['cold'],
'cold': ['hot'],
'happy': ['sad', 'unhappy'],
'sad': ['happy', 'joyful'],
'true': ['false', 'untrue'],
'false': ['true'],
'real': ['fake', 'unreal'],
'fake': ['real'],
'up': ['down'],
'down': ['up'],
'yes': ['no'],
'no': ['yes'],
'alive': ['dead'],
'dead': ['alive'],
'safe': ['unsafe', 'dangerous'],
'dangerous': ['safe'],
'clean': ['dirty'],
'dirty': ['clean'],
'full': ['empty'],
'empty': ['full'],
'open': ['closed', 'shut'],
'closed': ['open'],
'begin': ['end', 'finish'],
'end': ['begin', 'start'],
'start': ['end', 'finish'],
'finish': ['start', 'begin'],
'first': ['last'],
'last': ['first']
}
# Check if word is a known antonym of target (case-insensitive)
target_lower = target.lower()
if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]:
continue
# Check if word and target are in the same specific noun category (e.g., crops, animals, companies)
# If they are different specific terms in the same category, exclude the candidate
if not _are_semantically_compatible(target, word):
continue
filtered_words.append(word)
filtered_scores.append(score)
if len(filtered_words) < 2:
return None
# Return filtered words without similarity scoring (done at batch level)
return filtered_words
def _batch_similarity_scoring(batch_results, tokenizer):
"""
Optimized batched similarity scoring across multiple sentences for full context.
Processes all candidates from multiple sentences together for better efficiency.
"""
# Collect all similarity scoring tasks
similarity_tasks = []
sentence_contexts = {}
for (sentence, target, index), value in batch_results.items():
if value is None:
continue
# Support both legacy list and new dict with context
if isinstance(value, dict):
filtered_words = value.get('filtered_words')
context = value.get('context', sentence)
else:
filtered_words = value
context = sentence
# Tokenize the sentence once
tokens = tokenizer.tokenize(sentence)
if index >= len(tokens):
continue
# Store sentence context for later use
sentence_contexts[(sentence, target, index)] = {
'tokens': tokens,
'target_index': index,
'filtered_words': filtered_words
}
# Create candidate sentences for this target
for word in filtered_words:
candidate_tokens = tokens.copy()
candidate_tokens[index] = word
candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens)
# Build full-context candidate by replacing the sentence inside the chosen context once
candidate_full_context = context.replace(sentence, candidate_sentence, 1)
similarity_tasks.append({
'original_context': context,
'candidate_full_context': candidate_full_context,
'target_word': word,
'context_key': (sentence, target, index)
})
if not similarity_tasks:
return {}
# Batch process all similarity scoring
try:
# Group by original full context for efficient BERTScore computation
context_groups = {}
for task in similarity_tasks:
orig_ctx = task['original_context']
if orig_ctx not in context_groups:
context_groups[orig_ctx] = []
context_groups[orig_ctx].append(task)
# Process each context group
final_results = {}
for orig_context, tasks in context_groups.items():
# Extract candidate full-contexts
candidate_contexts = [task['candidate_full_context'] for task in tasks]
# Batch BERTScore computation against the same full context
try:
similarity_scores = calc_scores_bert(orig_context, candidate_contexts)
except Exception as e:
# Fallback to neutral scores
similarity_scores = [0.5] * len(candidate_contexts)
if similarity_scores and not all(score == 0.5 for score in similarity_scores):
# Group results by context key
for task, score in zip(tasks, similarity_scores):
context_key = task['context_key']
if context_key not in final_results:
final_results[context_key] = []
final_results[context_key].append((task['target_word'], score))
# Sort results by similarity score
for context_key in final_results:
final_results[context_key].sort(key=lambda x: x[1], reverse=True)
return final_results
except Exception as e:
return {}
def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha):
"""
Parallel tournament sampling for multiple targets.
"""
results = {}
if not target_results:
return results
def process_single_tournament(item):
(sentence, target, index), candidates = item
if not candidates:
return (sentence, target, index), None
alternatives = [alt[0] for alt in candidates]
similarity = [alt[1] for alt in candidates]
if not alternatives or not similarity:
return (sentence, target, index), None
# Get context
context_tokens = word_tokenize(sentence)
left_context = context_tokens[max(0, index - h):index]
# Tournament selection
from SynthID_randomization import tournament_select_word
randomized_word = tournament_select_word(
target, alternatives, similarity,
context=left_context, key=secret_key, m=m, c=c, alpha=alpha
)
return (sentence, target, index), randomized_word
# Process in parallel
max_workers = max(1, min(8, len(target_results)))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()}
for future in as_completed(future_to_item):
key, result = future.result()
results[key] = result
return results
def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'):
"""
Enhanced sentence processing using whole document context for better candidate generation.
"""
replacements = []
sampling_results = []
doc = nlp(sentence)
sentence_target_pairs = extract_entities_and_pos(sentence)
if not sentence_target_pairs:
return replacements, sampling_results
# Filter valid target pairs
valid_pairs = []
spacy_tokens = [token.text for token in doc]
for sent, target, position in sentence_target_pairs:
if position < len(spacy_tokens) and spacy_tokens[position] == target:
valid_pairs.append((sent, target, position))
if not valid_pairs:
return replacements, sampling_results
# Enhanced MLM inference with whole document context
batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode)
# Filter by threshold (matching original logic)
filtered_results = {}
for key, candidates in batch_results.items():
if candidates:
# Apply threshold filtering (matching original logic)
threshold_candidates = [(word, score) for word, score in candidates if score >= threshold]
if len(threshold_candidates) >= 2:
filtered_results[key] = threshold_candidates
# Parallel tournament sampling
tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha)
# Collect replacements and sampling results
for (sent, target, position), randomized_word in tournament_results.items():
if randomized_word:
# Get the alternatives for this target from the filtered results
alternatives = filtered_results.get((sent, target, position), [])
alternatives_list = [alt[0] for alt in alternatives]
# Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility)
alternatives_with_similarity = [
{"word": alt[0], "similarity": float(alt[1])} for alt in alternatives
]
# Track sampling results
sampling_results.append({
"word": target,
"alternatives": alternatives_list,
"alternatives_with_similarity": alternatives_with_similarity,
"randomized_word": randomized_word
})
replacements.append((position, target, randomized_word))
return replacements, sampling_results
# Legacy function for compatibility
def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75):
"""
Legacy single-target lookup function for compatibility.
"""
results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K)
return results.get((sentence, target, index), None)
def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20):
"""
Legacy batch MLM inference function for compatibility.
"""
return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K)
def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4):
"""
Optimized batch lookup using the new batch MLM inference.
"""
return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K)