File size: 8,115 Bytes
e6066e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
## AutoCudaGraph Design

Author: ZhiyaoCen

## Overview
AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
* Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
* Support dynamic shape adaptation (sequence length expansion)
* Optimize memory efficiency via global memory pool and static tensor reuse
* Ensure consistency between cached graphs and runtime inputs/outputs
## Key Components


### CudaGraphMgr (Core Manager)
Singleton class managing all CUDA Graph operations:
```python
class CudaGraphMgr:
    def __init__(self):
        self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
        self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
```

**Core Methods**
| Method | Purpose |
|---------------------------------|----------------------------------------|
| run()                           | Main entry: Replay cached graph or warm up & capture new graph|
| wrapped_graph_capture()         | Capture CUDA Graph with sliced static input/output tensors |
| wrapped_graph_replay()          | Replay cached CUDA Graph with sliced static tensors and output template wrapping
| get_expanded_static_tensors()   | Expand static tensors, reuse buffers if dimensionally compatible|


### Signature System

StaticSignature
```python
@dataclass(unsafe_hash=True)
class StaticSignature(HashableDataclass):
    func_name: str = ""
    tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
```
* Encodes fixed properties of input tensors (dtype, static dimensions)
* Used as primary key for static tensor buffer caching

DynamicSignature
```python
@dataclass(unsafe_hash=True)
class DynamicSignature(HashableDataclass):
    tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
    literals_info: LiteralsInfo = None
```
* Tracks dynamic dimensions (sequence length) and literal parameters
* Secondary key for graph entry lookup

### Tensor Management
```python
@dataclass
class StaticTensorEntry:
    input_tensors: Optional[List[torch.Tensor]] = None
    output_tensors: Optional[List[torch.Tensor]] = None
    template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
```
* Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
* Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
* Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors


### Graph Management
```python
@dataclass
class GraphEntry:
    graph: Optional[torch.cuda.CUDAGraph] = None
    inconsistent: bool = False
    invalid: bool = False

@dataclass
class OutputTemplateEntry:
    graph_entry_dict: Dict[int, GraphEntry] = None
    output_template: Any = None
```
* Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
* Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
* Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.

## Execution Flow
### Inline Replay (Fast Path)
* Extract input signatures from runtime arguments
* Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
* Validate graph consistency (not inconsistent/invalid)
* Reuse static tensors with dynamic slicing
* Replay graph and return sliced output
### Graph Capture (Slow Path)
Triggered when no valid cached graph exists or tensor expansion is needed:
* Execute function to get output tensors
* Ensure input signatures match post-warmup
* Expand static buffers if new shapes require it
* Capture new CUDA Graph with static tensors
* Store new graph and update tensor entries
* Return warmup execution output as final result
### Sequence Length Handling
* Only last dimension is static for ND tensors (ND > 1)
* All dimension is dynamic for 1D tensors (ND=1)
* Automatic buffer expansion for increasing sequence lengths
* Invalidates old graphs when tensors are expanded


## Examples
```python
import torch
import torch.nn as nn
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if

class SimpleTransformerLayer(nn.Module):
    def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.layer_number = 0

    @cuda_graph_enable_if(lambda: torch.cuda.is_available())
    def forward(self, x: torch.Tensor):
        attn_out, _ = self.self_attn(x, x, x)
        out = self.linear(self.layer_norm(x + attn_out))
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
graph_mgr = cuda_graph_mgr()

with torch.no_grad():
    input_1 = torch.randn(2, 512, 1024, device=device)
    output_1 = model(input_1)
    print(f"First run (graph capture): Output shape = {output_1.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")

    input_2 = torch.randn(2, 512, 1024, device=device)
    output_2 = model(input_2)
    print(f"Second run (graph replay): Output shape = {output_2.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")

    input_3 = torch.randn(2, 1024, 1024, device=device)
    output_3 = model(input_3)
    print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")
    print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")

    print("\nCUDA Graph Cache Details:")
    print(graph_mgr.formatted_cache_str())

    # StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
    #   Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
    #   Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
    #   DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
    #     Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
    #     Layer 0: Graph Status: Invalid
    #   DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
    #     Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
    #     Layer 0: Graph Status: Valid
```

## Limitations and Constraints
* No support for data-dependent control flow in captured functions
* Graph capture fails if function contains CPU/GPU synchronization
* Only supports CUDA tensors (CPU tensors trigger fallback)
* Custom input classes must inherit from InplaceSubstituteFakeClass
* Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
* Relies on identical function, input tensors shapes, and constants for valid graph reuse
* No support for multi-stream execution scenarios

## Best Practices
* Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
* Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
* Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
* LRU Cache (Future): Implement cache eviction to limit total graph/tensor count