Spaces:
Paused
Paused
Search commited on
Commit ·
35e5b8a
1
Parent(s): c3488d4
fix: proper controls — loss masking, uniform baseline, multi-seed
Browse files- src/fog/config.py +11 -0
- src/fog/data.py +26 -6
- src/fog/model_baseline.py +12 -1
- src/fog/model_motif.py +11 -1
- src/fog/train.py +54 -36
src/fog/config.py
CHANGED
|
@@ -34,6 +34,17 @@ MOTIF_SMALL = FOGConfig(
|
|
| 34 |
d_gate=32,
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Tiny configs for fast iteration
|
| 38 |
BASELINE_TINY = FOGConfig(
|
| 39 |
vocab_size=32,
|
|
|
|
| 34 |
d_gate=32,
|
| 35 |
)
|
| 36 |
|
| 37 |
+
# Param-matched uniform baseline for controlled comparison
|
| 38 |
+
# d_model=94, d_ff=376 → ~432K params to match MOTIF_TINY
|
| 39 |
+
UNIFORM_TINY = FOGConfig(
|
| 40 |
+
vocab_size=32,
|
| 41 |
+
d_model=94,
|
| 42 |
+
n_layers=4,
|
| 43 |
+
n_heads=2,
|
| 44 |
+
max_seq_len=32,
|
| 45 |
+
d_ff=376,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
# Tiny configs for fast iteration
|
| 49 |
BASELINE_TINY = FOGConfig(
|
| 50 |
vocab_size=32,
|
src/fog/data.py
CHANGED
|
@@ -16,17 +16,18 @@ class CopyTask(Dataset):
|
|
| 16 |
self.sep_token = vocab_size - 1
|
| 17 |
rng = random.Random(seed)
|
| 18 |
self.samples = []
|
| 19 |
-
|
|
|
|
| 20 |
half = seq_len // 2 - 1
|
| 21 |
for _ in range(n_samples):
|
| 22 |
content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
|
| 23 |
-
# input: content + SEP + content (teacher forcing)
|
| 24 |
ids = content + [self.sep_token] + content
|
| 25 |
-
|
| 26 |
ids = ids[:seq_len]
|
| 27 |
while len(ids) < seq_len:
|
| 28 |
ids.append(0)
|
| 29 |
self.samples.append(ids)
|
|
|
|
| 30 |
|
| 31 |
def __len__(self) -> int:
|
| 32 |
return len(self.samples)
|
|
@@ -35,7 +36,12 @@ class CopyTask(Dataset):
|
|
| 35 |
ids = self.samples[idx]
|
| 36 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 37 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
class ReverseTask(Dataset):
|
|
@@ -47,15 +53,18 @@ class ReverseTask(Dataset):
|
|
| 47 |
self.sep_token = vocab_size - 1
|
| 48 |
rng = random.Random(seed)
|
| 49 |
self.samples = []
|
|
|
|
| 50 |
content_vocab = vocab_size - 1
|
| 51 |
half = seq_len // 2 - 1
|
| 52 |
for _ in range(n_samples):
|
| 53 |
content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
|
| 54 |
ids = content + [self.sep_token] + list(reversed(content))
|
|
|
|
| 55 |
ids = ids[:seq_len]
|
| 56 |
while len(ids) < seq_len:
|
| 57 |
ids.append(0)
|
| 58 |
self.samples.append(ids)
|
|
|
|
| 59 |
|
| 60 |
def __len__(self) -> int:
|
| 61 |
return len(self.samples)
|
|
@@ -64,7 +73,11 @@ class ReverseTask(Dataset):
|
|
| 64 |
ids = self.samples[idx]
|
| 65 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 66 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
class SelectiveRetrieval(Dataset):
|
|
@@ -79,6 +92,7 @@ class SelectiveRetrieval(Dataset):
|
|
| 79 |
self.sep_token = vocab_size - 1
|
| 80 |
rng = random.Random(seed)
|
| 81 |
self.samples = []
|
|
|
|
| 82 |
content_vocab = vocab_size - 2 # exclude SEP and padding
|
| 83 |
for _ in range(n_samples):
|
| 84 |
keys = rng.sample(range(content_vocab), min(n_pairs, content_vocab))
|
|
@@ -88,6 +102,7 @@ class SelectiveRetrieval(Dataset):
|
|
| 88 |
ids = []
|
| 89 |
for k, v in zip(keys, values):
|
| 90 |
ids.extend([k, v])
|
|
|
|
| 91 |
ids.append(self.sep_token)
|
| 92 |
ids.append(keys[query_idx])
|
| 93 |
ids.append(values[query_idx])
|
|
@@ -96,6 +111,7 @@ class SelectiveRetrieval(Dataset):
|
|
| 96 |
while len(ids) < seq_len:
|
| 97 |
ids.append(0)
|
| 98 |
self.samples.append(ids)
|
|
|
|
| 99 |
|
| 100 |
def __len__(self) -> int:
|
| 101 |
return len(self.samples)
|
|
@@ -104,4 +120,8 @@ class SelectiveRetrieval(Dataset):
|
|
| 104 |
ids = self.samples[idx]
|
| 105 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 106 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
self.sep_token = vocab_size - 1
|
| 17 |
rng = random.Random(seed)
|
| 18 |
self.samples = []
|
| 19 |
+
self.sep_positions = []
|
| 20 |
+
content_vocab = vocab_size - 1
|
| 21 |
half = seq_len // 2 - 1
|
| 22 |
for _ in range(n_samples):
|
| 23 |
content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
|
|
|
|
| 24 |
ids = content + [self.sep_token] + content
|
| 25 |
+
sep_pos = len(content)
|
| 26 |
ids = ids[:seq_len]
|
| 27 |
while len(ids) < seq_len:
|
| 28 |
ids.append(0)
|
| 29 |
self.samples.append(ids)
|
| 30 |
+
self.sep_positions.append(sep_pos)
|
| 31 |
|
| 32 |
def __len__(self) -> int:
|
| 33 |
return len(self.samples)
|
|
|
|
| 36 |
ids = self.samples[idx]
|
| 37 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 38 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 39 |
+
# loss_mask: 1 after SEP, 0 before (shifted by -1 for targets)
|
| 40 |
+
mask = torch.zeros_like(y)
|
| 41 |
+
sep = self.sep_positions[idx]
|
| 42 |
+
if sep < len(mask):
|
| 43 |
+
mask[sep:] = 1
|
| 44 |
+
return {"input_ids": x, "targets": y, "loss_mask": mask}
|
| 45 |
|
| 46 |
|
| 47 |
class ReverseTask(Dataset):
|
|
|
|
| 53 |
self.sep_token = vocab_size - 1
|
| 54 |
rng = random.Random(seed)
|
| 55 |
self.samples = []
|
| 56 |
+
self.sep_positions = []
|
| 57 |
content_vocab = vocab_size - 1
|
| 58 |
half = seq_len // 2 - 1
|
| 59 |
for _ in range(n_samples):
|
| 60 |
content = [rng.randint(0, content_vocab - 1) for _ in range(half)]
|
| 61 |
ids = content + [self.sep_token] + list(reversed(content))
|
| 62 |
+
sep_pos = len(content)
|
| 63 |
ids = ids[:seq_len]
|
| 64 |
while len(ids) < seq_len:
|
| 65 |
ids.append(0)
|
| 66 |
self.samples.append(ids)
|
| 67 |
+
self.sep_positions.append(sep_pos)
|
| 68 |
|
| 69 |
def __len__(self) -> int:
|
| 70 |
return len(self.samples)
|
|
|
|
| 73 |
ids = self.samples[idx]
|
| 74 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 75 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 76 |
+
mask = torch.zeros_like(y)
|
| 77 |
+
sep = self.sep_positions[idx]
|
| 78 |
+
if sep < len(mask):
|
| 79 |
+
mask[sep:] = 1
|
| 80 |
+
return {"input_ids": x, "targets": y, "loss_mask": mask}
|
| 81 |
|
| 82 |
|
| 83 |
class SelectiveRetrieval(Dataset):
|
|
|
|
| 92 |
self.sep_token = vocab_size - 1
|
| 93 |
rng = random.Random(seed)
|
| 94 |
self.samples = []
|
| 95 |
+
self.sep_positions = []
|
| 96 |
content_vocab = vocab_size - 2 # exclude SEP and padding
|
| 97 |
for _ in range(n_samples):
|
| 98 |
keys = rng.sample(range(content_vocab), min(n_pairs, content_vocab))
|
|
|
|
| 102 |
ids = []
|
| 103 |
for k, v in zip(keys, values):
|
| 104 |
ids.extend([k, v])
|
| 105 |
+
sep_pos = len(ids)
|
| 106 |
ids.append(self.sep_token)
|
| 107 |
ids.append(keys[query_idx])
|
| 108 |
ids.append(values[query_idx])
|
|
|
|
| 111 |
while len(ids) < seq_len:
|
| 112 |
ids.append(0)
|
| 113 |
self.samples.append(ids)
|
| 114 |
+
self.sep_positions.append(sep_pos)
|
| 115 |
|
| 116 |
def __len__(self) -> int:
|
| 117 |
return len(self.samples)
|
|
|
|
| 120 |
ids = self.samples[idx]
|
| 121 |
x = torch.tensor(ids[:-1], dtype=torch.long)
|
| 122 |
y = torch.tensor(ids[1:], dtype=torch.long)
|
| 123 |
+
mask = torch.zeros_like(y)
|
| 124 |
+
sep = self.sep_positions[idx]
|
| 125 |
+
if sep < len(mask):
|
| 126 |
+
mask[sep:] = 1
|
| 127 |
+
return {"input_ids": x, "targets": y, "loss_mask": mask}
|
src/fog/model_baseline.py
CHANGED
|
@@ -79,6 +79,7 @@ class BaselineTransformer(nn.Module):
|
|
| 79 |
self,
|
| 80 |
input_ids: torch.Tensor,
|
| 81 |
targets: torch.Tensor | None = None,
|
|
|
|
| 82 |
) -> dict[str, torch.Tensor]:
|
| 83 |
b, t = input_ids.shape
|
| 84 |
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
|
|
@@ -95,6 +96,16 @@ class BaselineTransformer(nn.Module):
|
|
| 95 |
|
| 96 |
loss = None
|
| 97 |
if targets is not None:
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
return {"logits": logits, "loss": loss}
|
|
|
|
| 79 |
self,
|
| 80 |
input_ids: torch.Tensor,
|
| 81 |
targets: torch.Tensor | None = None,
|
| 82 |
+
loss_mask: torch.Tensor | None = None,
|
| 83 |
) -> dict[str, torch.Tensor]:
|
| 84 |
b, t = input_ids.shape
|
| 85 |
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
|
|
|
|
| 96 |
|
| 97 |
loss = None
|
| 98 |
if targets is not None:
|
| 99 |
+
if loss_mask is not None:
|
| 100 |
+
# only compute loss on target positions (after SEP)
|
| 101 |
+
flat_logits = logits.view(-1, logits.size(-1))
|
| 102 |
+
flat_targets = targets.view(-1)
|
| 103 |
+
flat_mask = loss_mask.view(-1).bool()
|
| 104 |
+
if flat_mask.any():
|
| 105 |
+
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
|
| 106 |
+
else:
|
| 107 |
+
loss = torch.tensor(0.0, device=logits.device)
|
| 108 |
+
else:
|
| 109 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 110 |
|
| 111 |
return {"logits": logits, "loss": loss}
|
src/fog/model_motif.py
CHANGED
|
@@ -121,6 +121,7 @@ class MotifTransformer(nn.Module):
|
|
| 121 |
self,
|
| 122 |
input_ids: torch.Tensor,
|
| 123 |
targets: torch.Tensor | None = None,
|
|
|
|
| 124 |
) -> dict[str, torch.Tensor]:
|
| 125 |
b, t = input_ids.shape
|
| 126 |
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
|
|
@@ -136,6 +137,15 @@ class MotifTransformer(nn.Module):
|
|
| 136 |
|
| 137 |
loss = None
|
| 138 |
if targets is not None:
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
return {"logits": logits, "loss": loss}
|
|
|
|
| 121 |
self,
|
| 122 |
input_ids: torch.Tensor,
|
| 123 |
targets: torch.Tensor | None = None,
|
| 124 |
+
loss_mask: torch.Tensor | None = None,
|
| 125 |
) -> dict[str, torch.Tensor]:
|
| 126 |
b, t = input_ids.shape
|
| 127 |
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
|
|
|
|
| 137 |
|
| 138 |
loss = None
|
| 139 |
if targets is not None:
|
| 140 |
+
if loss_mask is not None:
|
| 141 |
+
flat_logits = logits.view(-1, logits.size(-1))
|
| 142 |
+
flat_targets = targets.view(-1)
|
| 143 |
+
flat_mask = loss_mask.view(-1).bool()
|
| 144 |
+
if flat_mask.any():
|
| 145 |
+
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
|
| 146 |
+
else:
|
| 147 |
+
loss = torch.tensor(0.0, device=logits.device)
|
| 148 |
+
else:
|
| 149 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 150 |
|
| 151 |
return {"logits": logits, "loss": loss}
|
src/fog/train.py
CHANGED
|
@@ -9,7 +9,7 @@ from pathlib import Path
|
|
| 9 |
import torch
|
| 10 |
from torch.utils.data import DataLoader
|
| 11 |
|
| 12 |
-
from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL, BASELINE_TINY, MOTIF_TINY
|
| 13 |
from src.fog.model_baseline import BaselineTransformer
|
| 14 |
from src.fog.model_motif import MotifTransformer
|
| 15 |
from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
|
|
@@ -31,7 +31,10 @@ def train_epoch(
|
|
| 31 |
for batch in loader:
|
| 32 |
input_ids = batch["input_ids"].to(device)
|
| 33 |
targets = batch["targets"].to(device)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
loss = out["loss"]
|
| 36 |
optimizer.zero_grad()
|
| 37 |
loss.backward()
|
|
@@ -58,21 +61,29 @@ def eval_accuracy(
|
|
| 58 |
for batch in loader:
|
| 59 |
input_ids = batch["input_ids"].to(device)
|
| 60 |
targets = batch["targets"].to(device)
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
total_loss += out["loss"].item()
|
| 63 |
n_batches += 1
|
| 64 |
|
| 65 |
preds = out["logits"].argmax(dim=-1)
|
| 66 |
-
# only
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
return {
|
| 78 |
"loss": total_loss / max(n_batches, 1),
|
|
@@ -89,17 +100,20 @@ def run_experiment(
|
|
| 89 |
batch_size: int,
|
| 90 |
lr: float,
|
| 91 |
device: torch.device,
|
|
|
|
| 92 |
) -> dict:
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
n_train, n_eval = 5000, 500
|
| 95 |
if task_name == "copy":
|
| 96 |
-
train_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=
|
| 97 |
eval_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 98 |
elif task_name == "reverse":
|
| 99 |
-
train_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=
|
| 100 |
eval_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 101 |
elif task_name == "retrieval":
|
| 102 |
-
train_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_train, seed=
|
| 103 |
eval_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 104 |
else:
|
| 105 |
raise ValueError(f"Unknown task: {task_name}")
|
|
@@ -141,6 +155,7 @@ def run_experiment(
|
|
| 141 |
return {
|
| 142 |
"model_type": model_type,
|
| 143 |
"task": task_name,
|
|
|
|
| 144 |
"n_params": n_params,
|
| 145 |
"n_epochs": n_epochs,
|
| 146 |
"elapsed_s": round(elapsed, 1),
|
|
@@ -159,37 +174,40 @@ def main() -> None:
|
|
| 159 |
parser.add_argument("--lr", type=float, default=3e-4)
|
| 160 |
parser.add_argument("--device", type=str, default="cpu")
|
| 161 |
parser.add_argument("--size", type=str, default="tiny", choices=["tiny", "small"])
|
|
|
|
| 162 |
parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
|
| 163 |
args = parser.parse_args()
|
| 164 |
|
| 165 |
device = torch.device(args.device)
|
| 166 |
|
| 167 |
if args.size == "tiny":
|
| 168 |
-
configs = [("baseline", BASELINE_TINY), ("motif", MOTIF_TINY)]
|
| 169 |
else:
|
| 170 |
configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
|
| 171 |
|
| 172 |
results = []
|
| 173 |
|
| 174 |
for task in args.tasks:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
| 193 |
|
| 194 |
# Summary
|
| 195 |
print(f"\n{'='*60}")
|
|
|
|
| 9 |
import torch
|
| 10 |
from torch.utils.data import DataLoader
|
| 11 |
|
| 12 |
+
from src.fog.config import FOGConfig, BASELINE_SMALL, MOTIF_SMALL, BASELINE_TINY, MOTIF_TINY, UNIFORM_TINY
|
| 13 |
from src.fog.model_baseline import BaselineTransformer
|
| 14 |
from src.fog.model_motif import MotifTransformer
|
| 15 |
from src.fog.data import CopyTask, ReverseTask, SelectiveRetrieval
|
|
|
|
| 31 |
for batch in loader:
|
| 32 |
input_ids = batch["input_ids"].to(device)
|
| 33 |
targets = batch["targets"].to(device)
|
| 34 |
+
loss_mask = batch.get("loss_mask")
|
| 35 |
+
if loss_mask is not None:
|
| 36 |
+
loss_mask = loss_mask.to(device)
|
| 37 |
+
out = model(input_ids, targets, loss_mask=loss_mask)
|
| 38 |
loss = out["loss"]
|
| 39 |
optimizer.zero_grad()
|
| 40 |
loss.backward()
|
|
|
|
| 61 |
for batch in loader:
|
| 62 |
input_ids = batch["input_ids"].to(device)
|
| 63 |
targets = batch["targets"].to(device)
|
| 64 |
+
loss_mask = batch.get("loss_mask")
|
| 65 |
+
if loss_mask is not None:
|
| 66 |
+
loss_mask = loss_mask.to(device)
|
| 67 |
+
out = model(input_ids, targets, loss_mask=loss_mask)
|
| 68 |
total_loss += out["loss"].item()
|
| 69 |
n_batches += 1
|
| 70 |
|
| 71 |
preds = out["logits"].argmax(dim=-1)
|
| 72 |
+
# accuracy only on masked (target) positions
|
| 73 |
+
if loss_mask is not None:
|
| 74 |
+
m = loss_mask.bool()
|
| 75 |
+
correct += (preds[m] == targets[m]).sum().item()
|
| 76 |
+
total += m.sum().item()
|
| 77 |
+
else:
|
| 78 |
+
for i in range(input_ids.size(0)):
|
| 79 |
+
sep_positions = (input_ids[i] == sep_token).nonzero(as_tuple=True)[0]
|
| 80 |
+
if len(sep_positions) == 0:
|
| 81 |
+
continue
|
| 82 |
+
start = sep_positions[0].item() + 1
|
| 83 |
+
if start >= targets.size(1):
|
| 84 |
+
continue
|
| 85 |
+
correct += (preds[i, start:] == targets[i, start:]).sum().item()
|
| 86 |
+
total += targets.size(1) - start
|
| 87 |
|
| 88 |
return {
|
| 89 |
"loss": total_loss / max(n_batches, 1),
|
|
|
|
| 100 |
batch_size: int,
|
| 101 |
lr: float,
|
| 102 |
device: torch.device,
|
| 103 |
+
seed: int = 42,
|
| 104 |
) -> dict:
|
| 105 |
+
torch.manual_seed(seed)
|
| 106 |
+
|
| 107 |
+
# Data — use fixed seeds for data, model seed varies
|
| 108 |
n_train, n_eval = 5000, 500
|
| 109 |
if task_name == "copy":
|
| 110 |
+
train_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
|
| 111 |
eval_ds = CopyTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 112 |
elif task_name == "reverse":
|
| 113 |
+
train_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
|
| 114 |
eval_ds = ReverseTask(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 115 |
elif task_name == "retrieval":
|
| 116 |
+
train_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0)
|
| 117 |
eval_ds = SelectiveRetrieval(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99)
|
| 118 |
else:
|
| 119 |
raise ValueError(f"Unknown task: {task_name}")
|
|
|
|
| 155 |
return {
|
| 156 |
"model_type": model_type,
|
| 157 |
"task": task_name,
|
| 158 |
+
"seed": seed,
|
| 159 |
"n_params": n_params,
|
| 160 |
"n_epochs": n_epochs,
|
| 161 |
"elapsed_s": round(elapsed, 1),
|
|
|
|
| 174 |
parser.add_argument("--lr", type=float, default=3e-4)
|
| 175 |
parser.add_argument("--device", type=str, default="cpu")
|
| 176 |
parser.add_argument("--size", type=str, default="tiny", choices=["tiny", "small"])
|
| 177 |
+
parser.add_argument("--seeds", type=int, nargs="+", default=[42])
|
| 178 |
parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
|
| 179 |
args = parser.parse_args()
|
| 180 |
|
| 181 |
device = torch.device(args.device)
|
| 182 |
|
| 183 |
if args.size == "tiny":
|
| 184 |
+
configs = [("baseline", BASELINE_TINY), ("uniform_small", UNIFORM_TINY), ("motif", MOTIF_TINY)]
|
| 185 |
else:
|
| 186 |
configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
|
| 187 |
|
| 188 |
results = []
|
| 189 |
|
| 190 |
for task in args.tasks:
|
| 191 |
+
for seed in args.seeds:
|
| 192 |
+
print(f"\n{'='*60}")
|
| 193 |
+
print(f" Task: {task} (size={args.size}, seed={seed})")
|
| 194 |
+
print(f"{'='*60}")
|
| 195 |
+
|
| 196 |
+
for model_type, cfg in configs:
|
| 197 |
+
result = run_experiment(
|
| 198 |
+
task_name=task,
|
| 199 |
+
cfg=cfg,
|
| 200 |
+
model_type=model_type,
|
| 201 |
+
n_epochs=args.epochs,
|
| 202 |
+
batch_size=args.batch_size,
|
| 203 |
+
lr=args.lr,
|
| 204 |
+
device=device,
|
| 205 |
+
seed=seed,
|
| 206 |
+
)
|
| 207 |
+
results.append(result)
|
| 208 |
+
print(f" -> {model_type}: params={result['n_params']:,} "
|
| 209 |
+
f"acc={result['final_accuracy']:.4f} "
|
| 210 |
+
f"time={result['elapsed_s']}s")
|
| 211 |
|
| 212 |
# Summary
|
| 213 |
print(f"\n{'='*60}")
|