File size: 8,672 Bytes
ab8a8b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """
Tests for domainTokenizer Phase 2B: Model Architecture.
33 tests covering config, model, PLR, DCNv2, joint fusion, and end-to-end integration.
Run: pytest tests/test_model.py -v
"""
import math
import json
from datetime import datetime
import numpy as np
import torch
import pytest
from domain_tokenizer.models.configuration import DomainTransformerConfig
from domain_tokenizer.models.modeling import (
DomainTransformerForCausalLM, DomainTransformerModel, DomainTransformerAttention, DomainTransformerBlock,
)
from domain_tokenizer.models.plr_embeddings import PeriodicLinearReLU
from domain_tokenizer.models.joint_fusion import DCNv2CrossLayer, DCNv2, JointFusionModel
from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
def tiny_config(vocab_size=128):
return DomainTransformerConfig(
vocab_size=vocab_size, hidden_size=64, num_hidden_layers=2, num_attention_heads=4,
intermediate_size=128, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0,
max_position_embeddings=64,
)
class TestDomainTransformerConfig:
def test_default(self):
c = DomainTransformerConfig()
assert c.vocab_size == 32000 and c.hidden_size == 512 and c.model_type == "domain_transformer"
def test_preset_24m(self):
c = DomainTransformerConfig.from_preset("24m")
assert c.hidden_size == 512 and c.num_hidden_layers == 6
def test_preset_85m(self):
assert DomainTransformerConfig.from_preset("85m").hidden_size == 768
def test_preset_330m(self):
c = DomainTransformerConfig.from_preset("330m")
assert c.hidden_size == 1024 and c.num_hidden_layers == 24
def test_preset_override(self):
c = DomainTransformerConfig.from_preset("24m", vocab_size=500)
assert c.vocab_size == 500 and c.hidden_size == 512
def test_invalid_preset(self):
with pytest.raises(ValueError):
DomainTransformerConfig.from_preset("999m")
def test_serialization(self):
c = DomainTransformerConfig(vocab_size=1000, hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
c2 = DomainTransformerConfig(**c.to_dict())
assert c2.vocab_size == 1000
def test_head_dim(self):
with pytest.raises(AssertionError):
DomainTransformerConfig(hidden_size=100, num_attention_heads=7)
def test_intermediate_default(self):
assert DomainTransformerConfig(hidden_size=256).intermediate_size == 1024
class TestDomainTransformerModel:
def test_forward(self):
m = DomainTransformerModel(tiny_config())
assert m(input_ids=torch.randint(0, 128, (2, 16))).last_hidden_state.shape == (2, 16, 64)
def test_embeds(self):
m = DomainTransformerModel(tiny_config())
assert m(inputs_embeds=torch.randn(2, 16, 64)).last_hidden_state.shape == (2, 16, 64)
class TestDomainTransformerForCausalLM:
def test_no_labels(self):
m = DomainTransformerForCausalLM(tiny_config())
m.eval()
with torch.no_grad():
o = m(input_ids=torch.randint(0, 128, (2, 16)))
assert o.logits.shape == (2, 16, 128) and o.loss is None
def test_with_labels(self):
m = DomainTransformerForCausalLM(tiny_config())
ids = torch.randint(0, 128, (2, 16))
o = m(input_ids=ids, labels=ids)
assert o.loss is not None and o.loss.item() > 0
def test_backward(self):
m = DomainTransformerForCausalLM(tiny_config())
ids = torch.randint(0, 128, (2, 16))
m(input_ids=ids, labels=ids).loss.backward()
assert any(p.grad is not None for p in m.parameters() if p.requires_grad)
def test_weight_tying(self):
m = DomainTransformerForCausalLM(tiny_config())
assert m.lm_head.weight is m.model.embed_tokens.weight
def test_user_embedding(self):
m = DomainTransformerForCausalLM(tiny_config())
m.eval()
with torch.no_grad():
assert m.get_user_embedding(torch.randint(0, 128, (3, 16))).shape == (3, 64)
def test_user_embedding_mask(self):
m = DomainTransformerForCausalLM(tiny_config())
m.eval()
mask = torch.ones(2, 16, dtype=torch.long)
mask[0, 10:] = 0
with torch.no_grad():
assert m.get_user_embedding(torch.randint(0, 128, (2, 16)), attention_mask=mask).shape == (2, 64)
def test_params_tiny(self):
n = sum(p.numel() for p in DomainTransformerForCausalLM(tiny_config()).parameters())
assert n < 1_000_000
def test_params_24m(self):
n = sum(p.numel() for p in DomainTransformerForCausalLM(DomainTransformerConfig.from_preset("24m")).parameters())
assert 15_000_000 < n < 40_000_000
def test_grad_checkpoint(self):
m = DomainTransformerForCausalLM(tiny_config())
m.gradient_checkpointing_enable()
m(input_ids=torch.randint(0, 128, (2, 16)), labels=torch.randint(0, 128, (2, 16))).loss.backward()
class TestAttention:
def test_shape(self):
assert DomainTransformerAttention(tiny_config())(torch.randn(2, 16, 64)).shape == (2, 16, 64)
def test_causal(self):
c = tiny_config()
c.attention_probs_dropout_prob = 0.0
a = DomainTransformerAttention(c)
a.eval()
x = torch.zeros(1, 8, 64)
x[0, 4:, :] = 100.0
with torch.no_grad():
o = a(x)
assert o[0, 7].norm() > o[0, 0].norm() * 2
class TestPLR:
def test_shape(self):
assert PeriodicLinearReLU(10, 32, 64)(torch.randn(4, 10)).shape == (4, 10, 64)
def test_different(self):
p = PeriodicLinearReLU(5, 16, 32)
assert not torch.allclose(p(torch.ones(1, 5)), p(torch.ones(1, 5) * 10))
def test_grad(self):
p = PeriodicLinearReLU(5, 16, 32)
x = torch.randn(2, 5, requires_grad=True)
p(x).sum().backward()
assert x.grad is not None and p.frequencies.grad is not None
def test_single(self):
assert PeriodicLinearReLU(1, 8, 16)(torch.tensor([[3.14]])).shape == (1, 1, 16)
class TestDCNv2:
def test_cross(self):
assert DCNv2CrossLayer(64)(torch.randn(4, 64), torch.randn(4, 64)).shape == (4, 64)
def test_dcn(self):
d = DCNv2(128, 3, 2, 64)
assert d(torch.randn(4, 128)).shape == (4, 64) and d.output_dim == 64
class TestJointFusion:
@pytest.fixture
def model(self):
return JointFusionModel(
DomainTransformerForCausalLM(tiny_config(128)), 10, 1, 8, 16, 2, 2, 32, 32,
)
def test_forward(self, model):
o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10))
assert o["logits"].shape == (2, 1) and o["loss"] is None
def test_loss(self, model):
o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10), torch.tensor([1.0, 0.0]))
assert o["loss"] is not None and o["loss"].dim() == 0
def test_backward(self, model):
o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10), torch.tensor([1.0, 0.0]))
o["loss"].backward()
assert model.transformer.model.embed_tokens.weight.grad is not None
assert model.plr.frequencies.grad is not None
def test_multiclass(self):
m = JointFusionModel(DomainTransformerForCausalLM(tiny_config(128)), 5, 3, 4, 8, 2, 2, 16, 16)
o = m(torch.randint(0, 128, (2, 8)), tabular_features=torch.randn(2, 5), labels=torch.tensor([0, 2]))
assert o["logits"].shape == (2, 3) and o["loss"] is not None
class TestIntegration:
def test_finance(self):
events = [
{"amount_sign": 79.99, "amount": 79.99, "timestamp": datetime(2025, 3, 15, 14, 30), "description": "AMAZON"},
{"amount_sign": -200.0, "amount": -200.0, "timestamp": datetime(2025, 3, 16, 9, 0), "description": "SALARY"},
]
builder = DomainTokenizerBuilder(FINANCE_SCHEMA)
builder.fit(events)
hf_tok = builder.build(text_corpus=["AMAZON", "SALARY", "UBER", "GROCERY"] * 20, bpe_vocab_size=300)
enc = builder.encode_sequence(events, hf_tok, max_length=64)
ids = torch.tensor([enc["input_ids"]])
mask = torch.tensor([enc["attention_mask"]])
model = DomainTransformerForCausalLM(tiny_config(hf_tok.vocab_size))
out = model(input_ids=ids, attention_mask=mask, labels=ids)
assert out.loss.item() > 0
out.loss.backward()
assert sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None) > 0
|