rahul7star/MagiGPU-storage / MagiCompiler /docs /AutoCudaGraphDesign.md
|
download
raw
8.12 kB

AutoCudaGraph Design

Author: ZhiyaoCen

Overview

AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:

  • Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
  • Support dynamic shape adaptation (sequence length expansion)
  • Optimize memory efficiency via global memory pool and static tensor reuse
  • Ensure consistency between cached graphs and runtime inputs/outputs

Key Components

CudaGraphMgr (Core Manager)

Singleton class managing all CUDA Graph operations:

class CudaGraphMgr:
    def __init__(self):
        self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
        self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None

Core Methods

Method Purpose
run() Main entry: Replay cached graph or warm up & capture new graph
wrapped_graph_capture() Capture CUDA Graph with sliced static input/output tensors
wrapped_graph_replay() Replay cached CUDA Graph with sliced static tensors and output template wrapping
get_expanded_static_tensors() Expand static tensors, reuse buffers if dimensionally compatible

Signature System

StaticSignature

@dataclass(unsafe_hash=True)
class StaticSignature(HashableDataclass):
    func_name: str = ""
    tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
  • Encodes fixed properties of input tensors (dtype, static dimensions)
  • Used as primary key for static tensor buffer caching

DynamicSignature

@dataclass(unsafe_hash=True)
class DynamicSignature(HashableDataclass):
    tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
    literals_info: LiteralsInfo = None
  • Tracks dynamic dimensions (sequence length) and literal parameters
  • Secondary key for graph entry lookup

Tensor Management

@dataclass
class StaticTensorEntry:
    input_tensors: Optional[List[torch.Tensor]] = None
    output_tensors: Optional[List[torch.Tensor]] = None
    template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
  • Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
  • Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
  • Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors

Graph Management

@dataclass
class GraphEntry:
    graph: Optional[torch.cuda.CUDAGraph] = None
    inconsistent: bool = False
    invalid: bool = False

@dataclass
class OutputTemplateEntry:
    graph_entry_dict: Dict[int, GraphEntry] = None
    output_template: Any = None
  • Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
  • Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
  • Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.

Execution Flow

Inline Replay (Fast Path)

  • Extract input signatures from runtime arguments
  • Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
  • Validate graph consistency (not inconsistent/invalid)
  • Reuse static tensors with dynamic slicing
  • Replay graph and return sliced output

Graph Capture (Slow Path)

Triggered when no valid cached graph exists or tensor expansion is needed:

  • Execute function to get output tensors
  • Ensure input signatures match post-warmup
  • Expand static buffers if new shapes require it
  • Capture new CUDA Graph with static tensors
  • Store new graph and update tensor entries
  • Return warmup execution output as final result

Sequence Length Handling

  • Only last dimension is static for ND tensors (ND > 1)
  • All dimension is dynamic for 1D tensors (ND=1)
  • Automatic buffer expansion for increasing sequence lengths
  • Invalidates old graphs when tensors are expanded

Examples

import torch
import torch.nn as nn
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if

class SimpleTransformerLayer(nn.Module):
    def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.layer_number = 0

    @cuda_graph_enable_if(lambda: torch.cuda.is_available())
    def forward(self, x: torch.Tensor):
        attn_out, _ = self.self_attn(x, x, x)
        out = self.linear(self.layer_norm(x + attn_out))
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
graph_mgr = cuda_graph_mgr()

with torch.no_grad():
    input_1 = torch.randn(2, 512, 1024, device=device)
    output_1 = model(input_1)
    print(f"First run (graph capture): Output shape = {output_1.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")

    input_2 = torch.randn(2, 512, 1024, device=device)
    output_2 = model(input_2)
    print(f"Second run (graph replay): Output shape = {output_2.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")

    input_3 = torch.randn(2, 1024, 1024, device=device)
    output_3 = model(input_3)
    print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
    print(f"Cached graphs count: {graph_mgr.graph_count}")
    print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")

    print("\nCUDA Graph Cache Details:")
    print(graph_mgr.formatted_cache_str())

    # StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
    #   Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
    #   Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
    #   DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
    #     Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
    #     Layer 0: Graph Status: Invalid
    #   DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
    #     Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
    #     Layer 0: Graph Status: Valid

Limitations and Constraints

  • No support for data-dependent control flow in captured functions
  • Graph capture fails if function contains CPU/GPU synchronization
  • Only supports CUDA tensors (CPU tensors trigger fallback)
  • Custom input classes must inherit from InplaceSubstituteFakeClass
  • Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
  • Relies on identical function, input tensors shapes, and constants for valid graph reuse
  • No support for multi-stream execution scenarios

Best Practices

  • Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
  • Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
  • Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
  • LRU Cache (Future): Implement cache eviction to limit total graph/tensor count

Xet Storage Details

Size:
8.12 kB
·
Xet hash:
fb6edc72cb479e3e7bda253351479f92a4c3d0ae45118ddd86a5c2a5330af6ae

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.