## AutoCudaGraph Design Author: ZhiyaoCen ## Overview AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals: * Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion * Support dynamic shape adaptation (sequence length expansion) * Optimize memory efficiency via global memory pool and static tensor reuse * Ensure consistency between cached graphs and runtime inputs/outputs ## Key Components ### CudaGraphMgr (Core Manager) Singleton class managing all CUDA Graph operations: ```python class CudaGraphMgr: def __init__(self): self.cache: Dict[StaticSignature, StaticTensorEntry] = dict() self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None ``` **Core Methods** | Method | Purpose | |---------------------------------|----------------------------------------| | run() | Main entry: Replay cached graph or warm up & capture new graph| | wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors | | wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping | get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible| ### Signature System StaticSignature ```python @dataclass(unsafe_hash=True) class StaticSignature(HashableDataclass): func_name: str = "" tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple() ``` * Encodes fixed properties of input tensors (dtype, static dimensions) * Used as primary key for static tensor buffer caching DynamicSignature ```python @dataclass(unsafe_hash=True) class DynamicSignature(HashableDataclass): tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple() literals_info: LiteralsInfo = None ``` * Tracks dynamic dimensions (sequence length) and literal parameters * Secondary key for graph entry lookup ### Tensor Management ```python @dataclass class StaticTensorEntry: input_tensors: Optional[List[torch.Tensor]] = None output_tensors: Optional[List[torch.Tensor]] = None template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None ``` * Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation * Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size * Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors ### Graph Management ```python @dataclass class GraphEntry: graph: Optional[torch.cuda.CUDAGraph] = None inconsistent: bool = False invalid: bool = False @dataclass class OutputTemplateEntry: graph_entry_dict: Dict[int, GraphEntry] = None output_template: Any = None ``` * Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility. * Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse. * Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay. ## Execution Flow ### Inline Replay (Fast Path) * Extract input signatures from runtime arguments * Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number * Validate graph consistency (not inconsistent/invalid) * Reuse static tensors with dynamic slicing * Replay graph and return sliced output ### Graph Capture (Slow Path) Triggered when no valid cached graph exists or tensor expansion is needed: * Execute function to get output tensors * Ensure input signatures match post-warmup * Expand static buffers if new shapes require it * Capture new CUDA Graph with static tensors * Store new graph and update tensor entries * Return warmup execution output as final result ### Sequence Length Handling * Only last dimension is static for ND tensors (ND > 1) * All dimension is dynamic for 1D tensors (ND=1) * Automatic buffer expansion for increasing sequence lengths * Invalidates old graphs when tensors are expanded ## Examples ```python import torch import torch.nn as nn from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if class SimpleTransformerLayer(nn.Module): def __init__(self, hidden_dim: int = 1024, num_heads: int = 8): super().__init__() self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True) self.linear = nn.Linear(hidden_dim, hidden_dim) self.layer_norm = nn.LayerNorm(hidden_dim) self.layer_number = 0 @cuda_graph_enable_if(lambda: torch.cuda.is_available()) def forward(self, x: torch.Tensor): attn_out, _ = self.self_attn(x, x, x) out = self.linear(self.layer_norm(x + attn_out)) return out device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval() graph_mgr = cuda_graph_mgr() with torch.no_grad(): input_1 = torch.randn(2, 512, 1024, device=device) output_1 = model(input_1) print(f"First run (graph capture): Output shape = {output_1.shape}") print(f"Cached graphs count: {graph_mgr.graph_count}") input_2 = torch.randn(2, 512, 1024, device=device) output_2 = model(input_2) print(f"Second run (graph replay): Output shape = {output_2.shape}") print(f"Cached graphs count: {graph_mgr.graph_count}") input_3 = torch.randn(2, 1024, 1024, device=device) output_3 = model(input_3) print(f"Third run (tensor expansion): Output shape = {output_3.shape}") print(f"Cached graphs count: {graph_mgr.graph_count}") print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB") print("\nCUDA Graph Cache Details:") print(graph_mgr.formatted_cache_str()) # StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),)) # Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32] # Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32] # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=())) # Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0') # Layer 0: Graph Status: Invalid # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=())) # Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0') # Layer 0: Graph Status: Valid ``` ## Limitations and Constraints * No support for data-dependent control flow in captured functions * Graph capture fails if function contains CPU/GPU synchronization * Only supports CUDA tensors (CPU tensors trigger fallback) * Custom input classes must inherit from InplaceSubstituteFakeClass * Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse) * Relies on identical function, input tensors shapes, and constants for valid graph reuse * No support for multi-stream execution scenarios ## Best Practices * Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible * Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM * Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers * LRU Cache (Future): Implement cache eviction to limit total graph/tensor count