import os import re import time import math import torch import string import spacy import pandas as pd import numpy as np import nltk import sys import subprocess from nltk.tokenize import word_tokenize from nltk.stem.wordnet import WordNetLemmatizer from nltk.corpus import wordnet as wn import json from filelock import FileLock from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from functools import lru_cache from typing import List, Tuple, Dict, Any import multiprocessing as mp # Ensure the HF_HOME environment variable points to your desired cache location # Token removed for security cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' # Handle potential import conflicts with sentence_transformers try: # Try to import bert_score directly to avoid sentence_transformers conflicts from bert_score import score as bert_score SIMILARITY_AVAILABLE = True def calc_scores_bert(original_sentence, substitute_sentences): """BERTScore function using direct bert_score import.""" try: # Safety check: truncate inputs if they're too long max_chars = 2000 # Roughly 500 tokens if len(original_sentence) > max_chars: original_sentence = original_sentence[:max_chars] truncated_substitutes = [] for sub in substitute_sentences: if len(sub) > max_chars: sub = sub[:max_chars] truncated_substitutes.append(sub) references = [original_sentence] * len(truncated_substitutes) P, R, F1 = bert_score( cands=truncated_substitutes, refs=references, model_type="bert-base-uncased", verbose=False ) return F1.tolist() except Exception as e: return [0.5] * len(substitute_sentences) def get_similarity_scores(original_sentence, substitute_sentences, method='bert'): """Similarity function using direct bert_score import.""" if method == 'bert': return calc_scores_bert(original_sentence, substitute_sentences) else: return [0.5] * len(substitute_sentences) except ImportError as e: print(f"Warning: bert_score import failed: {e}") print("Falling back to neutral similarity scores...") SIMILARITY_AVAILABLE = False def calc_scores_bert(original_sentence, substitute_sentences): """Fallback BERTScore function with neutral scores.""" return [0.5] * len(substitute_sentences) def get_similarity_scores(original_sentence, substitute_sentences, method='bert'): """Fallback similarity function with neutral scores.""" return [0.5] * len(substitute_sentences) # Setup NLTK data def setup_nltk_data(): """Setup NLTK data with error handling.""" try: nltk.download('punkt_tab', quiet=True) except: pass try: nltk.download('averaged_perceptron_tagger_eng', quiet=True) except: pass try: nltk.download('wordnet', quiet=True) except: pass try: nltk.download('omw-1.4', quiet=True) except: pass setup_nltk_data() lemmatizer = WordNetLemmatizer() # Load spaCy model - download if not available try: nlp = spacy.load("en_core_web_sm") except OSError: print("Downloading spaCy model...") subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"]) nlp = spacy.load("en_core_web_sm") # Define the detailed whitelist of POS tags (excluding adverbs) DETAILED_POS_WHITELIST = { 'NN', # Noun, singular or mass (e.g., dog, car) 'NNS', # Noun, plural (e.g., dogs, cars) 'VB', # Verb, base form (e.g., run, eat) 'VBD', # Verb, past tense (e.g., ran, ate) 'VBG', # Verb, gerund or present participle (e.g., running, eating) 'VBN', # Verb, past participle (e.g., run, eaten) 'VBP', # Verb, non-3rd person singular present (e.g., run, eat) 'VBZ', # Verb, 3rd person singular present (e.g., runs, eats) 'JJ', # Adjective (e.g., big, blue) 'JJR', # Adjective, comparative (e.g., bigger, bluer) 'JJS', # Adjective, superlative (e.g., biggest, bluest) 'RB', # Adverb (e.g., very, silently) 'RBR', # Adverb, comparative (e.g., better) 'RBS' # Adverb, superlative (e.g., best) } # Global caches for better performance _pos_cache = {} _antonym_cache = {} _word_validity_cache = {} def extract_entities_and_pos(text): """ Detect eligible tokens for replacement while skipping: - Named entities (e.g., names, locations, organizations). - Compound words (e.g., "Opteron-based"). - Phrasal verbs (e.g., "make up", "focus on"). - Punctuation and non-POS-whitelisted tokens. """ doc = nlp(text) sentence_target_pairs = [] # List to hold (sentence, target word, token index) for sent in doc.sents: for token in sent: # Skip named entities using token.ent_type_ (more reliable than a text match) if token.ent_type_: continue # Skip standalone punctuation if token.is_punct: continue # Skip compound words (e.g., "Opteron-based") if "-" in token.text or token.dep_ in {"compound", "amod"}: continue # Skip phrasal verbs (e.g., "make up", "focus on") if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children): continue # Include regular tokens matching the POS whitelist if token.tag_ in DETAILED_POS_WHITELIST: sentence_target_pairs.append((sent.text, token.text, token.i)) return sentence_target_pairs def preprocess_text(text): """ Preprocesses the text to handle abbreviations, titles, and edge cases where a period or other punctuation does not signify a sentence end. Ensures figures, acronyms, and short names are left untouched. """ # Protect common abbreviations like "U.S." and "Corp." text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1', text) # Protect floating-point numbers or ranges like "3.57" or "1.48–2.10" text = re.sub(r'(\b\d+)\.(\d+)', r'\1\2', text) # Avoid modifying standalone single-letter initials in names (e.g., "J. Smith") text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1', text) # Protect acronym-like patterns with dots, such as "F.B.I." text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', ''), text) return text def split_sentences(text): """ Splits text into sentences while preserving original newlines exactly. - Protects abbreviations, acronyms, and floating-point numbers. - Only adds newlines where necessary without duplicating them. """ # Step 1: Protect abbreviations, floating numbers, acronyms text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1', text) text = re.sub(r'(\b\d+)\.(\d+)', r'\1\2', text) text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', ''), text) # Step 2: Identify sentence boundaries without duplicating newlines sentences = [] for line in text.splitlines(keepends=True): # Retain original newlines # Split only if punctuation marks end a sentence split_line = re.split(r'(?<=[.!?])\s+', line.strip()) sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line]) # Step 3: Restore protected patterns return [sent.replace('', '.').replace('', '.') for sent in sentences] @lru_cache(maxsize=10000) def is_valid_word(word): """Check if a word is valid using WordNet (cached).""" return bool(wn.synsets(word)) @lru_cache(maxsize=5000) def get_word_pos_tags(word): """Get POS tags for a word using both NLTK and spaCy (cached).""" nltk_pos = nltk.pos_tag([word])[0][1] spacy_pos = nlp(word)[0].pos_ return nltk_pos, spacy_pos @lru_cache(maxsize=5000) def get_word_lemma(word): """Get lemmatized form of a word (cached).""" return lemmatizer.lemmatize(word) @lru_cache(maxsize=2000) def get_word_antonyms(word): """Get antonyms for a word (cached). Includes all lemmas from all synsets.""" target_synsets = wn.synsets(word) antonyms = set() # Get antonyms from all synsets and all lemmas for syn in target_synsets: for lem in syn.lemmas(): for ant in lem.antonyms(): # Add the antonym word (first part before the dot) antonyms.add(ant.name().split('.')[0]) # Also add other lemmas of the antonym for completeness for alt_lem in wn.synsets(ant.name().split('.')[0]): for alt_ant_lem in alt_lem.lemmas(): antonyms.add(alt_ant_lem.name().split('.')[0]) return antonyms def _are_semantically_compatible(target, candidate): """ Check if target and candidate are semantically compatible for replacement. Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals). """ try: # Direct check: if target and candidate are both specific terms for crops, animals, etc. # check if they're NOT near-synonyms # Agricultural/crop terms that shouldn't be swapped agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum', 'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume'] # If both are agricultural terms and different, block if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and target.lower() != candidate.lower()): return False target_synsets = wn.synsets(target) cand_synsets = wn.synsets(candidate) if not target_synsets or not cand_synsets: return True # If no synsets, allow through # Check if they're near-synonyms (very similar) - if so, allow # We can use path similarity to check if they're similar enough max_similarity = 0.0 for t_syn in target_synsets: for c_syn in cand_synsets: try: similarity = t_syn.path_similarity(c_syn) or 0.0 max_similarity = max(max_similarity, similarity) except: pass # If they have high path similarity (>0.5), they're similar enough to allow if max_similarity > 0.5: return True # Otherwise, check if they share common direct hypernyms target_hypernyms = set() for syn in target_synsets: # Get immediate hypernyms (parent concepts) for hypernym in syn.hypernyms(): target_hypernyms.add(hypernym) cand_hypernyms = set() for syn in cand_synsets: for hypernym in syn.hypernyms(): cand_hypernyms.add(hypernym) # If they share hypernyms, check if they're both specific instances (not general terms) common_hypernyms = target_hypernyms & cand_hypernyms if common_hypernyms: # Check if both words are specific instances of the same category # If so, they shouldn't be replaced with each other # We identify this by checking if their hypernym has many siblings for hypernym in common_hypernyms: siblings = hypernym.hyponyms() # If there are many specific instances (e.g., many crops, many fruits) # it's likely a category with specific instances that shouldn't be interchanged if len(siblings) > 3: # Check if hypernym name suggests a specific category hypernym_name = hypernym.name().split('.')[0] category_keywords = [ 'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company', 'country', 'city', 'brand', 'product', 'food', 'vehicle' ] # If the hypernym contains category keywords, these are likely # specific instances that shouldn't be swapped if any(keyword in hypernym_name for keyword in category_keywords): return False return True except Exception as e: # On any error, allow the candidate through (conservative approach) return True def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400): """ Create context windows around the target sentence for better MLM generation. Intelligently handles tokenizer length limits by preserving the most relevant context. Args: full_text: The complete document text target_sentence: The sentence containing the target word target_word: The word to be replaced tokenizer: The tokenizer to check length limits max_tokens: Maximum tokens to use for context (leave room for instruction + mask) Returns: List of context windows with different levels of context """ # Split full text into sentences sentences = split_sentences(full_text) # Find the target sentence index target_sentence_idx = None for i, sent in enumerate(sentences): if target_sentence.strip() in sent.strip(): target_sentence_idx = i break if target_sentence_idx is None: return [target_sentence] # Fallback to original sentence # Create context windows with sentence-prioritized approach context_windows = [] # Window 1: Just the target sentence (always include) context_windows.append(target_sentence) # Window 2: Target sentence + 1 sentence before and after (if fits) start_idx = max(0, target_sentence_idx - 1) end_idx = min(len(sentences), target_sentence_idx + 2) context_window = " ".join(sentences[start_idx:end_idx]) try: encoded_len = len(tokenizer.encode(context_window)) if encoded_len <= max_tokens: context_windows.append(context_window) except Exception as e: pass # Window 3: Target sentence + 2 sentences before and after (if fits) start_idx = max(0, target_sentence_idx - 2) end_idx = min(len(sentences), target_sentence_idx + 3) context_window = " ".join(sentences[start_idx:end_idx]) try: encoded_len = len(tokenizer.encode(context_window)) if encoded_len <= max_tokens: context_windows.append(context_window) except Exception as e: pass # Window 4: Target sentence + 3 sentences before and after (if fits) start_idx = max(0, target_sentence_idx - 3) end_idx = min(len(sentences), target_sentence_idx + 4) context_window = " ".join(sentences[start_idx:end_idx]) try: encoded_len = len(tokenizer.encode(context_window)) if encoded_len <= max_tokens: context_windows.append(context_window) except Exception as e: pass # Window 5: Intelligent context with sentence prioritization + word expansion intelligent_context = _create_intelligent_context( full_text, target_word, target_sentence_idx, tokenizer, max_tokens ) context_windows.append(intelligent_context) return context_windows def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens): """ Create intelligent context that prioritizes sentence boundaries while respecting token limits. Strategy: Target sentence → Nearby sentences → Word-level expansion """ sentences = split_sentences(full_text) # Strategy 1: Always start with the target sentence target_sentence = sentences[target_sentence_idx] try: target_sentence_tokens = len(tokenizer.encode(target_sentence)) except Exception as e: target_sentence_tokens = 1000 # Fallback to assume it's too long if target_sentence_tokens > max_tokens: # If even target sentence is too long, truncate intelligently return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens) # Strategy 2: Expand sentence-by-sentence around target sentence best_context = target_sentence best_token_count = target_sentence_tokens # Try adding sentences before and after the target sentence for sentence_radius in range(1, min(len(sentences), 20)): # Max 20 sentences radius start_idx = max(0, target_sentence_idx - sentence_radius) end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1) # Create context with complete sentences context_sentences = sentences[start_idx:end_idx] context_window = " ".join(context_sentences) try: token_count = len(tokenizer.encode(context_window)) except Exception as e: token_count = 1000 # Fallback to assume it's too long if token_count <= max_tokens: # This sentence expansion fits, keep it as our best option best_context = context_window best_token_count = token_count else: # This expansion is too big, stop here break # Strategy 3: If we have room left, try word-level expansion within the best sentence context remaining_tokens = max_tokens - best_token_count if remaining_tokens > 50: # If we have significant room left enhanced_context = _enhance_with_word_expansion( full_text, target_word, best_context, tokenizer, remaining_tokens ) if enhanced_context: return enhanced_context return best_context def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens): """ Enhance the current sentence-based context with word-level expansion if there's room. """ words = full_text.split() target_word_idx = None # Find target word position in full text for i, word in enumerate(words): if word.lower() == target_word.lower(): target_word_idx = i break if target_word_idx is None: return current_context # Try to expand word-by-word around the target word try: current_tokens = len(tokenizer.encode(current_context)) except Exception as e: print(f"WARNING: Error encoding current context: {e}") current_tokens = 1000 # Fallback to assume it's too long for expansion_size in range(1, min(len(words), 100)): # Max 100 words expansion start_word = max(0, target_word_idx - expansion_size) end_word = min(len(words), target_word_idx + expansion_size + 1) expanded_context = " ".join(words[start_word:end_word]) try: expanded_tokens = len(tokenizer.encode(expanded_context)) except Exception as e: expanded_tokens = 1000 # Fallback to assume it's too long if expanded_tokens <= current_tokens + remaining_tokens: # This expansion fits within our remaining token budget return expanded_context else: # This expansion is too big, stop here break return current_context def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens): """ Intelligently truncate a sentence while preserving context around the target word. """ words = sentence.split() target_word_idx = None # Find target word position for i, word in enumerate(words): if word.lower() == target_word.lower(): target_word_idx = i break if target_word_idx is None: # If target word not found, truncate from the end truncated = " ".join(words) try: while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1: words = words[:-1] truncated = " ".join(words) except Exception as e: # Fallback: return first few words truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words) return truncated # Truncate symmetrically around target word context_words = 10 # Start with 10 words before/after while context_words > 0: start_word = max(0, target_word_idx - context_words) end_word = min(len(words), target_word_idx + context_words + 1) truncated_sentence = " ".join(words[start_word:end_word]) try: if len(tokenizer.encode(truncated_sentence)) <= max_tokens: return truncated_sentence except Exception as e: # Continue to next iteration pass context_words -= 1 # Fallback: just the target word with minimal context return f"... {target_word} ..." def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None): """ Intelligently slice input text to fit within max_length tokens while preserving the mask token. Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas. Args: input_text: The full input text to be tokenized tokenizer: The tokenizer to use max_length: Maximum allowed sequence length (default 512) mask_token_id: The mask token ID to preserve Returns: Tuple of (sliced_input_ids, mask_position_in_sliced) """ # First, tokenize the full input input_ids = tokenizer.encode(input_text, add_special_tokens=True) # If already within limits, return as is if len(input_ids) <= max_length: mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None return input_ids, mask_pos # Find mask token position mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id] if not mask_positions: # No mask token found, truncate from the end return input_ids[:max_length], None mask_pos = mask_positions[0] # Use first mask token # Calculate how many tokens we need to remove excess_tokens = len(input_ids) - max_length # Strategy: Remove tokens from both ends while preserving mask context # Reserve some context around the mask token mask_context_size = min(50, max_length // 4) # Reserve 25% of max_length or 50 tokens, whichever is smaller # Calculate available space for context around mask available_before = min(mask_pos, mask_context_size) available_after = min(len(input_ids) - mask_pos - 1, mask_context_size) # Calculate how much to remove from each end tokens_to_remove_before = max(0, mask_pos - available_before) tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after) # Initialize removal variables remove_before = 0 remove_after = 0 # Distribute excess tokens proportionally if excess_tokens > 0: if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens: # We can remove enough from the ends if tokens_to_remove_before >= excess_tokens // 2: remove_before = excess_tokens // 2 remove_after = excess_tokens - remove_before else: remove_before = tokens_to_remove_before remove_after = min(tokens_to_remove_after, excess_tokens - remove_before) else: # Need to remove more aggressively remove_before = tokens_to_remove_before remove_after = tokens_to_remove_after remaining_excess = excess_tokens - remove_before - remove_after # Remove remaining excess from the end if remaining_excess > 0: remove_after += remaining_excess # Calculate final indices start_idx = remove_before end_idx = len(input_ids) - remove_after # Ensure we don't exceed max_length if end_idx - start_idx > max_length: # Center around mask token half_length = max_length // 2 start_idx = max(0, mask_pos - half_length) end_idx = min(len(input_ids), start_idx + max_length) # Slice the input_ids sliced_input_ids = input_ids[start_idx:end_idx] # Debug information if len(sliced_input_ids) > max_length: # Force truncation as final fallback sliced_input_ids = sliced_input_ids[:max_length] # Adjust mask position for the sliced sequence adjusted_mask_pos = mask_pos - start_idx return sliced_input_ids, adjusted_mask_pos def _create_word_level_context(full_text, target_word, tokenizer, max_tokens): """ Create context by expanding word-by-word around the target word until reaching token limit. This maximizes context while respecting tokenizer limits. """ words = full_text.split() target_word_idx = None # Find target word position in full text for i, word in enumerate(words): if word.lower() == target_word.lower(): target_word_idx = i break if target_word_idx is None: # Fallback: expand from beginning until token limit return _expand_from_start(words, tokenizer, max_tokens) # Word-by-word expansion around target word return _expand_around_target(words, target_word_idx, tokenizer, max_tokens) def _expand_around_target(words, target_idx, tokenizer, max_tokens): """ Expand word-by-word around target word until reaching token limit. """ best_context = "" best_token_count = 0 # Try different expansion sizes for expansion_size in range(1, min(len(words), 200)): # Max 200 words expansion start_word = max(0, target_idx - expansion_size) end_word = min(len(words), target_idx + expansion_size + 1) context_window = " ".join(words[start_word:end_word]) try: token_count = len(tokenizer.encode(context_window)) except Exception as e: token_count = 1000 # Fallback to assume it's too long if token_count <= max_tokens: # This expansion fits, keep it as our best option best_context = context_window best_token_count = token_count else: # This expansion is too big, stop here break # If we found a good context, return it if best_context: return best_context # Fallback: minimal context around target word start_word = max(0, target_idx - 5) end_word = min(len(words), target_idx + 6) return " ".join(words[start_word:end_word]) def _expand_from_start(words, tokenizer, max_tokens): """ Expand from the start of the text until reaching token limit. """ for end_idx in range(len(words), 0, -1): context_window = " ".join(words[:end_idx]) try: if len(tokenizer.encode(context_window)) <= max_tokens: return context_window except Exception as e: # Continue to next iteration pass # Fallback: first few words return " ".join(words[:10]) if len(words) >= 10 else " ".join(words) def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'): """ Enhanced MLM inference using whole document context for better candidate generation. """ results = {} # Group targets by sentence for batch processing sentence_groups = {} for sent, target, index in sentence_target_pairs: if sent not in sentence_groups: sentence_groups[sent] = [] sentence_groups[sent].append((target, index)) for sentence, targets in sentence_groups.items(): # Process targets in batches for i in range(0, len(targets), batch_size): batch_targets = targets[i:i+batch_size] batch_results = _process_whole_context_mlm_batch( full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode ) results.update(batch_results) return results def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'): """ Process a batch of targets using whole document context for MLM. """ results = {} # Tokenize sentence once doc = nlp(sentence) tokens = [token.text for token in doc] # Create multiple masked versions for batch processing masked_inputs = [] mask_positions = [] contexts_for_targets = [] for target, index in targets: if index < len(tokens): # Create context windows with tokenizer length awareness context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens) # Use the most comprehensive context window that fits within token limits full_context = context_windows[-1] # Built around the target sentence # Select context for similarity according to mode context = sentence if similarity_context_mode == 'sentence' else full_context # Create masked version of the FULL context (not just the sentence) masked_full_context = context.replace(target, tokenizer.mask_token, 1) instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:" input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}" # AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors # Estimate token count (roughly 1 token per 4 characters for English) estimated_tokens = len(input_text) // 4 if estimated_tokens > 500: # Leave some buffer # Truncate to roughly 2000 characters (500 tokens) input_text = input_text[:2000] # SIMPLE FIX: Truncate input text if it's too long try: temp_tokens = tokenizer.encode(input_text, add_special_tokens=True) if len(temp_tokens) > 512: # Truncate the input text by removing words from the end words = input_text.split() while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10: words = words[:-1] input_text = " ".join(words) except Exception as e: # Emergency truncation - just take first 200 words words = input_text.split() input_text = " ".join(words[:200]) masked_inputs.append(input_text) # Store the original sentence-level index for reference, but mask position will be calculated during tokenization mask_positions.append(index) contexts_for_targets.append(context) if not masked_inputs: return results # Batch tokenize MAX_LENGTH = max_length # Use parameter for A100 optimization batch_inputs = [] batch_mask_positions = [] batch_contexts = [] for input_text, mask_pos in zip(masked_inputs, mask_positions): # Use intelligent token slicing to ensure we stay within MAX_LENGTH try: input_ids, adjusted_mask_pos = _intelligent_token_slicing( input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id ) if adjusted_mask_pos is not None: batch_inputs.append(input_ids) batch_mask_positions.append(adjusted_mask_pos) else: # Mask token not found in sliced sequence, skip this input continue except Exception as e: # Fallback: simple truncation try: input_ids = tokenizer.encode(input_text, add_special_tokens=True) if len(input_ids) > MAX_LENGTH: input_ids = input_ids[:MAX_LENGTH] masked_position = input_ids.index(tokenizer.mask_token_id) batch_inputs.append(input_ids) batch_mask_positions.append(masked_position) except ValueError: # Mask token not found, skip this input continue if not batch_inputs: return results # Pad sequences to same length, but ensure we don't exceed MAX_LENGTH max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH) # Additional safety check: truncate any sequences that are still too long truncated_batch_inputs = [] for input_ids in batch_inputs: if len(input_ids) > MAX_LENGTH: input_ids = input_ids[:MAX_LENGTH] truncated_batch_inputs.append(input_ids) padded_inputs = [] attention_masks = [] for input_ids in truncated_batch_inputs: attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids)) padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) padded_inputs.append(padded_ids) attention_masks.append(attention_mask) # Final safety check: ensure no sequence exceeds MAX_LENGTH for i, padded_ids in enumerate(padded_inputs): if len(padded_ids) > MAX_LENGTH: padded_inputs[i] = padded_ids[:MAX_LENGTH] attention_masks[i] = attention_masks[i][:MAX_LENGTH] # Batch inference - optimized for A100 with mixed precision with torch.no_grad(): input_tensor = torch.tensor(padded_inputs, dtype=torch.long) attention_tensor = torch.tensor(attention_masks, dtype=torch.long) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() attention_tensor = attention_tensor.cuda() # Use mixed precision for A100 optimization with torch.amp.autocast('cuda'): outputs = lm_model(input_tensor, attention_mask=attention_tensor) logits = outputs.logits else: outputs = lm_model(input_tensor, attention_mask=attention_tensor) logits = outputs.logits # Process results - collect filtered candidates first batch_filtered_results = {} for i, (target, index) in enumerate(targets): if i < len(batch_mask_positions): mask_pos = batch_mask_positions[i] mask_logits = logits[i, mask_pos].squeeze() # Get top predictions top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1] scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist() words = [tokenizer.decode(token.item()).strip() for token in top_tokens] # Filter candidates (without similarity scoring) filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index) if filtered_candidates: # Attach the exact context window used for this target batch_filtered_results[(sentence, target, index)] = { 'filtered_words': filtered_candidates, 'context': contexts_for_targets[i] } # Batch similarity scoring for all candidates if batch_filtered_results: similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer) results.update(similarity_results) return results def _filter_candidates_batch(target, words, scores, tokens, target_index): """ Optimized batch filtering of candidates (no similarity scoring here - moved to batch level). """ # Basic filtering filtered_words = [] filtered_scores = [] seen_words = set() for word, score in zip(words, scores): word_lower = word.lower() if word_lower in seen_words or word_lower == target.lower(): continue seen_words.add(word_lower) if not is_valid_word(word): continue # Quick POS check target_nltk_pos, target_spacy_pos = get_word_pos_tags(target) cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word) if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos: continue # Check antonyms (bidirectional and case-insensitive) antonyms = get_word_antonyms(target) if word.lower() in [ant.lower() for ant in antonyms]: continue # Also check if the candidate has the target as an antonym (reverse check) candidate_antonyms = get_word_antonyms(word) if target.lower() in [ant.lower() for ant in candidate_antonyms]: continue # Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard) common_antonyms = { 'big': ['small', 'tiny', 'little'], 'small': ['big', 'large', 'huge'], 'large': ['small', 'tiny', 'little'], 'good': ['bad', 'evil', 'wrong'], 'bad': ['good', 'great', 'excellent'], 'high': ['low'], 'low': ['high'], 'new': ['old'], 'old': ['new'], 'fast': ['slow'], 'slow': ['fast'], 'rich': ['poor'], 'poor': ['rich'], 'hot': ['cold'], 'cold': ['hot'], 'happy': ['sad', 'unhappy'], 'sad': ['happy', 'joyful'], 'true': ['false', 'untrue'], 'false': ['true'], 'real': ['fake', 'unreal'], 'fake': ['real'], 'up': ['down'], 'down': ['up'], 'yes': ['no'], 'no': ['yes'], 'alive': ['dead'], 'dead': ['alive'], 'safe': ['unsafe', 'dangerous'], 'dangerous': ['safe'], 'clean': ['dirty'], 'dirty': ['clean'], 'full': ['empty'], 'empty': ['full'], 'open': ['closed', 'shut'], 'closed': ['open'], 'begin': ['end', 'finish'], 'end': ['begin', 'start'], 'start': ['end', 'finish'], 'finish': ['start', 'begin'], 'first': ['last'], 'last': ['first'] } # Check if word is a known antonym of target (case-insensitive) target_lower = target.lower() if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]: continue # Check if word and target are in the same specific noun category (e.g., crops, animals, companies) # If they are different specific terms in the same category, exclude the candidate if not _are_semantically_compatible(target, word): continue filtered_words.append(word) filtered_scores.append(score) if len(filtered_words) < 2: return None # Return filtered words without similarity scoring (done at batch level) return filtered_words def _batch_similarity_scoring(batch_results, tokenizer): """ Optimized batched similarity scoring across multiple sentences for full context. Processes all candidates from multiple sentences together for better efficiency. """ # Collect all similarity scoring tasks similarity_tasks = [] sentence_contexts = {} for (sentence, target, index), value in batch_results.items(): if value is None: continue # Support both legacy list and new dict with context if isinstance(value, dict): filtered_words = value.get('filtered_words') context = value.get('context', sentence) else: filtered_words = value context = sentence # Tokenize the sentence once tokens = tokenizer.tokenize(sentence) if index >= len(tokens): continue # Store sentence context for later use sentence_contexts[(sentence, target, index)] = { 'tokens': tokens, 'target_index': index, 'filtered_words': filtered_words } # Create candidate sentences for this target for word in filtered_words: candidate_tokens = tokens.copy() candidate_tokens[index] = word candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens) # Build full-context candidate by replacing the sentence inside the chosen context once candidate_full_context = context.replace(sentence, candidate_sentence, 1) similarity_tasks.append({ 'original_context': context, 'candidate_full_context': candidate_full_context, 'target_word': word, 'context_key': (sentence, target, index) }) if not similarity_tasks: return {} # Batch process all similarity scoring try: # Group by original full context for efficient BERTScore computation context_groups = {} for task in similarity_tasks: orig_ctx = task['original_context'] if orig_ctx not in context_groups: context_groups[orig_ctx] = [] context_groups[orig_ctx].append(task) # Process each context group final_results = {} for orig_context, tasks in context_groups.items(): # Extract candidate full-contexts candidate_contexts = [task['candidate_full_context'] for task in tasks] # Batch BERTScore computation against the same full context try: similarity_scores = calc_scores_bert(orig_context, candidate_contexts) except Exception as e: # Fallback to neutral scores similarity_scores = [0.5] * len(candidate_contexts) if similarity_scores and not all(score == 0.5 for score in similarity_scores): # Group results by context key for task, score in zip(tasks, similarity_scores): context_key = task['context_key'] if context_key not in final_results: final_results[context_key] = [] final_results[context_key].append((task['target_word'], score)) # Sort results by similarity score for context_key in final_results: final_results[context_key].sort(key=lambda x: x[1], reverse=True) return final_results except Exception as e: return {} def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha): """ Parallel tournament sampling for multiple targets. """ results = {} if not target_results: return results def process_single_tournament(item): (sentence, target, index), candidates = item if not candidates: return (sentence, target, index), None alternatives = [alt[0] for alt in candidates] similarity = [alt[1] for alt in candidates] if not alternatives or not similarity: return (sentence, target, index), None # Get context context_tokens = word_tokenize(sentence) left_context = context_tokens[max(0, index - h):index] # Tournament selection from SynthID_randomization import tournament_select_word randomized_word = tournament_select_word( target, alternatives, similarity, context=left_context, key=secret_key, m=m, c=c, alpha=alpha ) return (sentence, target, index), randomized_word # Process in parallel max_workers = max(1, min(8, len(target_results))) with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()} for future in as_completed(future_to_item): key, result = future.result() results[key] = result return results def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'): """ Enhanced sentence processing using whole document context for better candidate generation. """ replacements = [] sampling_results = [] doc = nlp(sentence) sentence_target_pairs = extract_entities_and_pos(sentence) if not sentence_target_pairs: return replacements, sampling_results # Filter valid target pairs valid_pairs = [] spacy_tokens = [token.text for token in doc] for sent, target, position in sentence_target_pairs: if position < len(spacy_tokens) and spacy_tokens[position] == target: valid_pairs.append((sent, target, position)) if not valid_pairs: return replacements, sampling_results # Enhanced MLM inference with whole document context batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode) # Filter by threshold (matching original logic) filtered_results = {} for key, candidates in batch_results.items(): if candidates: # Apply threshold filtering (matching original logic) threshold_candidates = [(word, score) for word, score in candidates if score >= threshold] if len(threshold_candidates) >= 2: filtered_results[key] = threshold_candidates # Parallel tournament sampling tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha) # Collect replacements and sampling results for (sent, target, position), randomized_word in tournament_results.items(): if randomized_word: # Get the alternatives for this target from the filtered results alternatives = filtered_results.get((sent, target, position), []) alternatives_list = [alt[0] for alt in alternatives] # Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility) alternatives_with_similarity = [ {"word": alt[0], "similarity": float(alt[1])} for alt in alternatives ] # Track sampling results sampling_results.append({ "word": target, "alternatives": alternatives_list, "alternatives_with_similarity": alternatives_with_similarity, "randomized_word": randomized_word }) replacements.append((position, target, randomized_word)) return replacements, sampling_results # Legacy function for compatibility def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75): """ Legacy single-target lookup function for compatibility. """ results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K) return results.get((sentence, target, index), None) def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20): """ Legacy batch MLM inference function for compatibility. """ return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K) def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4): """ Optimized batch lookup using the new batch MLM inference. """ return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K)