Spaces:
Runtime error
Runtime error
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| from typing import Dict, Tuple | |
| import torch | |
| from magi_compiler.tokenflow.graph_executor import ( | |
| GraphNormalExecutor, | |
| GraphOptimizer, | |
| GraphRawExecutor, | |
| GraphStageExecutor, | |
| LaneType, | |
| ) | |
| from magi_compiler.tokenflow.green_ctx import GreenCtxManager | |
| from magi_compiler.tokenflow.sampler import exponential_aligned_sampler | |
| from magi_compiler.tokenflow.utils import ModelConfig, TransformerModel, benchmark_func | |
| from torch import fx | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| ATOL = 1e-5 | |
| def _setup_executors(model) -> Tuple[Dict[str, object], torch.nn.Module, fx.GraphModule]: | |
| """内部辅助函数:初始化模型和所有执行器""" | |
| model.eval() | |
| gm = fx.symbolic_trace(model) | |
| # 生成阶段配置 | |
| optimizer = GraphOptimizer() | |
| stage_configs = optimizer.generate_stages_per_op(gm.graph) | |
| # 初始化所有执行器 | |
| executors = {} | |
| executors["model"] = model # 原始模型 | |
| executors["fx"] = gm # FX GraphModule | |
| executors["raw"] = GraphRawExecutor(gm, DEVICE) # 原始执行器(未优化) | |
| # 普通执行器 | |
| executors["normal_default"] = GraphNormalExecutor(gm, DEVICE) | |
| executors["normal_multi"] = GraphNormalExecutor(gm, DEVICE) | |
| executors["normal_green"] = GraphNormalExecutor(gm, DEVICE) | |
| for node_name in executors["normal_default"].stream_map.keys(): | |
| executors["normal_default"].stream_map[node_name] = torch.cuda.default_stream(DEVICE) | |
| for node_name in executors["normal_multi"].stream_map.keys(): | |
| executors["normal_multi"].stream_map[node_name] = torch.cuda.Stream(DEVICE) | |
| for node_name in executors["normal_green"].stream_map.keys(): | |
| gmgr = GreenCtxManager(DEVICE.index) | |
| executors["normal_green"].stream_map[node_name] = gmgr.create_stream(sm_count=gmgr.max_sm) | |
| # 阶段化执行器 | |
| executors["stage_default"] = GraphStageExecutor(gm, stage_configs, DEVICE) | |
| executors["stage_multi"] = GraphStageExecutor(gm, stage_configs, DEVICE) | |
| executors["stage_green"] = GraphStageExecutor(gm, stage_configs, DEVICE) | |
| # 配置阶段化执行器Stream | |
| for stage_cfg in stage_configs: | |
| stage_name = stage_cfg.name | |
| gmgr = GreenCtxManager(DEVICE.index) | |
| for lane_type in stage_cfg.lane_node_dict.keys(): | |
| executors["stage_default"].stage_lane_stream[stage_name][lane_type] = torch.cuda.default_stream(DEVICE) | |
| gmgr = GreenCtxManager(DEVICE.index) | |
| for lane_type in stage_cfg.lane_node_dict.keys(): | |
| executors["stage_multi"].stage_lane_stream[stage_name][lane_type] = torch.cuda.Stream(DEVICE) | |
| gmgr = GreenCtxManager(DEVICE.index) | |
| for lane_type in stage_cfg.lane_node_dict.keys(): | |
| if lane_type == LaneType.COMPUTE: | |
| executors["stage_green"].stage_lane_stream[stage_name][lane_type] = gmgr.create_stream(sm_count=gmgr.max_sm) | |
| return executors | |
| def test_executor_correctness_basic(): | |
| """测试基础序列长度下所有执行器的正确性""" | |
| model_config = ModelConfig( | |
| hidden_size=4096, | |
| num_layers=1, | |
| num_heads_q=32, | |
| num_heads_kv=8, | |
| head_dim=128, | |
| intermediate_size=16384, | |
| activation_type="gelu", | |
| ) | |
| model = TransformerModel(model_config).to(DEVICE) | |
| model.eval() | |
| test_seq_lengths = exponential_aligned_sampler(min_val=16, max_val=2048, num_samples=10, align=7) | |
| executors = _setup_executors(model) | |
| test_input = None | |
| def run_orig(): | |
| with torch.no_grad(): | |
| res = executors["model"](test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_fx(): | |
| with torch.no_grad(): | |
| res = executors["fx"](test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_raw(): | |
| with torch.no_grad(): | |
| res = executors["raw"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_normal_default(): | |
| with torch.no_grad(): | |
| res = executors["normal_default"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_normal_multi(): | |
| with torch.no_grad(): | |
| res = executors["normal_multi"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_normal_green(): | |
| with torch.no_grad(): | |
| res = executors["normal_green"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_stage_default(): | |
| with torch.no_grad(): | |
| res = executors["stage_default"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_stage_multi(): | |
| with torch.no_grad(): | |
| res = executors["stage_multi"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| def run_stage_green(): | |
| with torch.no_grad(): | |
| res = executors["stage_green"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| return res | |
| for seq_len in test_seq_lengths: | |
| # 生成测试输入 | |
| test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE) | |
| print(f"\n--- 正确性验证,序列长度={seq_len} ---") | |
| out_orig = run_orig() | |
| out_fx = run_fx() | |
| out_raw = run_raw() | |
| out_normal_default = run_normal_default() | |
| out_normal_multi = run_normal_multi() | |
| out_normal_green = run_normal_green() | |
| out_stage_default = run_stage_default() | |
| out_stage_multi = run_stage_multi() | |
| out_stage_green = run_stage_green() | |
| try: | |
| torch.testing.assert_close(out_fx, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_raw, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_normal_default, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_normal_multi, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_normal_green, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_stage_default, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_stage_multi, out_orig, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(out_stage_green, out_orig, rtol=1e-5, atol=1e-5) | |
| print(f"序列长度={seq_len} 正确性验证通过!") | |
| except Exception as e: | |
| print(f"序列长度={seq_len} 正确性验证失败!错误信息: {e}") | |
| raise e | |
| def test_executor_correctness_large_sequence(): | |
| """测试大序列长度下的执行器正确性""" | |
| # 大序列配置(减小hidden_size避免OOM) | |
| model_config = ModelConfig( | |
| hidden_size=1024, | |
| num_layers=1, | |
| num_heads_q=8, | |
| num_heads_kv=4, | |
| head_dim=128, | |
| intermediate_size=4096, | |
| activation_type="gelu", | |
| ) | |
| model = TransformerModel(model_config).to(DEVICE) | |
| model.eval() | |
| executors = _setup_executors(model) | |
| # 生成测试输入 | |
| seq_len = 8192 | |
| test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE) | |
| # 基准结果 | |
| with torch.no_grad(): | |
| baseline = model(test_input) | |
| # 验证阶段化绿色Stream(重点验证最优性能执行器) | |
| with torch.no_grad(): | |
| stage_green_result = executors["stage_green"].execute(test_input) | |
| executors["stage_green"].synchronize() | |
| assert torch.allclose(baseline, stage_green_result, atol=ATOL), f"大序列长度 {seq_len} 阶段化绿色Stream结果不匹配" | |
| def test_executor_efficiency(): | |
| """测试所有执行器的效率(输出耗时和加速比)""" | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| BASE_CONFIG = ModelConfig( | |
| hidden_size=4096, | |
| num_layers=1, ### fx382->raw858 ### fx3104->raw7409 | |
| num_heads_q=32, | |
| num_heads_kv=8, | |
| head_dim=128, | |
| intermediate_size=16384, | |
| activation_type="gelu", | |
| ) | |
| seq_len = 1 | |
| warmup_steps = 10 | |
| run_steps = 10 | |
| model = TransformerModel(BASE_CONFIG).to(DEVICE) | |
| # model = MiniMLP(BASE_CONFIG).to(DEVICE) | |
| executors = _setup_executors(model) | |
| test_input = torch.randn(seq_len, BASE_CONFIG.hidden_size).to(DEVICE) | |
| # 定义各执行函数 | |
| def run_original(): | |
| with torch.no_grad(): | |
| executors["model"](test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_fx(): | |
| with torch.no_grad(): | |
| executors["fx"](test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_raw(): | |
| with torch.no_grad(): | |
| executors["raw"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_normal_default(): | |
| with torch.no_grad(): | |
| executors["normal_default"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_normal_multi(): | |
| with torch.no_grad(): | |
| executors["normal_multi"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_normal_green(): | |
| with torch.no_grad(): | |
| executors["normal_green"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_stage_default(): | |
| with torch.no_grad(): | |
| executors["stage_default"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_stage_multi(): | |
| with torch.no_grad(): | |
| executors["stage_multi"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| def run_stage_green(): | |
| with torch.no_grad(): | |
| executors["stage_green"].execute(test_input) | |
| torch.cuda.synchronize(DEVICE) | |
| # 执行基准测试 | |
| times = { | |
| "original": benchmark_func(run_original, warmup_steps, run_steps), | |
| "fx": benchmark_func(run_fx, warmup_steps, run_steps), | |
| "raw": benchmark_func(run_raw, warmup_steps, run_steps), | |
| "normal_default": benchmark_func(run_normal_default, warmup_steps, run_steps), | |
| "normal_multi": benchmark_func(run_normal_multi, warmup_steps, run_steps), | |
| "normal_green": benchmark_func(run_normal_green, warmup_steps, run_steps), | |
| "stage_default": benchmark_func(run_stage_default, warmup_steps, run_steps), | |
| "stage_multi": benchmark_func(run_stage_multi, warmup_steps, run_steps), | |
| "stage_green": benchmark_func(run_stage_green, warmup_steps, run_steps), | |
| } | |
| # 计算加速比 | |
| speedups = {k: times["original"] / v for k, v in times.items()} | |
| # 输出结果(pytest会捕获print输出) | |
| print(f"\n=== 执行器效率测试结果({seq_len=}) ===") | |
| for name, t in times.items(): | |
| print(f"{name:15s}: {t:.6f} 秒/次 (加速比: {speedups[name]:.2f}x)") | |
| def test_executor_edge_cases(): | |
| """测试边界情况""" | |
| # 1. 极小序列长度 | |
| model_config = ModelConfig( | |
| hidden_size=128, | |
| num_layers=1, | |
| num_heads_q=4, | |
| num_heads_kv=2, | |
| head_dim=32, | |
| intermediate_size=512, | |
| activation_type="gelu", | |
| ) | |
| model = TransformerModel(model_config).to(DEVICE) | |
| model.eval() | |
| executors = _setup_executors(model) | |
| test_input = torch.randn(1, model_config.hidden_size).to(DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(test_input) | |
| result = executors["normal_default"].execute(test_input) | |
| executors["normal_default"].synchronize() | |
| assert torch.allclose(baseline, result, atol=ATOL), "极小序列长度执行失败" | |
| # 2. 空依赖模型测试 | |
| simple_model = torch.nn.Sequential(torch.nn.Linear(128, 256), torch.nn.GELU(), torch.nn.Linear(256, 128)).to(DEVICE) | |
| simple_model.eval() | |
| gm_simple = fx.symbolic_trace(simple_model) | |
| executor = GraphNormalExecutor(gm_simple, DEVICE) | |
| test_input = torch.randn(32, 128).to(DEVICE) | |
| with torch.no_grad(): | |
| baseline = simple_model(test_input) | |
| executor_result = executor.execute(test_input) | |
| executor.synchronize() | |
| assert torch.allclose(baseline, executor_result, atol=ATOL), "空依赖节点执行失败" | |