""" RAE Trainer — Custom Training Loop ═══════════════════════════════════════════════════════════════ Extends HuggingFace's Trainer with RAE multi-objective loss. This is the FULL CONTROL path. The training loop: 1. Loads base model with QLoRA 2. Applies RAE-structured training data 3. Computes multi-phase weighted loss 4. Tracks per-phase metrics (saturation, abstraction, descent, integration) 5. Pushes trained model to HuggingFace Hub The key difference from standard SFT: - Loss is NOT uniform across tokens - Abstraction + Descent phases get higher loss weight - Coherence loss penalizes abstraction that diverges from saturation - Compression loss rewards shorter abstractions ═══════════════════════════════════════════════════════════════ """ import json import os import sys import logging from pathlib import Path from dataclasses import dataclass, field from typing import Optional import torch from torch.utils.data import Dataset import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForSeq2Seq, ) from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, ) from datasets import load_dataset from rae_loss import RAELoss, RAEPhaseTokenizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger("rae_trainer") # ── Configuration ───────────────────────────────────────────── @dataclass class RAETrainingConfig: """Configuration for RAE training run.""" # Model base_model: str = "Qwen/Qwen2.5-7B-Instruct" quantization: str = "int4" torch_dtype: str = "bfloat16" attn_implementation: str = "flash_attention_2" trust_remote_code: bool = True # LoRA lora_r: int = 32 lora_alpha: int = 64 lora_dropout: float = 0.05 lora_target_modules: list = field(default_factory=lambda: [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ]) # Data train_path: str = "data/rae_training_data/train.jsonl" eval_path: str = "data/rae_training_data/validation.jsonl" max_seq_length: int = 4096 # Training epochs: int = 3 batch_size: int = 1 gradient_accumulation_steps: int = 8 learning_rate: float = 5e-6 lr_scheduler: str = "cosine" warmup_ratio: float = 0.1 weight_decay: float = 0.01 max_grad_norm: float = 1.0 bf16: bool = True logging_steps: int = 10 eval_steps: int = 100 save_steps: int = 200 save_total_limit: int = 3 # RAE Loss rae_loss_enabled: bool = True phase_weights: dict = field(default_factory=lambda: { "saturation": 1.0, "abstraction": 1.5, "descent": 1.5, "integration": 1.0, }) coherence_weight: float = 0.3 compression_weight: float = 0.2 # Output output_dir: str = "outputs/rae-trained-model" push_to_hub: bool = True hub_model_id: str = "rae-cognitive-model" @classmethod def from_json(cls, path: str) -> "RAETrainingConfig": with open(path) as f: data = json.load(f) # Flatten nested config flat = {} for section in data.values(): if isinstance(section, dict): flat.update(section) return cls(**{k: v for k, v in flat.items() if k in cls.__dataclass_fields__ and not k.startswith("_")}) # ── RAE Dataset ─────────────────────────────────────────────── class RAEDataset(Dataset): """Dataset that loads RAE-structured JSONL data.""" def __init__(self, path: str, tokenizer, max_length: int = 4096): self.tokenizer = tokenizer self.max_length = max_length self.examples = [] with open(path) as f: for line in f: data = json.loads(line) self.examples.append(data) logger.info(f"Loaded {len(self.examples)} RAE examples from {path}") def __len__(self): return len(self.examples) def __getitem__(self, idx): example = self.examples[idx] messages = example["messages"] # Apply chat template text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) # Tokenize encoding = self.tokenizer( text, max_length=self.max_length, truncation=True, padding=False, return_tensors="pt", ) input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Labels = input_ids (autoregressive), mask system + user tokens labels = input_ids.clone() # Find where assistant response starts and mask everything before # This ensures loss is only computed on the RAE response assistant_start = self._find_assistant_start(input_ids) if assistant_start > 0: labels[:assistant_start] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _find_assistant_start(self, input_ids: torch.Tensor) -> int: """Find where the assistant's RAE response begins.""" # Look for tag as the start of the RAE response sat_tokens = self.tokenizer.encode("", add_special_tokens=False) ids = input_ids.tolist() for i in range(len(ids) - len(sat_tokens) + 1): if ids[i:i + len(sat_tokens)] == sat_tokens: return i # Fallback: use 30% of sequence as system/user tokens return int(len(ids) * 0.3) # ── Custom RAE Trainer ──────────────────────────────────────── class RAETrainer(Trainer): """ Extended Trainer with RAE multi-objective loss. This is where the handwriting effect is implemented: - Phase-weighted loss forces differential encoding depth - Coherence loss creates cross-phase binding - Compression loss rewards information distillation """ def __init__(self, rae_config: RAETrainingConfig, **kwargs): super().__init__(**kwargs) self.rae_config = rae_config if rae_config.rae_loss_enabled: self.rae_loss_fn = RAELoss( phase_weights=rae_config.phase_weights, coherence_weight=rae_config.coherence_weight, compression_weight=rae_config.compression_weight, ) self.phase_tokenizer = RAEPhaseTokenizer(self.tokenizer) logger.info("RAE multi-objective loss enabled") logger.info(f" Phase weights: {rae_config.phase_weights}") logger.info(f" Coherence weight: {rae_config.coherence_weight}") logger.info(f" Compression weight: {rae_config.compression_weight}") else: self.rae_loss_fn = None self.phase_tokenizer = None def compute_loss(self, model, inputs, return_outputs=False, **kwargs): """Override compute_loss to use RAE multi-objective loss.""" if not self.rae_config.rae_loss_enabled: return super().compute_loss(model, inputs, return_outputs, **kwargs) # Forward pass with hidden states for coherence loss outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], output_hidden_states=True, ) logits = outputs.logits labels = inputs["labels"] # Get phase masks phase_masks = self.phase_tokenizer.get_phase_masks(inputs["input_ids"]) # Get last hidden state for coherence loss hidden_states = outputs.hidden_states[-1] if outputs.hidden_states else None # Compute RAE loss loss_dict = self.rae_loss_fn(logits, labels, phase_masks, hidden_states) # Log per-phase metrics if self.state.global_step % self.args.logging_steps == 0: for phase, loss_val in loss_dict["phase_losses"].items(): self.log({f"rae/{phase}_loss": loss_val}) self.log({ "rae/coherence_loss": loss_dict["coherence"].item(), "rae/compression_loss": loss_dict["compression"].item(), "rae/weighted_ce": loss_dict["weighted_ce"].item(), }) total_loss = loss_dict["total"] return (total_loss, outputs) if return_outputs else total_loss # ── Main Training Pipeline ──────────────────────────────────── def load_model_and_tokenizer(config: RAETrainingConfig): """Load and configure the base model with QLoRA.""" logger.info(f"Loading base model: {config.base_model}") # Quantization config bnb_config = None if config.quantization == "int4": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=getattr(torch, config.torch_dtype), bnb_4bit_use_double_quant=True, ) # Load model model_kwargs = { "quantization_config": bnb_config, "torch_dtype": getattr(torch, config.torch_dtype), "trust_remote_code": config.trust_remote_code, "device_map": "auto", } # Try flash attention, fall back gracefully try: model = AutoModelForCausalLM.from_pretrained( config.base_model, attn_implementation=config.attn_implementation, **model_kwargs, ) except Exception: logger.warning("Flash Attention not available, using default attention") model = AutoModelForCausalLM.from_pretrained( config.base_model, **model_kwargs, ) # Prepare for k-bit training model = prepare_model_for_kbit_training(model) # Apply LoRA lora_config = LoraConfig( r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, target_modules=config.lora_target_modules, task_type=TaskType.CAUSAL_LM, bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( config.base_model, trust_remote_code=config.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" return model, tokenizer def train(config_path: str = "configs/rae_training_config.json"): """Execute the full RAE training pipeline.""" # Load config config = RAETrainingConfig.from_json(config_path) logger.info("=" * 60) logger.info(" RAE TRAINING METHODOLOGY") logger.info(" Recursive Abstraction Engine as Training-Time") logger.info(" Cognitive Installation") logger.info("=" * 60) logger.info(f" Base model: {config.base_model}") logger.info(f" RAE loss: {'ENABLED' if config.rae_loss_enabled else 'disabled'}") logger.info(f" LoRA rank: {config.lora_r}") logger.info(f" Epochs: {config.epochs}") logger.info(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") logger.info("=" * 60) # Load model model, tokenizer = load_model_and_tokenizer(config) # Load datasets train_dataset = RAEDataset(config.train_path, tokenizer, config.max_seq_length) eval_dataset = RAEDataset(config.eval_path, tokenizer, config.max_seq_length) # Data collator data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, padding=True, max_length=config.max_seq_length, pad_to_multiple_of=8, ) # Training arguments training_args = TrainingArguments( output_dir=config.output_dir, num_train_epochs=config.epochs, per_device_train_batch_size=config.batch_size, per_device_eval_batch_size=config.batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, learning_rate=config.learning_rate, lr_scheduler_type=config.lr_scheduler, warmup_ratio=config.warmup_ratio, weight_decay=config.weight_decay, max_grad_norm=config.max_grad_norm, bf16=config.bf16, logging_steps=config.logging_steps, eval_strategy="steps", eval_steps=config.eval_steps, save_strategy="steps", save_steps=config.save_steps, save_total_limit=config.save_total_limit, load_best_model_at_end=True, report_to=["tensorboard", "wandb"], remove_unused_columns=False, push_to_hub=config.push_to_hub, hub_model_id=config.hub_model_id if config.push_to_hub else None, hub_token=os.environ.get("HF_TOKEN"), ) # Initialize RAE Trainer trainer = RAETrainer( rae_config=config, model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, tokenizer=tokenizer, ) # Train! logger.info("\n🧠 Beginning RAE Training...") logger.info(" The hand is slow so the mind can be fast later.\n") trainer.train() # Save final model logger.info("Saving final model...") trainer.save_model(config.output_dir) tokenizer.save_pretrained(config.output_dir) # Push to hub if config.push_to_hub: logger.info(f"Pushing to HuggingFace Hub: {config.hub_model_id}") trainer.push_to_hub() logger.info("\n" + "=" * 60) logger.info(" RAE Training Complete") logger.info(f" Model saved: {config.output_dir}") logger.info("=" * 60) if __name__ == "__main__": config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/rae_training_config.json" train(config_path)