WCNegentropy commited on
Commit
9684e67
·
verified ·
1 Parent(s): 8445c2b

Remove nested directory: BitTransformerLM/bit_transformer/bit_io.py

Browse files
BitTransformerLM/bit_transformer/bit_io.py DELETED
@@ -1,97 +0,0 @@
1
- from typing import List, TYPE_CHECKING
2
- import torch
3
- import sys
4
-
5
- try: # torch.compile may be unavailable or unsupported
6
- if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
7
- compile_fn = torch.compile
8
- else:
9
- raise RuntimeError
10
- except Exception: # pragma: no cover
11
-
12
- def compile_fn(fn=None, **kwargs):
13
- if fn is None:
14
- return lambda f: f
15
- return fn
16
-
17
-
18
- if TYPE_CHECKING: # pragma: no cover
19
- from .model import BitTransformerLM
20
-
21
-
22
- @compile_fn
23
- def bytes_to_bits(data: bytes) -> List[int]:
24
- """Convert bytes to bits with per-byte parity bit."""
25
- result: List[int] = []
26
- for b in data:
27
- bits = [(b >> i) & 1 for i in reversed(range(8))]
28
- parity = sum(bits) % 2
29
- result.extend(bits + [parity])
30
- return result
31
-
32
-
33
- @compile_fn
34
- def bits_to_bytes(bits: List[int]) -> bytes:
35
- """Convert parity-protected bits back to bytes."""
36
- if len(bits) % 9 != 0:
37
- raise ValueError("Bit stream length must be multiple of 9")
38
- out = bytearray()
39
- for i in range(0, len(bits), 9):
40
- chunk = bits[i : i + 9]
41
- payload = chunk[:8]
42
- parity = chunk[8]
43
- if parity != sum(payload) % 2:
44
- raise ValueError("Parity check failed")
45
- value = 0
46
- for bit in payload:
47
- value = (value << 1) | bit
48
- out.append(value)
49
- return bytes(out)
50
-
51
-
52
- def text_to_bits(text: str) -> List[int]:
53
- return bytes_to_bits(text.encode("utf-8"))
54
-
55
-
56
- def bits_to_text(bits: List[int]) -> str:
57
- return bits_to_bytes(bits).decode("utf-8", errors="replace")
58
-
59
-
60
- def infer_text(
61
- model: "BitTransformerLM",
62
- text: str,
63
- c_floor: float = 0.3,
64
- s_floor: float = 0.5,
65
- ) -> str:
66
- """Run text through the model using the safety gate."""
67
- from .safety import hil_safe_inference
68
- bits = text_to_bits(text)
69
- tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
70
- out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor)
71
- return bits_to_text(out_bits.squeeze(0).tolist())
72
-
73
-
74
- def sample_text(
75
- model: "BitTransformerLM",
76
- prompt: str,
77
- max_new_tokens: int = 16,
78
- temperature: float = 1.0,
79
- top_p: float = 1.0,
80
- ) -> str:
81
- """Generate text from the model using simple top-p sampling."""
82
- model.eval()
83
- bits = text_to_bits(prompt)
84
- tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0)
85
- for _ in range(max_new_tokens * 9):
86
- if tensor.size(1) >= model.pos_enc.pe.size(0):
87
- break
88
- logits, _ = model(tensor, causal=True)
89
- prob = logits[0, -1].softmax(-1) / temperature
90
- sorted_prob, sorted_idx = prob.sort(descending=True)
91
- cumulative = sorted_prob.cumsum(0)
92
- mask = cumulative > top_p
93
- sorted_prob[mask] = 0
94
- sorted_prob = sorted_prob / sorted_prob.sum()
95
- next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)]
96
- tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1)
97
- return bits_to_text(tensor.squeeze(0).tolist())