rae-training / src /train_rae.py
TrueV1sion123's picture
Upload src/train_rae.py with huggingface_hub
2403e59 verified
"""
RAE Trainer β€” Custom Training Loop
═══════════════════════════════════════════════════════════════
Extends HuggingFace's Trainer with RAE multi-objective loss.
This is the FULL CONTROL path. The training loop:
1. Loads base model with QLoRA
2. Applies RAE-structured training data
3. Computes multi-phase weighted loss
4. Tracks per-phase metrics (saturation, abstraction, descent, integration)
5. Pushes trained model to HuggingFace Hub
The key difference from standard SFT:
- Loss is NOT uniform across tokens
- Abstraction + Descent phases get higher loss weight
- Coherence loss penalizes abstraction that diverges from saturation
- Compression loss rewards shorter abstractions
═══════════════════════════════════════════════════════════════
"""
import json
import os
import sys
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional
import torch
from torch.utils.data import Dataset
import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType,
)
from datasets import load_dataset
from rae_loss import RAELoss, RAEPhaseTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rae_trainer")
# ── Configuration ─────────────────────────────────────────────
@dataclass
class RAETrainingConfig:
"""Configuration for RAE training run."""
# Model
base_model: str = "Qwen/Qwen2.5-7B-Instruct"
quantization: str = "int4"
torch_dtype: str = "bfloat16"
attn_implementation: str = "flash_attention_2"
trust_remote_code: bool = True
# LoRA
lora_r: int = 32
lora_alpha: int = 64
lora_dropout: float = 0.05
lora_target_modules: list = field(default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
])
# Data
train_path: str = "data/rae_training_data/train.jsonl"
eval_path: str = "data/rae_training_data/validation.jsonl"
max_seq_length: int = 4096
# Training
epochs: int = 3
batch_size: int = 1
gradient_accumulation_steps: int = 8
learning_rate: float = 5e-6
lr_scheduler: str = "cosine"
warmup_ratio: float = 0.1
weight_decay: float = 0.01
max_grad_norm: float = 1.0
bf16: bool = True
logging_steps: int = 10
eval_steps: int = 100
save_steps: int = 200
save_total_limit: int = 3
# RAE Loss
rae_loss_enabled: bool = True
phase_weights: dict = field(default_factory=lambda: {
"saturation": 1.0,
"abstraction": 1.5,
"descent": 1.5,
"integration": 1.0,
})
coherence_weight: float = 0.3
compression_weight: float = 0.2
# Output
output_dir: str = "outputs/rae-trained-model"
push_to_hub: bool = True
hub_model_id: str = "rae-cognitive-model"
@classmethod
def from_json(cls, path: str) -> "RAETrainingConfig":
with open(path) as f:
data = json.load(f)
# Flatten nested config
flat = {}
for section in data.values():
if isinstance(section, dict):
flat.update(section)
return cls(**{k: v for k, v in flat.items()
if k in cls.__dataclass_fields__ and not k.startswith("_")})
# ── RAE Dataset ───────────────────────────────────────────────
class RAEDataset(Dataset):
"""Dataset that loads RAE-structured JSONL data."""
def __init__(self, path: str, tokenizer, max_length: int = 4096):
self.tokenizer = tokenizer
self.max_length = max_length
self.examples = []
with open(path) as f:
for line in f:
data = json.loads(line)
self.examples.append(data)
logger.info(f"Loaded {len(self.examples)} RAE examples from {path}")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
example = self.examples[idx]
messages = example["messages"]
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
# Tokenize
encoding = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding=False,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Labels = input_ids (autoregressive), mask system + user tokens
labels = input_ids.clone()
# Find where assistant response starts and mask everything before
# This ensures loss is only computed on the RAE response
assistant_start = self._find_assistant_start(input_ids)
if assistant_start > 0:
labels[:assistant_start] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _find_assistant_start(self, input_ids: torch.Tensor) -> int:
"""Find where the assistant's RAE response begins."""
# Look for <SATURATION> tag as the start of the RAE response
sat_tokens = self.tokenizer.encode("<SATURATION>", add_special_tokens=False)
ids = input_ids.tolist()
for i in range(len(ids) - len(sat_tokens) + 1):
if ids[i:i + len(sat_tokens)] == sat_tokens:
return i
# Fallback: use 30% of sequence as system/user tokens
return int(len(ids) * 0.3)
# ── Custom RAE Trainer ────────────────────────────────────────
class RAETrainer(Trainer):
"""
Extended Trainer with RAE multi-objective loss.
This is where the handwriting effect is implemented:
- Phase-weighted loss forces differential encoding depth
- Coherence loss creates cross-phase binding
- Compression loss rewards information distillation
"""
def __init__(self, rae_config: RAETrainingConfig, **kwargs):
super().__init__(**kwargs)
self.rae_config = rae_config
if rae_config.rae_loss_enabled:
self.rae_loss_fn = RAELoss(
phase_weights=rae_config.phase_weights,
coherence_weight=rae_config.coherence_weight,
compression_weight=rae_config.compression_weight,
)
self.phase_tokenizer = RAEPhaseTokenizer(self.tokenizer)
logger.info("RAE multi-objective loss enabled")
logger.info(f" Phase weights: {rae_config.phase_weights}")
logger.info(f" Coherence weight: {rae_config.coherence_weight}")
logger.info(f" Compression weight: {rae_config.compression_weight}")
else:
self.rae_loss_fn = None
self.phase_tokenizer = None
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""Override compute_loss to use RAE multi-objective loss."""
if not self.rae_config.rae_loss_enabled:
return super().compute_loss(model, inputs, return_outputs, **kwargs)
# Forward pass with hidden states for coherence loss
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True,
)
logits = outputs.logits
labels = inputs["labels"]
# Get phase masks
phase_masks = self.phase_tokenizer.get_phase_masks(inputs["input_ids"])
# Get last hidden state for coherence loss
hidden_states = outputs.hidden_states[-1] if outputs.hidden_states else None
# Compute RAE loss
loss_dict = self.rae_loss_fn(logits, labels, phase_masks, hidden_states)
# Log per-phase metrics
if self.state.global_step % self.args.logging_steps == 0:
for phase, loss_val in loss_dict["phase_losses"].items():
self.log({f"rae/{phase}_loss": loss_val})
self.log({
"rae/coherence_loss": loss_dict["coherence"].item(),
"rae/compression_loss": loss_dict["compression"].item(),
"rae/weighted_ce": loss_dict["weighted_ce"].item(),
})
total_loss = loss_dict["total"]
return (total_loss, outputs) if return_outputs else total_loss
# ── Main Training Pipeline ────────────────────────────────────
def load_model_and_tokenizer(config: RAETrainingConfig):
"""Load and configure the base model with QLoRA."""
logger.info(f"Loading base model: {config.base_model}")
# Quantization config
bnb_config = None
if config.quantization == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=getattr(torch, config.torch_dtype),
bnb_4bit_use_double_quant=True,
)
# Load model
model_kwargs = {
"quantization_config": bnb_config,
"torch_dtype": getattr(torch, config.torch_dtype),
"trust_remote_code": config.trust_remote_code,
"device_map": "auto",
}
# Try flash attention, fall back gracefully
try:
model = AutoModelForCausalLM.from_pretrained(
config.base_model,
attn_implementation=config.attn_implementation,
**model_kwargs,
)
except Exception:
logger.warning("Flash Attention not available, using default attention")
model = AutoModelForCausalLM.from_pretrained(
config.base_model,
**model_kwargs,
)
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
# Apply LoRA
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=config.lora_target_modules,
task_type=TaskType.CAUSAL_LM,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.base_model,
trust_remote_code=config.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return model, tokenizer
def train(config_path: str = "configs/rae_training_config.json"):
"""Execute the full RAE training pipeline."""
# Load config
config = RAETrainingConfig.from_json(config_path)
logger.info("=" * 60)
logger.info(" RAE TRAINING METHODOLOGY")
logger.info(" Recursive Abstraction Engine as Training-Time")
logger.info(" Cognitive Installation")
logger.info("=" * 60)
logger.info(f" Base model: {config.base_model}")
logger.info(f" RAE loss: {'ENABLED' if config.rae_loss_enabled else 'disabled'}")
logger.info(f" LoRA rank: {config.lora_r}")
logger.info(f" Epochs: {config.epochs}")
logger.info(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
logger.info("=" * 60)
# Load model
model, tokenizer = load_model_and_tokenizer(config)
# Load datasets
train_dataset = RAEDataset(config.train_path, tokenizer, config.max_seq_length)
eval_dataset = RAEDataset(config.eval_path, tokenizer, config.max_seq_length)
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
max_length=config.max_seq_length,
pad_to_multiple_of=8,
)
# Training arguments
training_args = TrainingArguments(
output_dir=config.output_dir,
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
lr_scheduler_type=config.lr_scheduler,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
max_grad_norm=config.max_grad_norm,
bf16=config.bf16,
logging_steps=config.logging_steps,
eval_strategy="steps",
eval_steps=config.eval_steps,
save_strategy="steps",
save_steps=config.save_steps,
save_total_limit=config.save_total_limit,
load_best_model_at_end=True,
report_to=["tensorboard", "wandb"],
remove_unused_columns=False,
push_to_hub=config.push_to_hub,
hub_model_id=config.hub_model_id if config.push_to_hub else None,
hub_token=os.environ.get("HF_TOKEN"),
)
# Initialize RAE Trainer
trainer = RAETrainer(
rae_config=config,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
# Train!
logger.info("\n🧠 Beginning RAE Training...")
logger.info(" The hand is slow so the mind can be fast later.\n")
trainer.train()
# Save final model
logger.info("Saving final model...")
trainer.save_model(config.output_dir)
tokenizer.save_pretrained(config.output_dir)
# Push to hub
if config.push_to_hub:
logger.info(f"Pushing to HuggingFace Hub: {config.hub_model_id}")
trainer.push_to_hub()
logger.info("\n" + "=" * 60)
logger.info(" RAE Training Complete")
logger.info(f" Model saved: {config.output_dir}")
logger.info("=" * 60)
if __name__ == "__main__":
config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/rae_training_config.json"
train(config_path)