Spaces:
Sleeping
Sleeping
| from typing import Union | |
| import torch | |
| def set_device(device : Union[str, torch.device]) -> torch.device: | |
| """ | |
| Set the device to use for inference. Recommended to use GPU. | |
| Arguments: | |
| device Union[str, torch.device] | |
| The device to use for inference. Can be either a string or a torch.device object. | |
| Returns: | |
| torch.device | |
| The device to use for inference. | |
| """ | |
| if isinstance(device, str): | |
| if device == 'cuda' and torch.cuda.is_available(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| elif device == 'mps' and torch.backends.mps.is_built(): | |
| device = torch.device('mps') | |
| else: | |
| device = torch.device(device) | |
| return device |