Spaces:
Runtime error
Runtime error
File size: 12,255 Bytes
e6066e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 | # Copyright (c) 2026 SandAI. All Rights Reserved.
from typing import Dict, Tuple
import torch
from magi_compiler.tokenflow.graph_executor import (
GraphNormalExecutor,
GraphOptimizer,
GraphRawExecutor,
GraphStageExecutor,
LaneType,
)
from magi_compiler.tokenflow.green_ctx import GreenCtxManager
from magi_compiler.tokenflow.sampler import exponential_aligned_sampler
from magi_compiler.tokenflow.utils import ModelConfig, TransformerModel, benchmark_func
from torch import fx
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ATOL = 1e-5
def _setup_executors(model) -> Tuple[Dict[str, object], torch.nn.Module, fx.GraphModule]:
"""内部辅助函数:初始化模型和所有执行器"""
model.eval()
gm = fx.symbolic_trace(model)
# 生成阶段配置
optimizer = GraphOptimizer()
stage_configs = optimizer.generate_stages_per_op(gm.graph)
# 初始化所有执行器
executors = {}
executors["model"] = model # 原始模型
executors["fx"] = gm # FX GraphModule
executors["raw"] = GraphRawExecutor(gm, DEVICE) # 原始执行器(未优化)
# 普通执行器
executors["normal_default"] = GraphNormalExecutor(gm, DEVICE)
executors["normal_multi"] = GraphNormalExecutor(gm, DEVICE)
executors["normal_green"] = GraphNormalExecutor(gm, DEVICE)
for node_name in executors["normal_default"].stream_map.keys():
executors["normal_default"].stream_map[node_name] = torch.cuda.default_stream(DEVICE)
for node_name in executors["normal_multi"].stream_map.keys():
executors["normal_multi"].stream_map[node_name] = torch.cuda.Stream(DEVICE)
for node_name in executors["normal_green"].stream_map.keys():
gmgr = GreenCtxManager(DEVICE.index)
executors["normal_green"].stream_map[node_name] = gmgr.create_stream(sm_count=gmgr.max_sm)
# 阶段化执行器
executors["stage_default"] = GraphStageExecutor(gm, stage_configs, DEVICE)
executors["stage_multi"] = GraphStageExecutor(gm, stage_configs, DEVICE)
executors["stage_green"] = GraphStageExecutor(gm, stage_configs, DEVICE)
# 配置阶段化执行器Stream
for stage_cfg in stage_configs:
stage_name = stage_cfg.name
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
executors["stage_default"].stage_lane_stream[stage_name][lane_type] = torch.cuda.default_stream(DEVICE)
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
executors["stage_multi"].stage_lane_stream[stage_name][lane_type] = torch.cuda.Stream(DEVICE)
gmgr = GreenCtxManager(DEVICE.index)
for lane_type in stage_cfg.lane_node_dict.keys():
if lane_type == LaneType.COMPUTE:
executors["stage_green"].stage_lane_stream[stage_name][lane_type] = gmgr.create_stream(sm_count=gmgr.max_sm)
return executors
def test_executor_correctness_basic():
"""测试基础序列长度下所有执行器的正确性"""
model_config = ModelConfig(
hidden_size=4096,
num_layers=1,
num_heads_q=32,
num_heads_kv=8,
head_dim=128,
intermediate_size=16384,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
test_seq_lengths = exponential_aligned_sampler(min_val=16, max_val=2048, num_samples=10, align=7)
executors = _setup_executors(model)
test_input = None
def run_orig():
with torch.no_grad():
res = executors["model"](test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_fx():
with torch.no_grad():
res = executors["fx"](test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_raw():
with torch.no_grad():
res = executors["raw"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_default():
with torch.no_grad():
res = executors["normal_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_multi():
with torch.no_grad():
res = executors["normal_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_normal_green():
with torch.no_grad():
res = executors["normal_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_default():
with torch.no_grad():
res = executors["stage_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_multi():
with torch.no_grad():
res = executors["stage_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
def run_stage_green():
with torch.no_grad():
res = executors["stage_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
return res
for seq_len in test_seq_lengths:
# 生成测试输入
test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE)
print(f"\n--- 正确性验证,序列长度={seq_len} ---")
out_orig = run_orig()
out_fx = run_fx()
out_raw = run_raw()
out_normal_default = run_normal_default()
out_normal_multi = run_normal_multi()
out_normal_green = run_normal_green()
out_stage_default = run_stage_default()
out_stage_multi = run_stage_multi()
out_stage_green = run_stage_green()
try:
torch.testing.assert_close(out_fx, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_raw, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_default, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_multi, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_normal_green, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_default, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_multi, out_orig, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(out_stage_green, out_orig, rtol=1e-5, atol=1e-5)
print(f"序列长度={seq_len} 正确性验证通过!")
except Exception as e:
print(f"序列长度={seq_len} 正确性验证失败!错误信息: {e}")
raise e
def test_executor_correctness_large_sequence():
"""测试大序列长度下的执行器正确性"""
# 大序列配置(减小hidden_size避免OOM)
model_config = ModelConfig(
hidden_size=1024,
num_layers=1,
num_heads_q=8,
num_heads_kv=4,
head_dim=128,
intermediate_size=4096,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
executors = _setup_executors(model)
# 生成测试输入
seq_len = 8192
test_input = torch.randn(seq_len, model_config.hidden_size).to(DEVICE)
# 基准结果
with torch.no_grad():
baseline = model(test_input)
# 验证阶段化绿色Stream(重点验证最优性能执行器)
with torch.no_grad():
stage_green_result = executors["stage_green"].execute(test_input)
executors["stage_green"].synchronize()
assert torch.allclose(baseline, stage_green_result, atol=ATOL), f"大序列长度 {seq_len} 阶段化绿色Stream结果不匹配"
def test_executor_efficiency():
"""测试所有执行器的效率(输出耗时和加速比)"""
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_CONFIG = ModelConfig(
hidden_size=4096,
num_layers=1, ### fx382->raw858 ### fx3104->raw7409
num_heads_q=32,
num_heads_kv=8,
head_dim=128,
intermediate_size=16384,
activation_type="gelu",
)
seq_len = 1
warmup_steps = 10
run_steps = 10
model = TransformerModel(BASE_CONFIG).to(DEVICE)
# model = MiniMLP(BASE_CONFIG).to(DEVICE)
executors = _setup_executors(model)
test_input = torch.randn(seq_len, BASE_CONFIG.hidden_size).to(DEVICE)
# 定义各执行函数
def run_original():
with torch.no_grad():
executors["model"](test_input)
torch.cuda.synchronize(DEVICE)
def run_fx():
with torch.no_grad():
executors["fx"](test_input)
torch.cuda.synchronize(DEVICE)
def run_raw():
with torch.no_grad():
executors["raw"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_default():
with torch.no_grad():
executors["normal_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_multi():
with torch.no_grad():
executors["normal_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_normal_green():
with torch.no_grad():
executors["normal_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_default():
with torch.no_grad():
executors["stage_default"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_multi():
with torch.no_grad():
executors["stage_multi"].execute(test_input)
torch.cuda.synchronize(DEVICE)
def run_stage_green():
with torch.no_grad():
executors["stage_green"].execute(test_input)
torch.cuda.synchronize(DEVICE)
# 执行基准测试
times = {
"original": benchmark_func(run_original, warmup_steps, run_steps),
"fx": benchmark_func(run_fx, warmup_steps, run_steps),
"raw": benchmark_func(run_raw, warmup_steps, run_steps),
"normal_default": benchmark_func(run_normal_default, warmup_steps, run_steps),
"normal_multi": benchmark_func(run_normal_multi, warmup_steps, run_steps),
"normal_green": benchmark_func(run_normal_green, warmup_steps, run_steps),
"stage_default": benchmark_func(run_stage_default, warmup_steps, run_steps),
"stage_multi": benchmark_func(run_stage_multi, warmup_steps, run_steps),
"stage_green": benchmark_func(run_stage_green, warmup_steps, run_steps),
}
# 计算加速比
speedups = {k: times["original"] / v for k, v in times.items()}
# 输出结果(pytest会捕获print输出)
print(f"\n=== 执行器效率测试结果({seq_len=}) ===")
for name, t in times.items():
print(f"{name:15s}: {t:.6f} 秒/次 (加速比: {speedups[name]:.2f}x)")
def test_executor_edge_cases():
"""测试边界情况"""
# 1. 极小序列长度
model_config = ModelConfig(
hidden_size=128,
num_layers=1,
num_heads_q=4,
num_heads_kv=2,
head_dim=32,
intermediate_size=512,
activation_type="gelu",
)
model = TransformerModel(model_config).to(DEVICE)
model.eval()
executors = _setup_executors(model)
test_input = torch.randn(1, model_config.hidden_size).to(DEVICE)
with torch.no_grad():
baseline = model(test_input)
result = executors["normal_default"].execute(test_input)
executors["normal_default"].synchronize()
assert torch.allclose(baseline, result, atol=ATOL), "极小序列长度执行失败"
# 2. 空依赖模型测试
simple_model = torch.nn.Sequential(torch.nn.Linear(128, 256), torch.nn.GELU(), torch.nn.Linear(256, 128)).to(DEVICE)
simple_model.eval()
gm_simple = fx.symbolic_trace(simple_model)
executor = GraphNormalExecutor(gm_simple, DEVICE)
test_input = torch.randn(32, 128).to(DEVICE)
with torch.no_grad():
baseline = simple_model(test_input)
executor_result = executor.execute(test_input)
executor.synchronize()
assert torch.allclose(baseline, executor_result, atol=ATOL), "空依赖节点执行失败"
|