| from __future__ import annotations | |
| import torch | |
| from deepgenopix.model import DeepGenomicClassifier | |
| def test_classifier_pads_to_stride_before_compression(): | |
| model = DeepGenomicClassifier(num_classes=2, compressor_stride=4) | |
| x = torch.randn(1, 3, 5) | |
| lengths = torch.tensor([5]) | |
| padded = model._pad_to_stride(x, lengths) | |
| assert padded.shape == (1, 3, 8) | |
| def test_classifier_uses_ceil_token_lengths_for_masks(): | |
| model = DeepGenomicClassifier(num_classes=2, compressor_stride=4) | |
| x = torch.zeros(1, 64, 2) | |
| lengths = torch.tensor([5]) | |
| mask, comp_lens = model._compute_mask(x, lengths) | |
| assert comp_lens.tolist() == [2] | |
| assert mask.shape == (1, 2) | |
| assert mask.tolist() == [[False, False]] | |