robtacconelli commited on
Commit
8f1bdaa
·
verified ·
1 Parent(s): 5b8133e

Upload 11 files

Browse files
Files changed (1) hide show
  1. compressor.py +199 -0
compressor.py CHANGED
@@ -17,6 +17,8 @@ by ParallelNeuralCompressor (NC05/NC06 formats).
17
  """
18
 
19
  import gc
 
 
20
  import struct
21
  import sys
22
 
@@ -30,6 +32,14 @@ from lzp_model import LZPModel
30
  from context_mixer import ContextMixer
31
  from adaptive_head import AdaptiveHead
32
 
 
 
 
 
 
 
 
 
33
  # ---- CDF precision ----
34
 
35
  # Enhanced CDF: 2^24 instead of the original 2^16.
@@ -606,3 +616,192 @@ class NeuralCompressor:
606
  num_models=num_mix, lr=DEFAULT_MIXER_LR,
607
  ) if num_mix > 1 else None
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
  import gc
20
+ import gzip
21
+ import lzma
22
  import struct
23
  import sys
24
 
 
32
  from context_mixer import ContextMixer
33
  from adaptive_head import AdaptiveHead
34
 
35
+ # ---- File format constants (NC05 text / NC06 hybrid binary) ----
36
+
37
+ MAGIC = b'NC05' # single-worker text format
38
+ MAGIC_BIN = b'NC06' # single-worker hybrid binary format
39
+ # Minimum bytes needed to identify a valid header (NC05: 9B)
40
+ HEADER_SIZE = 9
41
+ NC06_VERSION = 1
42
+
43
  # ---- CDF precision ----
44
 
45
  # Enhanced CDF: 2^24 instead of the original 2^16.
 
616
  num_models=num_mix, lr=DEFAULT_MIXER_LR,
617
  ) if num_mix > 1 else None
618
 
619
+ # ------------------------------------------------------------------
620
+ # Public compress / decompress (NC05 text, NC06 hybrid binary)
621
+ # ------------------------------------------------------------------
622
+
623
+ def compress(self, text: str) -> bytes:
624
+ """Compress text to bytes (NC05 single-chunk format)."""
625
+ flags = self._config_flags()
626
+ temp_encoded = int(round(self.temperature * 10000))
627
+
628
+ if not text:
629
+ return MAGIC + struct.pack('>BHH', flags, temp_encoded, 0)
630
+
631
+ self.model.reset_cache()
632
+ self._reset_secondary_models()
633
+ num_tokens, compressed_bits, stream = self._compress_text_to_stream(text)
634
+
635
+ header = MAGIC + struct.pack('>BHH', flags, temp_encoded, 1)
636
+ entry = struct.pack('>III', num_tokens, compressed_bits, len(stream))
637
+ return header + entry + stream
638
+
639
+ def compress_bytes(self, data: bytes) -> bytes:
640
+ """Compress raw bytes using hybrid chunked format (NC06)."""
641
+ chunks = _segment_chunks(data)
642
+ num_entries = len(chunks)
643
+
644
+ flags = self._config_flags()
645
+ temp_encoded = int(round(self.temperature * 10000))
646
+ file_header = MAGIC_BIN + struct.pack(
647
+ '>BHII', flags, temp_encoded, NC06_VERSION, num_entries,
648
+ )
649
+
650
+ if num_entries == 0:
651
+ return file_header
652
+
653
+ entry_table = []
654
+ binary_parts = []
655
+ text_indices = []
656
+ total_binary = 0
657
+
658
+ for ci, (chunk_type, offset, length) in enumerate(chunks):
659
+ entry_table.append(struct.pack('>BI', chunk_type, length))
660
+ if chunk_type == CHUNK_TYPE_BINARY:
661
+ binary_parts.append(data[offset:offset + length])
662
+ total_binary += length
663
+ else:
664
+ text_indices.append(ci)
665
+
666
+ if total_binary > 0:
667
+ binary_blob = b''.join(binary_parts)
668
+ if total_binary >= LZMA_THRESHOLD:
669
+ compressed = lzma.compress(binary_blob)
670
+ method = BLOB_LZMA
671
+ else:
672
+ compressed = gzip.compress(binary_blob, compresslevel=9)
673
+ method = BLOB_GZIP
674
+ if len(compressed) >= total_binary:
675
+ compressed = binary_blob
676
+ method = BLOB_RAW
677
+ binary_section = struct.pack('>BI', method, len(compressed)) + compressed
678
+ else:
679
+ binary_section = b''
680
+
681
+ # NC06 text entry: n_sub_chunks(2) + sub-chunk table + streams
682
+ # Single worker: always 1 sub-chunk per text entry.
683
+ text_sections = []
684
+ for ci in text_indices:
685
+ chunk_type, offset, length = chunks[ci]
686
+ text = data[offset:offset + length].decode('latin-1')
687
+ self.model.reset_cache()
688
+ self._reset_secondary_models()
689
+ token_count, bit_count, stream = self._compress_text_to_stream(text)
690
+ sub_entry = struct.pack('>III', token_count, bit_count, len(stream))
691
+ text_sections.append(struct.pack('>H', 1) + sub_entry + stream)
692
+
693
+ return (file_header
694
+ + b''.join(entry_table)
695
+ + binary_section
696
+ + b''.join(text_sections))
697
+
698
+ def decompress(self, data: bytes) -> 'str | bytes':
699
+ """Decompress NC05 (text) or NC06 (hybrid binary) format."""
700
+ if len(data) < HEADER_SIZE:
701
+ raise ValueError("Data too short to contain a valid header")
702
+ magic = data[:4]
703
+ if magic == MAGIC:
704
+ return self._decompress_nc05(data)
705
+ elif magic == MAGIC_BIN:
706
+ return self._decompress_nc06(data)
707
+ else:
708
+ raise ValueError(
709
+ f"Invalid magic bytes: {magic!r} "
710
+ f"(expected {MAGIC!r} or {MAGIC_BIN!r})"
711
+ )
712
+
713
+ def _decompress_nc05(self, data: bytes) -> str:
714
+ """Decompress NC05 (text) format."""
715
+ flags = data[4]
716
+ temp_encoded, n_chunks = struct.unpack('>HH', data[5:9])
717
+
718
+ if n_chunks == 0:
719
+ return ""
720
+
721
+ self._apply_flags(flags)
722
+ self.temperature = temp_encoded / 10000.0
723
+
724
+ pos = 9
725
+ entries = []
726
+ for _ in range(n_chunks):
727
+ num_tokens, comp_bits, stream_len = struct.unpack(
728
+ '>III', data[pos:pos + 12],
729
+ )
730
+ entries.append((num_tokens, comp_bits, stream_len))
731
+ pos += 12
732
+
733
+ texts = []
734
+ for num_tokens, comp_bits, stream_len in entries:
735
+ stream = data[pos:pos + stream_len]
736
+ pos += stream_len
737
+ self.model.reset_cache()
738
+ self._reset_secondary_models()
739
+ texts.append(self._decompress_text_stream(stream, num_tokens))
740
+
741
+ return ''.join(texts)
742
+
743
+ def _decompress_nc06(self, data: bytes) -> bytes:
744
+ """Decompress NC06 (hybrid binary) format."""
745
+ flags = data[4]
746
+ temp_encoded, _version, num_entries = struct.unpack('>HII', data[5:15])
747
+
748
+ self._apply_flags(flags)
749
+ self.temperature = temp_encoded / 10000.0
750
+
751
+ if num_entries == 0:
752
+ return b""
753
+
754
+ pos = 15
755
+ entries = []
756
+ total_binary = 0
757
+ for _ in range(num_entries):
758
+ etype, elen = struct.unpack('>BI', data[pos:pos + 5])
759
+ entries.append((etype, elen))
760
+ if etype == CHUNK_TYPE_BINARY:
761
+ total_binary += elen
762
+ pos += 5
763
+
764
+ binary_data = b''
765
+ if total_binary > 0:
766
+ method, comp_len = struct.unpack('>BI', data[pos:pos + 5])
767
+ pos += 5
768
+ compressed = data[pos:pos + comp_len]
769
+ pos += comp_len
770
+ if method == BLOB_RAW:
771
+ binary_data = compressed
772
+ elif method == BLOB_GZIP:
773
+ binary_data = gzip.decompress(compressed)
774
+ elif method == BLOB_LZMA:
775
+ binary_data = lzma.decompress(compressed)
776
+
777
+ binary_offset = 0
778
+ output_parts = []
779
+ for etype, elen in entries:
780
+ if etype == CHUNK_TYPE_BINARY:
781
+ output_parts.append(
782
+ binary_data[binary_offset:binary_offset + elen]
783
+ )
784
+ binary_offset += elen
785
+ else:
786
+ n_sub = struct.unpack('>H', data[pos:pos + 2])[0]
787
+ pos += 2
788
+ sub_entries = []
789
+ for _ in range(n_sub):
790
+ num_tokens, comp_bits, stream_len = struct.unpack(
791
+ '>III', data[pos:pos + 12],
792
+ )
793
+ sub_entries.append((num_tokens, comp_bits, stream_len))
794
+ pos += 12
795
+ texts = []
796
+ for num_tokens, comp_bits, stream_len in sub_entries:
797
+ stream = data[pos:pos + stream_len]
798
+ pos += stream_len
799
+ self.model.reset_cache()
800
+ self._reset_secondary_models()
801
+ texts.append(
802
+ self._decompress_text_stream(stream, num_tokens)
803
+ )
804
+ output_parts.append(''.join(texts).encode('latin-1'))
805
+
806
+ return b''.join(output_parts)
807
+