"""ARB — Any Relational Bit. Core model assembly.""" import warnings import torch import torch.nn as nn import torch.nn.functional as F from math import ceil as _ceil _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 from .config import VOCAB, HIDDEN_DIM, SPECIAL_VOCAB, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, KV_LEDGER_SIZE, KQ_CACHE_SIZE, MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES, MG_TOP_K from .kernel.ternary_scale import TScaleType, TernaryScaleTensor, TernaryRMSNorm, _HAS_TRITON try: from .kernel.ternary_scale import _triton_apply_accumulated_flips except ImportError: _triton_apply_accumulated_flips = None from .converters.convert_to_ternary8 import pack_ternary try: from .kernel.ternary_scale import _TritonTernaryEmbedFn except ImportError: _TritonTernaryEmbedFn = None from .sequencers import ByteEmbedding, MultimodalSequencer from .vq import SharedVQ from .components import ( ByteHead, OutputRouter, MemGram, LossComponents, LossWeights, CompositeProposalHead, MoEGraph, ) from .decoders import VideoHead, TalkerHead from .components import _BOUNDARY_TOKEN_MAP as _BOUNDARY_MAP from .attention import KVLedger, KQCache, ContextAttentionScheduler from .kernel.flash_vq import FlashVQCodebook def _extract_boundary_from_input(x): if x.dim() != 2: return None first_token = x[0, 0].item() if first_token in _BOUNDARY_MAP: return first_token for tok in x[0].tolist(): if tok in _BOUNDARY_MAP: return tok return None class ARBModel(nn.Module): def __init__(self, tscale_type=TScaleType.T32, threshold=THRESHOLD, max_graph_hops=4, max_moe_iters=4, halt_threshold=0.99, enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True, enable_memory_modules=False, enable_moe=True, shared_vq_size=None, kgvq_codebook_size=None, enable_attention=True, enable_output_router=True, enable_video_output=True, enable_talker_output=True): super().__init__() self.image_enabled = enable_image self.audio_enabled = enable_audio self.embedding = ByteEmbedding(tscale_type=tscale_type) self.multimodal_sequencer = MultimodalSequencer( tscale_type=tscale_type, enable_text=True, enable_image=enable_image, enable_audio=enable_audio, ) self.text_sequencer = self.multimodal_sequencer.text self.image_sequencer = self.multimodal_sequencer.image self.audio_sequencer = self.multimodal_sequencer.audio self.vq_enabled = enable_vq self.bridge = SharedVQ( codebook_size=shared_vq_size, tscale_type=tscale_type, enable_image=enable_image, enable_audio=enable_audio, ) if enable_vq else None self.vq_to_trigram = TernaryScaleTensor(CODEBOOK_DIM, HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None self.vq_to_trigram_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None self.graph_enabled = enable_graph and enable_vq graph_vocab_size = self.bridge.total_codebook_size if self.graph_enabled else None self.threshold = threshold self.moegraph = MoEGraph( trigram_dim=HIDDEN_DIM, codebook_size=graph_vocab_size or CODEBOOK_SIZE, max_iters=max_moe_iters, halt_threshold=halt_threshold, top_k=MG_TOP_K, ) if self.graph_enabled else None self.byte_head = ByteHead(tscale_type=tscale_type) # Composite motif generation (Phase 17) self.composite_head = CompositeProposalHead( dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM, k_max=K_MAX_COMPOSITES, codebook_size=kgvq_codebook_size or KGVQ_CODEBOOK_SIZE, tscale_type=tscale_type, ) if self.graph_enabled else None self.output_router = OutputRouter(tscale_type=tscale_type, depth=3) if enable_output_router else None self.video_head = VideoHead(tscale_type=tscale_type) if enable_video_output else None self.talker_head = TalkerHead(tscale_type=tscale_type) if enable_talker_output else None self.memgram = MemGram( struct_primes=MEMGRAM_STRUCT_PRIMES, conv_primes=MEMGRAM_CONV_PRIMES, embed_dim=MEMGRAM_EMBED_DIM, key_dim=MEMGRAM_KEY_DIM, hidden_dim=HIDDEN_DIM, ) if enable_memory_modules else None self.memgram_enabled = self.memgram is not None # KV Ledger + Attention (Phase 16 — replaces LSTM) self.kv_ledger = KVLedger(max_size=KV_LEDGER_SIZE) if enable_attention else None self.kq_cache = KQCache(max_size=KQ_CACHE_SIZE) if enable_attention else None self.attention = ContextAttentionScheduler(dim=HIDDEN_DIM) if enable_attention else None self.attention_enabled = bool(enable_attention) def forward(self, x, targets=None, commitment_warmup_weight=1.0, act_warmup_mode=False, ponder_lambda=0.01, images=None, audio=None, timestep=0, loss_weights=None, output_mode=None): has_image = images is not None has_audio = audio is not None if has_image and (not self.image_enabled or self.image_sequencer is None): raise ValueError("images provided but model has enable_image=False") if has_audio and (not self.audio_enabled or self.audio_sequencer is None): raise ValueError("audio provided but model has enable_audio=False") embedded = self.embedding(x) seq_inputs = {'text': embedded} if has_image: seq_inputs['image'] = images if has_audio: seq_inputs['audio'] = audio seq_outputs = self.multimodal_sequencer(seq_inputs) relational = seq_outputs['text'] indices_dict = {} if self.vq_enabled: bridge_inputs = {'text': relational} if 'image' in seq_outputs: bridge_inputs['image'] = seq_outputs['image'] if 'audio' in seq_outputs: bridge_inputs['audio'] = seq_outputs['audio'] combined, vq_losses, indices_dict = self.bridge(bridge_inputs, timestep=timestep) if combined is None: combined = relational elif combined.shape[-1] == CODEBOOK_DIM: combined = self.vq_to_trigram_norm(self.vq_to_trigram(combined)) vq_loss = vq_losses.get('text_vq', torch.zeros((), device=x.device)) if 'image_vq' in vq_losses: vq_loss = vq_loss + vq_losses['image_vq'] if 'audio_vq' in vq_losses: vq_loss = vq_loss + vq_losses['audio_vq'] else: combined = relational vq_loss = torch.zeros((), device=x.device) active_mods = ['text'] if has_image: active_mods.append('image') if has_audio: active_mods.append('audio') active_count = len(active_mods) # MemGram injection (after VQ, before Graph — D92) memgram_decay_reg = torch.tensor(0.0, device=x.device) if self.memgram_enabled and self.memgram is not None and self.vq_enabled: vq_indices = indices_dict.get('text', torch.zeros(combined.shape[0], combined.shape[1], dtype=torch.long, device=x.device)) combined = self.memgram( vq_indices=vq_indices, hidden_state=combined, ) all_indices = None composite_ids = None composite_vq_loss = None processed = combined moegraph_ponder_loss = torch.tensor(0.0, device=x.device) if self.graph_enabled and self.moegraph is not None and self.vq_enabled and vq_loss is not None: self.moegraph._codebook_table = self.bridge.vq.table self.moegraph._codebook_embed = None all_indices = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long)) if has_image and 'image' in indices_dict: all_indices = torch.cat([all_indices, indices_dict['image']], dim=1) if has_audio and 'audio' in indices_dict: all_indices = torch.cat([all_indices, indices_dict['audio']], dim=1) # MemGram retrieval for MoEGraph injection memgram_cb = None if self.memgram_enabled and self.memgram is not None and self.vq_enabled: vq_idx = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long)) memgram_cb = self.memgram.retrieve_cb(vq_idx) # Attention output for KV conditioning attn_out = None if self.attention_enabled and self.attention is not None and self.kv_ledger is not None: attn_out = self.attention(combined, self.kv_ledger, kq_cache=self.kq_cache) # MoEGraph forward (unified ACT loop) processed, moegraph_ponder_loss = self.moegraph( combined, all_indices, attention_output=attn_out, memgram_cb_output=memgram_cb, threshold=self.threshold, ) # Composite motif generation (Phase 17) if self.composite_head is not None: composite_ids, composite_vq_loss, _ = self.composite_head(processed.mean(dim=1)) # Update bounded int-only KG co-occurrence state. self.moegraph.update_kg_edges(all_indices) # OutputRouter: route to appropriate head if targets is not None or output_mode == "text": logits = self.byte_head(processed) elif output_mode == "video": if self.video_head is None: raise ValueError("output_mode='video' requested but video output is disabled") logits = self.video_head(processed) elif output_mode in {"audio", "talker"}: if self.talker_head is None: raise ValueError("audio/talker output requested but talker output is disabled") logits = self.talker_head(processed) elif self.training and self.output_router is not None: route = self.output_router(processed, training=True) route_weights, route_logits = route logits = self.byte_head(processed) elif self.output_router is not None: route = self.output_router(processed, training=False) if isinstance(route, torch.Tensor) and route.numel() > 0: use_video = (route == 2).any() and self.video_head is not None use_talk = (route == 3).any() and self.talker_head is not None logits = self.video_head(processed) if use_video else \ self.talker_head(processed) if use_talk else \ self.byte_head(processed) else: logits = self.byte_head(processed) else: logits = self.byte_head(processed) T_text = relational.shape[1] if logits.dim() == 3 and logits.shape[-1] == VOCAB: logits = logits[:, :T_text, :] with torch.no_grad(): self._append_predictions_to_kv(logits.argmax(dim=-1), composite_ids=composite_ids) losses = None if targets is not None: next_byte_logits = logits[:, :-1, :].contiguous() lm_loss = F.cross_entropy( next_byte_logits.view(-1, VOCAB), targets.contiguous().view(-1), ignore_index=SPECIAL_VOCAB["PAD"] ) vq_component = commitment_warmup_weight * vq_loss if self.vq_enabled else None losses = LossComponents( lm=lm_loss, vq_commitment=vq_component, graph_l1=None, moegraph_ponder=moegraph_ponder_loss, memgram_decay_reg=memgram_decay_reg if self.memgram_enabled else None, composite_vq=composite_vq_loss if self.composite_head is not None and composite_ids is not None else None, weights=loss_weights if loss_weights is not None else LossWeights(), ) return logits, losses, all_indices, None @torch.no_grad() def _append_predictions_to_kv(self, pred_ids, composite_ids=None): if self.kv_ledger is None or self.kq_cache is None: return for b in range(pred_ids.shape[0]): for t in range(pred_ids.shape[1]): token_id = int(pred_ids[b, t]) self.kv_ledger.append(token_id) self.kq_cache.append(token_id) if composite_ids is None: continue composite_offset = self.bridge.total_codebook_size if self.vq_enabled and self.bridge is not None else 0 for k in range(composite_ids.shape[1]): cid = int(composite_ids[b, k]) if cid >= 0: self.kv_ledger.append(composite_offset + cid) def _ternary_update_memory(self, accum_threshold=8, update_scales=True, loss_components=None, loss_signal=None): signal = loss_components.total if loss_components is not None else loss_signal t_step = self._ternary_t_step(signal) if signal is not None and not torch.isfinite(signal.detach()).all(): warnings.warn("Non-finite loss detected — skipping ternary state update", RuntimeWarning, stacklevel=2) self._clear_ternary_hooks() self.zero_grad(set_to_none=True) return if loss_components is not None: self._componentwise_ternary_backward(loss_components, t_step, update_scales, accum_threshold) else: self._apply_regular_ternary_hooks(accum_threshold, update_scales, t_step, loss_signal) self._clear_ternary_hooks() self._clear_backward_update_flags() def prepare_ternary_backward(self, loss_signal=None, update_scales=True): """Configure streaming CUDA ternary updates before `loss.backward()`. BigInt-scaled dense linear backward accumulates directly into int64 `corr_accum`, while legacy sparse tables still use int8 `T_accum`. Calling this before backward lets the streaming path use the same loss-scaled step that `_ternary_update_memory()` will finalize. """ t_step = self._ternary_t_step(loss_signal) for module in self.modules(): if hasattr(module, "T_accum") or hasattr(module, "corr_accum"): module._backward_t_accum_step = t_step module._backward_update_scales = bool(update_scales) module._stream_backward_updates = True def _clear_backward_update_flags(self): for module in self.modules(): for attr in ( "_backward_t_accum_step", "_backward_update_scales", "_stream_backward_updates", "_streamed_ternary_backward", "_streamed_bigint_backward", ): if hasattr(module, attr): delattr(module, attr) @staticmethod def _ternary_t_step(loss_signal): return 1 def _clear_ternary_hooks(self): base_names = [ "_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d", "_hook_T", "_hook_sparse_indices", "_hook_sparse_grad_sign", "_hook_sparse_T", ] for module in self.modules(): if hasattr(module, "_T_accum_fp"): delattr(module, "_T_accum_fp") for hook_name in base_names: if hasattr(module, hook_name): delattr(module, hook_name) for hook_name in list(vars(module).keys()): if hook_name.startswith(( "_hook_grad_T_sign_", "_hook_grad_2d_", "_hook_x_2d_", "_hook_T_", "_hook_sparse_indices_", "_hook_sparse_grad_sign_", "_hook_sparse_T_", )): delattr(module, hook_name) def _componentwise_ternary_backward(self, loss_components, t_step, update_scales, accum_threshold): from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT self.prepare_ternary_backward(loss_components.total, update_scales=update_scales) active = [(n, t, w) for n, t, w in loss_components.active_fields if t is not None and t.dim() == 0 and t.requires_grad and float(w) != 0.0] for idx, (name, comp_tensor, weight) in enumerate(active): retain = idx < len(active) - 1 _COMPONENT_CONTEXT.set(name, weight) try: comp_tensor.backward(retain_graph=retain) finally: _COMPONENT_CONTEXT.clear() self._consume_component_hooks(name, weight, t_step, update_scales, accum_threshold) with torch.no_grad(): for module in self.modules(): if self._is_large_sparse_embedding(module): continue if update_scales: self._step_E_from_accum(module) self._apply_accumulated_flips(module, accum_threshold=accum_threshold) def _consume_component_hooks(self, name, weight, t_step, update_scales, accum_threshold): for module in self.modules(): sparse_idx_key = f"_hook_sparse_indices_{name}" sparse_grad_key = f"_hook_sparse_grad_sign_{name}" sparse_t_key = f"_hook_sparse_T_{name}" if hasattr(module, sparse_idx_key) and hasattr(module, sparse_grad_key): setattr(module, "_hook_sparse_indices", getattr(module, sparse_idx_key)) setattr(module, "_hook_sparse_grad_sign", getattr(module, sparse_grad_key)) if hasattr(module, sparse_t_key): setattr(module, "_hook_sparse_T", getattr(module, sparse_t_key)) if update_scales and hasattr(module, "update_E"): module._e_accum_threshold = 8 module.update_E() if hasattr(module, "T_accum"): module._t_accum_step = max(1, int(round(abs(float(weight)) * t_step))) if hasattr(module, "ternary_step"): module.ternary_step(accum_threshold=accum_threshold) for key in (sparse_idx_key, sparse_grad_key, sparse_t_key): if hasattr(module, key): delattr(module, key) continue dense_key = f"_hook_grad_T_sign_{name}" dense_t_key = f"_hook_T_{name}" if hasattr(module, dense_key): grad_sign = getattr(module, dense_key) hook_t = getattr(module, dense_t_key, None) self._accumulate_component_grad_continuous( module, grad_sign, weight, t_step, ) delattr(module, dense_key) if hasattr(module, dense_t_key): delattr(module, dense_t_key) grad_key = f"_hook_grad_2d_{name}" x_key = f"_hook_x_2d_{name}" if not hasattr(module, grad_key) or not hasattr(module, x_key): continue comp_grad = getattr(module, grad_key) comp_x = getattr(module, x_key) if torch.isfinite(comp_grad).all() and torch.isfinite(comp_x).all(): raw_grad = torch.clamp(comp_grad.transpose(0, 1) @ comp_x, -10.0, 10.0) self._accumulate_component_grad_continuous( module, raw_grad, weight, t_step, ) delattr(module, grad_key) delattr(module, x_key) def _accumulate_component_grad_continuous(self, module, raw_grad, weight, t_step): """Component loss accumulation without persistent float optimizer state.""" if not hasattr(module, "_T_shape"): return shape = tuple(int(x) for x in module._T_shape.tolist()) if tuple(raw_grad.shape) != shape: return with torch.no_grad(): step = max(1, int(round(abs(float(weight)) * t_step))) if float(weight) < 0: step = -step if hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign"): signed = raw_grad.sign().to(device=module.corr_accum.device, dtype=torch.int8) module._accumulate_corr_from_grad_sign(signed, corr_step=step) return if not hasattr(module, "T_accum") or tuple(module.T_accum.shape) != shape: return if hasattr(module, "_T_accum_fp"): delattr(module, "_T_accum_fp") signed = raw_grad.sign().to(device=module.T_accum.device, dtype=torch.int8) module.T_accum.copy_( torch.clamp( module.T_accum.to(torch.int16) - signed.to(torch.int16) * step, -127, 127, ).to(torch.int8) ) def _apply_regular_ternary_hooks(self, accum_threshold, update_scales, t_step, loss_signal): for module in self.modules(): is_bigint = hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign") is_legacy = hasattr(module, "T_accum") or hasattr(module, "E_accum") if is_bigint or is_legacy: self._prepare_per_group_threshold(module) streamed = bool(getattr(module, "_streamed_ternary_backward", False)) has_hook = ( hasattr(module, "_hook_grad_T_sign") or (hasattr(module, "_hook_grad_2d") and hasattr(module, "_hook_x_2d")) or (hasattr(module, "_hook_sparse_indices") and hasattr(module, "_hook_sparse_grad_sign")) ) bigint_streamed = bool(getattr(module, "_streamed_bigint_backward", False)) if (streamed or bigint_streamed) and not has_hook: if streamed and update_scales: self._step_E_from_accum(module) if streamed: had_flip = self._apply_accumulated_flips(module, accum_threshold=accum_threshold) self._record_flip_health(module, had_flip) if hasattr(module, "per_group_threshold"): del module.per_group_threshold continue if has_hook: if hasattr(module, "_hook_grad_T_sign") and hasattr(module, "_accumulate_corr_from_grad_sign"): module._accumulate_corr_from_grad_sign(module._hook_grad_T_sign) del module._hook_grad_T_sign if hasattr(module, "ternary_step"): module.ternary_step(accum_threshold=accum_threshold) if hasattr(module, "per_group_threshold"): del module.per_group_threshold def _prepare_per_group_threshold(self, module): if self._is_large_sparse_embedding(module): module.per_group_threshold = None return if hasattr(module, "corr_accum") and not hasattr(module, "T_accum"): module.per_group_threshold = None return if not hasattr(module, "E") or not hasattr(module, "_T_shape"): module.per_group_threshold = None return shape = tuple(int(x) for x in module._T_shape.tolist()) out_dim, in_dim = shape gpr = _ceil_div(in_dim, module.group_size) E_view = module.E.view(out_dim, gpr).float() threshold_g = 8.0 + 0.25 * torch.min(E_view.abs(), torch.tensor(32.0, device=E_view.device)) module.per_group_threshold = torch.clamp(threshold_g, max=16.0).to(torch.int8).reshape(-1) @staticmethod def _is_large_sparse_embedding(module): return ( hasattr(module, "num_embeddings") and hasattr(module, "sparse_threshold") and module.num_embeddings >= module.sparse_threshold ) @staticmethod def _step_E_from_accum(module): if hasattr(module, "corr_accum"): return # BigInt modules don't use E_accum threshold flips if not hasattr(module, "E") or not hasattr(module, "E_accum"): return threshold = int(getattr(module, "_e_accum_threshold", 8)) accum = module.E_accum.to(torch.int16) step = torch.where( accum >= threshold, torch.ones_like(accum, dtype=torch.int16), torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)), ) if step.any(): module.E = torch.clamp(module.E.to(torch.int16) + step, -128, 127).to(torch.int8) module.E_accum = (accum - step * threshold).to(torch.int8) @staticmethod def _apply_accumulated_flips(module, accum_threshold=3): """Packed-byte carry: when T_accum crosses ±1, move trit by ±1 via ±3^pos.""" if not hasattr(module, "T_accum") or not hasattr(module, "T_packed") or not hasattr(module, "_T_shape"): return False shape = tuple(int(x) for x in module._T_shape.tolist()) if tuple(module.T_accum.shape) != shape: return False carry_up = module.T_accum > 1 carry_down = module.T_accum < -1 if not carry_up.any() and not carry_down.any(): return False dev = module.T_packed.device out_dim, in_dim = shape pows = torch.tensor([1, 3, 9, 27, 81], device=dev, dtype=torch.int16) pk = module.T_packed.to(torch.int16).clone() for p in range(5): if p >= in_dim: continue cols = torch.arange(p, in_dim, 5, device=dev) if cols.numel() == 0: continue is_up = carry_up[:, cols] is_dn = carry_down[:, cols] if not is_up.any() and not is_dn.any(): continue rows_2d = torch.arange(out_dim, device=dev)[:, None] lin_idx = rows_2d * in_dim + cols[None, :] byte_idx = lin_idx // 5 pv = pk[byte_idx] p_up = (pv + pows[p]).clamp(0, 242) p_dn = (pv - pows[p]).clamp(0, 242) pk[byte_idx] = torch.where(is_up, p_up, torch.where(is_dn, p_dn, pv)) module.T_packed = pk.to(torch.uint8) # Reset T_accum to 0 on carry so W = T_accum × T doesn't jump mask = carry_up | carry_down module.T_accum[mask] = torch.zeros_like(module.T_accum[mask]) return True @staticmethod def _record_flip_health(module, had_flip): if not hasattr(module, "T_accum"): return steps_since = getattr(module, "_steps_since_flip", 0) module._steps_since_flip = 0 if had_flip else steps_since + 1 module._had_flip = False def generate(self, idx, max_new_token, temperature=1.0, images=None, audio=None, conversation_id=None, top_k=None, min_new_tokens=0, return_metadata=False): if self.kv_ledger is not None and self.kv_ledger.size == 0: with torch.no_grad(): for token_id in idx.reshape(-1).tolist(): self.kv_ledger.append(int(token_id)) self.kq_cache.append(int(token_id)) for i in range(max_new_token): idx_cond = idx[:, -CTX:] logits, _, _, _ = self(idx_cond, images=images, audio=audio, timestep=i, output_mode="text") last_logits = logits[:, -1, :] / temperature # top-k filtering if top_k is not None and top_k > 0: v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1))) kth = v[:, -1].unsqueeze(-1).expand_as(last_logits) last_logits = last_logits.where(last_logits >= kth, float('-inf')) probs = F.softmax(last_logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) # Enforce min_new_tokens (only relevant if caller truncates after generation) generated = idx.shape[1] - (min_new_tokens if return_metadata else 0) if return_metadata: return { "tokens": idx, "n_generated": generated, "temperature": temperature, } return idx