File size: 2,993 Bytes
efb213a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Smoke tests for the optional CRF token-classification head."""

from __future__ import annotations

import math
import sys
import tempfile
from pathlib import Path

import torch

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from anifilebert.config import Config
from anifilebert.model import BertCrfForTokenClassification, LinearChainCRF, create_model, load_model


def tiny_config() -> Config:
    cfg = Config()
    cfg.vocab_size = 32
    cfg.max_position_embeddings = 16
    cfg.hidden_size = 32
    cfg.num_hidden_layers = 1
    cfg.num_attention_heads = 4
    cfg.intermediate_size = 64
    cfg.max_seq_length = 8
    return cfg


def test_crf_forward_backward_and_decode() -> None:
    torch.manual_seed(7)
    cfg = tiny_config()
    model = create_model(cfg, model_head="crf")
    assert isinstance(model, BertCrfForTokenClassification)

    input_ids = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_length), dtype=torch.long)
    attention_mask = torch.tensor(
        [
            [1, 1, 1, 1, 1, 1, 1, 0],
            [1, 1, 1, 1, 1, 1, 0, 0],
        ],
        dtype=torch.bool,
    )
    labels = torch.tensor(
        [
            [-100, 1, 2, 5, 6, 0, -100, -100],
            [-100, 7, 8, 0, 13, -100, -100, -100],
        ],
        dtype=torch.long,
    )

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    assert outputs.loss is not None
    assert outputs.loss.ndim == 0
    outputs.loss.backward()
    assert model.classifier.weight.grad is not None

    paths = model.decode(outputs.logits, attention_mask)
    assert [len(path) for path in paths] == [5, 4]


def test_crf_can_bootstrap_from_linear_checkpoint() -> None:
    cfg = tiny_config()
    linear_model = create_model(cfg, model_head="linear")

    with tempfile.TemporaryDirectory() as tmp:
        path = Path(tmp)
        linear_model.save_pretrained(path)
        crf_model = load_model(str(path), model_head="crf")

    assert isinstance(crf_model, BertCrfForTokenClassification)
    assert crf_model.get_input_embeddings().weight.shape[0] == cfg.vocab_size
    assert crf_model.classifier.out_features == cfg.num_labels


def test_crf_loss_allows_weak_bio_but_decode_constrains() -> None:
    crf = LinearChainCRF(3, {0: "O", 1: "B-TITLE", 2: "I-TITLE"})
    emissions = torch.zeros((1, 1, 3), dtype=torch.float32)
    tags = torch.tensor([[2]], dtype=torch.long)
    mask = torch.tensor([[1]], dtype=torch.bool)

    loss = crf.neg_log_likelihood(emissions, tags, mask)
    assert torch.isfinite(loss)
    assert abs(float(loss.item()) - math.log(3.0)) < 1e-6

    decoded = crf.decode(emissions, mask)
    assert decoded[0][0] != 2


def main() -> None:
    test_crf_forward_backward_and_decode()
    test_crf_can_bootstrap_from_linear_checkpoint()
    test_crf_loss_allows_weak_bio_but_decode_constrains()
    print("CRF model smoke tests passed.")


if __name__ == "__main__":
    main()