abpt / tests /test_fog.py
Search
sync: FOG micro+medium configs, stress tasks, fast pipeline
6ef010e
"""Tests for FOG baseline and motif-aware models."""
from __future__ import annotations
import torch
from src.fog.config import (
BASELINE_SMALL, MOTIF_SMALL,
BASELINE_MICRO, MOTIF_MICRO, UNIFORM_MICRO,
)
from src.fog.model_baseline import BaselineTransformer
from src.fog.model_motif import MotifTransformer
from src.fog.data import (
CopyTask, ReverseTask, SelectiveRetrieval,
ChainedRetrieval, prebatch_dataset, TensorBatchIterator,
)
def test_baseline_forward_backward() -> None:
model = BaselineTransformer(BASELINE_SMALL)
x = torch.randint(0, BASELINE_SMALL.vocab_size, (2, 32))
y = torch.randint(0, BASELINE_SMALL.vocab_size, (2, 32))
out = model(x, y)
assert out["logits"].shape == (2, 32, BASELINE_SMALL.vocab_size)
assert out["loss"] is not None
out["loss"].backward()
def test_motif_forward_backward() -> None:
model = MotifTransformer(MOTIF_SMALL)
x = torch.randint(0, MOTIF_SMALL.vocab_size, (2, 32))
y = torch.randint(0, MOTIF_SMALL.vocab_size, (2, 32))
out = model(x, y)
assert out["logits"].shape == (2, 32, MOTIF_SMALL.vocab_size)
assert out["loss"] is not None
out["loss"].backward()
def test_motif_has_separate_subspaces() -> None:
model = MotifTransformer(MOTIF_SMALL)
attn = model.blocks[0].attn
assert attn.q_proj.out_features == MOTIF_SMALL.d_compare
assert attn.k_proj.out_features == MOTIF_SMALL.d_compare
assert attn.v_proj.out_features == MOTIF_SMALL.d_memory
assert MOTIF_SMALL.d_compare < MOTIF_SMALL.d_memory
ffn = model.blocks[0].ffn
assert ffn.gate.out_features == MOTIF_SMALL.d_gate
assert ffn.expand.out_features == MOTIF_SMALL.d_expand
assert MOTIF_SMALL.d_gate < MOTIF_SMALL.d_expand
def test_copy_task() -> None:
ds = CopyTask(vocab_size=64, seq_len=32, n_samples=10)
assert len(ds) == 10
sample = ds[0]
assert sample["input_ids"].shape == (31,)
assert sample["targets"].shape == (31,)
def test_reverse_task() -> None:
ds = ReverseTask(vocab_size=64, seq_len=32, n_samples=10)
assert len(ds) == 10
def test_selective_retrieval_task() -> None:
ds = SelectiveRetrieval(vocab_size=64, seq_len=32, n_samples=10, n_pairs=4)
assert len(ds) == 10
def test_chained_retrieval_task() -> None:
ds = ChainedRetrieval(vocab_size=128, seq_len=64, n_samples=50, n_pairs=6)
assert len(ds) > 0
sample = ds[0]
assert sample["input_ids"].shape[0] == 63
assert sample["targets"].shape[0] == 63
assert sample["loss_mask"].sum() > 0
def test_prebatch_and_iterator() -> None:
ds = CopyTask(vocab_size=64, seq_len=32, n_samples=20)
data = prebatch_dataset(ds, 32)
assert data["input_ids"].shape == (20, 31)
loader = TensorBatchIterator(data, batch_size=8, shuffle=True)
batches = list(loader)
assert len(batches) == 3
assert batches[0]["input_ids"].shape[0] == 8
def test_micro_configs_forward() -> None:
for cfg, cls in [
(BASELINE_MICRO, BaselineTransformer),
(UNIFORM_MICRO, BaselineTransformer),
(MOTIF_MICRO, MotifTransformer),
]:
model = cls(cfg)
x = torch.randint(0, cfg.vocab_size, (2, 32))
y = torch.randint(0, cfg.vocab_size, (2, 32))
out = model(x, y)
assert out["loss"] is not None
out["loss"].backward()