Update src/facerender/modules/make_animation.py
Browse files
src/facerender/modules/make_animation.py
CHANGED
|
@@ -148,10 +148,15 @@ def make_animation(source_image, source_semantics, target_semantics,
|
|
| 148 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
| 149 |
use_exp=True):
|
| 150 |
|
| 151 |
-
device='cuda'
|
| 152 |
-
generator =
|
| 153 |
-
kp_detector =
|
| 154 |
-
mapping =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
source_image = source_image.to(device)
|
| 157 |
source_semantics = source_semantics.to(device)
|
|
|
|
| 148 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
| 149 |
use_exp=True):
|
| 150 |
|
| 151 |
+
device = 'cuda:0'
|
| 152 |
+
generator = generator.to(device)
|
| 153 |
+
kp_detector = kp_detector.to(device)
|
| 154 |
+
mapping = mapping.to(device)
|
| 155 |
+
|
| 156 |
+
# Wrap the models in DataParallel to use all available GPUs
|
| 157 |
+
generator = torch.nn.DataParallel(generator)
|
| 158 |
+
kp_detector = torch.nn.DataParallel(kp_detector)
|
| 159 |
+
mapping = torch.nn.DataParallel(mapping)
|
| 160 |
|
| 161 |
source_image = source_image.to(device)
|
| 162 |
source_semantics = source_semantics.to(device)
|