Buckets:
AutoCudaGraph Design
Author: ZhiyaoCen
Overview
AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
- Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
- Support dynamic shape adaptation (sequence length expansion)
- Optimize memory efficiency via global memory pool and static tensor reuse
- Ensure consistency between cached graphs and runtime inputs/outputs
Key Components
CudaGraphMgr (Core Manager)
Singleton class managing all CUDA Graph operations:
class CudaGraphMgr:
def __init__(self):
self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
Core Methods
| Method | Purpose |
|---|---|
| run() | Main entry: Replay cached graph or warm up & capture new graph |
| wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors |
| wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping |
| get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible |
Signature System
StaticSignature
@dataclass(unsafe_hash=True)
class StaticSignature(HashableDataclass):
func_name: str = ""
tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
- Encodes fixed properties of input tensors (dtype, static dimensions)
- Used as primary key for static tensor buffer caching
DynamicSignature
@dataclass(unsafe_hash=True)
class DynamicSignature(HashableDataclass):
tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
literals_info: LiteralsInfo = None
- Tracks dynamic dimensions (sequence length) and literal parameters
- Secondary key for graph entry lookup
Tensor Management
@dataclass
class StaticTensorEntry:
input_tensors: Optional[List[torch.Tensor]] = None
output_tensors: Optional[List[torch.Tensor]] = None
template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
- Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
- Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
- Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors
Graph Management
@dataclass
class GraphEntry:
graph: Optional[torch.cuda.CUDAGraph] = None
inconsistent: bool = False
invalid: bool = False
@dataclass
class OutputTemplateEntry:
graph_entry_dict: Dict[int, GraphEntry] = None
output_template: Any = None
- Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
- Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
- Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.
Execution Flow
Inline Replay (Fast Path)
- Extract input signatures from runtime arguments
- Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
- Validate graph consistency (not inconsistent/invalid)
- Reuse static tensors with dynamic slicing
- Replay graph and return sliced output
Graph Capture (Slow Path)
Triggered when no valid cached graph exists or tensor expansion is needed:
- Execute function to get output tensors
- Ensure input signatures match post-warmup
- Expand static buffers if new shapes require it
- Capture new CUDA Graph with static tensors
- Store new graph and update tensor entries
- Return warmup execution output as final result
Sequence Length Handling
- Only last dimension is static for ND tensors (ND > 1)
- All dimension is dynamic for 1D tensors (ND=1)
- Automatic buffer expansion for increasing sequence lengths
- Invalidates old graphs when tensors are expanded
Examples
import torch
import torch.nn as nn
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if
class SimpleTransformerLayer(nn.Module):
def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
super().__init__()
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.linear = nn.Linear(hidden_dim, hidden_dim)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.layer_number = 0
@cuda_graph_enable_if(lambda: torch.cuda.is_available())
def forward(self, x: torch.Tensor):
attn_out, _ = self.self_attn(x, x, x)
out = self.linear(self.layer_norm(x + attn_out))
return out
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
graph_mgr = cuda_graph_mgr()
with torch.no_grad():
input_1 = torch.randn(2, 512, 1024, device=device)
output_1 = model(input_1)
print(f"First run (graph capture): Output shape = {output_1.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_2 = torch.randn(2, 512, 1024, device=device)
output_2 = model(input_2)
print(f"Second run (graph replay): Output shape = {output_2.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_3 = torch.randn(2, 1024, 1024, device=device)
output_3 = model(input_3)
print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")
print("\nCUDA Graph Cache Details:")
print(graph_mgr.formatted_cache_str())
# StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
# Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Invalid
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Valid
Limitations and Constraints
- No support for data-dependent control flow in captured functions
- Graph capture fails if function contains CPU/GPU synchronization
- Only supports CUDA tensors (CPU tensors trigger fallback)
- Custom input classes must inherit from InplaceSubstituteFakeClass
- Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
- Relies on identical function, input tensors shapes, and constants for valid graph reuse
- No support for multi-stream execution scenarios
Best Practices
- Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
- Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
- Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
- LRU Cache (Future): Implement cache eviction to limit total graph/tensor count
Xet Storage Details
- Size:
- 8.12 kB
- Xet hash:
- fb6edc72cb479e3e7bda253351479f92a4c3d0ae45118ddd86a5c2a5330af6ae
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.