from __future__ import annotations import torch from deepgenopix.model import DeepGenomicClassifier def test_classifier_pads_to_stride_before_compression(): model = DeepGenomicClassifier(num_classes=2, compressor_stride=4) x = torch.randn(1, 3, 5) lengths = torch.tensor([5]) padded = model._pad_to_stride(x, lengths) assert padded.shape == (1, 3, 8) def test_classifier_uses_ceil_token_lengths_for_masks(): model = DeepGenomicClassifier(num_classes=2, compressor_stride=4) x = torch.zeros(1, 64, 2) lengths = torch.tensor([5]) mask, comp_lens = model._compute_mask(x, lengths) assert comp_lens.tolist() == [2] assert mask.shape == (1, 2) assert mask.tolist() == [[False, False]]