Search commited on
Commit
02aaa6d
·
1 Parent(s): b65859f

auto: sync run_external_moe_family_wikitext_compare.py

Browse files
scripts/run_external_moe_family_wikitext_compare.py CHANGED
@@ -24,6 +24,7 @@ from src.data.wikitext_bpe import _collect_wikitext_text, load_wikitext_bpe
24
  from src.model import dense_transformer_baseline_external as local_dense_mod
25
  from src.model import motif_moe_external as local_motif_mod
26
  from src.model import plain_moe_transformer_external as local_flat_mod
 
27
 
28
 
29
  DEFAULT_DOWNLOADS_DIR = Path(r"C:\Users\Kharki\Downloads\Telegram Desktop")
@@ -123,6 +124,19 @@ def make_optimizer(model: torch.nn.Module, d_model: int) -> torch.optim.Optimize
123
  )
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def _dense_param_count(cfg: Any) -> int:
127
  hidden = max(16, int(round(cfg.d_model * cfg.ffn_hidden_ratio)))
128
  per_layer = (
@@ -280,6 +294,7 @@ def run_compare(
280
  motif_source: str,
281
  motif_profile: str,
282
  baseline_source: str,
 
283
  ) -> dict[str, Any]:
284
  modules = load_modules(downloads_dir, motif_source=motif_source, baseline_source=baseline_source)
285
  motif_mod = modules["motif"]
@@ -362,6 +377,39 @@ def run_compare(
362
  dense_model = dense_mod.DenseTransformerLM(dense_cfg).to(device)
363
  dense_params = int(dense_mod.count_parameters(dense_model))
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  motif_history = train_model(
366
  name="MotifMoE",
367
  model=motif_model,
@@ -398,6 +446,19 @@ def run_compare(
398
  eval_every=eval_every,
399
  eval_batches=eval_batches,
400
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
  report = {
403
  "downloads_dir": str(downloads_dir),
@@ -414,6 +475,7 @@ def run_compare(
414
  "motif_source": motif_source,
415
  "motif_profile": motif_profile,
416
  "baseline_source": baseline_source,
 
417
  "models": {
418
  "motif_moe": {
419
  "params": motif_params,
@@ -457,6 +519,13 @@ def run_compare(
457
  },
458
  },
459
  }
 
 
 
 
 
 
 
460
  return report
461
 
462
 
@@ -475,6 +544,7 @@ def main() -> None:
475
  parser.add_argument("--motif-source", choices=("local", "external"), default="external")
476
  parser.add_argument("--motif-profile", choices=("scaled", "text-lm"), default="scaled")
477
  parser.add_argument("--baseline-source", choices=("local", "external"), default="external")
 
478
  parser.add_argument("--model-name", default=None)
479
  parser.add_argument("--model", default=None)
480
  args = parser.parse_args()
@@ -493,6 +563,7 @@ def main() -> None:
493
  motif_source=args.motif_source,
494
  motif_profile=args.motif_profile,
495
  baseline_source=args.baseline_source,
 
496
  )
497
 
498
  suffix = f"_densematch-{report['dense_match_target']}" if report["dense_match_target"] != "none" else ""
@@ -506,7 +577,8 @@ def main() -> None:
506
  f"_{report['motif_source']}-{report['motif_profile']}"
507
  f"_baseline-{report['baseline_source']}"
508
  )
509
- out_path = ARCHIVE_DIR / f"external_moe_family_wikitext_compare_vocab{report['vocab_size']}{suffix}{profile_suffix}.json"
 
510
  out_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8")
511
  ranking = sorted(
512
  (
 
24
  from src.model import dense_transformer_baseline_external as local_dense_mod
25
  from src.model import motif_moe_external as local_motif_mod
26
  from src.model import plain_moe_transformer_external as local_flat_mod
27
+ from src.model import wikitext_motif_combined_ensemble_external as local_ensemble_mod
28
 
29
 
30
  DEFAULT_DOWNLOADS_DIR = Path(r"C:\Users\Kharki\Downloads\Telegram Desktop")
 
124
  )
125
 
126
 
127
+ def infer_model_width(model: torch.nn.Module, fallback: int) -> int:
128
+ config = getattr(model, "config", None)
129
+ if config is not None and hasattr(config, "d_model"):
130
+ return int(config.d_model)
131
+ members = getattr(model, "members", None)
132
+ if members:
133
+ first_member = next(iter(members.values()))
134
+ member_config = getattr(first_member, "config", None)
135
+ if member_config is not None and hasattr(member_config, "d_model"):
136
+ return int(member_config.d_model)
137
+ return int(fallback)
138
+
139
+
140
  def _dense_param_count(cfg: Any) -> int:
141
  hidden = max(16, int(round(cfg.d_model * cfg.ffn_hidden_ratio)))
142
  per_layer = (
 
294
  motif_source: str,
295
  motif_profile: str,
296
  baseline_source: str,
297
+ include_combined_ensemble: bool,
298
  ) -> dict[str, Any]:
299
  modules = load_modules(downloads_dir, motif_source=motif_source, baseline_source=baseline_source)
300
  motif_mod = modules["motif"]
 
377
  dense_model = dense_mod.DenseTransformerLM(dense_cfg).to(device)
378
  dense_params = int(dense_mod.count_parameters(dense_model))
379
 
380
+ combined_model: torch.nn.Module | None = None
381
+ combined_params: int | None = None
382
+ combined_config: dict[str, Any] | None = None
383
+ combined_history: list[dict[str, float]] | None = None
384
+ if include_combined_ensemble:
385
+ ensemble_build = local_ensemble_mod.build_wikitext_motif_ensemble(
386
+ local_ensemble_mod.EnsembleBuildConfig(
387
+ scale="0.125x",
388
+ block_size=seq_len,
389
+ vocab_size=actual_vocab_size,
390
+ include_text_motif_moe=True,
391
+ gate_hidden_dim=128,
392
+ gate_mode="contextual",
393
+ freeze_members=False,
394
+ member_dropout_p=0.05,
395
+ gate_temperature=1.0,
396
+ gate_entropy_bonus_coef=0.0,
397
+ )
398
+ )
399
+ combined_model = ensemble_build.model.to(device)
400
+ combined_params = int(local_ensemble_mod.count_parameters(combined_model))
401
+ combined_config = {
402
+ "scale": "0.125x",
403
+ "block_size": seq_len,
404
+ "gate_mode": "contextual",
405
+ "gate_hidden_dim": 128,
406
+ "freeze_members": False,
407
+ "member_dropout_p": 0.05,
408
+ "gate_temperature": 1.0,
409
+ "gate_entropy_bonus_coef": 0.0,
410
+ "member_names": [spec.name for spec in ensemble_build.specs],
411
+ }
412
+
413
  motif_history = train_model(
414
  name="MotifMoE",
415
  model=motif_model,
 
446
  eval_every=eval_every,
447
  eval_batches=eval_batches,
448
  )
449
+ if combined_model is not None:
450
+ combined_history = train_model(
451
+ name="CombinedEnsemble",
452
+ model=combined_model,
453
+ d_model=infer_model_width(combined_model, fallback=motif_cfg.d_model),
454
+ train_data=train_data,
455
+ val_data=val_data,
456
+ device=device,
457
+ steps=steps,
458
+ batch_size=batch_size,
459
+ eval_every=eval_every,
460
+ eval_batches=eval_batches,
461
+ )
462
 
463
  report = {
464
  "downloads_dir": str(downloads_dir),
 
475
  "motif_source": motif_source,
476
  "motif_profile": motif_profile,
477
  "baseline_source": baseline_source,
478
+ "include_combined_ensemble": include_combined_ensemble,
479
  "models": {
480
  "motif_moe": {
481
  "params": motif_params,
 
519
  },
520
  },
521
  }
522
+ if combined_model is not None and combined_history is not None and combined_params is not None and combined_config is not None:
523
+ report["models"]["combined_ensemble"] = {
524
+ "params": combined_params,
525
+ "config": combined_config,
526
+ "history": combined_history,
527
+ "final": combined_history[-1],
528
+ }
529
  return report
530
 
531
 
 
544
  parser.add_argument("--motif-source", choices=("local", "external"), default="external")
545
  parser.add_argument("--motif-profile", choices=("scaled", "text-lm"), default="scaled")
546
  parser.add_argument("--baseline-source", choices=("local", "external"), default="external")
547
+ parser.add_argument("--include-combined-ensemble", action="store_true")
548
  parser.add_argument("--model-name", default=None)
549
  parser.add_argument("--model", default=None)
550
  args = parser.parse_args()
 
563
  motif_source=args.motif_source,
564
  motif_profile=args.motif_profile,
565
  baseline_source=args.baseline_source,
566
+ include_combined_ensemble=args.include_combined_ensemble,
567
  )
568
 
569
  suffix = f"_densematch-{report['dense_match_target']}" if report["dense_match_target"] != "none" else ""
 
577
  f"_{report['motif_source']}-{report['motif_profile']}"
578
  f"_baseline-{report['baseline_source']}"
579
  )
580
+ ensemble_suffix = "_with-ensemble" if report["include_combined_ensemble"] else ""
581
+ out_path = ARCHIVE_DIR / f"external_moe_family_wikitext_compare_vocab{report['vocab_size']}{suffix}{profile_suffix}{ensemble_suffix}.json"
582
  out_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8")
583
  ranking = sorted(
584
  (
src/model/wikitext_motif_combined_ensemble_external.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ from dataclasses import asdict, dataclass, field
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader, Dataset
14
+
15
+ from src.model.dense_transformer_baseline_external import DenseTransformerConfig, DenseTransformerLM
16
+ from src.model.motif_moe_external import MotifMoEConfig, MotifMoETransformer, count_parameters, make_motif_moe_config
17
+
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Data utilities
21
+ # -----------------------------------------------------------------------------
22
+
23
+
24
+ class ByteTokenizer:
25
+ """Dependency-free byte tokenizer for quick WikiText experiments.
26
+
27
+ This is intentionally simple so the code runs without external tokenization
28
+ libraries. If you already have pretokenized ids, use `load_token_tensor` and
29
+ skip this tokenizer entirely.
30
+ """
31
+
32
+ PAD_ID = 256
33
+ EOS_ID = 257
34
+ VOCAB_SIZE = 258
35
+
36
+ def __init__(self) -> None:
37
+ self.pad_id = self.PAD_ID
38
+ self.eos_id = self.EOS_ID
39
+ self.vocab_size = self.VOCAB_SIZE
40
+
41
+ def encode(self, text: str, add_eos: bool = True) -> List[int]:
42
+ ids = list(text.encode("utf-8", errors="replace"))
43
+ if add_eos:
44
+ ids.append(self.eos_id)
45
+ return ids
46
+
47
+ def decode(self, ids: Sequence[int]) -> str:
48
+ buf = bytearray()
49
+ for idx in ids:
50
+ if idx < 256:
51
+ buf.append(int(idx))
52
+ elif idx == self.eos_id:
53
+ buf.extend(b"\n")
54
+ return buf.decode("utf-8", errors="replace")
55
+
56
+
57
+ def load_token_tensor(path: str | Path) -> torch.Tensor:
58
+ path = Path(path)
59
+ suffix = path.suffix.lower()
60
+ if suffix == ".pt":
61
+ obj = torch.load(path, map_location="cpu")
62
+ if isinstance(obj, torch.Tensor):
63
+ return obj.long().flatten()
64
+ raise TypeError(f"Expected a Tensor in {path}, got {type(obj)!r}")
65
+
66
+ if suffix == ".npy":
67
+ try:
68
+ import numpy as np
69
+ except Exception as exc: # pragma: no cover - optional dependency
70
+ raise ImportError("numpy is required to load .npy token arrays") from exc
71
+ arr = np.load(path)
72
+ return torch.from_numpy(arr).long().flatten()
73
+
74
+ if suffix in {".bin", ".uint16", ".u16"}:
75
+ try:
76
+ import numpy as np
77
+ except Exception as exc: # pragma: no cover - optional dependency
78
+ raise ImportError("numpy is required to load .bin token arrays") from exc
79
+ arr = np.fromfile(path, dtype=np.uint16)
80
+ return torch.from_numpy(arr.astype("int64"))
81
+
82
+ raise ValueError(f"Unsupported token file extension for {path}")
83
+
84
+
85
+ def text_file_to_tokens(path: str | Path, tokenizer: Optional[ByteTokenizer] = None) -> torch.Tensor:
86
+ tokenizer = tokenizer or ByteTokenizer()
87
+ text = Path(path).read_text(encoding="utf-8")
88
+ ids: List[int] = []
89
+ for line in text.splitlines():
90
+ ids.extend(tokenizer.encode(line, add_eos=True))
91
+ return torch.tensor(ids, dtype=torch.long)
92
+
93
+
94
+ class TokenBlockDataset(Dataset[torch.Tensor]):
95
+ def __init__(self, tokens: torch.Tensor, block_size: int, stride: Optional[int] = None) -> None:
96
+ if tokens.ndim != 1:
97
+ raise ValueError("tokens must be a 1D tensor")
98
+ if block_size < 2:
99
+ raise ValueError("block_size must be >= 2")
100
+ self.tokens = tokens.long().contiguous()
101
+ self.block_size = int(block_size)
102
+ self.stride = int(stride if stride is not None else block_size)
103
+ if self.stride < 1:
104
+ raise ValueError("stride must be >= 1")
105
+ self.max_start = max(0, self.tokens.numel() - self.block_size)
106
+ self.num_examples = 1 + (self.max_start // self.stride) if self.tokens.numel() >= self.block_size else 0
107
+
108
+ def __len__(self) -> int:
109
+ return self.num_examples
110
+
111
+ def __getitem__(self, index: int) -> torch.Tensor:
112
+ start = index * self.stride
113
+ chunk = self.tokens[start : start + self.block_size]
114
+ if chunk.numel() != self.block_size:
115
+ raise IndexError(index)
116
+ return chunk
117
+
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Branch specs and builders
121
+ # -----------------------------------------------------------------------------
122
+
123
+
124
+ @dataclass
125
+ class MotifBranchSpec:
126
+ name: str
127
+ model_type: str # 'dense' or 'motif_moe'
128
+ attn_qk_ratio: float = 1.0
129
+ attn_v_ratio: float = 1.0
130
+
131
+ # Dense-only fields
132
+ ffn_hidden_ratio: float = 4.0
133
+ ffn_activation: str = "gelu"
134
+
135
+ # Motif-MoE-only fields
136
+ motif_families: Tuple[str, ...] = ()
137
+ experts_per_family: int = 2
138
+ motif_hidden_ratios: Dict[str, float] = field(default_factory=dict)
139
+ top_k_motifs: int = 2
140
+ top_k_experts: int = 1
141
+ expert_pool_scale: float = 1.0
142
+ router_jitter_noise: float = 0.01
143
+
144
+ def validate(self) -> None:
145
+ if self.model_type not in {"dense", "motif_moe"}:
146
+ raise ValueError(f"Unsupported model_type={self.model_type!r}")
147
+ if self.attn_qk_ratio <= 0 or self.attn_v_ratio <= 0:
148
+ raise ValueError("Attention ratios must be > 0")
149
+ if self.model_type == "dense" and self.ffn_hidden_ratio <= 0:
150
+ raise ValueError("ffn_hidden_ratio must be > 0 for dense branches")
151
+ if self.model_type == "motif_moe":
152
+ if not self.motif_families:
153
+ raise ValueError("motif_moe branch requires motif_families")
154
+ if self.experts_per_family < 1:
155
+ raise ValueError("experts_per_family must be >= 1")
156
+
157
+
158
+ @dataclass
159
+ class EnsembleBuildConfig:
160
+ scale: str = "0.125x"
161
+ block_size: int = 256
162
+ vocab_size: int = ByteTokenizer.VOCAB_SIZE
163
+ include_text_motif_moe: bool = True
164
+ gate_hidden_dim: int = 128
165
+ gate_mode: str = "contextual" # contextual | static
166
+ freeze_members: bool = True
167
+ member_dropout_p: float = 0.0
168
+ gate_temperature: float = 1.0
169
+ gate_entropy_bonus_coef: float = 0.0
170
+
171
+
172
+ @dataclass
173
+ class TrainConfig:
174
+ batch_size: int = 8
175
+ lr: float = 3e-4
176
+ weight_decay: float = 0.01
177
+ epochs: int = 1
178
+ max_steps: Optional[int] = None
179
+ grad_clip: float = 1.0
180
+ device: str = "cpu"
181
+ log_every: int = 50
182
+
183
+
184
+ @dataclass
185
+ class EvalSummary:
186
+ split: str
187
+ lm_loss: float
188
+ perplexity: float
189
+ tokens: int
190
+
191
+
192
+ def make_wikitext_branch_specs(include_text_motif_moe: bool = True) -> List[MotifBranchSpec]:
193
+ """Default motif-profile ensemble for WikiText-like natural text.
194
+
195
+ Profiles reflect the observed pattern that text often benefits from stronger
196
+ expand/select/memory capacity and does not require the widest compare space.
197
+ """
198
+ specs = [
199
+ MotifBranchSpec(
200
+ name="uniform_dense",
201
+ model_type="dense",
202
+ attn_qk_ratio=1.0,
203
+ attn_v_ratio=1.0,
204
+ ffn_hidden_ratio=4.0,
205
+ ffn_activation="gelu",
206
+ ),
207
+ MotifBranchSpec(
208
+ name="narrow_compare_dense",
209
+ model_type="dense",
210
+ attn_qk_ratio=0.75,
211
+ attn_v_ratio=1.0,
212
+ ffn_hidden_ratio=4.5,
213
+ ffn_activation="gelu",
214
+ ),
215
+ MotifBranchSpec(
216
+ name="wide_memory_dense",
217
+ model_type="dense",
218
+ attn_qk_ratio=0.75,
219
+ attn_v_ratio=1.0,
220
+ ffn_hidden_ratio=6.0,
221
+ ffn_activation="swiglu",
222
+ ),
223
+ ]
224
+ if include_text_motif_moe:
225
+ specs.append(
226
+ MotifBranchSpec(
227
+ name="text_motif_moe",
228
+ model_type="motif_moe",
229
+ attn_qk_ratio=0.75,
230
+ attn_v_ratio=1.0,
231
+ motif_families=("expand", "select", "memory"),
232
+ experts_per_family=2,
233
+ motif_hidden_ratios={
234
+ "expand": 6.0,
235
+ "select": 2.0,
236
+ "memory": 4.0,
237
+ },
238
+ top_k_motifs=2,
239
+ top_k_experts=1,
240
+ expert_pool_scale=0.5,
241
+ router_jitter_noise=0.01,
242
+ )
243
+ )
244
+ for spec in specs:
245
+ spec.validate()
246
+ return specs
247
+
248
+
249
+ def _clone_motif_cfg(cfg: MotifMoEConfig) -> MotifMoEConfig:
250
+ return MotifMoEConfig(**asdict(cfg))
251
+
252
+
253
+ class DenseTransformerLMWithHidden(DenseTransformerLM):
254
+ """DenseTransformerLM that can optionally return final hidden states."""
255
+
256
+ def forward(
257
+ self,
258
+ input_ids: torch.Tensor,
259
+ labels: Optional[torch.Tensor] = None,
260
+ *,
261
+ return_hidden_states: bool = False,
262
+ return_router_stats: bool = False,
263
+ ) -> Dict[str, Any]:
264
+ del return_router_stats
265
+ x = self.emb_dropout(self.tok_emb(input_ids))
266
+ for block in self.blocks:
267
+ x = block(x)
268
+ hidden = self.ln_f(x)
269
+ logits = self.lm_head(hidden)
270
+
271
+ lm_loss: Optional[torch.Tensor] = None
272
+ loss: Optional[torch.Tensor] = None
273
+ if labels is not None:
274
+ shift_logits = logits[:, :-1, :].contiguous()
275
+ shift_labels = labels[:, 1:].contiguous()
276
+ lm_loss = F.cross_entropy(
277
+ shift_logits.view(-1, shift_logits.size(-1)),
278
+ shift_labels.view(-1),
279
+ ignore_index=-100,
280
+ )
281
+ loss = lm_loss
282
+
283
+ zero = logits.new_zeros(())
284
+ out: Dict[str, Any] = {
285
+ "logits": logits,
286
+ "loss": loss,
287
+ "lm_loss": lm_loss,
288
+ "router_aux_loss": zero,
289
+ "router_aux_loss_raw": zero,
290
+ "router_stats": None,
291
+ }
292
+ if return_hidden_states:
293
+ out["hidden_states"] = hidden
294
+ return out
295
+
296
+
297
+ class MotifMoETransformerWithHidden(MotifMoETransformer):
298
+ """MotifMoETransformer that can optionally return final hidden states."""
299
+
300
+ def forward(
301
+ self,
302
+ input_ids: torch.Tensor,
303
+ labels: Optional[torch.Tensor] = None,
304
+ *,
305
+ return_hidden_states: bool = False,
306
+ return_router_stats: bool = False,
307
+ ) -> Dict[str, Any]:
308
+ out = super().forward(input_ids, labels=labels, return_router_stats=return_router_stats)
309
+ if return_hidden_states:
310
+ x = self.emb_dropout(self.tok_emb(input_ids))
311
+ aux_losses: List[torch.Tensor] = []
312
+ router_stats: List[Dict[str, Any]] = []
313
+ for block in self.blocks:
314
+ x, aux_loss, stats = block(x, return_router_stats=return_router_stats)
315
+ aux_losses.append(aux_loss)
316
+ if return_router_stats:
317
+ router_stats.append(stats)
318
+ hidden = self.ln_f(x)
319
+ out["hidden_states"] = hidden
320
+ # Preserve the already-computed outputs, but keep router stats consistent.
321
+ if return_router_stats:
322
+ out["router_stats"] = router_stats
323
+ return out
324
+
325
+
326
+ def build_branch_model(
327
+ spec: MotifBranchSpec,
328
+ *,
329
+ scale: str,
330
+ vocab_size: int,
331
+ block_size: int,
332
+ ) -> nn.Module:
333
+ spec.validate()
334
+ base_cfg = make_motif_moe_config(
335
+ scale=scale,
336
+ vocab_size=vocab_size,
337
+ max_seq_len=block_size,
338
+ expert_pool_scale=max(0.5, spec.expert_pool_scale),
339
+ top_k_motifs=max(1, spec.top_k_motifs),
340
+ top_k_experts=max(1, spec.top_k_experts),
341
+ attn_qk_ratio=spec.attn_qk_ratio,
342
+ attn_v_ratio=spec.attn_v_ratio,
343
+ )
344
+
345
+ if spec.model_type == "dense":
346
+ dense_cfg = DenseTransformerConfig(
347
+ vocab_size=base_cfg.vocab_size,
348
+ max_seq_len=base_cfg.max_seq_len,
349
+ tie_word_embeddings=base_cfg.tie_word_embeddings,
350
+ n_layers=base_cfg.n_layers,
351
+ d_model=base_cfg.d_model,
352
+ n_heads=base_cfg.n_heads,
353
+ attn_qk_ratio=spec.attn_qk_ratio,
354
+ attn_v_ratio=spec.attn_v_ratio,
355
+ ffn_hidden_ratio=spec.ffn_hidden_ratio,
356
+ ffn_activation=spec.ffn_activation,
357
+ rope_base=base_cfg.rope_base,
358
+ bias=base_cfg.bias,
359
+ norm_eps=base_cfg.norm_eps,
360
+ resid_dropout=base_cfg.resid_dropout,
361
+ attn_dropout=base_cfg.attn_dropout,
362
+ emb_dropout=base_cfg.emb_dropout,
363
+ ffn_dropout=base_cfg.expert_dropout,
364
+ initializer_range=base_cfg.initializer_range,
365
+ )
366
+ return DenseTransformerLMWithHidden(dense_cfg)
367
+
368
+ motif_cfg = _clone_motif_cfg(base_cfg)
369
+ motif_cfg.attn_qk_ratio = spec.attn_qk_ratio
370
+ motif_cfg.attn_v_ratio = spec.attn_v_ratio
371
+ motif_cfg.motif_families = tuple(spec.motif_families)
372
+ motif_cfg.experts_per_family = {family: spec.experts_per_family for family in motif_cfg.motif_families}
373
+ motif_cfg.motif_hidden_ratios = {
374
+ family: spec.motif_hidden_ratios.get(family, base_cfg.motif_hidden_ratios.get(family, 4.0))
375
+ for family in motif_cfg.motif_families
376
+ }
377
+ motif_cfg.top_k_motifs = min(spec.top_k_motifs, len(motif_cfg.motif_families))
378
+ motif_cfg.top_k_experts = spec.top_k_experts
379
+ motif_cfg.router_jitter_noise = spec.router_jitter_noise
380
+ motif_cfg.validate()
381
+ return MotifMoETransformerWithHidden(motif_cfg)
382
+
383
+
384
+ # -----------------------------------------------------------------------------
385
+ # Ensemble model
386
+ # -----------------------------------------------------------------------------
387
+
388
+
389
+ class ContextualLogitGate(nn.Module):
390
+ """Learns per-token convex weights over branch logits.
391
+
392
+ Features are extracted from each branch's logits:
393
+ - entropy
394
+ - max probability
395
+ - top-1 / top-2 margin
396
+ - logit standard deviation
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ num_branches: int,
402
+ *,
403
+ hidden_dim: int = 128,
404
+ mode: str = "contextual",
405
+ temperature: float = 1.0,
406
+ ) -> None:
407
+ super().__init__()
408
+ if mode not in {"contextual", "static"}:
409
+ raise ValueError("mode must be 'contextual' or 'static'")
410
+ self.num_branches = num_branches
411
+ self.hidden_dim = hidden_dim
412
+ self.mode = mode
413
+ self.temperature = temperature
414
+ self.branch_bias = nn.Parameter(torch.zeros(num_branches))
415
+ if mode == "contextual":
416
+ in_dim = num_branches * 4
417
+ self.mlp = nn.Sequential(
418
+ nn.Linear(in_dim, hidden_dim),
419
+ nn.GELU(),
420
+ nn.Linear(hidden_dim, hidden_dim),
421
+ nn.GELU(),
422
+ nn.Linear(hidden_dim, num_branches),
423
+ )
424
+ else:
425
+ self.mlp = None
426
+
427
+ def _extract_features(self, logits: torch.Tensor) -> torch.Tensor:
428
+ # logits: [M, B, T, V]
429
+ stats: List[torch.Tensor] = []
430
+ for branch_logits in logits.detach():
431
+ log_probs = F.log_softmax(branch_logits.float(), dim=-1)
432
+ probs = log_probs.exp()
433
+ entropy = -(probs * log_probs).sum(dim=-1)
434
+ top2 = torch.topk(branch_logits.float(), k=min(2, branch_logits.size(-1)), dim=-1).values
435
+ if top2.size(-1) == 1:
436
+ margin = top2[..., 0]
437
+ else:
438
+ margin = top2[..., 0] - top2[..., 1]
439
+ max_prob = probs.max(dim=-1).values
440
+ spread = branch_logits.float().std(dim=-1)
441
+ stats.append(torch.stack([entropy, max_prob, margin, spread], dim=-1))
442
+ return torch.cat(stats, dim=-1) # [B, T, 4M]
443
+
444
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
445
+ # logits: [M, B, T, V]
446
+ _, batch, seqlen, _ = logits.shape
447
+ if self.mode == "static":
448
+ weights = F.softmax(self.branch_bias / max(self.temperature, 1e-6), dim=-1)
449
+ return weights.view(1, 1, -1).expand(batch, seqlen, -1)
450
+ features = self._extract_features(logits)
451
+ gate_logits = self.mlp(features) + self.branch_bias
452
+ return F.softmax(gate_logits / max(self.temperature, 1e-6), dim=-1)
453
+
454
+
455
+ class WikiTextMotifEnsembleLM(nn.Module):
456
+ """Ensemble over motif-profile branches with a learned convex logit combiner."""
457
+
458
+ def __init__(
459
+ self,
460
+ members: Mapping[str, nn.Module],
461
+ *,
462
+ gate_hidden_dim: int = 128,
463
+ gate_mode: str = "contextual",
464
+ freeze_members: bool = True,
465
+ member_dropout_p: float = 0.0,
466
+ gate_temperature: float = 1.0,
467
+ gate_entropy_bonus_coef: float = 0.0,
468
+ ) -> None:
469
+ super().__init__()
470
+ if not members:
471
+ raise ValueError("At least one ensemble member is required")
472
+ self.members = nn.ModuleDict(members)
473
+ self.branch_names = list(members.keys())
474
+ self.freeze_members = freeze_members
475
+ self.member_dropout_p = member_dropout_p
476
+ self.gate_entropy_bonus_coef = gate_entropy_bonus_coef
477
+ self.gate = ContextualLogitGate(
478
+ len(self.members),
479
+ hidden_dim=gate_hidden_dim,
480
+ mode=gate_mode,
481
+ temperature=gate_temperature,
482
+ )
483
+ if self.freeze_members:
484
+ self.freeze_all_members()
485
+
486
+ def freeze_all_members(self) -> None:
487
+ for model in self.members.values():
488
+ for param in model.parameters():
489
+ param.requires_grad_(False)
490
+ model.eval()
491
+
492
+ def unfreeze_all_members(self) -> None:
493
+ for model in self.members.values():
494
+ for param in model.parameters():
495
+ param.requires_grad_(True)
496
+
497
+ def _forward_member(self, model: nn.Module, input_ids: torch.Tensor, return_router_stats: bool) -> Dict[str, Any]:
498
+ if self.freeze_members:
499
+ with torch.no_grad():
500
+ return model(input_ids, labels=None, return_hidden_states=False, return_router_stats=return_router_stats)
501
+ return model(input_ids, labels=None, return_hidden_states=False, return_router_stats=return_router_stats)
502
+
503
+ def _apply_member_dropout(self, weights: torch.Tensor) -> torch.Tensor:
504
+ if not self.training or self.member_dropout_p <= 0:
505
+ return weights
506
+ keep = torch.rand_like(weights) > self.member_dropout_p
507
+ # Ensure at least one branch stays alive per token.
508
+ all_zero = ~keep.any(dim=-1, keepdim=True)
509
+ if torch.any(all_zero):
510
+ fallback = torch.zeros_like(keep)
511
+ fallback[..., 0] = True
512
+ keep = torch.where(all_zero, fallback, keep)
513
+ dropped = weights * keep.float()
514
+ norm = dropped.sum(dim=-1, keepdim=True).clamp_min(1e-8)
515
+ return dropped / norm
516
+
517
+ def forward(
518
+ self,
519
+ input_ids: torch.Tensor,
520
+ labels: Optional[torch.Tensor] = None,
521
+ *,
522
+ return_member_outputs: bool = False,
523
+ return_router_stats: bool = False,
524
+ ) -> Dict[str, Any]:
525
+ member_outputs: Dict[str, Dict[str, Any]] = {}
526
+ member_logits: List[torch.Tensor] = []
527
+ member_aux_losses: List[torch.Tensor] = []
528
+
529
+ for name, model in self.members.items():
530
+ out = self._forward_member(model, input_ids, return_router_stats=return_router_stats)
531
+ member_outputs[name] = out
532
+ member_logits.append(out["logits"])
533
+ if out.get("router_aux_loss") is not None:
534
+ member_aux_losses.append(out["router_aux_loss"])
535
+
536
+ logits_stack = torch.stack(member_logits, dim=0) # [M, B, T, V]
537
+ gate_weights = self.gate(logits_stack) # [B, T, M]
538
+ gate_weights = self._apply_member_dropout(gate_weights)
539
+ weights_for_sum = gate_weights.permute(2, 0, 1).unsqueeze(-1)
540
+ ensemble_logits = (logits_stack * weights_for_sum).sum(dim=0)
541
+
542
+ lm_loss: Optional[torch.Tensor] = None
543
+ loss: Optional[torch.Tensor] = None
544
+ if labels is not None:
545
+ shift_logits = ensemble_logits[:, :-1, :].contiguous()
546
+ shift_labels = labels[:, 1:].contiguous()
547
+ lm_loss = F.cross_entropy(
548
+ shift_logits.view(-1, shift_logits.size(-1)),
549
+ shift_labels.view(-1),
550
+ ignore_index=-100,
551
+ )
552
+ loss = lm_loss
553
+ if not self.freeze_members and member_aux_losses:
554
+ loss = loss + torch.stack(member_aux_losses).mean()
555
+ if self.gate_entropy_bonus_coef != 0.0:
556
+ entropy = -(gate_weights * (gate_weights.clamp_min(1e-8).log())).sum(dim=-1).mean()
557
+ loss = loss - self.gate_entropy_bonus_coef * entropy
558
+
559
+ out: Dict[str, Any] = {
560
+ "logits": ensemble_logits,
561
+ "loss": loss,
562
+ "lm_loss": lm_loss,
563
+ "router_aux_loss": torch.stack(member_aux_losses).mean() if member_aux_losses else ensemble_logits.new_zeros(()),
564
+ "router_aux_loss_raw": torch.stack(member_aux_losses).mean() if member_aux_losses else ensemble_logits.new_zeros(()),
565
+ "ensemble_weights": gate_weights,
566
+ "branch_names": self.branch_names,
567
+ }
568
+ if return_member_outputs:
569
+ out["member_outputs"] = member_outputs
570
+ if return_router_stats:
571
+ out["member_router_stats"] = {name: member_outputs[name].get("router_stats") for name in self.branch_names}
572
+ return out
573
+
574
+ @torch.no_grad()
575
+ def summarize_weights(self, gate_weights: torch.Tensor) -> Dict[str, float]:
576
+ mean_weights = gate_weights.mean(dim=(0, 1)).cpu().tolist()
577
+ return {name: float(weight) for name, weight in zip(self.branch_names, mean_weights)}
578
+
579
+
580
+ @dataclass
581
+ class BuiltEnsemble:
582
+ model: WikiTextMotifEnsembleLM
583
+ specs: List[MotifBranchSpec]
584
+
585
+
586
+ def build_wikitext_motif_ensemble(config: EnsembleBuildConfig) -> BuiltEnsemble:
587
+ specs = make_wikitext_branch_specs(include_text_motif_moe=config.include_text_motif_moe)
588
+ members = {
589
+ spec.name: build_branch_model(spec, scale=config.scale, vocab_size=config.vocab_size, block_size=config.block_size)
590
+ for spec in specs
591
+ }
592
+ model = WikiTextMotifEnsembleLM(
593
+ members,
594
+ gate_hidden_dim=config.gate_hidden_dim,
595
+ gate_mode=config.gate_mode,
596
+ freeze_members=config.freeze_members,
597
+ member_dropout_p=config.member_dropout_p,
598
+ gate_temperature=config.gate_temperature,
599
+ gate_entropy_bonus_coef=config.gate_entropy_bonus_coef,
600
+ )
601
+ return BuiltEnsemble(model=model, specs=specs)
602
+
603
+
604
+ # -----------------------------------------------------------------------------
605
+ # Checkpoint helpers
606
+ # -----------------------------------------------------------------------------
607
+
608
+
609
+ @dataclass
610
+ class BranchCheckpoint:
611
+ spec: Dict[str, Any]
612
+ state_dict: Dict[str, Any]
613
+
614
+
615
+ @torch.no_grad()
616
+ def save_branch_checkpoint(model: nn.Module, spec: MotifBranchSpec, path: str | Path) -> None:
617
+ payload = {
618
+ "spec": asdict(spec),
619
+ "state_dict": model.state_dict(),
620
+ }
621
+ torch.save(payload, Path(path))
622
+
623
+
624
+ @torch.no_grad()
625
+ def load_branch_checkpoint(path: str | Path, *, scale: str, vocab_size: int, block_size: int, device: str = "cpu") -> Tuple[MotifBranchSpec, nn.Module]:
626
+ payload = torch.load(Path(path), map_location=device)
627
+ spec_dict = payload["spec"]
628
+ if isinstance(spec_dict.get("motif_families"), list):
629
+ spec_dict["motif_families"] = tuple(spec_dict["motif_families"])
630
+ spec = MotifBranchSpec(**spec_dict)
631
+ model = build_branch_model(spec, scale=scale, vocab_size=vocab_size, block_size=block_size)
632
+ model.load_state_dict(payload["state_dict"])
633
+ return spec, model
634
+
635
+
636
+ def save_ensemble_checkpoint(
637
+ ensemble: WikiTextMotifEnsembleLM,
638
+ specs: Sequence[MotifBranchSpec],
639
+ path: str | Path,
640
+ ) -> None:
641
+ payload = {
642
+ "specs": [asdict(spec) for spec in specs],
643
+ "gate_state_dict": ensemble.gate.state_dict(),
644
+ "branch_names": list(ensemble.branch_names),
645
+ }
646
+ torch.save(payload, Path(path))
647
+
648
+
649
+ def load_ensemble_gate(ensemble: WikiTextMotifEnsembleLM, path: str | Path, device: str = "cpu") -> Dict[str, Any]:
650
+ payload = torch.load(Path(path), map_location=device)
651
+ ensemble.gate.load_state_dict(payload["gate_state_dict"])
652
+ return payload
653
+
654
+
655
+ # -----------------------------------------------------------------------------
656
+ # Training and evaluation
657
+ # -----------------------------------------------------------------------------
658
+
659
+
660
+ @dataclass
661
+ class RunningAverage:
662
+ total: float = 0.0
663
+ count: int = 0
664
+
665
+ def update(self, value: float, n: int = 1) -> None:
666
+ self.total += value * n
667
+ self.count += n
668
+
669
+ @property
670
+ def mean(self) -> float:
671
+ return self.total / max(1, self.count)
672
+
673
+
674
+ def move_batch(batch: torch.Tensor, device: torch.device) -> torch.Tensor:
675
+ return batch.to(device, non_blocking=True)
676
+
677
+
678
+ def make_loader(tokens: torch.Tensor, *, block_size: int, batch_size: int, shuffle: bool) -> DataLoader[torch.Tensor]:
679
+ dataset = TokenBlockDataset(tokens, block_size=block_size, stride=block_size)
680
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False)
681
+
682
+
683
+ def train_lm_model(
684
+ model: nn.Module,
685
+ train_tokens: torch.Tensor,
686
+ valid_tokens: Optional[torch.Tensor],
687
+ *,
688
+ block_size: int,
689
+ cfg: TrainConfig,
690
+ ) -> Dict[str, Any]:
691
+ device = torch.device(cfg.device)
692
+ model.to(device)
693
+ model.train()
694
+ loader = make_loader(train_tokens, block_size=block_size, batch_size=cfg.batch_size, shuffle=True)
695
+ optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=cfg.weight_decay)
696
+
697
+ step = 0
698
+ train_loss = RunningAverage()
699
+ for epoch in range(cfg.epochs):
700
+ for batch in loader:
701
+ batch = move_batch(batch, device)
702
+ optimizer.zero_grad(set_to_none=True)
703
+ out = model(batch, labels=batch)
704
+ loss = out["loss"]
705
+ if loss is None:
706
+ raise RuntimeError("Model did not return a loss")
707
+ loss.backward()
708
+ if cfg.grad_clip > 0:
709
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
710
+ optimizer.step()
711
+ step += 1
712
+ train_loss.update(float(loss.detach().cpu()), n=batch.size(0))
713
+ if cfg.log_every > 0 and step % cfg.log_every == 0:
714
+ print(f"[train] epoch={epoch+1} step={step} loss={train_loss.mean:.4f}")
715
+ if cfg.max_steps is not None and step >= cfg.max_steps:
716
+ break
717
+ if cfg.max_steps is not None and step >= cfg.max_steps:
718
+ break
719
+
720
+ result: Dict[str, Any] = {
721
+ "train_loss": train_loss.mean,
722
+ "train_steps": step,
723
+ }
724
+ if valid_tokens is not None:
725
+ result["valid"] = asdict(
726
+ evaluate_lm_model(model, valid_tokens, block_size=block_size, batch_size=cfg.batch_size, device=cfg.device)
727
+ )
728
+ return result
729
+
730
+
731
+ @torch.no_grad()
732
+ def evaluate_lm_model(
733
+ model: nn.Module,
734
+ tokens: torch.Tensor,
735
+ *,
736
+ block_size: int,
737
+ batch_size: int,
738
+ device: str,
739
+ split: str = "valid",
740
+ ) -> EvalSummary:
741
+ model.eval()
742
+ dev = torch.device(device)
743
+ model.to(dev)
744
+ loader = make_loader(tokens, block_size=block_size, batch_size=batch_size, shuffle=False)
745
+ loss_avg = RunningAverage()
746
+ token_count = 0
747
+ for batch in loader:
748
+ batch = move_batch(batch, dev)
749
+ out = model(batch, labels=batch)
750
+ lm_loss = out["lm_loss"]
751
+ if lm_loss is None:
752
+ raise RuntimeError("Model did not return lm_loss")
753
+ valid_tokens = int((batch[:, 1:] != -100).sum().item())
754
+ loss_avg.update(float(lm_loss.detach().cpu()), n=max(1, valid_tokens))
755
+ token_count += valid_tokens
756
+ lm_loss = loss_avg.mean
757
+ ppl = float(math.exp(min(20.0, lm_loss)))
758
+ return EvalSummary(split=split, lm_loss=lm_loss, perplexity=ppl, tokens=token_count)
759
+
760
+
761
+ @torch.no_grad()
762
+ def evaluate_members(
763
+ members: Mapping[str, nn.Module],
764
+ tokens: torch.Tensor,
765
+ *,
766
+ block_size: int,
767
+ batch_size: int,
768
+ device: str,
769
+ ) -> Dict[str, Dict[str, float]]:
770
+ results: Dict[str, Dict[str, float]] = {}
771
+ for name, model in members.items():
772
+ summary = evaluate_lm_model(model, tokens, block_size=block_size, batch_size=batch_size, device=device, split=name)
773
+ results[name] = asdict(summary)
774
+ return results
775
+
776
+
777
+ # -----------------------------------------------------------------------------
778
+ # CLI helpers
779
+ # -----------------------------------------------------------------------------
780
+
781
+
782
+ @dataclass
783
+ class DataBundle:
784
+ train_tokens: torch.Tensor
785
+ valid_tokens: torch.Tensor
786
+ test_tokens: Optional[torch.Tensor]
787
+ vocab_size: int
788
+
789
+
790
+ def load_wikitext_data(
791
+ *,
792
+ train_path: str,
793
+ valid_path: str,
794
+ test_path: Optional[str] = None,
795
+ tokenized: bool = False,
796
+ ) -> DataBundle:
797
+ if tokenized:
798
+ train_tokens = load_token_tensor(train_path)
799
+ valid_tokens = load_token_tensor(valid_path)
800
+ test_tokens = load_token_tensor(test_path) if test_path is not None else None
801
+ vocab_size = int(max(train_tokens.max().item(), valid_tokens.max().item(), test_tokens.max().item() if test_tokens is not None else 0) + 1)
802
+ else:
803
+ tokenizer = ByteTokenizer()
804
+ train_tokens = text_file_to_tokens(train_path, tokenizer)
805
+ valid_tokens = text_file_to_tokens(valid_path, tokenizer)
806
+ test_tokens = text_file_to_tokens(test_path, tokenizer) if test_path is not None else None
807
+ vocab_size = tokenizer.vocab_size
808
+ return DataBundle(train_tokens=train_tokens, valid_tokens=valid_tokens, test_tokens=test_tokens, vocab_size=vocab_size)
809
+
810
+
811
+ def train_members_cli(args: argparse.Namespace) -> None:
812
+ data = load_wikitext_data(
813
+ train_path=args.train_path,
814
+ valid_path=args.valid_path,
815
+ test_path=args.test_path,
816
+ tokenized=args.tokenized,
817
+ )
818
+ out_dir = Path(args.output_dir)
819
+ out_dir.mkdir(parents=True, exist_ok=True)
820
+ members_dir = out_dir / "members"
821
+ members_dir.mkdir(parents=True, exist_ok=True)
822
+
823
+ specs = make_wikitext_branch_specs(include_text_motif_moe=not args.no_text_motif_moe)
824
+ train_cfg = TrainConfig(
825
+ batch_size=args.batch_size,
826
+ lr=args.lr,
827
+ weight_decay=args.weight_decay,
828
+ epochs=args.epochs,
829
+ max_steps=args.max_steps,
830
+ grad_clip=args.grad_clip,
831
+ device=args.device,
832
+ log_every=args.log_every,
833
+ )
834
+
835
+ summary: Dict[str, Any] = {"members": {}}
836
+ for spec in specs:
837
+ print(f"\n=== Training member: {spec.name} ===")
838
+ model = build_branch_model(spec, scale=args.scale, vocab_size=data.vocab_size, block_size=args.block_size)
839
+ print(f"params[{spec.name}]={count_parameters(model):,}")
840
+ result = train_lm_model(
841
+ model,
842
+ data.train_tokens,
843
+ data.valid_tokens,
844
+ block_size=args.block_size,
845
+ cfg=train_cfg,
846
+ )
847
+ ckpt_path = members_dir / f"{spec.name}.pt"
848
+ save_branch_checkpoint(model, spec, ckpt_path)
849
+ summary["members"][spec.name] = {
850
+ "checkpoint": str(ckpt_path),
851
+ "params": count_parameters(model),
852
+ "result": result,
853
+ }
854
+ print(json.dumps(summary["members"][spec.name], indent=2))
855
+
856
+ (out_dir / "member_training_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
857
+
858
+
859
+
860
+ def train_ensemble_cli(args: argparse.Namespace) -> None:
861
+ data = load_wikitext_data(
862
+ train_path=args.train_path,
863
+ valid_path=args.valid_path,
864
+ test_path=args.test_path,
865
+ tokenized=args.tokenized,
866
+ )
867
+ out_dir = Path(args.output_dir)
868
+ out_dir.mkdir(parents=True, exist_ok=True)
869
+ members_dir = out_dir / "members"
870
+
871
+ specs: List[MotifBranchSpec] = []
872
+ members: Dict[str, nn.Module] = {}
873
+ for ckpt_path in sorted(members_dir.glob("*.pt")):
874
+ spec, model = load_branch_checkpoint(
875
+ ckpt_path,
876
+ scale=args.scale,
877
+ vocab_size=data.vocab_size,
878
+ block_size=args.block_size,
879
+ device=args.device,
880
+ )
881
+ specs.append(spec)
882
+ members[spec.name] = model
883
+
884
+ if not members:
885
+ raise FileNotFoundError(f"No member checkpoints found in {members_dir}")
886
+
887
+ ensemble = WikiTextMotifEnsembleLM(
888
+ members,
889
+ gate_hidden_dim=args.gate_hidden_dim,
890
+ gate_mode=args.gate_mode,
891
+ freeze_members=not args.joint_finetune,
892
+ member_dropout_p=args.member_dropout_p,
893
+ gate_temperature=args.gate_temperature,
894
+ gate_entropy_bonus_coef=args.gate_entropy_bonus_coef,
895
+ )
896
+ device = torch.device(args.device)
897
+ ensemble.to(device)
898
+ if args.joint_finetune:
899
+ ensemble.unfreeze_all_members()
900
+ for model in ensemble.members.values():
901
+ model.train()
902
+
903
+ train_cfg = TrainConfig(
904
+ batch_size=args.batch_size,
905
+ lr=args.lr,
906
+ weight_decay=args.weight_decay,
907
+ epochs=args.epochs,
908
+ max_steps=args.max_steps,
909
+ grad_clip=args.grad_clip,
910
+ device=args.device,
911
+ log_every=args.log_every,
912
+ )
913
+ result = train_lm_model(
914
+ ensemble,
915
+ data.train_tokens,
916
+ data.valid_tokens,
917
+ block_size=args.block_size,
918
+ cfg=train_cfg,
919
+ )
920
+ gate_path = out_dir / "motif_ensemble_gate.pt"
921
+ save_ensemble_checkpoint(ensemble, specs, gate_path)
922
+
923
+ valid_summary = evaluate_lm_model(
924
+ ensemble,
925
+ data.valid_tokens,
926
+ block_size=args.block_size,
927
+ batch_size=args.batch_size,
928
+ device=args.device,
929
+ split="valid_ensemble",
930
+ )
931
+ member_summaries = evaluate_members(
932
+ ensemble.members,
933
+ data.valid_tokens,
934
+ block_size=args.block_size,
935
+ batch_size=args.batch_size,
936
+ device=args.device,
937
+ )
938
+
939
+ summary = {
940
+ "ensemble_params_trainable": sum(p.numel() for p in ensemble.parameters() if p.requires_grad),
941
+ "ensemble_result": result,
942
+ "valid_ensemble": asdict(valid_summary),
943
+ "valid_members": member_summaries,
944
+ "gate_checkpoint": str(gate_path),
945
+ }
946
+ (out_dir / "ensemble_training_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
947
+ print(json.dumps(summary, indent=2))
948
+
949
+
950
+
951
+ def evaluate_cli(args: argparse.Namespace) -> None:
952
+ data = load_wikitext_data(
953
+ train_path=args.train_path,
954
+ valid_path=args.valid_path,
955
+ test_path=args.test_path,
956
+ tokenized=args.tokenized,
957
+ )
958
+ out_dir = Path(args.output_dir)
959
+ members_dir = out_dir / "members"
960
+ specs: List[MotifBranchSpec] = []
961
+ members: Dict[str, nn.Module] = {}
962
+ for ckpt_path in sorted(members_dir.glob("*.pt")):
963
+ spec, model = load_branch_checkpoint(
964
+ ckpt_path,
965
+ scale=args.scale,
966
+ vocab_size=data.vocab_size,
967
+ block_size=args.block_size,
968
+ device=args.device,
969
+ )
970
+ specs.append(spec)
971
+ members[spec.name] = model
972
+
973
+ if not members:
974
+ raise FileNotFoundError(f"No member checkpoints found in {members_dir}")
975
+
976
+ ensemble = WikiTextMotifEnsembleLM(
977
+ members,
978
+ gate_hidden_dim=args.gate_hidden_dim,
979
+ gate_mode=args.gate_mode,
980
+ freeze_members=True,
981
+ member_dropout_p=0.0,
982
+ gate_temperature=args.gate_temperature,
983
+ gate_entropy_bonus_coef=0.0,
984
+ )
985
+ gate_path = out_dir / "motif_ensemble_gate.pt"
986
+ if gate_path.exists():
987
+ load_ensemble_gate(ensemble, gate_path, device=args.device)
988
+ else:
989
+ print(f"Warning: {gate_path} not found. Evaluating with randomly initialized gate.")
990
+
991
+ splits = [("valid", data.valid_tokens)]
992
+ if data.test_tokens is not None:
993
+ splits.append(("test", data.test_tokens))
994
+
995
+ report: Dict[str, Any] = {"members": {}, "ensemble": {}}
996
+ for split_name, split_tokens in splits:
997
+ report["ensemble"][split_name] = asdict(
998
+ evaluate_lm_model(
999
+ ensemble,
1000
+ split_tokens,
1001
+ block_size=args.block_size,
1002
+ batch_size=args.batch_size,
1003
+ device=args.device,
1004
+ split=split_name,
1005
+ )
1006
+ )
1007
+ report["members"][split_name] = evaluate_members(
1008
+ members,
1009
+ split_tokens,
1010
+ block_size=args.block_size,
1011
+ batch_size=args.batch_size,
1012
+ device=args.device,
1013
+ )
1014
+
1015
+ sample_batch = next(iter(make_loader(data.valid_tokens, block_size=args.block_size, batch_size=min(2, args.batch_size), shuffle=False)))
1016
+ sample_batch = sample_batch.to(args.device)
1017
+ ensemble = ensemble.to(args.device)
1018
+ out = ensemble(sample_batch, labels=sample_batch)
1019
+ report["mean_gate_weights"] = ensemble.summarize_weights(out["ensemble_weights"])
1020
+
1021
+ report_path = out_dir / "ensemble_eval_report.json"
1022
+ report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
1023
+ print(json.dumps(report, indent=2))
1024
+
1025
+
1026
+ # -----------------------------------------------------------------------------
1027
+ # CLI
1028
+ # -----------------------------------------------------------------------------
1029
+
1030
+
1031
+ def build_arg_parser() -> argparse.ArgumentParser:
1032
+ parser = argparse.ArgumentParser(description="WikiText motif-profile ensemble")
1033
+ subparsers = parser.add_subparsers(dest="command", required=True)
1034
+
1035
+ def add_common(sub: argparse.ArgumentParser) -> None:
1036
+ sub.add_argument("--train-path", type=str, required=True)
1037
+ sub.add_argument("--valid-path", type=str, required=True)
1038
+ sub.add_argument("--test-path", type=str, default=None)
1039
+ sub.add_argument("--tokenized", action="store_true", help="Interpret input paths as pretokenized tensors (.pt/.npy/.bin)")
1040
+ sub.add_argument("--output-dir", type=str, required=True)
1041
+ sub.add_argument("--scale", type=str, default="0.125x")
1042
+ sub.add_argument("--block-size", type=int, default=256)
1043
+ sub.add_argument("--batch-size", type=int, default=8)
1044
+ sub.add_argument("--device", type=str, default="cpu")
1045
+ sub.add_argument("--lr", type=float, default=3e-4)
1046
+ sub.add_argument("--weight-decay", type=float, default=0.01)
1047
+ sub.add_argument("--epochs", type=int, default=1)
1048
+ sub.add_argument("--max-steps", type=int, default=None)
1049
+ sub.add_argument("--grad-clip", type=float, default=1.0)
1050
+ sub.add_argument("--log-every", type=int, default=50)
1051
+
1052
+ train_members = subparsers.add_parser("train-members", help="Train individual motif-profile members")
1053
+ add_common(train_members)
1054
+ train_members.add_argument("--no-text-motif-moe", action="store_true")
1055
+
1056
+ train_ensemble = subparsers.add_parser("train-ensemble", help="Train the ensemble gate over saved members")
1057
+ add_common(train_ensemble)
1058
+ train_ensemble.add_argument("--joint-finetune", action="store_true", help="Also fine-tune member models while training the gate")
1059
+ train_ensemble.add_argument("--gate-hidden-dim", type=int, default=128)
1060
+ train_ensemble.add_argument("--gate-mode", type=str, default="contextual", choices=["contextual", "static"])
1061
+ train_ensemble.add_argument("--member-dropout-p", type=float, default=0.05)
1062
+ train_ensemble.add_argument("--gate-temperature", type=float, default=1.0)
1063
+ train_ensemble.add_argument("--gate-entropy-bonus-coef", type=float, default=0.0)
1064
+
1065
+ evaluate = subparsers.add_parser("evaluate", help="Evaluate trained members and the ensemble")
1066
+ add_common(evaluate)
1067
+ evaluate.add_argument("--gate-hidden-dim", type=int, default=128)
1068
+ evaluate.add_argument("--gate-mode", type=str, default="contextual", choices=["contextual", "static"])
1069
+ evaluate.add_argument("--gate-temperature", type=float, default=1.0)
1070
+
1071
+ return parser
1072
+
1073
+
1074
+ def main() -> None:
1075
+ parser = build_arg_parser()
1076
+ args = parser.parse_args()
1077
+ if args.command == "train-members":
1078
+ train_members_cli(args)
1079
+ elif args.command == "train-ensemble":
1080
+ train_ensemble_cli(args)
1081
+ elif args.command == "evaluate":
1082
+ evaluate_cli(args)
1083
+ else: # pragma: no cover
1084
+ raise KeyError(args.command)
1085
+
1086
+
1087
+ if __name__ == "__main__":
1088
+ main()