Spaces:
Runtime error
Runtime error
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import numpy as np | |
| import pytest | |
| from magi_compiler.tokenflow.sampler import exponential_aligned_sampler | |
| def test_sm_sampling(): | |
| result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=5, align=8) | |
| assert isinstance(result, list), "返回值必须是列表" | |
| assert len(result) == 5, f"返回长度应为5,实际为{len(result)}" | |
| assert all(isinstance(x, int) for x in result), "所有元素必须是整数" | |
| assert all(x % 8 == 0 for x in result), "所有元素必须对齐到8的倍数" | |
| assert result[0] == 8, f"首元素应为8,实际为{result[0]}" | |
| assert result[-1] == 128, f"尾元素应为128(132对齐后),实际为{result[-1]}" | |
| assert all(result[i] < result[i + 1] for i in range(len(result) - 1)), "采样结果应严格递增" | |
| mid_value = (8 + 128) / 2 | |
| mid_value_found = np.median(result) | |
| assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性" | |
| def test_seqlen_sampling(): | |
| result = exponential_aligned_sampler(min_val=1, max_val=65536, num_samples=5, align=32) | |
| assert len(result) == 5 | |
| assert all(x % 32 == 0 for x in result) | |
| assert result[0] == 32 | |
| assert result[-1] == 65536 | |
| mid_value = (32 + 65536) / 2 | |
| mid_value_found = np.median(result) | |
| assert mid_value_found < mid_value, "中间值应小于范围中点,符合指数分布特性" | |
| def test_min_max_aligned_exactly(): | |
| result = exponential_aligned_sampler(min_val=16, max_val=64, num_samples=4, align=16) | |
| assert len(result) == 4 | |
| assert result[0] == 16 | |
| assert result[-1] == 64 | |
| assert all(x in [16, 32, 48, 64] for x in result) | |
| def test_output_len_2(): | |
| result = exponential_aligned_sampler(min_val=8, max_val=132, num_samples=2, align=8) | |
| assert len(result) == 2 | |
| assert result == [8, 128] | |
| def test_large_range_sampling(): | |
| result = exponential_aligned_sampler(min_val=1, max_val=1000001, num_samples=10, align=64) | |
| assert len(result) == 10 | |
| assert all(x % 64 == 0 for x in result) | |
| assert result[0] == 64 | |
| assert result[-1] == 1000000 | |
| def test_large_output_len(): | |
| with pytest.raises(ValueError) as excinfo: | |
| result = exponential_aligned_sampler(min_val=8, max_val=64, num_samples=10, align=8) | |
| assert excinfo is not None | |