"""Smoke tests for the optional CRF token-classification head.""" from __future__ import annotations import math import sys import tempfile from pathlib import Path import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from anifilebert.config import Config from anifilebert.model import BertCrfForTokenClassification, LinearChainCRF, create_model, load_model def tiny_config() -> Config: cfg = Config() cfg.vocab_size = 32 cfg.max_position_embeddings = 16 cfg.hidden_size = 32 cfg.num_hidden_layers = 1 cfg.num_attention_heads = 4 cfg.intermediate_size = 64 cfg.max_seq_length = 8 return cfg def test_crf_forward_backward_and_decode() -> None: torch.manual_seed(7) cfg = tiny_config() model = create_model(cfg, model_head="crf") assert isinstance(model, BertCrfForTokenClassification) input_ids = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_length), dtype=torch.long) attention_mask = torch.tensor( [ [1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0, 0], ], dtype=torch.bool, ) labels = torch.tensor( [ [-100, 1, 2, 5, 6, 0, -100, -100], [-100, 7, 8, 0, 13, -100, -100, -100], ], dtype=torch.long, ) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) assert outputs.loss is not None assert outputs.loss.ndim == 0 outputs.loss.backward() assert model.classifier.weight.grad is not None paths = model.decode(outputs.logits, attention_mask) assert [len(path) for path in paths] == [5, 4] def test_crf_can_bootstrap_from_linear_checkpoint() -> None: cfg = tiny_config() linear_model = create_model(cfg, model_head="linear") with tempfile.TemporaryDirectory() as tmp: path = Path(tmp) linear_model.save_pretrained(path) crf_model = load_model(str(path), model_head="crf") assert isinstance(crf_model, BertCrfForTokenClassification) assert crf_model.get_input_embeddings().weight.shape[0] == cfg.vocab_size assert crf_model.classifier.out_features == cfg.num_labels def test_crf_loss_allows_weak_bio_but_decode_constrains() -> None: crf = LinearChainCRF(3, {0: "O", 1: "B-TITLE", 2: "I-TITLE"}) emissions = torch.zeros((1, 1, 3), dtype=torch.float32) tags = torch.tensor([[2]], dtype=torch.long) mask = torch.tensor([[1]], dtype=torch.bool) loss = crf.neg_log_likelihood(emissions, tags, mask) assert torch.isfinite(loss) assert abs(float(loss.item()) - math.log(3.0)) < 1e-6 decoded = crf.decode(emissions, mask) assert decoded[0][0] != 2 def main() -> None: test_crf_forward_backward_and_decode() test_crf_can_bootstrap_from_linear_checkpoint() test_crf_loss_allows_weak_bio_but_decode_constrains() print("CRF model smoke tests passed.") if __name__ == "__main__": main()