Update videoretalking/models/__init__.py
Browse files
videoretalking/models/__init__.py
CHANGED
|
@@ -1,37 +1,37 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from models.DNet import DNet
|
| 3 |
-
from models.LNet import LNet
|
| 4 |
-
from models.ENet import ENet
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def _load(checkpoint_path):
|
| 8 |
-
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
| 9 |
-
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 10 |
-
return checkpoint
|
| 11 |
-
|
| 12 |
-
def load_checkpoint(path, model):
|
| 13 |
-
print("Load checkpoint from: {}".format(path))
|
| 14 |
-
checkpoint = _load(path)
|
| 15 |
-
s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
|
| 16 |
-
new_s = {}
|
| 17 |
-
for k, v in s.items():
|
| 18 |
-
if 'low_res' in k:
|
| 19 |
-
continue
|
| 20 |
-
else:
|
| 21 |
-
new_s[k.replace('module.', '')] = v
|
| 22 |
-
model.load_state_dict(new_s, strict=False)
|
| 23 |
-
return model
|
| 24 |
-
|
| 25 |
-
def load_network(LNet_path,ENet_path):
|
| 26 |
-
L_net = LNet()
|
| 27 |
-
L_net = load_checkpoint(LNet_path, L_net)
|
| 28 |
-
E_net = ENet(lnet=L_net)
|
| 29 |
-
model = load_checkpoint(ENet_path, E_net)
|
| 30 |
-
return model.eval()
|
| 31 |
-
|
| 32 |
-
def load_DNet(DNet_path):
|
| 33 |
-
D_Net = DNet()
|
| 34 |
-
print("Load checkpoint from: {}".format(DNet_path))
|
| 35 |
-
checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
|
| 36 |
-
D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
|
| 37 |
return D_Net.eval()
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from videoretalking.models.DNet import DNet
|
| 3 |
+
from videoretalking.models.LNet import LNet
|
| 4 |
+
from videoretalking.models.ENet import ENet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _load(checkpoint_path):
|
| 8 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
| 9 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 10 |
+
return checkpoint
|
| 11 |
+
|
| 12 |
+
def load_checkpoint(path, model):
|
| 13 |
+
print("Load checkpoint from: {}".format(path))
|
| 14 |
+
checkpoint = _load(path)
|
| 15 |
+
s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
|
| 16 |
+
new_s = {}
|
| 17 |
+
for k, v in s.items():
|
| 18 |
+
if 'low_res' in k:
|
| 19 |
+
continue
|
| 20 |
+
else:
|
| 21 |
+
new_s[k.replace('module.', '')] = v
|
| 22 |
+
model.load_state_dict(new_s, strict=False)
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
def load_network(LNet_path,ENet_path):
|
| 26 |
+
L_net = LNet()
|
| 27 |
+
L_net = load_checkpoint(LNet_path, L_net)
|
| 28 |
+
E_net = ENet(lnet=L_net)
|
| 29 |
+
model = load_checkpoint(ENet_path, E_net)
|
| 30 |
+
return model.eval()
|
| 31 |
+
|
| 32 |
+
def load_DNet(DNet_path):
|
| 33 |
+
D_Net = DNet()
|
| 34 |
+
print("Load checkpoint from: {}".format(DNet_path))
|
| 35 |
+
checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
|
| 36 |
+
D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
|
| 37 |
return D_Net.eval()
|