WCNegentropy commited on
Commit
c1c18dc
·
verified ·
1 Parent(s): 68ec438

Remove nested directory: BitTransformerLM/integration_flow.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/integration_flow.py +0 -110
BitTransformerLM/integration_flow.py DELETED
@@ -1,110 +0,0 @@
1
- import torch
2
- from torch.profiler import profile
3
- from bit_transformer import (
4
- BitTransformerLM,
5
- quantize_dynamic,
6
- hil_safe_inference,
7
- collapse_submodel,
8
- )
9
- from bit_transformer.training import train_loop
10
- from bit_transformer.torch_utils import cpu_autocast
11
-
12
- def train(
13
- model: BitTransformerLM,
14
- data: torch.Tensor,
15
- epochs: int = 3,
16
- compress_prob: float = 0.5,
17
- direct_prob: float = 0.0,
18
- log: bool = False,
19
- forward_kwargs: dict | None = None,
20
- ) -> list[dict]:
21
- """Train on bit sequences with optional random compression.
22
-
23
- If ``direct_prob`` is positive, some batches are fed using their
24
- run-length encoded representation packed into bits. Loss on these
25
- direct-compressed batches is tracked separately.
26
-
27
- Returns a list of per-epoch metric dictionaries containing raw and
28
- compressed loss/accuracy statistics and the mean compression ratio.
29
- """
30
- return train_loop(
31
- model,
32
- data,
33
- epochs=epochs,
34
- compress_prob=compress_prob,
35
- direct_prob=direct_prob,
36
- log=log,
37
- forward_kwargs=forward_kwargs,
38
- )
39
-
40
-
41
- def main() -> None:
42
- data = torch.randint(0, 2, (64, 128), dtype=torch.long)
43
- validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long)
44
- input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long)
45
- bit_sequence_data = data.tolist()
46
-
47
- model = BitTransformerLM(
48
- d_model=32,
49
- nhead=4,
50
- num_layers=1,
51
- dim_feedforward=64,
52
- max_seq_len=128,
53
- use_act=True,
54
- act_threshold=0.7,
55
- reversible=True,
56
- chunk_size=128,
57
- )
58
-
59
- for step in range(1, 13):
60
- if step % 2 == 0:
61
- model = model.double_width()
62
- else:
63
- model = model.double_layers()
64
- train(model, data, epochs=3, compress_prob=0.5, log=True)
65
- _, telemetry = model(validation_bits)
66
- K = telemetry["negentropy_logits"].mean().item()
67
- C = telemetry["lz_complexity_logits"].mean().item()
68
- S = telemetry["symbiosis_score"].mean().item()
69
- assert (
70
- K > 0.3 and C > 0.35 and S > 0.5
71
- ), f"Step {step} telemetry floor failure"
72
-
73
- with cpu_autocast():
74
- model(input_bits)
75
-
76
- quantized_model = quantize_dynamic(model)
77
- quantized_model.eval()
78
-
79
- safe_output, _ = hil_safe_inference(
80
- quantized_model, input_bits, c_floor=0.35, s_floor=0.5
81
- )
82
-
83
- student_model, _ = collapse_submodel(
84
- bit_sequence_data,
85
- target_params=dict(
86
- d_model=16,
87
- nhead=4,
88
- num_layers=1,
89
- dim_feedforward=32,
90
- max_seq_len=128,
91
- ),
92
- floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5},
93
- )
94
-
95
- compiled_model = (
96
- torch.compile(student_model)
97
- if hasattr(torch, "compile")
98
- else student_model
99
- )
100
- compiled_model.eval()
101
-
102
- with profile() as prof:
103
- compiled_model(input_bits)
104
-
105
- prof.export_chrome_trace("trace12.json")
106
- print("Safe output bits:", safe_output.squeeze(0).tolist())
107
-
108
-
109
- if __name__ == "__main__":
110
- main()