deepgenopix / tests /test_model.py
vedatonuryilmaz's picture
length-bucketed training and HF bootstrap
b0ad1f7 verified
raw
history blame contribute delete
736 Bytes
from __future__ import annotations
import torch
from deepgenopix.model import DeepGenomicClassifier
def test_classifier_pads_to_stride_before_compression():
model = DeepGenomicClassifier(num_classes=2, compressor_stride=4)
x = torch.randn(1, 3, 5)
lengths = torch.tensor([5])
padded = model._pad_to_stride(x, lengths)
assert padded.shape == (1, 3, 8)
def test_classifier_uses_ceil_token_lengths_for_masks():
model = DeepGenomicClassifier(num_classes=2, compressor_stride=4)
x = torch.zeros(1, 64, 2)
lengths = torch.tensor([5])
mask, comp_lens = model._compute_mask(x, lengths)
assert comp_lens.tolist() == [2]
assert mask.shape == (1, 2)
assert mask.tolist() == [[False, False]]