Spaces:
Runtime error
Runtime error
File size: 8,115 Bytes
e6066e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | ## AutoCudaGraph Design
Author: ZhiyaoCen
## Overview
AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
* Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
* Support dynamic shape adaptation (sequence length expansion)
* Optimize memory efficiency via global memory pool and static tensor reuse
* Ensure consistency between cached graphs and runtime inputs/outputs
## Key Components
### CudaGraphMgr (Core Manager)
Singleton class managing all CUDA Graph operations:
```python
class CudaGraphMgr:
def __init__(self):
self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
```
**Core Methods**
| Method | Purpose |
|---------------------------------|----------------------------------------|
| run() | Main entry: Replay cached graph or warm up & capture new graph|
| wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors |
| wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping
| get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible|
### Signature System
StaticSignature
```python
@dataclass(unsafe_hash=True)
class StaticSignature(HashableDataclass):
func_name: str = ""
tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
```
* Encodes fixed properties of input tensors (dtype, static dimensions)
* Used as primary key for static tensor buffer caching
DynamicSignature
```python
@dataclass(unsafe_hash=True)
class DynamicSignature(HashableDataclass):
tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
literals_info: LiteralsInfo = None
```
* Tracks dynamic dimensions (sequence length) and literal parameters
* Secondary key for graph entry lookup
### Tensor Management
```python
@dataclass
class StaticTensorEntry:
input_tensors: Optional[List[torch.Tensor]] = None
output_tensors: Optional[List[torch.Tensor]] = None
template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
```
* Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
* Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
* Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors
### Graph Management
```python
@dataclass
class GraphEntry:
graph: Optional[torch.cuda.CUDAGraph] = None
inconsistent: bool = False
invalid: bool = False
@dataclass
class OutputTemplateEntry:
graph_entry_dict: Dict[int, GraphEntry] = None
output_template: Any = None
```
* Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
* Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
* Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.
## Execution Flow
### Inline Replay (Fast Path)
* Extract input signatures from runtime arguments
* Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
* Validate graph consistency (not inconsistent/invalid)
* Reuse static tensors with dynamic slicing
* Replay graph and return sliced output
### Graph Capture (Slow Path)
Triggered when no valid cached graph exists or tensor expansion is needed:
* Execute function to get output tensors
* Ensure input signatures match post-warmup
* Expand static buffers if new shapes require it
* Capture new CUDA Graph with static tensors
* Store new graph and update tensor entries
* Return warmup execution output as final result
### Sequence Length Handling
* Only last dimension is static for ND tensors (ND > 1)
* All dimension is dynamic for 1D tensors (ND=1)
* Automatic buffer expansion for increasing sequence lengths
* Invalidates old graphs when tensors are expanded
## Examples
```python
import torch
import torch.nn as nn
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if
class SimpleTransformerLayer(nn.Module):
def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
super().__init__()
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.linear = nn.Linear(hidden_dim, hidden_dim)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.layer_number = 0
@cuda_graph_enable_if(lambda: torch.cuda.is_available())
def forward(self, x: torch.Tensor):
attn_out, _ = self.self_attn(x, x, x)
out = self.linear(self.layer_norm(x + attn_out))
return out
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
graph_mgr = cuda_graph_mgr()
with torch.no_grad():
input_1 = torch.randn(2, 512, 1024, device=device)
output_1 = model(input_1)
print(f"First run (graph capture): Output shape = {output_1.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_2 = torch.randn(2, 512, 1024, device=device)
output_2 = model(input_2)
print(f"Second run (graph replay): Output shape = {output_2.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_3 = torch.randn(2, 1024, 1024, device=device)
output_3 = model(input_3)
print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")
print("\nCUDA Graph Cache Details:")
print(graph_mgr.formatted_cache_str())
# StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
# Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Invalid
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Valid
```
## Limitations and Constraints
* No support for data-dependent control flow in captured functions
* Graph capture fails if function contains CPU/GPU synchronization
* Only supports CUDA tensors (CPU tensors trigger fallback)
* Custom input classes must inherit from InplaceSubstituteFakeClass
* Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
* Relies on identical function, input tensors shapes, and constants for valid graph reuse
* No support for multi-stream execution scenarios
## Best Practices
* Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
* Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
* Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
* LRU Cache (Future): Implement cache eviction to limit total graph/tensor count
|