| import pytest |
| import torch |
| from unittest.mock import MagicMock |
| from mentioned.model import LitMentionDetector, Detector, MentionDetectorCore |
|
|
| |
|
|
|
|
| @pytest.fixture |
| def mock_components(): |
| """Create minimal components for the LightningModule.""" |
| |
| encoder = MagicMock(spec=torch.nn.Module) |
| encoder.max_length = 512 |
| encoder.dim = 128 |
| encoder.parameters.return_value = [torch.nn.Parameter(torch.randn(128, 128))] |
|
|
| start_det = Detector(128, 64) |
| end_det = Detector(256, 64) |
| mention_det = MentionDetectorCore(start_det, end_det) |
|
|
| tokenizer = MagicMock() |
| |
| tokenizer.return_value = { |
| "input_ids": torch.zeros((1, 5)), |
| "attention_mask": torch.ones((1, 5)), |
| } |
| |
| tokenizer.word_ids.return_value = [0, 0, 1, 1, None] |
|
|
| return tokenizer, encoder, mention_det |
|
|
|
|
| @pytest.fixture |
| def dummy_batch(): |
| """A realistic batch as produced by your collate_fn.""" |
| B, N = 2, 4 |
| return { |
| "sentences": [["Hello", "world"], ["Test"]], |
| "starts": torch.zeros((B, N), dtype=torch.long), |
| "spans": torch.zeros((B, N, N), dtype=torch.long), |
| "token_mask": torch.ones((B, N), dtype=torch.bool), |
| "span_loss_mask": torch.ones((B, N, N), dtype=torch.bool), |
| "task_id": [0, 0], |
| } |
|
|
|
|
| |
|
|
|
|
| def test_encoder_freezing(mock_components): |
| """Verify the encoder parameters are set to requires_grad=False.""" |
| tokenizer, encoder, mention_det = mock_components |
| model = LitMentionDetector(tokenizer, encoder, mention_det) |
|
|
| for param in model.encoder.parameters(): |
| assert param.requires_grad is False |
|
|
| |
| for param in model.mention_detector.parameters(): |
| assert param.requires_grad is True |
|
|
|
|
| def test_compute_loss_empty_mask(mock_components, dummy_batch): |
| tokenizer, encoder, mention_det = mock_components |
| model = LitMentionDetector(tokenizer, encoder, mention_det) |
|
|
| |
| dummy_batch["span_loss_mask"] = torch.zeros_like( |
| dummy_batch["span_loss_mask"] |
| ).bool() |
|
|
| |
| end_logits = torch.randn(2, 4, 4, requires_grad=True) |
|
|
| loss = model._compute_end_loss(end_logits, dummy_batch) |
|
|
| assert loss.item() == 0.0 |
| |
| assert loss.requires_grad |
|
|
|
|
| def test_predict_mentions_format(mock_components): |
| """Verify predict_mentions returns the expected list of (start, end) tuples.""" |
| tokenizer, encoder, mention_det = mock_components |
| model = LitMentionDetector(tokenizer, encoder, mention_det) |
|
|
| |
| model.encode = MagicMock(return_value=torch.randn(1, 3, 128)) |
|
|
| |
| |
| s_logits = torch.tensor([[10.0, -10.0, -10.0]]) |
| e_logits = torch.tensor( |
| [[[-10.0, 10.0, -10.0], [-10.0, -10.0, -10.0], [-10.0, -10.0, -10.0]]] |
| ) |
|
|
| model.mention_detector.forward = MagicMock(return_value=(s_logits, e_logits)) |
|
|
| results = model.predict_mentions([["Mock", "sentence", "."]]) |
|
|
| assert len(results) == 1 |
| assert results[0] == [(0, 1)] |
|
|
|
|
| def test_training_step_integration(mock_components, dummy_batch): |
| """Verify training_step returns a valid loss tensor.""" |
| tokenizer, encoder, mention_det = mock_components |
| model = LitMentionDetector(tokenizer, encoder, mention_det) |
| |
| model.encode = MagicMock(return_value=torch.randn(2, 4, 128)) |
|
|
| loss = model.training_step(dummy_batch, 0) |
|
|
| assert isinstance(loss, torch.Tensor) |
| assert loss.ndim == 0 |
| assert not torch.isnan(loss) |
|
|
|
|
| def test_validation_metrics_update(mock_components, dummy_batch): |
| """Ensure F1 metrics are updated during validation.""" |
| tokenizer, encoder, mention_det = mock_components |
| model = LitMentionDetector(tokenizer, encoder, mention_det) |
| model.encode = MagicMock(return_value=torch.randn(2, 4, 128)) |
|
|
| |
| model.validation_step(dummy_batch, 0) |
|
|
| |
| assert model.val_f1_start.update_count > 0 |
| assert model.val_f1_mention.update_count > 0 |
|
|