| import torch |
| import torch.nn.functional as F |
| from torch.profiler import profile |
| from bit_transformer import ( |
| BitTransformerLM, |
| quantize_dynamic, |
| hil_safe_inference, |
| collapse_submodel, |
| ) |
| from bit_transformer.training import train_loop |
| from bit_transformer.torch_utils import cpu_autocast |
|
|
|
|
| def train( |
| model: BitTransformerLM, |
| data: torch.Tensor, |
| epochs: int = 1, |
| compress_prob: float = 0.5, |
| log: bool = False, |
| forward_kwargs: dict | None = None, |
| ) -> list[dict]: |
| """Train with random compression; returns per-epoch metrics.""" |
| return train_loop( |
| model, |
| data, |
| epochs=epochs, |
| compress_prob=compress_prob, |
| direct_prob=0.0, |
| log=log, |
| forward_kwargs=forward_kwargs, |
| ) |
|
|
|
|
| def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None: |
| """Run a dynamic scale-up loop with telemetry-based gating.""" |
| train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long) |
| valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long) |
| input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long) |
| bit_sequence_data = train_bits.tolist() |
|
|
| best_K = best_C = best_S = 0.0 |
|
|
| model = BitTransformerLM( |
| d_model=32, |
| nhead=4, |
| num_layers=1, |
| dim_feedforward=64, |
| max_seq_len=max_len, |
| use_act=True, |
| act_threshold=0.7, |
| reversible=True, |
| chunk_size=max_len, |
| use_autocast=True, |
| ) |
|
|
| results = [] |
| for step in range(steps + 1): |
| epochs = min(10, 2 + step // 2) |
| train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True) |
|
|
| with torch.no_grad(): |
| with cpu_autocast(): |
| logits, telemetry = model(valid_bits) |
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = valid_bits[:, 1:].reshape(-1) |
| val_loss = F.cross_entropy(pred, target).item() |
| k = telemetry["negentropy_logits"].mean().item() |
| c = telemetry["lz_complexity_logits"].mean().item() |
| s = telemetry["symbiosis_score"].mean().item() |
|
|
| print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
| results.append((step, val_loss, k, c, s)) |
|
|
| if step > 0: |
| if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3: |
| print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.") |
| break |
| best_K = max(best_K, k) |
| best_C = max(best_C, c) |
| best_S = max(best_S, s) |
|
|
| if step < steps: |
| if step % 2 == 0: |
| model = model.double_width() |
| else: |
| model = model.double_layers() |
|
|
| |
| with cpu_autocast(): |
| model(input_bits) |
|
|
| qmodel = quantize_dynamic(model) |
| qmodel.eval() |
|
|
| safe_output = hil_safe_inference( |
| qmodel, input_bits, c_floor=0.5, s_floor=0.2 |
| ) |
|
|
| student_model, _ = collapse_submodel( |
| bit_sequence_data, |
| target_params=dict( |
| d_model=16, |
| nhead=4, |
| num_layers=1, |
| dim_feedforward=32, |
| max_seq_len=max_len, |
| ), |
| floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2}, |
| ) |
|
|
| if hasattr(torch, "compile"): |
| try: |
| compiled = torch.compile(student_model) |
| except RuntimeError as exc: |
| print(f"Compilation skipped: {exc}") |
| compiled = student_model |
| else: |
| compiled = student_model |
| compiled.eval() |
|
|
| with profile() as prof: |
| compiled(input_bits) |
| prof.export_chrome_trace("trace12.json") |
| print("Safe output bits:", safe_output[0].tolist()) |
|
|
|
|
| if __name__ == "__main__": |
| recursive_integration_flow() |
|
|