WCNegentropy commited on
Commit
f5498bc
·
verified ·
1 Parent(s): f443b7f

Remove nested directory: BitTransformerLM/recursive_integration_flow.py

Browse files
BitTransformerLM/recursive_integration_flow.py DELETED
@@ -1,128 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch.profiler import profile
4
- from bit_transformer import (
5
- BitTransformerLM,
6
- quantize_dynamic,
7
- hil_safe_inference,
8
- collapse_submodel,
9
- )
10
- from bit_transformer.training import train_loop
11
- from bit_transformer.torch_utils import cpu_autocast
12
-
13
-
14
- def train(
15
- model: BitTransformerLM,
16
- data: torch.Tensor,
17
- epochs: int = 1,
18
- compress_prob: float = 0.5,
19
- log: bool = False,
20
- forward_kwargs: dict | None = None,
21
- ) -> list[dict]:
22
- """Train with random compression; returns per-epoch metrics."""
23
- return train_loop(
24
- model,
25
- data,
26
- epochs=epochs,
27
- compress_prob=compress_prob,
28
- direct_prob=0.0,
29
- log=log,
30
- forward_kwargs=forward_kwargs,
31
- )
32
-
33
-
34
- def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None:
35
- """Run a dynamic scale-up loop with telemetry-based gating."""
36
- train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long)
37
- valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long)
38
- input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long)
39
- bit_sequence_data = train_bits.tolist()
40
-
41
- best_K = best_C = best_S = 0.0
42
-
43
- model = BitTransformerLM(
44
- d_model=32,
45
- nhead=4,
46
- num_layers=1,
47
- dim_feedforward=64,
48
- max_seq_len=max_len,
49
- use_act=True,
50
- act_threshold=0.7,
51
- reversible=True,
52
- chunk_size=max_len,
53
- use_autocast=True,
54
- )
55
-
56
- results = []
57
- for step in range(steps + 1):
58
- epochs = min(10, 2 + step // 2)
59
- train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True)
60
-
61
- with torch.no_grad():
62
- with cpu_autocast():
63
- logits, telemetry = model(valid_bits)
64
- pred = logits[:, :-1, :].reshape(-1, 2)
65
- target = valid_bits[:, 1:].reshape(-1)
66
- val_loss = F.cross_entropy(pred, target).item()
67
- k = telemetry["negentropy_logits"].mean().item()
68
- c = telemetry["lz_complexity_logits"].mean().item()
69
- s = telemetry["symbiosis_score"].mean().item()
70
-
71
- print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
72
- results.append((step, val_loss, k, c, s))
73
-
74
- if step > 0:
75
- if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3:
76
- print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.")
77
- break
78
- best_K = max(best_K, k)
79
- best_C = max(best_C, c)
80
- best_S = max(best_S, s)
81
-
82
- if step < steps:
83
- if step % 2 == 0:
84
- model = model.double_width()
85
- else:
86
- model = model.double_layers()
87
-
88
- # Post-scaling optimizations
89
- with cpu_autocast():
90
- model(input_bits)
91
-
92
- qmodel = quantize_dynamic(model)
93
- qmodel.eval()
94
-
95
- safe_output = hil_safe_inference(
96
- qmodel, input_bits, c_floor=0.5, s_floor=0.2
97
- )
98
-
99
- student_model, _ = collapse_submodel(
100
- bit_sequence_data,
101
- target_params=dict(
102
- d_model=16,
103
- nhead=4,
104
- num_layers=1,
105
- dim_feedforward=32,
106
- max_seq_len=max_len,
107
- ),
108
- floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2},
109
- )
110
-
111
- if hasattr(torch, "compile"):
112
- try:
113
- compiled = torch.compile(student_model)
114
- except RuntimeError as exc:
115
- print(f"Compilation skipped: {exc}")
116
- compiled = student_model
117
- else:
118
- compiled = student_model
119
- compiled.eval()
120
-
121
- with profile() as prof:
122
- compiled(input_bits)
123
- prof.export_chrome_trace("trace12.json")
124
- print("Safe output bits:", safe_output[0].tolist())
125
-
126
-
127
- if __name__ == "__main__":
128
- recursive_integration_flow()