File size: 63,868 Bytes
be99bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple

import torch
import torch.nn.functional as F
from transformers.cache_utils import DynamicCache

from .sliding_utils import drop_tokens_from_cache

logger = logging.getLogger(__name__)


@dataclass
class DuplexWindowConfig:
    """双工滑窗配置

    滑窗模式:
    - "off": 禁用滑窗
    - "basic": 基础滑窗(按 cache 长度触发)
    - "context": 带 context 的滑窗(按 unit 数量触发,保留生成文本到 previous)
    """

    # 滑窗模式
    sliding_window_mode: str = "off"  # "off" / "basic" / "context"

    # 基础滑窗参数
    basic_window_high_tokens: int = 4000  # 高水位线:超过此值触发滑窗
    basic_window_low_tokens: int = 3500  # 低水位线:滑窗后保留到此值

    # 带 context 滑窗参数
    context_previous_max_tokens: int = 500  # previous 最大 token 数
    context_max_units: int = 24  # 最大 unit 数量(超过时触发滑窗)

    # 验证模式(用于对比测试)
    verify_mode: bool = False  # 是否启用验证日志


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
    logits = logits.clone()

    # Top-k filtering
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    # Top-p (nucleus) filtering
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        # keep the first token that exceeds top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[0, indices_to_remove] = filter_value

    return logits


class StreamDecoder:
    def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None):
        self.m = llm
        self.tokenizer = tokenizer
        self.listen_id = self.tokenizer.eos_token_id

        self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
        self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
        self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
        self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>")

        self.special_token_ids = special_token_ids if special_token_ids is not None else []

        # 缓存 special tokens(用于 context 滑窗时过滤)
        self._all_special_ids = set()
        self._all_special_tokens_text = set()
        if self.tokenizer:
            if hasattr(self.tokenizer, "all_special_ids"):
                self._all_special_ids = set(self.tokenizer.all_special_ids)
            if hasattr(self.tokenizer, "all_special_tokens"):
                self._all_special_tokens_text = set(self.tokenizer.all_special_tokens)

        custom_special_tokens = [
            "<unit>",
            "</unit>",
            "<image>",
            "</image>",
            "<slice>",
            "</slice>",
            "<|listen|>",
            "<|speak|>",
            "<|tts_bos|>",
            "<|tts_eos|>",
            "<|audio_start|>",
            "<|audio_end|>",
            "<|chunk_eos|>",
            "<|chunk_tts_eos|>",
            "<|turn_eos|>",
            "<|audio_start|>",
            "<|audio_end|>",
        ]
        self._all_special_tokens_text.update(custom_special_tokens)
        for token in custom_special_tokens:
            token_id = self.tokenizer.convert_tokens_to_ids(token)
            if token_id is not None and token_id != self.tokenizer.unk_token_id:
                self._all_special_ids.add(token_id)

        if forbidden_token_ids is None:
            self.forbidden_token_ids = []
        elif isinstance(forbidden_token_ids, int):
            self.forbidden_token_ids = [self.forbidden_token_ids]
        else:
            self.forbidden_token_ids = forbidden_token_ids
        self.forbidden_token_ids.append(self.chunk_eos_id)

        assert isinstance(self.forbidden_token_ids, list)

        self.cache = None
        self.context = ""
        self.generated_tokens = []  # track generated tokens
        self.generated_special_tokens = []  # track generated special tokens
        self.reset()
        self.embeds = None
        self.system_embeds = None

        # ========== 滑窗相关状态 ==========
        self._unit_history: List[Dict[str, Any]] = []
        self._next_unit_id: int = 0
        self._pending_unit_id: Optional[int] = None
        self._pending_unit_start_cache_len: int = 0
        self._system_preserve_length: int = 0
        self._position_offset: int = 0
        self._window_config = DuplexWindowConfig()
        self._window_enabled: bool = True
        self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {}

        # ========== 带 Context 保留的滑窗状态 ==========
        # 初始化时 Cache 布局: [prefix] [suffix] [units...]
        # 首次滑窗后布局: [prefix] [previous_marker + content] [suffix] [units...]
        #                  固定           动态滑动区              固定
        self._preserve_prefix_length: int = 0  # 原始 prefix 的长度(固定不变)
        self._previous_content_length: int = 0  # previous 内容的长度(动态变化,含 marker)
        self._suffix_token_ids: List[int] = []  # suffix 的 token ids(如 <|im_end|>)

        # Previous 标志(首次滑窗时动态添加)
        self._previous_marker: str = "\n\nprevious: "  # 固定前缀标志
        self._previous_marker_token_ids: List[int] = []  # marker 的 token ids(初始化时设置)
        self._has_previous: bool = False  # 是否已添加 previous 标志

        # Previous 内容
        self._previous_text: str = ""  # 累积的生成文本(不含 marker)
        self._previous_token_ids: List[int] = []  # previous 的完整 token ids(含 marker)

        # ========== 验证统计 ==========
        self._sliding_event_count: int = 0  # 滑窗触发次数
        self._total_dropped_tokens: int = 0  # 总共丢弃的 token 数
        self._total_dropped_units: int = 0  # 总共丢弃的 unit 数

    def sliding_embeds(self):
        # tmp = system_embeds
        # tmp +-》 embeds after 5s
        # reset
        # feed
        pass

    def reset(self):
        self.context = ""
        self.cache = None
        self.generated_tokens = []
        self.generated_special_tokens = []
        self.embeds = None
        self.system_embeds = None

        # 滑窗状态重置
        old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0
        self._unit_history = []
        self._next_unit_id = 0
        self._pending_unit_id = None
        self._pending_unit_start_cache_len = 0
        self._system_preserve_length = 0
        self._position_offset = 0
        self._rope_inv_freq_cache = {}

        # Context 保留状态重置
        self._preserve_prefix_length = 0
        self._previous_content_length = 0
        self._suffix_token_ids = []
        self._previous_marker = "\n\nprevious: "
        self._previous_marker_token_ids = []
        self._has_previous = False
        self._previous_text = ""
        self._previous_token_ids = []

        # 验证统计
        self._sliding_event_count = 0  # 滑窗触发次数
        self._total_dropped_tokens = 0  # 总共丢弃的 token 数
        self._total_dropped_units = 0  # 总共丢弃的 unit 数

        if old_unit_count > 0:
            logger.info("[SW] reset: cleared %d units, all sliding window state reset", old_unit_count)

    def get_cache_length(self) -> int:
        if self.cache is None:
            return 0
        if isinstance(self.cache, DynamicCache):
            if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0:
                return self.cache.key_cache[0].shape[2]
            return 0
        # Tuple cache format
        return self.cache[0][0].shape[2]

    def get_total_generated_tokens(self) -> int:
        return sum(len(u.get("generated_tokens", [])) for u in self._unit_history)

    def register_unit_start(self) -> int:
        self._pending_unit_id = self._next_unit_id
        self._pending_unit_start_cache_len = self.get_cache_length()
        logger.info(
            "[SW] unit_start: pending_unit_id=%d, cache_len=%d, preserve=%d, units=%d",
            self._pending_unit_id,
            self._pending_unit_start_cache_len,
            self._system_preserve_length,
            len(self._unit_history),
        )
        return self._pending_unit_id

    def register_unit_end(
        self,
        input_type: str,
        generated_tokens: Optional[List[int]] = None,
        is_listen: bool = False,
        generated_text: Optional[str] = None,
    ):
        """在 unit 结束时调用,记录该 unit 的信息

        应在 feed </unit> token 之后调用

        Args:
            input_type: "audio" / "video" / "omni" / "system"
            generated_tokens: 该 unit 生成的 tokens(token ids)
            is_listen: 是否是 listen 状态
            generated_text: 该 unit 生成的文本(用于 context 保留模式)
        """
        if self._pending_unit_id is None:
            logger.warning("register_unit_end called without register_unit_start")
            return

        # 计算该 unit 的长度
        current_cache_len = self.get_cache_length()
        unit_len = current_cache_len - self._pending_unit_start_cache_len

        if unit_len > 0:
            entry = {
                "unit_id": self._pending_unit_id,
                "length": unit_len,
                "type": input_type,
                "generated_tokens": generated_tokens or [],
                "generated_text": generated_text or "",  # 用于 context 保留模式
                "is_listen": is_listen,
            }
            self._unit_history.append(entry)
            gen_count = len(generated_tokens) if generated_tokens else 0
            gen_text_preview = (
                (generated_text[:30] + "...") if generated_text and len(generated_text) > 30 else (generated_text or "")
            )
            logger.info(
                "[SW] unit_end: unit_id=%d type=%s len=%d gen_tokens=%d is_listen=%s | "
                "cache=%d preserve=%d total_units=%d | text='%s'",
                self._pending_unit_id,
                input_type,
                unit_len,
                gen_count,
                is_listen,
                current_cache_len,
                self._system_preserve_length,
                len(self._unit_history),
                gen_text_preview,
            )
        else:
            logger.warning(
                "[SW] unit_end: unit_id=%d has zero length (start=%d, current=%d), not recorded",
                self._pending_unit_id,
                self._pending_unit_start_cache_len,
                current_cache_len,
            )

        self._pending_unit_id = None
        self._pending_unit_start_cache_len = 0
        self._next_unit_id += 1

    def register_system_prompt(self):
        """在 system prompt prefill 完成后调用,记录保护长度"""
        self._system_preserve_length = self.get_cache_length()
        logger.info(
            "[SW] system_prompt registered: preserve_length=%d (will be protected from sliding)",
            self._system_preserve_length,
        )

    # ==================== 滑窗核心方法 ====================

    def _get_rope_theta(self) -> float:
        """获取模型的 rope_theta 配置"""
        return float(getattr(self.m.config, "rope_theta", 10000.0))

    def _drop_tokens_from_cache(self, length: int) -> bool:
        """从 cache 中移除指定数量的 tokens(保护 system prompt)

        移除位于 [preserve, preserve + length) 区间的 tokens
        支持 DynamicCache 和 tuple cache 两种格式
        """
        if self.cache is None or length <= 0:
            logger.warning("[SW] _drop_tokens_from_cache: cache is None or length<=0 (length=%d)", length)
            return False

        cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache"
        cache_len_before = self.get_cache_length()
        offset_before = self._position_offset

        logger.debug(
            "[SW] _drop_tokens_from_cache: type=%s, drop=%d tokens from [%d, %d), cache=%d, preserve=%d",
            cache_type,
            length,
            self._system_preserve_length,
            self._system_preserve_length + length,
            cache_len_before,
            self._system_preserve_length,
        )

        new_cache, new_offset, success = drop_tokens_from_cache(
            cache=self.cache,
            length=length,
            preserve=self._system_preserve_length,
            position_offset=self._position_offset,
            rope_theta=self._get_rope_theta(),
            inv_freq_cache=self._rope_inv_freq_cache,
        )
        if success:
            self.cache = new_cache  # For DynamicCache this is the same object (in-place)
            self._position_offset = new_offset

        if success:
            logger.debug(
                "[SW] _drop_tokens_from_cache: SUCCESS cache %d -> %d, offset %d -> %d (RoPE reindexed)",
                cache_len_before,
                self.get_cache_length(),
                offset_before,
                self._position_offset,
            )
        else:
            logger.error(
                "[SW] _drop_tokens_from_cache: FAILED to drop %d tokens (cache=%d, preserve=%d)",
                length,
                cache_len_before,
                self._system_preserve_length,
            )

        return success

    def _drop_unit(self, unit_id: int) -> bool:
        """移除指定 unit"""
        entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
        if not entries:
            logger.warning("[SW] _drop_unit: unit_id=%d not found", unit_id)
            return False

        total_len = sum(e["length"] for e in entries)
        if total_len <= 0:
            logger.warning("[SW] _drop_unit: unit_id=%d has zero total length, removing from history", unit_id)
            for e in entries:
                self._unit_history.remove(e)
            return False

        cache_before = self.get_cache_length()
        if not self._drop_tokens_from_cache(total_len):
            logger.error(
                "[SW] _drop_unit: failed to drop %d tokens for unit_id=%d from cache (cache=%d, preserve=%d)",
                total_len,
                unit_id,
                cache_before,
                self._system_preserve_length,
            )
            return False

        cache_after = self.get_cache_length()
        for e in entries:
            gen_count = len(e.get("generated_tokens", []))
            logger.info(
                "[SW] 🗑️ DROPPED unit_id=%d type=%s len=%d gen_tokens=%d | cache %d -> %d, offset=%d",
                e["unit_id"],
                e["type"],
                e["length"],
                gen_count,
                cache_before,
                cache_after,
                self._position_offset,
            )
            self._unit_history.remove(e)

        return True

    def _drop_next_unit(self) -> bool:
        """移除最早的一个非 system unit"""
        for entry in self._unit_history:
            unit_id = entry.get("unit_id")
            if unit_id is None:
                continue
            # 跳过 system 类型
            if entry.get("type") == "system":
                logger.debug("[SW] _drop_next_unit: skipping system unit_id=%d", unit_id)
                continue
            logger.debug("[SW] _drop_next_unit: attempting to drop unit_id=%d", unit_id)
            if self._drop_unit(unit_id):
                return True
        logger.debug("[SW] _drop_next_unit: no droppable unit found in %d units", len(self._unit_history))
        return False

    def enforce_window(self) -> bool:
        """强制执行滑窗策略(与单工保持一致,只看 cache 长度)

        当 cache 长度超过高水位线时,循环移除最早的 unit,
        直到 cache 长度降到低水位线以下。
        """
        if not self._window_enabled:
            logger.info("[SW] enforce_window: window disabled, skip")
            return False

        cfg = self._window_config
        cache_len_before = self.get_cache_length()

        if cache_len_before <= cfg.basic_window_high_tokens:
            logger.debug(
                "[SW] enforce_window: cache=%d <= high_water=%d, no sliding needed",
                cache_len_before,
                cfg.basic_window_high_tokens,
            )
            return False  # 未超过高水位线,不触发

        # 超过高水位线,开始滑窗
        logger.info(
            "[SW] ⚡ SLIDING TRIGGERED: cache=%d > high_water=%d, target=low_water=%d",
            cache_len_before,
            cfg.basic_window_high_tokens,
            cfg.basic_window_low_tokens,
        )

        dropped_count = 0
        cache_len = cache_len_before
        while cache_len > cfg.basic_window_low_tokens:
            if not self._drop_next_unit():
                logger.warning("[SW] enforce_window: no more units to drop, stopping")
                break
            dropped_count += 1
            cache_len = self.get_cache_length()

        if dropped_count > 0:
            # 更新统计计数器
            self._sliding_event_count += 1
            self._total_dropped_tokens += cache_len_before - cache_len
            self._total_dropped_units += dropped_count

            # 一致性检查
            expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
            is_consistent = expected == cache_len
            logger.info(
                "[SW] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units | "
                "consistency: expected=%d actual=%d %s",
                cache_len_before,
                cache_len,
                dropped_count,
                len(self._unit_history),
                expected,
                cache_len,
                "✓" if is_consistent else "✗ MISMATCH!",
            )
            if not is_consistent:
                logger.error(
                    "[SW] ❌ CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d",
                    self._system_preserve_length,
                    sum(u["length"] for u in self._unit_history),
                    cache_len,
                    self._position_offset,
                )

        return dropped_count > 0

    # ==================== 带 Context 保留的滑窗方法 ====================

    def register_system_prompt_with_context(
        self,
        suffix_token_ids: Optional[List[int]] = None,
        context_previous_marker: str = "\n\nprevious: ",
    ):
        """注册 system prompt(带 context 保留模式)

        初始化时 Cache 布局: [prefix] [suffix] [units...]
        首次滑窗后布局: [prefix] [context_previous_marker + content] [suffix] [units...]

        调用此方法时,cache 中应该只有 prefix(不含 previous 标志)
        suffix 会在后续 feed 进去

        Args:
            suffix_token_ids: suffix 的 token ids(如 <|im_end|> 的 id)
            context_previous_marker: previous 标志前缀,如 "\\n\\nprevious: "
        """
        # prefix = 当前 cache 内容(固定不变,不含 previous 标志)
        self._preserve_prefix_length = self.get_cache_length()
        self._previous_content_length = 0  # 初始时没有 previous 内容
        self._suffix_token_ids = suffix_token_ids or []
        # 总保护长度 = prefix + suffix(初始时无 previous)
        self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids)

        # 初始化 previous 相关状态
        self._previous_marker = context_previous_marker
        self._previous_marker_token_ids = (
            self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else []
        )
        self._has_previous = False
        self._previous_text = ""
        self._previous_token_ids = []

        logger.info(
            "[SW-CTX] system_prompt registered: prefix_len=%d, suffix_len=%d, marker='%s' (%d tokens)",
            self._preserve_prefix_length,
            len(self._suffix_token_ids),
            context_previous_marker.replace("\n", "\\n"),
            len(self._previous_marker_token_ids),
        )
        self.log_cache_layout("After register_system_prompt")

    def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]:
        """从 units 中提取生成的文本和 token ids

        Args:
            units: 要提取的 unit 列表

        Returns:
            (text, token_ids): 拼接后的文本和 token ids(过滤掉 special tokens)
        """
        text_parts = []
        token_ids = []

        for u in units:
            # 只保留非 listen 的 unit 的生成内容
            if u.get("is_listen", False):
                continue
            gen_text = u.get("generated_text", "")
            gen_tokens = u.get("generated_tokens", [])

            # 过滤文本中的 special tokens
            if gen_text:
                clean_text = gen_text
                for st in self._all_special_tokens_text:
                    clean_text = clean_text.replace(st, "")
                if clean_text.strip():
                    text_parts.append(clean_text)

            # 过滤掉 special tokens
            if gen_tokens:
                filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids]
                token_ids.extend(filtered_tokens)

        return "".join(text_parts), token_ids

    def _rebuild_cache_with_previous(
        self,
        new_previous_tokens: List[int],
        units_to_keep_len: Optional[int] = None,
    ) -> bool:
        """重建 cache,把新的 previous 内容插入到 prefix 和 suffix 之间

        Cache 布局变化:
        [prefix] [old_prev] [suffix] [old_units]  →  [prefix] [new_prev] [suffix] [remaining_units]

        Args:
            new_previous_tokens: 新的 previous token ids
            units_to_keep_len: 需要保留的 units 长度(从 cache 末尾往回算)
                               如果为 None,根据 unit_history 计算

        Returns:
            是否成功重建
        """
        if self.cache is None:
            logger.warning("[SW-CTX] _rebuild_cache_with_previous: cache is None")
            return False

        old_previous_len = self._previous_content_length
        new_previous_len = len(new_previous_tokens)
        suffix_len = len(self._suffix_token_ids)
        total_cache_len = self.get_cache_length()

        # 计算需要保留的 units 长度
        if units_to_keep_len is None:
            units_to_keep_len = sum(u["length"] for u in self._unit_history)

        # 特殊情况:如果 previous 没有变化(新旧都为空),不需要重建 cache 的 prefix+suffix 部分
        # 但仍需要对 units 做 RoPE reindex(因为删除了一个 unit,位置变了)
        if new_previous_len == 0 and old_previous_len == 0:
            # Cache 布局: [prefix(7)] [suffix(1)] [units...]
            # 只需保留 prefix + suffix + remaining_units
            preserve_len = self._preserve_prefix_length + suffix_len

            # 简单地截取 cache:[prefix+suffix] + [remaining_units]
            # remaining_units 在 cache 末尾
            if units_to_keep_len > 0:
                # [0:preserve_len] + [total-units_to_keep_len:total]
                prefix_suffix_cache = self._slice_cache(0, preserve_len)
                units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None)

                # 计算被删除的 tokens 数量
                dropped_tokens = total_cache_len - preserve_len - units_to_keep_len

                # 对 units 做 RoPE reindex:位置从 (preserve_len + dropped_tokens) 移到 preserve_len
                # 注意:不加 position_offset,因为 cache 位置已经被压缩(从 0 开始)
                if dropped_tokens > 0:
                    old_start = preserve_len + dropped_tokens
                    new_start = preserve_len
                    logger.debug(
                        "[SW-CTX] RoPE reindex (no-op path): old_pos=[%d:%d] -> new_pos=[%d:%d], length=%d",
                        old_start,
                        old_start + units_to_keep_len,
                        new_start,
                        new_start + units_to_keep_len,
                        units_to_keep_len,
                    )
                    units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)

                self.cache = self._concat_caches(prefix_suffix_cache, units_cache)
            else:
                self.cache = self._slice_cache(0, preserve_len)

            logger.info(
                "[SW-CTX] _rebuild_cache_with_previous (no-op): previous unchanged (0->0), "
                "just removed unit from cache, cache=%d, units_kept=%d",
                self.get_cache_length(),
                units_to_keep_len,
            )
            return True

        # 1. 获取 prefix cache(固定不变)
        prefix_end = self._preserve_prefix_length
        prefix_cache = self._slice_cache(0, prefix_end)

        # 2. 获取需要保留的 units cache(从末尾取)
        units_start_in_old_cache = total_cache_len - units_to_keep_len
        units_cache = None
        if units_to_keep_len > 0:
            units_cache = self._slice_cache(units_start_in_old_cache, None)

        # 3. 计算新 previous + suffix 的 cache(需要 forward)
        # 合并 previous tokens 和 suffix tokens
        prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids
        prev_suffix_len = len(prev_suffix_tokens)

        new_prefix_prev_suffix_cache = prefix_cache
        if prev_suffix_len > 0:
            # Embed tokens
            prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens)
            # 计算起始位置(在 prefix 之后)
            start_pos = self._preserve_prefix_length + self._position_offset

            # Forward 计算 KV cache
            with torch.no_grad():
                device = prev_suffix_embeds.device
                position_ids = torch.arange(
                    start_pos,
                    start_pos + prev_suffix_len,
                    device=device,
                ).unsqueeze(0)

                # 用 prefix cache 作为 past_key_values
                outputs = self.m(
                    inputs_embeds=(
                        prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds
                    ),
                    position_ids=position_ids,
                    past_key_values=prefix_cache,
                    use_cache=True,
                    return_dict=True,
                )
                # 新 cache 包含 prefix + new_previous + suffix
                new_prefix_prev_suffix_cache = outputs.past_key_values

        # 4. 调整 units cache 的 RoPE
        # 新布局:[prefix] [new_prev] [suffix] [units]
        # 注意:不加 position_offset,因为 cache 位置已经被压缩(从 0 开始)
        new_system_total = prefix_end + new_previous_len + suffix_len
        if units_cache is not None and self._get_cache_len(units_cache) > 0:
            old_start = units_start_in_old_cache
            new_start = new_system_total

            if old_start != new_start:
                units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)

        # 5. 拼接新 cache
        if units_cache is not None and self._get_cache_len(units_cache) > 0:
            self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache)
        else:
            self.cache = new_prefix_prev_suffix_cache

        # 6. 更新长度
        self._previous_content_length = new_previous_len
        # 总保护长度 = prefix + previous + suffix
        self._system_preserve_length = prefix_end + new_previous_len + suffix_len

        # 打印详细的 cache 布局信息
        prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text
        suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else ""
        logger.info(
            "[SW-CTX] _rebuild_cache_with_previous:\n"
            "  prefix_len=%d | previous: %d tokens '%s' | suffix: %d tokens '%s'\n"
            "  cache: %d -> %d, units_kept=%d, preserve=%d",
            self._preserve_prefix_length,
            new_previous_len,
            prev_text_preview,
            suffix_len,
            suffix_preview,
            old_previous_len + self._preserve_prefix_length + suffix_len + units_to_keep_len,
            self.get_cache_length(),
            units_to_keep_len,
            self._system_preserve_length,
        )
        return True

    def _slice_cache(self, start: int, end: Optional[int], clone: bool = True):
        """切片 cache

        Args:
            start: 起始位置
            end: 结束位置(None 表示到末尾)
            clone: 是否克隆(默认 True,防止共享内存问题)
        """
        if self.cache is None:
            return None
        if isinstance(self.cache, DynamicCache):
            # DynamicCache
            new_key_cache = [
                k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache
            ]
            new_value_cache = [
                v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache
            ]
            new_cache = DynamicCache()
            new_cache.key_cache = new_key_cache
            new_cache.value_cache = new_value_cache
            return new_cache
        else:
            # Tuple cache
            if clone:
                return tuple(
                    (layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache
                )
            else:
                return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache)

    def _get_cache_len(self, cache) -> int:
        """获取 cache 长度"""
        if cache is None:
            return 0
        if isinstance(cache, DynamicCache):
            if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0:
                return cache.key_cache[0].shape[2]
            return 0
        # Tuple cache
        if cache and cache[0] and cache[0][0] is not None:
            return cache[0][0].shape[2]
        return 0

    def _concat_caches(self, cache1, cache2):
        """拼接两个 cache"""
        if cache1 is None:
            return cache2
        if cache2 is None:
            return cache1

        if isinstance(cache1, DynamicCache):
            new_cache = DynamicCache()
            new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)]
            new_cache.value_cache = [
                torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache)
            ]
            return new_cache
        else:
            # Tuple cache
            return tuple(
                (
                    torch.cat([layer1[0], layer2[0]], dim=2),
                    torch.cat([layer1[1], layer2[1]], dim=2),
                )
                for layer1, layer2 in zip(cache1, cache2)
            )

    def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int):
        """对 cache 进行 RoPE 位置调整"""
        if cache is None or length <= 0:
            return cache

        device = None
        if isinstance(cache, DynamicCache):
            device = cache.key_cache[0].device if cache.key_cache else None
        else:
            device = cache[0][0].device if cache and cache[0] else None

        if device is None:
            return cache

        old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long)
        new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long)

        from .sliding_utils import realign_rotary_suffix

        rope_theta = self._get_rope_theta()

        if isinstance(cache, DynamicCache):
            new_key_cache = []
            for k in cache.key_cache:
                new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache)
                new_key_cache.append(new_k)
            cache.key_cache = new_key_cache
            return cache
        else:
            new_cache = []
            for layer in cache:
                new_k = realign_rotary_suffix(
                    layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache
                )
                new_cache.append((new_k, layer[1]))
            return tuple(new_cache)

    def _update_previous(
        self,
        new_text: str,
        new_tokens: List[int],
        max_tokens: int,
    ) -> None:
        """更新 previous 上下文(同时更新 cache)

        首次滑窗时动态添加 marker + 文本,后续滑窗追加文本
        超过 max_tokens 时截断内容(保留 marker)
        同时重建 cache 以保持一致

        Args:
            new_text: 新增的文本
            new_tokens: 新增的 token ids
            max_tokens: previous 内容的最大 token 数(不含 marker)
        """
        marker_len = len(self._previous_marker_token_ids)
        tokens_to_drop = 0

        # 如果没有新内容,不添加 marker,但仍需重建 cache
        if not new_tokens and not new_text:
            logger.info("[SW-CTX] _update_previous: no new content, skip adding to previous")
            # 仍然需要重建 cache(因为删除了 unit)
            self._rebuild_cache_with_previous(self._previous_token_ids)
            return

        if not self._has_previous:
            # 首次有实际内容时:添加 marker + 文本
            self._previous_text = new_text
            self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens
            self._has_previous = True
            logger.info(
                "[SW-CTX] _update_previous: first slide with content, added marker + %d tokens",
                len(new_tokens),
            )
        else:
            # 后续滑窗:追加文本到 previous
            self._previous_text += new_text
            self._previous_token_ids.extend(new_tokens)

        # 计算内容部分的 token 数(不含 marker)
        content_token_count = len(self._previous_token_ids) - marker_len

        # 检查是否需要截断内容(保留 marker)
        if content_token_count > max_tokens:
            # 截断左侧内容,保留 marker + 最新的 max_tokens 内容
            tokens_to_drop = content_token_count - max_tokens
            old_text = self._previous_text
            # 保留 marker + 截断后的内容
            content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :]
            self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens
            # 重新 decode 文本(只 decode 内容部分)
            try:
                self._previous_text = self.tokenizer.decode(
                    content_tokens,
                    skip_special_tokens=True,
                )
            except Exception as e:
                logger.warning("[SW-CTX] _update_previous: decode failed: %s", e)

            # 左截断日志
            logger.info(
                "[SW-CTX] ⚠️ LEFT TRUNCATION: previous exceeded max_tokens=%d\n"
                "  before: %d content tokens, text='%s'\n"
                "  after:  %d content tokens, text='%s'\n"
                "  dropped %d tokens from left",
                max_tokens,
                content_token_count,
                old_text[:60] + "..." if len(old_text) > 60 else old_text,
                len(content_tokens),
                self._previous_text[:60] + "..." if len(self._previous_text) > 60 else self._previous_text,
                tokens_to_drop,
            )

        # 重建 cache
        self._rebuild_cache_with_previous(self._previous_token_ids)

        prev_preview = self._previous_text[:80] + "..." if len(self._previous_text) > 80 else self._previous_text
        content_len = len(self._previous_token_ids) - marker_len
        if tokens_to_drop > 0:
            logger.info(
                "[SW-CTX] _update_previous: +%d tokens, -%d truncated -> %d content tokens (marker=%d) | '%s'",
                len(new_tokens),
                tokens_to_drop,
                content_len,
                marker_len,
                prev_preview,
            )
        else:
            logger.info(
                "[SW-CTX] _update_previous: +%d tokens -> %d content tokens (marker=%d) | '%s'",
                len(new_tokens),
                content_len,
                marker_len,
                prev_preview,
            )

    def _drop_unit_with_context(
        self,
        unit_id: int,
        max_previous_tokens: int,
    ) -> Tuple[bool, str, List[int]]:
        """移除指定 unit 并返回其生成内容(用于 context 保留)

        流程:
        1. 提取 unit 的生成内容
        2. 先从 cache 移除 unit(不包括 prefix+previous)
        3. 追加生成内容到 previous
        4. 重建 cache(在 _update_previous 中完成)

        Args:
            unit_id: 要移除的 unit ID
            max_previous_tokens: previous 最大 token 数

        Returns:
            (success, extracted_text, extracted_tokens): 是否成功,提取的文本和 tokens
        """
        entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
        if not entries:
            logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d not found", unit_id)
            return False, "", []

        # 提取生成内容
        extracted_text, extracted_tokens = self._extract_generated_text(entries)

        # 计算总长度
        total_len = sum(e["length"] for e in entries)
        if total_len <= 0:
            logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d has zero length", unit_id)
            for e in entries:
                self._unit_history.remove(e)
            return False, extracted_text, extracted_tokens

        cache_before = self.get_cache_length()

        # 从 unit_history 中移除(先记录,以便后续处理)
        for e in entries:
            self._unit_history.remove(e)

        # 注意:这里不再调用 _drop_tokens_from_cache
        # 因为 _update_previous 会重建整个 cache

        # 更新 previous(同时重建 cache)
        self._update_previous(extracted_text, extracted_tokens, max_previous_tokens)

        cache_after = self.get_cache_length()
        for e in entries:
            logger.info(
                "[SW-CTX] 🗑️ DROPPED unit_id=%d type=%s len=%d, extracted=%d chars | cache %d -> %d",
                e["unit_id"],
                e["type"],
                e["length"],
                len(extracted_text),
                cache_before,
                cache_after,
            )

        return True, extracted_text, extracted_tokens

    def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool:
        """移除最早的一个非 system unit(带 context 保留)"""
        for entry in self._unit_history:
            unit_id = entry.get("unit_id")
            if unit_id is None:
                continue
            if entry.get("type") == "system":
                continue
            success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens)
            if success:
                return True
        return False

    def enforce_window_with_context(self) -> bool:
        """带 context 保留的滑窗执行

        当 unit 数量超过 max_units 时,移除最早的 unit,
        并将其生成内容累积到 previous。
        Cache 会在 _update_previous 中自动重建。

        Returns:
            是否执行了滑窗
        """
        if not self._window_enabled:
            logger.info("[SW-CTX] enforce_window_with_context: window disabled, skip")
            return False

        cfg = self._window_config

        if cfg.sliding_window_mode != "context":
            # 如果不是 context 模式,fallback 到基础滑窗
            return self.enforce_window()

        cache_len_before = self.get_cache_length()
        units_before = len(self._unit_history)

        # 带 context 保留模式:只看 unit 数量是否超限
        # (previous 超限时在 _update_previous 中自动截断左侧)
        if units_before <= cfg.context_max_units:
            logger.debug(
                "[SW-CTX] enforce_window_with_context: no sliding needed (units=%d/%d)",
                units_before,
                cfg.context_max_units,
            )
            self.log_cache_layout("No sliding (units=%d/%d)" % (units_before, cfg.context_max_units))
            return False

        slide_tag = "slide #%d" % (self._sliding_event_count + 1)
        logger.info(
            "[SW-CTX] ⚡ SLIDING TRIGGERED (%s): units=%d > max_units=%d, previous=%d tokens",
            slide_tag,
            units_before,
            cfg.context_max_units,
            len(self._previous_token_ids),
        )
        self.log_cache_layout("Before %s" % slide_tag)

        # 滑窗循环:移除 unit 直到数量 ≤ max_units
        dropped_count = 0
        while len(self._unit_history) > cfg.context_max_units:
            if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens):
                logger.warning("[SW-CTX] enforce_window_with_context: no more units to drop")
                break

            dropped_count += 1

        cache_len_after = self.get_cache_length()

        if dropped_count > 0:
            # 更新统计计数器
            self._sliding_event_count += 1
            self._total_dropped_tokens += cache_len_before - cache_len_after
            self._total_dropped_units += dropped_count

            # 一致性检查
            expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
            is_consistent = expected == cache_len_after
            logger.info(
                "[SW-CTX] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units, "
                "previous=%d tokens | consistency: %s",
                cache_len_before,
                cache_len_after,
                dropped_count,
                len(self._unit_history),
                len(self._previous_token_ids),
                "✓" if is_consistent else "✗ MISMATCH!",
            )
            self.log_cache_layout("After slide #%d" % self._sliding_event_count)

        return dropped_count > 0

    def get_previous_context(self) -> Tuple[str, List[int]]:
        """获取当前累积的 previous context

        Returns:
            (previous_text, previous_token_ids): 当前累积的文本和 token ids
        """
        return self._previous_text, self._previous_token_ids.copy()

    # ==================== 调试方法 ====================

    def log_cache_layout(self, tag: str = "") -> None:
        """打印当前 cache 布局(调试用)

        根据滑窗模式显示不同的布局信息:
        - context 模式:[prefix] [previous] [suffix] [units...]
        - 其他模式:[system] [units...]
        """
        cache_len = self.get_cache_length()
        units_len = sum(u["length"] for u in self._unit_history)

        if self._window_config.sliding_window_mode == "context":
            # Context 模式:显示详细布局
            prefix_len = self._preserve_prefix_length
            prev_len = len(self._previous_token_ids)
            suffix_len = len(self._suffix_token_ids)

            # Decode 各部分内容(用于验证)
            prev_full = ""
            if prev_len > 0 and self.tokenizer:
                prev_full = self.tokenizer.decode(self._previous_token_ids)
            suffix_text = ""
            if suffix_len > 0 and self.tokenizer:
                suffix_text = self.tokenizer.decode(self._suffix_token_ids)

            logger.info(
                "[SW-CTX] %s Cache Layout:\n"
                "  [prefix: %d tokens] [previous: %d tokens] [suffix: %d tokens] [units: %d tokens]\n"
                "  preserve=%d | cache=%d | has_previous=%s\n"
                "  previous_full: %s\n"
                "  suffix: %s",
                tag,
                prefix_len,
                prev_len,
                suffix_len,
                units_len,
                self._system_preserve_length,
                cache_len,
                self._has_previous,
                repr(prev_full) if prev_full else "(empty)",
                repr(suffix_text) if suffix_text else "(empty)",
            )
        else:
            # 其他模式:简单布局
            logger.info(
                "[SW] %s Cache Layout: [system: %d] [units: %d] | cache=%d",
                tag,
                self._system_preserve_length,
                units_len,
                cache_len,
            )

    def get_window_stats(self) -> Dict[str, Any]:
        """获取滑窗统计信息"""
        unit_lengths = [u["length"] for u in self._unit_history]
        return {
            "cache_length": self.get_cache_length(),
            "unit_count": len(self._unit_history),
            "unit_lengths": unit_lengths,
            "unit_total_length": sum(unit_lengths),
            "system_preserve_length": self._system_preserve_length,
            "position_offset": self._position_offset,
            "window_enabled": self._window_enabled,
            "total_generated_tokens": self.get_total_generated_tokens(),
            "pending_unit_id": self._pending_unit_id,
            "next_unit_id": self._next_unit_id,
            "config": {
                "sliding_window_mode": self._window_config.sliding_window_mode,
                "basic_window_high_tokens": self._window_config.basic_window_high_tokens,
                "basic_window_low_tokens": self._window_config.basic_window_low_tokens,
                "context_previous_max_tokens": self._window_config.context_previous_max_tokens,
                "context_max_units": self._window_config.context_max_units,
            },
            # Context 保留相关
            "preserve_prefix_length": self._preserve_prefix_length,
            "previous_content_length": self._previous_content_length,
            "suffix_token_count": len(self._suffix_token_ids),
            "previous_text_length": len(self._previous_text),
            "previous_token_count": len(self._previous_token_ids),
            "has_system_template": self._system_prompt_template is not None,
        }

    def _verify_consistency(self) -> bool:
        """验证 unit 历史与 cache 长度一致"""
        expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
        actual = self.get_cache_length()
        return expected == actual

    def dump_unit_history(self, prefix: str = "") -> None:
        """打印当前 unit 历史(调试用)"""
        cache_len = self.get_cache_length()
        unit_sum = sum(u["length"] for u in self._unit_history)
        expected = self._system_preserve_length + unit_sum

        logger.info(
            "[SW] %s=== UNIT HISTORY DUMP === cache=%d, preserve=%d, units=%d, offset=%d",
            prefix + " " if prefix else "",
            cache_len,
            self._system_preserve_length,
            len(self._unit_history),
            self._position_offset,
        )
        logger.info(
            "[SW] Consistency: preserve(%d) + sum(units)(%d) = %d, actual=%d, %s",
            self._system_preserve_length,
            unit_sum,
            expected,
            cache_len,
            "✓ MATCH" if expected == cache_len else "✗ MISMATCH!",
        )
        for i, u in enumerate(self._unit_history):
            gen_count = len(u.get("generated_tokens", []))
            logger.info(
                "[SW]   [%d] unit_id=%d type=%-6s len=%4d gen=%3d listen=%s",
                i,
                u["unit_id"],
                u["type"],
                u["length"],
                gen_count,
                u.get("is_listen", False),
            )

    def print_verification_summary(self) -> Dict[str, Any]:
        """打印验证摘要(用于对比 off/basic/context 模式)

        Returns:
            包含关键验证数据的字典
        """
        cfg = self._window_config

        # 收集所有生成的文本
        all_generated_text = []
        all_generated_tokens = []
        for u in self._unit_history:
            if not u.get("is_listen", False):
                gen_text = u.get("generated_text", "")
                gen_tokens = u.get("generated_tokens", [])
                if gen_text:
                    all_generated_text.append(gen_text)
                if gen_tokens:
                    all_generated_tokens.extend(gen_tokens)

        combined_text = "".join(all_generated_text)

        summary = {
            "mode": cfg.sliding_window_mode,
            "final_cache_length": self.get_cache_length(),
            "final_unit_count": len(self._unit_history),
            "sliding_event_count": self._sliding_event_count,
            "total_dropped_tokens": self._total_dropped_tokens,
            "total_dropped_units": self._total_dropped_units,
            "total_generated_tokens": len(all_generated_tokens),
            "generated_text": combined_text,
            "previous_text": self._previous_text,
            "previous_token_count": len(self._previous_token_ids),
            "position_offset": self._position_offset,
            "system_preserve_length": self._system_preserve_length,
        }

        logger.info("=" * 70)
        logger.info("[VERIFY] === SLIDING WINDOW VERIFICATION SUMMARY ===")
        logger.info("[VERIFY] Mode: %s", cfg.sliding_window_mode)
        logger.info("[VERIFY] Final cache length: %d", summary["final_cache_length"])
        logger.info("[VERIFY] Final unit count: %d", summary["final_unit_count"])
        logger.info("[VERIFY] Sliding events: %d", summary["sliding_event_count"])
        logger.info(
            "[VERIFY] Total dropped: %d tokens, %d units",
            summary["total_dropped_tokens"],
            summary["total_dropped_units"],
        )
        logger.info("[VERIFY] Total generated tokens: %d", summary["total_generated_tokens"])
        logger.info(
            "[VERIFY] Generated text: '%s'", combined_text[:100] + "..." if len(combined_text) > 100 else combined_text
        )
        if cfg.sliding_window_mode == "context":
            logger.info(
                "[VERIFY] Previous content: %d tokens, '%s'",
                summary["previous_token_count"],
                self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text,
            )
        logger.info("[VERIFY] Position offset: %d", summary["position_offset"])
        logger.info("[VERIFY] System preserve length: %d", summary["system_preserve_length"])
        logger.info("=" * 70)

        return summary

    def set_window_config(self, config: DuplexWindowConfig) -> None:
        """设置滑窗配置"""
        self._window_config = config
        logger.info(
            "[SW] Window config set: high_water=%d, low_water=%d",
            config.basic_window_high_tokens,
            config.basic_window_low_tokens,
        )

    def set_window_enabled(self, enabled: bool) -> None:
        """启用/禁用滑窗"""
        old_enabled = self._window_enabled
        self._window_enabled = enabled
        if old_enabled != enabled:
            logger.info("[SW] Window enabled: %s -> %s", old_enabled, enabled)

    def get_context(self):
        return self.context

    def embed_token(self, tid):
        if isinstance(tid, int):
            tid = torch.tensor([tid], device=self.m.device)
        return self.m.model.embed_tokens(tid)

    def embed_tokens(self, token_ids: List[int]) -> torch.Tensor:
        """批量嵌入多个 tokens

        Args:
            token_ids: token id 列表

        Returns:
            embeddings tensor [L, H]
        """
        if not token_ids:
            return torch.empty(0, self.m.config.hidden_size, device=self.m.device)
        tids = torch.tensor(token_ids, device=self.m.device)
        return self.m.model.embed_tokens(tids)

    @torch.no_grad()
    def feed(self, embeds: torch.Tensor, return_logits: bool = False):
        """
        embeds : [L, H]   —— new embedding sequence fed into model at once
        """
        L = embeds.size(0)
        device = embeds.device

        past_len = self.get_cache_length()
        pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0)  # [1, L]

        out = self.m(
            inputs_embeds=embeds.unsqueeze(0),  # [1, L, H]
            position_ids=pos_ids,
            past_key_values=self.cache,
            # use_cache = True,
            return_dict=True,
            output_hidden_states=True,
            # attention_mask=attention_mask
        )
        self.cache = out.past_key_values

        if return_logits:
            logits = self.m.lm_head(out.hidden_states[-1])[:, -1]  # [1, vocab]
            return logits, out.hidden_states[-1]

    @torch.no_grad()
    def decode(
        self,
        logits,
        mode: Literal["sampling", "greedy"] = "sampling",
        temperature=0.7,
        top_k=20,
        top_p=0.8,
        listen_top_k=None,
        listen_prob_scale=1.0,
        text_repetition_penalty=1.05,
        text_repetition_window_size=512,
        debug_print_top5=False,
    ):
        """
        Args:
            logits:
            mode: sampling or greedy
            temperature:
            top_k:
            top_p:
            listen_top_k: force listen_id to be in top-k to keep
            listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase)
            text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition
            text_repetition_window_size: repetition penalty window size
            debug_print_top5: whether to print debug information for top 5 tokens

        Sampling strategy:
            1. first sample all tokens with original logits (apply temperature)
            2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop)
            3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens
            4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling
        """

        logits = logits.clone()

        # ======== 0. 提前对 chunk_eos 进行独立采样判断 ========
        eos_id = self.chunk_eos_id

        with torch.no_grad():
            if mode == "greedy":
                sampled_token = torch.argmax(logits[0]).item()
            else:
                original_probs = F.softmax(logits[0], dim=-1)
                sampled_token = torch.multinomial(original_probs, num_samples=1).item()

            # 如果采到 chunk_eos,直接返回
            if sampled_token == eos_id:
                next_token_id = torch.tensor([eos_id], device=logits.device)
                next_token_str = self.tokenizer.decode(next_token_id)

                return next_token_id

        # 如果没有采到 chunk_eos,把它的 logit 设为 -inf,不让后续采样
        if self.forbidden_token_ids:
            logits[:, self.forbidden_token_ids] = float("-inf")

        # 打印施加 repetition penalty 之前的 topk logits
        if debug_print_top5:
            print("🔵" * 30)
            print("【BEFORE repetition penalty】施加重复惩罚之前的 Top-k logits")
            logits_before_penalty = logits[0] / temperature if mode == "sampling" else logits[0]
            topk_logits_before, topk_indices_before = torch.topk(
                logits_before_penalty, k=min(5, logits_before_penalty.size(-1))
            )

            for i, (token_id, logit_val) in enumerate(zip(topk_indices_before.tolist(), topk_logits_before.tolist())):
                token_str = self.tokenizer.decode([token_id])
                # 特殊处理一些token的显示
                if token_str == "\n":
                    display_str = "\\n"
                elif token_str == " ":
                    display_str = "[SPACE]"
                elif token_str == "":
                    display_str = "[EMPTY]"
                elif token_str == "\t":
                    display_str = "\\t"
                else:
                    display_str = token_str

                # 标记特殊token
                special_mark = ""
                if token_id == self.listen_id:
                    special_mark = " 🎧[LISTEN]"
                elif token_id == self.tokenizer.eos_token_id:
                    special_mark = " 🛑[EOS]"

                print(f"  {i + 1:2d}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}")
            print("🔵" * 30)

        # ======== 1. 应用重复惩罚 ========
        if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0:
            # 获取最近的 tokens(在窗口大小内)考虑特殊token和普通token
            recent_tokens = self.generated_tokens[-text_repetition_window_size:]

            # make it unique
            recent_tokens = list(set(recent_tokens))

            # 对重复的 tokens 应用惩罚
            for token_id in recent_tokens:
                if token_id < logits.size(-1):  # 确保 token_id 在词汇表范围内
                    if text_repetition_penalty > 1.0:
                        # 惩罚重复:降低 logits
                        logits[0, token_id] /= text_repetition_penalty
                    else:
                        # 鼓励重复:增加 logits
                        logits[0, token_id] *= 1.0 / text_repetition_penalty

        if listen_prob_scale != 1.0:  # 对 listen token 单独修改其 logit
            logits[0, self.listen_id] *= listen_prob_scale

        listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item()

        # 打印 top 5 tokens(如果启用)
        if debug_print_top5:
            # 先打印 softmax 之前的 top-k logits
            logits_before_softmax = logits[0] / temperature if mode == "sampling" else logits[0]
            top5_logits_before, top5_indices_before = torch.topk(
                logits_before_softmax, k=min(5, logits_before_softmax.size(-1))
            )

            print("=" * 20)

            print("\n📊 Top 5 tokens BEFORE softmax (temperature={:.2f}, mode={}):".format(temperature, mode))
            for i, (token_id, logit_val) in enumerate(zip(top5_indices_before.tolist(), top5_logits_before.tolist())):
                token_str = self.tokenizer.decode([token_id])
                # 特殊处理一些token的显示
                if token_str == "\n":
                    display_str = "\\n"
                elif token_str == " ":
                    display_str = "[SPACE]"
                elif token_str == "":
                    display_str = "[EMPTY]"
                elif token_str == "\t":
                    display_str = "\\t"
                else:
                    display_str = token_str

                # 标记特殊token
                special_mark = ""
                if token_id == self.listen_id:
                    special_mark = " 🎧[LISTEN]"
                elif token_id == self.tokenizer.eos_token_id:
                    special_mark = " 🛑[EOS]"

                print(f"  {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}")

            # 再打印 softmax 之后的 top-k probs
            probs = F.softmax(logits[0] / temperature if mode == "sampling" else logits[0], dim=-1)
            top5_probs, top5_indices = torch.topk(probs, k=min(5, probs.size(-1)))

            print("\n📊 Top 5 tokens AFTER softmax (temperature={:.2f}, mode={}):".format(temperature, mode))
            for i, (token_id, prob) in enumerate(zip(top5_indices.tolist(), top5_probs.tolist())):
                token_str = self.tokenizer.decode([token_id])
                # 特殊处理一些token的显示
                if token_str == "\n":
                    display_str = "\\n"
                elif token_str == " ":
                    display_str = "[SPACE]"
                elif token_str == "":
                    display_str = "[EMPTY]"
                elif token_str == "\t":
                    display_str = "\\t"
                else:
                    display_str = token_str

                # 标记特殊token
                special_mark = ""
                if token_id == self.listen_id:
                    special_mark = " 🎧[LISTEN]"
                elif token_id == self.tokenizer.eos_token_id:
                    special_mark = " 🛑[EOS]"

                print(
                    f"  {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): {prob:.4f} ({prob * 100:.2f}%)"
                )
            # 如果 listen token 不在 top 5,也显示它的概率
            if self.listen_id not in top5_indices.tolist():
                listen_prob = probs[self.listen_id].item()
                print(f"  ... <|listen|> 🎧 rank={listen_rank + 1}, prob={listen_prob:.6f} ({listen_prob * 100:.4f}%)")

        if listen_top_k is not None and listen_rank < listen_top_k:  # listen_id 在 top-k 里,直接返回
            next_token_id = torch.tensor([self.listen_id], device=logits.device)
            next_token_str = self.tokenizer.decode(next_token_id)

            if next_token_str == "<|listen|>":
                self.context += " "
            else:
                self.context += next_token_str

            return next_token_id

        if mode == "greedy":
            next_token_id = torch.argmax(logits, dim=-1)
        elif mode == "sampling":
            logits = logits / temperature
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probs = F.softmax(logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            raise ValueError("Unsupported decode mode")

        if next_token_id.item() not in self.special_token_ids:
            self.generated_tokens.append(next_token_id.item())
        else:
            self.generated_special_tokens.append(next_token_id.item())

        return next_token_id