Buckets:
| ## AutoCudaGraph Design | |
| Author: ZhiyaoCen | |
| ## Overview | |
| AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals: | |
| * Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion | |
| * Support dynamic shape adaptation (sequence length expansion) | |
| * Optimize memory efficiency via global memory pool and static tensor reuse | |
| * Ensure consistency between cached graphs and runtime inputs/outputs | |
| ## Key Components | |
| ### CudaGraphMgr (Core Manager) | |
| Singleton class managing all CUDA Graph operations: | |
| ```python | |
| class CudaGraphMgr: | |
| def __init__(self): | |
| self.cache: Dict[StaticSignature, StaticTensorEntry] = dict() | |
| self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None | |
| ``` | |
| **Core Methods** | |
| | Method | Purpose | | |
| |---------------------------------|----------------------------------------| | |
| | run() | Main entry: Replay cached graph or warm up & capture new graph| | |
| | wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors | | |
| | wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping | |
| | get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible| | |
| ### Signature System | |
| StaticSignature | |
| ```python | |
| @dataclass(unsafe_hash=True) | |
| class StaticSignature(HashableDataclass): | |
| func_name: str = "" | |
| tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple() | |
| ``` | |
| * Encodes fixed properties of input tensors (dtype, static dimensions) | |
| * Used as primary key for static tensor buffer caching | |
| DynamicSignature | |
| ```python | |
| @dataclass(unsafe_hash=True) | |
| class DynamicSignature(HashableDataclass): | |
| tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple() | |
| literals_info: LiteralsInfo = None | |
| ``` | |
| * Tracks dynamic dimensions (sequence length) and literal parameters | |
| * Secondary key for graph entry lookup | |
| ### Tensor Management | |
| ```python | |
| @dataclass | |
| class StaticTensorEntry: | |
| input_tensors: Optional[List[torch.Tensor]] = None | |
| output_tensors: Optional[List[torch.Tensor]] = None | |
| template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None | |
| ``` | |
| * Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation | |
| * Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size | |
| * Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors | |
| ### Graph Management | |
| ```python | |
| @dataclass | |
| class GraphEntry: | |
| graph: Optional[torch.cuda.CUDAGraph] = None | |
| inconsistent: bool = False | |
| invalid: bool = False | |
| @dataclass | |
| class OutputTemplateEntry: | |
| graph_entry_dict: Dict[int, GraphEntry] = None | |
| output_template: Any = None | |
| ``` | |
| * Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility. | |
| * Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse. | |
| * Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay. | |
| ## Execution Flow | |
| ### Inline Replay (Fast Path) | |
| * Extract input signatures from runtime arguments | |
| * Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number | |
| * Validate graph consistency (not inconsistent/invalid) | |
| * Reuse static tensors with dynamic slicing | |
| * Replay graph and return sliced output | |
| ### Graph Capture (Slow Path) | |
| Triggered when no valid cached graph exists or tensor expansion is needed: | |
| * Execute function to get output tensors | |
| * Ensure input signatures match post-warmup | |
| * Expand static buffers if new shapes require it | |
| * Capture new CUDA Graph with static tensors | |
| * Store new graph and update tensor entries | |
| * Return warmup execution output as final result | |
| ### Sequence Length Handling | |
| * Only last dimension is static for ND tensors (ND > 1) | |
| * All dimension is dynamic for 1D tensors (ND=1) | |
| * Automatic buffer expansion for increasing sequence lengths | |
| * Invalidates old graphs when tensors are expanded | |
| ## Examples | |
| ```python | |
| import torch | |
| import torch.nn as nn | |
| from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if | |
| class SimpleTransformerLayer(nn.Module): | |
| def __init__(self, hidden_dim: int = 1024, num_heads: int = 8): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True) | |
| self.linear = nn.Linear(hidden_dim, hidden_dim) | |
| self.layer_norm = nn.LayerNorm(hidden_dim) | |
| self.layer_number = 0 | |
| @cuda_graph_enable_if(lambda: torch.cuda.is_available()) | |
| def forward(self, x: torch.Tensor): | |
| attn_out, _ = self.self_attn(x, x, x) | |
| out = self.linear(self.layer_norm(x + attn_out)) | |
| return out | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval() | |
| graph_mgr = cuda_graph_mgr() | |
| with torch.no_grad(): | |
| input_1 = torch.randn(2, 512, 1024, device=device) | |
| output_1 = model(input_1) | |
| print(f"First run (graph capture): Output shape = {output_1.shape}") | |
| print(f"Cached graphs count: {graph_mgr.graph_count}") | |
| input_2 = torch.randn(2, 512, 1024, device=device) | |
| output_2 = model(input_2) | |
| print(f"Second run (graph replay): Output shape = {output_2.shape}") | |
| print(f"Cached graphs count: {graph_mgr.graph_count}") | |
| input_3 = torch.randn(2, 1024, 1024, device=device) | |
| output_3 = model(input_3) | |
| print(f"Third run (tensor expansion): Output shape = {output_3.shape}") | |
| print(f"Cached graphs count: {graph_mgr.graph_count}") | |
| print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB") | |
| print("\nCUDA Graph Cache Details:") | |
| print(graph_mgr.formatted_cache_str()) | |
| # StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),)) | |
| # Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32] | |
| # Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32] | |
| # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=())) | |
| # Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0') | |
| # Layer 0: Graph Status: Invalid | |
| # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=())) | |
| # Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0') | |
| # Layer 0: Graph Status: Valid | |
| ``` | |
| ## Limitations and Constraints | |
| * No support for data-dependent control flow in captured functions | |
| * Graph capture fails if function contains CPU/GPU synchronization | |
| * Only supports CUDA tensors (CPU tensors trigger fallback) | |
| * Custom input classes must inherit from InplaceSubstituteFakeClass | |
| * Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse) | |
| * Relies on identical function, input tensors shapes, and constants for valid graph reuse | |
| * No support for multi-stream execution scenarios | |
| ## Best Practices | |
| * Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible | |
| * Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM | |
| * Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers | |
| * LRU Cache (Future): Implement cache eviction to limit total graph/tensor count | |
Xet Storage Details
- Size:
- 8.12 kB
- Xet hash:
- fb6edc72cb479e3e7bda253351479f92a4c3d0ae45118ddd86a5c2a5330af6ae
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.