Update build/torch-universal/triton_kernels/target_info.py
#7
by
KernelMC - opened
build/torch-universal/triton_kernels/target_info.py
CHANGED
|
@@ -92,7 +92,7 @@ def has_native_mxfp():
|
|
| 92 |
|
| 93 |
|
| 94 |
def num_sms():
|
| 95 |
-
if is_cuda():
|
| 96 |
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
if is_xpu():
|
| 98 |
return torch.xpu.get_device_properties(0).max_compute_units
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def num_sms():
|
| 95 |
+
if is_cuda() or is_hip():
|
| 96 |
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
if is_xpu():
|
| 98 |
return torch.xpu.get_device_properties(0).max_compute_units
|