Search commited on
Commit
35e5b8a
·
1 Parent(s): c3488d4

fix: proper controls — loss masking, uniform baseline, multi-seed

Browse files
src/fog/config.py CHANGED
@@ -34,6 +34,17 @@ MOTIF_SMALL = FOGConfig(
34
  d_gate=32,
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Tiny configs for fast iteration
38
  BASELINE_TINY = FOGConfig(
39
  vocab_size=32,
 
34
  d_gate=32,
35
  )
36
 
37
+ # Param-matched uniform baseline for controlled comparison
38
+ # d_model=94, d_ff=376 → ~432K params to match MOTIF_TINY
39
+ UNIFORM_TINY = FOGConfig(
40
+ vocab_size=32,
41
+ d_model=94,
42
+ n_layers=4,
43
+ n_heads=2,
44
+ max_seq_len=32,
45
+ d_ff=376,
46
+ )
47
+
48
  # Tiny configs for fast iteration
49
  BASELINE_TINY = FOGConfig(
50
  vocab_size=32,
src/fog/data.py CHANGED
@@ -16,17 +16,18 @@ class CopyTask(Dataset):
16
  self.sep_token = vocab_size - 1
17
  rng = random.Random(seed)
18
  self.samples = []
19
- content_vocab = vocab_size - 1 # exclude SEP
 
20
  half = seq_len // 2 - 1
21
  for _ in range(n_samples):
22
  content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
23
- # input: content + SEP + content (teacher forcing)
24
  ids = content + [self.sep_token] + content
25
- # pad/truncate to seq_len
26
  ids = ids[:seq_len]
27
  while len(ids) < seq_len:
28
  ids.append(0)
29
  self.samples.append(ids)
 
30
 
31
  def __len__(self) -> int:
32
  return len(self.samples)
@@ -35,7 +36,12 @@ class CopyTask(Dataset):
35
  ids = self.samples[idx]
36
  x = torch.tensor(ids[:-1], dtype=torch.long)
37
  y = torch.tensor(ids[1:], dtype=torch.long)
38
- return {"input_ids": x, "targets": y}
 
 
 
 
 
39
 
40
 
41
  class ReverseTask(Dataset):
@@ -47,15 +53,18 @@ class ReverseTask(Dataset):
47
  self.sep_token = vocab_size - 1
48
  rng = random.Random(seed)
49
  self.samples = []
 
50
  content_vocab = vocab_size - 1
51
  half = seq_len // 2 - 1
52
  for _ in range(n_samples):
53
  content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
54
  ids = content + [self.sep_token] + list(reversed(content))
 
55
  ids = ids[:seq_len]
56
  while len(ids) < seq_len:
57
  ids.append(0)
58
  self.samples.append(ids)
 
59
 
60
  def __len__(self) -> int:
61
  return len(self.samples)
@@ -64,7 +73,11 @@ class ReverseTask(Dataset):
64
  ids = self.samples[idx]
65
  x = torch.tensor(ids[:-1], dtype=torch.long)
66
  y = torch.tensor(ids[1:], dtype=torch.long)
67
- return {"input_ids": x, "targets": y}
 
 
 
 
68
 
69
 
70
  class SelectiveRetrieval(Dataset):
@@ -79,6 +92,7 @@ class SelectiveRetrieval(Dataset):
79
  self.sep_token = vocab_size - 1
80
  rng = random.Random(seed)
81
  self.samples = []
 
82
  content_vocab = vocab_size - 2 # exclude SEP and padding
83
  for _ in range(n_samples):
84
  keys = rng.sample(range(content_vocab), min(n_pairs, content_vocab))
@@ -88,6 +102,7 @@ class SelectiveRetrieval(Dataset):
88
  ids = []
89
  for k, v in zip(keys, values):
90
  ids.extend([k, v])
 
91
  ids.append(self.sep_token)
92
  ids.append(keys[query_idx])
93
  ids.append(values[query_idx])
@@ -96,6 +111,7 @@ class SelectiveRetrieval(Dataset):
96
  while len(ids) < seq_len:
97
  ids.append(0)
98
  self.samples.append(ids)
 
99
 
100
  def __len__(self) -> int:
101
  return len(self.samples)
@@ -104,4 +120,8 @@ class SelectiveRetrieval(Dataset):
104
  ids = self.samples[idx]
105
  x = torch.tensor(ids[:-1], dtype=torch.long)
106
  y = torch.tensor(ids[1:], dtype=torch.long)
107
- return {"input_ids": x, "targets": y}
 
 
 
 
 
16
  self.sep_token = vocab_size - 1
17
  rng = random.Random(seed)
18
  self.samples = []
19
+ self.sep_positions = []
20
+ content_vocab = vocab_size - 1
21
  half = seq_len // 2 - 1
22
  for _ in range(n_samples):
23
  content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
 
24
  ids = content + [self.sep_token] + content
25
+ sep_pos = len(content)
26
  ids = ids[:seq_len]
27
  while len(ids) < seq_len:
28
  ids.append(0)
29
  self.samples.append(ids)
30
+ self.sep_positions.append(sep_pos)
31
 
32
  def __len__(self) -> int:
33
  return len(self.samples)
 
36
  ids = self.samples[idx]
37
  x = torch.tensor(ids[:-1], dtype=torch.long)
38
  y = torch.tensor(ids[1:], dtype=torch.long)
39
+ # loss_mask: 1 after SEP, 0 before (shifted by -1 for targets)
40
+ mask = torch.zeros_like(y)
41
+ sep = self.sep_positions[idx]
42
+ if sep < len(mask):
43
+ mask[sep:] = 1
44
+ return {"input_ids": x, "targets": y, "loss_mask": mask}
45
 
46
 
47
  class ReverseTask(Dataset):
 
53
  self.sep_token = vocab_size - 1
54
  rng = random.Random(seed)
55
  self.samples = []
56
+ self.sep_positions = []
57
  content_vocab = vocab_size - 1
58
  half = seq_len // 2 - 1
59
  for _ in range(n_samples):
60
  content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
61
  ids = content + [self.sep_token] + list(reversed(content))
62
+ sep_pos = len(content)
63
  ids = ids[:seq_len]
64
  while len(ids) < seq_len:
65
  ids.append(0)
66
  self.samples.append(ids)
67
+ self.sep_positions.append(sep_pos)
68
 
69
  def __len__(self) -> int:
70
  return len(self.samples)
 
73
  ids = self.samples[idx]
74
  x = torch.tensor(ids[:-1], dtype=torch.long)
75
  y = torch.tensor(ids[1:], dtype=torch.long)
76
+ mask = torch.zeros_like(y)
77
+ sep = self.sep_positions[idx]
78
+ if sep < len(mask):
79
+ mask[sep:] = 1
80
+ return {"input_ids": x, "targets": y, "loss_mask": mask}
81
 
82
 
83
  class SelectiveRetrieval(Dataset):
 
92
  self.sep_token = vocab_size - 1
93
  rng = random.Random(seed)
94
  self.samples = []
95
+ self.sep_positions = []
96
  content_vocab = vocab_size - 2 # exclude SEP and padding
97
  for _ in range(n_samples):
98
  keys = rng.sample(range(content_vocab), min(n_pairs, content_vocab))
 
102
  ids = []
103
  for k, v in zip(keys, values):
104
  ids.extend([k, v])
105
+ sep_pos = len(ids)
106
  ids.append(self.sep_token)
107
  ids.append(keys[query_idx])
108
  ids.append(values[query_idx])
 
111
  while len(ids) < seq_len:
112
  ids.append(0)
113
  self.samples.append(ids)
114
+ self.sep_positions.append(sep_pos)
115
 
116
  def __len__(self) -> int:
117
  return len(self.samples)
 
120
  ids = self.samples[idx]
121
  x = torch.tensor(ids[:-1], dtype=torch.long)
122
  y = torch.tensor(ids[1:], dtype=torch.long)
123
+ mask = torch.zeros_like(y)
124
+ sep = self.sep_positions[idx]
125
+ if sep < len(mask):
126
+ mask[sep:] = 1
127
+ return {"input_ids": x, "targets": y, "loss_mask": mask}
src/fog/model_baseline.py CHANGED
@@ -79,6 +79,7 @@ class BaselineTransformer(nn.Module):
79
  self,
80
  input_ids: torch.Tensor,
81
  targets: torch.Tensor | None = None,
 
82
  ) -> dict[str, torch.Tensor]:
83
  b, t = input_ids.shape
84
  pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
@@ -95,6 +96,16 @@ class BaselineTransformer(nn.Module):
95
 
96
  loss = None
97
  if targets is not None:
98
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
 
 
 
 
 
 
 
 
 
 
99
 
100
  return {"logits": logits, "loss": loss}
 
79
  self,
80
  input_ids: torch.Tensor,
81
  targets: torch.Tensor | None = None,
82
+ loss_mask: torch.Tensor | None = None,
83
  ) -> dict[str, torch.Tensor]:
84
  b, t = input_ids.shape
85
  pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
 
96
 
97
  loss = None
98
  if targets is not None:
99
+ if loss_mask is not None:
100
+ # only compute loss on target positions (after SEP)
101
+ flat_logits = logits.view(-1, logits.size(-1))
102
+ flat_targets = targets.view(-1)
103
+ flat_mask = loss_mask.view(-1).bool()
104
+ if flat_mask.any():
105
+ loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
106
+ else:
107
+ loss = torch.tensor(0.0, device=logits.device)
108
+ else:
109
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
110
 
111
  return {"logits": logits, "loss": loss}
src/fog/model_motif.py CHANGED
@@ -121,6 +121,7 @@ class MotifTransformer(nn.Module):
121
  self,
122
  input_ids: torch.Tensor,
123
  targets: torch.Tensor | None = None,
 
124
  ) -> dict[str, torch.Tensor]:
125
  b, t = input_ids.shape
126
  pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
@@ -136,6 +137,15 @@ class MotifTransformer(nn.Module):
136
 
137
  loss = None
138
  if targets is not None:
139
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
 
 
 
 
 
 
 
 
 
140
 
141
  return {"logits": logits, "loss": loss}
 
121
  self,
122
  input_ids: torch.Tensor,
123
  targets: torch.Tensor | None = None,
124
+ loss_mask: torch.Tensor | None = None,
125
  ) -> dict[str, torch.Tensor]:
126
  b, t = input_ids.shape
127
  pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
 
137
 
138
  loss = None
139
  if targets is not None:
140
+ if loss_mask is not None:
141
+ flat_logits = logits.view(-1, logits.size(-1))
142
+ flat_targets = targets.view(-1)
143
+ flat_mask = loss_mask.view(-1).bool()
144
+ if flat_mask.any():
145
+ loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
146
+ else:
147
+ loss = torch.tensor(0.0, device=logits.device)
148
+ else:
149
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
150
 
151
  return {"logits": logits, "loss": loss}
src/fog/train.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  import torch
10
  from torch.utils.data import DataLoader
11
 
12
- from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL, BASELINE_TINY, MOTIF_TINY
13
  from src.fog.model_baseline import BaselineTransformer
14
  from src.fog.model_motif import MotifTransformer
15
  from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
@@ -31,7 +31,10 @@ def train_epoch(
31
  for batch in loader:
32
  input_ids = batch["input_ids"].to(device)
33
  targets = batch["targets"].to(device)
34
- out = model(input_ids, targets)
 
 
 
35
  loss = out["loss"]
36
  optimizer.zero_grad()
37
  loss.backward()
@@ -58,21 +61,29 @@ def eval_accuracy(
58
  for batch in loader:
59
  input_ids = batch["input_ids"].to(device)
60
  targets = batch["targets"].to(device)
61
- out = model(input_ids, targets)
 
 
 
62
  total_loss += out["loss"].item()
63
  n_batches += 1
64
 
65
  preds = out["logits"].argmax(dim=-1)
66
- # only count accuracy after SEP token
67
- for i in range(input_ids.size(0)):
68
- sep_positions = (input_ids[i] == sep_token).nonzero(as_tuple=True)[0]
69
- if len(sep_positions) == 0:
70
- continue
71
- start = sep_positions[0].item() + 1
72
- if start >= targets.size(1):
73
- continue
74
- correct += (preds[i, start:] == targets[i, start:]).sum().item()
75
- total += targets.size(1) - start
 
 
 
 
 
76
 
77
  return {
78
  "loss": total_loss / max(n_batches, 1),
@@ -89,17 +100,20 @@ def run_experiment(
89
  batch_size: int,
90
  lr: float,
91
  device: torch.device,
 
92
  ) -> dict:
93
- # Data
 
 
94
  n_train, n_eval = 5000, 500
95
  if task_name == "copy":
96
- train_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=42)
97
  eval_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
98
  elif task_name == "reverse":
99
- train_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=42)
100
  eval_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
101
  elif task_name == "retrieval":
102
- train_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_train, seed=42)
103
  eval_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
104
  else:
105
  raise ValueError(f"Unknown task: {task_name}")
@@ -141,6 +155,7 @@ def run_experiment(
141
  return {
142
  "model_type": model_type,
143
  "task": task_name,
 
144
  "n_params": n_params,
145
  "n_epochs": n_epochs,
146
  "elapsed_s": round(elapsed, 1),
@@ -159,37 +174,40 @@ def main() -> None:
159
  parser.add_argument("--lr", type=float, default=3e-4)
160
  parser.add_argument("--device", type=str, default="cpu")
161
  parser.add_argument("--size", type=str, default="tiny", choices=["tiny", "small"])
 
162
  parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
163
  args = parser.parse_args()
164
 
165
  device = torch.device(args.device)
166
 
167
  if args.size == "tiny":
168
- configs = [("baseline", BASELINE_TINY), ("motif", MOTIF_TINY)]
169
  else:
170
  configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
171
 
172
  results = []
173
 
174
  for task in args.tasks:
175
- print(f"\n{'='*60}")
176
- print(f" Task: {task} (size={args.size})")
177
- print(f"{'='*60}")
178
-
179
- for model_type, cfg in configs:
180
- result = run_experiment(
181
- task_name=task,
182
- cfg=cfg,
183
- model_type=model_type,
184
- n_epochs=args.epochs,
185
- batch_size=args.batch_size,
186
- lr=args.lr,
187
- device=device,
188
- )
189
- results.append(result)
190
- print(f" → {model_type}: params={result['n_params']:,} "
191
- f"acc={result['final_accuracy']:.4f} "
192
- f"time={result['elapsed_s']}s")
 
 
193
 
194
  # Summary
195
  print(f"\n{'='*60}")
 
9
  import torch
10
  from torch.utils.data import DataLoader
11
 
12
+ from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL, BASELINE_TINY, MOTIF_TINY, UNIFORM_TINY
13
  from src.fog.model_baseline import BaselineTransformer
14
  from src.fog.model_motif import MotifTransformer
15
  from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
 
31
  for batch in loader:
32
  input_ids = batch["input_ids"].to(device)
33
  targets = batch["targets"].to(device)
34
+ loss_mask = batch.get("loss_mask")
35
+ if loss_mask is not None:
36
+ loss_mask = loss_mask.to(device)
37
+ out = model(input_ids, targets, loss_mask=loss_mask)
38
  loss = out["loss"]
39
  optimizer.zero_grad()
40
  loss.backward()
 
61
  for batch in loader:
62
  input_ids = batch["input_ids"].to(device)
63
  targets = batch["targets"].to(device)
64
+ loss_mask = batch.get("loss_mask")
65
+ if loss_mask is not None:
66
+ loss_mask = loss_mask.to(device)
67
+ out = model(input_ids, targets, loss_mask=loss_mask)
68
  total_loss += out["loss"].item()
69
  n_batches += 1
70
 
71
  preds = out["logits"].argmax(dim=-1)
72
+ # accuracy only on masked (target) positions
73
+ if loss_mask is not None:
74
+ m = loss_mask.bool()
75
+ correct += (preds[m] == targets[m]).sum().item()
76
+ total += m.sum().item()
77
+ else:
78
+ for i in range(input_ids.size(0)):
79
+ sep_positions = (input_ids[i] == sep_token).nonzero(as_tuple=True)[0]
80
+ if len(sep_positions) == 0:
81
+ continue
82
+ start = sep_positions[0].item() + 1
83
+ if start >= targets.size(1):
84
+ continue
85
+ correct += (preds[i, start:] == targets[i, start:]).sum().item()
86
+ total += targets.size(1) - start
87
 
88
  return {
89
  "loss": total_loss / max(n_batches, 1),
 
100
  batch_size: int,
101
  lr: float,
102
  device: torch.device,
103
+ seed: int = 42,
104
  ) -> dict:
105
+ torch.manual_seed(seed)
106
+
107
+ # Data — use fixed seeds for data, model seed varies
108
  n_train, n_eval = 5000, 500
109
  if task_name == "copy":
110
+ train_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
111
  eval_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
112
  elif task_name == "reverse":
113
+ train_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
114
  eval_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
115
  elif task_name == "retrieval":
116
+ train_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
117
  eval_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
118
  else:
119
  raise ValueError(f"Unknown task: {task_name}")
 
155
  return {
156
  "model_type": model_type,
157
  "task": task_name,
158
+ "seed": seed,
159
  "n_params": n_params,
160
  "n_epochs": n_epochs,
161
  "elapsed_s": round(elapsed, 1),
 
174
  parser.add_argument("--lr", type=float, default=3e-4)
175
  parser.add_argument("--device", type=str, default="cpu")
176
  parser.add_argument("--size", type=str, default="tiny", choices=["tiny", "small"])
177
+ parser.add_argument("--seeds", type=int, nargs="+", default=[42])
178
  parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
179
  args = parser.parse_args()
180
 
181
  device = torch.device(args.device)
182
 
183
  if args.size == "tiny":
184
+ configs = [("baseline", BASELINE_TINY), ("uniform_small", UNIFORM_TINY), ("motif", MOTIF_TINY)]
185
  else:
186
  configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
187
 
188
  results = []
189
 
190
  for task in args.tasks:
191
+ for seed in args.seeds:
192
+ print(f"\n{'='*60}")
193
+ print(f" Task: {task} (size={args.size}, seed={seed})")
194
+ print(f"{'='*60}")
195
+
196
+ for model_type, cfg in configs:
197
+ result = run_experiment(
198
+ task_name=task,
199
+ cfg=cfg,
200
+ model_type=model_type,
201
+ n_epochs=args.epochs,
202
+ batch_size=args.batch_size,
203
+ lr=args.lr,
204
+ device=device,
205
+ seed=seed,
206
+ )
207
+ results.append(result)
208
+ print(f" -> {model_type}: params={result['n_params']:,} "
209
+ f"acc={result['final_accuracy']:.4f} "
210
+ f"time={result['elapsed_s']}s")
211
 
212
  # Summary
213
  print(f"\n{'='*60}")