| """ |
| RAE Loss Functions |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Multi-objective loss replicating the handwriting effect. |
| |
| Standard SFT loss = CrossEntropy over all tokens equally. |
| RAE loss = Weighted CrossEntropy per phase + coherence + compression. |
| |
| This forces the model to invest different amounts of "cognitive effort" |
| in different phases β just as handwriting forces different brain regions |
| to co-activate with different intensities during letter formation. |
| |
| The phase weights are the training-time equivalent of the motor cortex |
| being forced to plan each stroke: higher weight on Abstraction and |
| Descent means the model builds richer representations in those layers. |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| import re |
|
|
|
|
| class RAEPhaseTokenizer: |
| """Identifies phase boundaries in tokenized sequences.""" |
| |
| PHASE_TAGS = { |
| "saturation": ("<SATURATION>", "</SATURATION>"), |
| "abstraction": ("<ABSTRACTION>", "</ABSTRACTION>"), |
| "descent": ("<DESCENT>", "</DESCENT>"), |
| "integration": ("<INTEGRATION>", "</INTEGRATION>"), |
| } |
| |
| def __init__(self, tokenizer): |
| self.tokenizer = tokenizer |
| |
| self.phase_token_ids = {} |
| for phase, (open_tag, close_tag) in self.PHASE_TAGS.items(): |
| open_ids = tokenizer.encode(open_tag, add_special_tokens=False) |
| close_ids = tokenizer.encode(close_tag, add_special_tokens=False) |
| self.phase_token_ids[phase] = { |
| "open": open_ids, |
| "close": close_ids, |
| } |
| |
| def get_phase_masks(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor]: |
| """ |
| Return boolean masks for each RAE phase in the sequence. |
| |
| Args: |
| input_ids: [batch_size, seq_len] token IDs |
| |
| Returns: |
| dict mapping phase name β [batch_size, seq_len] boolean mask |
| """ |
| batch_size, seq_len = input_ids.shape |
| masks = {} |
| |
| for phase, tag_ids in self.phase_token_ids.items(): |
| mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=input_ids.device) |
| |
| for b in range(batch_size): |
| ids = input_ids[b].tolist() |
| |
| open_pos = self._find_subsequence(ids, tag_ids["open"]) |
| close_pos = self._find_subsequence(ids, tag_ids["close"]) |
| |
| if open_pos is not None and close_pos is not None: |
| start = open_pos + len(tag_ids["open"]) |
| end = close_pos |
| if start < end: |
| mask[b, start:end] = True |
| |
| masks[phase] = mask |
| |
| return masks |
| |
| @staticmethod |
| def _find_subsequence(sequence: list, subsequence: list) -> Optional[int]: |
| """Find the starting position of a subsequence.""" |
| sub_len = len(subsequence) |
| for i in range(len(sequence) - sub_len + 1): |
| if sequence[i:i + sub_len] == subsequence: |
| return i |
| return None |
|
|
|
|
| class RAELoss(nn.Module): |
| """ |
| Multi-objective loss for RAE training. |
| |
| Components: |
| 1. Phase-weighted cross-entropy: Different loss weight per RAE phase |
| 2. Coherence penalty: Abstraction should be entailed by Saturation |
| 3. Compression reward: Abstraction should be shorter than Saturation |
| |
| This is the computational equivalent of handwriting's multi-circuit |
| co-activation: the loss landscape has multiple interacting gradients |
| that force richer weight updates than flat cross-entropy. |
| """ |
| |
| def __init__( |
| self, |
| phase_weights: Optional[dict[str, float]] = None, |
| coherence_weight: float = 0.3, |
| compression_weight: float = 0.2, |
| base_weight: float = 1.0, |
| ): |
| super().__init__() |
| |
| self.phase_weights = phase_weights or { |
| "saturation": 1.0, |
| "abstraction": 1.5, |
| "descent": 1.5, |
| "integration": 1.0, |
| } |
| self.coherence_weight = coherence_weight |
| self.compression_weight = compression_weight |
| self.base_weight = base_weight |
| |
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| phase_masks: dict[str, torch.Tensor], |
| hidden_states: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Compute RAE multi-objective loss. |
| |
| Args: |
| logits: [batch, seq_len, vocab_size] model output logits |
| labels: [batch, seq_len] target token IDs (-100 = ignore) |
| phase_masks: dict of [batch, seq_len] boolean masks per phase |
| hidden_states: [batch, seq_len, hidden_dim] for coherence loss |
| |
| Returns: |
| dict with 'total', 'phase_losses', 'coherence', 'compression' |
| """ |
| batch_size, seq_len, vocab_size = logits.shape |
| |
| |
| |
| token_weights = torch.ones(batch_size, seq_len, device=logits.device) |
| |
| for phase, mask in phase_masks.items(): |
| weight = self.phase_weights.get(phase, 1.0) |
| token_weights = torch.where(mask, torch.tensor(weight, device=logits.device), token_weights) |
| |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| shift_weights = token_weights[..., 1:].contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100) |
| per_token_loss = loss_fct( |
| shift_logits.view(-1, vocab_size), |
| shift_labels.view(-1), |
| ).view(batch_size, -1) |
| |
| |
| weighted_loss = (per_token_loss * shift_weights).sum() / shift_weights.sum() |
| |
| |
| phase_losses = {} |
| for phase, mask in phase_masks.items(): |
| phase_mask_shifted = mask[..., 1:] |
| valid_mask = phase_mask_shifted & (shift_labels != -100) |
| if valid_mask.any(): |
| phase_loss = per_token_loss[valid_mask].mean() |
| phase_losses[phase] = phase_loss.item() |
| |
| |
| |
| |
| coherence_loss = torch.tensor(0.0, device=logits.device) |
| if hidden_states is not None and self.coherence_weight > 0: |
| sat_mask = phase_masks.get("saturation", None) |
| abs_mask = phase_masks.get("abstraction", None) |
| |
| if sat_mask is not None and abs_mask is not None: |
| for b in range(batch_size): |
| sat_hidden = hidden_states[b][sat_mask[b]] |
| abs_hidden = hidden_states[b][abs_mask[b]] |
| |
| if sat_hidden.shape[0] > 0 and abs_hidden.shape[0] > 0: |
| |
| sat_repr = sat_hidden.mean(dim=0) |
| abs_repr = abs_hidden.mean(dim=0) |
| |
| |
| similarity = F.cosine_similarity( |
| sat_repr.unsqueeze(0), abs_repr.unsqueeze(0) |
| ) |
| |
| coherence_loss = coherence_loss + (1 - similarity.mean()) |
| |
| coherence_loss = coherence_loss / batch_size |
| |
| |
| |
| |
| compression_loss = torch.tensor(0.0, device=logits.device) |
| if self.compression_weight > 0: |
| sat_mask = phase_masks.get("saturation", None) |
| abs_mask = phase_masks.get("abstraction", None) |
| |
| if sat_mask is not None and abs_mask is not None: |
| for b in range(batch_size): |
| sat_len = sat_mask[b].sum().float() |
| abs_len = abs_mask[b].sum().float() |
| |
| if sat_len > 0: |
| |
| ratio = abs_len / sat_len |
| |
| compression_loss = compression_loss + F.relu(ratio - 0.7) |
| |
| compression_loss = compression_loss / batch_size |
| |
| |
| total = ( |
| self.base_weight * weighted_loss |
| + self.coherence_weight * coherence_loss |
| + self.compression_weight * compression_loss |
| ) |
| |
| return { |
| "total": total, |
| "weighted_ce": weighted_loss, |
| "coherence": coherence_loss, |
| "compression": compression_loss, |
| "phase_losses": phase_losses, |
| } |
|
|
|
|
| class RAELossSimple(nn.Module): |
| """ |
| Simplified RAE loss for use with AutoTrain/standard trainers. |
| |
| When you can't modify the training loop, this loss applies |
| phase weighting through label masking β tokens in higher-weight |
| phases are effectively seen more times. |
| |
| This is the "pragmatic handwriting" approach: you can't fully |
| replicate the multi-objective setup, but you CAN force the model |
| to pay more attention to critical phases. |
| """ |
| |
| def __init__(self, phase_weights: Optional[dict[str, float]] = None): |
| super().__init__() |
| self.phase_weights = phase_weights or { |
| "saturation": 1.0, |
| "abstraction": 1.5, |
| "descent": 1.5, |
| "integration": 1.0, |
| } |
| |
| def create_weighted_labels( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor, |
| phase_tokenizer: RAEPhaseTokenizer, |
| ) -> torch.Tensor: |
| """ |
| Create sample weights for the trainer's built-in loss. |
| |
| For standard HF Trainer, we can't change the loss function, |
| but we CAN duplicate high-weight phase tokens in the training |
| data through oversampling, achieving a similar effect. |
| """ |
| phase_masks = phase_tokenizer.get_phase_masks(input_ids) |
| |
| |
| |
| |
| return labels |
|
|