WCNegentropy commited on
Commit
3523f82
·
verified ·
1 Parent(s): cb2d747

Remove nested directory: BitTransformerLM/tests/test_compression.py

Browse files
BitTransformerLM/tests/test_compression.py DELETED
@@ -1,63 +0,0 @@
1
- import os, sys
2
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
3
- import torch
4
- from bit_transformer import BitTransformerLM, compress_bits, decompress_bits, model_output_decompress
5
-
6
-
7
- def test_compress_roundtrip():
8
- bits = torch.randint(0, 2, (16,), dtype=torch.uint8)
9
- comp = compress_bits(bits)
10
- decomp = decompress_bits(comp)
11
- assert torch.equal(bits, decomp)
12
-
13
-
14
- def test_forward_compressed_equivalence():
15
- B, L = 2, 8
16
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L)
17
- model.eval()
18
- bits = torch.randint(0, 2, (B, L), dtype=torch.long)
19
- logits_a, tele_a = model(bits)
20
- compressed = [compress_bits(row.to(torch.uint8)) for row in bits]
21
- logits_b, tele_b = model.forward_compressed(compressed)
22
- assert torch.allclose(logits_a, logits_b)
23
- for key in tele_a:
24
- if isinstance(tele_a[key], list):
25
- continue
26
- assert torch.allclose(tele_a[key], tele_b[key])
27
-
28
-
29
- def test_model_output_decompress():
30
- bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
31
- comp = [compress_bits(row) for row in bits]
32
- decomp = model_output_decompress(comp)
33
- assert torch.equal(decomp, bits)
34
-
35
-
36
- def test_metrics_on_compressed():
37
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
38
- bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
39
- comps = [compress_bits(row) for row in bits]
40
- comp_batch = torch.nn.utils.rnn.pad_sequence(comps, batch_first=True)
41
- neg = model.negentropy_kpi(comp_batch)
42
- assert neg.shape[0] == bits.size(0)
43
-
44
-
45
- def test_compress_long_run_split():
46
- bits = torch.zeros(300, dtype=torch.uint8)
47
- comp = compress_bits(bits)
48
- expected = torch.tensor([0, 255, 0, 45], dtype=torch.uint8)
49
- assert torch.equal(comp, expected)
50
- decomp = decompress_bits(comp)
51
- assert torch.equal(decomp, bits)
52
-
53
-
54
- def test_compress_long_run_with_change():
55
- run1 = torch.ones(260, dtype=torch.uint8)
56
- run2 = torch.zeros(10, dtype=torch.uint8)
57
- bits = torch.cat([run1, run2])
58
- comp = compress_bits(bits)
59
- expected = torch.tensor([1, 255, 1, 5, 0, 10], dtype=torch.uint8)
60
- assert torch.equal(comp, expected)
61
- decomp = decompress_bits(comp)
62
- assert torch.equal(decomp, bits)
63
-