| # -*- coding: utf-8 -*- | |
| from fla.utils import device_platform | |
| def fp32_to_tf32_asm() -> str: | |
| """ | |
| Get the assembly code for converting FP32 to TF32. | |
| """ | |
| ASM_DICT = { | |
| 'nvidia': 'cvt.rna.tf32.f32 $0, $1;' | |
| } | |
| if device_platform in ASM_DICT: | |
| return ASM_DICT[device_platform] | |
| else: | |
| # return empty string if the device is not supported | |
| return "" | |