🚀 Refined BitTransformerLM: Organized codebase with best practices
Browse files- bit_transformer/bit_io.py +40 -4
bit_transformer/bit_io.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
import sys
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
try: # torch.compile may be unavailable or unsupported
|
| 6 |
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
|
|
@@ -21,7 +24,14 @@ if TYPE_CHECKING: # pragma: no cover
|
|
| 21 |
|
| 22 |
@compile_fn
|
| 23 |
def bytes_to_bits(data: bytes) -> List[int]:
|
| 24 |
-
"""Convert bytes to bits with per-byte parity bit.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
result: List[int] = []
|
| 26 |
for b in data:
|
| 27 |
bits = [(b >> i) & 1 for i in reversed(range(8))]
|
|
@@ -32,7 +42,17 @@ def bytes_to_bits(data: bytes) -> List[int]:
|
|
| 32 |
|
| 33 |
@compile_fn
|
| 34 |
def bits_to_bytes(bits: List[int]) -> bytes:
|
| 35 |
-
"""Convert parity-protected bits back to bytes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if len(bits) % 9 != 0:
|
| 37 |
raise ValueError("Bit stream length must be multiple of 9")
|
| 38 |
out = bytearray()
|
|
@@ -50,10 +70,26 @@ def bits_to_bytes(bits: List[int]) -> bytes:
|
|
| 50 |
|
| 51 |
|
| 52 |
def text_to_bits(text: str) -> List[int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return bytes_to_bits(text.encode("utf-8"))
|
| 54 |
|
| 55 |
|
| 56 |
def bits_to_text(bits: List[int]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return bits_to_bytes(bits).decode("utf-8", errors="replace")
|
| 58 |
|
| 59 |
|
|
|
|
| 1 |
+
"""Text-to-bit conversion utilities with parity protection."""
|
| 2 |
+
|
| 3 |
import sys
|
| 4 |
+
from typing import TYPE_CHECKING, List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
try: # torch.compile may be unavailable or unsupported
|
| 9 |
if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
|
|
|
|
| 24 |
|
| 25 |
@compile_fn
|
| 26 |
def bytes_to_bits(data: bytes) -> List[int]:
|
| 27 |
+
"""Convert bytes to bits with per-byte parity bit.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data: Input bytes to convert.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
List of bits with parity protection (9 bits per byte).
|
| 34 |
+
"""
|
| 35 |
result: List[int] = []
|
| 36 |
for b in data:
|
| 37 |
bits = [(b >> i) & 1 for i in reversed(range(8))]
|
|
|
|
| 42 |
|
| 43 |
@compile_fn
|
| 44 |
def bits_to_bytes(bits: List[int]) -> bytes:
|
| 45 |
+
"""Convert parity-protected bits back to bytes.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
bits: List of bits with parity protection (length must be multiple of 9).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Decoded bytes.
|
| 52 |
+
|
| 53 |
+
Raises:
|
| 54 |
+
ValueError: If bit stream length is not multiple of 9 or parity check fails.
|
| 55 |
+
"""
|
| 56 |
if len(bits) % 9 != 0:
|
| 57 |
raise ValueError("Bit stream length must be multiple of 9")
|
| 58 |
out = bytearray()
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def text_to_bits(text: str) -> List[int]:
|
| 73 |
+
"""Convert text to parity-protected bit sequence.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
text: Input text string.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
List of bits with parity protection.
|
| 80 |
+
"""
|
| 81 |
return bytes_to_bits(text.encode("utf-8"))
|
| 82 |
|
| 83 |
|
| 84 |
def bits_to_text(bits: List[int]) -> str:
|
| 85 |
+
"""Convert parity-protected bit sequence to text.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
bits: List of bits with parity protection.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Decoded text string (with error replacement for invalid UTF-8).
|
| 92 |
+
"""
|
| 93 |
return bits_to_bytes(bits).decode("utf-8", errors="replace")
|
| 94 |
|
| 95 |
|