Spaces:
Paused
Paused
| """Tests for FOG baseline and motif-aware models.""" | |
| from __future__ import annotations | |
| import torch | |
| from src.fog.config import ( | |
| BASELINE_SMALL, MOTIF_SMALL, | |
| BASELINE_MICRO, MOTIF_MICRO, UNIFORM_MICRO, | |
| ) | |
| from src.fog.model_baseline import BaselineTransformer | |
| from src.fog.model_motif import MotifTransformer | |
| from src.fog.data import ( | |
| CopyTask, ReverseTask, SelectiveRetrieval, | |
| ChainedRetrieval, prebatch_dataset, TensorBatchIterator, | |
| ) | |
| def test_baseline_forward_backward() -> None: | |
| model = BaselineTransformer(BASELINE_SMALL) | |
| x = torch.randint(0, BASELINE_SMALL.vocab_size, (2, 32)) | |
| y = torch.randint(0, BASELINE_SMALL.vocab_size, (2, 32)) | |
| out = model(x, y) | |
| assert out["logits"].shape == (2, 32, BASELINE_SMALL.vocab_size) | |
| assert out["loss"] is not None | |
| out["loss"].backward() | |
| def test_motif_forward_backward() -> None: | |
| model = MotifTransformer(MOTIF_SMALL) | |
| x = torch.randint(0, MOTIF_SMALL.vocab_size, (2, 32)) | |
| y = torch.randint(0, MOTIF_SMALL.vocab_size, (2, 32)) | |
| out = model(x, y) | |
| assert out["logits"].shape == (2, 32, MOTIF_SMALL.vocab_size) | |
| assert out["loss"] is not None | |
| out["loss"].backward() | |
| def test_motif_has_separate_subspaces() -> None: | |
| model = MotifTransformer(MOTIF_SMALL) | |
| attn = model.blocks[0].attn | |
| assert attn.q_proj.out_features == MOTIF_SMALL.d_compare | |
| assert attn.k_proj.out_features == MOTIF_SMALL.d_compare | |
| assert attn.v_proj.out_features == MOTIF_SMALL.d_memory | |
| assert MOTIF_SMALL.d_compare < MOTIF_SMALL.d_memory | |
| ffn = model.blocks[0].ffn | |
| assert ffn.gate.out_features == MOTIF_SMALL.d_gate | |
| assert ffn.expand.out_features == MOTIF_SMALL.d_expand | |
| assert MOTIF_SMALL.d_gate < MOTIF_SMALL.d_expand | |
| def test_copy_task() -> None: | |
| ds = CopyTask(vocab_size=64, seq_len=32, n_samples=10) | |
| assert len(ds) == 10 | |
| sample = ds[0] | |
| assert sample["input_ids"].shape == (31,) | |
| assert sample["targets"].shape == (31,) | |
| def test_reverse_task() -> None: | |
| ds = ReverseTask(vocab_size=64, seq_len=32, n_samples=10) | |
| assert len(ds) == 10 | |
| def test_selective_retrieval_task() -> None: | |
| ds = SelectiveRetrieval(vocab_size=64, seq_len=32, n_samples=10, n_pairs=4) | |
| assert len(ds) == 10 | |
| def test_chained_retrieval_task() -> None: | |
| ds = ChainedRetrieval(vocab_size=128, seq_len=64, n_samples=50, n_pairs=6) | |
| assert len(ds) > 0 | |
| sample = ds[0] | |
| assert sample["input_ids"].shape[0] == 63 | |
| assert sample["targets"].shape[0] == 63 | |
| assert sample["loss_mask"].sum() > 0 | |
| def test_prebatch_and_iterator() -> None: | |
| ds = CopyTask(vocab_size=64, seq_len=32, n_samples=20) | |
| data = prebatch_dataset(ds, 32) | |
| assert data["input_ids"].shape == (20, 31) | |
| loader = TensorBatchIterator(data, batch_size=8, shuffle=True) | |
| batches = list(loader) | |
| assert len(batches) == 3 | |
| assert batches[0]["input_ids"].shape[0] == 8 | |
| def test_micro_configs_forward() -> None: | |
| for cfg, cls in [ | |
| (BASELINE_MICRO, BaselineTransformer), | |
| (UNIFORM_MICRO, BaselineTransformer), | |
| (MOTIF_MICRO, MotifTransformer), | |
| ]: | |
| model = cls(cfg) | |
| x = torch.randint(0, cfg.vocab_size, (2, 32)) | |
| y = torch.randint(0, cfg.vocab_size, (2, 32)) | |
| out = model(x, y) | |
| assert out["loss"] is not None | |
| out["loss"].backward() | |