jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from magi_compiler.tokenflow.sampler import exponential_aligned_sampler
def test_sm_sampling():
result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=5, align=8)
assert isinstance(result, list), "返回值必须是列表"
assert len(result) == 5, f"返回长度应为5,实际为{len(result)}"
assert all(isinstance(x, int) for x in result), "所有元素必须是整数"
assert all(x % 8 == 0 for x in result), "所有元素必须对齐到8的倍数"
assert result[0] == 8, f"首元素应为8,实际为{result[0]}"
assert result[-1] == 128, f"尾元素应为128(132对齐后),实际为{result[-1]}"
assert all(result[i] < result[i + 1] for i in range(len(result) - 1)), "采样结果应严格递增"
mid_value = (8 + 128) / 2
mid_value_found = np.median(result)
assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性"
def test_seqlen_sampling():
result = exponential_aligned_sampler(min_val=1, max_val=65536, num_samples=5, align=32)
assert len(result) == 5
assert all(x % 32 == 0 for x in result)
assert result[0] == 32
assert result[-1] == 65536
mid_value = (32 + 65536) / 2
mid_value_found = np.median(result)
assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性"
def test_min_max_aligned_exactly():
result = exponential_aligned_sampler(min_val=16, max_val=64, num_samples=4, align=16)
assert len(result) == 4
assert result[0] == 16
assert result[-1] == 64
assert all(x in [16, 32, 48, 64] for x in result)
def test_output_len_2():
result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=2, align=8)
assert len(result) == 2
assert result == [8, 128]
def test_large_range_sampling():
result = exponential_aligned_sampler(min_val=1, max_val=1000001, num_samples=10, align=64)
assert len(result) == 10
assert all(x % 64 == 0 for x in result)
assert result[0] == 64
assert result[-1] == 1000000
def test_large_output_len():
with pytest.raises(ValueError) as excinfo:
result = exponential_aligned_sampler(min_val=8, max_val=64, num_samples=10, align=8)
assert excinfo is not None