Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| import math | |
| import torch | |
| import string | |
| import spacy | |
| import pandas as pd | |
| import numpy as np | |
| import nltk | |
| import sys | |
| import subprocess | |
| from nltk.tokenize import word_tokenize | |
| from nltk.stem.wordnet import WordNetLemmatizer | |
| from nltk.corpus import wordnet as wn | |
| import json | |
| from filelock import FileLock | |
| from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed | |
| from functools import lru_cache | |
| from typing import List, Tuple, Dict, Any | |
| import multiprocessing as mp | |
| # Ensure the HF_HOME environment variable points to your desired cache location | |
| # Token removed for security | |
| cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' | |
| # Handle potential import conflicts with sentence_transformers | |
| try: | |
| # Try to import bert_score directly to avoid sentence_transformers conflicts | |
| from bert_score import score as bert_score | |
| SIMILARITY_AVAILABLE = True | |
| def calc_scores_bert(original_sentence, substitute_sentences): | |
| """BERTScore function using direct bert_score import.""" | |
| try: | |
| # Safety check: truncate inputs if they're too long | |
| max_chars = 2000 # Roughly 500 tokens | |
| if len(original_sentence) > max_chars: | |
| original_sentence = original_sentence[:max_chars] | |
| truncated_substitutes = [] | |
| for sub in substitute_sentences: | |
| if len(sub) > max_chars: | |
| sub = sub[:max_chars] | |
| truncated_substitutes.append(sub) | |
| references = [original_sentence] * len(truncated_substitutes) | |
| P, R, F1 = bert_score( | |
| cands=truncated_substitutes, | |
| refs=references, | |
| model_type="bert-base-uncased", | |
| verbose=False | |
| ) | |
| return F1.tolist() | |
| except Exception as e: | |
| return [0.5] * len(substitute_sentences) | |
| def get_similarity_scores(original_sentence, substitute_sentences, method='bert'): | |
| """Similarity function using direct bert_score import.""" | |
| if method == 'bert': | |
| return calc_scores_bert(original_sentence, substitute_sentences) | |
| else: | |
| return [0.5] * len(substitute_sentences) | |
| except ImportError as e: | |
| print(f"Warning: bert_score import failed: {e}") | |
| print("Falling back to neutral similarity scores...") | |
| SIMILARITY_AVAILABLE = False | |
| def calc_scores_bert(original_sentence, substitute_sentences): | |
| """Fallback BERTScore function with neutral scores.""" | |
| return [0.5] * len(substitute_sentences) | |
| def get_similarity_scores(original_sentence, substitute_sentences, method='bert'): | |
| """Fallback similarity function with neutral scores.""" | |
| return [0.5] * len(substitute_sentences) | |
| # Setup NLTK data | |
| def setup_nltk_data(): | |
| """Setup NLTK data with error handling.""" | |
| try: | |
| nltk.download('punkt_tab', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('averaged_perceptron_tagger_eng', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('wordnet', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('omw-1.4', quiet=True) | |
| except: | |
| pass | |
| setup_nltk_data() | |
| lemmatizer = WordNetLemmatizer() | |
| # Load spaCy model - download if not available | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| print("Downloading spaCy model...") | |
| subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"]) | |
| nlp = spacy.load("en_core_web_sm") | |
| # Define the detailed whitelist of POS tags (excluding adverbs) | |
| DETAILED_POS_WHITELIST = { | |
| 'NN', # Noun, singular or mass (e.g., dog, car) | |
| 'NNS', # Noun, plural (e.g., dogs, cars) | |
| 'VB', # Verb, base form (e.g., run, eat) | |
| 'VBD', # Verb, past tense (e.g., ran, ate) | |
| 'VBG', # Verb, gerund or present participle (e.g., running, eating) | |
| 'VBN', # Verb, past participle (e.g., run, eaten) | |
| 'VBP', # Verb, non-3rd person singular present (e.g., run, eat) | |
| 'VBZ', # Verb, 3rd person singular present (e.g., runs, eats) | |
| 'JJ', # Adjective (e.g., big, blue) | |
| 'JJR', # Adjective, comparative (e.g., bigger, bluer) | |
| 'JJS', # Adjective, superlative (e.g., biggest, bluest) | |
| 'RB', # Adverb (e.g., very, silently) | |
| 'RBR', # Adverb, comparative (e.g., better) | |
| 'RBS' # Adverb, superlative (e.g., best) | |
| } | |
| # Global caches for better performance | |
| _pos_cache = {} | |
| _antonym_cache = {} | |
| _word_validity_cache = {} | |
| def extract_entities_and_pos(text): | |
| """ | |
| Detect eligible tokens for replacement while skipping: | |
| - Named entities (e.g., names, locations, organizations). | |
| - Compound words (e.g., "Opteron-based"). | |
| - Phrasal verbs (e.g., "make up", "focus on"). | |
| - Punctuation and non-POS-whitelisted tokens. | |
| """ | |
| doc = nlp(text) | |
| sentence_target_pairs = [] # List to hold (sentence, target word, token index) | |
| for sent in doc.sents: | |
| for token in sent: | |
| # Skip named entities using token.ent_type_ (more reliable than a text match) | |
| if token.ent_type_: | |
| continue | |
| # Skip standalone punctuation | |
| if token.is_punct: | |
| continue | |
| # Skip compound words (e.g., "Opteron-based") | |
| if "-" in token.text or token.dep_ in {"compound", "amod"}: | |
| continue | |
| # Skip phrasal verbs (e.g., "make up", "focus on") | |
| if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children): | |
| continue | |
| # Include regular tokens matching the POS whitelist | |
| if token.tag_ in DETAILED_POS_WHITELIST: | |
| sentence_target_pairs.append((sent.text, token.text, token.i)) | |
| return sentence_target_pairs | |
| def preprocess_text(text): | |
| """ | |
| Preprocesses the text to handle abbreviations, titles, and edge cases | |
| where a period or other punctuation does not signify a sentence end. | |
| Ensures figures, acronyms, and short names are left untouched. | |
| """ | |
| # Protect common abbreviations like "U.S." and "Corp." | |
| text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1<PERIOD>', text) | |
| # Protect floating-point numbers or ranges like "3.57" or "1.48–2.10" | |
| text = re.sub(r'(\b\d+)\.(\d+)', r'\1<PERIOD>\2', text) | |
| # Avoid modifying standalone single-letter initials in names (e.g., "J. Smith") | |
| text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1<PERIOD>', text) | |
| # Protect acronym-like patterns with dots, such as "F.B.I." | |
| text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<PERIOD>'), text) | |
| return text | |
| def split_sentences(text): | |
| """ | |
| Splits text into sentences while preserving original newlines exactly. | |
| - Protects abbreviations, acronyms, and floating-point numbers. | |
| - Only adds newlines where necessary without duplicating them. | |
| """ | |
| # Step 1: Protect abbreviations, floating numbers, acronyms | |
| text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1<ABBR>', text) | |
| text = re.sub(r'(\b\d+)\.(\d+)', r'\1<FLOAT>\2', text) | |
| text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<ABBR>'), text) | |
| # Step 2: Identify sentence boundaries without duplicating newlines | |
| sentences = [] | |
| for line in text.splitlines(keepends=True): # Retain original newlines | |
| # Split only if punctuation marks end a sentence | |
| split_line = re.split(r'(?<=[.!?])\s+', line.strip()) | |
| sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line]) | |
| # Step 3: Restore protected patterns | |
| return [sent.replace('<ABBR>', '.').replace('<FLOAT>', '.') for sent in sentences] | |
| def is_valid_word(word): | |
| """Check if a word is valid using WordNet (cached).""" | |
| return bool(wn.synsets(word)) | |
| def get_word_pos_tags(word): | |
| """Get POS tags for a word using both NLTK and spaCy (cached).""" | |
| nltk_pos = nltk.pos_tag([word])[0][1] | |
| spacy_pos = nlp(word)[0].pos_ | |
| return nltk_pos, spacy_pos | |
| def get_word_lemma(word): | |
| """Get lemmatized form of a word (cached).""" | |
| return lemmatizer.lemmatize(word) | |
| def get_word_antonyms(word): | |
| """Get antonyms for a word (cached). Includes all lemmas from all synsets.""" | |
| target_synsets = wn.synsets(word) | |
| antonyms = set() | |
| # Get antonyms from all synsets and all lemmas | |
| for syn in target_synsets: | |
| for lem in syn.lemmas(): | |
| for ant in lem.antonyms(): | |
| # Add the antonym word (first part before the dot) | |
| antonyms.add(ant.name().split('.')[0]) | |
| # Also add other lemmas of the antonym for completeness | |
| for alt_lem in wn.synsets(ant.name().split('.')[0]): | |
| for alt_ant_lem in alt_lem.lemmas(): | |
| antonyms.add(alt_ant_lem.name().split('.')[0]) | |
| return antonyms | |
| def _are_semantically_compatible(target, candidate): | |
| """ | |
| Check if target and candidate are semantically compatible for replacement. | |
| Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals). | |
| """ | |
| try: | |
| # Direct check: if target and candidate are both specific terms for crops, animals, etc. | |
| # check if they're NOT near-synonyms | |
| # Agricultural/crop terms that shouldn't be swapped | |
| agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum', | |
| 'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume'] | |
| # If both are agricultural terms and different, block | |
| if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and | |
| target.lower() != candidate.lower()): | |
| return False | |
| target_synsets = wn.synsets(target) | |
| cand_synsets = wn.synsets(candidate) | |
| if not target_synsets or not cand_synsets: | |
| return True # If no synsets, allow through | |
| # Check if they're near-synonyms (very similar) - if so, allow | |
| # We can use path similarity to check if they're similar enough | |
| max_similarity = 0.0 | |
| for t_syn in target_synsets: | |
| for c_syn in cand_synsets: | |
| try: | |
| similarity = t_syn.path_similarity(c_syn) or 0.0 | |
| max_similarity = max(max_similarity, similarity) | |
| except: | |
| pass | |
| # If they have high path similarity (>0.5), they're similar enough to allow | |
| if max_similarity > 0.5: | |
| return True | |
| # Otherwise, check if they share common direct hypernyms | |
| target_hypernyms = set() | |
| for syn in target_synsets: | |
| # Get immediate hypernyms (parent concepts) | |
| for hypernym in syn.hypernyms(): | |
| target_hypernyms.add(hypernym) | |
| cand_hypernyms = set() | |
| for syn in cand_synsets: | |
| for hypernym in syn.hypernyms(): | |
| cand_hypernyms.add(hypernym) | |
| # If they share hypernyms, check if they're both specific instances (not general terms) | |
| common_hypernyms = target_hypernyms & cand_hypernyms | |
| if common_hypernyms: | |
| # Check if both words are specific instances of the same category | |
| # If so, they shouldn't be replaced with each other | |
| # We identify this by checking if their hypernym has many siblings | |
| for hypernym in common_hypernyms: | |
| siblings = hypernym.hyponyms() | |
| # If there are many specific instances (e.g., many crops, many fruits) | |
| # it's likely a category with specific instances that shouldn't be interchanged | |
| if len(siblings) > 3: | |
| # Check if hypernym name suggests a specific category | |
| hypernym_name = hypernym.name().split('.')[0] | |
| category_keywords = [ | |
| 'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company', | |
| 'country', 'city', 'brand', 'product', 'food', 'vehicle' | |
| ] | |
| # If the hypernym contains category keywords, these are likely | |
| # specific instances that shouldn't be swapped | |
| if any(keyword in hypernym_name for keyword in category_keywords): | |
| return False | |
| return True | |
| except Exception as e: | |
| # On any error, allow the candidate through (conservative approach) | |
| return True | |
| def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400): | |
| """ | |
| Create context windows around the target sentence for better MLM generation. | |
| Intelligently handles tokenizer length limits by preserving the most relevant context. | |
| Args: | |
| full_text: The complete document text | |
| target_sentence: The sentence containing the target word | |
| target_word: The word to be replaced | |
| tokenizer: The tokenizer to check length limits | |
| max_tokens: Maximum tokens to use for context (leave room for instruction + mask) | |
| Returns: | |
| List of context windows with different levels of context | |
| """ | |
| # Split full text into sentences | |
| sentences = split_sentences(full_text) | |
| # Find the target sentence index | |
| target_sentence_idx = None | |
| for i, sent in enumerate(sentences): | |
| if target_sentence.strip() in sent.strip(): | |
| target_sentence_idx = i | |
| break | |
| if target_sentence_idx is None: | |
| return [target_sentence] # Fallback to original sentence | |
| # Create context windows with sentence-prioritized approach | |
| context_windows = [] | |
| # Window 1: Just the target sentence (always include) | |
| context_windows.append(target_sentence) | |
| # Window 2: Target sentence + 1 sentence before and after (if fits) | |
| start_idx = max(0, target_sentence_idx - 1) | |
| end_idx = min(len(sentences), target_sentence_idx + 2) | |
| context_window = " ".join(sentences[start_idx:end_idx]) | |
| try: | |
| encoded_len = len(tokenizer.encode(context_window)) | |
| if encoded_len <= max_tokens: | |
| context_windows.append(context_window) | |
| except Exception as e: | |
| pass | |
| # Window 3: Target sentence + 2 sentences before and after (if fits) | |
| start_idx = max(0, target_sentence_idx - 2) | |
| end_idx = min(len(sentences), target_sentence_idx + 3) | |
| context_window = " ".join(sentences[start_idx:end_idx]) | |
| try: | |
| encoded_len = len(tokenizer.encode(context_window)) | |
| if encoded_len <= max_tokens: | |
| context_windows.append(context_window) | |
| except Exception as e: | |
| pass | |
| # Window 4: Target sentence + 3 sentences before and after (if fits) | |
| start_idx = max(0, target_sentence_idx - 3) | |
| end_idx = min(len(sentences), target_sentence_idx + 4) | |
| context_window = " ".join(sentences[start_idx:end_idx]) | |
| try: | |
| encoded_len = len(tokenizer.encode(context_window)) | |
| if encoded_len <= max_tokens: | |
| context_windows.append(context_window) | |
| except Exception as e: | |
| pass | |
| # Window 5: Intelligent context with sentence prioritization + word expansion | |
| intelligent_context = _create_intelligent_context( | |
| full_text, target_word, target_sentence_idx, tokenizer, max_tokens | |
| ) | |
| context_windows.append(intelligent_context) | |
| return context_windows | |
| def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens): | |
| """ | |
| Create intelligent context that prioritizes sentence boundaries while respecting token limits. | |
| Strategy: Target sentence → Nearby sentences → Word-level expansion | |
| """ | |
| sentences = split_sentences(full_text) | |
| # Strategy 1: Always start with the target sentence | |
| target_sentence = sentences[target_sentence_idx] | |
| try: | |
| target_sentence_tokens = len(tokenizer.encode(target_sentence)) | |
| except Exception as e: | |
| target_sentence_tokens = 1000 # Fallback to assume it's too long | |
| if target_sentence_tokens > max_tokens: | |
| # If even target sentence is too long, truncate intelligently | |
| return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens) | |
| # Strategy 2: Expand sentence-by-sentence around target sentence | |
| best_context = target_sentence | |
| best_token_count = target_sentence_tokens | |
| # Try adding sentences before and after the target sentence | |
| for sentence_radius in range(1, min(len(sentences), 20)): # Max 20 sentences radius | |
| start_idx = max(0, target_sentence_idx - sentence_radius) | |
| end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1) | |
| # Create context with complete sentences | |
| context_sentences = sentences[start_idx:end_idx] | |
| context_window = " ".join(context_sentences) | |
| try: | |
| token_count = len(tokenizer.encode(context_window)) | |
| except Exception as e: | |
| token_count = 1000 # Fallback to assume it's too long | |
| if token_count <= max_tokens: | |
| # This sentence expansion fits, keep it as our best option | |
| best_context = context_window | |
| best_token_count = token_count | |
| else: | |
| # This expansion is too big, stop here | |
| break | |
| # Strategy 3: If we have room left, try word-level expansion within the best sentence context | |
| remaining_tokens = max_tokens - best_token_count | |
| if remaining_tokens > 50: # If we have significant room left | |
| enhanced_context = _enhance_with_word_expansion( | |
| full_text, target_word, best_context, tokenizer, remaining_tokens | |
| ) | |
| if enhanced_context: | |
| return enhanced_context | |
| return best_context | |
| def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens): | |
| """ | |
| Enhance the current sentence-based context with word-level expansion if there's room. | |
| """ | |
| words = full_text.split() | |
| target_word_idx = None | |
| # Find target word position in full text | |
| for i, word in enumerate(words): | |
| if word.lower() == target_word.lower(): | |
| target_word_idx = i | |
| break | |
| if target_word_idx is None: | |
| return current_context | |
| # Try to expand word-by-word around the target word | |
| try: | |
| current_tokens = len(tokenizer.encode(current_context)) | |
| except Exception as e: | |
| print(f"WARNING: Error encoding current context: {e}") | |
| current_tokens = 1000 # Fallback to assume it's too long | |
| for expansion_size in range(1, min(len(words), 100)): # Max 100 words expansion | |
| start_word = max(0, target_word_idx - expansion_size) | |
| end_word = min(len(words), target_word_idx + expansion_size + 1) | |
| expanded_context = " ".join(words[start_word:end_word]) | |
| try: | |
| expanded_tokens = len(tokenizer.encode(expanded_context)) | |
| except Exception as e: | |
| expanded_tokens = 1000 # Fallback to assume it's too long | |
| if expanded_tokens <= current_tokens + remaining_tokens: | |
| # This expansion fits within our remaining token budget | |
| return expanded_context | |
| else: | |
| # This expansion is too big, stop here | |
| break | |
| return current_context | |
| def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens): | |
| """ | |
| Intelligently truncate a sentence while preserving context around the target word. | |
| """ | |
| words = sentence.split() | |
| target_word_idx = None | |
| # Find target word position | |
| for i, word in enumerate(words): | |
| if word.lower() == target_word.lower(): | |
| target_word_idx = i | |
| break | |
| if target_word_idx is None: | |
| # If target word not found, truncate from the end | |
| truncated = " ".join(words) | |
| try: | |
| while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1: | |
| words = words[:-1] | |
| truncated = " ".join(words) | |
| except Exception as e: | |
| # Fallback: return first few words | |
| truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words) | |
| return truncated | |
| # Truncate symmetrically around target word | |
| context_words = 10 # Start with 10 words before/after | |
| while context_words > 0: | |
| start_word = max(0, target_word_idx - context_words) | |
| end_word = min(len(words), target_word_idx + context_words + 1) | |
| truncated_sentence = " ".join(words[start_word:end_word]) | |
| try: | |
| if len(tokenizer.encode(truncated_sentence)) <= max_tokens: | |
| return truncated_sentence | |
| except Exception as e: | |
| # Continue to next iteration | |
| pass | |
| context_words -= 1 | |
| # Fallback: just the target word with minimal context | |
| return f"... {target_word} ..." | |
| def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None): | |
| """ | |
| Intelligently slice input text to fit within max_length tokens while preserving the mask token. | |
| Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas. | |
| Args: | |
| input_text: The full input text to be tokenized | |
| tokenizer: The tokenizer to use | |
| max_length: Maximum allowed sequence length (default 512) | |
| mask_token_id: The mask token ID to preserve | |
| Returns: | |
| Tuple of (sliced_input_ids, mask_position_in_sliced) | |
| """ | |
| # First, tokenize the full input | |
| input_ids = tokenizer.encode(input_text, add_special_tokens=True) | |
| # If already within limits, return as is | |
| if len(input_ids) <= max_length: | |
| mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None | |
| return input_ids, mask_pos | |
| # Find mask token position | |
| mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id] | |
| if not mask_positions: | |
| # No mask token found, truncate from the end | |
| return input_ids[:max_length], None | |
| mask_pos = mask_positions[0] # Use first mask token | |
| # Calculate how many tokens we need to remove | |
| excess_tokens = len(input_ids) - max_length | |
| # Strategy: Remove tokens from both ends while preserving mask context | |
| # Reserve some context around the mask token | |
| mask_context_size = min(50, max_length // 4) # Reserve 25% of max_length or 50 tokens, whichever is smaller | |
| # Calculate available space for context around mask | |
| available_before = min(mask_pos, mask_context_size) | |
| available_after = min(len(input_ids) - mask_pos - 1, mask_context_size) | |
| # Calculate how much to remove from each end | |
| tokens_to_remove_before = max(0, mask_pos - available_before) | |
| tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after) | |
| # Initialize removal variables | |
| remove_before = 0 | |
| remove_after = 0 | |
| # Distribute excess tokens proportionally | |
| if excess_tokens > 0: | |
| if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens: | |
| # We can remove enough from the ends | |
| if tokens_to_remove_before >= excess_tokens // 2: | |
| remove_before = excess_tokens // 2 | |
| remove_after = excess_tokens - remove_before | |
| else: | |
| remove_before = tokens_to_remove_before | |
| remove_after = min(tokens_to_remove_after, excess_tokens - remove_before) | |
| else: | |
| # Need to remove more aggressively | |
| remove_before = tokens_to_remove_before | |
| remove_after = tokens_to_remove_after | |
| remaining_excess = excess_tokens - remove_before - remove_after | |
| # Remove remaining excess from the end | |
| if remaining_excess > 0: | |
| remove_after += remaining_excess | |
| # Calculate final indices | |
| start_idx = remove_before | |
| end_idx = len(input_ids) - remove_after | |
| # Ensure we don't exceed max_length | |
| if end_idx - start_idx > max_length: | |
| # Center around mask token | |
| half_length = max_length // 2 | |
| start_idx = max(0, mask_pos - half_length) | |
| end_idx = min(len(input_ids), start_idx + max_length) | |
| # Slice the input_ids | |
| sliced_input_ids = input_ids[start_idx:end_idx] | |
| # Debug information | |
| if len(sliced_input_ids) > max_length: | |
| # Force truncation as final fallback | |
| sliced_input_ids = sliced_input_ids[:max_length] | |
| # Adjust mask position for the sliced sequence | |
| adjusted_mask_pos = mask_pos - start_idx | |
| return sliced_input_ids, adjusted_mask_pos | |
| def _create_word_level_context(full_text, target_word, tokenizer, max_tokens): | |
| """ | |
| Create context by expanding word-by-word around the target word until reaching token limit. | |
| This maximizes context while respecting tokenizer limits. | |
| """ | |
| words = full_text.split() | |
| target_word_idx = None | |
| # Find target word position in full text | |
| for i, word in enumerate(words): | |
| if word.lower() == target_word.lower(): | |
| target_word_idx = i | |
| break | |
| if target_word_idx is None: | |
| # Fallback: expand from beginning until token limit | |
| return _expand_from_start(words, tokenizer, max_tokens) | |
| # Word-by-word expansion around target word | |
| return _expand_around_target(words, target_word_idx, tokenizer, max_tokens) | |
| def _expand_around_target(words, target_idx, tokenizer, max_tokens): | |
| """ | |
| Expand word-by-word around target word until reaching token limit. | |
| """ | |
| best_context = "" | |
| best_token_count = 0 | |
| # Try different expansion sizes | |
| for expansion_size in range(1, min(len(words), 200)): # Max 200 words expansion | |
| start_word = max(0, target_idx - expansion_size) | |
| end_word = min(len(words), target_idx + expansion_size + 1) | |
| context_window = " ".join(words[start_word:end_word]) | |
| try: | |
| token_count = len(tokenizer.encode(context_window)) | |
| except Exception as e: | |
| token_count = 1000 # Fallback to assume it's too long | |
| if token_count <= max_tokens: | |
| # This expansion fits, keep it as our best option | |
| best_context = context_window | |
| best_token_count = token_count | |
| else: | |
| # This expansion is too big, stop here | |
| break | |
| # If we found a good context, return it | |
| if best_context: | |
| return best_context | |
| # Fallback: minimal context around target word | |
| start_word = max(0, target_idx - 5) | |
| end_word = min(len(words), target_idx + 6) | |
| return " ".join(words[start_word:end_word]) | |
| def _expand_from_start(words, tokenizer, max_tokens): | |
| """ | |
| Expand from the start of the text until reaching token limit. | |
| """ | |
| for end_idx in range(len(words), 0, -1): | |
| context_window = " ".join(words[:end_idx]) | |
| try: | |
| if len(tokenizer.encode(context_window)) <= max_tokens: | |
| return context_window | |
| except Exception as e: | |
| # Continue to next iteration | |
| pass | |
| # Fallback: first few words | |
| return " ".join(words[:10]) if len(words) >= 10 else " ".join(words) | |
| def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'): | |
| """ | |
| Enhanced MLM inference using whole document context for better candidate generation. | |
| """ | |
| results = {} | |
| # Group targets by sentence for batch processing | |
| sentence_groups = {} | |
| for sent, target, index in sentence_target_pairs: | |
| if sent not in sentence_groups: | |
| sentence_groups[sent] = [] | |
| sentence_groups[sent].append((target, index)) | |
| for sentence, targets in sentence_groups.items(): | |
| # Process targets in batches | |
| for i in range(0, len(targets), batch_size): | |
| batch_targets = targets[i:i+batch_size] | |
| batch_results = _process_whole_context_mlm_batch( | |
| full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode | |
| ) | |
| results.update(batch_results) | |
| return results | |
| def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'): | |
| """ | |
| Process a batch of targets using whole document context for MLM. | |
| """ | |
| results = {} | |
| # Tokenize sentence once | |
| doc = nlp(sentence) | |
| tokens = [token.text for token in doc] | |
| # Create multiple masked versions for batch processing | |
| masked_inputs = [] | |
| mask_positions = [] | |
| contexts_for_targets = [] | |
| for target, index in targets: | |
| if index < len(tokens): | |
| # Create context windows with tokenizer length awareness | |
| context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens) | |
| # Use the most comprehensive context window that fits within token limits | |
| full_context = context_windows[-1] # Built around the target sentence | |
| # Select context for similarity according to mode | |
| context = sentence if similarity_context_mode == 'sentence' else full_context | |
| # Create masked version of the FULL context (not just the sentence) | |
| masked_full_context = context.replace(target, tokenizer.mask_token, 1) | |
| instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:" | |
| input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}" | |
| # AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors | |
| # Estimate token count (roughly 1 token per 4 characters for English) | |
| estimated_tokens = len(input_text) // 4 | |
| if estimated_tokens > 500: # Leave some buffer | |
| # Truncate to roughly 2000 characters (500 tokens) | |
| input_text = input_text[:2000] | |
| # SIMPLE FIX: Truncate input text if it's too long | |
| try: | |
| temp_tokens = tokenizer.encode(input_text, add_special_tokens=True) | |
| if len(temp_tokens) > 512: | |
| # Truncate the input text by removing words from the end | |
| words = input_text.split() | |
| while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10: | |
| words = words[:-1] | |
| input_text = " ".join(words) | |
| except Exception as e: | |
| # Emergency truncation - just take first 200 words | |
| words = input_text.split() | |
| input_text = " ".join(words[:200]) | |
| masked_inputs.append(input_text) | |
| # Store the original sentence-level index for reference, but mask position will be calculated during tokenization | |
| mask_positions.append(index) | |
| contexts_for_targets.append(context) | |
| if not masked_inputs: | |
| return results | |
| # Batch tokenize | |
| MAX_LENGTH = max_length # Use parameter for A100 optimization | |
| batch_inputs = [] | |
| batch_mask_positions = [] | |
| batch_contexts = [] | |
| for input_text, mask_pos in zip(masked_inputs, mask_positions): | |
| # Use intelligent token slicing to ensure we stay within MAX_LENGTH | |
| try: | |
| input_ids, adjusted_mask_pos = _intelligent_token_slicing( | |
| input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id | |
| ) | |
| if adjusted_mask_pos is not None: | |
| batch_inputs.append(input_ids) | |
| batch_mask_positions.append(adjusted_mask_pos) | |
| else: | |
| # Mask token not found in sliced sequence, skip this input | |
| continue | |
| except Exception as e: | |
| # Fallback: simple truncation | |
| try: | |
| input_ids = tokenizer.encode(input_text, add_special_tokens=True) | |
| if len(input_ids) > MAX_LENGTH: | |
| input_ids = input_ids[:MAX_LENGTH] | |
| masked_position = input_ids.index(tokenizer.mask_token_id) | |
| batch_inputs.append(input_ids) | |
| batch_mask_positions.append(masked_position) | |
| except ValueError: | |
| # Mask token not found, skip this input | |
| continue | |
| if not batch_inputs: | |
| return results | |
| # Pad sequences to same length, but ensure we don't exceed MAX_LENGTH | |
| max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH) | |
| # Additional safety check: truncate any sequences that are still too long | |
| truncated_batch_inputs = [] | |
| for input_ids in batch_inputs: | |
| if len(input_ids) > MAX_LENGTH: | |
| input_ids = input_ids[:MAX_LENGTH] | |
| truncated_batch_inputs.append(input_ids) | |
| padded_inputs = [] | |
| attention_masks = [] | |
| for input_ids in truncated_batch_inputs: | |
| attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids)) | |
| padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) | |
| padded_inputs.append(padded_ids) | |
| attention_masks.append(attention_mask) | |
| # Final safety check: ensure no sequence exceeds MAX_LENGTH | |
| for i, padded_ids in enumerate(padded_inputs): | |
| if len(padded_ids) > MAX_LENGTH: | |
| padded_inputs[i] = padded_ids[:MAX_LENGTH] | |
| attention_masks[i] = attention_masks[i][:MAX_LENGTH] | |
| # Batch inference - optimized for A100 with mixed precision | |
| with torch.no_grad(): | |
| input_tensor = torch.tensor(padded_inputs, dtype=torch.long) | |
| attention_tensor = torch.tensor(attention_masks, dtype=torch.long) | |
| if torch.cuda.is_available(): | |
| input_tensor = input_tensor.cuda() | |
| attention_tensor = attention_tensor.cuda() | |
| # Use mixed precision for A100 optimization | |
| with torch.amp.autocast('cuda'): | |
| outputs = lm_model(input_tensor, attention_mask=attention_tensor) | |
| logits = outputs.logits | |
| else: | |
| outputs = lm_model(input_tensor, attention_mask=attention_tensor) | |
| logits = outputs.logits | |
| # Process results - collect filtered candidates first | |
| batch_filtered_results = {} | |
| for i, (target, index) in enumerate(targets): | |
| if i < len(batch_mask_positions): | |
| mask_pos = batch_mask_positions[i] | |
| mask_logits = logits[i, mask_pos].squeeze() | |
| # Get top predictions | |
| top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1] | |
| scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist() | |
| words = [tokenizer.decode(token.item()).strip() for token in top_tokens] | |
| # Filter candidates (without similarity scoring) | |
| filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index) | |
| if filtered_candidates: | |
| # Attach the exact context window used for this target | |
| batch_filtered_results[(sentence, target, index)] = { | |
| 'filtered_words': filtered_candidates, | |
| 'context': contexts_for_targets[i] | |
| } | |
| # Batch similarity scoring for all candidates | |
| if batch_filtered_results: | |
| similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer) | |
| results.update(similarity_results) | |
| return results | |
| def _filter_candidates_batch(target, words, scores, tokens, target_index): | |
| """ | |
| Optimized batch filtering of candidates (no similarity scoring here - moved to batch level). | |
| """ | |
| # Basic filtering | |
| filtered_words = [] | |
| filtered_scores = [] | |
| seen_words = set() | |
| for word, score in zip(words, scores): | |
| word_lower = word.lower() | |
| if word_lower in seen_words or word_lower == target.lower(): | |
| continue | |
| seen_words.add(word_lower) | |
| if not is_valid_word(word): | |
| continue | |
| # Quick POS check | |
| target_nltk_pos, target_spacy_pos = get_word_pos_tags(target) | |
| cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word) | |
| if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos: | |
| continue | |
| # Check antonyms (bidirectional and case-insensitive) | |
| antonyms = get_word_antonyms(target) | |
| if word.lower() in [ant.lower() for ant in antonyms]: | |
| continue | |
| # Also check if the candidate has the target as an antonym (reverse check) | |
| candidate_antonyms = get_word_antonyms(word) | |
| if target.lower() in [ant.lower() for ant in candidate_antonyms]: | |
| continue | |
| # Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard) | |
| common_antonyms = { | |
| 'big': ['small', 'tiny', 'little'], | |
| 'small': ['big', 'large', 'huge'], | |
| 'large': ['small', 'tiny', 'little'], | |
| 'good': ['bad', 'evil', 'wrong'], | |
| 'bad': ['good', 'great', 'excellent'], | |
| 'high': ['low'], | |
| 'low': ['high'], | |
| 'new': ['old'], | |
| 'old': ['new'], | |
| 'fast': ['slow'], | |
| 'slow': ['fast'], | |
| 'rich': ['poor'], | |
| 'poor': ['rich'], | |
| 'hot': ['cold'], | |
| 'cold': ['hot'], | |
| 'happy': ['sad', 'unhappy'], | |
| 'sad': ['happy', 'joyful'], | |
| 'true': ['false', 'untrue'], | |
| 'false': ['true'], | |
| 'real': ['fake', 'unreal'], | |
| 'fake': ['real'], | |
| 'up': ['down'], | |
| 'down': ['up'], | |
| 'yes': ['no'], | |
| 'no': ['yes'], | |
| 'alive': ['dead'], | |
| 'dead': ['alive'], | |
| 'safe': ['unsafe', 'dangerous'], | |
| 'dangerous': ['safe'], | |
| 'clean': ['dirty'], | |
| 'dirty': ['clean'], | |
| 'full': ['empty'], | |
| 'empty': ['full'], | |
| 'open': ['closed', 'shut'], | |
| 'closed': ['open'], | |
| 'begin': ['end', 'finish'], | |
| 'end': ['begin', 'start'], | |
| 'start': ['end', 'finish'], | |
| 'finish': ['start', 'begin'], | |
| 'first': ['last'], | |
| 'last': ['first'] | |
| } | |
| # Check if word is a known antonym of target (case-insensitive) | |
| target_lower = target.lower() | |
| if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]: | |
| continue | |
| # Check if word and target are in the same specific noun category (e.g., crops, animals, companies) | |
| # If they are different specific terms in the same category, exclude the candidate | |
| if not _are_semantically_compatible(target, word): | |
| continue | |
| filtered_words.append(word) | |
| filtered_scores.append(score) | |
| if len(filtered_words) < 2: | |
| return None | |
| # Return filtered words without similarity scoring (done at batch level) | |
| return filtered_words | |
| def _batch_similarity_scoring(batch_results, tokenizer): | |
| """ | |
| Optimized batched similarity scoring across multiple sentences for full context. | |
| Processes all candidates from multiple sentences together for better efficiency. | |
| """ | |
| # Collect all similarity scoring tasks | |
| similarity_tasks = [] | |
| sentence_contexts = {} | |
| for (sentence, target, index), value in batch_results.items(): | |
| if value is None: | |
| continue | |
| # Support both legacy list and new dict with context | |
| if isinstance(value, dict): | |
| filtered_words = value.get('filtered_words') | |
| context = value.get('context', sentence) | |
| else: | |
| filtered_words = value | |
| context = sentence | |
| # Tokenize the sentence once | |
| tokens = tokenizer.tokenize(sentence) | |
| if index >= len(tokens): | |
| continue | |
| # Store sentence context for later use | |
| sentence_contexts[(sentence, target, index)] = { | |
| 'tokens': tokens, | |
| 'target_index': index, | |
| 'filtered_words': filtered_words | |
| } | |
| # Create candidate sentences for this target | |
| for word in filtered_words: | |
| candidate_tokens = tokens.copy() | |
| candidate_tokens[index] = word | |
| candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens) | |
| # Build full-context candidate by replacing the sentence inside the chosen context once | |
| candidate_full_context = context.replace(sentence, candidate_sentence, 1) | |
| similarity_tasks.append({ | |
| 'original_context': context, | |
| 'candidate_full_context': candidate_full_context, | |
| 'target_word': word, | |
| 'context_key': (sentence, target, index) | |
| }) | |
| if not similarity_tasks: | |
| return {} | |
| # Batch process all similarity scoring | |
| try: | |
| # Group by original full context for efficient BERTScore computation | |
| context_groups = {} | |
| for task in similarity_tasks: | |
| orig_ctx = task['original_context'] | |
| if orig_ctx not in context_groups: | |
| context_groups[orig_ctx] = [] | |
| context_groups[orig_ctx].append(task) | |
| # Process each context group | |
| final_results = {} | |
| for orig_context, tasks in context_groups.items(): | |
| # Extract candidate full-contexts | |
| candidate_contexts = [task['candidate_full_context'] for task in tasks] | |
| # Batch BERTScore computation against the same full context | |
| try: | |
| similarity_scores = calc_scores_bert(orig_context, candidate_contexts) | |
| except Exception as e: | |
| # Fallback to neutral scores | |
| similarity_scores = [0.5] * len(candidate_contexts) | |
| if similarity_scores and not all(score == 0.5 for score in similarity_scores): | |
| # Group results by context key | |
| for task, score in zip(tasks, similarity_scores): | |
| context_key = task['context_key'] | |
| if context_key not in final_results: | |
| final_results[context_key] = [] | |
| final_results[context_key].append((task['target_word'], score)) | |
| # Sort results by similarity score | |
| for context_key in final_results: | |
| final_results[context_key].sort(key=lambda x: x[1], reverse=True) | |
| return final_results | |
| except Exception as e: | |
| return {} | |
| def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha): | |
| """ | |
| Parallel tournament sampling for multiple targets. | |
| """ | |
| results = {} | |
| if not target_results: | |
| return results | |
| def process_single_tournament(item): | |
| (sentence, target, index), candidates = item | |
| if not candidates: | |
| return (sentence, target, index), None | |
| alternatives = [alt[0] for alt in candidates] | |
| similarity = [alt[1] for alt in candidates] | |
| if not alternatives or not similarity: | |
| return (sentence, target, index), None | |
| # Get context | |
| context_tokens = word_tokenize(sentence) | |
| left_context = context_tokens[max(0, index - h):index] | |
| # Tournament selection | |
| from SynthID_randomization import tournament_select_word | |
| randomized_word = tournament_select_word( | |
| target, alternatives, similarity, | |
| context=left_context, key=secret_key, m=m, c=c, alpha=alpha | |
| ) | |
| return (sentence, target, index), randomized_word | |
| # Process in parallel | |
| max_workers = max(1, min(8, len(target_results))) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()} | |
| for future in as_completed(future_to_item): | |
| key, result = future.result() | |
| results[key] = result | |
| return results | |
| def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'): | |
| """ | |
| Enhanced sentence processing using whole document context for better candidate generation. | |
| """ | |
| replacements = [] | |
| sampling_results = [] | |
| doc = nlp(sentence) | |
| sentence_target_pairs = extract_entities_and_pos(sentence) | |
| if not sentence_target_pairs: | |
| return replacements, sampling_results | |
| # Filter valid target pairs | |
| valid_pairs = [] | |
| spacy_tokens = [token.text for token in doc] | |
| for sent, target, position in sentence_target_pairs: | |
| if position < len(spacy_tokens) and spacy_tokens[position] == target: | |
| valid_pairs.append((sent, target, position)) | |
| if not valid_pairs: | |
| return replacements, sampling_results | |
| # Enhanced MLM inference with whole document context | |
| batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode) | |
| # Filter by threshold (matching original logic) | |
| filtered_results = {} | |
| for key, candidates in batch_results.items(): | |
| if candidates: | |
| # Apply threshold filtering (matching original logic) | |
| threshold_candidates = [(word, score) for word, score in candidates if score >= threshold] | |
| if len(threshold_candidates) >= 2: | |
| filtered_results[key] = threshold_candidates | |
| # Parallel tournament sampling | |
| tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha) | |
| # Collect replacements and sampling results | |
| for (sent, target, position), randomized_word in tournament_results.items(): | |
| if randomized_word: | |
| # Get the alternatives for this target from the filtered results | |
| alternatives = filtered_results.get((sent, target, position), []) | |
| alternatives_list = [alt[0] for alt in alternatives] | |
| # Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility) | |
| alternatives_with_similarity = [ | |
| {"word": alt[0], "similarity": float(alt[1])} for alt in alternatives | |
| ] | |
| # Track sampling results | |
| sampling_results.append({ | |
| "word": target, | |
| "alternatives": alternatives_list, | |
| "alternatives_with_similarity": alternatives_with_similarity, | |
| "randomized_word": randomized_word | |
| }) | |
| replacements.append((position, target, randomized_word)) | |
| return replacements, sampling_results | |
| # Legacy function for compatibility | |
| def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75): | |
| """ | |
| Legacy single-target lookup function for compatibility. | |
| """ | |
| results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K) | |
| return results.get((sentence, target, index), None) | |
| def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20): | |
| """ | |
| Legacy batch MLM inference function for compatibility. | |
| """ | |
| return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K) | |
| def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4): | |
| """ | |
| Optimized batch lookup using the new batch MLM inference. | |
| """ | |
| return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K) | |