File size: 928 Bytes
b5a064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import sys

import torch

from modules.cmd_opts import opts

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODELS_DIR = os.path.join(ROOT_DIR, "models")


def has_mps():
    if sys.platform != "darwin":
        return False
    else:
        if not getattr(torch, "has_mps", False):
            return False
        try:
            torch.zeros(1).to(torch.device("mps"))
            return True
        except Exception:
            return False


is_half = opts.precision == "fp16"
half_support = (
    torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 5.3
)

if not half_support:
    print("WARNING: FP16 is not supported on this GPU")
    is_half = False

device = "cuda:0"

if not torch.cuda.is_available():
    if has_mps():
        print("Using MPS")
        device = "mps"
    else:
        print("Using CPU")
        device = "cpu"

device = torch.device(device)