| | import math |
| | import contextlib |
| | import logging |
| | from typing import Dict, List, Tuple, Optional |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import sys |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint as checkpoint |
| |
|
| | from .torch_utils import cpu_autocast |
| |
|
| | from .optimization import configure_optimizer |
| | from .compression import decompress_bits |
| | from .parity import enforce_parity |
| |
|
| | _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} |
| | _attention_cache: Dict[str, torch.Tensor] = {} |
| | _MAX_CACHE_SIZE = 50 |
| |
|
| |
|
| | def clear_cache(): |
| | """Clear memory caches to prevent OOM in long sequences.""" |
| | global _mask_cache, _attention_cache |
| | _mask_cache.clear() |
| | _attention_cache.clear() |
| |
|
| |
|
| | def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
| | """Return or create a cached upper-triangular mask with memory management.""" |
| | key = (seq_len, device) |
| | |
| | |
| | if len(_mask_cache) > _MAX_CACHE_SIZE: |
| | clear_cache() |
| | |
| | if key not in _mask_cache: |
| | _mask_cache[key] = torch.triu( |
| | torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 |
| | ) |
| | return _mask_cache[key] |
| |
|
| | try: |
| | if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
| | compile_fn = torch.compile |
| | else: |
| | raise RuntimeError |
| | except Exception: |
| |
|
| | def compile_fn(fn=None, **kwargs): |
| | if fn is None: |
| | return lambda f: f |
| | return fn |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | """Sinusoidal positional encoding.""" |
| |
|
| | def __init__(self, d_model: int, max_len: int = 1024) -> None: |
| | super().__init__() |
| | pe = torch.zeros(max_len, d_model) |
| | pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| | inv = torch.exp( |
| | torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) |
| | ) |
| | pe[:, 0::2] = torch.sin(pos * inv) |
| | pe[:, 1::2] = torch.cos(pos * inv) |
| | self.register_buffer("pe", pe.unsqueeze(1)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Add positional encoding to input tensor.""" |
| | return x + self.pe[: x.size(0)] |
| |
|
| |
|
| | class LoggingTransformerEncoderLayer(nn.Module): |
| | """Transformer encoder layer that exposes attention weights. |
| | |
| | It optionally performs chunked attention with a fixed window size. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | dim_feedforward: int = 512, |
| | dropout: float = 0.1, |
| | chunk_size: Optional[int] = None, |
| | overlap: int = 0, |
| | full_attn_logging: Optional[bool] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| | self.chunk_size = chunk_size |
| | self.overlap = overlap |
| | if full_attn_logging is None: |
| | full_attn_logging = False if chunk_size is not None else True |
| | self.full_attn_logging = full_attn_logging |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| | self.activation = F.relu |
| |
|
| | def _chunked_attn( |
| | self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Perform memory-efficient chunked self attention with overlap.""" |
| | T, B, D = src.shape |
| | |
| | |
| | if T <= 128 or self.chunk_size is None or self.chunk_size >= T: |
| | return self._full_attn(src, attn_mask) |
| | |
| | src_b = src.transpose(0, 1) |
| | C = self.chunk_size |
| | O = self.overlap |
| | n_chunks = (T + C - 1) // C |
| | pad_len = n_chunks * C - T |
| | |
| | |
| | outputs = [] |
| | weights_list = [] |
| | |
| | |
| | with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): |
| | for chunk_idx in range(n_chunks): |
| | start_idx = chunk_idx * C |
| | end_idx = min(start_idx + C + 2 * O, T + O) |
| | |
| | |
| | chunk_start = max(0, start_idx - O) |
| | chunk_end = min(T, end_idx) |
| | chunk = src_b[:, chunk_start:chunk_end] |
| | |
| | |
| | if chunk.size(1) < C + 2 * O: |
| | pad_size = C + 2 * O - chunk.size(1) |
| | chunk = F.pad(chunk, (0, 0, 0, pad_size)) |
| | |
| | chunk_len = chunk.size(1) |
| | mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None |
| | |
| | |
| | out, weights = self.self_attn( |
| | chunk, chunk, chunk, |
| | attn_mask=mask, |
| | need_weights=self.full_attn_logging, |
| | average_attn_weights=False, |
| | ) |
| | |
| | |
| | core_start = O if chunk_idx > 0 else 0 |
| | core_end = core_start + min(C, T - start_idx) |
| | outputs.append(out[:, core_start:core_end]) |
| | |
| | if self.full_attn_logging and weights is not None: |
| | weights_list.append(weights[:, :, core_start:core_end]) |
| | |
| | |
| | del out, weights, chunk |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | |
| | seq = torch.cat(outputs, dim=1) |
| | |
| | |
| | if self.full_attn_logging and weights_list: |
| | |
| | if T > 1024: |
| | attn_out = torch.empty(0, device=src.device) |
| | else: |
| | attn_out = torch.cat(weights_list, dim=2) |
| | else: |
| | attn_out = torch.empty(0, device=src.device) |
| | |
| | return seq.transpose(0, 1), attn_out |
| | |
| | def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Standard full attention for smaller sequences.""" |
| | qkv = src.transpose(0, 1) |
| | attn_output, attn_weights = self.self_attn( |
| | qkv, qkv, qkv, |
| | attn_mask=attn_mask, |
| | need_weights=True, |
| | average_attn_weights=False, |
| | ) |
| | return attn_output.transpose(0, 1), attn_weights |
| |
|
| | def forward( |
| | self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Return output and attention map.""" |
| | if self.chunk_size is not None: |
| | attn_output, attn_weights = self._chunked_attn(src, attn_mask) |
| | else: |
| | qkv = src.transpose(0, 1) |
| | attn_output, attn_weights = self.self_attn( |
| | qkv, |
| | qkv, |
| | qkv, |
| | attn_mask=attn_mask, |
| | need_weights=True, |
| | average_attn_weights=False, |
| | ) |
| | attn_output = attn_output.transpose(0, 1) |
| | src = src + self.dropout1(attn_output) |
| | src = self.norm1(src) |
| | out = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| | src = src + self.dropout2(out) |
| | src = self.norm2(src) |
| | return src, attn_weights.detach() |
| |
|
| |
|
| | class ReversibleLoggingTransformerEncoderLayer(nn.Module): |
| | """Reversible transformer encoder layer with checkpointing.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | dim_feedforward: int = 512, |
| | dropout: float = 0.1, |
| | chunk_size: Optional[int] = None, |
| | overlap: int = 0, |
| | full_attn_logging: Optional[bool] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| | self.chunk_size = chunk_size |
| | self.overlap = overlap |
| | if full_attn_logging is None: |
| | full_attn_logging = False if chunk_size is not None else True |
| | self.full_attn_logging = full_attn_logging |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| | self.activation = F.relu |
| |
|
| | @compile_fn |
| | def _sa_block( |
| | self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.chunk_size is not None: |
| | T, B, D = x.shape |
| | x_b = x.transpose(0, 1) |
| | C = self.chunk_size or T |
| | O = self.overlap |
| | n_chunks = (T + C - 1) // C |
| | pad_len = n_chunks * C - T |
| | src_pad = F.pad(x_b, (0, 0, O, pad_len + O)) |
| | chunk_len = C + 2 * O |
| | chunks = src_pad.unfold(1, chunk_len, C) |
| | mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None |
| | out, weights = self.self_attn( |
| | chunks.reshape(B * n_chunks, chunk_len, D), |
| | chunks.reshape(B * n_chunks, chunk_len, D), |
| | chunks.reshape(B * n_chunks, chunk_len, D), |
| | attn_mask=mask, |
| | need_weights=True, |
| | average_attn_weights=False, |
| | ) |
| | out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C] |
| | weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[ |
| | :, :, :, O : O + C |
| | ] |
| | seq = out.reshape(B, n_chunks * C, D)[:, :T] |
| | if self.full_attn_logging and C < T: |
| | full_attn = torch.zeros( |
| | B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device |
| | ) |
| | for idx in range(n_chunks): |
| | s = idx * C |
| | start = max(s - O, 0) |
| | end = min(s + C, n_chunks * C) |
| | src_start = O - (s - start) |
| | src_end = src_start + (end - start) |
| | full_attn[:, :, s : s + C, start:end] = weights[ |
| | :, idx, :, src_start:src_end |
| | ] |
| | full_attn = full_attn[:, :, :T, :T] |
| | weights = full_attn.detach() |
| | else: |
| | weights = torch.empty(0, device=x.device) |
| | attn_out = seq.transpose(0, 1) |
| | else: |
| | qkv = x.transpose(0, 1) |
| | attn_out, weights = self.self_attn( |
| | qkv, |
| | qkv, |
| | qkv, |
| | attn_mask=attn_mask, |
| | need_weights=True, |
| | average_attn_weights=False, |
| | ) |
| | attn_out = attn_out.transpose(0, 1) |
| | x = self.norm1(x + self.dropout1(attn_out)) |
| | return x, weights.detach() |
| |
|
| | @compile_fn |
| | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
| | out = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| | x = self.norm2(x + self.dropout2(out)) |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x1: torch.Tensor, |
| | x2: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | y1, weights = self._sa_block(x2, attn_mask) |
| | y1 = x1 + y1 |
| | y2 = x2 + self._ff_block(y1) |
| | return y1, y2, weights |
| |
|
| |
|
| | class BitTransformerLM(nn.Module): |
| | """Transformer language model that operates on raw bits (0/1) with telemetry.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model: int = 128, |
| | nhead: int = 8, |
| | num_layers: int = 4, |
| | dim_feedforward: int = 512, |
| | max_seq_len: int = 1024, |
| | lambda_K: float = 1.0, |
| | lambda_C: float = 1.0, |
| | lambda_S: float = 1.0, |
| | reversible: bool = False, |
| | use_checkpoint: bool = True, |
| | use_autocast: bool = False, |
| | use_act: bool = False, |
| | act_threshold: float = 0.9, |
| | chunk_size: Optional[int] = None, |
| | overlap: int = 0, |
| | full_attn_logging: Optional[bool] = None, |
| | ) -> None: |
| | """Create a BitTransformer language model. |
| | |
| | Args: |
| | full_attn_logging: When ``False`` and ``chunk_size`` is |
| | smaller than the sequence length, the model skips |
| | reconstructing the full ``TรT`` attention matrices for |
| | telemetry to reduce memory use. |
| | """ |
| | super().__init__() |
| | self.d_model = d_model |
| | self.num_layers = num_layers |
| | self.lambda_K = lambda_K |
| | self.lambda_C = lambda_C |
| | self.lambda_S = lambda_S |
| | self.reversible = reversible |
| | self.use_checkpoint = use_checkpoint |
| | self.use_autocast = use_autocast |
| | self.use_act = use_act |
| | self.act_threshold = act_threshold |
| | self.chunk_size = chunk_size |
| | self.overlap = overlap |
| | if full_attn_logging is None: |
| | full_attn_logging = False if chunk_size is not None else True |
| | self.full_attn_logging = full_attn_logging |
| |
|
| | |
| | self.embedding = nn.Embedding(2, d_model) |
| | self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len) |
| |
|
| | layer_cls = ( |
| | ReversibleLoggingTransformerEncoderLayer |
| | if reversible |
| | else LoggingTransformerEncoderLayer |
| | ) |
| | self.layers = nn.ModuleList( |
| | [ |
| | layer_cls( |
| | d_model=d_model, |
| | nhead=nhead, |
| | dim_feedforward=dim_feedforward, |
| | chunk_size=chunk_size, |
| | overlap=overlap, |
| | full_attn_logging=full_attn_logging, |
| | ) |
| | for _ in range(num_layers) |
| | ] |
| | ) |
| |
|
| | if self.use_act: |
| | self.halt_projs = nn.ModuleList( |
| | [nn.Linear(d_model, 1) for _ in range(num_layers)] |
| | ) |
| |
|
| | self.out_head = nn.Linear(d_model, 2) |
| |
|
| | def expand_positional_encoding(self, new_len: int) -> None: |
| | """Expand positional encoding to at least ``new_len``.""" |
| | cur_len = self.pos_enc.pe.size(0) |
| | if new_len <= cur_len: |
| | return |
| | device = self.pos_enc.pe.device |
| | d_model = self.d_model |
| | pe = torch.zeros(new_len, d_model, device=device) |
| | pe[:cur_len] = self.pos_enc.pe.squeeze(1) |
| | pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1) |
| | inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) |
| | pe[cur_len:, 0::2] = torch.sin(pos * inv) |
| | pe[cur_len:, 1::2] = torch.cos(pos * inv) |
| | self.pos_enc.pe = pe.unsqueeze(1) |
| |
|
| | def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None: |
| | """Update weighting coefficients for telemetry metrics.""" |
| | self.lambda_K = lambda_K |
| | self.lambda_C = lambda_C |
| | self.lambda_S = lambda_S |
| |
|
| | def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor: |
| | """Return raw bit sequences, decompressing if input appears run-length encoded.""" |
| | if codes.dim() <= 1: |
| | return codes |
| | needs_decompress = codes.max().item() > 1 |
| | if not needs_decompress and codes.size(1) % 2 == 0: |
| | vals = codes[:, 0::2] |
| | if torch.all(vals[:, 1:] != vals[:, :-1]): |
| | needs_decompress = True |
| | if not needs_decompress: |
| | return codes |
| | seqs = [decompress_bits(row.to(torch.uint8)) for row in codes] |
| | max_len = max(seq.numel() for seq in seqs) |
| | padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs] |
| | return torch.stack(padded) |
| |
|
| | def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor: |
| | """Approximate negentropy of bit sequences. |
| | |
| | Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered |
| | sequence (all zeros or ones) and ``0`` reflects maximal entropy. |
| | """ |
| | codes = self._maybe_decompress(codes) |
| | p = codes.float().mean(dim=1) |
| | entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
| | max_e = math.log(2.0) |
| | return 1 - entropy / max_e |
| |
|
| | def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor: |
| | """Differentiable proxy for LempelโZiv complexity. |
| | |
| | Values near ``0`` indicate highly compressible sequences while values |
| | approaching ``1`` correspond to rapid bit alternation. |
| | """ |
| | codes = self._maybe_decompress(codes) |
| | diffs = torch.abs(codes[:, 1:] - codes[:, :-1]) |
| | return diffs.float().mean(dim=1) |
| |
|
| | def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
| | """Negentropy computed from model logits. |
| | |
| | Parameters |
| | ---------- |
| | logits: ``torch.Tensor`` |
| | Logit tensor of shape ``(B, T, 2)``. |
| | detach: bool, default ``True`` |
| | When ``True`` the computation is detached from the autograd graph. |
| | """ |
| | assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| | prob = logits.softmax(-1) |
| | if detach: |
| | prob = prob.detach() |
| | p = prob[..., 1].mean(dim=1) |
| | entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
| | max_e = math.log(2.0) |
| | return 1 - entropy / max_e |
| |
|
| | def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
| | """LZ complexity proxy computed from logits. |
| | |
| | Parameters |
| | ---------- |
| | logits: ``torch.Tensor`` |
| | Logit tensor of shape ``(B, T, 2)``. |
| | detach: bool, default ``True`` |
| | When ``True`` the computation is detached from the autograd graph. |
| | """ |
| | assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| | prob = logits.softmax(-1) |
| | if detach: |
| | prob = prob.detach() |
| | prob1 = prob[..., 1] |
| | diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1]) |
| | return diffs.mean(dim=1) |
| |
|
| | def symbiosis_kl_logits( |
| | self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True |
| | ) -> torch.Tensor: |
| | """Symbiosis score from KL divergence to a reference distribution. |
| | |
| | Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with |
| | the reference distribution and ``0`` indicating maximal divergence. |
| | """ |
| | assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| | probs = logits.softmax(-1) |
| | if detach: |
| | probs = probs.detach() |
| | ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device) |
| | kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1) |
| | max_kl = math.log(2.0) |
| | return 1 - kl / max_kl |
| |
|
| | def _act_step( |
| | self, |
| | hidden: torch.Tensor, |
| | idx: int, |
| | halt_prob: torch.Tensor, |
| | act_state: torch.Tensor, |
| | halt_history: List[torch.Tensor], |
| | ) -> Tuple[torch.Tensor, torch.Tensor, bool]: |
| | """Apply one step of ACT halting logic.""" |
| | p = torch.sigmoid(self.halt_projs[idx](hidden)) |
| | delta = (1 - halt_prob) * p |
| | halt_prob = halt_prob + delta |
| | act_state = act_state + hidden * delta |
| | halt_history.append(halt_prob.detach()) |
| | min_prob = halt_prob.detach().min() |
| | if dist.is_initialized(): |
| | dist.all_reduce(min_prob, op=dist.ReduceOp.MIN) |
| | return halt_prob, act_state, min_prob.item() >= self.act_threshold |
| |
|
| | def forward( |
| | self, bit_seq: torch.Tensor, causal: bool = True |
| | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """Forward pass returning logits and telemetry from the same graph. |
| | |
| | By default the model uses causal masking and (optional) chunked |
| | attention. When ``causal`` is ``False`` the model operates in |
| | "Diffusion LM" mode. In this mode chunked attention is temporarily |
| | disabled so that every token can attend to the full sequence |
| | bidirectionally. The original chunking configuration is restored after |
| | the forward pass. |
| | """ |
| |
|
| | |
| | orig_chunks = None |
| | orig_model_chunk = None |
| | if not causal and self.chunk_size is not None: |
| | orig_model_chunk = self.chunk_size |
| | orig_chunks = [layer.chunk_size for layer in self.layers] |
| | self.chunk_size = None |
| | for layer in self.layers: |
| | layer.chunk_size = None |
| |
|
| | try: |
| | ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext() |
| | with ctx: |
| | x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model) |
| | x = self.pos_enc(x) |
| |
|
| | attn_mask = get_tri_mask(x.size(0), x.device) if causal else None |
| |
|
| | activations: List[torch.Tensor] = [] |
| | attn_maps: List[torch.Tensor] = [] |
| | halt_history: List[torch.Tensor] = [] |
| | if self.use_act: |
| | halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device) |
| | act_state = torch.zeros_like(x) |
| | if self.reversible: |
| | x1, x2 = x, x |
| | for idx, layer in enumerate(self.layers): |
| | if self.use_checkpoint: |
| | x1, x2, attn = checkpoint.checkpoint( |
| | layer, x1, x2, attn_mask |
| | ) |
| | else: |
| | x1, x2, attn = layer(x1, x2, attn_mask) |
| | combined = (x1 + x2) / 2 |
| | activations.append(combined) |
| | if attn.numel() > 0: |
| | attn_maps.append(attn) |
| | if self.use_act: |
| | halt_prob, act_state, should_break = self._act_step( |
| | combined, idx, halt_prob, act_state, halt_history |
| | ) |
| | if should_break: |
| | break |
| | x = (x1 + x2) / 2 |
| | else: |
| | for idx, layer in enumerate(self.layers): |
| | if self.use_checkpoint: |
| | x, attn = checkpoint.checkpoint(layer, x, attn_mask) |
| | else: |
| | x, attn = layer(x, attn_mask) |
| | activations.append(x) |
| | if attn.numel() > 0: |
| | attn_maps.append(attn) |
| | if self.use_act: |
| | halt_prob, act_state, should_break = self._act_step( |
| | x, idx, halt_prob, act_state, halt_history |
| | ) |
| | if should_break: |
| | break |
| | if self.use_act: |
| | act_state = act_state + x * (1 - halt_prob) |
| | x = act_state |
| | logits = self.out_head(x) |
| |
|
| | |
| | entropies = [] |
| | for act in activations: |
| | prob = act.softmax(-1) |
| | ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean() |
| | entropies.append(ent) |
| |
|
| | attn_entropies = [] |
| | for attn in attn_maps: |
| | prob = attn |
| | ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1) |
| | ent = ent.mean(1) |
| | attn_entropies.append(ent) |
| | if attn_entropies: |
| | attn_entropy_map = torch.stack(attn_entropies).mean(0) |
| | else: |
| | attn_entropy_map = torch.zeros( |
| | bit_seq.size(0), bit_seq.size(1), device=bit_seq.device |
| | ) |
| | max_ent = math.log(attn_entropy_map.size(-1)) |
| | attn_entropy_map = attn_entropy_map / max_ent |
| | attn_entropy = attn_entropy_map.mean(1) |
| |
|
| | logits_bt = logits.transpose(0, 1) |
| | negentropy_in = self.negentropy_kpi(bit_seq) |
| | lz_in = self.lz_complexity(bit_seq.float()) |
| | negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False) |
| | lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False) |
| | kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False) |
| |
|
| | raw_sym = ( |
| | (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2 |
| | + negentropy_logits_b * lz_logits_b |
| | - self.lambda_S * kl_div_b |
| | - 0.1 * attn_entropy |
| | ) |
| | weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach() |
| | raw_sym = raw_sym - 0.01 * weight_norm |
| | sym_score = torch.sigmoid(raw_sym) |
| |
|
| | B, T = bit_seq.shape |
| | assert logits_bt.shape[:2] == (B, T) |
| | assert attn_entropy_map.shape == (B, T) |
| |
|
| | telemetry = { |
| | "activations": activations, |
| | "attention_maps": attn_maps, |
| | "attention_entropy": attn_entropy_map, |
| | "entropy": entropies, |
| | "attention_entropy_mean": attn_entropy, |
| | "negentropy_input": negentropy_in.detach(), |
| | "lz_complexity_input": lz_in.detach(), |
| | "negentropy_logits": negentropy_logits_b.detach(), |
| | "lz_complexity_logits": lz_logits_b.detach(), |
| | "symbiosis_kl": kl_div_b.detach(), |
| | "symbiosis_score": sym_score.detach(), |
| | } |
| | if self.use_act: |
| | telemetry["halt_probs"] = halt_history |
| |
|
| | return logits_bt, telemetry |
| | finally: |
| | if orig_chunks is not None: |
| | self.chunk_size = orig_model_chunk |
| | for layer, chunk in zip(self.layers, orig_chunks): |
| | layer.chunk_size = chunk |
| |
|
| | def forward_compressed( |
| | self, compressed_bits, causal: bool = True |
| | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """Decompress bit sequences then run the normal forward pass.""" |
| | if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1: |
| | sequences = [decompress_bits(compressed_bits).to(torch.long)] |
| | else: |
| | sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits] |
| | lengths = [seq.numel() for seq in sequences] |
| | if len(set(lengths)) != 1: |
| | raise ValueError("Sequences decompress to different lengths") |
| | bits = torch.stack(sequences) |
| | return self.forward(bits, causal=causal) |
| |
|
| | def _current_params(self) -> Dict: |
| | """Return a dictionary with the current model hyperparameters.""" |
| | return { |
| | "d_model": self.d_model, |
| | "nhead": self.layers[0].self_attn.num_heads, |
| | "num_layers": self.num_layers, |
| | "dim_feedforward": self.layers[0].linear1.out_features, |
| | "max_seq_len": self.pos_enc.pe.size(0), |
| | "lambda_K": self.lambda_K, |
| | "lambda_C": self.lambda_C, |
| | "lambda_S": self.lambda_S, |
| | "reversible": self.reversible, |
| | "use_checkpoint": self.use_checkpoint, |
| | "use_autocast": self.use_autocast, |
| | "use_act": self.use_act, |
| | "act_threshold": self.act_threshold, |
| | "chunk_size": self.chunk_size, |
| | "overlap": self.overlap, |
| | } |
| |
|
| | def double_width(self) -> "BitTransformerLM": |
| | """Return a copy of the model with doubled hidden size.""" |
| | from .scale import expand_model |
| |
|
| | params = self._current_params() |
| | params["d_model"] *= 2 |
| | params["dim_feedforward"] *= 2 |
| | return expand_model(self, params) |
| |
|
| | def double_layers(self) -> "BitTransformerLM": |
| | """Return a copy of the model with twice as many layers.""" |
| | from .scale import expand_model |
| |
|
| | params = self._current_params() |
| | params["num_layers"] *= 2 |
| | return expand_model(self, params) |
| |
|
| | def double_length(self) -> "BitTransformerLM": |
| | """Return a copy of the model with doubled maximum sequence length.""" |
| | from .scale import expand_model |
| |
|
| | params = self._current_params() |
| | params["max_seq_len"] *= 2 |
| | params["chunk_size"] = params["max_seq_len"] |
| | return expand_model(self, params) |
| |
|
| | def train_full_sequence( |
| | self, |
| | bits: torch.Tensor, |
| | *, |
| | ctx_bits: int = 4096, |
| | detach_every_n: int = 1_048_576, |
| | ) -> float: |
| | """Train on a long bit tensor using sliding windows. |
| | |
| | Parameters |
| | ---------- |
| | bits: ``torch.Tensor`` |
| | 1D tensor containing the full bit sequence. |
| | ctx_bits: int |
| | Size of the training context window. |
| | detach_every_n: int |
| | Interval in bits for optimizer updates and graph detachment. |
| | Returns |
| | ------- |
| | float |
| | Mean loss over all windows. |
| | """ |
| | self.train() |
| | optimizer, scheduler = configure_optimizer( |
| | self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits) |
| | ) |
| | accum = 0 |
| | total_loss = 0.0 |
| | count = 0 |
| | for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits): |
| | segment = bits[start : start + ctx_bits + 1].unsqueeze(0) |
| | logits, _ = self(segment) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = segment[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | loss.backward() |
| | accum += ctx_bits |
| | total_loss += loss.item() |
| | count += 1 |
| | if accum >= detach_every_n: |
| | torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | accum = 0 |
| | if accum > 0: |
| | torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | return total_loss / max(1, count) |
| |
|
| |
|
| | def infer_long_sequence( |
| | model: BitTransformerLM, |
| | bits: torch.Tensor, |
| | *, |
| | ctx_bits: int = 4096, |
| | overlap: int = 256, |
| | ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: |
| | """Infer a long bit sequence using sliding windows with overlap.""" |
| | model.eval() |
| | device = next(model.parameters()).device |
| | bits = bits.to(device) |
| | step = ctx_bits - overlap |
| | outputs: List[torch.Tensor] = [] |
| | logs: List[Dict[str, torch.Tensor]] = [] |
| | for start in range(0, bits.numel(), step): |
| | window = bits[start : start + ctx_bits].unsqueeze(0) |
| | logits, tele = model(window, causal=True) |
| | pred = logits.argmax(-1).squeeze(0) |
| | outputs.append(pred) |
| | logs.append(tele) |
| | out = torch.cat(outputs)[: bits.numel()] |
| | return out, logs |
| |
|
| |
|
| | def diffusion_inference( |
| | model: BitTransformerLM, |
| | *, |
| | length: int, |
| | steps: int = 8, |
| | batch_size: int = 1, |
| | init_bits: Optional[torch.Tensor] = None, |
| | schedule: str = "linear", |
| | ) -> torch.Tensor: |
| | """Generate bit sequences using iterative denoising diffusion. |
| | |
| | Parameters |
| | ---------- |
| | model: ``BitTransformerLM`` |
| | The model used for denoising. It is run in non-causal mode with |
| | chunked attention disabled, enabling full-context bidirectional |
| | attention. |
| | length: int |
| | Length of the bit sequences to generate. |
| | steps: int, default ``8`` |
| | Number of denoising iterations. More steps generally yield sharper |
| | samples at the cost of compute. |
| | batch_size: int, default ``1`` |
| | Number of sequences to generate in parallel. |
| | init_bits: ``torch.Tensor`` | ``None`` |
| | Optional initial noisy bits of shape ``(batch_size, length)``. When |
| | ``None`` random noise is used. |
| | schedule: str, default ``"linear"`` |
| | Noise schedule for the denoising mask probability. Options are |
| | ``"linear"``, ``"cosine"``, and ``"exp"``. |
| | |
| | Returns |
| | ------- |
| | ``torch.Tensor`` |
| | A tensor of shape ``(batch_size, length)`` containing generated bits. |
| | """ |
| |
|
| | model.eval() |
| | device = next(model.parameters()).device |
| | if init_bits is None: |
| | bits = torch.randint(0, 2, (batch_size, length), device=device) |
| | else: |
| | bits = init_bits.to(device) |
| | if bits.shape != (batch_size, length): |
| | raise ValueError("init_bits must have shape (batch_size, length)") |
| |
|
| | for step in range(steps): |
| | logits, _ = model(bits, causal=False) |
| | prob = logits.softmax(-1)[..., 1] |
| | t = (step + 1) / steps |
| | if schedule == "linear": |
| | mask_prob = 1.0 - t |
| | elif schedule == "cosine": |
| | mask_prob = math.cos(math.pi * t / 2) |
| | elif schedule == "exp": |
| | mask_prob = math.exp(-5 * t) |
| | else: |
| | raise ValueError(f"unknown schedule: {schedule}") |
| | mask = (torch.rand_like(bits.float()) < mask_prob).long() |
| | sampled = torch.bernoulli(prob).long() |
| | bits = torch.where(mask.bool(), sampled, bits) |
| | if bits.shape[-1] % 9 == 0: |
| | bits, corrections = enforce_parity(bits) |
| | if corrections: |
| | logging.info("Parity corrections applied: %d", corrections) |
| | try: |
| | from .safety import hil_safe_inference |
| |
|
| | hil_safe_inference(model, bits, causal=False, strict=False) |
| | except RuntimeError as exc: |
| | logging.warning("Safety gate warning: %s", exc) |
| | return bits |
| |
|
| |
|
| | def example_usage() -> float: |
| | """Run the example from the README and return the loss.""" |
| | B, L = 4, 16 |
| | model = BitTransformerLM( |
| | d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L |
| | ) |
| | bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
| | logits, _ = model(bits) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = bits[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | return loss.item() |
| |
|
| |
|
| | def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]: |
| | """Demonstrate a training step where metrics do not affect gradients.""" |
| | B, L = 4, 16 |
| | model = BitTransformerLM( |
| | d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L |
| | ) |
| | optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1) |
| |
|
| | bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
| | logits, telemetry = model(bits) |
| |
|
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = bits[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| |
|
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | return loss.item(), telemetry |
| |
|
| |
|
| | if __name__ == "__main__": |
| | loss, telemetry = example_training_step() |
| | print("Composite loss:", loss) |
| | print("Telemetry keys:", list(telemetry.keys())) |
| |
|