| | import pytest |
| | import torch |
| | from transformers import AutoConfig |
| | from mentioned.model import SentenceEncoder, Detector, MentionDetectorCore |
| |
|
| | |
| |
|
| |
|
| | @pytest.fixture |
| | def model_dims(): |
| | return {"input_dim": 128, "hidden_dim": 64, "seq_len": 10, "batch_size": 2} |
| |
|
| |
|
| | @pytest.fixture |
| | def mock_embeddings(model_dims): |
| | |
| | return torch.randn( |
| | model_dims["batch_size"], |
| | model_dims["seq_len"], |
| | model_dims["input_dim"], |
| | requires_grad=True, |
| | ) |
| |
|
| |
|
| | @pytest.mark.parametrize("subwords_per_word", [1, 2, 4]) |
| | def test_variable_subword_pooling(subwords_per_word): |
| | encoder = SentenceEncoder(model_name="sshleifer/tiny-distilroberta-base") |
| | B, Hidden = 1, encoder.dim |
| | Total_Subwords = 8 |
| | Num_Words = Total_Subwords // subwords_per_word |
| |
|
| | input_ids = torch.randint(0, 100, (B, Total_Subwords)) |
| | attention_mask = torch.ones(B, Total_Subwords) |
| |
|
| | |
| | word_ids = torch.arange(Num_Words).repeat_interleave(subwords_per_word).unsqueeze(0) |
| |
|
| | word_embs = encoder(input_ids, attention_mask, word_ids) |
| |
|
| | assert word_embs.shape == (B, Num_Words, Hidden) |
| | assert not torch.isnan(word_embs).any() |
| |
|
| |
|
| | def test_detector_projections(model_dims, mock_embeddings): |
| | """Verify Detector handles both 3D (starts) and 4D (spans) tensors.""" |
| | detector = Detector(model_dims["input_dim"], model_dims["hidden_dim"]) |
| |
|
| | |
| | start_out = detector(mock_embeddings) |
| | assert start_out.shape == (model_dims["batch_size"], model_dims["seq_len"], 1) |
| |
|
| | |
| | pair_input = torch.randn(2, 10, 10, model_dims["input_dim"]) |
| | pair_out = detector(pair_input) |
| | assert pair_out.shape == (2, 10, 10, 1) |
| |
|
| |
|
| | def test_mention_detector_core_logic(model_dims, mock_embeddings): |
| | """Verify the N x N pair materialization and concatenation.""" |
| | B, N, H = model_dims["batch_size"], model_dims["seq_len"], model_dims["input_dim"] |
| |
|
| | start_det = Detector(H, 32) |
| | |
| | end_det = Detector(H * 2, 32) |
| |
|
| | model = MentionDetectorCore(start_det, end_det) |
| | start_logits, end_logits = model(mock_embeddings) |
| |
|
| | assert start_logits.shape == (B, N) |
| | assert end_logits.shape == (B, N, N) |
| |
|
| | |
| | loss = start_logits.sum() + end_logits.sum() |
| | loss.backward() |
| | assert mock_embeddings.grad is not None |
| |
|