File size: 98,769 Bytes
2f99d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
# AraSpell — Arabic Spell Checker Pipeline
# Production-ready version

import re
import math
import logging
import torch
import os
from collections import Counter
from transformers import AutoTokenizer, EncoderDecoderModel
import Levenshtein
import jellyfish

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# ═══════════════════════════════════════════════════════════════════════════════
# LOAD ARABERT SEQ2SEQ MODEL
# ═══════════════════════════════════════════════════════════════════════════════

from huggingface_hub import hf_hub_download

MODEL_REPO = 'bayan10/AraSpell-Model'
MODEL_FILENAME = 'last_model.pt'

try:
    logger.info(f"Downloading/loading model from Hugging Face: {MODEL_REPO}")
    MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
except Exception as e:
    raise RuntimeError(f"Failed to download model from Hugging Face: {e}")

MODEL_NAME = 'aubmindlab/bert-base-arabertv02'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_NAME, MODEL_NAME)

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.generation_config.max_length = 128
model.generation_config.decoder_start_token_id = tokenizer.cls_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.eos_token_id = tokenizer.sep_token_id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model = model.to(device)
model.eval()

logger.info(f"Model loaded on {device}, epoch: {checkpoint.get('epoch', 'N/A')}")

from enum import Enum
from typing import List, Tuple, Optional

# ─────────────────────────────────────────────────────────────────────────────
# ERROR TYPE ENUM
# ─────────────────────────────────────────────────────────────────────────────

class ErrorType(Enum):
    """Types of spelling errors"""
    CHAR_REPETITION = "char_repetition"
    WORD_MERGE = "word_merge"
    CHAR_SUBSTITUTION = "char_substitution"
    MIXED = "mixed"
    CLEAN = "clean"

# ═══════════════════════════════════════════════════════════════════════════════
# POST PROCESSOR
# ═══════════════════════════════════════════════════════════════════════════════

class AraSpellPostProcessor:
    """Arabic text post-processing techniques."""
    
    ARABIC_HARAKAT = 'ًٌٍَُِّْ'
    TATWEEL = 'ـ'
    NORMALIZER_MAP = {
        'ﻹ': 'لإ', 'ﻷ': 'لأ', 'ﻵ': 'لآ', 'ﻻ': 'لا', 'ﷲ': 'الله'
    }
    ARABIC_CONSONANTS = set('بتثجحخدذرزسشصضطظعغفقكلمن')
    
    # --- Basic Normalization ---
    
    @staticmethod
    def remove_harakat(text: str) -> str:
        """Remove Arabic diacritics"""
        return re.sub(r'[ً-ْ]', '', text)
    
    @staticmethod
    def remove_tatweel(text: str) -> str:
        """Remove Arabic kashida/tatweel"""
        return text.replace(AraSpellPostProcessor.TATWEEL, '')
    
    @staticmethod
    def normalize_special_chars(text: str) -> str:
        """Normalize special Arabic ligatures"""
        for old, new in AraSpellPostProcessor.NORMALIZER_MAP.items():
            text = text.replace(old, new)
        return text
    
    # --- Core Functions ---
    
    @staticmethod
    def unified_collapse_repeated(text: str) -> str:
        """

        Collapse repeated characters.

        Arabic: 3+ consecutive → 1 | Latin: 2+ consecutive → 1

        """
        # Arabic characters: 3+ → 1
        text = re.sub(r"([\u0600-\u06FF])\1{2,}", r"\1", text)
        
        # Latin characters: 2+ → 1
        text = re.sub(r"([a-zA-Z])\1+", r"\1", text)
        
        return text
    
    @staticmethod
    def remove_duplicate_words(text: str) -> str:
        """Remove consecutive duplicate words. e.g. كتاب كتاب → كتاب"""
        words = text.split()
        if len(words) < 2:
            return text
        
        result = [words[0]]
        for i in range(1, len(words)):
            if words[i] != words[i-1]:
                result.append(words[i])
        
        return ' '.join(result)
    
    @staticmethod
    def normalize_spaces(text: str) -> str:
        """Normalize whitespace: multiple spaces, unicode spaces, punctuation spacing."""
        # Multiple spaces → single
        text = re.sub(r' +', ' ', text)
        
        # Unicode spaces
        text = text.replace('\u00A0', ' ')  # Non-breaking space
        text = text.replace('\u200B', '')   # Zero-width space
        text = text.replace('\u200C', '')   # Zero-width non-joiner
        text = text.replace('\u200D', '')   # Zero-width joiner
        
        # Trim
        text = text.strip()
        
        # Punctuation spacing
        text = re.sub(r'\s*([،؛؟!.])\s*', r'\1 ', text)
        text = text.strip()
        
        return text
    
    @staticmethod
    def remove_word_repetition_with_wa(text: str) -> str:
        """Remove word و word → word"""
        words = text.split()
        result = []
        i = 0
        while i < len(words):
            if i + 2 < len(words) and words[i] == words[i+2] and words[i+1] == 'و':
                result.append(words[i])
                i += 3
            else:
                result.append(words[i])
                i += 1
        return ' '.join(result)
    
    # --- Hamza & Ta Marbuta Handling ---
    
    @staticmethod
    def fix_hamza_conservative(text: str) -> str:
        """Conservative Hamza normalization — only at word END, not middle."""
        words = text.split()
        result = []
        
        for word in words:
            if len(word) >= 3:
                # Fix trailing أ → ا
                if word.endswith('أ'):
                    word = word[:-1] + 'ا'
                
                # Fix trailing إ → ا
                if word.endswith('إ'):
                    word = word[:-1] + 'ا'
            
            result.append(word)
        
        return ' '.join(result)
    
    @staticmethod
    def fix_ha_ta_marbuta(text: str, vocab_manager=None) -> str:
        """

        Smart ه → ة fix at end of words.

        

        Key insight: ه at word end can be:

        - Ta Marbuta (should be ة): المدرسه → المدرسة

        - Possessive pronoun (should stay ه): تحقيقه = his achievement

        

        Strategy: Only convert if the ة version is IV (in tokenizer vocab).

        This distinguishes المدرسة (IV) from تحقيقة (not a real word form).

        Without vocab_manager, falls back to original pattern-based approach.

        """
        # Protected words: anything containing لله
        PROTECTED_ENDINGS = ['لله']
        
        words = text.split()
        result = []
        
        for word in words:
            # Skip protected words (Allah-related)
            if any(word.endswith(e) for e in PROTECTED_ENDINGS):
                result.append(word)
                continue
            
            if len(word) >= 4 and word.endswith('ه'):
                # Check if second-to-last char is a consonant
                if word[-2] in AraSpellPostProcessor.ARABIC_CONSONANTS:
                    candidate_with_ta = word[:-1] + 'ة'
                    
                    if vocab_manager:
                        # SMART MODE: Use vocab to decide
                        ta_iv = vocab_manager.is_iv(candidate_with_ta)
                        ha_iv = vocab_manager.is_iv(word)
                        
                        if ta_iv:
                            # ة version is IV → convert (المدرسه→المدرسة)
                            result.append(candidate_with_ta)
                            continue
                        elif ha_iv:
                            # Only ه version is IV → keep ه (possessive: تحقيقه)
                            result.append(word)
                            continue
                        # else: NEITHER is IV → keep original ه
                        # (safer than guessing — could be rare possessive)
                    else:
                        # FALLBACK: No vocab → use original pattern-based approach
                        result.append(candidate_with_ta)
                        continue
            result.append(word)
        
        return ' '.join(result)
    
    # --- Hallucination Removal ---
    
    @staticmethod
    def remove_hallucinations(text: str) -> str:
        """Remove model hallucinations: duplicate words, trailing 'و' artifacts."""
        words = text.split()
        if not words:
            return text
        
        result = []
        i = 0
        
        def normalize_word(w: str) -> str:
            """Normalize for comparison"""
            w = w.replace('ال', '').replace('ة', 'ه')
            w = re.sub(r'[أإآ]', 'ا', w)
            return w
        
        while i < len(words):
            word = words[i]
            
            # Remove trailing 'و' artifacts (الماضيةو → الماضية)
            if len(word) > 4 and word.endswith('و'):
                prev_char = word[-2]
                if prev_char in 'ةهاأإآء':
                    word = word[:-1]
            
            # Check for duplicate patterns
            if i + 1 < len(words):
                next_word = words[i + 1]
                if normalize_word(word) == normalize_word(next_word):
                    # Keep the one with 'ال' if possible
                    keep = next_word if next_word.startswith('ال') and not word.startswith('ال') else word
                    result.append(keep)
                    i += 2
                    continue
            
            result.append(word)
            i += 1
        
        return ' '.join(result)
    
    @staticmethod
    def remove_hallucinated_prefix(text: str, original: str) -> str:
        """Remove particles (و/في) added by model if not in original"""
        if not original:
            return text
        
        if text.startswith('و ') and not original.startswith('و'):
            rest = text[2:].strip()
            # Verify it matches original
            if AraSpellPostProcessor.normalize_special_chars(rest) == AraSpellPostProcessor.normalize_special_chars(original):
                return rest
        
        return text
    
    # --- Word Splitting & Merging ---
    
    @staticmethod
    def merge_separated_al(text: str) -> str:
        """Merge 'ال' separated by space: ال + كتاب → الكتاب"""
        return re.sub(r'\bال\s+(\w+)', r'ال\1', text)
    
    @staticmethod
    def join_fragments(text: str) -> str:
        """Join short fragments with validation. e.g. الط + الب → الطالب"""
        words = text.split()
        if len(words) < 2:
            return text
        
        # Common standalone words that should NOT be merged
        STANDALONE_WORDS = {
            'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال', 
            'بعد', 'قبل', 'ب', 'ل', 'ك', 'و', 'أو', 'لا', 'ما', 'لم', 'لن',
            'هو', 'هي', 'هم', 'أن', 'إن', 'كل', 'كان', 'قد', 'قال', 'ذلك',
            'هذا', 'هذه', 'تلك', 'التي', 'الذي', 'التى', 'اللذي'
        }
        
        result = []
        i = 0
        
        while i < len(words):
            word = words[i]
            
            if i + 1 < len(words):
                next_word = words[i + 1]
                
                # SAFETY: Don't merge if both are standalone words
                if word in STANDALONE_WORDS and next_word in STANDALONE_WORDS:
                    result.append(word)
                    i += 1
                    continue
                
                # Case 1: Single char fragment (safe to merge)
                if len(next_word) == 1:
                    result.append(word + next_word)
                    i += 2
                    continue
                
                # Case 2: Overlap (last char of word == first char of next)
                if len(word) >= 2 and len(next_word) >= 2 and word[-1] == next_word[0]:
                    if not (word in STANDALONE_WORDS and next_word in STANDALONE_WORDS):
                        result.append(word[:-1] + next_word)
                        i += 2
                        continue
                
                # Case 3: Short fragments (2-4 chars + 1-2 chars)
                if (2 <= len(word) <= 4 and 
                    1 <= len(next_word) <= 2 and
                    3 <= len(word) + len(next_word) <= 7):
                    if not (word in STANDALONE_WORDS and next_word in STANDALONE_WORDS):
                        result.append(word + next_word)
                        i += 2
                        continue
            
            result.append(word)
            i += 1
        
        return ' '.join(result)
    
    # --- Main Pipelines ---
    
    @staticmethod
    def full_postprocess(text: str, original: str = "", vocab_manager=None) -> str:
        """

        Apply all post-processing steps (OPTIMIZED ORDER!)

        vocab_manager: optional, enables smart ه/ة handling

        """
        # 1. Remove hallucinated prefixes
        if original:
            text = AraSpellPostProcessor.remove_hallucinated_prefix(text, original)
        
        # 2. Basic normalization
        text = AraSpellPostProcessor.normalize_special_chars(text)
        
        # 3. Remove hallucinations
        text = AraSpellPostProcessor.remove_hallucinations(text)
        
        # 4. Collapse repetitions (UNIFIED!)
        text = AraSpellPostProcessor.unified_collapse_repeated(text)
        
        # 5. Fix Hamza (CONSERVATIVE!)
        text = AraSpellPostProcessor.fix_hamza_conservative(text)
        
        # 6. Fix Ta Marbuta (SMART MODE with vocab_manager!)
        text = AraSpellPostProcessor.fix_ha_ta_marbuta(text, vocab_manager=vocab_manager)
        
        # 7. Remove word repetition with 'و'
        text = AraSpellPostProcessor.remove_word_repetition_with_wa(text)
        
        # 8. Remove duplicate words
        text = AraSpellPostProcessor.remove_duplicate_words(text)
        
        # 9. Final space normalization
        text = AraSpellPostProcessor.normalize_spaces(text)
        
        return text


# ─────────────────────────────────────────────────────────────────────────────
# ERROR CLASSIFIER
# ─────────────────────────────────────────────────────────────────────────────

class ErrorClassifier:
    """Classify type of spelling error"""
    
    NON_ARABIC_KEYBOARD = set('پگچژکەڕڤڵڎےۀۃھیټډڼڑ')
    
    @staticmethod
    def has_char_substitution(text: str) -> bool:
        return any(c in ErrorClassifier.NON_ARABIC_KEYBOARD for c in text)
    
    @staticmethod
    def has_char_repetition(text: str, threshold: int = 3) -> bool:
        return bool(re.search(r"(.)\1{" + str(threshold - 1) + ",}", text))
    
    @staticmethod
    def has_word_merge(text: str, max_word_len: int = 8) -> bool:
        words = text.split()
        if any(len(w) > max_word_len for w in words):
            return True
        if len(words) == 1 and len(text) > 6:
            return True
        return False
    
    @staticmethod
    def classify(text: str) -> ErrorType:
        """Classify the error type"""
        has_rep = ErrorClassifier.has_char_repetition(text)
        has_merge = ErrorClassifier.has_word_merge(text)
        has_sub = ErrorClassifier.has_char_substitution(text)
        
        error_count = sum([has_rep, has_merge, has_sub])
        
        if error_count >= 2:
            return ErrorType.MIXED
        elif has_sub:
            return ErrorType.CHAR_SUBSTITUTION
        elif has_rep:
            return ErrorType.CHAR_REPETITION
        elif has_merge:
            return ErrorType.WORD_MERGE
        else:
            return ErrorType.CLEAN

# ═══════════════════════════════════════════════════════════════════════════════
# RULES-BASED CORRECTOR
# ═══════════════════════════════════════════════════════════════════════════════

class RulesBasedCorrector:
    """Rules-based correction with keyboard proximity mapping."""
    
    # Persian/Urdu → Arabic mapping
    SUBSTITUTION_MAP = {
        'ک': 'ك', 'ی': 'ي', 'ے': 'ي',
        'پ': 'ب', 'چ': 'ج', 'ژ': 'ز',
        'گ': 'ك', 'ڤ': 'ف', 'ڵ': 'ل',
        'ڕ': 'ر', 'ڎ': 'د', 'ڼ': 'ن',
        'ټ': 'ت', 'ډ': 'د', 'ړ': 'ر',
        'ۀ': 'ه', 'ۃ': 'ة', 'ھ': 'ه',
        'ە': 'ه', 'ڑ': 'ر'
    }
    
    # EXPANDED: 16 prepositions instead of 2
    PREPOSITIONS = {
        'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى',
        'حتى', 'منذ', 'خلال', 'بعد', 'قبل',
        'ب', 'ل', 'ك',
        'لل'
    }
    
    # Keyboard Proximity Mapping
    # Arabic keyboard layout adjacency
    KEYBOARD_NEIGHBORS = {
        'ض': ['ص', 'ق'],
        'ص': ['ض', 'ث', 'ق'],
        'ث': ['ص', 'ق'],
        'ق': ['ض', 'ص', 'ث', 'ف', 'غ'],
        'ف': ['ق', 'غ', 'ع', 'ب'],
        'غ': ['ق', 'ف', 'ع', 'ه'],
        'ع': ['ف', 'غ', 'ه', 'خ'],
        'ه': ['غ', 'ع', 'خ', 'ح'],
        'خ': ['ع', 'ه', 'ح', 'ج'],
        'ح': ['ه', 'خ', 'ج'],
        'ج': ['خ', 'ح', 'د'],
        'د': ['ج', 'ذ'],
        'ذ': ['د'],
        'ش': ['س', 'ي', 'ئ'],
        'س': ['ش', 'ي', 'ب'],
        'ي': ['ش', 'س', 'ب', 'ت'],
        'ب': ['ي', 'س', 'ف', 'ل', 'ن'],
        'ل': ['ب', 'ا', 'ن', 'م'],
        'ا': ['ل', 'ت', 'م'],
        'ت': ['ي', 'ا', 'ن'],
        'ن': ['ب', 'ل', 'ت', 'م', 'ك'],
        'م': ['ل', 'ا', 'ن', 'ك'],
        'ك': ['ن', 'م', 'ط'],
        'ط': ['ك', 'ظ'],
        'ظ': ['ط'],
        'ئ': ['ش', 'ء', 'ر'],
        'ء': ['ئ', 'ؤ'],
        'ؤ': ['ء', 'ر'],
        'ر': ['ئ', 'ؤ', 'لا', 'ى', 'ز'],
        'لا': ['ر', 'ى'],
        'ى': ['ر', 'لا', 'ة', 'ز'],
        'ة': ['ى', 'و', 'ز'],
        'و': ['ة', 'ز'],
        'ز': ['ر', 'ى', 'ة', 'و'],
        # Alif variants
        'أ': ['ا', 'إ', 'آ'],
        'إ': ['ا', 'أ'],
        'آ': ['ا', 'أ'],
    }
    
    @staticmethod
    def is_keyboard_neighbor(char1: str, char2: str) -> bool:
        """Check if two Arabic chars are adjacent on keyboard."""
        neighbors = RulesBasedCorrector.KEYBOARD_NEIGHBORS.get(char1, [])
        return char2 in neighbors
    
    @staticmethod
    def fix_char_substitution(text: str) -> str:
        """Replace Persian/Urdu characters with Arabic"""
        for old, new in RulesBasedCorrector.SUBSTITUTION_MAP.items():
            text = text.replace(old, new)
        return text
    
    @staticmethod
    def fix_char_repetition(text: str) -> str:
        """Remove excessive character repetition (3+ consecutive → 1)."""
        # Only collapse 3+ repetitions (not 2+)
        text = re.sub(r'([^\d\s])\1{2,}', r'\1', text)
        return text
    
    @staticmethod
    def advanced_heuristic_repair(text: str) -> str:
        """

        Apply aggressive heuristic repairs to generate a strong baseline candidate.

        1. Unified Char Fixes (Persian/Urdu + Repetition)

        2. Aggressive Word Splitting (Iterative & Anchored)

        """
        # 1. Base Fixes
        text = RulesBasedCorrector.fix_char_substitution(text)
        text = RulesBasedCorrector.fix_char_repetition(text)
        
        # 2. Heuristic Split
        words = text.split()
        processed_words = []
        for word in words:
            processed_words.append(RulesBasedCorrector._recursive_split(word))
        
        return ' '.join(processed_words)

    @staticmethod
    def _recursive_split(word: str) -> str:
        """

        Recursively split merged words (Anchored to Start)

        Avoids splitting 'المنزل' -> 'ال من زل' (middle split)

        """
        if len(word) < 4:
            return word

        # 1. Separable Prepositions (Must be at START)
        # "فيالبيت" -> "في البيت"
        separables = sorted(['من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال', 'بعد', 'قبل'], key=len, reverse=True)
        
        for sep in separables:
            # Check matches: exact match or prefix match
            if word == sep:
                return word
            
            if word.startswith(sep):
                remainder = word[len(sep):]
                # Condition: Remainder must be substantial (usually starts with al- or len > 2)
                if len(remainder) >= 3:
                     # Recursive call on remainder
                     return sep + " " + RulesBasedCorrector._recursive_split(remainder)

        # 2. Common typo merges (e.g. "يا" + Name)
        if word.startswith('يا') and len(word) > 4:
             return 'يا ' + RulesBasedCorrector._recursive_split(word[2:])

        # 3. Attached Particles (Only 'Wa' and 'Fa' are commonly mistakenly merged with non-al words in typos)
        # "وال" -> "و ال" is usually correct in tokenization but "و" is attached in script.
        # We only split if it looks like a HARD merge error.
        
        return word
    



# ═══════════════════════════════════════════════════════════════════════════════
# OUTPUT VALIDATOR (Hallucination Prevention)
# ═══════════════════════════════════════════════════════════════════════════════

class OutputValidator:
    """Validate model outputs to prevent hallucinations"""
    
    @staticmethod
    def calculate_edit_distance(s1: str, s2: str) -> int:
        """Calculate Levenshtein distance"""
        return Levenshtein.distance(s1, s2)
    
    @staticmethod
    def check_character_preservation(original: str, corrected: str) -> Tuple[bool, str]:
        """Check if characters are mostly preserved (Jaccard similarity)"""
        chars_original = set(original)
        chars_corrected = set(corrected)
        
        if not chars_original:
            return True, "valid"
        
        intersection = chars_original & chars_corrected
        union = chars_original | chars_corrected
        
        jaccard = len(intersection) / len(union) if union else 0
        
        if jaccard < 0.35:
            return False, "low_character_similarity"
        
        return True, "valid"

    @staticmethod
    def check_word_count(original: str, corrected: str) -> Tuple[bool, str]:
        """

        Check if word count is reasonable

        Relaxed: Allow splitting merged words (count can double)

        """
        len_orig = len(original.split())
        len_corr = len(corrected.split())
        
        # Allow expanding 1 word to up to 3 (e.g. "فيالمدرسة" -> "في المدرسة")
        if len_orig == 1:
            if len_corr <= 3:
                return True, "valid"
            # If original is very long, allow more splits (e.g. "هذاالولدذهبإلىالمدرسة")
            if len(original) > 12 and len_corr <= 6:
                return True, "valid"
             
        # For sentences, stricter ratio
        ratio = len_corr / len_orig if len_orig > 0 else 0
        if ratio > 2.0 or ratio < 0.5:
             return False, "word_count_mismatch"
             
        return True, "valid"

    def validate(self, original: str, corrected: str, error_type: str) -> Tuple[bool, str]:
        """

        Main validation logic

        """
        # 0. Sanity Check
        if not corrected or not corrected.strip():
            return False, "empty_output"
        
        # Space Leniency: if ONLY difference is whitespace → accept
        original_no_space = original.replace(' ', '').replace('\u200c', '')  # Also handle ZWNJ
        corrected_no_space = corrected.replace(' ', '').replace('\u200c', '')
        
        if original_no_space == corrected_no_space:
            # Only whitespace changed - accept immediately
            return True, "space_leniency_accept"
            
        # 1. Length Ratio Check
        len_orig = len(original)
        len_corr = len(corrected)
        
        # Allow expansion for word splitting
        if len_corr > len_orig * 2.5:
             return False, "too_long"
             
        # Allow shrinking (but not typically more than 50% unless removing repetition)
        if len_corr < len_orig * 0.5:
             # Exception: if original had excessive repetition
             if error_type == ErrorType.CHAR_REPETITION:
                 pass
             else:
                 return False, "too_short"
                 
        # 2. Check Word Count
        is_valid_count, reason = self.check_word_count(original, corrected)
        if not is_valid_count:
            return False, reason
            
        # 3. Check Character Preservation
        # Critical for avoiding hallucinations
        is_valid_chars, reason = self.check_character_preservation(original, corrected)
        if not is_valid_chars:
             # Exception: If input was garbage/keyboard mash, preservation might be low.
             # But for valid inputs, this prevents changing "كتاب" to "مكتبة" (if no roots match)
             return False, reason
             
        return True, "valid"

# ═══════════════════════════════════════════════════════════════════════════════
# VOCABULARY MANAGER
# ═══════════════════════════════════════════════════════════════════════════════

class VocabularyManager:
    """

    Centralized vocabulary management for OOV/IV detection.

    Key for vocabulary-aware acceptance: OOV→IV = accept, IV→OOV = reject.

    """
    
    # Arabic character equivalence for normalization
    HAMZA_VARIANTS = {'أ', 'إ', 'آ', 'ء', 'ؤ', 'ئ', 'ا'}
    ALEF_NORMALIZED = 'ا'
    TA_MARBUTA = 'ة'
    HA = 'ه'
    YA_VARIANTS = {'ي', 'ى'}
    YA_NORMALIZED = 'ي'
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
        # Build vocabulary set from tokenizer (exclude subwords and short tokens)
        self.vocab = {
            w for w in tokenizer.get_vocab().keys()
            if w.isalpha() and not w.startswith('##') and len(w) > 1
        }
        
        # Frequency rank: lower index = more common (usually)
        self.vocab_rank = {w: i for w, i in tokenizer.get_vocab().items()}
        
        # Build normalized vocabulary for fuzzy matching
        self.normalized_vocab = {self.normalize_for_comparison(w): w for w in self.vocab}
        
        logger.info(f"VocabularyManager initialized: {len(self.vocab)} words")
    
    @classmethod
    def normalize_for_comparison(cls, word: str) -> str:
        """

        Normalize Arabic word for comparison (hamza, ta marbuta, etc.)

        Used for equivalence checking, not for final output.

        """
        result = []
        for i, char in enumerate(word):
            # Normalize Hamza variants to Alef
            if char in cls.HAMZA_VARIANTS:
                result.append(cls.ALEF_NORMALIZED)
            # Normalize Ta Marbuta to Ha at word end
            elif char == cls.TA_MARBUTA and i == len(word) - 1:
                result.append(cls.HA)
            # Normalize Ya variants
            elif char in cls.YA_VARIANTS:
                result.append(cls.YA_NORMALIZED)
            else:
                result.append(char)
        return ''.join(result)
    
    def is_iv(self, word: str) -> bool:
        """Check if word is In-Vocabulary (known word)."""
        clean = re.sub(r'[^\w]', '', word)
        if not clean:
            return True  # Empty/punctuation only = treat as valid
        
        # Direct check
        if clean in self.vocab:
            return True
        
        # Normalized check (handles hamza/ta marbuta variations)
        normalized = self.normalize_for_comparison(clean)
        if normalized in self.normalized_vocab:
            return True
            
        return False
    
    def is_oov(self, word: str) -> bool:
        """Check if word is Out-Of-Vocabulary (unknown word)."""
        return not self.is_iv(word)
    
    def get_frequency_rank(self, word: str) -> int:
        """Get frequency rank (lower = more common). Returns 999999 for OOV."""
        clean = re.sub(r'[^\w]', '', word)
        return self.vocab_rank.get(clean, 999999)
    
    def all_words_iv(self, text: str) -> bool:
        """Check if ALL words in text are In-Vocabulary."""
        words = text.split()
        return all(self.is_iv(w) for w in words)
    
    def count_oov_words(self, text: str) -> int:
        """Count number of OOV words in text."""
        words = text.split()
        return sum(1 for w in words if self.is_oov(w))
    
    def get_oov_words(self, text: str) -> List[str]:
        """Get list of OOV words in text."""
        words = text.split()
        return [w for w in words if self.is_oov(w)]
    
    def words_are_equivalent(self, word1: str, word2: str) -> bool:
        """

        Check if two words are equivalent (considering Arabic character variations).

        Useful for accepting corrections that only differ in hamza/ta marbuta.

        """
        norm1 = self.normalize_for_comparison(word1)
        norm2 = self.normalize_for_comparison(word2)
        return norm1 == norm2
    
    @staticmethod
    def damerau_levenshtein_distance(s1: str, s2: str) -> int:
        """

        Calculate Damerau-Levenshtein distance (transpositions count as 1 edit).

        This is better for Arabic typos like اقصتاديا→اقتصاديا (swap صت→تص).

        """
        return jellyfish.damerau_levenshtein_distance(s1, s2)
    
    def calculate_similarity(self, original: str, corrected: str) -> float:
        """

        Calculate similarity score using Damerau-Levenshtein distance.

        Returns value between 0 and 1 (1 = identical).

        """
        dist = self.damerau_levenshtein_distance(original, corrected)
        max_len = max(len(original), len(corrected), 1)
        return 1.0 - (dist / max_len)

# ═══════════════════════════════════════════════════════════════════════════════
# WORD ALIGNER
# ═══════════════════════════════════════════════════════════════════════════════

class WordAligner:
    """

    Aligns input and output words to create hybrid corrections.

    Helps when model fixes one word but breaks another (Raw Wins/Both Wrong cause).

    """
    
    def __init__(self, vocab_manager):
        """Initialize with VocabularyManager for IV checks."""
        self.vocab = vocab_manager
    
    def align_words(self, input_text: str, output_text: str) -> str:
        """

        Create hybrid by selecting best word from each position.

        Uses simple space-based alignment (works for most Arabic cases).

        """
        input_words = input_text.split()
        output_words = output_text.split()
        
        # If lengths differ significantly, alignment is risky -> fallback to output
        if abs(len(input_words) - len(output_words)) > 2:
            input_oov = self.vocab.count_oov_words(input_text)
            output_oov = self.vocab.count_oov_words(output_text)
            return output_text if output_oov < input_oov else input_text
        
        result = []
        
        # Simple position-based alignment (min length)
        min_len = min(len(input_words), len(output_words))
        
        for i in range(min_len):
            in_word = input_words[i]
            out_word = output_words[i]
            
            best_word = self._select_best_word(in_word, out_word)
            result.append(best_word)
            
        # Append remaining words from the longer sequence
        if len(output_words) > min_len:
            result.extend(output_words[min_len:])
        elif len(input_words) > min_len:
            # If input is longer, verify if trailing words are IV
            # If trailing input words are OOV, maybe model was right to remove them?
            # Safest is to keep them if they are IV, else drop.
            for w in input_words[min_len:]:
                 if self.vocab.is_iv(w):
                     result.append(w)
        
        return ' '.join(result)
    
    def _select_best_word(self, input_word: str, output_word: str) -> str:
        """

        Select best word between input and output version.

        

        Logic:

        1. Input OOV + Output IV → Take Output (Model fixed it)

        2. Input IV + Output OOV → Keep Input (Model broke it)

        3. Input IV + Output IV → Keep Input (Conservative) unless Output is much better?

           - For now, strict conservative: if input is valid, keep it.

        4. Both OOV → Take Output (Model likely closer)

        """
        if input_word == output_word:
            return input_word
            
        in_iv = self.vocab.is_iv(input_word)
        out_iv = self.vocab.is_iv(output_word)
        
        # Case 1: Correction worked (OOV -> IV)
        if not in_iv and out_iv:
            return output_word
            
        # Case 2: Correction broke it (IV -> OOV)
        if in_iv and not out_iv:
            return input_word
            
        # Case 3: Both IV (Semantic change or split/merge)
        # Conservative: Keep input to avoid semantic drift (Contextual errors are rare compared to typos)
        if in_iv and out_iv:
            return input_word 
            
        # Case 4: Both OOV
        # Subword-level correction
        # If words are similar length, try character-level blending to find IV
        if len(input_word) == len(output_word) and len(input_word) >= 3:
            # Try replacing one char at a time from output into input
            for i in range(len(input_word)):
                if input_word[i] != output_word[i]:
                    # Try input with this one char from output
                    hybrid = input_word[:i] + output_word[i] + input_word[i+1:]
                    if self.vocab.is_iv(hybrid):
                        return hybrid
                    # Try output with this one char from input
                    hybrid2 = output_word[:i] + input_word[i] + output_word[i+1:]
                    if self.vocab.is_iv(hybrid2):
                        return hybrid2
        
        # Default: Take output, usually closer to target even if still OOV
        return output_word

# ═══════════════════════════════════════════════════════════════════════════════
# SPLIT/MERGE SPECIALIST
# ═══════════════════════════════════════════════════════════════════════════════

class SplitMergeSpecialist:
    """

    Handles word splitting and merging with vocabulary validation.

    

    Key patterns:

    1. SPLIT: OOV word that can be split into two IV words

       - فيالغالب → في الغالب

       - يقعبجماعة → يقع بجماعة

    2. MERGE: Adjacent OOV fragments that can merge to IV  

       - السوري ة → السورية (ta-marbuta attachment)

       - ال كتاب → الكتاب

    """
    
    # Common Arabic prefixes that can be detached
    SEPARABLE_PREFIXES = [
        # Prepositions (longer first for greedy matching)
        'من', 'في', 'على', 'عن', 'مع', 'إلى', 'الى', 'حتى', 'منذ', 'خلال', 
        'بعد', 'قبل', 'بين', 'حول', 'تحت', 'فوق', 'أمام', 'وراء', 'دون',
        # Particles
        'أن', 'لن', 'لم', 'قد', 'سوف', 'كي', 'إذا', 'لو', 'مثل', 'غير',
        # Call particle
        'يا',
    ]
    
    # Protected short words that shouldn't be split
    PROTECTED_WORDS = {
        'في', 'من', 'على', 'عن', 'مع', 'إلى', 'الى', 'ان', 'أن', 'لا', 'ما', 'هو', 'هي',
        'لم', 'لن', 'قد', 'كل', 'كان', 'ذلك', 'هذا', 'هذه', 'التي', 'الذي', 'بين',
    }
    
    def __init__(self, vocab_manager):
        """Initialize with VocabularyManager for IV checks."""
        self.vocab = vocab_manager
        self.separable_prefixes = sorted(
            self.SEPARABLE_PREFIXES, key=len, reverse=True
        )
    
    # Attached prefix patterns that should NOT be split (normal Arabic word formations)
    ATTACHED_PREFIXES = [
        'وال', 'بال', 'فال', 'كال', 'لل',   # Conjunction/Preposition + Article
        'وب', 'وف', 'ول', 'وك', 'وم', 'ون',  # Conjunction + Preposition
        'فب', 'فل', 'فك', 'فم',              # Conjunction + Preposition
    ]
    
    def split_word(self, word: str) -> str:
        """

        Try to split an OOV word into IV components.

        

        Strict Strategy:

        - Only split when BOTH parts are IV

        - Protect attached prefix patterns (وال، بال، etc.)

        - Minimum part lengths to prevent micro-splits

        """
        # Short words: don't split (increased from 4 to 5 for safety)
        if len(word) < 5:
            return word
        
        # Already IV: no need to split
        if self.vocab.is_iv(word):
            return word
        
        # Protected words: don't split
        if word in self.PROTECTED_WORDS:
            return word
        
        # Protected prefix patterns (وال، بال، فال، etc.)
        # These are normal Arabic word formations, NOT merge errors
        for prefix in self.ATTACHED_PREFIXES:
            if word.startswith(prefix):
                remainder = word[len(prefix):]
                # If the remainder (without the prefix) is IV, this is a valid prefixed word
                if self.vocab.is_iv(remainder):
                    return word  # Don't split — it's prefix+valid_word
                # Also check with article: e.g. والخصوصي → وال+خصوصي, check خصوصي
                if prefix.endswith('ال') and self.vocab.is_iv(remainder):
                    return word
        
        # 1. Try separable prefixes first (higher priority)
        for prefix in self.separable_prefixes:
            if word.startswith(prefix) and len(word) > len(prefix) + 2:  # Remainder must be > 2 chars
                remainder = word[len(prefix):]
                
                # Only accept if remainder is IV
                if self.vocab.is_iv(remainder):
                    return f"{prefix} {remainder}"
        
        # 2. Try all positions - STRICT: BOTH parts must be IV AND both >= 3 chars
        for i in range(3, len(word) - 2):  # Both parts at least 3 chars
            left = word[:i]
            right = word[i:]
            
            if self.vocab.is_iv(left) and self.vocab.is_iv(right):
                return f"{left} {right}"
        
        # No valid split found
        return word
    
    # Common Arabic pronoun/possessive suffixes (2-3 chars)
    # These are often incorrectly split from their host word
    PRONOUN_SUFFIXES = {'كم', 'هم', 'ها', 'هن', 'كن', 'نا', 'هما', 'كما', 'تم', 'تن'}
    
    def merge_fragments(self, text: str) -> str:
        """

        Try to merge adjacent OOV fragments into IV words.

        

        Key patterns:

        1. Ta-marbuta detachment: السوري ة → السورية

        2. Al- detachment: ال كتاب → الكتاب

        3. General OOV+OOV merging: Only if both are OOV and result is IV

        4. Short OOV fragment: 1-2 char OOV + next → IV

        5. Pronoun suffix reattachment: علي كم → عليكم

        """
        words = text.split()
        if len(words) < 2:
            return text
        
        result = []
        i = 0
        
        while i < len(words):
            word = words[i]
            
            # Try to merge with next word
            if i + 1 < len(words):
                next_word = words[i + 1]
                merged = word + next_word
                
                # Pattern 1: Detached suffix (ة، ه، ي، ك...)
                # Allow merging even if 'word' is IV because detached suffix is definitely wrong
                if len(next_word) == 1 and next_word in 'ةهاي':
                    if self.vocab.is_iv(merged):
                        result.append(merged)
                        i += 2
                        continue
                
                # Pattern 2: Detached 'Al-' prefix
                # ال كتاب → الكتاب (Safe to merge)
                if word == 'ال' and len(next_word) >= 2:
                    if self.vocab.is_iv(merged):
                        result.append(merged)
                        i += 2
                        continue
                
                # Pattern 3: General OOV + OOV → IV
                # STRICT: Both must be OOV to avoid merging valid words
                if self.vocab.is_oov(word) and self.vocab.is_oov(next_word):
                    if self.vocab.is_iv(merged):
                        result.append(merged)
                        i += 2
                        continue
                
                # Pattern 4: Short OOV fragment (1-2 chars) merge
                if len(word) <= 2 and self.vocab.is_oov(word):
                    if self.vocab.is_iv(merged):
                        result.append(merged)
                        i += 2
                        continue
                
                # Pattern 5: Pronoun suffix reattachment
                # Fixes over-splitting: علي كم → عليكم
                if next_word in self.PRONOUN_SUFFIXES:
                    if self.vocab.is_iv(merged) and not self.vocab.is_iv(word):
                        result.append(merged)
                        i += 2
                        continue
                
                # Pattern 6: Short fragment merge
                # Merges two short words when combined they form a valid longer word
                # Fixes: علي كم → عليكم, ويت أمل → ويتأمل, المد فتر → المدفتر
                # Condition: both words ≤ 3 chars, merged ≥ 5 chars and IV
                if len(word) <= 3 and len(next_word) <= 3:
                    if len(merged) >= 5 and self.vocab.is_iv(merged):
                        result.append(merged)
                        i += 2
                        continue
            
            result.append(word)
            i += 1
        
        return ' '.join(result)
    
    def process_text(self, text: str) -> str:
        """

        Apply full split/merge processing to text.

        Order: First merge, then split.

        """
        # Step 1: Merge fragments
        text = self.merge_fragments(text)
        
        # Step 2: Split OOV words
        words = text.split()
        processed = []
        
        for word in words:
            if self.vocab.is_oov(word) and len(word) >= 4:
                split_result = self.split_word(word)
                processed.append(split_result)
            else:
                processed.append(word)
        
        return ' '.join(processed)

# ═══════════════════════════════════════════════════════════════════════════════
# EDIT DISTANCE CORRECTOR
# ═══════════════════════════════════════════════════════════════════════════════

class EditDistanceCorrector:
    """

    Generates candidates based on Levenshtein distance.

    Uses BERT Vocabulary to filter for valid words.

    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # Build strict vocabulary (ignore subwords starting with ## and punctuation)
        self.vocab = {
            w for w in tokenizer.get_vocab().keys() 
            if w.isalpha() and not w.startswith('##') and len(w) > 1
        }
        # Frequency rank heuristic: lower index = higher frequency (usually)
        self.vocab_rank = {w: i for w, i in tokenizer.get_vocab().items()}

    def edits1(self, word):
        """All edits that are one edit away from `word`."""
        letters    = 'أابتثجحخدذرزسشصضطظعغفقكلمنهويءآىةئؤ' # Arabic chars
        splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
        deletes    = [L + R[1:]               for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
        replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
        inserts    = [L + c + R               for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        """All edits that are two edits away from `word`."""
        return (e2 for e1 in self.edits1(word) for e2 in self.edits1(e1))

    def known(self, words):
        """The subset of `words` that appear in the dictionary of known words."""
        return set(w for w in words if w in self.vocab)

    def generate_candidate(self, text: str) -> str:
        """

        Generate a candidate sentence by fixing OOV words using Edit Distance.

        """
        words = text.split()
        corrected_words = []
        
        for word in words:
            # Clean word for checking
            clean_word = re.sub(r'[^\w]', '', word)
            
            # If word is known, keep it
            if clean_word in self.vocab:
                corrected_words.append(word)
                continue
            
            # If OOV, try to find neighbor
            # 1. Edits 1
            candidates = self.known(self.edits1(clean_word))
            
            # 2. Edits 2 (if no Edits 1)
            if not candidates:
                # Optimize: Only check edits2 if word length is reasonable
                if len(clean_word) < 7: 
                    candidates = self.known(self.edits2(clean_word))
            
            if candidates:
                # Pick best candidate: Lowest vocab rank (most frequent)
                best_candidate = min(candidates, key=lambda w: self.vocab_rank.get(w, 999999))
                corrected_words.append(best_candidate)
            else:
                # No correction found, keep original
                corrected_words.append(word)
                
        return ' '.join(corrected_words)

    





# ═══════════════════════════════════════════════════════════════════════════════
# CONTEXTUAL CORRECTOR (MLM-based with Batch Scoring)
# ═══════════════════════════════════════════════════════════════════════════════

class ContextualCorrector:
    """MLM-based contextual correction for confusion pairs"""
    
    # Common confusion pairs in Arabic
    CONFUSION_PAIRS = [
        ('ض', 'ظ'), ('ذ', 'ز'), ('ث', 'س'), ('ص', 'س'),
        ('ط', 'ت'), ('ق', 'ك'), ('ه', 'ة'), ('ا', 'ى'),
        ('ت', 'د'), ('د', 'ض'), ('ك', 'ق'), ('غ', 'ق'),
        ('ج', 'ش'), ('س', 'ز'), ('ف', 'ب'), ('و', 'و'), # (و, و) placeholder, maybe (و, ؤ)?
        ('ؤ', 'و'), ('ئ', 'ي'), ('ء', 'أ'), ('إ', 'أ'),
    ]
    
    def __init__(self, model_name: str = 'aubmindlab/bert-base-arabertv02', cache_size: int = 10000):
        """Initialize with BERT MLM model and LRU cache"""
        from transformers import AutoTokenizer, AutoModelForMaskedLM
        from functools import lru_cache
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Build confusion map
        self.confusion_map = self._build_confusion_map()
        
        # Stats
        self.cache_hits = 0
        self.cache_misses = 0
        
        # Create LRU cache for scoring
        self._score_cache = {}
        self.cache_size = cache_size
        
        # Load vocabulary for filtering
        self.vocab = self.tokenizer.get_vocab()
    
    def _build_confusion_map(self):
        """Build bidirectional confusion map"""
        confusion_map = {}
        for char1, char2 in self.CONFUSION_PAIRS:
            if char1 not in confusion_map:
                confusion_map[char1] = []
            if char2 not in confusion_map:
                confusion_map[char2] = []
            confusion_map[char1].append(char2)
            confusion_map[char2].append(char1)
        return confusion_map
    
    def get_confusable_chars(self, char: str) -> List[str]:
        """Get confusable characters for a given char"""
        return self.confusion_map.get(char, [])
    
    def generate_candidates(self, word: str) -> List[str]:
        """Generate candidate corrections for a word"""
        candidates = [word]
        
        # 1. Substitute confusable chars
        for i, char in enumerate(word):
            confusables = self.get_confusable_chars(char)
            for conf_char in confusables:
                candidate = word[:i] + conf_char + word[i+1:]
                if candidate not in candidates:
                    candidates.append(candidate)
        
        # 2. Remove repeated characters (deletion)
        # Fixes: مدررسة -> مدرسة, جميلل -> جميل
        for i in range(len(word) - 1):
            if word[i] == word[i+1]:
                # Remove one instance of the repeated char
                candidate = word[:i] + word[i+1:]
                if candidate not in candidates:
                    candidates.append(candidate)
        
        # 3. Edit Distance 1 Candidates (Insertions, Substitutions, Transpositions)
        # Using a restricted set of characters to avoid explosion
        COMMON_CHARS = 'ابتثجحخدذرزسشصضطظعغفقكلمنهويأإآءئؤةى'
        
        # Filter candidates by vocabulary to prevent hallucinations and scoring errors
        # Only keep candidates that are valid single tokens in the vocabulary.
        
        # Insertions (missing char)
        for i in range(len(word) + 1):
            for char in COMMON_CHARS:
                candidate = word[:i] + char + word[i:]
                if candidate in self.vocab and candidate not in candidates:
                    candidates.append(candidate)
                    
        # Substitutions (wrong char)
        if len(word) < 7:
            for i in range(len(word)):
                for char in COMMON_CHARS:
                    if char != word[i]:
                        candidate = word[:i] + char + word[i+1:]
                        if candidate in self.vocab and candidate not in candidates:
                            candidates.append(candidate)
                            
        # Deletions (extra char) - General
        for i in range(len(word)):
            candidate = word[:i] + word[i+1:]
            if len(candidate) > 1:
                # For deletions, candidate might be a valid word even if not in vocab?
                # But to be safe and consistent with scoring, let's enforce vocab.
                # (Note: 'جميل' IS in vocab, so it works).
                if candidate in self.vocab and candidate not in candidates:
                    candidates.append(candidate)

        return candidates
    
    def score_with_mlm(self, text: str, position: int, word: str) -> float:
        """Score a word in context using BERT MLM"""
        # Check cache
        cache_key = f"{text}|{position}|{word}"
        if cache_key in self._score_cache:
            self.cache_hits += 1
            return self._score_cache[cache_key]
        
        self.cache_misses += 1
        
        # Create masked text
        words = text.split()
        if position >= len(words):
            return 0.0
        
        masked_words = words.copy()
        masked_words[position] = '[MASK]'
        masked_text = ' '.join(masked_words)
        
        # Tokenize
        inputs = self.tokenizer(masked_text, return_tensors='pt', padding=True, truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits
        
        # Find mask position
        mask_token_index = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        
        if len(mask_token_index) == 0:
            return 0.0
        
        # Get probabilities for the word
        mask_token_logits = predictions[0, mask_token_index[0], :]
        probs = torch.softmax(mask_token_logits, dim=0)
        
        # Get word token id
        word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
        if not word_tokens:
            return 0.0
        
        word_token_id = word_tokens[0]
        score = probs[word_token_id].item()
        
        # Update cache (with size limit)
        if len(self._score_cache) >= self.cache_size:
            # Remove oldest entry (simple FIFO)
            self._score_cache.pop(next(iter(self._score_cache)))
        
        self._score_cache[cache_key] = score
        
        return score
    
    def score_candidates_batch(self, text: str, position: int, candidates: List[str]) -> dict:
        """

        Batch score multiple candidates (NEW - more efficient!)

        Returns: {candidate: score}

        """
        scores = {}
        
        for candidate in candidates:
            scores[candidate] = self.score_with_mlm(text, position, candidate)
        
        return scores
    
    def predict_masked_token(self, text: str, position: int, top_k: int = 5) -> List[Tuple[str, float]]:
        """Predict words for a masked position. Returns list of (word, score)."""
        words = text.split()
        if position >= len(words):
            return []
            
        masked_words = words.copy()
        masked_words[position] = '[MASK]'
        masked_text = ' '.join(masked_words)
        
        inputs = self.tokenizer(masked_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits
            
        mask_token_index = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        
        if len(mask_token_index) == 0:
            return []
            
        mask_token_logits = predictions[0, mask_token_index[0], :]
        probs = torch.softmax(mask_token_logits, dim=0)
        
        top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
        
        results = []
        for i in range(top_k):
            token_id = top_k_indices[i].item()
            score = top_k_weights[i].item()
            token = self.tokenizer.decode([token_id]).strip()
            if not token.startswith("##") and token not in self.tokenizer.all_special_tokens:
                results.append((token, score))
                
        return results

    def refine_sentence_with_mask(self, text: str, threshold: float = 0.001, vocab_manager=None, raw_model_output=None) -> str:
        """Refine sentence by masking weak words and predicting replacements.

        IV-Safe + Strict similarity + BERT Kill Switch.

        """
        words = text.split()
        refined_words = words.copy()
        
        # Build set of raw model words for kill switch
        raw_words = raw_model_output.split() if raw_model_output else []
        
        for i, word in enumerate(words):
            # IV-Safe check - NEVER replace IV words
            if vocab_manager and vocab_manager.is_iv(word):
                continue
            
            # BERT Kill Switch: skip words matching raw model output
            if i < len(raw_words) and word == raw_words[i]:
                continue
            
            # Skip very short words (prepositions etc)
            if len(word) <= 2:
                continue
            
            # 1. Check confidence
            current_score = self.score_with_mlm(text, i, word)
            
            if current_score > threshold:
                continue
                
            # 2. Mask and Predict
            predictions = self.predict_masked_token(text, i, top_k=10)
            
            # 3. Filter and Select (strict)
            for pred_word, pred_score in predictions:
                if pred_word == word:
                    continue

                if abs(len(pred_word) - len(word)) > 1:
                     continue
                     
                # Similarity Check (0.90 minimum)
                dist = Levenshtein.distance(word, pred_word)
                max_len = max(len(word), len(pred_word))
                similarity = 1.0 - (dist / max_len)
                
                if similarity < 0.90:
                    continue
                
                # Must be IV
                if vocab_manager and vocab_manager.is_oov(pred_word):
                    continue
                
                # Minimum absolute confidence gate (12%)
                if pred_score < 0.12:
                    continue
                    
                # Score Improvement
                is_original_common = current_score > 0.001
                
                if is_original_common:
                     if pred_score > current_score * 1000:
                         refined_words[i] = pred_word
                         break
                else:
                    if pred_score > current_score * 50 and pred_score > 0.2:
                        refined_words[i] = pred_word
                        break
        
        return ' '.join(refined_words)
    
    def calculate_sentence_score(self, text: str) -> float:
        """Calculate fluency score using BERT MLM average word probability."""
        words = text.split()
        if not words:
            return 0.0
            
        total_score = 0.0
        scored_words = 0
        
        for i, word in enumerate(words):
            score = self.score_with_mlm(text, i, word)
            total_score += score
            scored_words += 1
            
        if scored_words == 0:
            return 0.0
            
        return total_score / scored_words


# ═══════════════════════════════════════════════════════════════════════════════
# MAIN SPELL CHECKER CLASS
# ═══════════════════════════════════════════════════════════════════════════════

class ArabicSpellChecker:
    """Main Arabic Spell Checker class"""
    
    def __init__(self, model, tokenizer, device, use_contextual: bool = True):
        """Initialize spell checker with model and components"""
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        
        # Initialize components
        self.postprocessor = AraSpellPostProcessor()
        self.classifier = ErrorClassifier()
        self.rules = RulesBasedCorrector()
        self.validator = OutputValidator()
        self.vocab_manager = VocabularyManager(tokenizer)
        self.edit_corrector = EditDistanceCorrector(tokenizer)  # Edit Distance candidates
        self.split_merge = SplitMergeSpecialist(self.vocab_manager)
        
        # WordAligner for word-level hybrid corrections
        self.word_aligner = WordAligner(self.vocab_manager)
        
        # Initialize contextual corrector (optional)
        self.use_contextual = use_contextual
        if use_contextual:
            try:
                self.contextual = ContextualCorrector()
                logger.info("Contextual correction enabled")
            except Exception as e:
                logger.warning(f"Contextual correction disabled: {e}")
                self.contextual = None
                self.use_contextual = False
        else:
            self.contextual = None
    def _fix_repeated_end_chars(self, text: str) -> str:
        """

        🆕 Fix repeated characters at word endings

        

        Examples:

            اليومم → اليوم

            جميلل → جميل

            صباحح → صباح

        """
        # Remove repeated chars at word end (keep only one)
        text = re.sub(r'([ا-ي])\1+\b', r'\1', text)
        return text
    
    def _fix_merged_with_errors(self, text: str) -> str:
        """ Fix merged words that contain errors

        

        Examples:

            الممدرسة → المدرسة

            الكتابب → الكتاب

            الططالب → الطالب

        """
        # Pattern 1: ال + repeated char + word
        text = re.sub(r'ال([ا-ي])\1+([ا-ي]{2,})', r'ال\2', text)
        
        # Pattern 2: word + repeated char at end
        text = re.sub(r'\b([ا-ي]{3,})([ا-ي])\2+\b', r'\1\2', text)
        
        return text
    

    def _split_merged_words_linguistic(self, text: str) -> str:
        """ Split merged words using linguistic patterns

        

        Examples:

            كلصباح → كل صباح

            فيالطريق → في الطريق

            السلامعليكم → السلام عليكم

        """
        # Pattern 1: Prepositions + (article)? + word
        # Added: ك (like in كالكتاب) but careful not to split overlapping words
        text = re.sub(
            r'\b(في|من|إلى|الى|حتى|منذ|خلال|بعد|قبل)(ال)?([ا-ي]{3,})',
            r'\1 \2\3',
            text
        )
        
        # Pattern 2: كل + word
        text = re.sub(r'\b(كل)([ا-ي]{3,})', r'\1 \2', text)
        
        # Pattern 3: Article repetition
        text = re.sub(r'([ا-ي]{3,})(ال)([ا-ي]{3,})', r'\1 \2\3', text)
        
        # Pattern 4: Single-letter prepositions
        text = re.sub(r'\b([بلك])(ال)?([ا-ي]{3,})', r'\1 \2\3', text)
        
        # Pattern 5: Word + عليكم/عليك
        text = re.sub(r'([ا-ي]{4,})(عليكم|عليك|عليه|عليها)', r'\1 \2', text)
        
        # Pattern 6: على/عن in middle of (merged) words
        text = re.sub(r'([ا-ي]{3,})(على|عن)([ا-ي]{3,})', r'\1 \2 \3', text)
        
        return text
    
    def _split_long_words_heuristic(self, text: str, max_length: int = 15) -> str:
        """ Split suspiciously long words using heuristics

        """
        words = text.split()
        result = []
        
        for word in words:
            if len(word) <= max_length:
                result.append(word)
                continue
            
            # Check for embedded article
            if 'ال' in word[2:]:
                parts = word.split('ال', 1)
                if len(parts[0]) >= 2 and len(parts[1]) >= 3:
                    result.extend([parts[0], 'ال' + parts[1]])
                    continue
            
            # Check for common prefixes at start of long word
            if len(word) >= 8:
                split_found = False
                for split_pos in [2, 3]:
                    prefix = word[:split_pos]
                    suffix = word[split_pos:]
                    
                    if prefix in ['في', 'من', 'على', 'عن', 'مع', 'كل', 'ب', 'ل', 'ك']:
                        result.extend([prefix, suffix])
                        split_found = True
                        break
                
                if not split_found:
                    result.append(word)
            else:
                result.append(word)
        
        return ' '.join(result)
    
    def _normalize_tanween_patterns(self, text: str) -> str:
        """ Normalize tanween patterns

        

        Examples:

            جدأ → جداً

            كثيرأ → كثيراً

        """
        # أ at word end → اً
        text = re.sub(r'([ا-ي]{2,})أ\b', r'\1اً', text)
        
        # Remove standalone أ
        text = re.sub(r'\s+أ\s+', ' ', text)
        
        # Fix accidental splits (e.g. ب + space + word)
        text = re.sub(r'\b([بلك])\s+([ا-ي])', r'\1\2', text)
        
        return text
    

    

    
    def preprocess(self, text: str) -> str:
        """Preprocessing pipeline (مع التحسينات المدمجة)"""
        # Basic normalization
        text = self.postprocessor.remove_harakat(text)
        text = self.postprocessor.remove_tatweel(text)
        text = self.postprocessor.normalize_special_chars(text)
        
        # Integrated improvements
        # Fix repeated chars and merged words with errors FIRST
        text = self._fix_repeated_end_chars(text)
        text = self._fix_merged_with_errors(text)
        
        # Then split merged words
        text = self._split_merged_words_linguistic(text)
        text = self._split_long_words_heuristic(text)
        text = self._normalize_tanween_patterns(text)
        
        # Merge separated 'ال'
        text = self.postprocessor.merge_separated_al(text)
        
        # Collapse repetitions
        text = self.postprocessor.unified_collapse_repeated(text)
        
        # Rules-based fixes
        text = self.rules.fix_char_substitution(text)
        text = self.rules.fix_char_repetition(text)
        
        # Normalize spaces
        text = self.postprocessor.normalize_spaces(text)
        
        return text
    
    def postprocess(self, text: str, original: str = "") -> str:
        """Postprocessing pipeline — passes vocab_manager for smart ه/ة handling"""
        return self.postprocessor.full_postprocess(text, original, vocab_manager=self.vocab_manager)
    
    def model_inference(self, text: str, num_return_sequences: int = 5) -> List[str]:
        """Run seq2seq model inference and return top candidates.

        Also extracts beam scores (token-level probabilities) for diagnostics.

        """
        # Tokenize
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Generate with beam search
        # Keeping 5 beams as model was trained/optimized for this
        # Keeping 5 beams as model was trained/optimized for this
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                num_beams=5,
                num_return_sequences=num_return_sequences,
                early_stopping=True,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        # Decode
        candidates = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
        
        # Store beam scores for potential use
        self._last_beam_scores = {}
        if hasattr(outputs, 'sequences_scores') and outputs.sequences_scores is not None:
            scores = outputs.sequences_scores.tolist()
            for cand, score in zip(candidates, scores):
                self._last_beam_scores[cand] = score
        
        return candidates
    
    def correct(self, text: str) -> str:
        """

        Main correction pipeline (RERANKING APPROACH)

        

        Steps:

        1. Preprocess

        2. Generate Candidates (Model Beams + Baseline)

        3. Rerank Candidates (Validator + Fluency)

        4. Select Best

        5. Postprocess

        """
        if not text or not text.strip():
            return text
        
        original = text
        
        # 1. Preprocess
        # This provides a strong baseline candidate
        preprocessed_text = self.preprocess(text)
        
        # 2. Classify error type
        error_type = self.classifier.classify(preprocessed_text)
        
        # 3. Generate Candidates
        candidates = []
        
        # A. Baseline (Preprocessed)
        candidates.append(preprocessed_text)
        
        # B. Smart Rules Candidate (Aggressive Heuristic)
        rules_candidate = self.rules.advanced_heuristic_repair(text)
        candidates.append(rules_candidate)
        
        # B2. Edit Distance Candidate
        edit_candidate = self.edit_corrector.generate_candidate(text)
        if edit_candidate != text and edit_candidate != rules_candidate:
            candidates.append(edit_candidate)
        
        # C. Model Beams
        raw_model_output = None  # Track for safety net
        try:
            model_candidates = self.model_inference(preprocessed_text, num_return_sequences=5)
            raw_model_output = model_candidates[0] if model_candidates else None
            candidates.extend(model_candidates)
            
            # D. Word-Aligned Hybrid Candidate
            # Creates a hybrid by selecting best word from each position
            if model_candidates:
                hybrid_candidate = self.word_aligner.align_words(preprocessed_text, model_candidates[0])
                if hybrid_candidate not in candidates:
                    candidates.append(hybrid_candidate)
                
                # E. Word-Aligned with ALL top beams (not just beam 0)
                for beam in model_candidates[1:3]:  # Top 3 beams
                    hybrid_beam = self.word_aligner.align_words(preprocessed_text, beam)
                    if hybrid_beam not in candidates:
                        candidates.append(hybrid_beam)
            
            # D2. Token-level Voting Candidate
            # Majority-vote each token across all beams
            if model_candidates and len(model_candidates) >= 3:
                try:
                    beam_word_lists = [c.split() for c in model_candidates]
                    max_words = max(len(wl) for wl in beam_word_lists)
                    voted_words = []
                    for pos in range(max_words):
                        words_at_pos = []
                        for wl in beam_word_lists:
                            if pos < len(wl):
                                words_at_pos.append(wl[pos])
                        if words_at_pos:
                            most_common = Counter(words_at_pos).most_common(1)[0][0]
                            voted_words.append(most_common)
                    voted_candidate = ' '.join(voted_words)
                    if voted_candidate not in candidates:
                        candidates.append(voted_candidate)
                except Exception:
                    pass
        except Exception as e:
            logger.warning(f"Model inference failed: {e}")
        
        # Remove duplicates while preserving order
        unique_candidates = []
        seen = set()
        for c in candidates:
            if c not in seen:
                unique_candidates.append(c)
                seen.add(c)
        candidates = unique_candidates
        

        
        # 4. Rerank Candidates
        best_candidate = preprocessed_text
        best_score = -1.0
        
        # Debug info
        candidate_scores = []
        
        for cand in candidates:
            # A. Validation Score (Hard Penalty)
            # Check validity against strict original
            is_valid, reason = self.validator.validate(original, cand, error_type.value)
            
            # Additional check: If candidate is suspiciously shorter than original (and not just harakat removal)
            if len(cand) < len(original) * 0.5:
                is_valid = False
                reason = "too_short"

            # ═══════════════════════════════════════════════════════════════════════════
            # VOCABULARY-AWARE ACCEPTANCE
            # ═══════════════════════════════════════════════════════════════════════════
            # Logic: OOV→IV = ACCEPT (boost), IV→OOV = REJECT (penalize)
            # This prevents over-conservative validation from rejecting correct corrections
            
            input_oov_count = self.vocab_manager.count_oov_words(original)
            cand_oov_count = self.vocab_manager.count_oov_words(cand)
            
            vocab_boost = 1.0
            
            # Case 1: OOV→IV (Correction fixed unknown words) → Accept more readily
            if input_oov_count > 0 and cand_oov_count < input_oov_count:
                # Significant boost for reducing OOV words
                oov_reduction = input_oov_count - cand_oov_count
                vocab_boost = 1.0 + (oov_reduction * 0.3)  # +30% per OOV fixed
                
                # If ALL words are now IV, accept even with higher edit distance
                if cand_oov_count == 0 and self.vocab_manager.all_words_iv(cand):
                    # Override validation rejection if OOV→IV
                    if not is_valid and reason not in ["empty_output"]:
                        is_valid = True
                        reason = "vocab_aware_accept"
            
            # Case 2: IV→OOV (Correction introduced unknown words) → Penalize
            elif cand_oov_count > input_oov_count:
                # Penalize for introducing new OOV words
                vocab_boost = 0.5  # 50% penalty
            
            # Case 3: All IV to begin with → Standard validation
            elif input_oov_count == 0 and cand_oov_count == 0:
                # Both are valid vocab, prefer minimal edits
                vocab_boost = 1.0
            
            # ═══════════════════════════════════════════════════════════════════════════

            
            # Penalty factor
            # Valid: 1.0
            # Invalid: 0.01 (Heavy penalty, essentially disqualified unless all are invalid)
            validity_factor = 1.0 if is_valid else 0.001
            
            # B. Fluency Score (BERT MLM)
            fluency_score = 0.0
            if self.use_contextual and self.contextual:
                try:
                    fluency_score = self.contextual.calculate_sentence_score(cand)
                except Exception as e:
                    logger.warning(f"Scoring failed: {e}")
                    fluency_score = 0.5 # Default fallback
            else:
                fluency_score = 1.0 
            
            # C. Similarity Score (Damerau-Levenshtein Distance)
            dist = VocabularyManager.damerau_levenshtein_distance(preprocessed_text, cand)
            max_len = max(len(preprocessed_text), len(cand), 1)
            similarity = 1.0 - (dist / max_len)
            
            # Boost exact matches
            if cand == preprocessed_text:
                similarity = 1.0
            
            # Keyboard Proximity Bonus
            # If changes between input and candidate are keyboard-adjacent,
            # it's more likely a typo fix (give bonus)
            keyboard_bonus = 1.0
            input_words = preprocessed_text.split()
            cand_words = cand.split()
            if len(input_words) == len(cand_words):
                for iw, cw in zip(input_words, cand_words):
                    if iw != cw and len(iw) == len(cw):
                        # Check char-by-char differences
                        for ic, cc in zip(iw, cw):
                            if ic != cc and RulesBasedCorrector.is_keyboard_neighbor(ic, cc):
                                keyboard_bonus *= 1.05  # 5% bonus per keyboard-adjacent fix
            
            # HIGH CONFIDENCE GATING
            # If model is extremely confident (high fluency) and words are valid, relax validation
            # This allows correcting severe corruptions that fail strict edit distance
            if fluency_score > 0.85 and cand_oov_count == 0:
                 if not is_valid and reason in ["too_short", "low_character_similarity", "word_count_mismatch"]:
                      # Check if it makes sense length-wise (don't allow completely empty or massive hallucinations)
                      if len(cand) >= len(original) * 0.4:
                          is_valid = True
                          reason = "high_confidence_override"
                          vocab_boost *= 1.2  # Bonus for high confidence
                          validity_factor = 1.0  # Reset validity factor
            
            # Final Score = (Fluency^0.3) * (Similarity^3.0) * Validity * VocabBoost * KeyboardBonus * BeamBoost
            fluency_exp = 0.3
            similarity_exp = 3.0
            
            # Beam 0 Boost — model's top beam gets 15% priority
            beam_boost = 1.0
            if raw_model_output and cand == raw_model_output:
                beam_boost = 1.15
            
            final_score = (fluency_score ** fluency_exp) * (similarity ** similarity_exp) * validity_factor * vocab_boost * keyboard_bonus * beam_boost
            
            candidate_scores.append({
                'text': cand,
                'is_valid': is_valid,
                'reason': reason,
                'fluency': fluency_score,
                'similarity': similarity,
                'vocab_boost': vocab_boost,
                'input_oov': input_oov_count,
                'cand_oov': cand_oov_count,
                'final_score': final_score
            })
            
            if final_score > best_score:
                best_score = final_score
                best_candidate = cand
        
        # ═══════════════════════════════════════════════════════════════════════════
        # --- Output Quality Scoring (Minimum Score Threshold) ---
        # If ALL candidates scored poorly, the correction is unreliable → keep input
        # ═══════════════════════════════════════════════════════════════════════════
        if best_candidate != preprocessed_text:
            # Check: did the best candidate actually get a decent score?
            # The preprocessed input (candidate 0) is always in the pool.
            # If the best candidate barely beats preprocessed_text, it might not be trustworthy.
            preprocessed_score = 0.0
            for cs in candidate_scores:
                if cs['text'] == preprocessed_text:
                    preprocessed_score = cs['final_score']
                    break
            
            # If best score is less than 1.05x the preprocessed score AND
            # the best candidate introduced OOV words → fall back to preprocessed
            if preprocessed_score > 0 and best_score < preprocessed_score * 1.05:
                best_oov = self.vocab_manager.count_oov_words(best_candidate)
                prep_oov = self.vocab_manager.count_oov_words(preprocessed_text)
                if best_oov > prep_oov:
                    best_candidate = preprocessed_text
                    best_score = preprocessed_score
        
        # ═══════════════════════════════════════════════════════════════════════════
        # --- Contextual Validation Layer ---
        # Compare fluency of input vs best candidate
        # If correction made text LESS fluent → reject the correction
        # ═══════════════════════════════════════════════════════════════════════════
        if best_candidate != preprocessed_text and self.use_contextual and self.contextual:
            try:
                input_fluency = self.contextual.calculate_sentence_score(preprocessed_text)
                best_fluency = 0.0
                for cs in candidate_scores:
                    if cs['text'] == best_candidate:
                        best_fluency = cs['fluency']
                        break
                
                # If input is significantly more fluent than best candidate
                # AND both have similar OOV counts → prefer input
                if input_fluency > 0 and best_fluency > 0:
                    if input_fluency > best_fluency * 1.5:  # Input 50% more fluent
                        input_oov = self.vocab_manager.count_oov_words(preprocessed_text)
                        best_oov = self.vocab_manager.count_oov_words(best_candidate)
                        if input_oov <= best_oov:
                            # Input is more fluent AND has fewer/equal OOV → keep input
                            best_candidate = preprocessed_text
            except Exception:
                pass  # Contextual validation is optional
        
        # 5. Postprocess Winner
        result = self.postprocess(best_candidate, original)
        
        # 5.5 IV-Safe Postprocessing Check
        # If postprocessing changed an IV word to OOV, revert that word
        if result != best_candidate:
            result_words = result.split()
            best_words = best_candidate.split()
            if len(result_words) == len(best_words):
                fixed_words = []
                input_words_pp = preprocessed_text.split()
                for idx_fw, (rw, bw) in enumerate(zip(result_words, best_words)):
                    if rw != bw:
                        # Postprocessor changed this word
                        bw_iv = self.vocab_manager.is_iv(bw)
                        rw_iv = self.vocab_manager.is_iv(rw)
                        if bw_iv and not rw_iv:
                            # IV → OOV: revert to pre-postprocess version
                            fixed_words.append(bw)
                        elif bw_iv and rw_iv:
                            # Postprocess Distance Guard
                            # DISABLED: Caused word-level regression. When both are IV,
                            # the postprocessor's choice (rw) is usually better because
                            # it applies Arabic-specific rules (hamza, ta marbuta).
                            fixed_words.append(rw)
                        else:
                            fixed_words.append(rw)
                    else:
                        fixed_words.append(rw)
                result = ' '.join(fixed_words)
        
        # 6. Contextual fine-tuning (BERT Masked Refinement)
        # IV-Safe mode - pass vocab_manager to protect IV words
        # BERT Kill Switch - also pass raw_model_output to protect model-confident words
        if self.use_contextual and self.contextual:
             if len(result) > 3:
                 result = self.contextual.refine_sentence_with_mask(
                     result, vocab_manager=self.vocab_manager,
                     raw_model_output=raw_model_output
                 )
        
        # 7. Safe Split/Merge Post-processing
        # Only apply merge_fragments (safe: only merges when result is IV)
        result = self.split_merge.merge_fragments(result)
        
        # ═══════════════════════════════════════════════════════════════════════════
        # VALIDATION & QUALITY CHECKS
        # ═══════════════════════════════════════════════════════════════════════════
        
        # 8. Output Stability Test (Solution 30)
        # If correcting the output again changes it → unstable correction → reject
        # Stable corrections are idempotent: correct(correct(x)) == correct(x)
        if result != preprocessed_text and raw_model_output:
            try:
                # Quick stability check: run the result through preprocessing only
                # (full model inference would be too slow)
                re_preprocessed = self.preprocess(result)
                
                # If re-preprocessing changes the result significantly, it was unstable
                stability_dist = VocabularyManager.damerau_levenshtein_distance(result, re_preprocessed)
                result_len = max(len(result), 1)
                
                if stability_dist > 0:
                    # Result is not stable under re-preprocessing
                    stability_ratio = stability_dist / result_len
                    
                    if stability_ratio > 0.15:  # More than 15% changed → very unstable
                        # Fall back to raw model output if it's more stable
                        raw_re = self.preprocess(raw_model_output)
                        raw_stability = VocabularyManager.damerau_levenshtein_distance(
                            raw_model_output, raw_re
                        ) / max(len(raw_model_output), 1)
                        
                        if raw_stability < stability_ratio:
                            # Raw is more stable → use it
                            raw_oov = self.vocab_manager.count_oov_words(raw_model_output)
                            our_oov = self.vocab_manager.count_oov_words(result)
                            if raw_oov <= our_oov:
                                result = raw_model_output
            except Exception:
                pass  # Stability check is optional, don't break pipeline
        
        # 9. Bidirectional Word-Level Validation (Solution 24)
        # Compare our result word-by-word with raw model output
        # If we corrupted a word that the model got right, revert that word
        if raw_model_output and result != raw_model_output:
            result_words = result.split()
            raw_words = raw_model_output.split()
            
            if len(result_words) == len(raw_words):
                corrected_words = []
                changed = False
                
                for rw, raw_w in zip(result_words, raw_words):
                    if rw != raw_w:
                        rw_iv = self.vocab_manager.is_iv(rw)
                        raw_iv = self.vocab_manager.is_iv(raw_w)
                        
                        # Case 1: Our word is OOV but raw word is IV → take raw
                        if not rw_iv and raw_iv:
                            corrected_words.append(raw_w)
                            changed = True
                        # Case 2: Both IV but our word is further from input
                        elif rw_iv and raw_iv:
                            # Find corresponding input word
                            input_words = preprocessed_text.split()
                            idx = len(corrected_words)
                            if idx < len(input_words):
                                input_w = input_words[idx]
                                rw_dist = Levenshtein.distance(input_w, rw)
                                raw_dist = Levenshtein.distance(input_w, raw_w)
                                # If raw is closer to input AND both are IV → prefer raw
                                # (our pipeline likely introduced unnecessary change)
                                if raw_dist < rw_dist:
                                    corrected_words.append(raw_w)
                                    changed = True
                                else:
                                    corrected_words.append(rw)
                            else:
                                corrected_words.append(rw)
                        else:
                            corrected_words.append(rw)
                    else:
                        corrected_words.append(rw)
                
                if changed:
                    new_result = ' '.join(corrected_words)
                    # Only accept if the new result doesn't increase OOV
                    new_oov = self.vocab_manager.count_oov_words(new_result)
                    old_oov = self.vocab_manager.count_oov_words(result)
                    if new_oov <= old_oov:
                        result = new_result
        
        # 10. SAFETY NET: Compare with raw model output (Conservative)
        # Only switch to raw if raw is CLEARLY better
        if raw_model_output and raw_model_output != result:
            raw_oov = self.vocab_manager.count_oov_words(raw_model_output)
            our_oov = self.vocab_manager.count_oov_words(result)
            
            # Case A: Raw all-IV, ours has OOV
            if raw_oov == 0 and our_oov > 0:
                is_valid, reason = self.validator.validate(original, raw_model_output, "mixed")
                if is_valid or reason == "space_leniency_accept":
                    result = raw_model_output
            
            # Case B: Both all-IV but raw is more similar to input
            # Catches BERT/postprocess damage (word substitutions up to 5 char distance)
            elif raw_oov == 0 and our_oov == 0:
                raw_dist = VocabularyManager.damerau_levenshtein_distance(original, raw_model_output)
                our_dist = VocabularyManager.damerau_levenshtein_distance(original, result)
                result_vs_raw_dist = VocabularyManager.damerau_levenshtein_distance(result, raw_model_output)
                # Threshold at 3 chars — covers single char edits and small substitutions
                # (widening to 5 caused regression by reverting valid hybrid corrections)
                if raw_dist < our_dist and result_vs_raw_dist <= 3:
                    raw_valid, _ = self.validator.validate(original, raw_model_output, "mixed")
                    if raw_valid:
                        result = raw_model_output
            
            # Case C: Word count differs — raw might have correct splitting
            # Catches: 'فيلق → في فيلق' (pipeline added word)
            # or 'بلاكبيرن روفرز → بلاكبيرن روفر' (pipeline lost word ending)
            elif raw_oov == 0:
                raw_wc = len(raw_model_output.split())
                our_wc = len(result.split())
                if raw_wc != our_wc:
                    raw_dist = VocabularyManager.damerau_levenshtein_distance(original, raw_model_output)
                    our_dist = VocabularyManager.damerau_levenshtein_distance(original, result)
                    if raw_dist < our_dist:
                        raw_valid, _ = self.validator.validate(original, raw_model_output, "mixed")
                        if raw_valid:
                            result = raw_model_output
        
        return result

# ═══════════════════════════════════════════════════════════════════════════════
# PUBLIC API
# ═══════════════════════════════════════════════════════════════════════════════

# Exported for use by benchmark.py and external consumers
spell_checker = None  # Will be initialized on first import with __main__ or by benchmark


def initialize(use_contextual=True):
    """Initialize the spell checker. Call once before using."""
    global spell_checker
    spell_checker = ArabicSpellChecker(model, tokenizer, device, use_contextual=use_contextual)
    logger.info("Spell checker initialized")
    return spell_checker


if __name__ == "__main__":
    sc = initialize(use_contextual=True)

    # Quick demo
    test_cases = [
        "السلام عليكممم",
        "فيالمدرسه",
        "الطقص جميل اليومم",
    ]

    print("\n" + "=" * 60)
    print("AraSpell Demo")
    print("=" * 60)

    for text in test_cases:
        corrected = sc.correct(text)
        print(f"\n  Input:     {text}")
        print(f"  Corrected: {corrected}")

    print("\n" + "=" * 60)
    print("For full benchmark, run: python benchmark.py")
    print("=" * 60)