| """ |
| RAE Trainer β Custom Training Loop |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Extends HuggingFace's Trainer with RAE multi-objective loss. |
| |
| This is the FULL CONTROL path. The training loop: |
| 1. Loads base model with QLoRA |
| 2. Applies RAE-structured training data |
| 3. Computes multi-phase weighted loss |
| 4. Tracks per-phase metrics (saturation, abstraction, descent, integration) |
| 5. Pushes trained model to HuggingFace Hub |
| |
| The key difference from standard SFT: |
| - Loss is NOT uniform across tokens |
| - Abstraction + Descent phases get higher loss weight |
| - Coherence loss penalizes abstraction that diverges from saturation |
| - Compression loss rewards shorter abstractions |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """ |
|
|
| import json |
| import os |
| import sys |
| import logging |
| from pathlib import Path |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import Dataset |
| import transformers |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForSeq2Seq, |
| ) |
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| TaskType, |
| ) |
| from datasets import load_dataset |
|
|
| from rae_loss import RAELoss, RAEPhaseTokenizer |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("rae_trainer") |
|
|
|
|
| |
|
|
| @dataclass |
| class RAETrainingConfig: |
| """Configuration for RAE training run.""" |
| |
| |
| base_model: str = "Qwen/Qwen2.5-7B-Instruct" |
| quantization: str = "int4" |
| torch_dtype: str = "bfloat16" |
| attn_implementation: str = "flash_attention_2" |
| trust_remote_code: bool = True |
| |
| |
| lora_r: int = 32 |
| lora_alpha: int = 64 |
| lora_dropout: float = 0.05 |
| lora_target_modules: list = field(default_factory=lambda: [ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ]) |
| |
| |
| train_path: str = "data/rae_training_data/train.jsonl" |
| eval_path: str = "data/rae_training_data/validation.jsonl" |
| max_seq_length: int = 4096 |
| |
| |
| epochs: int = 3 |
| batch_size: int = 1 |
| gradient_accumulation_steps: int = 8 |
| learning_rate: float = 5e-6 |
| lr_scheduler: str = "cosine" |
| warmup_ratio: float = 0.1 |
| weight_decay: float = 0.01 |
| max_grad_norm: float = 1.0 |
| bf16: bool = True |
| logging_steps: int = 10 |
| eval_steps: int = 100 |
| save_steps: int = 200 |
| save_total_limit: int = 3 |
| |
| |
| rae_loss_enabled: bool = True |
| phase_weights: dict = field(default_factory=lambda: { |
| "saturation": 1.0, |
| "abstraction": 1.5, |
| "descent": 1.5, |
| "integration": 1.0, |
| }) |
| coherence_weight: float = 0.3 |
| compression_weight: float = 0.2 |
| |
| |
| output_dir: str = "outputs/rae-trained-model" |
| push_to_hub: bool = True |
| hub_model_id: str = "rae-cognitive-model" |
| |
| @classmethod |
| def from_json(cls, path: str) -> "RAETrainingConfig": |
| with open(path) as f: |
| data = json.load(f) |
| |
| |
| flat = {} |
| for section in data.values(): |
| if isinstance(section, dict): |
| flat.update(section) |
| |
| return cls(**{k: v for k, v in flat.items() |
| if k in cls.__dataclass_fields__ and not k.startswith("_")}) |
|
|
|
|
| |
|
|
| class RAEDataset(Dataset): |
| """Dataset that loads RAE-structured JSONL data.""" |
| |
| def __init__(self, path: str, tokenizer, max_length: int = 4096): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.examples = [] |
| |
| with open(path) as f: |
| for line in f: |
| data = json.loads(line) |
| self.examples.append(data) |
| |
| logger.info(f"Loaded {len(self.examples)} RAE examples from {path}") |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx): |
| example = self.examples[idx] |
| messages = example["messages"] |
| |
| |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=False, |
| ) |
| |
| |
| encoding = self.tokenizer( |
| text, |
| max_length=self.max_length, |
| truncation=True, |
| padding=False, |
| return_tensors="pt", |
| ) |
| |
| input_ids = encoding["input_ids"].squeeze(0) |
| attention_mask = encoding["attention_mask"].squeeze(0) |
| |
| |
| labels = input_ids.clone() |
| |
| |
| |
| assistant_start = self._find_assistant_start(input_ids) |
| if assistant_start > 0: |
| labels[:assistant_start] = -100 |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |
| |
| def _find_assistant_start(self, input_ids: torch.Tensor) -> int: |
| """Find where the assistant's RAE response begins.""" |
| |
| sat_tokens = self.tokenizer.encode("<SATURATION>", add_special_tokens=False) |
| ids = input_ids.tolist() |
| |
| for i in range(len(ids) - len(sat_tokens) + 1): |
| if ids[i:i + len(sat_tokens)] == sat_tokens: |
| return i |
| |
| |
| return int(len(ids) * 0.3) |
|
|
|
|
| |
|
|
| class RAETrainer(Trainer): |
| """ |
| Extended Trainer with RAE multi-objective loss. |
| |
| This is where the handwriting effect is implemented: |
| - Phase-weighted loss forces differential encoding depth |
| - Coherence loss creates cross-phase binding |
| - Compression loss rewards information distillation |
| """ |
| |
| def __init__(self, rae_config: RAETrainingConfig, **kwargs): |
| super().__init__(**kwargs) |
| self.rae_config = rae_config |
| |
| if rae_config.rae_loss_enabled: |
| self.rae_loss_fn = RAELoss( |
| phase_weights=rae_config.phase_weights, |
| coherence_weight=rae_config.coherence_weight, |
| compression_weight=rae_config.compression_weight, |
| ) |
| self.phase_tokenizer = RAEPhaseTokenizer(self.tokenizer) |
| logger.info("RAE multi-objective loss enabled") |
| logger.info(f" Phase weights: {rae_config.phase_weights}") |
| logger.info(f" Coherence weight: {rae_config.coherence_weight}") |
| logger.info(f" Compression weight: {rae_config.compression_weight}") |
| else: |
| self.rae_loss_fn = None |
| self.phase_tokenizer = None |
| |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """Override compute_loss to use RAE multi-objective loss.""" |
| |
| if not self.rae_config.rae_loss_enabled: |
| return super().compute_loss(model, inputs, return_outputs, **kwargs) |
| |
| |
| outputs = model( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| output_hidden_states=True, |
| ) |
| |
| logits = outputs.logits |
| labels = inputs["labels"] |
| |
| |
| phase_masks = self.phase_tokenizer.get_phase_masks(inputs["input_ids"]) |
| |
| |
| hidden_states = outputs.hidden_states[-1] if outputs.hidden_states else None |
| |
| |
| loss_dict = self.rae_loss_fn(logits, labels, phase_masks, hidden_states) |
| |
| |
| if self.state.global_step % self.args.logging_steps == 0: |
| for phase, loss_val in loss_dict["phase_losses"].items(): |
| self.log({f"rae/{phase}_loss": loss_val}) |
| self.log({ |
| "rae/coherence_loss": loss_dict["coherence"].item(), |
| "rae/compression_loss": loss_dict["compression"].item(), |
| "rae/weighted_ce": loss_dict["weighted_ce"].item(), |
| }) |
| |
| total_loss = loss_dict["total"] |
| return (total_loss, outputs) if return_outputs else total_loss |
|
|
|
|
| |
|
|
| def load_model_and_tokenizer(config: RAETrainingConfig): |
| """Load and configure the base model with QLoRA.""" |
| |
| logger.info(f"Loading base model: {config.base_model}") |
| |
| |
| bnb_config = None |
| if config.quantization == "int4": |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=getattr(torch, config.torch_dtype), |
| bnb_4bit_use_double_quant=True, |
| ) |
| |
| |
| model_kwargs = { |
| "quantization_config": bnb_config, |
| "torch_dtype": getattr(torch, config.torch_dtype), |
| "trust_remote_code": config.trust_remote_code, |
| "device_map": "auto", |
| } |
| |
| |
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| config.base_model, |
| attn_implementation=config.attn_implementation, |
| **model_kwargs, |
| ) |
| except Exception: |
| logger.warning("Flash Attention not available, using default attention") |
| model = AutoModelForCausalLM.from_pretrained( |
| config.base_model, |
| **model_kwargs, |
| ) |
| |
| |
| model = prepare_model_for_kbit_training(model) |
| |
| |
| lora_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| lora_dropout=config.lora_dropout, |
| target_modules=config.lora_target_modules, |
| task_type=TaskType.CAUSAL_LM, |
| bias="none", |
| ) |
| |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| config.base_model, |
| trust_remote_code=config.trust_remote_code, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
| |
| return model, tokenizer |
|
|
|
|
| def train(config_path: str = "configs/rae_training_config.json"): |
| """Execute the full RAE training pipeline.""" |
| |
| |
| config = RAETrainingConfig.from_json(config_path) |
| |
| logger.info("=" * 60) |
| logger.info(" RAE TRAINING METHODOLOGY") |
| logger.info(" Recursive Abstraction Engine as Training-Time") |
| logger.info(" Cognitive Installation") |
| logger.info("=" * 60) |
| logger.info(f" Base model: {config.base_model}") |
| logger.info(f" RAE loss: {'ENABLED' if config.rae_loss_enabled else 'disabled'}") |
| logger.info(f" LoRA rank: {config.lora_r}") |
| logger.info(f" Epochs: {config.epochs}") |
| logger.info(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") |
| logger.info("=" * 60) |
| |
| |
| model, tokenizer = load_model_and_tokenizer(config) |
| |
| |
| train_dataset = RAEDataset(config.train_path, tokenizer, config.max_seq_length) |
| eval_dataset = RAEDataset(config.eval_path, tokenizer, config.max_seq_length) |
| |
| |
| data_collator = DataCollatorForSeq2Seq( |
| tokenizer=tokenizer, |
| padding=True, |
| max_length=config.max_seq_length, |
| pad_to_multiple_of=8, |
| ) |
| |
| |
| training_args = TrainingArguments( |
| output_dir=config.output_dir, |
| num_train_epochs=config.epochs, |
| per_device_train_batch_size=config.batch_size, |
| per_device_eval_batch_size=config.batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| learning_rate=config.learning_rate, |
| lr_scheduler_type=config.lr_scheduler, |
| warmup_ratio=config.warmup_ratio, |
| weight_decay=config.weight_decay, |
| max_grad_norm=config.max_grad_norm, |
| bf16=config.bf16, |
| logging_steps=config.logging_steps, |
| eval_strategy="steps", |
| eval_steps=config.eval_steps, |
| save_strategy="steps", |
| save_steps=config.save_steps, |
| save_total_limit=config.save_total_limit, |
| load_best_model_at_end=True, |
| report_to=["tensorboard", "wandb"], |
| remove_unused_columns=False, |
| push_to_hub=config.push_to_hub, |
| hub_model_id=config.hub_model_id if config.push_to_hub else None, |
| hub_token=os.environ.get("HF_TOKEN"), |
| ) |
| |
| |
| trainer = RAETrainer( |
| rae_config=config, |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator, |
| tokenizer=tokenizer, |
| ) |
| |
| |
| logger.info("\nπ§ Beginning RAE Training...") |
| logger.info(" The hand is slow so the mind can be fast later.\n") |
| |
| trainer.train() |
| |
| |
| logger.info("Saving final model...") |
| trainer.save_model(config.output_dir) |
| tokenizer.save_pretrained(config.output_dir) |
| |
| |
| if config.push_to_hub: |
| logger.info(f"Pushing to HuggingFace Hub: {config.hub_model_id}") |
| trainer.push_to_hub() |
| |
| logger.info("\n" + "=" * 60) |
| logger.info(" RAE Training Complete") |
| logger.info(f" Model saved: {config.output_dir}") |
| logger.info("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/rae_training_config.json" |
| train(config_path) |
|
|