Spaces:
Running
Running
File size: 5,265 Bytes
192065b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
MDLM Training Loop — Train on governed corpus, measure convergence.
Usage:
python -m pipeline.mdlm.train --corpus corpus/CORPUS-FINAL --schedule A --epochs 50
Convergence metric: per-epoch accuracy on masked token prediction.
PHI_CONTRACTION_RATE threshold: loss ratio between consecutive epochs.
"""
from __future__ import annotations
import argparse
import json
import math
import random
import sys
import time
from pathlib import Path
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
try:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
print("PyTorch not available. Install with: pip install torch")
from pipeline.mdlm.tokenizer import load_corpus, pad_sequence, VOCAB_SIZE, decode
from pipeline.mdlm.model import (
MaskingSchedule, StructureModel, compute_loss, generate,
)
PHI_CONTRACTION_RATE = 0.381966 # (3 - sqrt(5)) / 2
def train(
corpus_dir: str,
schedule: MaskingSchedule,
epochs: int = 50,
batch_size: int = 64,
lr: float = 1e-3,
total_timesteps: int = 100,
max_len: int = 40,
device_name: str = "cpu",
):
"""Train MDLM on governed corpus."""
print(f"=== MDLM Training ===")
print(f"Schedule: {schedule.value} ({schedule.name})")
print(f"Corpus: {corpus_dir}")
print(f"Epochs: {epochs}, Batch: {batch_size}, LR: {lr}")
print()
# Load and encode corpus
sequences = load_corpus(corpus_dir)
print(f"Loaded {len(sequences)} sequences")
# Pad to fixed length
padded = [pad_sequence(seq, max_len) for seq in sequences]
data = torch.tensor(padded, dtype=torch.long)
dataset = TensorDataset(data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Model
device = torch.device(device_name)
model = StructureModel(
vocab_size=VOCAB_SIZE, d_model=128, nhead=4,
num_layers=4, max_len=max_len,
).to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {param_count:,}")
print()
optimizer = optim.Adam(model.parameters(), lr=lr)
# Training loop
history = []
prev_loss = float("inf")
for epoch in range(1, epochs + 1):
model.train()
epoch_loss = 0.0
epoch_batches = 0
for (batch,) in loader:
batch = batch.to(device)
# Random timestep per batch
timestep = random.randint(0, total_timesteps)
loss = compute_loss(model, batch, schedule, timestep, total_timesteps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_batches += 1
avg_loss = epoch_loss / max(epoch_batches, 1)
# Convergence ratio: loss_t / loss_{t-1}
ratio = avg_loss / prev_loss if prev_loss > 0 and prev_loss != float("inf") else 1.0
prev_loss = avg_loss
# Regime classification
if ratio < PHI_CONTRACTION_RATE:
regime = "CONTRACTING"
elif ratio < 0.8:
regime = "BOUNDARY"
elif ratio < 1.0:
regime = "CRITICAL"
else:
regime = "DIVERGENT"
history.append({
"epoch": epoch,
"loss": avg_loss,
"ratio": ratio,
"regime": regime,
})
if epoch <= 5 or epoch % 5 == 0 or epoch == epochs:
print(f" Epoch {epoch:3d}: loss={avg_loss:.4f} ratio={ratio:.4f} [{regime}]")
print()
# Generate samples
print("Generating 5 samples...")
model.eval()
samples = generate(model, 5, max_len, schedule, total_timesteps)
for i in range(5):
seq = samples[i].tolist()
print(f" Sample {i}: {decode(seq)}")
print()
# Summary
final = history[-1]
contracting_epochs = sum(1 for h in history if h["regime"] == "CONTRACTING")
print(f"=== SUMMARY ===")
print(f"Final loss: {final['loss']:.4f}")
print(f"Final ratio: {final['ratio']:.4f} ({final['regime']})")
print(f"Contracting epochs: {contracting_epochs}/{len(history)}")
print(f"Schedule: {schedule.value} ({schedule.name})")
return history
def main():
if not HAS_TORCH:
sys.exit(1)
parser = argparse.ArgumentParser(description="MDLM Training")
parser.add_argument("--corpus", required=True, help="Corpus directory")
parser.add_argument("--schedule", default="A", choices=["A", "B", "C", "D"])
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--device", default="cpu")
parser.add_argument("--output", default=None, help="Save history to JSON")
args = parser.parse_args()
schedule = MaskingSchedule(args.schedule)
history = train(
args.corpus, schedule, args.epochs, args.batch_size,
args.lr, device_name=args.device,
)
if args.output:
with open(args.output, "w") as f:
json.dump(history, f, indent=2)
print(f"History saved to {args.output}")
if __name__ == "__main__":
main()
|