Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
·
5b7158a
1
Parent(s):
2fec875
fix device
Browse files- criteria/id_loss.py +2 -2
- criteria/mask.py +2 -2
- training/coaches/base_coach.py +2 -1
criteria/id_loss.py
CHANGED
|
@@ -17,7 +17,7 @@ class IDLoss(nn.Module):
|
|
| 17 |
[4] https://github.com/eladrich/pixel2style2pixel
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
def __init__(self, model_path, official=False):
|
| 21 |
"""
|
| 22 |
Arguments:
|
| 23 |
model_path (str): Path to IR-SE50 model.
|
|
@@ -32,7 +32,7 @@ class IDLoss(nn.Module):
|
|
| 32 |
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
|
| 33 |
)
|
| 34 |
|
| 35 |
-
self.facenet.load_state_dict(torch.load(model_path))
|
| 36 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 37 |
self.facenet.eval()
|
| 38 |
|
|
|
|
| 17 |
[4] https://github.com/eladrich/pixel2style2pixel
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
def __init__(self, model_path, official=False, device="cpu"):
|
| 21 |
"""
|
| 22 |
Arguments:
|
| 23 |
model_path (str): Path to IR-SE50 model.
|
|
|
|
| 32 |
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
|
| 33 |
)
|
| 34 |
|
| 35 |
+
self.facenet.load_state_dict(torch.load(model_path, map_location=device))
|
| 36 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 37 |
self.facenet.eval()
|
| 38 |
|
criteria/mask.py
CHANGED
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
| 9 |
|
| 10 |
|
| 11 |
class Mask(nn.Module):
|
| 12 |
-
def __init__(self):
|
| 13 |
"""
|
| 14 |
|
| 15 |
| Class | Number | Class | Number |
|
|
@@ -41,7 +41,7 @@ class Mask(nn.Module):
|
|
| 41 |
.requires_grad_(False)
|
| 42 |
)
|
| 43 |
|
| 44 |
-
ckpt = torch.load(paths_config.deeplab, map_location=
|
| 45 |
state_dict = {
|
| 46 |
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
|
| 47 |
}
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class Mask(nn.Module):
|
| 12 |
+
def __init__(self, device="cpu"):
|
| 13 |
"""
|
| 14 |
|
| 15 |
| Class | Number | Class | Number |
|
|
|
|
| 41 |
.requires_grad_(False)
|
| 42 |
)
|
| 43 |
|
| 44 |
+
ckpt = torch.load(paths_config.deeplab, map_location=device)
|
| 45 |
state_dict = {
|
| 46 |
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
|
| 47 |
}
|
training/coaches/base_coach.py
CHANGED
|
@@ -51,13 +51,14 @@ class BaseCoach:
|
|
| 51 |
id_loss.IDLoss(
|
| 52 |
paths_config.ir_se50,
|
| 53 |
official=False,
|
|
|
|
| 54 |
)
|
| 55 |
.to(global_config.device)
|
| 56 |
.eval()
|
| 57 |
)
|
| 58 |
|
| 59 |
if hyperparameters.use_mask:
|
| 60 |
-
self.mask = mask.Mask()
|
| 61 |
|
| 62 |
self.restart_training()
|
| 63 |
|
|
|
|
| 51 |
id_loss.IDLoss(
|
| 52 |
paths_config.ir_se50,
|
| 53 |
official=False,
|
| 54 |
+
device=global_config.device
|
| 55 |
)
|
| 56 |
.to(global_config.device)
|
| 57 |
.eval()
|
| 58 |
)
|
| 59 |
|
| 60 |
if hyperparameters.use_mask:
|
| 61 |
+
self.mask = mask.Mask(device=global_config.device)
|
| 62 |
|
| 63 |
self.restart_training()
|
| 64 |
|