MetaCortex-Dynamics commited on
Commit
192065b
·
verified ·
1 Parent(s): 83b737d

Create pipeline/mdlm/train.py

Browse files
Files changed (1) hide show
  1. pipeline/mdlm/train.py +179 -0
pipeline/mdlm/train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MDLM Training Loop — Train on governed corpus, measure convergence.
3
+
4
+ Usage:
5
+ python -m pipeline.mdlm.train --corpus corpus/CORPUS-FINAL --schedule A --epochs 50
6
+
7
+ Convergence metric: per-epoch accuracy on masked token prediction.
8
+ PHI_CONTRACTION_RATE threshold: loss ratio between consecutive epochs.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import json
15
+ import math
16
+ import random
17
+ import sys
18
+ import time
19
+ from pathlib import Path
20
+
21
+ sys.stdout.reconfigure(encoding="utf-8", errors="replace")
22
+
23
+ try:
24
+ import torch
25
+ import torch.optim as optim
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+ HAS_TORCH = True
28
+ except ImportError:
29
+ HAS_TORCH = False
30
+ print("PyTorch not available. Install with: pip install torch")
31
+
32
+ from pipeline.mdlm.tokenizer import load_corpus, pad_sequence, VOCAB_SIZE, decode
33
+ from pipeline.mdlm.model import (
34
+ MaskingSchedule, StructureModel, compute_loss, generate,
35
+ )
36
+
37
+ PHI_CONTRACTION_RATE = 0.381966 # (3 - sqrt(5)) / 2
38
+
39
+
40
+ def train(
41
+ corpus_dir: str,
42
+ schedule: MaskingSchedule,
43
+ epochs: int = 50,
44
+ batch_size: int = 64,
45
+ lr: float = 1e-3,
46
+ total_timesteps: int = 100,
47
+ max_len: int = 40,
48
+ device_name: str = "cpu",
49
+ ):
50
+ """Train MDLM on governed corpus."""
51
+ print(f"=== MDLM Training ===")
52
+ print(f"Schedule: {schedule.value} ({schedule.name})")
53
+ print(f"Corpus: {corpus_dir}")
54
+ print(f"Epochs: {epochs}, Batch: {batch_size}, LR: {lr}")
55
+ print()
56
+
57
+ # Load and encode corpus
58
+ sequences = load_corpus(corpus_dir)
59
+ print(f"Loaded {len(sequences)} sequences")
60
+
61
+ # Pad to fixed length
62
+ padded = [pad_sequence(seq, max_len) for seq in sequences]
63
+ data = torch.tensor(padded, dtype=torch.long)
64
+ dataset = TensorDataset(data)
65
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
66
+
67
+ # Model
68
+ device = torch.device(device_name)
69
+ model = StructureModel(
70
+ vocab_size=VOCAB_SIZE, d_model=128, nhead=4,
71
+ num_layers=4, max_len=max_len,
72
+ ).to(device)
73
+
74
+ param_count = sum(p.numel() for p in model.parameters())
75
+ print(f"Model parameters: {param_count:,}")
76
+ print()
77
+
78
+ optimizer = optim.Adam(model.parameters(), lr=lr)
79
+
80
+ # Training loop
81
+ history = []
82
+ prev_loss = float("inf")
83
+
84
+ for epoch in range(1, epochs + 1):
85
+ model.train()
86
+ epoch_loss = 0.0
87
+ epoch_batches = 0
88
+
89
+ for (batch,) in loader:
90
+ batch = batch.to(device)
91
+ # Random timestep per batch
92
+ timestep = random.randint(0, total_timesteps)
93
+
94
+ loss = compute_loss(model, batch, schedule, timestep, total_timesteps)
95
+
96
+ optimizer.zero_grad()
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ epoch_loss += loss.item()
101
+ epoch_batches += 1
102
+
103
+ avg_loss = epoch_loss / max(epoch_batches, 1)
104
+
105
+ # Convergence ratio: loss_t / loss_{t-1}
106
+ ratio = avg_loss / prev_loss if prev_loss > 0 and prev_loss != float("inf") else 1.0
107
+ prev_loss = avg_loss
108
+
109
+ # Regime classification
110
+ if ratio < PHI_CONTRACTION_RATE:
111
+ regime = "CONTRACTING"
112
+ elif ratio < 0.8:
113
+ regime = "BOUNDARY"
114
+ elif ratio < 1.0:
115
+ regime = "CRITICAL"
116
+ else:
117
+ regime = "DIVERGENT"
118
+
119
+ history.append({
120
+ "epoch": epoch,
121
+ "loss": avg_loss,
122
+ "ratio": ratio,
123
+ "regime": regime,
124
+ })
125
+
126
+ if epoch <= 5 or epoch % 5 == 0 or epoch == epochs:
127
+ print(f" Epoch {epoch:3d}: loss={avg_loss:.4f} ratio={ratio:.4f} [{regime}]")
128
+
129
+ print()
130
+
131
+ # Generate samples
132
+ print("Generating 5 samples...")
133
+ model.eval()
134
+ samples = generate(model, 5, max_len, schedule, total_timesteps)
135
+ for i in range(5):
136
+ seq = samples[i].tolist()
137
+ print(f" Sample {i}: {decode(seq)}")
138
+ print()
139
+
140
+ # Summary
141
+ final = history[-1]
142
+ contracting_epochs = sum(1 for h in history if h["regime"] == "CONTRACTING")
143
+ print(f"=== SUMMARY ===")
144
+ print(f"Final loss: {final['loss']:.4f}")
145
+ print(f"Final ratio: {final['ratio']:.4f} ({final['regime']})")
146
+ print(f"Contracting epochs: {contracting_epochs}/{len(history)}")
147
+ print(f"Schedule: {schedule.value} ({schedule.name})")
148
+
149
+ return history
150
+
151
+
152
+ def main():
153
+ if not HAS_TORCH:
154
+ sys.exit(1)
155
+
156
+ parser = argparse.ArgumentParser(description="MDLM Training")
157
+ parser.add_argument("--corpus", required=True, help="Corpus directory")
158
+ parser.add_argument("--schedule", default="A", choices=["A", "B", "C", "D"])
159
+ parser.add_argument("--epochs", type=int, default=50)
160
+ parser.add_argument("--batch-size", type=int, default=64)
161
+ parser.add_argument("--lr", type=float, default=1e-3)
162
+ parser.add_argument("--device", default="cpu")
163
+ parser.add_argument("--output", default=None, help="Save history to JSON")
164
+ args = parser.parse_args()
165
+
166
+ schedule = MaskingSchedule(args.schedule)
167
+ history = train(
168
+ args.corpus, schedule, args.epochs, args.batch_size,
169
+ args.lr, device_name=args.device,
170
+ )
171
+
172
+ if args.output:
173
+ with open(args.output, "w") as f:
174
+ json.dump(history, f, indent=2)
175
+ print(f"History saved to {args.output}")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()