File size: 48,984 Bytes
fc6dcab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
import os
import re
import time
import math
import torch
import string
import spacy
import pandas as pd
import numpy as np
import nltk
import sys
import subprocess
from nltk.tokenize import word_tokenize
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.corpus import wordnet as wn
import json
from filelock import FileLock
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from functools import lru_cache
from typing import List, Tuple, Dict, Any
import multiprocessing as mp

# Ensure the HF_HOME environment variable points to your desired cache location
# Token removed for security
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'

# Handle potential import conflicts with sentence_transformers
try:
    # Try to import bert_score directly to avoid sentence_transformers conflicts
    from bert_score import score as bert_score
    SIMILARITY_AVAILABLE = True
    
    def calc_scores_bert(original_sentence, substitute_sentences):
        """BERTScore function using direct bert_score import."""
        try:
            # Safety check: truncate inputs if they're too long
            max_chars = 2000  # Roughly 500 tokens
            if len(original_sentence) > max_chars:
                original_sentence = original_sentence[:max_chars]
            
            truncated_substitutes = []
            for sub in substitute_sentences:
                if len(sub) > max_chars:
                    sub = sub[:max_chars]
                truncated_substitutes.append(sub)
            
            references = [original_sentence] * len(truncated_substitutes)
            P, R, F1 = bert_score(
                cands=truncated_substitutes, 
                refs=references, 
                model_type="bert-base-uncased", 
                verbose=False
            )
            return F1.tolist()
        except Exception as e:
            return [0.5] * len(substitute_sentences)
    
    def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
        """Similarity function using direct bert_score import."""
        if method == 'bert':
            return calc_scores_bert(original_sentence, substitute_sentences)
        else:
            return [0.5] * len(substitute_sentences)
            
except ImportError as e:
    print(f"Warning: bert_score import failed: {e}")
    print("Falling back to neutral similarity scores...")
    SIMILARITY_AVAILABLE = False
    
    def calc_scores_bert(original_sentence, substitute_sentences):
        """Fallback BERTScore function with neutral scores."""
        return [0.5] * len(substitute_sentences)
    
    def get_similarity_scores(original_sentence, substitute_sentences, method='bert'):
        """Fallback similarity function with neutral scores."""
        return [0.5] * len(substitute_sentences)

# Setup NLTK data
def setup_nltk_data():
    """Setup NLTK data with error handling."""
    try:
        nltk.download('punkt_tab', quiet=True)
    except:
        pass
    try:
        nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    except:
        pass
    try:
        nltk.download('wordnet', quiet=True)
    except:
        pass
    try:
        nltk.download('omw-1.4', quiet=True)
    except:
        pass

setup_nltk_data()

lemmatizer = WordNetLemmatizer()

# Load spaCy model - download if not available
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Downloading spaCy model...")
    subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")

# Define the detailed whitelist of POS tags (excluding adverbs)
DETAILED_POS_WHITELIST = {
    'NN',  # Noun, singular or mass (e.g., dog, car)
    'NNS', # Noun, plural (e.g., dogs, cars)
    'VB',  # Verb, base form (e.g., run, eat)
    'VBD', # Verb, past tense (e.g., ran, ate)
    'VBG', # Verb, gerund or present participle (e.g., running, eating)
    'VBN', # Verb, past participle (e.g., run, eaten)
    'VBP', # Verb, non-3rd person singular present (e.g., run, eat)
    'VBZ', # Verb, 3rd person singular present (e.g., runs, eats)
    'JJ',  # Adjective (e.g., big, blue)
    'JJR', # Adjective, comparative (e.g., bigger, bluer)
    'JJS', # Adjective, superlative (e.g., biggest, bluest)
    'RB',  # Adverb (e.g., very, silently)
    'RBR', # Adverb, comparative (e.g., better)
    'RBS'  # Adverb, superlative (e.g., best)
}

# Global caches for better performance
_pos_cache = {}
_antonym_cache = {}
_word_validity_cache = {}

def extract_entities_and_pos(text):
    """
    Detect eligible tokens for replacement while skipping:
    - Named entities (e.g., names, locations, organizations).
    - Compound words (e.g., "Opteron-based").
    - Phrasal verbs (e.g., "make up", "focus on").
    - Punctuation and non-POS-whitelisted tokens.
    """
    doc = nlp(text)
    sentence_target_pairs = []  # List to hold (sentence, target word, token index)

    for sent in doc.sents:
        for token in sent:
            # Skip named entities using token.ent_type_ (more reliable than a text match)
            if token.ent_type_:
                continue

            # Skip standalone punctuation
            if token.is_punct:
                continue

            # Skip compound words (e.g., "Opteron-based")
            if "-" in token.text or token.dep_ in {"compound", "amod"}:
                continue

            # Skip phrasal verbs (e.g., "make up", "focus on")
            if token.pos_ == "VERB" and any(child.dep_ == "prt" for child in token.children):
                continue

            # Include regular tokens matching the POS whitelist
            if token.tag_ in DETAILED_POS_WHITELIST:
                sentence_target_pairs.append((sent.text, token.text, token.i))

    return sentence_target_pairs

def preprocess_text(text):
    """
    Preprocesses the text to handle abbreviations, titles, and edge cases
    where a period or other punctuation does not signify a sentence end.
    Ensures figures, acronyms, and short names are left untouched.
    """
    # Protect common abbreviations like "U.S." and "Corp."
    text = re.sub(r'\b(U\.S|U\.K|Corp|Inc|Ltd)\.', r'\1<PERIOD>', text)
    
    # Protect floating-point numbers or ranges like "3.57" or "1.48–2.10"
    text = re.sub(r'(\b\d+)\.(\d+)', r'\1<PERIOD>\2', text)
    
    # Avoid modifying standalone single-letter initials in names (e.g., "J. Smith")
    text = re.sub(r'\b([A-Z])\.(?=\s[A-Z])', r'\1<PERIOD>', text)

    # Protect acronym-like patterns with dots, such as "F.B.I."
    text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<PERIOD>'), text)

    return text

def split_sentences(text):
    """
    Splits text into sentences while preserving original newlines exactly.
    - Protects abbreviations, acronyms, and floating-point numbers.
    - Only adds newlines where necessary without duplicating them.
    """
    # Step 1: Protect abbreviations, floating numbers, acronyms
    text = re.sub(r'\b(U\.S\.|U\.K\.|Inc\.|Ltd\.|Corp\.|e\.g\.|i\.e\.|etc\.)\b', r'\1<ABBR>', text)
    text = re.sub(r'(\b\d+)\.(\d+)', r'\1<FLOAT>\2', text)
    text = re.sub(r'\b([A-Z]\.){2,}[A-Z]\.', lambda m: m.group(0).replace('.', '<ABBR>'), text)

    # Step 2: Identify sentence boundaries without duplicating newlines
    sentences = []
    for line in text.splitlines(keepends=True):  # Retain original newlines
        # Split only if punctuation marks end a sentence
        split_line = re.split(r'(?<=[.!?])\s+', line.strip())
        sentences.extend([segment + "\n" if line.endswith("\n") else segment for segment in split_line])

    # Step 3: Restore protected patterns
    return [sent.replace('<ABBR>', '.').replace('<FLOAT>', '.') for sent in sentences]

@lru_cache(maxsize=10000)
def is_valid_word(word):
    """Check if a word is valid using WordNet (cached)."""
    return bool(wn.synsets(word))

@lru_cache(maxsize=5000)
def get_word_pos_tags(word):
    """Get POS tags for a word using both NLTK and spaCy (cached)."""
    nltk_pos = nltk.pos_tag([word])[0][1]
    spacy_pos = nlp(word)[0].pos_
    return nltk_pos, spacy_pos

@lru_cache(maxsize=5000)
def get_word_lemma(word):
    """Get lemmatized form of a word (cached)."""
    return lemmatizer.lemmatize(word)

@lru_cache(maxsize=2000)
def get_word_antonyms(word):
    """Get antonyms for a word (cached). Includes all lemmas from all synsets."""
    target_synsets = wn.synsets(word)
    antonyms = set()
    
    # Get antonyms from all synsets and all lemmas
    for syn in target_synsets:
        for lem in syn.lemmas():
            for ant in lem.antonyms():
                # Add the antonym word (first part before the dot)
                antonyms.add(ant.name().split('.')[0])
                # Also add other lemmas of the antonym for completeness
                for alt_lem in wn.synsets(ant.name().split('.')[0]):
                    for alt_ant_lem in alt_lem.lemmas():
                        antonyms.add(alt_ant_lem.name().split('.')[0])
    
    return antonyms

def _are_semantically_compatible(target, candidate):
    """
    Check if target and candidate are semantically compatible for replacement.
    Returns False if they are specific nouns in the same category (e.g., different crops, fruits, animals).
    """
    try:
        # Direct check: if target and candidate are both specific terms for crops, animals, etc.
        # check if they're NOT near-synonyms
        
        # Agricultural/crop terms that shouldn't be swapped
        agricultural_terms = ['soybean', 'corn', 'maize', 'wheat', 'rice', 'barley', 'oats', 'sorghum', 
                             'millet', 'grain', 'cereal', 'pulse', 'bean', 'legume']
        
        # If both are agricultural terms and different, block
        if (target.lower() in agricultural_terms and candidate.lower() in agricultural_terms and 
            target.lower() != candidate.lower()):
            return False
        
        target_synsets = wn.synsets(target)
        cand_synsets = wn.synsets(candidate)
        
        if not target_synsets or not cand_synsets:
            return True  # If no synsets, allow through
        
        # Check if they're near-synonyms (very similar) - if so, allow
        # We can use path similarity to check if they're similar enough
        max_similarity = 0.0
        for t_syn in target_synsets:
            for c_syn in cand_synsets:
                try:
                    similarity = t_syn.path_similarity(c_syn) or 0.0
                    max_similarity = max(max_similarity, similarity)
                except:
                    pass
        
        # If they have high path similarity (>0.5), they're similar enough to allow
        if max_similarity > 0.5:
            return True
        
        # Otherwise, check if they share common direct hypernyms
        target_hypernyms = set()
        for syn in target_synsets:
            # Get immediate hypernyms (parent concepts)
            for hypernym in syn.hypernyms():
                target_hypernyms.add(hypernym)
        
        cand_hypernyms = set()
        for syn in cand_synsets:
            for hypernym in syn.hypernyms():
                cand_hypernyms.add(hypernym)
        
        # If they share hypernyms, check if they're both specific instances (not general terms)
        common_hypernyms = target_hypernyms & cand_hypernyms
        
        if common_hypernyms:
            # Check if both words are specific instances of the same category
            # If so, they shouldn't be replaced with each other
            # We identify this by checking if their hypernym has many siblings
            for hypernym in common_hypernyms:
                siblings = hypernym.hyponyms()
                # If there are many specific instances (e.g., many crops, many fruits)
                # it's likely a category with specific instances that shouldn't be interchanged
                if len(siblings) > 3:
                    # Check if hypernym name suggests a specific category
                    hypernym_name = hypernym.name().split('.')[0]
                    category_keywords = [
                        'crop', 'grain', 'fruit', 'animal', 'bird', 'fish', 'company', 
                        'country', 'city', 'brand', 'product', 'food', 'vehicle'
                    ]
                    
                    # If the hypernym contains category keywords, these are likely
                    # specific instances that shouldn't be swapped
                    if any(keyword in hypernym_name for keyword in category_keywords):
                        return False
        
        return True
        
    except Exception as e:
        # On any error, allow the candidate through (conservative approach)
        return True

def create_context_windows(full_text, target_sentence, target_word, tokenizer, max_tokens=400):
    """
    Create context windows around the target sentence for better MLM generation.
    Intelligently handles tokenizer length limits by preserving the most relevant context.
    
    Args:
        full_text: The complete document text
        target_sentence: The sentence containing the target word
        target_word: The word to be replaced
        tokenizer: The tokenizer to check length limits
        max_tokens: Maximum tokens to use for context (leave room for instruction + mask)
    
    Returns:
        List of context windows with different levels of context
    """
    # Split full text into sentences
    sentences = split_sentences(full_text)
    
    # Find the target sentence index
    target_sentence_idx = None
    for i, sent in enumerate(sentences):
        if target_sentence.strip() in sent.strip():
            target_sentence_idx = i
            break
    
    if target_sentence_idx is None:
        return [target_sentence]  # Fallback to original sentence
    
    # Create context windows with sentence-prioritized approach
    context_windows = []
    
    # Window 1: Just the target sentence (always include)
    context_windows.append(target_sentence)
    
    # Window 2: Target sentence + 1 sentence before and after (if fits)
    start_idx = max(0, target_sentence_idx - 1)
    end_idx = min(len(sentences), target_sentence_idx + 2)
    context_window = " ".join(sentences[start_idx:end_idx])
    
    try:
        encoded_len = len(tokenizer.encode(context_window))
        if encoded_len <= max_tokens:
            context_windows.append(context_window)
    except Exception as e:
        pass
    
    # Window 3: Target sentence + 2 sentences before and after (if fits)
    start_idx = max(0, target_sentence_idx - 2)
    end_idx = min(len(sentences), target_sentence_idx + 3)
    context_window = " ".join(sentences[start_idx:end_idx])
    
    try:
        encoded_len = len(tokenizer.encode(context_window))
        if encoded_len <= max_tokens:
            context_windows.append(context_window)
    except Exception as e:
        pass
    
    # Window 4: Target sentence + 3 sentences before and after (if fits)
    start_idx = max(0, target_sentence_idx - 3)
    end_idx = min(len(sentences), target_sentence_idx + 4)
    context_window = " ".join(sentences[start_idx:end_idx])
    
    try:
        encoded_len = len(tokenizer.encode(context_window))
        if encoded_len <= max_tokens:
            context_windows.append(context_window)
    except Exception as e:
        pass
    
    # Window 5: Intelligent context with sentence prioritization + word expansion
    intelligent_context = _create_intelligent_context(
        full_text, target_word, target_sentence_idx, tokenizer, max_tokens
    )
    context_windows.append(intelligent_context)
    
    return context_windows

def _create_intelligent_context(full_text, target_word, target_sentence_idx, tokenizer, max_tokens):
    """
    Create intelligent context that prioritizes sentence boundaries while respecting token limits.
    Strategy: Target sentence → Nearby sentences → Word-level expansion
    """
    sentences = split_sentences(full_text)
    
    # Strategy 1: Always start with the target sentence
    target_sentence = sentences[target_sentence_idx]
    try:
        target_sentence_tokens = len(tokenizer.encode(target_sentence))
    except Exception as e:
        target_sentence_tokens = 1000  # Fallback to assume it's too long
    
    if target_sentence_tokens > max_tokens:
        # If even target sentence is too long, truncate intelligently
        return _truncate_sentence_intelligently(target_sentence, target_word, tokenizer, max_tokens)
    
    # Strategy 2: Expand sentence-by-sentence around target sentence
    best_context = target_sentence
    best_token_count = target_sentence_tokens
    
    # Try adding sentences before and after the target sentence
    for sentence_radius in range(1, min(len(sentences), 20)):  # Max 20 sentences radius
        start_idx = max(0, target_sentence_idx - sentence_radius)
        end_idx = min(len(sentences), target_sentence_idx + sentence_radius + 1)
        
        # Create context with complete sentences
        context_sentences = sentences[start_idx:end_idx]
        context_window = " ".join(context_sentences)
        try:
            token_count = len(tokenizer.encode(context_window))
        except Exception as e:
            token_count = 1000  # Fallback to assume it's too long
        
        if token_count <= max_tokens:
            # This sentence expansion fits, keep it as our best option
            best_context = context_window
            best_token_count = token_count
        else:
            # This expansion is too big, stop here
            break
    
    # Strategy 3: If we have room left, try word-level expansion within the best sentence context
    remaining_tokens = max_tokens - best_token_count
    if remaining_tokens > 50:  # If we have significant room left
        enhanced_context = _enhance_with_word_expansion(
            full_text, target_word, best_context, tokenizer, remaining_tokens
        )
        if enhanced_context:
            return enhanced_context
    
    return best_context

def _enhance_with_word_expansion(full_text, target_word, current_context, tokenizer, remaining_tokens):
    """
    Enhance the current sentence-based context with word-level expansion if there's room.
    """
    words = full_text.split()
    target_word_idx = None
    
    # Find target word position in full text
    for i, word in enumerate(words):
        if word.lower() == target_word.lower():
            target_word_idx = i
            break
    
    if target_word_idx is None:
        return current_context
    
    # Try to expand word-by-word around the target word
    try:
        current_tokens = len(tokenizer.encode(current_context))
    except Exception as e:
        print(f"WARNING: Error encoding current context: {e}")
        current_tokens = 1000  # Fallback to assume it's too long
    
    for expansion_size in range(1, min(len(words), 100)):  # Max 100 words expansion
        start_word = max(0, target_word_idx - expansion_size)
        end_word = min(len(words), target_word_idx + expansion_size + 1)
        
        expanded_context = " ".join(words[start_word:end_word])
        try:
            expanded_tokens = len(tokenizer.encode(expanded_context))
        except Exception as e:
            expanded_tokens = 1000  # Fallback to assume it's too long
        
        if expanded_tokens <= current_tokens + remaining_tokens:
            # This expansion fits within our remaining token budget
            return expanded_context
        else:
            # This expansion is too big, stop here
            break
    
    return current_context

def _truncate_sentence_intelligently(sentence, target_word, tokenizer, max_tokens):
    """
    Intelligently truncate a sentence while preserving context around the target word.
    """
    words = sentence.split()
    target_word_idx = None
    
    # Find target word position
    for i, word in enumerate(words):
        if word.lower() == target_word.lower():
            target_word_idx = i
            break
    
    if target_word_idx is None:
        # If target word not found, truncate from the end
        truncated = " ".join(words)
        try:
            while len(tokenizer.encode(truncated)) > max_tokens and len(words) > 1:
                words = words[:-1]
                truncated = " ".join(words)
        except Exception as e:
            # Fallback: return first few words
            truncated = " ".join(words[:10]) if len(words) >= 10 else " ".join(words)
        return truncated
    
    # Truncate symmetrically around target word
    context_words = 10  # Start with 10 words before/after
    while context_words > 0:
        start_word = max(0, target_word_idx - context_words)
        end_word = min(len(words), target_word_idx + context_words + 1)
        truncated_sentence = " ".join(words[start_word:end_word])
        
        try:
            if len(tokenizer.encode(truncated_sentence)) <= max_tokens:
                return truncated_sentence
        except Exception as e:
            # Continue to next iteration
            pass
        
        context_words -= 1
    
    # Fallback: just the target word with minimal context
    return f"... {target_word} ..."

def _intelligent_token_slicing(input_text, tokenizer, max_length=512, mask_token_id=None):
    """
    Intelligently slice input text to fit within max_length tokens while preserving the mask token.
    Strategy: Preserve mask token and surrounding context, remove excess tokens from less important areas.
    
    Args:
        input_text: The full input text to be tokenized
        tokenizer: The tokenizer to use
        max_length: Maximum allowed sequence length (default 512)
        mask_token_id: The mask token ID to preserve
    
    Returns:
        Tuple of (sliced_input_ids, mask_position_in_sliced)
    """
    # First, tokenize the full input
    input_ids = tokenizer.encode(input_text, add_special_tokens=True)
    
    # If already within limits, return as is
    if len(input_ids) <= max_length:
        mask_pos = input_ids.index(mask_token_id) if mask_token_id in input_ids else None
        return input_ids, mask_pos
    
    # Find mask token position
    mask_positions = [i for i, token_id in enumerate(input_ids) if token_id == mask_token_id]
    
    if not mask_positions:
        # No mask token found, truncate from the end
        return input_ids[:max_length], None
    
    mask_pos = mask_positions[0]  # Use first mask token
    
    # Calculate how many tokens we need to remove
    excess_tokens = len(input_ids) - max_length
    
    # Strategy: Remove tokens from both ends while preserving mask context
    # Reserve some context around the mask token
    mask_context_size = min(50, max_length // 4)  # Reserve 25% of max_length or 50 tokens, whichever is smaller
    
    # Calculate available space for context around mask
    available_before = min(mask_pos, mask_context_size)
    available_after = min(len(input_ids) - mask_pos - 1, mask_context_size)
    
    # Calculate how much to remove from each end
    tokens_to_remove_before = max(0, mask_pos - available_before)
    tokens_to_remove_after = max(0, (len(input_ids) - mask_pos - 1) - available_after)
    
    # Initialize removal variables
    remove_before = 0
    remove_after = 0
    
    # Distribute excess tokens proportionally
    if excess_tokens > 0:
        if tokens_to_remove_before + tokens_to_remove_after >= excess_tokens:
            # We can remove enough from the ends
            if tokens_to_remove_before >= excess_tokens // 2:
                remove_before = excess_tokens // 2
                remove_after = excess_tokens - remove_before
            else:
                remove_before = tokens_to_remove_before
                remove_after = min(tokens_to_remove_after, excess_tokens - remove_before)
        else:
            # Need to remove more aggressively
            remove_before = tokens_to_remove_before
            remove_after = tokens_to_remove_after
            remaining_excess = excess_tokens - remove_before - remove_after
            
            # Remove remaining excess from the end
            if remaining_excess > 0:
                remove_after += remaining_excess
    
    # Calculate final indices
    start_idx = remove_before
    end_idx = len(input_ids) - remove_after
    
    # Ensure we don't exceed max_length
    if end_idx - start_idx > max_length:
        # Center around mask token
        half_length = max_length // 2
        start_idx = max(0, mask_pos - half_length)
        end_idx = min(len(input_ids), start_idx + max_length)
    
    # Slice the input_ids
    sliced_input_ids = input_ids[start_idx:end_idx]
    
    # Debug information
    if len(sliced_input_ids) > max_length:
        # Force truncation as final fallback
        sliced_input_ids = sliced_input_ids[:max_length]
    
    # Adjust mask position for the sliced sequence
    adjusted_mask_pos = mask_pos - start_idx
    
    return sliced_input_ids, adjusted_mask_pos

def _create_word_level_context(full_text, target_word, tokenizer, max_tokens):
    """
    Create context by expanding word-by-word around the target word until reaching token limit.
    This maximizes context while respecting tokenizer limits.
    """
    words = full_text.split()
    target_word_idx = None
    
    # Find target word position in full text
    for i, word in enumerate(words):
        if word.lower() == target_word.lower():
            target_word_idx = i
            break
    
    if target_word_idx is None:
        # Fallback: expand from beginning until token limit
        return _expand_from_start(words, tokenizer, max_tokens)
    
    # Word-by-word expansion around target word
    return _expand_around_target(words, target_word_idx, tokenizer, max_tokens)

def _expand_around_target(words, target_idx, tokenizer, max_tokens):
    """
    Expand word-by-word around target word until reaching token limit.
    """
    best_context = ""
    best_token_count = 0
    
    # Try different expansion sizes
    for expansion_size in range(1, min(len(words), 200)):  # Max 200 words expansion
        start_word = max(0, target_idx - expansion_size)
        end_word = min(len(words), target_idx + expansion_size + 1)
        
        context_window = " ".join(words[start_word:end_word])
        try:
            token_count = len(tokenizer.encode(context_window))
        except Exception as e:
            token_count = 1000  # Fallback to assume it's too long
        
        if token_count <= max_tokens:
            # This expansion fits, keep it as our best option
            best_context = context_window
            best_token_count = token_count
        else:
            # This expansion is too big, stop here
            break
    
    # If we found a good context, return it
    if best_context:
        return best_context
    
    # Fallback: minimal context around target word
    start_word = max(0, target_idx - 5)
    end_word = min(len(words), target_idx + 6)
    return " ".join(words[start_word:end_word])

def _expand_from_start(words, tokenizer, max_tokens):
    """
    Expand from the start of the text until reaching token limit.
    """
    for end_idx in range(len(words), 0, -1):
        context_window = " ".join(words[:end_idx])
        try:
            if len(tokenizer.encode(context_window)) <= max_tokens:
                return context_window
        except Exception as e:
            # Continue to next iteration
            pass
    
    # Fallback: first few words
    return " ".join(words[:10]) if len(words) >= 10 else " ".join(words)

def whole_context_mlm_inference(full_text, sentence_target_pairs, tokenizer, lm_model, Top_K=20, batch_size=32, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
    """
    Enhanced MLM inference using whole document context for better candidate generation.
    """
    results = {}
    
    # Group targets by sentence for batch processing
    sentence_groups = {}
    for sent, target, index in sentence_target_pairs:
        if sent not in sentence_groups:
            sentence_groups[sent] = []
        sentence_groups[sent].append((target, index))
    
    for sentence, targets in sentence_groups.items():
        # Process targets in batches
        for i in range(0, len(targets), batch_size):
            batch_targets = targets[i:i+batch_size]
            batch_results = _process_whole_context_mlm_batch(
                full_text, sentence, batch_targets, tokenizer, lm_model, Top_K, max_context_tokens, max_length, similarity_context_mode
            )
            results.update(batch_results)
    
    return results

def _process_whole_context_mlm_batch(full_text, sentence, targets, tokenizer, lm_model, Top_K, max_context_tokens=400, max_length=512, similarity_context_mode='whole'):
    """
    Process a batch of targets using whole document context for MLM.
    """
    results = {}
    
    # Tokenize sentence once
    doc = nlp(sentence)
    tokens = [token.text for token in doc]
    
    # Create multiple masked versions for batch processing
    masked_inputs = []
    mask_positions = []
    contexts_for_targets = []
    
    for target, index in targets:
        if index < len(tokens):
            # Create context windows with tokenizer length awareness
            context_windows = create_context_windows(full_text, sentence, target, tokenizer, max_tokens=max_context_tokens)
            
            # Use the most comprehensive context window that fits within token limits
            full_context = context_windows[-1]  # Built around the target sentence
            # Select context for similarity according to mode
            context = sentence if similarity_context_mode == 'sentence' else full_context
            
            # Create masked version of the FULL context (not just the sentence)
            masked_full_context = context.replace(target, tokenizer.mask_token, 1)
            
            instruction = "Given the full document context, replace the masked word with a word that fits grammatically, preserves the original meaning, and ensures natural flow in the document:"
            input_text = f"{instruction} {context} {tokenizer.sep_token} {masked_full_context}"
            
            # AGGRESSIVE FIX: Truncate input text BEFORE tokenization to prevent errors
            # Estimate token count (roughly 1 token per 4 characters for English)
            estimated_tokens = len(input_text) // 4
            if estimated_tokens > 500:  # Leave some buffer
                # Truncate to roughly 2000 characters (500 tokens)
                input_text = input_text[:2000]
            
            # SIMPLE FIX: Truncate input text if it's too long
            try:
                temp_tokens = tokenizer.encode(input_text, add_special_tokens=True)
                if len(temp_tokens) > 512:
                    # Truncate the input text by removing words from the end
                    words = input_text.split()
                    while len(tokenizer.encode(" ".join(words), add_special_tokens=True)) > 512 and len(words) > 10:
                        words = words[:-1]
                    input_text = " ".join(words)
            except Exception as e:
                # Emergency truncation - just take first 200 words
                words = input_text.split()
                input_text = " ".join(words[:200])
            
            masked_inputs.append(input_text)
            # Store the original sentence-level index for reference, but mask position will be calculated during tokenization
            mask_positions.append(index)
            contexts_for_targets.append(context)
    
    if not masked_inputs:
        return results
    
    # Batch tokenize
    MAX_LENGTH = max_length  # Use parameter for A100 optimization
    batch_inputs = []
    batch_mask_positions = []
    batch_contexts = []
    
    for input_text, mask_pos in zip(masked_inputs, mask_positions):
        # Use intelligent token slicing to ensure we stay within MAX_LENGTH
        try:
            input_ids, adjusted_mask_pos = _intelligent_token_slicing(
                input_text, tokenizer, max_length=MAX_LENGTH, mask_token_id=tokenizer.mask_token_id
            )
            
            if adjusted_mask_pos is not None:
                batch_inputs.append(input_ids)
                batch_mask_positions.append(adjusted_mask_pos)
            else:
                # Mask token not found in sliced sequence, skip this input
                continue
                
        except Exception as e:
            # Fallback: simple truncation
            try:
                input_ids = tokenizer.encode(input_text, add_special_tokens=True)
                if len(input_ids) > MAX_LENGTH:
                    input_ids = input_ids[:MAX_LENGTH]
                
                masked_position = input_ids.index(tokenizer.mask_token_id)
                batch_inputs.append(input_ids)
                batch_mask_positions.append(masked_position)
            except ValueError:
                # Mask token not found, skip this input
                continue
    
    if not batch_inputs:
        return results
    
    # Pad sequences to same length, but ensure we don't exceed MAX_LENGTH
    max_len = min(max(len(ids) for ids in batch_inputs), MAX_LENGTH)
    
    # Additional safety check: truncate any sequences that are still too long
    truncated_batch_inputs = []
    for input_ids in batch_inputs:
        if len(input_ids) > MAX_LENGTH:
            input_ids = input_ids[:MAX_LENGTH]
        truncated_batch_inputs.append(input_ids)
    
    padded_inputs = []
    attention_masks = []
    
    for input_ids in truncated_batch_inputs:
        attention_mask = [1] * len(input_ids) + [0] * (max_len - len(input_ids))
        padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
        padded_inputs.append(padded_ids)
        attention_masks.append(attention_mask)
    
    # Final safety check: ensure no sequence exceeds MAX_LENGTH
    for i, padded_ids in enumerate(padded_inputs):
        if len(padded_ids) > MAX_LENGTH:
            padded_inputs[i] = padded_ids[:MAX_LENGTH]
            attention_masks[i] = attention_masks[i][:MAX_LENGTH]
    
    # Batch inference - optimized for A100 with mixed precision
    with torch.no_grad():
        input_tensor = torch.tensor(padded_inputs, dtype=torch.long)
        attention_tensor = torch.tensor(attention_masks, dtype=torch.long)
        
        if torch.cuda.is_available():
            input_tensor = input_tensor.cuda()
            attention_tensor = attention_tensor.cuda()
            
            # Use mixed precision for A100 optimization
            with torch.amp.autocast('cuda'):
                outputs = lm_model(input_tensor, attention_mask=attention_tensor)
                logits = outputs.logits
        else:
            outputs = lm_model(input_tensor, attention_mask=attention_tensor)
            logits = outputs.logits
    
    # Process results - collect filtered candidates first
    batch_filtered_results = {}
    for i, (target, index) in enumerate(targets):
        if i < len(batch_mask_positions):
            mask_pos = batch_mask_positions[i]
            mask_logits = logits[i, mask_pos].squeeze()
            
            # Get top predictions
            top_tokens = torch.topk(mask_logits, k=Top_K, dim=0)[1]
            scores = torch.softmax(mask_logits, dim=0)[top_tokens].tolist()
            words = [tokenizer.decode(token.item()).strip() for token in top_tokens]
            
            # Filter candidates (without similarity scoring)
            filtered_candidates = _filter_candidates_batch(target, words, scores, tokens, index)
            if filtered_candidates:
                # Attach the exact context window used for this target
                batch_filtered_results[(sentence, target, index)] = {
                    'filtered_words': filtered_candidates,
                    'context': contexts_for_targets[i]
                }
    
    # Batch similarity scoring for all candidates
    if batch_filtered_results:
        similarity_results = _batch_similarity_scoring(batch_filtered_results, tokenizer)
        results.update(similarity_results)
    
    return results

def _filter_candidates_batch(target, words, scores, tokens, target_index):
    """
    Optimized batch filtering of candidates (no similarity scoring here - moved to batch level).
    """
    # Basic filtering
    filtered_words = []
    filtered_scores = []
    seen_words = set()
    
    for word, score in zip(words, scores):
        word_lower = word.lower()
        if word_lower in seen_words or word_lower == target.lower():
            continue
        seen_words.add(word_lower)
        
        if not is_valid_word(word):
            continue
        
        # Quick POS check
        target_nltk_pos, target_spacy_pos = get_word_pos_tags(target)
        cand_nltk_pos, cand_spacy_pos = get_word_pos_tags(word)
        
        if target_nltk_pos != cand_nltk_pos or target_spacy_pos != cand_spacy_pos:
            continue
        
        # Check antonyms (bidirectional and case-insensitive)
        antonyms = get_word_antonyms(target)
        if word.lower() in [ant.lower() for ant in antonyms]:
            continue
        
        # Also check if the candidate has the target as an antonym (reverse check)
        candidate_antonyms = get_word_antonyms(word)
        if target.lower() in [ant.lower() for ant in candidate_antonyms]:
            continue
        
        # Hardcoded common antonym pairs (for words not in WordNet or as additional safeguard)
        common_antonyms = {
            'big': ['small', 'tiny', 'little'],
            'small': ['big', 'large', 'huge'],
            'large': ['small', 'tiny', 'little'],
            'good': ['bad', 'evil', 'wrong'],
            'bad': ['good', 'great', 'excellent'],
            'high': ['low'],
            'low': ['high'],
            'new': ['old'],
            'old': ['new'],
            'fast': ['slow'],
            'slow': ['fast'],
            'rich': ['poor'],
            'poor': ['rich'],
            'hot': ['cold'],
            'cold': ['hot'],
            'happy': ['sad', 'unhappy'],
            'sad': ['happy', 'joyful'],
            'true': ['false', 'untrue'],
            'false': ['true'],
            'real': ['fake', 'unreal'],
            'fake': ['real'],
            'up': ['down'],
            'down': ['up'],
            'yes': ['no'],
            'no': ['yes'],
            'alive': ['dead'],
            'dead': ['alive'],
            'safe': ['unsafe', 'dangerous'],
            'dangerous': ['safe'],
            'clean': ['dirty'],
            'dirty': ['clean'],
            'full': ['empty'],
            'empty': ['full'],
            'open': ['closed', 'shut'],
            'closed': ['open'],
            'begin': ['end', 'finish'],
            'end': ['begin', 'start'],
            'start': ['end', 'finish'],
            'finish': ['start', 'begin'],
            'first': ['last'],
            'last': ['first']
        }
        
        # Check if word is a known antonym of target (case-insensitive)
        target_lower = target.lower()
        if target_lower in common_antonyms and word.lower() in common_antonyms[target_lower]:
            continue
        
        # Check if word and target are in the same specific noun category (e.g., crops, animals, companies)
        # If they are different specific terms in the same category, exclude the candidate
        if not _are_semantically_compatible(target, word):
            continue
        
        filtered_words.append(word)
        filtered_scores.append(score)
    
    if len(filtered_words) < 2:
        return None
    
    # Return filtered words without similarity scoring (done at batch level)
    return filtered_words

def _batch_similarity_scoring(batch_results, tokenizer):
    """
    Optimized batched similarity scoring across multiple sentences for full context.
    Processes all candidates from multiple sentences together for better efficiency.
    """
    # Collect all similarity scoring tasks
    similarity_tasks = []
    sentence_contexts = {}
    
    for (sentence, target, index), value in batch_results.items():
        if value is None:
            continue
        # Support both legacy list and new dict with context
        if isinstance(value, dict):
            filtered_words = value.get('filtered_words')
            context = value.get('context', sentence)
        else:
            filtered_words = value
            context = sentence
            
        # Tokenize the sentence once
        tokens = tokenizer.tokenize(sentence)
        if index >= len(tokens):
            continue
            
        # Store sentence context for later use
        sentence_contexts[(sentence, target, index)] = {
            'tokens': tokens,
            'target_index': index,
            'filtered_words': filtered_words
        }
        
        # Create candidate sentences for this target
        for word in filtered_words:
            candidate_tokens = tokens.copy()
            candidate_tokens[index] = word
            candidate_sentence = tokenizer.convert_tokens_to_string(candidate_tokens)
            
            # Build full-context candidate by replacing the sentence inside the chosen context once
            candidate_full_context = context.replace(sentence, candidate_sentence, 1)
            similarity_tasks.append({
                'original_context': context,
                'candidate_full_context': candidate_full_context,
                'target_word': word,
                'context_key': (sentence, target, index)
            })
    
    if not similarity_tasks:
        return {}
    
    # Batch process all similarity scoring
    try:
        # Group by original full context for efficient BERTScore computation
        context_groups = {}
        for task in similarity_tasks:
            orig_ctx = task['original_context']
            if orig_ctx not in context_groups:
                context_groups[orig_ctx] = []
            context_groups[orig_ctx].append(task)
        
        # Process each context group
        final_results = {}
        for orig_context, tasks in context_groups.items():
            # Extract candidate full-contexts
            candidate_contexts = [task['candidate_full_context'] for task in tasks]
            
            # Batch BERTScore computation against the same full context
            try:
                similarity_scores = calc_scores_bert(orig_context, candidate_contexts)
            except Exception as e:
                # Fallback to neutral scores
                similarity_scores = [0.5] * len(candidate_contexts)
            
            if similarity_scores and not all(score == 0.5 for score in similarity_scores):
                # Group results by context key
                for task, score in zip(tasks, similarity_scores):
                    context_key = task['context_key']
                    if context_key not in final_results:
                        final_results[context_key] = []
                    final_results[context_key].append((task['target_word'], score))
        
        # Sort results by similarity score
        for context_key in final_results:
            final_results[context_key].sort(key=lambda x: x[1], reverse=True)
        
        return final_results
        
    except Exception as e:
        return {}

def parallel_tournament_sampling(target_results, secret_key, m, c, h, alpha):
    """
    Parallel tournament sampling for multiple targets.
    """
    results = {}
    
    if not target_results:
        return results
    
    def process_single_tournament(item):
        (sentence, target, index), candidates = item
        if not candidates:
            return (sentence, target, index), None
        
        alternatives = [alt[0] for alt in candidates]
        similarity = [alt[1] for alt in candidates]
        
        if not alternatives or not similarity:
            return (sentence, target, index), None
        
        # Get context
        context_tokens = word_tokenize(sentence)
        left_context = context_tokens[max(0, index - h):index]
        
        # Tournament selection
        from SynthID_randomization import tournament_select_word
        randomized_word = tournament_select_word(
            target, alternatives, similarity, 
            context=left_context, key=secret_key, m=m, c=c, alpha=alpha
        )
        
        return (sentence, target, index), randomized_word
    
    # Process in parallel
    max_workers = max(1, min(8, len(target_results)))
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_item = {executor.submit(process_single_tournament, item): item for item in target_results.items()}
        
        for future in as_completed(future_to_item):
            key, result = future.result()
            results[key] = result
    
    return results

def whole_context_process_sentence(full_text, sentence, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, output_name, batch_size=32, max_length=512, max_context_tokens=400, similarity_context_mode='whole'):
    """
    Enhanced sentence processing using whole document context for better candidate generation.
    """
    replacements = []
    sampling_results = []
    doc = nlp(sentence)
    sentence_target_pairs = extract_entities_and_pos(sentence)

    if not sentence_target_pairs:
        return replacements, sampling_results

    # Filter valid target pairs
    valid_pairs = []
    spacy_tokens = [token.text for token in doc]
    
    for sent, target, position in sentence_target_pairs:
        if position < len(spacy_tokens) and spacy_tokens[position] == target:
            valid_pairs.append((sent, target, position))

    if not valid_pairs:
        return replacements, sampling_results

    # Enhanced MLM inference with whole document context
    batch_results = whole_context_mlm_inference(full_text, valid_pairs, tokenizer, lm_model, Top_K, batch_size, max_context_tokens, max_length, similarity_context_mode)
    
    # Filter by threshold (matching original logic)
    filtered_results = {}
    for key, candidates in batch_results.items():
        if candidates:
            # Apply threshold filtering (matching original logic)
            threshold_candidates = [(word, score) for word, score in candidates if score >= threshold]
            if len(threshold_candidates) >= 2:
                filtered_results[key] = threshold_candidates
    
    # Parallel tournament sampling
    tournament_results = parallel_tournament_sampling(filtered_results, secret_key, m, c, h, alpha)
    
    # Collect replacements and sampling results
    for (sent, target, position), randomized_word in tournament_results.items():
        if randomized_word:
            # Get the alternatives for this target from the filtered results
            alternatives = filtered_results.get((sent, target, position), [])
            alternatives_list = [alt[0] for alt in alternatives]
            # Include similarity scores for each alternative (preserve old 'alternatives' list for compatibility)
            alternatives_with_similarity = [
                {"word": alt[0], "similarity": float(alt[1])} for alt in alternatives
            ]
            
            # Track sampling results
            sampling_results.append({
                "word": target,
                "alternatives": alternatives_list,
                "alternatives_with_similarity": alternatives_with_similarity,
                "randomized_word": randomized_word
            })
            
            replacements.append((position, target, randomized_word))
    
    return replacements, sampling_results

# Legacy function for compatibility
def look_up(sentence, target, index, tokenizer, lm_model, Top_K=20, threshold=0.75):
    """
    Legacy single-target lookup function for compatibility.
    """
    results = batch_mlm_inference([(sentence, target, index)], tokenizer, lm_model, Top_K)
    return results.get((sentence, target, index), None)

def batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K=20):
    """
    Legacy batch MLM inference function for compatibility.
    """
    return whole_context_mlm_inference("", sentence_target_pairs, tokenizer, lm_model, Top_K)

def batch_look_up(sentence_target_pairs, tokenizer, lm_model, Top_K=20, threshold=0.75, max_workers=4):
    """
    Optimized batch lookup using the new batch MLM inference.
    """
    return batch_mlm_inference(sentence_target_pairs, tokenizer, lm_model, Top_K)