| | import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | hil_safe_inference, |
| | text_to_bits, |
| | bits_to_text, |
| | plot_telemetry, |
| | infer_long_sequence, |
| | diffusion_inference, |
| | compress_bits, |
| | ) |
| | from bit_transformer.safety import SafetyGate |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | import pytest |
| |
|
| | def test_forward_pass(): |
| | B, L = 2, 8 |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L) |
| | bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
| | logits, telemetry = model(bits) |
| | assert logits.shape == (B, L, 2) |
| | required_keys = { |
| | "negentropy_input", |
| | "lz_complexity_input", |
| | "negentropy_logits", |
| | "lz_complexity_logits", |
| | "symbiosis_kl", |
| | "symbiosis_score", |
| | "attention_entropy", |
| | "attention_entropy_mean", |
| | } |
| | assert required_keys.issubset(telemetry.keys()) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = bits[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | assert torch.isfinite(loss) |
| |
|
| |
|
| | def test_autocast_forward(): |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=64, |
| | max_seq_len=8, |
| | use_autocast=True, |
| | ) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, _ = model(bits) |
| | assert logits.shape == (1, 8, 2) |
| |
|
| |
|
| | def test_act_forward(): |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=2, |
| | dim_feedforward=64, |
| | max_seq_len=8, |
| | use_act=True, |
| | ) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, tele = model(bits) |
| | assert logits.shape == (1, 8, 2) |
| | assert "halt_probs" in tele |
| |
|
| |
|
| | def test_act_skips_layers(): |
| | model = BitTransformerLM( |
| | d_model=16, |
| | nhead=4, |
| | num_layers=3, |
| | dim_feedforward=32, |
| | max_seq_len=8, |
| | use_act=True, |
| | act_threshold=0.5, |
| | ) |
| | for proj in model.halt_projs: |
| | nn.init.constant_(proj.weight, 0.0) |
| | nn.init.constant_(proj.bias, 10.0) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | _, tele = model(bits) |
| | assert len(tele["halt_probs"]) < model.num_layers |
| |
|
| |
|
| | def test_hil_safety_gate(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | |
| | raised = False |
| | try: |
| | hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0) |
| | except RuntimeError: |
| | raised = True |
| | assert raised |
| |
|
| |
|
| | def test_hil_safety_non_strict(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | out, _ = hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0, strict=False) |
| | assert out.shape == bits.shape |
| |
|
| |
|
| | def test_safety_gate_burn_in(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | gate = SafetyGate(c_floor=1.0, s_floor=1.0, burn_in=1) |
| | hil_safe_inference(model, bits, gate=gate) |
| | with pytest.raises(RuntimeError): |
| | hil_safe_inference(model, bits, gate=gate) |
| |
|
| |
|
| | def test_bit_io_roundtrip(): |
| | text = "hello" |
| | bits = text_to_bits(text) |
| | assert bits_to_text(bits) == text |
| |
|
| |
|
| | def test_plot_telemetry(): |
| | log = { |
| | "negentropy": [0.6, 0.7, 0.4], |
| | "lz_complexity": [0.5, 0.45, 0.6], |
| | "symbiosis_score": [0.55, 0.6, 0.3], |
| | "clusters": [0, 0, 1], |
| | } |
| | fig, axes = plot_telemetry(log) |
| | assert len(axes) == 3 |
| | fig.clf() |
| |
|
| |
|
| | def test_metric_no_gradient_flow(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | bits = torch.randint(0, 2, (2, 8), dtype=torch.long) |
| | logits, _ = model(bits) |
| | loss = model.negentropy_logits(logits).mean() + model.lz_complexity_logits(logits).mean() |
| | assert not loss.requires_grad |
| | with pytest.raises(RuntimeError): |
| | loss.backward() |
| |
|
| |
|
| | def test_negentropy_decompression_edge_case(): |
| | bits = torch.tensor([0, 1] * 8, dtype=torch.uint8) |
| | comp = compress_bits(bits) |
| | model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=bits.numel()) |
| | neg_comp = model.negentropy_kpi(comp.unsqueeze(0)) |
| | neg_raw = model.negentropy_kpi(bits.unsqueeze(0)) |
| | assert torch.allclose(neg_comp, neg_raw, atol=1e-6) |
| |
|
| |
|
| | def test_dynamic_quantization(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | from bit_transformer import quantize_dynamic |
| |
|
| | qmodel = quantize_dynamic(model) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, _ = qmodel(bits) |
| | assert logits.shape == (1, 8, 2) |
| |
|
| |
|
| | def test_qat_fx_roundtrip(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | from bit_transformer import prepare_qat_fx, convert_qat_fx |
| |
|
| | example_bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | qat_model = prepare_qat_fx(model) |
| | qat_model.eval() |
| | qmodel = convert_qat_fx(qat_model) |
| |
|
| | logits, _ = qmodel(example_bits) |
| | assert logits.shape == (1, 8, 2) |
| |
|
| |
|
| | def test_fsdp_wrap(): |
| | import os |
| | import torch |
| | import torch.distributed as dist |
| | from bit_transformer import BitTransformerLM, wrap_fsdp |
| |
|
| | if not dist.is_initialized(): |
| | os.environ.setdefault("MASTER_ADDR", "localhost") |
| | os.environ.setdefault("MASTER_PORT", "29500") |
| | dist.init_process_group("gloo", rank=0, world_size=1) |
| | if not torch.cuda.is_available(): |
| | pytest.skip("CUDA not available") |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | fsdp_model = wrap_fsdp(model) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, _ = fsdp_model(bits) |
| | assert logits.shape == (1, 8, 2) |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def test_make_pipeline(): |
| | import pytest |
| | import torch.distributed.rpc as rpc |
| | from bit_transformer import BitTransformerLM, make_pipeline |
| |
|
| | if not rpc._is_current_rpc_agent_set(): |
| | pytest.skip("RPC not initialized") |
| |
|
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | pipe_model = make_pipeline(model, chunks=1) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, _ = pipe_model(bits) |
| | assert logits.shape == (1, 8, 2) |
| |
|
| |
|
| | def test_causal_attention(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | logits, tele = model(bits, causal=True) |
| | assert logits.shape == (1, 8, 2) |
| | attn = tele["attention_maps"][0] |
| | upper = attn.triu(1) |
| | assert torch.allclose(upper, torch.zeros_like(upper)) |
| |
|
| |
|
| | def test_scaling_helpers(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | model = model.double_width() |
| | assert model.d_model == 64 |
| | model = model.double_layers() |
| | assert model.num_layers == 2 |
| |
|
| |
|
| | def test_expand_positional_encoding(): |
| | model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) |
| | model.expand_positional_encoding(16) |
| | assert model.pos_enc.pe.size(0) == 16 |
| |
|
| |
|
| | def test_infer_long_sequence(): |
| | model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) |
| | bits = torch.randint(0, 2, (12,), dtype=torch.long) |
| | preds, logs = infer_long_sequence(model, bits, ctx_bits=8, overlap=4) |
| | assert len(preds) == 12 |
| | assert len(logs) >= 2 |
| |
|
| |
|
| | def test_chunking_disabled_when_non_causal(): |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=64, |
| | max_seq_len=8, |
| | chunk_size=2, |
| | full_attn_logging=True, |
| | ) |
| | |
| | |
| | |
| | nn.init.constant_(model.layers[0].self_attn.in_proj_weight, 0.0) |
| | nn.init.constant_(model.layers[0].self_attn.in_proj_bias, 0.0) |
| | |
| | model.eval() |
| | for module in model.modules(): |
| | if isinstance(module, nn.Dropout): |
| | module.p = 0.0 |
| | model.layers[0].self_attn.dropout = 0.0 |
| |
|
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | _, tele_causal = model(bits, causal=True) |
| | _, tele_noncausal = model(bits, causal=False) |
| | attn_causal = tele_causal["attention_maps"][0] |
| | attn_noncausal = tele_noncausal["attention_maps"][0] |
| | |
| | |
| | assert attn_causal[0, 0, 0, 4] == 0 |
| | assert attn_noncausal[0, 0, 0, 4] > 0 |
| |
|
| |
|
| | def test_diffusion_inference_generates_bits(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | out = diffusion_inference(model, length=8, steps=2, batch_size=2) |
| | assert out.shape == (2, 8) |
| | assert set(out.unique().tolist()).issubset({0, 1}) |
| |
|
| |
|
| | def test_diffusion_inference_cosine_schedule(): |
| | model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| | out = diffusion_inference(model, length=8, steps=2, schedule="cosine") |
| | assert out.shape == (1, 8) |
| |
|
| |
|
| | def test_chunking_restored_after_diffusion(): |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=64, |
| | max_seq_len=8, |
| | chunk_size=2, |
| | full_attn_logging=True, |
| | ) |
| | bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| | _ = model(bits, causal=False) |
| | assert model.layers[0].chunk_size == 2 |
| | _, tele = model(bits, causal=True) |
| | attn = tele["attention_maps"][0] |
| | assert attn[0, 0, 0, 4] == 0 |
| |
|