File size: 27,313 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Callable, Tuple, Optional
import logging
import argparse
import re
import gc
from collections import defaultdict, Counter
from m1_compression.compressor import (
    load_m1_model_and_tokenizer,
    ALPHABET_SIZE,
)
import multiprocessing as mp
from offline_utils import (
    compress_windows_starts_lens,
    decompress_windows_starts_lens,
    unpack_windows,
    InterleavedJsonlDataset,
    batched_m1_compress_predict_fn,
    find_next_batch_range,
)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

MAX_LINE_LEN = 512

def print_windows(text: str,
                  starts: list[int],
                  lens: list[int],
                  sample_idx: int = None,
                  ):
    from rich.console import Console
    from rich.text import Text
    import io
    PALETTE = (
        "#c6f6d5", "#bee3f8", "#fbb6ce",
        "#faf089", "#fed7e2", "#b2f5ea",
    )
    string_io = io.StringIO()
    console = Console(record=True, force_terminal=True, color_system="truecolor", file=string_io)

    t = Text()
    last_end = 0
    colour_idx = 0

    for s, l in sorted(zip(starts, lens)):
        t.append(text[last_end:s])

        if s == last_end:
            colour_idx = (colour_idx + 1) % len(PALETTE)

        t.append(text[s:s + l],
                 style=f"on {PALETTE[colour_idx]} bold black")
        last_end = s + l

    t.append(text[last_end:])
    console.print(t)

    # only save the last 100 samples
    save_idx = sample_idx % 100
    return console.save_svg(f"window_visualize_{save_idx}.svg")

def collect_lines(batched_bytes_data: List[bytes], max_len: int = 2048) -> Tuple[List[bytes], Dict[int, Tuple[int, int]]]:
    batched_lines = []
    line_id_to_sample_offsets = {}
    line_idx = 0
    
    for sample_idx, data_bytes in enumerate(batched_bytes_data):
        if len(data_bytes) == 0:
            continue
        
        # Find all lines with their consecutive newlines attached (handles \r\n, \r, \n)
        lines_with_positions = []
        for match in re.finditer(b'[^\r\n]*(?:\r\n|\r|\n)*', data_bytes):
            if match.group():  # Skip empty matches
                lines_with_positions.append((match.group(), match.start()))
        
        for line, byte_offset in lines_with_positions:
            if len(line) > max_len:
                logger.info("Line too long with {} bytes, splitting into chunks...".format(len(line)))
                # Split long line into chunks of max_len
                for chunk_start in range(0, len(line), max_len):
                    chunk_end = min(chunk_start + max_len, len(line))
                    batched_lines.append(line[chunk_start:chunk_end])
                    # Calculate the absolute byte offset for this chunk
                    chunk_byte_offset = byte_offset + chunk_start
                    line_id_to_sample_offsets[line_idx] = (sample_idx, chunk_byte_offset)
                    line_idx += 1
            else:
                batched_lines.append(line)
                line_id_to_sample_offsets[line_idx] = (sample_idx, byte_offset)
                line_idx += 1

    return batched_lines, line_id_to_sample_offsets

def calculate_skew(entropy: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(entropy)
    diffs = entropy - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    if std == 0.0:
        return torch.tensor(0.0)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    return skews

def get_split_points(
        probs: torch.Tensor, 
        next_bytes: torch.Tensor, 
        lengths: torch.Tensor, 
        base_global_quantile: float, 
        base_monotonic_quantile: float,
        debug: bool = False,
    ):
    B, L = probs.shape[0], probs.shape[1]
    arange_ids = torch.arange(L, device=probs.device).unsqueeze(0)
    pad_mask = arange_ids < lengths.unsqueeze(1)
    padded_cross_entropy = F.cross_entropy(
        probs.transpose(1, 2), 
        next_bytes, 
        reduction="none"
    )

    flattened_cross_entropy = padded_cross_entropy[pad_mask]
    assert flattened_cross_entropy.dim() == 1

    skew_flattened_cross_entropy = calculate_skew(flattened_cross_entropy.float())
    if skew_flattened_cross_entropy > 0.0:
        base_global_quantile = base_global_quantile - 0.04 * skew_flattened_cross_entropy.item()
        base_global_quantile = min(max(base_global_quantile, 0.0), 1.0)

    # entropy is a tensor of shape (B * L_b)
    threshold = torch.quantile(flattened_cross_entropy, base_global_quantile).clamp(0.1, 10.0)

    padded_cross_entropy_diff = torch.diff(padded_cross_entropy, dim=1)
    padded_cross_entropy_diff = torch.cat(
        [
            torch.zeros(B, 1, device=padded_cross_entropy_diff.device), 
            padded_cross_entropy_diff
        ], 
        dim=1
    )
    flattened_cross_entropy_diff = padded_cross_entropy_diff[pad_mask]

    skew_flattened_cross_entropy_diff = calculate_skew(flattened_cross_entropy_diff.float())
    if skew_flattened_cross_entropy_diff > 0.0:
        base_monotonic_quantile = base_monotonic_quantile - 0.04 * skew_flattened_cross_entropy_diff.item()
        base_monotonic_quantile = min(max(base_monotonic_quantile, 0.0), 1.0)

    diff_threshold = torch.quantile(flattened_cross_entropy_diff, base_monotonic_quantile).clamp(0.01, 10.0)
    split_points_mask = ((padded_cross_entropy > threshold) | (padded_cross_entropy_diff > diff_threshold)) & pad_mask

    if debug:
        logger.info(f"skew_flattened_cross_entropy: {skew_flattened_cross_entropy}")
        logger.info(f"skew_flattened_cross_entropy_diff: {skew_flattened_cross_entropy_diff}")
        logger.info(f"base_global_quantile: {base_global_quantile}")
        logger.info(f"base_monotonic_quantile: {base_monotonic_quantile}")
        logger.info(f"threshold: {threshold}")
        logger.info(f"diff_threshold: {diff_threshold}")
    return split_points_mask

def get_batch_size_for_length(window_len, max_batch_size):
    """Determines the batch size for a given window length."""
    BATCH_SIZE_TIERS = {
        512: max_batch_size,
        1024: max(max_batch_size // 2, 1),
        2048: max(max_batch_size // 4, 1),
    }
    for max_len, batch_size in BATCH_SIZE_TIERS.items():
        if window_len <= max_len:
            return batch_size
    return 1

def calculate_entropy_and_split_points_fn(
    batch: Dict[str, Any], # List, [{"text":"Hello world","id":"1"}]
    predict_fn: Callable,
    chunk_size: int = 512,
    base_global_quantile: float = 90.0,
    base_monotonic_quantile: float = 90.0,
    unigram_probs: Optional[torch.Tensor] = None,
    max_m1_batch_size: int = 2048,
    line_split: bool = False,
    debug: bool = False,
) -> List[Dict[str, Any]]:

    batched_bytes_data = [item["text"].encode('utf-8') for item in batch]
    # List bytes: [bytes, bytes,...]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if unigram_probs is not None:
        unigram_probs = unigram_probs.to(device)

    # 2. batched segmentations
    all_split_point_masks = []

    # 1. preprocess all samples, record each chunk
    if line_split:
        chunks, chunk_to_sample_and_offset = collect_lines(batched_bytes_data, max_len=MAX_LINE_LEN)

        sorted_chunks = sorted(enumerate(chunks), key=lambda x: len(x[1]))
        sorted_idx, sorted_chunks = zip(*sorted_chunks)
        sorted_chunks = list(sorted_chunks)  # Convert tuple to list
        chunk_idx_map = {
            orig_idx: new_idx 
            for new_idx, orig_idx in enumerate(sorted_idx)
        }

        chunks_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in sorted_chunks]
        num_chunks = len(sorted_chunks)
        start_idx = 0
        while start_idx < num_chunks:
            # Use the new helper function to find the exact range for the next safe batch
            start_idx, end_idx = find_next_batch_range(chunks_np, start_idx, max_m1_batch_size, get_batch_size_for_length)

            batch_chunks_np = chunks_np[start_idx:end_idx]
            
            effective_batch_size = end_idx - start_idx
            
            lengths_pt = torch.tensor([len(chunk) for chunk in batch_chunks_np], dtype=torch.long, device=device)
            batch_chunks_pt = torch.zeros(
                (effective_batch_size, max(lengths_pt)), 
                dtype=torch.long, 
                device=device
            )
            for i, chunk_np in enumerate(batch_chunks_np):
                batch_chunks_pt[i, :len(chunk_np)] = torch.tensor(chunk_np, dtype=torch.long, device=device)

            cur_batch = batch_chunks_pt[:effective_batch_size]
            cur_lengths = lengths_pt[:effective_batch_size]
            with torch.no_grad():
                probs = predict_fn(cur_batch)
            
            # add unigram 
            first_prob = unigram_probs.expand(
                effective_batch_size, 1, -1)
            final_probs = torch.cat([first_prob, probs[:, :-1, :]], dim=1)
            start_idx = end_idx
            
            # calculate (cross) entropy,
            # calculate dynamic threshold,
            # calculate split points
            split_points_mask = get_split_points(
                final_probs,
                cur_batch,
                cur_lengths,
                base_global_quantile,
                base_monotonic_quantile,
                debug,
            )
            all_split_point_masks.append(split_points_mask)
        
        split_point_chunk_idx_lst = []
        split_point_position_idx_lst = []
        processed_chunks = 0
        for mask in all_split_point_masks:
            split_point_chunk_idx, split_point_position_idx = mask.cpu().nonzero(as_tuple=True)
            split_point_chunk_idx_lst.append(split_point_chunk_idx + processed_chunks)
            split_point_position_idx_lst.append(split_point_position_idx)
            processed_chunks = processed_chunks + mask.shape[0]
        split_point_chunk_idx = torch.cat(split_point_chunk_idx_lst)
        split_point_position_idx = torch.cat(split_point_position_idx_lst)
    else:
        chunk_idx_map = None
        # 1. preprocess all samples, record each chunk
        chunks = []
        chunk_to_sample_and_offset = {}
        chunk_idx = 0
        for sample_idx, data_bytes in enumerate(batched_bytes_data):
            logger.debug(f"Processing sample {sample_idx+1} (bytes: {len(data_bytes)})")
            
            if len(data_bytes) == 0:
                continue
                
            byte_len = len(data_bytes)

            for i in range(0, byte_len, chunk_size):
                chunk_start = i
                chunk_end = min(i + chunk_size, byte_len)
                chunk = data_bytes[chunk_start:chunk_end]
                chunks.append(chunk)
                chunk_to_sample_and_offset[chunk_idx] = (sample_idx, chunk_start)
                # key: chunk_idx 被切分的块 -> (sample_idx, chunk_start)在原来某个样本中,起始位置
                chunk_idx += 1

        # 2. batched segmentations
        all_split_point_masks = []

        batch_chunks_pt = torch.zeros(
            (max_m1_batch_size, chunk_size), 
            dtype=torch.long, 
            device=device
        )
        lengths_pt = torch.zeros(max_m1_batch_size, dtype=torch.long, device=device)
        num_chunks = len(chunks)
        # batched get all segmentations
        for start_idx in range(0, num_chunks, max_m1_batch_size):
            end_idx = min(start_idx + max_m1_batch_size, num_chunks)
            batch_chunks = chunks[start_idx:end_idx]
            batch_chunks_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch_chunks]
            
            
            effective_batch_size = end_idx - start_idx
            # padding
            for i, chunk_np in enumerate(batch_chunks_np):
                batch_chunks_pt[i, :len(chunk_np)] = torch.tensor(chunk_np, dtype=torch.long, device=device)
                lengths_pt[i] = len(chunk_np)

            cur_batch = batch_chunks_pt[:effective_batch_size]
            cur_lengths = lengths_pt[:effective_batch_size]
            with torch.no_grad():
                probs = predict_fn(cur_batch)
            
            # add unigram 
            first_prob = unigram_probs.expand(
                effective_batch_size, 1, -1)
            final_probs = torch.cat([first_prob, probs[:, :-1, :]], dim=1)
            
            # calculate (cross) entropy,
            # calculate dynamic threshold,
            # calculate split points
            split_points_mask = get_split_points(
                final_probs,
                cur_batch,
                cur_lengths,
                base_global_quantile,
                base_monotonic_quantile,
                debug,
            )
            all_split_point_masks.append(split_points_mask)

        all_split_point_masks = torch.cat(all_split_point_masks, dim=0)
        # (nums_chunks,chunk_size)-> 汇集所有chunks的切分点标记
        all_split_points_tuple = all_split_point_masks.nonzero(as_tuple=True)
        # chunk_idx, position_idx包含两个一维张量,即第几个chunk的对应某个位置有切分点
        # `(tensor([0, 0, 2]), tensor([15, 89, 412]))` 表示在第0个chunk的第15和89位置,以及第2个chunk的第412位置有切分点。
        split_point_chunk_idx, split_point_position_idx = all_split_points_tuple[0].cpu(), all_split_points_tuple[1].cpu()

    sample_idx_to_split_positions = defaultdict(list)

    ## avoid scan each chunk -> dict -> 直接将每个split_point放到对应的chunk_idx中,用处理好的信息完成映射

    # 1. transfer pytorch to numpy
    split_point_chunk_idx_np = split_point_chunk_idx.numpy()
    split_point_position_idx_np = split_point_position_idx.numpy()
    chunk_to_splits = defaultdict(list)

    # 2. scan all the split_point and put it into chunk-dicts 如果先遍历放到字典中时 split_points次
    for i in range(len(split_point_chunk_idx_np)):
        chunk_idx = split_point_chunk_idx_np[i]
        position = split_point_position_idx_np[i]
        chunk_to_splits[chunk_idx].append(position)

    # 3. process all the chunk_to_split
    for chunk_idx in range(num_chunks):
        # chunk_idx: original index of chunk
        chunk = chunks[chunk_idx]
        sample_idx, chunk_start = chunk_to_sample_and_offset[chunk_idx]
        if line_split:
            sorted_chunk_idx = chunk_idx_map[chunk_idx]
            split_points = chunk_to_splits[sorted_chunk_idx]
        else:
            split_points = chunk_to_splits[chunk_idx]

        if len(split_points) == 0:
            split_points = [0]
        if split_points[0] != 0:
            split_points.insert(0, 0)
        split_points.append(len(chunk))
        
        offset_split_points = [s + chunk_start for s in split_points]
        sample_idx_to_split_positions[sample_idx].extend(offset_split_points)

    # for chunk_idx in range(num_chunks): # 相当于遍历 chunk_size * split_points 次
    #     chunk = chunks[chunk_idx]
    #     sample_idx, chunk_start = chunk_to_sample_and_offset[chunk_idx]

    #     # nonzero() leads to a GPU-CPU sync point so we defer until all batches finished
    #     split_points = split_point_position_idx[split_point_chunk_idx == chunk_idx].tolist()
    #     if len(split_points) == 0:
    #         split_points = [0]
    #     if split_points[0] != 0:
    #         split_points.insert(0, 0)
    #     split_points.append(len(chunk))
    #     offset_split_points = [s + chunk_start for s in split_points]
    #     sample_idx_to_split_positions[sample_idx].extend(offset_split_points) # 还原到原始样本的offset

    sample_idx_to_split_positions = {k: sorted(v) for k, v in sample_idx_to_split_positions.items()}

    write_results = []
    min_window_size = 3

    if debug:
        extreme_compression_results = []
    for sample_idx, item in enumerate(batch):
        split_points = sample_idx_to_split_positions[sample_idx]
        split_windows_starts = []
        split_windows_lens = []
        cur_l = 0
        cur_r = 0
        for i in range(len(split_points) - 1):
            cur_l = split_points[i]
            cur_r = split_points[i+1]
            if cur_r - cur_l >= min_window_size:
                split_windows_starts.append(cur_l)
                split_windows_lens.append(cur_r - cur_l)
        # assert cur_r == len(item["text"].encode('utf-8')), f"last cur_r: {cur_r} != len(item['text']): {len(item['text'])}"
        compressed_windows_starts_lens_b64 = compress_windows_starts_lens(split_windows_starts, split_windows_lens)
        result = {
            **item,
            "windows_starts_lens_b64": compressed_windows_starts_lens_b64
        }
        if debug:
            print_windows(item["text"], split_windows_starts, split_windows_lens, sample_idx=sample_idx)
            _debug_starts_lens = decompress_windows_starts_lens(compressed_windows_starts_lens_b64)
            _debug_starts, _debug_lens = _debug_starts_lens
            assert len(_debug_starts) == len(_debug_lens), f"Window starts and lens have different lengths: {len(_debug_starts)} != {len(_debug_lens)}"
            assert _debug_starts == split_windows_starts, f"Window starts do not match: {_debug_starts} != {split_windows_starts}"
            assert _debug_lens == split_windows_lens, f"Window lens do not match: {_debug_lens} != {split_windows_lens}"

            # calculate extreme compression rate: compress all windows into 1 byte
            debug_sample = item["text"].encode('utf-8')
            raw_bytes = len(debug_sample) - sum(_debug_lens)
            compressed_bytes = len(_debug_starts)
            extreme_compression_rate = (compressed_bytes + raw_bytes) / len(debug_sample)
            extreme_compression_results.append(extreme_compression_rate)
            logger.info(f"[Extreme compression rate] for sample idx {sample_idx}: {extreme_compression_rate:.4f}")
            debug_byte_windows = unpack_windows(debug_sample, compressed_windows_starts_lens_b64)
            debug_bytes_windows, debug_indicators = zip(*debug_byte_windows)
            assert b"".join(debug_bytes_windows) == debug_sample, f"Debug bytes windows do not match: {b''.join(debug_bytes_windows)} != {debug_sample}"

            debug_split_points = sample_idx_to_split_positions[sample_idx]
            logger.info(f"Original byte length: {len(debug_sample)}")
            logger.info(f"num split_points: {len(debug_split_points)}")
            
            _debug_compressed_windows = [x[0] for x in debug_byte_windows if x[1]]
            _debug_sorted_compressed_windows = sorted(_debug_compressed_windows, key=lambda x: len(x), reverse=True)
            _debug_raw_windows = [x[0] for x in debug_byte_windows if not x[1]]
            _debug_sorted_raw_windows = sorted(_debug_raw_windows, key=lambda x: len(x), reverse=True)
            for i, byte_window in enumerate(_debug_sorted_compressed_windows):
                logger.info(f"compressed byte_window[{i}]: {byte_window}")
                if i > 10:
                    break
            for i, byte_window in enumerate(_debug_sorted_raw_windows):
                logger.info(f"raw byte_window[{i}]: {byte_window}")
                if i > 10:
                    break
        write_results.append(result)
    if debug:
        logger.info(f"[Extreme compression rate] for all samples: {np.mean(extreme_compression_results):.4f}")
    return write_results

def writer_consumer(write_queue, output_file, buffer_size=100):
    """
    Writer consumer: reads compressed results from write_queue and writes to file.
    Maintains its own buffer and writes when buffer is full or receives sentinel.
    """
    write_buf = []
    
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            while True:
                item = write_queue.get()
                if item is None:
                    break
                
                write_buf.extend(item)
                
                # Write buffer when it's full
                if len(write_buf) >= buffer_size:
                    logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}")
                    for buffered_item in write_buf:
                        f.write(json.dumps(buffered_item) + '\n')
                    f.flush()
                    write_buf = []

            # Write remaining items in buffer
            if write_buf:
                logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}")
                for buffered_item in write_buf:
                    f.write(json.dumps(buffered_item) + '\n')
                f.flush()
                    
    except Exception as e:
        logger.error(f"Writer process error: {e}")
        raise



def main():
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach')
    parser.add_argument('--input_file', type=str, required=True,
                      help='Directory containing input JSONL files')
    parser.add_argument('--output_dir', type=str, required=True,
                      help='Directory to write compressed results')
    parser.add_argument('--entropy_model_path', type=str, required=True,
                      help='Path to the M1 model checkpoint')
    parser.add_argument('--compression_model_path', type=str, required=True,
                      help='Path to the M1 model checkpoint')
    parser.add_argument('--data_batch_size', type=int, default=512,
                      help='Size of batches for processing (default: 512)')
    parser.add_argument('--output_window_size', type=int, default=16,
                      help='Size of window for compression (default: 16)')
    parser.add_argument('--max_window_size', type=int, default=1024,
                      help='Maximum window size for reading from each file (default: 1024)')
    parser.add_argument('--max_entropy_batch_size', type=int, default=4096,
                      help='Size of max batch for compression (default: 4096)')
    parser.add_argument('--max_compression_batch_size', type=int, default=4096,
                      help='Size of max batch for compression (default: 4096)')
    parser.add_argument('--chunk_size', type=int, default=512,
                      help='Size of chunk for compression (default: 512)')
    parser.add_argument('--base_global_quantile', type=float, default=0.9,
                      help='Base global quantile for compression (default: 0.9)')
    parser.add_argument('--base_monotonic_quantile', type=float, default=0.9,
                      help='Base monotonic quantile for compression (default: 0.9)')
    parser.add_argument('--apply_line_split', action='store_true', default=False,
                      help='apply_line_split')
    parser.add_argument('--debug', action='store_true', default=False,
                      help='Debug mode (default: False)')
    parser.add_argument('--firstbyte_prob_path', type=str, default=None,
                      help='Probability path for the first word of each window (default : None)')
    parser.add_argument('--num_workers', type=int, default=1,
                      help='Number of workers for CPU jobs (default: 1)')
    parser.add_argument('--process_id', type=int, default=0,
                      help='Process ID for distributed processing (default: 0)')
    parser.add_argument('--num_processes', type=int, default=1,
                      help='Number of processes for distributed processing (default: 1)')
    args = parser.parse_args()

    mp.set_start_method('spawn', force=True)
    gc_freq = 100
    dump_freq = 25

    # Create output directory if it doesn't exist
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load model and tokenizer
    model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path)
    batched_predict_fn = batched_m1_compress_predict_fn(model)

    if args.firstbyte_prob_path is not None:
        with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f:
            first_byte_prob = json.load(f)
        print(first_byte_prob)
        first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0)
    else:
        first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE

    # Create dataset and dataloader
    # dataset = JsonlShardedDataset(
    #     args.input_file,
    #     current_proc_rank=args.process_id,
    #     total_procs=args.num_processes,
    # )
    dataset = InterleavedJsonlDataset(
        file_path=args.input_file,
        rank=args.process_id,
        world_size=args.num_processes,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.data_batch_size,
        shuffle=False,
        collate_fn=lambda x: x
    )

    input_file = Path(args.input_file)
    logger.info(f"Processing file: {input_file}")
    
    output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl"
    
    logger.info("Data loaded. Start processing...")

    # Create queue and start writer process
    write_queue = mp.Queue(maxsize=200)
    writer_process = mp.Process(
        target=writer_consumer, 
        args=(write_queue, output_file, dump_freq)
    )
    writer_process.start()
    
    try:
        # Process each batch
        for batch_idx, batch in enumerate(dataloader):
            split_points_results = calculate_entropy_and_split_points_fn(
                batch,
                batched_predict_fn,
                chunk_size=args.chunk_size,
                base_global_quantile=args.base_global_quantile,
                base_monotonic_quantile=args.base_monotonic_quantile,
                unigram_probs=first_byte_prob,
                max_m1_batch_size=args.max_entropy_batch_size,
                line_split=args.apply_line_split,
                debug=args.debug,
            )
            logger.info(f"Processed batch {batch_idx}")
            write_queue.put(split_points_results)

            if batch_idx % gc_freq == 0:
                # Clean up GPU memory
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        # Signal completion to writer process
        write_queue.put(None)
        
    except Exception as e:
        logger.error(f"Error during processing: {e}")
        # Try to terminate writer process cleanly
        try:
            write_queue.put(None)
        except:
            pass
        raise
    finally:
        # Wait for writer process to finish
        writer_process.join()
        if writer_process.exitcode != 0:
            logger.error(f"Writer process failed with exit code: {writer_process.exitcode}")

    logger.info(f"Completed processing successfully, output written to {output_file}")

if __name__ == "__main__":
    main()