WCNegentropy commited on
Commit
00c7a97
·
verified ·
1 Parent(s): b08919a

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. bit_transformer/bit_io.py +40 -4
bit_transformer/bit_io.py CHANGED
@@ -1,6 +1,9 @@
1
- from typing import List, TYPE_CHECKING
2
- import torch
3
  import sys
 
 
 
4
 
5
  try: # torch.compile may be unavailable or unsupported
6
  if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
@@ -21,7 +24,14 @@ if TYPE_CHECKING: # pragma: no cover
21
 
22
  @compile_fn
23
  def bytes_to_bits(data: bytes) -> List[int]:
24
- """Convert bytes to bits with per-byte parity bit."""
 
 
 
 
 
 
 
25
  result: List[int] = []
26
  for b in data:
27
  bits = [(b >> i) & 1 for i in reversed(range(8))]
@@ -32,7 +42,17 @@ def bytes_to_bits(data: bytes) -> List[int]:
32
 
33
  @compile_fn
34
  def bits_to_bytes(bits: List[int]) -> bytes:
35
- """Convert parity-protected bits back to bytes."""
 
 
 
 
 
 
 
 
 
 
36
  if len(bits) % 9 != 0:
37
  raise ValueError("Bit stream length must be multiple of 9")
38
  out = bytearray()
@@ -50,10 +70,26 @@ def bits_to_bytes(bits: List[int]) -> bytes:
50
 
51
 
52
  def text_to_bits(text: str) -> List[int]:
 
 
 
 
 
 
 
 
53
  return bytes_to_bits(text.encode("utf-8"))
54
 
55
 
56
  def bits_to_text(bits: List[int]) -> str:
 
 
 
 
 
 
 
 
57
  return bits_to_bytes(bits).decode("utf-8", errors="replace")
58
 
59
 
 
1
+ """Text-to-bit conversion utilities with parity protection."""
2
+
3
  import sys
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import torch
7
 
8
  try: # torch.compile may be unavailable or unsupported
9
  if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
 
24
 
25
  @compile_fn
26
  def bytes_to_bits(data: bytes) -> List[int]:
27
+ """Convert bytes to bits with per-byte parity bit.
28
+
29
+ Args:
30
+ data: Input bytes to convert.
31
+
32
+ Returns:
33
+ List of bits with parity protection (9 bits per byte).
34
+ """
35
  result: List[int] = []
36
  for b in data:
37
  bits = [(b >> i) & 1 for i in reversed(range(8))]
 
42
 
43
  @compile_fn
44
  def bits_to_bytes(bits: List[int]) -> bytes:
45
+ """Convert parity-protected bits back to bytes.
46
+
47
+ Args:
48
+ bits: List of bits with parity protection (length must be multiple of 9).
49
+
50
+ Returns:
51
+ Decoded bytes.
52
+
53
+ Raises:
54
+ ValueError: If bit stream length is not multiple of 9 or parity check fails.
55
+ """
56
  if len(bits) % 9 != 0:
57
  raise ValueError("Bit stream length must be multiple of 9")
58
  out = bytearray()
 
70
 
71
 
72
  def text_to_bits(text: str) -> List[int]:
73
+ """Convert text to parity-protected bit sequence.
74
+
75
+ Args:
76
+ text: Input text string.
77
+
78
+ Returns:
79
+ List of bits with parity protection.
80
+ """
81
  return bytes_to_bits(text.encode("utf-8"))
82
 
83
 
84
  def bits_to_text(bits: List[int]) -> str:
85
+ """Convert parity-protected bit sequence to text.
86
+
87
+ Args:
88
+ bits: List of bits with parity protection.
89
+
90
+ Returns:
91
+ Decoded text string (with error replacement for invalid UTF-8).
92
+ """
93
  return bits_to_bytes(bits).decode("utf-8", errors="replace")
94
 
95