AniFileBERT / tools /test_crf_model.py
ModerRAS's picture
chore: checkpoint current training and manual relabel progress
efb213a
raw
history blame
2.99 kB
"""Smoke tests for the optional CRF token-classification head."""
from __future__ import annotations
import math
import sys
import tempfile
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from anifilebert.config import Config
from anifilebert.model import BertCrfForTokenClassification, LinearChainCRF, create_model, load_model
def tiny_config() -> Config:
cfg = Config()
cfg.vocab_size = 32
cfg.max_position_embeddings = 16
cfg.hidden_size = 32
cfg.num_hidden_layers = 1
cfg.num_attention_heads = 4
cfg.intermediate_size = 64
cfg.max_seq_length = 8
return cfg
def test_crf_forward_backward_and_decode() -> None:
torch.manual_seed(7)
cfg = tiny_config()
model = create_model(cfg, model_head="crf")
assert isinstance(model, BertCrfForTokenClassification)
input_ids = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_length), dtype=torch.long)
attention_mask = torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
],
dtype=torch.bool,
)
labels = torch.tensor(
[
[-100, 1, 2, 5, 6, 0, -100, -100],
[-100, 7, 8, 0, 13, -100, -100, -100],
],
dtype=torch.long,
)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
assert outputs.loss is not None
assert outputs.loss.ndim == 0
outputs.loss.backward()
assert model.classifier.weight.grad is not None
paths = model.decode(outputs.logits, attention_mask)
assert [len(path) for path in paths] == [5, 4]
def test_crf_can_bootstrap_from_linear_checkpoint() -> None:
cfg = tiny_config()
linear_model = create_model(cfg, model_head="linear")
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp)
linear_model.save_pretrained(path)
crf_model = load_model(str(path), model_head="crf")
assert isinstance(crf_model, BertCrfForTokenClassification)
assert crf_model.get_input_embeddings().weight.shape[0] == cfg.vocab_size
assert crf_model.classifier.out_features == cfg.num_labels
def test_crf_loss_allows_weak_bio_but_decode_constrains() -> None:
crf = LinearChainCRF(3, {0: "O", 1: "B-TITLE", 2: "I-TITLE"})
emissions = torch.zeros((1, 1, 3), dtype=torch.float32)
tags = torch.tensor([[2]], dtype=torch.long)
mask = torch.tensor([[1]], dtype=torch.bool)
loss = crf.neg_log_likelihood(emissions, tags, mask)
assert torch.isfinite(loss)
assert abs(float(loss.item()) - math.log(3.0)) < 1e-6
decoded = crf.decode(emissions, mask)
assert decoded[0][0] != 2
def main() -> None:
test_crf_forward_backward_and_decode()
test_crf_can_bootstrap_from_linear_checkpoint()
test_crf_loss_allows_weak_bio_but_decode_constrains()
print("CRF model smoke tests passed.")
if __name__ == "__main__":
main()