WCNegentropy commited on
Commit
cf9795f
·
verified ·
1 Parent(s): f6e5ceb

Remove nested directory: BitTransformerLM/bit_transformer/compression.py

Browse files
BitTransformerLM/bit_transformer/compression.py DELETED
@@ -1,82 +0,0 @@
1
- import torch
2
- from typing import List
3
-
4
-
5
- def compress_bits(bits: torch.Tensor) -> torch.Tensor:
6
- """Run-length encode a 1D tensor of bits.
7
-
8
- Args:
9
- bits: 1D tensor with values 0 or 1 (bool or uint8).
10
-
11
- Returns:
12
- 1D uint8 tensor containing interleaved values and run lengths.
13
- """
14
- if bits.dim() != 1:
15
- raise ValueError("compress_bits expects a 1D tensor")
16
- b = bits.to(torch.uint8).flatten()
17
- if b.numel() == 0:
18
- return b
19
- changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
20
- starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
21
- ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
22
- values = b[starts.to(torch.long)]
23
- counts = ends - starts
24
-
25
- out_vals: List[int] = []
26
- out_counts: List[int] = []
27
- for v, c in zip(values.tolist(), counts.tolist()):
28
- while c > 255:
29
- out_vals.append(v)
30
- out_counts.append(255)
31
- c -= 255
32
- out_vals.append(v)
33
- out_counts.append(c)
34
- values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
35
- counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
36
- out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
37
- return out
38
-
39
-
40
- def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
41
- """Decode a run-length encoded bit tensor."""
42
- if compressed.dim() != 1 or compressed.numel() % 2 != 0:
43
- raise ValueError("compressed tensor must be 1D even-length")
44
- data = compressed.to(torch.uint8)
45
- values = data[0::2]
46
- counts = data[1::2].to(torch.long)
47
- return torch.repeat_interleave(values, counts)
48
-
49
-
50
- def model_output_decompress(compressed_batch) -> torch.Tensor:
51
- """Decompress a batch of compressed bit sequences."""
52
- if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
53
- sequences = [decompress_bits(compressed_batch)]
54
- else:
55
- sequences = [decompress_bits(row) for row in compressed_batch]
56
- lengths = [seq.numel() for seq in sequences]
57
- if len(set(lengths)) != 1:
58
- raise ValueError("Sequences decompress to different lengths")
59
- return torch.stack(sequences)
60
-
61
-
62
- import numpy as np
63
-
64
-
65
- def pack_bits(bits: torch.Tensor) -> torch.Tensor:
66
- """Pack groups of 8 bits into uint8 values using numpy.packbits."""
67
- if bits.dim() != 1:
68
- raise ValueError("pack_bits expects a 1D tensor")
69
- arr = bits.to(torch.uint8).cpu().numpy()
70
- packed = np.packbits(arr)
71
- return torch.from_numpy(packed)
72
-
73
-
74
- def unpack_bits(packed: torch.Tensor, *, n_bits: int | None = None) -> torch.Tensor:
75
- """Unpack uint8 values back into a bit tensor."""
76
- if packed.dim() != 1:
77
- raise ValueError("unpack_bits expects a 1D tensor")
78
- arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
79
- if n_bits is not None:
80
- arr = arr[:n_bits]
81
- return torch.from_numpy(arr)
82
-