File size: 1,808 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import torch
from torch.cuda.streams import ExternalStream
try:
from . import spatial_ops # triggers TORCH extension registration
except Exception as _e:
_spatial_import_error = _e
else:
_spatial_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
) -> tuple[ExternalStream, ExternalStream]:
"""
Create two streams for greenctx.
Args:
sm_A (int): The SM of stream A.
sm_B (int): The weight of stream B.
device_id (int): The device id.
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
stream_a = ExternalStream(
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
)
stream_b = ExternalStream(
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
)
return stream_a, stream_b
def get_sm_available(device_id: int = None) -> int:
"""
Get the SMs available on the device.
Args:
device_id (int): The device id.
Returns:
int: The SMs available.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
device_props = torch.cuda.get_device_properties(device_id)
# Get the number of Streaming Multiprocessors (SMs)
sm_count = device_props.multi_processor_count
return sm_count
|