| |
| """ |
| AAM Diffusion LLM — Minimal Training Script for CPU |
| |
| Trains a very small AAM Diffusion LLM model on CPU. |
| """ |
|
|
| import sys |
| import json |
| import time |
| import logging |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| import torch |
| import numpy as np |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| logger = logging.getLogger("train") |
|
|
| def main(): |
| from diffusion_llm.config.model_config import ( |
| AamDiffusionConfig, ModelConfig, DiffusionConfig, |
| GraphEncoderConfig, TokenizerConfig, TrainingConfig, InferenceConfig, |
| ) |
| from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel |
| from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer |
| from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn |
| from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator |
| from torch.utils.data import DataLoader |
|
|
| output_dir = Path("./aam-diffusion-v1") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| data_dir = output_dir / "data" |
| data_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| logger.info("STEP 1: Generating synthetic data...") |
| train_path, val_path = SyntheticDataGenerator.generate_training_split( |
| output_dir=data_dir, n_train=200, n_val=20, language="id", seed=42, |
| ) |
|
|
| |
| logger.info("STEP 2: Training tokenizer...") |
| tokenizer = AamTokenizer() |
|
|
| texts = [] |
| with open(train_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| data = json.loads(line) |
| for key in ["narrative", "trigger"]: |
| if data.get(key): |
| texts.append(data[key]) |
| for key in ["evidence_nodes", "anomalies", "reasoning_steps"]: |
| for item in data.get(key, []): |
| texts.append(item) |
| except json.JSONDecodeError: |
| continue |
|
|
| tokenizer.train(texts, vocab_size=2000) |
| tokenizer.save(data_dir / "tokenizer.json") |
| actual_vocab = tokenizer.vocab_size |
| logger.info(f" Tokenizer: vocab_size={actual_vocab}, merges={len(tokenizer.merges)}") |
|
|
| |
| config = AamDiffusionConfig( |
| model=ModelConfig( |
| d_model=128, |
| n_layers=2, |
| n_heads=4, |
| d_ff=256, |
| vocab_size=actual_vocab, |
| max_seq_len=64, |
| pos_encoding_type="learned", |
| use_flash_attention=False, |
| norm_type="layernorm", |
| init_std=0.02, |
| ), |
| diffusion=DiffusionConfig( |
| n_timesteps=100, |
| n_inference_steps=10, |
| schedule_type="cosine", |
| prediction_type="epsilon", |
| loss_type="mse", |
| loss_weighting="none", |
| ), |
| graph_encoder=GraphEncoderConfig( |
| d_graph=64, |
| n_graph_layers=1, |
| n_graph_heads=2, |
| max_evidence_nodes=5, |
| max_compositions=3, |
| max_anomalies=3, |
| max_reasoning_steps=3, |
| conditioning_method="cross_attention", |
| embed_confidence=False, |
| embed_temporal=False, |
| ), |
| tokenizer=TokenizerConfig(bpe_vocab_size=2000), |
| training=TrainingConfig( |
| batch_size=4, |
| learning_rate=1e-3, |
| max_steps=100, |
| warmup_steps=10, |
| use_amp=False, |
| num_workers=0, |
| grad_clip_norm=1.0, |
| ), |
| inference=InferenceConfig(n_steps=10), |
| model_name="aam-diffusion-v1.0", |
| output_dir=str(output_dir), |
| seed=42, |
| ) |
|
|
| |
| logger.info("STEP 3: Creating model...") |
| model = AamDiffusionModel(config) |
| n_params = model.get_num_params() |
| logger.info(f" Parameters: {model._format_params(n_params)} ({n_params:,})") |
|
|
| |
| logger.info("STEP 4: Creating dataloaders...") |
| train_dataset = GraphNarrativeDataset( |
| data_path=train_path, tokenizer=tokenizer, |
| max_seq_len=config.model.max_seq_len, |
| max_evidence=config.graph_encoder.max_evidence_nodes, |
| max_anomalies=config.graph_encoder.max_anomalies, |
| max_reasoning=config.graph_encoder.max_reasoning_steps, |
| augment=True, |
| ) |
| val_dataset = GraphNarrativeDataset( |
| data_path=val_path, tokenizer=tokenizer, |
| max_seq_len=config.model.max_seq_len, |
| max_evidence=config.graph_encoder.max_evidence_nodes, |
| max_anomalies=config.graph_encoder.max_anomalies, |
| max_reasoning=config.graph_encoder.max_reasoning_steps, |
| augment=False, |
| ) |
|
|
| train_loader = DataLoader( |
| train_dataset, batch_size=4, shuffle=True, |
| num_workers=0, collate_fn=collate_fn, |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=4, shuffle=False, |
| num_workers=0, collate_fn=collate_fn, |
| ) |
|
|
| |
| logger.info("STEP 5: Training...") |
| device = torch.device("cpu") |
| model.to(device) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) |
| max_steps = 100 |
|
|
| start_time = time.time() |
| global_step = 0 |
| train_losses = [] |
|
|
| for epoch in range(50): |
| model.train() |
| for batch in train_loader: |
| if global_step >= max_steps: |
| break |
|
|
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
|
|
| batch_size = batch["token_ids"].shape[0] |
| t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device) |
|
|
| predicted, target = model( |
| token_ids=batch["token_ids"], |
| timestep=t, |
| evidence_ids=batch.get("evidence_ids"), |
| evidence_confidence=batch.get("evidence_confidence"), |
| anomaly_ids=batch.get("anomaly_ids"), |
| anomaly_confidence=batch.get("anomaly_confidence"), |
| reasoning_ids=batch.get("reasoning_ids"), |
| reasoning_confidence=batch.get("reasoning_confidence"), |
| source_trust=batch.get("source_trust"), |
| ) |
|
|
| loss = model.compute_loss(predicted, target, t) |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| train_losses.append(loss.item()) |
| global_step += 1 |
|
|
| if global_step % 10 == 0: |
| avg = sum(train_losses[-10:]) / len(train_losses[-10:]) |
| elapsed = time.time() - start_time |
| logger.info(f" Step {global_step}/{max_steps} | Loss: {avg:.4f} | Time: {elapsed:.1f}s") |
|
|
| if global_step >= max_steps: |
| break |
|
|
| |
| logger.info("STEP 6: Evaluating...") |
| model.eval() |
| val_losses = [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| batch_size = batch["token_ids"].shape[0] |
| t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device) |
|
|
| predicted, target = model( |
| token_ids=batch["token_ids"], |
| timestep=t, |
| evidence_ids=batch.get("evidence_ids"), |
| evidence_confidence=batch.get("evidence_confidence"), |
| anomaly_ids=batch.get("anomaly_ids"), |
| anomaly_confidence=batch.get("anomaly_confidence"), |
| reasoning_ids=batch.get("reasoning_ids"), |
| reasoning_confidence=batch.get("reasoning_confidence"), |
| source_trust=batch.get("source_trust"), |
| ) |
| loss = model.compute_loss(predicted, target, t) |
| val_losses.append(loss.item()) |
|
|
| avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else 0 |
| logger.info(f" Val loss: {avg_val_loss:.4f}") |
|
|
| |
| logger.info("STEP 7: Saving model...") |
|
|
| |
| model_path = output_dir / "model.pt" |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": config.to_dict(), |
| }, model_path) |
|
|
| |
| |
| config.to_json(output_dir / "config.json") |
|
|
| elapsed = time.time() - start_time |
| logger.info(f"\n DONE! {global_step} steps in {elapsed:.1f}s") |
| logger.info(f" Final train loss: {train_losses[-1]:.4f}") |
| logger.info(f" Val loss: {avg_val_loss:.4f}") |
| logger.info(f" Parameters: {model._format_params(n_params)}") |
| logger.info(f" Output: {output_dir}") |
|
|
| return model, tokenizer, config, output_dir |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|