| import pytest | |
| import torch | |
| from src.data.esm.utils.sampling import sample_logits | |
| def test_sample_logits(): | |
| # batched input. temperature != 0.0. | |
| sampled = sample_logits( | |
| logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096)) | |
| ) | |
| assert sampled.shape == (64, 8) | |
| # batched input. temperature == 0.0. | |
| sampled = sample_logits( | |
| logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096)) | |
| ) | |
| assert sampled.shape == (64, 8) | |
| # non-batched input. temperature != 0.0. | |
| sampled = sample_logits( | |
| logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096)) | |
| ) | |
| assert sampled.shape == (8,) | |
| # non-batched input. temperature == 0.0. | |
| sampled = sample_logits( | |
| logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096)) | |
| ) | |
| assert sampled.shape == (8,) | |
| with pytest.raises(ValueError): | |
| sampled = sample_logits( | |
| logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[] | |
| ) | |
| test_sample_logits() | |