amirali1985 commited on
Commit
fe8c900
·
verified ·
1 Parent(s): 38a78f2

Upload modular/code/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modular/code/train.py +274 -0
modular/code/train.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular arithmetic SoRL training — matches Nanda et al. (2023) architecture.
3
+
4
+ Architecture: 1L / 4H / 128d / d_mlp=512 (Nanda's exact setup)
5
+ Dataset: all p²=12769 pairs, 30% train fixed (seed=42)
6
+
7
+ Usage:
8
+ python -m arithmetic.modular.training.train --mode baseline
9
+ python -m arithmetic.modular.training.train --mode sorl --K 1 --abs_vocab 30
10
+ """
11
+ import sys
12
+ import json
13
+ import argparse
14
+ from dataclasses import dataclass, asdict
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+ import matplotlib.pyplot as plt
21
+
22
+ try:
23
+ import wandb
24
+ WANDB_AVAILABLE = True
25
+ except ImportError:
26
+ WANDB_AVAILABLE = False
27
+
28
+ WANDB_PROJECT = "sorl-modular"
29
+ WANDB_ENTITY = "nlp_and_interpretability"
30
+
31
+ sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
32
+
33
+ from transformers import Qwen3Config
34
+ from sorl.sorl_wrapper import SorlModelWrapper
35
+ from sorl.sorl_trainer import sorl_search, SoRLLoss
36
+ from arithmetic.modular.data.modular import (
37
+ get_train_set, get_eval_set,
38
+ VOCAB_SIZE, PAD, PROMPT_LEN, P,
39
+ )
40
+ from arithmetic.modular.training.evaluate import ModularEvaluator
41
+
42
+
43
+ @dataclass
44
+ class ModularConfig:
45
+ # Task
46
+ p: int = P
47
+ mode: str = "sorl" # "baseline" or "sorl"
48
+
49
+ # Architecture — Nanda's exact setup
50
+ n_layer: int = 1
51
+ n_head: int = 4
52
+ n_embd: int = 128
53
+ d_mlp: int = 512
54
+
55
+ # SoRL
56
+ K: int = 1
57
+ abs_vocab: int = 30
58
+ alpha_info_gain: float = 10.0
59
+ alpha_abs: float = 0.1
60
+ alpha_soft_zipf: float = 1.0
61
+ n_rollouts: int = 2
62
+
63
+ # Optimizer
64
+ lr: float = 1e-3
65
+ weight_decay: float = 0.1
66
+ num_epochs: int = 500
67
+ batch_size: int = 256
68
+
69
+ # Eval / logging
70
+ eval_every: int = 50
71
+ log_every: int = 10
72
+ device: str = "cuda"
73
+ seed: int = 42
74
+ job_name: str = ""
75
+ out_dir: str = "" # if empty, defaults to arithmetic/modular/runs/<job_name>
76
+ no_wandb: bool = False
77
+
78
+
79
+ def make_model(cfg: ModularConfig) -> SorlModelWrapper:
80
+ config = Qwen3Config(
81
+ hidden_size=cfg.n_embd,
82
+ num_hidden_layers=cfg.n_layer,
83
+ num_attention_heads=cfg.n_head,
84
+ num_key_value_heads=cfg.n_head,
85
+ intermediate_size=cfg.d_mlp,
86
+ vocab_size=VOCAB_SIZE,
87
+ max_position_embeddings=32,
88
+ )
89
+ abs_v = cfg.abs_vocab if cfg.mode == "sorl" else 1
90
+ return SorlModelWrapper.from_scratch(config, [VOCAB_SIZE, abs_v], PAD)
91
+
92
+
93
+ def make_loader(examples, batch_size: int, shuffle: bool = True):
94
+ tokens = torch.tensor([ex.tokens for ex in examples], dtype=torch.long)
95
+ bs = len(examples) if batch_size == 0 else batch_size
96
+ return DataLoader(TensorDataset(tokens), batch_size=bs, shuffle=shuffle)
97
+
98
+
99
+ def compute_base_traj_loss(model, ids: torch.Tensor, attn: torch.Tensor) -> torch.Tensor:
100
+ """CE on result token only, no abstract tokens."""
101
+ out = model(input_ids=ids, attention_mask=attn, memory_span_abs=512, memory_span_traj=512)
102
+ base_v = int(model.vocab_sizes[0].item())
103
+ return nn.CrossEntropyLoss()(out.logits[:, PROMPT_LEN - 1, :base_v], ids[:, PROMPT_LEN])
104
+
105
+
106
+ def save_curves(history: dict, out_dir: Path):
107
+ fig, axes = plt.subplots(1, 2, figsize=(10, 4))
108
+
109
+ axes[0].plot(history["epoch"], history["train_loss"], label="train loss")
110
+ axes[0].set_xlabel("epoch"); axes[0].set_ylabel("loss"); axes[0].set_title("Training Loss")
111
+ axes[0].legend()
112
+
113
+ axes[1].plot(history["eval_epoch"], history["test_acc"], color="green", label="test acc")
114
+ axes[1].set_xlabel("epoch"); axes[1].set_ylabel("accuracy"); axes[1].set_title("Test Accuracy")
115
+ axes[1].set_ylim(0, 1); axes[1].legend()
116
+
117
+ plt.tight_layout()
118
+ plt.savefig(out_dir / "curves.png", dpi=100)
119
+ plt.close()
120
+
121
+
122
+ def train(cfg: ModularConfig):
123
+ torch.manual_seed(cfg.seed)
124
+ device = torch.device(cfg.device)
125
+
126
+ out_dir = Path(cfg.out_dir) if cfg.out_dir else (
127
+ Path(__file__).resolve().parents[2] / "runs" / (cfg.job_name or f"{cfg.mode}_K{cfg.K}")
128
+ )
129
+ out_dir.mkdir(parents=True, exist_ok=True)
130
+
131
+ train_examples = get_train_set(p=cfg.p, seed=cfg.seed)
132
+ test_examples = get_eval_set(p=cfg.p, seed=cfg.seed)
133
+ loader = make_loader(train_examples, cfg.batch_size)
134
+
135
+ model = make_model(cfg).to(device)
136
+ optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
137
+ evaluator = ModularEvaluator(model, device=cfg.device, K=cfg.K)
138
+
139
+ sorl_loss_fn = SoRLLoss(
140
+ abs_vocab_size=model.vocab_sizes[-1],
141
+ zipf_alpha=cfg.alpha_soft_zipf,
142
+ ).to(device) if cfg.mode == "sorl" else None
143
+
144
+ history = {"epoch": [], "train_loss": [], "eval_epoch": [], "test_acc": []}
145
+ best_acc = 0.0
146
+
147
+ use_wandb = WANDB_AVAILABLE and not cfg.no_wandb
148
+ if use_wandb:
149
+ wandb.init(
150
+ project=WANDB_PROJECT, entity=WANDB_ENTITY,
151
+ name=cfg.job_name or f"{cfg.mode}_K{cfg.K}_abs{cfg.abs_vocab}",
152
+ config=asdict(cfg),
153
+ )
154
+
155
+ print(f"Training {cfg.mode} | p={cfg.p} | {len(train_examples)} train | {len(test_examples)} test")
156
+ print(f"Model: {cfg.n_layer}L/{cfg.n_head}H/{cfg.n_embd}d | K={cfg.K} | abs_vocab={cfg.abs_vocab}")
157
+ print(f"Output: {out_dir}")
158
+
159
+ for epoch in range(1, cfg.num_epochs + 1):
160
+ model.train()
161
+ epoch_loss = 0.0
162
+
163
+ for (ids,) in loader:
164
+ ids = ids.to(device)
165
+ attn = torch.ones_like(ids)
166
+ pl = torch.full((ids.shape[0],), PROMPT_LEN, dtype=torch.long, device=device)
167
+ optimizer.zero_grad()
168
+
169
+ if cfg.mode == "baseline":
170
+ out = model(input_ids=ids, attention_mask=attn, memory_span_abs=512, memory_span_traj=512)
171
+ base_v = int(model.vocab_sizes[0].item())
172
+ loss = nn.CrossEntropyLoss()(out.logits[:, PROMPT_LEN - 1, :base_v], ids[:, PROMPT_LEN])
173
+ else:
174
+ btl = compute_base_traj_loss(model, ids, attn)
175
+ with torch.no_grad():
176
+ best_data, _, _, exp_mask, exp_pl = sorl_search(
177
+ model, ids, attn, pl, PAD,
178
+ n=cfg.n_rollouts, K=cfg.K,
179
+ max_iterations=2, memory_span_abs=512, memory_span_traj=512,
180
+ )
181
+ info_loss, abs_loss, zipf_loss = sorl_loss_fn(
182
+ best_data, model, btl.detach(), exp_mask, 512, 512, prompt_len=exp_pl,
183
+ )
184
+ loss = (btl
185
+ + cfg.alpha_info_gain * info_loss
186
+ + cfg.alpha_abs * abs_loss
187
+ + cfg.alpha_soft_zipf * zipf_loss)
188
+
189
+ loss.backward()
190
+ optimizer.step()
191
+ epoch_loss += loss.item()
192
+
193
+ avg_loss = epoch_loss / len(loader)
194
+ history["epoch"].append(epoch)
195
+ history["train_loss"].append(avg_loss)
196
+
197
+ if use_wandb:
198
+ wandb.log({"train/loss": avg_loss, "epoch": epoch})
199
+
200
+ if epoch % cfg.log_every == 0:
201
+ print(f" epoch {epoch:5d} | loss {avg_loss:.4f}")
202
+
203
+ if epoch % cfg.eval_every == 0:
204
+ acc = evaluator.run(test_examples, max_examples=1000)
205
+ history["eval_epoch"].append(epoch)
206
+ history["test_acc"].append(acc)
207
+ print(f" epoch {epoch:5d} | test_acc {acc:.3f}")
208
+ if use_wandb:
209
+ wandb.log({"eval/accuracy": acc, "epoch": epoch})
210
+ save_curves(history, out_dir)
211
+ with open(out_dir / "history.json", "w") as f:
212
+ json.dump(history, f, indent=2)
213
+ if acc > best_acc:
214
+ best_acc = acc
215
+ best_dir = out_dir / "best"
216
+ best_dir.mkdir(exist_ok=True)
217
+ torch.save(model.state_dict(), best_dir / "model_state_dict.pt")
218
+ with open(best_dir / "sorl_config.json", "w") as f:
219
+ json.dump({"K": cfg.K, "abs_vocab": cfg.abs_vocab, "p": cfg.p,
220
+ "n_layer": cfg.n_layer, "n_head": cfg.n_head,
221
+ "n_embd": cfg.n_embd, "d_mlp": cfg.d_mlp,
222
+ "best_epoch": epoch, "best_acc": acc}, f)
223
+
224
+ final_acc = evaluator.run(test_examples, max_examples=2000)
225
+ print(f"\nFinal test accuracy: {final_acc:.4f} ({int(final_acc * len(test_examples))}/{len(test_examples)})")
226
+
227
+ history["final_acc"] = final_acc
228
+ with open(out_dir / "history.json", "w") as f:
229
+ json.dump(history, f, indent=2)
230
+ save_curves(history, out_dir)
231
+ with open(out_dir / "config.json", "w") as f:
232
+ json.dump(asdict(cfg), f, indent=2)
233
+ (out_dir / "final").mkdir(exist_ok=True)
234
+ torch.save(model.state_dict(), out_dir / "final" / "model_state_dict.pt")
235
+ with open(out_dir / "final" / "sorl_config.json", "w") as f:
236
+ json.dump({"K": cfg.K, "abs_vocab": cfg.abs_vocab, "p": cfg.p,
237
+ "n_layer": cfg.n_layer, "n_head": cfg.n_head,
238
+ "n_embd": cfg.n_embd, "d_mlp": cfg.d_mlp}, f)
239
+ print(f"Model saved to {out_dir / 'final'}")
240
+
241
+ if use_wandb:
242
+ wandb.log({"eval/final_accuracy": final_acc})
243
+ wandb.finish()
244
+
245
+ return model, final_acc
246
+
247
+
248
+ def main():
249
+ p = argparse.ArgumentParser()
250
+ p.add_argument("--mode", default="sorl", choices=["baseline", "sorl"])
251
+ p.add_argument("--K", type=int, default=1)
252
+ p.add_argument("--abs_vocab", type=int, default=30)
253
+ p.add_argument("--num_epochs", type=int, default=500)
254
+ p.add_argument("--batch_size", type=int, default=256)
255
+ p.add_argument("--lr", type=float, default=1e-3)
256
+ p.add_argument("--weight_decay", type=float, default=0.1)
257
+ p.add_argument("--eval_every", type=int, default=50)
258
+ p.add_argument("--log_every", type=int, default=10)
259
+ p.add_argument("--n_layer", type=int, default=1)
260
+ p.add_argument("--n_head", type=int, default=4)
261
+ p.add_argument("--n_embd", type=int, default=128)
262
+ p.add_argument("--d_mlp", type=int, default=512)
263
+ p.add_argument("--device", default="cuda")
264
+ p.add_argument("--job_name", default="")
265
+ p.add_argument("--out_dir", default="")
266
+ p.add_argument("--no_wandb", action="store_true")
267
+ args = p.parse_args()
268
+
269
+ cfg = ModularConfig(**vars(args))
270
+ train(cfg)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()