LogicGoInfotechSpaces commited on
Commit
372a1fe
·
1 Parent(s): 87d55be

Add device detection and CPU fallback directly in FSEFull and FSEInverter constructors

Browse files
Files changed (1) hide show
  1. models/methods.py +18 -0
models/methods.py CHANGED
@@ -38,6 +38,15 @@ class FSEFull(nn.Module):
38
  self.opts.update(paths)
39
  self.opts = Namespace(**self.opts)
40
 
 
 
 
 
 
 
 
 
 
41
  self.device = torch.device(device)
42
  self.inverter_pth = inverter_pth
43
 
@@ -169,6 +178,15 @@ class FSEInverter(nn.Module):
169
  self.opts.update(paths)
170
  self.opts = Namespace(**self.opts)
171
 
 
 
 
 
 
 
 
 
 
172
  self.device = torch.device(device)
173
  self.encoder = self.set_encoder()
174
 
 
38
  self.opts.update(paths)
39
  self.opts = Namespace(**self.opts)
40
 
41
+ # Handle device detection and fallback to CPU if CUDA is not available
42
+ try:
43
+ torch.randn(1).to(device)
44
+ print("Device: {}".format(device))
45
+ except Exception as e:
46
+ print("Could not use device {}, {}".format(device, e))
47
+ print("Set device to CPU")
48
+ device = "cpu"
49
+
50
  self.device = torch.device(device)
51
  self.inverter_pth = inverter_pth
52
 
 
178
  self.opts.update(paths)
179
  self.opts = Namespace(**self.opts)
180
 
181
+ # Handle device detection and fallback to CPU if CUDA is not available
182
+ try:
183
+ torch.randn(1).to(device)
184
+ print("Device: {}".format(device))
185
+ except Exception as e:
186
+ print("Could not use device {}, {}".format(device, e))
187
+ print("Set device to CPU")
188
+ device = "cpu"
189
+
190
  self.device = torch.device(device)
191
  self.encoder = self.set_encoder()
192