Spaces:
Sleeping
Sleeping
tiny fix
Browse files
app.py
CHANGED
|
@@ -15,13 +15,13 @@ class GradioApp:
|
|
| 15 |
|
| 16 |
def __init__(self) -> None:
|
| 17 |
|
| 18 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
|
| 20 |
-
custom = CustomUnet().to(device).eval()
|
| 21 |
-
custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=device))
|
| 22 |
|
| 23 |
-
pretrained = get_pretrained_unet().to(device).eval()
|
| 24 |
-
pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=device))
|
| 25 |
|
| 26 |
self.models = {
|
| 27 |
'Custom': custom,
|
|
|
|
| 15 |
|
| 16 |
def __init__(self) -> None:
|
| 17 |
|
| 18 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
|
| 20 |
+
custom = CustomUnet().to(self.device).eval()
|
| 21 |
+
custom.load_state_dict(torch.load('models/custom_unet.pt', map_location=self.device))
|
| 22 |
|
| 23 |
+
pretrained = get_pretrained_unet().to(self.device).eval()
|
| 24 |
+
pretrained.load_state_dict(torch.load('models/pretrained_unet.pt', map_location=self.device))
|
| 25 |
|
| 26 |
self.models = {
|
| 27 |
'Custom': custom,
|