# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest from magi_compiler.tokenflow.sampler import exponential_aligned_sampler def test_sm_sampling(): result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=5, align=8) assert isinstance(result, list), "返回值必须是列表" assert len(result) == 5, f"返回长度应为5,实际为{len(result)}" assert all(isinstance(x, int) for x in result), "所有元素必须是整数" assert all(x % 8 == 0 for x in result), "所有元素必须对齐到8的倍数" assert result[0] == 8, f"首元素应为8,实际为{result[0]}" assert result[-1] == 128, f"尾元素应为128(132对齐后),实际为{result[-1]}" assert all(result[i] < result[i + 1] for i in range(len(result) - 1)), "采样结果应严格递增" mid_value = (8 + 128) / 2 mid_value_found = np.median(result) assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性" def test_seqlen_sampling(): result = exponential_aligned_sampler(min_val=1, max_val=65536, num_samples=5, align=32) assert len(result) == 5 assert all(x % 32 == 0 for x in result) assert result[0] == 32 assert result[-1] == 65536 mid_value = (32 + 65536) / 2 mid_value_found = np.median(result) assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性" def test_min_max_aligned_exactly(): result = exponential_aligned_sampler(min_val=16, max_val=64, num_samples=4, align=16) assert len(result) == 4 assert result[0] == 16 assert result[-1] == 64 assert all(x in [16, 32, 48, 64] for x in result) def test_output_len_2(): result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=2, align=8) assert len(result) == 2 assert result == [8, 128] def test_large_range_sampling(): result = exponential_aligned_sampler(min_val=1, max_val=1000001, num_samples=10, align=64) assert len(result) == 10 assert all(x % 64 == 0 for x in result) assert result[0] == 64 assert result[-1] == 1000000 def test_large_output_len(): with pytest.raises(ValueError) as excinfo: result = exponential_aligned_sampler(min_val=8, max_val=64, num_samples=10, align=8) assert excinfo is not None