nas / PFMBench /src /data /esm /utils /sampling_test.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import pytest
import torch
from src.data.esm.utils.sampling import sample_logits
def test_sample_logits():
# batched input. temperature != 0.0.
sampled = sample_logits(
logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096))
)
assert sampled.shape == (64, 8)
# batched input. temperature == 0.0.
sampled = sample_logits(
logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096))
)
assert sampled.shape == (64, 8)
# non-batched input. temperature != 0.0.
sampled = sample_logits(
logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096))
)
assert sampled.shape == (8,)
# non-batched input. temperature == 0.0.
sampled = sample_logits(
logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096))
)
assert sampled.shape == (8,)
with pytest.raises(ValueError):
sampled = sample_logits(
logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[]
)
test_sample_logits()