rahul7star/MagiGPU-storage / MagiCompiler /docs /AutoCudaGraphDesign.md
|
download
raw
8.12 kB
## AutoCudaGraph Design
Author: ZhiyaoCen
## Overview
AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
* Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
* Support dynamic shape adaptation (sequence length expansion)
* Optimize memory efficiency via global memory pool and static tensor reuse
* Ensure consistency between cached graphs and runtime inputs/outputs
## Key Components
### CudaGraphMgr (Core Manager)
Singleton class managing all CUDA Graph operations:
```python
class CudaGraphMgr:
def __init__(self):
self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
```
**Core Methods**
| Method | Purpose |
|---------------------------------|----------------------------------------|
| run() | Main entry: Replay cached graph or warm up & capture new graph|
| wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors |
| wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping
| get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible|
### Signature System
StaticSignature
```python
@dataclass(unsafe_hash=True)
class StaticSignature(HashableDataclass):
func_name: str = ""
tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
```
* Encodes fixed properties of input tensors (dtype, static dimensions)
* Used as primary key for static tensor buffer caching
DynamicSignature
```python
@dataclass(unsafe_hash=True)
class DynamicSignature(HashableDataclass):
tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
literals_info: LiteralsInfo = None
```
* Tracks dynamic dimensions (sequence length) and literal parameters
* Secondary key for graph entry lookup
### Tensor Management
```python
@dataclass
class StaticTensorEntry:
input_tensors: Optional[List[torch.Tensor]] = None
output_tensors: Optional[List[torch.Tensor]] = None
template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
```
* Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
* Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
* Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors
### Graph Management
```python
@dataclass
class GraphEntry:
graph: Optional[torch.cuda.CUDAGraph] = None
inconsistent: bool = False
invalid: bool = False
@dataclass
class OutputTemplateEntry:
graph_entry_dict: Dict[int, GraphEntry] = None
output_template: Any = None
```
* Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
* Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
* Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.
## Execution Flow
### Inline Replay (Fast Path)
* Extract input signatures from runtime arguments
* Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
* Validate graph consistency (not inconsistent/invalid)
* Reuse static tensors with dynamic slicing
* Replay graph and return sliced output
### Graph Capture (Slow Path)
Triggered when no valid cached graph exists or tensor expansion is needed:
* Execute function to get output tensors
* Ensure input signatures match post-warmup
* Expand static buffers if new shapes require it
* Capture new CUDA Graph with static tensors
* Store new graph and update tensor entries
* Return warmup execution output as final result
### Sequence Length Handling
* Only last dimension is static for ND tensors (ND > 1)
* All dimension is dynamic for 1D tensors (ND=1)
* Automatic buffer expansion for increasing sequence lengths
* Invalidates old graphs when tensors are expanded
## Examples
```python
import torch
import torch.nn as nn
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if
class SimpleTransformerLayer(nn.Module):
def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
super().__init__()
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.linear = nn.Linear(hidden_dim, hidden_dim)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.layer_number = 0
@cuda_graph_enable_if(lambda: torch.cuda.is_available())
def forward(self, x: torch.Tensor):
attn_out, _ = self.self_attn(x, x, x)
out = self.linear(self.layer_norm(x + attn_out))
return out
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
graph_mgr = cuda_graph_mgr()
with torch.no_grad():
input_1 = torch.randn(2, 512, 1024, device=device)
output_1 = model(input_1)
print(f"First run (graph capture): Output shape = {output_1.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_2 = torch.randn(2, 512, 1024, device=device)
output_2 = model(input_2)
print(f"Second run (graph replay): Output shape = {output_2.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
input_3 = torch.randn(2, 1024, 1024, device=device)
output_3 = model(input_3)
print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
print(f"Cached graphs count: {graph_mgr.graph_count}")
print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")
print("\nCUDA Graph Cache Details:")
print(graph_mgr.formatted_cache_str())
# StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
# Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Invalid
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
# Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
# Layer 0: Graph Status: Valid
```
## Limitations and Constraints
* No support for data-dependent control flow in captured functions
* Graph capture fails if function contains CPU/GPU synchronization
* Only supports CUDA tensors (CPU tensors trigger fallback)
* Custom input classes must inherit from InplaceSubstituteFakeClass
* Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
* Relies on identical function, input tensors shapes, and constants for valid graph reuse
* No support for multi-stream execution scenarios
## Best Practices
* Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
* Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
* Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
* LRU Cache (Future): Implement cache eviction to limit total graph/tensor count

Xet Storage Details

Size:
8.12 kB
·
Xet hash:
fb6edc72cb479e3e7bda253351479f92a4c3d0ae45118ddd86a5c2a5330af6ae

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.