Spaces:
Sleeping
Sleeping
Commit
·
cfa4803
1
Parent(s):
5d50fba
Avoid CUDA init on Spaces: map tensors to device, no .cuda() in LatentEditor
Browse files- editings/latent_editor.py +9 -8
- runners/base_runner.py +2 -1
editings/latent_editor.py
CHANGED
|
@@ -42,9 +42,10 @@ STYLESPACE_IDX = [
|
|
| 42 |
|
| 43 |
|
| 44 |
class LatentEditor:
|
| 45 |
-
def __init__(self, domain="human_faces"):
|
| 46 |
|
| 47 |
self.domain = domain
|
|
|
|
| 48 |
|
| 49 |
if self.domain == "human_faces":
|
| 50 |
self.interfacegan_directions = {
|
|
@@ -53,11 +54,11 @@ class LatentEditor:
|
|
| 53 |
"rotation": "editings/interfacegan_directions/rotation.pt",
|
| 54 |
}
|
| 55 |
self.interfacegan_tensors = {
|
| 56 |
-
name: torch.load(path
|
| 57 |
for name, path in self.interfacegan_directions.items()
|
| 58 |
}
|
| 59 |
|
| 60 |
-
self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt")
|
| 61 |
self.ganspace_directions = {
|
| 62 |
"eye_openness": (54, 7, 8, 5),
|
| 63 |
"trimmed_beard": (58, 7, 9, 7),
|
|
@@ -147,10 +148,10 @@ class LatentEditor:
|
|
| 147 |
|
| 148 |
|
| 149 |
def load_styleclip_global(self):
|
| 150 |
-
delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().
|
| 151 |
with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
|
| 152 |
_, s_std = pickle.load(channels_statistics)
|
| 153 |
-
s_std = [torch.from_numpy(s_i).float().
|
| 154 |
with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
|
| 155 |
text_prompt_templates = templates.readlines()
|
| 156 |
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
|
|
@@ -179,7 +180,7 @@ class LatentEditor:
|
|
| 179 |
opts = argparse.Namespace(**opts)
|
| 180 |
style_clip_net = StyleCLIPMapper(opts)
|
| 181 |
style_clip_net.eval()
|
| 182 |
-
style_clip_net.
|
| 183 |
direction = style_clip_net.mapper(start_w)
|
| 184 |
for factor in factors:
|
| 185 |
edited_latent = start_w + factor * direction
|
|
@@ -195,7 +196,7 @@ class LatentEditor:
|
|
| 195 |
disentanglement = float(disentanglement)
|
| 196 |
|
| 197 |
directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
|
| 198 |
-
factors = torch.tensor(factors).
|
| 199 |
srart_ss, start_rgb = start_s
|
| 200 |
|
| 201 |
edits_rgb = []
|
|
@@ -222,7 +223,7 @@ class LatentEditor:
|
|
| 222 |
neutral_text, target_text, disentanglement = direction.split("_")
|
| 223 |
disentanglement = float(disentanglement)
|
| 224 |
|
| 225 |
-
factors = torch.tensor(factors).
|
| 226 |
srart_ss, edited_rgb = start_s
|
| 227 |
edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
|
| 228 |
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
class LatentEditor:
|
| 45 |
+
def __init__(self, domain="human_faces", device="cpu"):
|
| 46 |
|
| 47 |
self.domain = domain
|
| 48 |
+
self.device = torch.device(device)
|
| 49 |
|
| 50 |
if self.domain == "human_faces":
|
| 51 |
self.interfacegan_directions = {
|
|
|
|
| 54 |
"rotation": "editings/interfacegan_directions/rotation.pt",
|
| 55 |
}
|
| 56 |
self.interfacegan_tensors = {
|
| 57 |
+
name: torch.load(path, map_location=self.device)
|
| 58 |
for name, path in self.interfacegan_directions.items()
|
| 59 |
}
|
| 60 |
|
| 61 |
+
self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt", map_location=self.device)
|
| 62 |
self.ganspace_directions = {
|
| 63 |
"eye_openness": (54, 7, 8, 5),
|
| 64 |
"trimmed_beard": (58, 7, 9, 7),
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
def load_styleclip_global(self):
|
| 151 |
+
delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().to(self.device)
|
| 152 |
with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
|
| 153 |
_, s_std = pickle.load(channels_statistics)
|
| 154 |
+
s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
|
| 155 |
with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
|
| 156 |
text_prompt_templates = templates.readlines()
|
| 157 |
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
|
|
|
|
| 180 |
opts = argparse.Namespace(**opts)
|
| 181 |
style_clip_net = StyleCLIPMapper(opts)
|
| 182 |
style_clip_net.eval()
|
| 183 |
+
style_clip_net.to(self.device)
|
| 184 |
direction = style_clip_net.mapper(start_w)
|
| 185 |
for factor in factors:
|
| 186 |
edited_latent = start_w + factor * direction
|
|
|
|
| 196 |
disentanglement = float(disentanglement)
|
| 197 |
|
| 198 |
directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
|
| 199 |
+
factors = torch.tensor(factors).to(self.device).view(-1, 1)
|
| 200 |
srart_ss, start_rgb = start_s
|
| 201 |
|
| 202 |
edits_rgb = []
|
|
|
|
| 223 |
neutral_text, target_text, disentanglement = direction.split("_")
|
| 224 |
disentanglement = float(disentanglement)
|
| 225 |
|
| 226 |
+
factors = torch.tensor(factors).to(self.device).view(-1, 1)
|
| 227 |
srart_ss, edited_rgb = start_s
|
| 228 |
edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
|
| 229 |
|
runners/base_runner.py
CHANGED
|
@@ -74,7 +74,8 @@ class BaseRunner:
|
|
| 74 |
return edited_latents
|
| 75 |
|
| 76 |
def _setup_latent_editor(self):
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
def _setup_device(self):
|
| 80 |
config_device = self.config.model["device"].lower()
|
|
|
|
| 74 |
return edited_latents
|
| 75 |
|
| 76 |
def _setup_latent_editor(self):
|
| 77 |
+
# Pass device to avoid unintended CUDA initialization on Spaces
|
| 78 |
+
self.latent_editor = LatentEditor(self.config.exp.domain, device=self.device)
|
| 79 |
|
| 80 |
def _setup_device(self):
|
| 81 |
config_device = self.config.model["device"].lower()
|