| import torch | |
| def infer_device(): | |
| """ | |
| Get current device name based on available devices | |
| """ | |
| if torch.cuda.is_available(): # Works for both Nvidia and AMD | |
| return "cuda" | |
| elif torch.xpu.is_available(): | |
| return "xpu" | |
| else: | |
| return None | |
| def supports_bfloat16(): | |
| device = infer_device() | |
| if device == "cuda": | |
| return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer | |
| elif device == "xpu": | |
| return True | |
| else: | |
| return False |