WCNegentropy commited on
Commit
4f80a3a
·
verified ·
1 Parent(s): bc0d887

Remove nested directory: BitTransformerLM/tests/test_quantization.py

Browse files
BitTransformerLM/tests/test_quantization.py DELETED
@@ -1,19 +0,0 @@
1
- import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
2
- import torch
3
- from bit_transformer import BitTransformerLM, quantize_dynamic, prepare_qat_fx, convert_qat_fx
4
- from bit_transformer.training import train_loop
5
-
6
- def test_qat_matches_dynamic_quant():
7
- data = torch.randint(0, 2, (16, 8), dtype=torch.long)
8
- base = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8)
9
- train_loop(base, data, epochs=1, log=False)
10
- dyn = quantize_dynamic(base)
11
- qat_model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8)
12
- qat_model.load_state_dict(base.state_dict())
13
- prepare_qat_fx(qat_model)
14
- convert_qat_fx(qat_model)
15
- inp = torch.randint(0, 2, (10, 8), dtype=torch.long)
16
- out_dyn, _ = dyn(inp)
17
- out_qat, _ = qat_model(inp)
18
- diff = (out_dyn - out_qat).abs().max().item()
19
- assert diff < 0.6