# Copyright (c) 2026 SandAI. All Rights Reserved. from typing import Dict, Tuple import torch from magi_compiler.tokenflow.graph_executor import ( GraphNormalExecutor, GraphOptimizer, GraphRawExecutor, GraphStageExecutor, LaneType, ) from magi_compiler.tokenflow.green_ctx import GreenCtxManager from magi_compiler.tokenflow.sampler import exponential_aligned_sampler from magi_compiler.tokenflow.utils import ModelConfig, TransformerModel, benchmark_func from torch import fx DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ATOL = 1e-5 def _setup_executors(model) -> Tuple[Dict[str, object], torch.nn.Module, fx.GraphModule]: """内部辅助函数:初始化模型和所有执行器""" model.eval() gm = fx.symbolic_trace(model) # 生成阶段配置 optimizer = GraphOptimizer() stage_configs = optimizer.generate_stages_per_op(gm.graph) # 初始化所有执行器 executors = {} executors["model"] = model # 原始模型 executors["fx"] = gm # FX GraphModule executors["raw"] = GraphRawExecutor(gm, DEVICE) # 原始执行器(未优化) # 普通执行器 executors["normal_default"] = GraphNormalExecutor(gm, DEVICE) executors["normal_multi"] = GraphNormalExecutor(gm, DEVICE) executors["normal_green"] = GraphNormalExecutor(gm, DEVICE) for node_name in executors["normal_default"].stream_map.keys(): executors["normal_default"].stream_map[node_name] = torch.cuda.default_stream(DEVICE) for node_name in executors["normal_multi"].stream_map.keys(): executors["normal_multi"].stream_map[node_name] = torch.cuda.Stream(DEVICE) for node_name in executors["normal_green"].stream_map.keys(): gmgr = GreenCtxManager(DEVICE.index) executors["normal_green"].stream_map[node_name] = gmgr.create_stream(sm_count=gmgr.max_sm) # 阶段化执行器 executors["stage_default"] = GraphStageExecutor(gm, stage_configs, DEVICE) executors["stage_multi"] = GraphStageExecutor(gm, stage_configs, DEVICE) executors["stage_green"] = GraphStageExecutor(gm, stage_configs, DEVICE) # 配置阶段化执行器Stream for stage_cfg in stage_configs: stage_name = stage_cfg.name gmgr = GreenCtxManager(DEVICE.index) for lane_type in stage_cfg.lane_node_dict.keys(): executors["stage_default"].stage_lane_stream[stage_name][lane_type] = torch.cuda.default_stream(DEVICE) gmgr = GreenCtxManager(DEVICE.index) for lane_type in stage_cfg.lane_node_dict.keys(): executors["stage_multi"].stage_lane_stream[stage_name][lane_type] = torch.cuda.Stream(DEVICE) gmgr = GreenCtxManager(DEVICE.index) for lane_type in stage_cfg.lane_node_dict.keys(): if lane_type == LaneType.COMPUTE: executors["stage_green"].stage_lane_stream[stage_name][lane_type] = gmgr.create_stream(sm_count=gmgr.max_sm) return executors def test_executor_correctness_basic(): """测试基础序列长度下所有执行器的正确性""" model_config = ModelConfig( hidden_size=4096, num_layers=1, num_heads_q=32, num_heads_kv=8, head_dim=128, intermediate_size=16384, activation_type="gelu", ) model = TransformerModel(model_config).to(DEVICE) model.eval() test_seq_lengths = exponential_aligned_sampler(min_val=16, max_val=2048, num_samples=10, align=7) executors = _setup_executors(model) test_input = None def run_orig(): with torch.no_grad(): res = executors["model"](test_input) torch.cuda.synchronize(DEVICE) return res def run_fx(): with torch.no_grad(): res = executors["fx"](test_input) torch.cuda.synchronize(DEVICE) return res def run_raw(): with torch.no_grad(): res = executors["raw"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_normal_default(): with torch.no_grad(): res = executors["normal_default"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_normal_multi(): with torch.no_grad(): res = executors["normal_multi"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_normal_green(): with torch.no_grad(): res = executors["normal_green"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_stage_default(): with torch.no_grad(): res = executors["stage_default"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_stage_multi(): with torch.no_grad(): res = executors["stage_multi"].execute(test_input) torch.cuda.synchronize(DEVICE) return res def run_stage_green(): with torch.no_grad(): res = executors["stage_green"].execute(test_input) torch.cuda.synchronize(DEVICE) return res for seq_len in test_seq_lengths: # 生成测试输入 test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE) print(f"\n--- 正确性验证,序列长度={seq_len} ---") out_orig = run_orig() out_fx = run_fx() out_raw = run_raw() out_normal_default = run_normal_default() out_normal_multi = run_normal_multi() out_normal_green = run_normal_green() out_stage_default = run_stage_default() out_stage_multi = run_stage_multi() out_stage_green = run_stage_green() try: torch.testing.assert_close(out_fx, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_raw, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_normal_default, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_normal_multi, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_normal_green, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_stage_default, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_stage_multi, out_orig, rtol=1e-5, atol=1e-5) torch.testing.assert_close(out_stage_green, out_orig, rtol=1e-5, atol=1e-5) print(f"序列长度={seq_len} 正确性验证通过!") except Exception as e: print(f"序列长度={seq_len} 正确性验证失败!错误信息: {e}") raise e def test_executor_correctness_large_sequence(): """测试大序列长度下的执行器正确性""" # 大序列配置(减小hidden_size避免OOM) model_config = ModelConfig( hidden_size=1024, num_layers=1, num_heads_q=8, num_heads_kv=4, head_dim=128, intermediate_size=4096, activation_type="gelu", ) model = TransformerModel(model_config).to(DEVICE) model.eval() executors = _setup_executors(model) # 生成测试输入 seq_len = 8192 test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE) # 基准结果 with torch.no_grad(): baseline = model(test_input) # 验证阶段化绿色Stream(重点验证最优性能执行器) with torch.no_grad(): stage_green_result = executors["stage_green"].execute(test_input) executors["stage_green"].synchronize() assert torch.allclose(baseline, stage_green_result, atol=ATOL), f"大序列长度 {seq_len} 阶段化绿色Stream结果不匹配" def test_executor_efficiency(): """测试所有执行器的效率(输出耗时和加速比)""" DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") BASE_CONFIG = ModelConfig( hidden_size=4096, num_layers=1, ### fx382->raw858 ### fx3104->raw7409 num_heads_q=32, num_heads_kv=8, head_dim=128, intermediate_size=16384, activation_type="gelu", ) seq_len = 1 warmup_steps = 10 run_steps = 10 model = TransformerModel(BASE_CONFIG).to(DEVICE) # model = MiniMLP(BASE_CONFIG).to(DEVICE) executors = _setup_executors(model) test_input = torch.randn(seq_len, BASE_CONFIG.hidden_size).to(DEVICE) # 定义各执行函数 def run_original(): with torch.no_grad(): executors["model"](test_input) torch.cuda.synchronize(DEVICE) def run_fx(): with torch.no_grad(): executors["fx"](test_input) torch.cuda.synchronize(DEVICE) def run_raw(): with torch.no_grad(): executors["raw"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_normal_default(): with torch.no_grad(): executors["normal_default"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_normal_multi(): with torch.no_grad(): executors["normal_multi"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_normal_green(): with torch.no_grad(): executors["normal_green"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_stage_default(): with torch.no_grad(): executors["stage_default"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_stage_multi(): with torch.no_grad(): executors["stage_multi"].execute(test_input) torch.cuda.synchronize(DEVICE) def run_stage_green(): with torch.no_grad(): executors["stage_green"].execute(test_input) torch.cuda.synchronize(DEVICE) # 执行基准测试 times = { "original": benchmark_func(run_original, warmup_steps, run_steps), "fx": benchmark_func(run_fx, warmup_steps, run_steps), "raw": benchmark_func(run_raw, warmup_steps, run_steps), "normal_default": benchmark_func(run_normal_default, warmup_steps, run_steps), "normal_multi": benchmark_func(run_normal_multi, warmup_steps, run_steps), "normal_green": benchmark_func(run_normal_green, warmup_steps, run_steps), "stage_default": benchmark_func(run_stage_default, warmup_steps, run_steps), "stage_multi": benchmark_func(run_stage_multi, warmup_steps, run_steps), "stage_green": benchmark_func(run_stage_green, warmup_steps, run_steps), } # 计算加速比 speedups = {k: times["original"] / v for k, v in times.items()} # 输出结果(pytest会捕获print输出) print(f"\n=== 执行器效率测试结果({seq_len=}) ===") for name, t in times.items(): print(f"{name:15s}: {t:.6f} 秒/次 (加速比: {speedups[name]:.2f}x)") def test_executor_edge_cases(): """测试边界情况""" # 1. 极小序列长度 model_config = ModelConfig( hidden_size=128, num_layers=1, num_heads_q=4, num_heads_kv=2, head_dim=32, intermediate_size=512, activation_type="gelu", ) model = TransformerModel(model_config).to(DEVICE) model.eval() executors = _setup_executors(model) test_input = torch.randn(1, model_config.hidden_size).to(DEVICE) with torch.no_grad(): baseline = model(test_input) result = executors["normal_default"].execute(test_input) executors["normal_default"].synchronize() assert torch.allclose(baseline, result, atol=ATOL), "极小序列长度执行失败" # 2. 空依赖模型测试 simple_model = torch.nn.Sequential(torch.nn.Linear(128, 256), torch.nn.GELU(), torch.nn.Linear(256, 128)).to(DEVICE) simple_model.eval() gm_simple = fx.symbolic_trace(simple_model) executor = GraphNormalExecutor(gm_simple, DEVICE) test_input = torch.randn(32, 128).to(DEVICE) with torch.no_grad(): baseline = simple_model(test_input) executor_result = executor.execute(test_input) executor.synchronize() assert torch.allclose(baseline, executor_result, atol=ATOL), "空依赖节点执行失败"