i4ata commited on
Commit
fb9b166
·
1 Parent(s): 4fcc913
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -15,13 +15,13 @@ class GradioApp:
15
 
16
  def __init__(self) -> None:
17
 
18
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
- custom = CustomUnet().to(device).eval()
21
- custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=device))
22
 
23
- pretrained = get_pretrained_unet().to(device).eval()
24
- pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=device))
25
 
26
  self.models = {
27
  'Custom': custom,
 
15
 
16
  def __init__(self) -> None:
17
 
18
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
+ custom = CustomUnet().to(self.device).eval()
21
+ custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=self.device))
22
 
23
+ pretrained = get_pretrained_unet().to(self.device).eval()
24
+ pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=self.device))
25
 
26
  self.models = {
27
  'Custom': custom,