daVinci-MagiHuman / pkgs /MagiCompiler /tests /tokenflow /test_graph_executor.py
jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Copyright (c) 2026 SandAI. All Rights Reserved.
from typing import Dict, Tuple
import torch
from magi_compiler.tokenflow.graph_executor import (
GraphNormalExecutor,
GraphOptimizer,
GraphRawExecutor,
GraphStageExecutor,
LaneType,
)
from magi_compiler.tokenflow.green_ctx import GreenCtxManager
from magi_compiler.tokenflow.sampler import exponential_aligned_sampler
from magi_compiler.tokenflow.utils import ModelConfig, TransformerModel, benchmark_func
from torch import fx
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ATOL = 1e-5
def _setup_executors(model) -> Tuple[Dict[str, object], torch.nn.Module, fx.GraphModule]:
"""内部辅助函数:初始化模型和所有执行器"""
model.eval()
gm = fx.symbolic_trace(model)
# 生成阶段配置
optimizer = GraphOptimizer()
stage_configs = optimizer.generate_stages_per_op(gm.graph)
# 初始化所有执行器
executors = {}
executors["model"] = model # 原始模型
executors["fx"] = gm # FX GraphModule
executors["raw"] = GraphRawExecutor(gm, DEVICE) # 原始执行器(未优化)
# 普通执行器
executors["normal_default"] = GraphNormalExecutor(gm, DEVICE)
executors["normal_multi"] = GraphNormalExecutor(gm, DEVICE)
executors["normal_green"] = GraphNormalExecutor(gm, DEVICE)
for node_name in executors["normal_default"].stream_map.keys():
executors["normal_default"].stream_map[node_name] = torch.cuda.default_stream(DEVICE)
for node_name in executors["normal_multi"].stream_map.keys():
executors["normal_multi"].stream_map[node_name] = torch.cuda.Stream(DEVICE)
for node_name in executors["normal_green"].stream_map.keys():
gmgr = GreenCtxManager(DEVICE.index)
executors["normal_green"].stream_map[node_name] = gmgr.create_stream(sm_count=gmgr.max_sm)
# 阶段化执行器
executors["stage_default"] = GraphStageExecutor(gm, stage_configs, DEVICE)
executors["stage_multi"] = GraphStageExecutor(gm, stage_configs, DEVICE)
executors["stage_green"] = GraphStageExecutor(gm, stage_configs, DEVICE)
# 配置阶段化执行器Stream
for stage_cfg in stage_configs:
stage_name = stage_cfg.name
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
executors["stage_default"].stage_lane_stream[stage_name][lane_type] = torch.cuda.default_stream(DEVICE)
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
executors["stage_multi"].stage_lane_stream[stage_name][lane_type] = torch.cuda.Stream(DEVICE)
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
if lane_type == LaneType.COMPUTE:
executors["stage_green"].stage_lane_stream[stage_name][lane_type] = gmgr.create_stream(sm_count=gmgr.max_sm)
return executors
def test_executor_correctness_basic():
"""测试基础序列长度下所有执行器的正确性"""
model_config = ModelConfig(
hidden_size=4096,
num_layers=1,
num_heads_q=32,
num_heads_kv=8,
head_dim=128,
intermediate_size=16384,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
test_seq_lengths = exponential_aligned_sampler(min_val=16, max_val=2048, num_samples=10, align=7)
executors = _setup_executors(model)
test_input = None
def run_orig():
with torch.no_grad():
res = executors["model"](test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_fx():
with torch.no_grad():
res = executors["fx"](test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_raw():
with torch.no_grad():
res = executors["raw"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_default():
with torch.no_grad():
res = executors["normal_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_multi():
with torch.no_grad():
res = executors["normal_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_green():
with torch.no_grad():
res = executors["normal_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_default():
with torch.no_grad():
res = executors["stage_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_multi():
with torch.no_grad():
res = executors["stage_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_green():
with torch.no_grad():
res = executors["stage_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
for seq_len in test_seq_lengths:
# 生成测试输入
test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE)
print(f"\n--- 正确性验证,序列长度={seq_len} ---")
out_orig = run_orig()
out_fx = run_fx()
out_raw = run_raw()
out_normal_default = run_normal_default()
out_normal_multi = run_normal_multi()
out_normal_green = run_normal_green()
out_stage_default = run_stage_default()
out_stage_multi = run_stage_multi()
out_stage_green = run_stage_green()
try:
torch.testing.assert_close(out_fx, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_raw, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_default, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_multi, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_green, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_default, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_multi, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_green, out_orig, rtol=1e-5, atol=1e-5)
print(f"序列长度={seq_len} 正确性验证通过!")
except Exception as e:
print(f"序列长度={seq_len} 正确性验证失败!错误信息: {e}")
raise e
def test_executor_correctness_large_sequence():
"""测试大序列长度下的执行器正确性"""
# 大序列配置(减小hidden_size避免OOM)
model_config = ModelConfig(
hidden_size=1024,
num_layers=1,
num_heads_q=8,
num_heads_kv=4,
head_dim=128,
intermediate_size=4096,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
executors = _setup_executors(model)
# 生成测试输入
seq_len = 8192
test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE)
# 基准结果
with torch.no_grad():
baseline = model(test_input)
# 验证阶段化绿色Stream(重点验证最优性能执行器)
with torch.no_grad():
stage_green_result = executors["stage_green"].execute(test_input)
executors["stage_green"].synchronize()
assert torch.allclose(baseline, stage_green_result, atol=ATOL), f"大序列长度 {seq_len} 阶段化绿色Stream结果不匹配"
def test_executor_efficiency():
"""测试所有执行器的效率(输出耗时和加速比)"""
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_CONFIG = ModelConfig(
hidden_size=4096,
num_layers=1, ### fx382->raw858 ### fx3104->raw7409
num_heads_q=32,
num_heads_kv=8,
head_dim=128,
intermediate_size=16384,
activation_type="gelu",
)
seq_len = 1
warmup_steps = 10
run_steps = 10
model = TransformerModel(BASE_CONFIG).to(DEVICE)
# model = MiniMLP(BASE_CONFIG).to(DEVICE)
executors = _setup_executors(model)
test_input = torch.randn(seq_len, BASE_CONFIG.hidden_size).to(DEVICE)
# 定义各执行函数
def run_original():
with torch.no_grad():
executors["model"](test_input)
torch.cuda.synchronize(DEVICE)
def run_fx():
with torch.no_grad():
executors["fx"](test_input)
torch.cuda.synchronize(DEVICE)
def run_raw():
with torch.no_grad():
executors["raw"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_default():
with torch.no_grad():
executors["normal_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_multi():
with torch.no_grad():
executors["normal_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_green():
with torch.no_grad():
executors["normal_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_default():
with torch.no_grad():
executors["stage_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_multi():
with torch.no_grad():
executors["stage_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_green():
with torch.no_grad():
executors["stage_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
# 执行基准测试
times = {
"original": benchmark_func(run_original, warmup_steps, run_steps),
"fx": benchmark_func(run_fx, warmup_steps, run_steps),
"raw": benchmark_func(run_raw, warmup_steps, run_steps),
"normal_default": benchmark_func(run_normal_default, warmup_steps, run_steps),
"normal_multi": benchmark_func(run_normal_multi, warmup_steps, run_steps),
"normal_green": benchmark_func(run_normal_green, warmup_steps, run_steps),
"stage_default": benchmark_func(run_stage_default, warmup_steps, run_steps),
"stage_multi": benchmark_func(run_stage_multi, warmup_steps, run_steps),
"stage_green": benchmark_func(run_stage_green, warmup_steps, run_steps),
}
# 计算加速比
speedups = {k: times["original"] / v for k, v in times.items()}
# 输出结果(pytest会捕获print输出)
print(f"\n=== 执行器效率测试结果({seq_len=}) ===")
for name, t in times.items():
print(f"{name:15s}: {t:.6f} 秒/次 (加速比: {speedups[name]:.2f}x)")
def test_executor_edge_cases():
"""测试边界情况"""
# 1. 极小序列长度
model_config = ModelConfig(
hidden_size=128,
num_layers=1,
num_heads_q=4,
num_heads_kv=2,
head_dim=32,
intermediate_size=512,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
executors = _setup_executors(model)
test_input = torch.randn(1, model_config.hidden_size).to(DEVICE)
with torch.no_grad():
baseline = model(test_input)
result = executors["normal_default"].execute(test_input)
executors["normal_default"].synchronize()
assert torch.allclose(baseline, result, atol=ATOL), "极小序列长度执行失败"
# 2. 空依赖模型测试
simple_model = torch.nn.Sequential(torch.nn.Linear(128, 256), torch.nn.GELU(), torch.nn.Linear(256, 128)).to(DEVICE)
simple_model.eval()
gm_simple = fx.symbolic_trace(simple_model)
executor = GraphNormalExecutor(gm_simple, DEVICE)
test_input = torch.randn(32, 128).to(DEVICE)
with torch.no_grad():
baseline = simple_model(test_input)
executor_result = executor.execute(test_input)
executor.synchronize()
assert torch.allclose(baseline, executor_result, atol=ATOL), "空依赖节点执行失败"