Search commited on
Commit
c3488d4
·
1 Parent(s): 019d725

fix: add tiny config (seq=32, d=128) for fast FOG ablation

Browse files
Files changed (2) hide show
  1. src/fog/config.py +23 -0
  2. src/fog/train.py +10 -3
src/fog/config.py CHANGED
@@ -33,3 +33,26 @@ MOTIF_SMALL = FOGConfig(
33
  d_expand=512,
34
  d_gate=32,
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  d_expand=512,
34
  d_gate=32,
35
  )
36
+
37
+ # Tiny configs for fast iteration
38
+ BASELINE_TINY = FOGConfig(
39
+ vocab_size=32,
40
+ d_model=128,
41
+ n_layers=4,
42
+ n_heads=4,
43
+ max_seq_len=32,
44
+ d_ff=512,
45
+ )
46
+
47
+ MOTIF_TINY = FOGConfig(
48
+ vocab_size=32,
49
+ d_model=128,
50
+ n_layers=4,
51
+ n_heads=4,
52
+ max_seq_len=32,
53
+ d_ff=512,
54
+ d_compare=32,
55
+ d_memory=96,
56
+ d_expand=256,
57
+ d_gate=16,
58
+ )
src/fog/train.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  import torch
10
  from torch.utils.data import DataLoader
11
 
12
- from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL
13
  from src.fog.model_baseline import BaselineTransformer
14
  from src.fog.model_motif import MotifTransformer
15
  from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
@@ -158,18 +158,25 @@ def main() -> None:
158
  parser.add_argument("--batch_size", type=int, default=64)
159
  parser.add_argument("--lr", type=float, default=3e-4)
160
  parser.add_argument("--device", type=str, default="cpu")
 
161
  parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
162
  args = parser.parse_args()
163
 
164
  device = torch.device(args.device)
 
 
 
 
 
 
165
  results = []
166
 
167
  for task in args.tasks:
168
  print(f"\n{'='*60}")
169
- print(f" Task: {task}")
170
  print(f"{'='*60}")
171
 
172
- for model_type, cfg in [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]:
173
  result = run_experiment(
174
  task_name=task,
175
  cfg=cfg,
 
9
  import torch
10
  from torch.utils.data import DataLoader
11
 
12
+ from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL, BASELINE_TINY, MOTIF_TINY
13
  from src.fog.model_baseline import BaselineTransformer
14
  from src.fog.model_motif import MotifTransformer
15
  from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
 
158
  parser.add_argument("--batch_size", type=int, default=64)
159
  parser.add_argument("--lr", type=float, default=3e-4)
160
  parser.add_argument("--device", type=str, default="cpu")
161
+ parser.add_argument("--size", type=str, default="tiny", choices=["tiny", "small"])
162
  parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
163
  args = parser.parse_args()
164
 
165
  device = torch.device(args.device)
166
+
167
+ if args.size == "tiny":
168
+ configs = [("baseline", BASELINE_TINY), ("motif", MOTIF_TINY)]
169
+ else:
170
+ configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
171
+
172
  results = []
173
 
174
  for task in args.tasks:
175
  print(f"\n{'='*60}")
176
+ print(f" Task: {task} (size={args.size})")
177
  print(f"{'='*60}")
178
 
179
+ for model_type, cfg in configs:
180
  result = run_experiment(
181
  task_name=task,
182
  cfg=cfg,