File size: 14,800 Bytes
2403e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
"""
RAE Trainer β€” Custom Training Loop
═══════════════════════════════════════════════════════════════
Extends HuggingFace's Trainer with RAE multi-objective loss.

This is the FULL CONTROL path. The training loop:
1. Loads base model with QLoRA
2. Applies RAE-structured training data
3. Computes multi-phase weighted loss
4. Tracks per-phase metrics (saturation, abstraction, descent, integration)
5. Pushes trained model to HuggingFace Hub

The key difference from standard SFT:
- Loss is NOT uniform across tokens
- Abstraction + Descent phases get higher loss weight
- Coherence loss penalizes abstraction that diverges from saturation
- Compression loss rewards shorter abstractions
═══════════════════════════════════════════════════════════════
"""

import json
import os
import sys
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional

import torch
from torch.utils.data import Dataset
import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
)
from datasets import load_dataset

from rae_loss import RAELoss, RAEPhaseTokenizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rae_trainer")


# ── Configuration ─────────────────────────────────────────────

@dataclass
class RAETrainingConfig:
    """Configuration for RAE training run."""
    
    # Model
    base_model: str = "Qwen/Qwen2.5-7B-Instruct"
    quantization: str = "int4"
    torch_dtype: str = "bfloat16"
    attn_implementation: str = "flash_attention_2"
    trust_remote_code: bool = True
    
    # LoRA
    lora_r: int = 32
    lora_alpha: int = 64
    lora_dropout: float = 0.05
    lora_target_modules: list = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ])
    
    # Data
    train_path: str = "data/rae_training_data/train.jsonl"
    eval_path: str = "data/rae_training_data/validation.jsonl"
    max_seq_length: int = 4096
    
    # Training
    epochs: int = 3
    batch_size: int = 1
    gradient_accumulation_steps: int = 8
    learning_rate: float = 5e-6
    lr_scheduler: str = "cosine"
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    bf16: bool = True
    logging_steps: int = 10
    eval_steps: int = 100
    save_steps: int = 200
    save_total_limit: int = 3
    
    # RAE Loss
    rae_loss_enabled: bool = True
    phase_weights: dict = field(default_factory=lambda: {
        "saturation": 1.0,
        "abstraction": 1.5,
        "descent": 1.5,
        "integration": 1.0,
    })
    coherence_weight: float = 0.3
    compression_weight: float = 0.2
    
    # Output
    output_dir: str = "outputs/rae-trained-model"
    push_to_hub: bool = True
    hub_model_id: str = "rae-cognitive-model"
    
    @classmethod
    def from_json(cls, path: str) -> "RAETrainingConfig":
        with open(path) as f:
            data = json.load(f)
        
        # Flatten nested config
        flat = {}
        for section in data.values():
            if isinstance(section, dict):
                flat.update(section)
            
        return cls(**{k: v for k, v in flat.items() 
                     if k in cls.__dataclass_fields__ and not k.startswith("_")})


# ── RAE Dataset ───────────────────────────────────────────────

class RAEDataset(Dataset):
    """Dataset that loads RAE-structured JSONL data."""
    
    def __init__(self, path: str, tokenizer, max_length: int = 4096):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        
        with open(path) as f:
            for line in f:
                data = json.loads(line)
                self.examples.append(data)
        
        logger.info(f"Loaded {len(self.examples)} RAE examples from {path}")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        messages = example["messages"]
        
        # Apply chat template
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding=False,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        
        # Labels = input_ids (autoregressive), mask system + user tokens
        labels = input_ids.clone()
        
        # Find where assistant response starts and mask everything before
        # This ensures loss is only computed on the RAE response
        assistant_start = self._find_assistant_start(input_ids)
        if assistant_start > 0:
            labels[:assistant_start] = -100
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }
    
    def _find_assistant_start(self, input_ids: torch.Tensor) -> int:
        """Find where the assistant's RAE response begins."""
        # Look for <SATURATION> tag as the start of the RAE response
        sat_tokens = self.tokenizer.encode("<SATURATION>", add_special_tokens=False)
        ids = input_ids.tolist()
        
        for i in range(len(ids) - len(sat_tokens) + 1):
            if ids[i:i + len(sat_tokens)] == sat_tokens:
                return i
        
        # Fallback: use 30% of sequence as system/user tokens
        return int(len(ids) * 0.3)


# ── Custom RAE Trainer ────────────────────────────────────────

class RAETrainer(Trainer):
    """
    Extended Trainer with RAE multi-objective loss.
    
    This is where the handwriting effect is implemented:
    - Phase-weighted loss forces differential encoding depth
    - Coherence loss creates cross-phase binding
    - Compression loss rewards information distillation
    """
    
    def __init__(self, rae_config: RAETrainingConfig, **kwargs):
        super().__init__(**kwargs)
        self.rae_config = rae_config
        
        if rae_config.rae_loss_enabled:
            self.rae_loss_fn = RAELoss(
                phase_weights=rae_config.phase_weights,
                coherence_weight=rae_config.coherence_weight,
                compression_weight=rae_config.compression_weight,
            )
            self.phase_tokenizer = RAEPhaseTokenizer(self.tokenizer)
            logger.info("RAE multi-objective loss enabled")
            logger.info(f"  Phase weights: {rae_config.phase_weights}")
            logger.info(f"  Coherence weight: {rae_config.coherence_weight}")
            logger.info(f"  Compression weight: {rae_config.compression_weight}")
        else:
            self.rae_loss_fn = None
            self.phase_tokenizer = None
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Override compute_loss to use RAE multi-objective loss."""
        
        if not self.rae_config.rae_loss_enabled:
            return super().compute_loss(model, inputs, return_outputs, **kwargs)
        
        # Forward pass with hidden states for coherence loss
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            output_hidden_states=True,
        )
        
        logits = outputs.logits
        labels = inputs["labels"]
        
        # Get phase masks
        phase_masks = self.phase_tokenizer.get_phase_masks(inputs["input_ids"])
        
        # Get last hidden state for coherence loss
        hidden_states = outputs.hidden_states[-1] if outputs.hidden_states else None
        
        # Compute RAE loss
        loss_dict = self.rae_loss_fn(logits, labels, phase_masks, hidden_states)
        
        # Log per-phase metrics
        if self.state.global_step % self.args.logging_steps == 0:
            for phase, loss_val in loss_dict["phase_losses"].items():
                self.log({f"rae/{phase}_loss": loss_val})
            self.log({
                "rae/coherence_loss": loss_dict["coherence"].item(),
                "rae/compression_loss": loss_dict["compression"].item(),
                "rae/weighted_ce": loss_dict["weighted_ce"].item(),
            })
        
        total_loss = loss_dict["total"]
        return (total_loss, outputs) if return_outputs else total_loss


# ── Main Training Pipeline ────────────────────────────────────

def load_model_and_tokenizer(config: RAETrainingConfig):
    """Load and configure the base model with QLoRA."""
    
    logger.info(f"Loading base model: {config.base_model}")
    
    # Quantization config
    bnb_config = None
    if config.quantization == "int4":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=getattr(torch, config.torch_dtype),
            bnb_4bit_use_double_quant=True,
        )
    
    # Load model
    model_kwargs = {
        "quantization_config": bnb_config,
        "torch_dtype": getattr(torch, config.torch_dtype),
        "trust_remote_code": config.trust_remote_code,
        "device_map": "auto",
    }
    
    # Try flash attention, fall back gracefully
    try:
        model = AutoModelForCausalLM.from_pretrained(
            config.base_model,
            attn_implementation=config.attn_implementation,
            **model_kwargs,
        )
    except Exception:
        logger.warning("Flash Attention not available, using default attention")
        model = AutoModelForCausalLM.from_pretrained(
            config.base_model,
            **model_kwargs,
        )
    
    # Prepare for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # Apply LoRA
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_target_modules,
        task_type=TaskType.CAUSAL_LM,
        bias="none",
    )
    
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config.base_model,
        trust_remote_code=config.trust_remote_code,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    return model, tokenizer


def train(config_path: str = "configs/rae_training_config.json"):
    """Execute the full RAE training pipeline."""
    
    # Load config
    config = RAETrainingConfig.from_json(config_path)
    
    logger.info("=" * 60)
    logger.info("  RAE TRAINING METHODOLOGY")
    logger.info("  Recursive Abstraction Engine as Training-Time")
    logger.info("  Cognitive Installation")
    logger.info("=" * 60)
    logger.info(f"  Base model: {config.base_model}")
    logger.info(f"  RAE loss: {'ENABLED' if config.rae_loss_enabled else 'disabled'}")
    logger.info(f"  LoRA rank: {config.lora_r}")
    logger.info(f"  Epochs: {config.epochs}")
    logger.info(f"  Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
    logger.info("=" * 60)
    
    # Load model
    model, tokenizer = load_model_and_tokenizer(config)
    
    # Load datasets
    train_dataset = RAEDataset(config.train_path, tokenizer, config.max_seq_length)
    eval_dataset = RAEDataset(config.eval_path, tokenizer, config.max_seq_length)
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding=True,
        max_length=config.max_seq_length,
        pad_to_multiple_of=8,
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=config.output_dir,
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        lr_scheduler_type=config.lr_scheduler,
        warmup_ratio=config.warmup_ratio,
        weight_decay=config.weight_decay,
        max_grad_norm=config.max_grad_norm,
        bf16=config.bf16,
        logging_steps=config.logging_steps,
        eval_strategy="steps",
        eval_steps=config.eval_steps,
        save_strategy="steps",
        save_steps=config.save_steps,
        save_total_limit=config.save_total_limit,
        load_best_model_at_end=True,
        report_to=["tensorboard", "wandb"],
        remove_unused_columns=False,
        push_to_hub=config.push_to_hub,
        hub_model_id=config.hub_model_id if config.push_to_hub else None,
        hub_token=os.environ.get("HF_TOKEN"),
    )
    
    # Initialize RAE Trainer
    trainer = RAETrainer(
        rae_config=config,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )
    
    # Train!
    logger.info("\n🧠 Beginning RAE Training...")
    logger.info("   The hand is slow so the mind can be fast later.\n")
    
    trainer.train()
    
    # Save final model
    logger.info("Saving final model...")
    trainer.save_model(config.output_dir)
    tokenizer.save_pretrained(config.output_dir)
    
    # Push to hub
    if config.push_to_hub:
        logger.info(f"Pushing to HuggingFace Hub: {config.hub_model_id}")
        trainer.push_to_hub()
    
    logger.info("\n" + "=" * 60)
    logger.info("  RAE Training Complete")
    logger.info(f"  Model saved: {config.output_dir}")
    logger.info("=" * 60)


if __name__ == "__main__":
    config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/rae_training_config.json"
    train(config_path)