LogicGoInfotechSpaces commited on
Commit
cfa4803
·
1 Parent(s): 5d50fba

Avoid CUDA init on Spaces: map tensors to device, no .cuda() in LatentEditor

Browse files
editings/latent_editor.py CHANGED
@@ -42,9 +42,10 @@ STYLESPACE_IDX = [
42
 
43
 
44
  class LatentEditor:
45
- def __init__(self, domain="human_faces"):
46
 
47
  self.domain = domain
 
48
 
49
  if self.domain == "human_faces":
50
  self.interfacegan_directions = {
@@ -53,11 +54,11 @@ class LatentEditor:
53
  "rotation": "editings/interfacegan_directions/rotation.pt",
54
  }
55
  self.interfacegan_tensors = {
56
- name: torch.load(path).cuda()
57
  for name, path in self.interfacegan_directions.items()
58
  }
59
 
60
- self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt")
61
  self.ganspace_directions = {
62
  "eye_openness": (54, 7, 8, 5),
63
  "trimmed_beard": (58, 7, 9, 7),
@@ -147,10 +148,10 @@ class LatentEditor:
147
 
148
 
149
  def load_styleclip_global(self):
150
- delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().cuda()
151
  with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
152
  _, s_std = pickle.load(channels_statistics)
153
- s_std = [torch.from_numpy(s_i).float().cuda() for s_i in s_std]
154
  with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
155
  text_prompt_templates = templates.readlines()
156
  global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
@@ -179,7 +180,7 @@ class LatentEditor:
179
  opts = argparse.Namespace(**opts)
180
  style_clip_net = StyleCLIPMapper(opts)
181
  style_clip_net.eval()
182
- style_clip_net.cuda()
183
  direction = style_clip_net.mapper(start_w)
184
  for factor in factors:
185
  edited_latent = start_w + factor * direction
@@ -195,7 +196,7 @@ class LatentEditor:
195
  disentanglement = float(disentanglement)
196
 
197
  directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
198
- factors = torch.tensor(factors).cuda().view(-1, 1)
199
  srart_ss, start_rgb = start_s
200
 
201
  edits_rgb = []
@@ -222,7 +223,7 @@ class LatentEditor:
222
  neutral_text, target_text, disentanglement = direction.split("_")
223
  disentanglement = float(disentanglement)
224
 
225
- factors = torch.tensor(factors).cuda().view(-1, 1)
226
  srart_ss, edited_rgb = start_s
227
  edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
228
 
 
42
 
43
 
44
  class LatentEditor:
45
+ def __init__(self, domain="human_faces", device="cpu"):
46
 
47
  self.domain = domain
48
+ self.device = torch.device(device)
49
 
50
  if self.domain == "human_faces":
51
  self.interfacegan_directions = {
 
54
  "rotation": "editings/interfacegan_directions/rotation.pt",
55
  }
56
  self.interfacegan_tensors = {
57
+ name: torch.load(path, map_location=self.device)
58
  for name, path in self.interfacegan_directions.items()
59
  }
60
 
61
+ self.ganspace_pca = torch.load("editings/ganspace_pca/ffhq_pca.pt", map_location=self.device)
62
  self.ganspace_directions = {
63
  "eye_openness": (54, 7, 8, 5),
64
  "trimmed_beard": (58, 7, 9, 7),
 
148
 
149
 
150
  def load_styleclip_global(self):
151
+ delta_i_c = torch.from_numpy(np.load("editings/styleclip/global_mapper_data/delta_i_c.npy")).float().to(self.device)
152
  with open("editings/styleclip/global_mapper_data/S_mean_std", "rb") as channels_statistics:
153
  _, s_std = pickle.load(channels_statistics)
154
+ s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
155
  with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
156
  text_prompt_templates = templates.readlines()
157
  global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
 
180
  opts = argparse.Namespace(**opts)
181
  style_clip_net = StyleCLIPMapper(opts)
182
  style_clip_net.eval()
183
+ style_clip_net.to(self.device)
184
  direction = style_clip_net.mapper(start_w)
185
  for factor in factors:
186
  edited_latent = start_w + factor * direction
 
196
  disentanglement = float(disentanglement)
197
 
198
  directions = self.styleclip_global_editor.get_delta_s(neutral_text, target_text, disentanglement)
199
+ factors = torch.tensor(factors).to(self.device).view(-1, 1)
200
  srart_ss, start_rgb = start_s
201
 
202
  edits_rgb = []
 
223
  neutral_text, target_text, disentanglement = direction.split("_")
224
  disentanglement = float(disentanglement)
225
 
226
+ factors = torch.tensor(factors).to(self.device).view(-1, 1)
227
  srart_ss, edited_rgb = start_s
228
  edits_ss = self.deltaedit_editor.get_delta_s(neutral_text, target_text, disentanglement, original_image, srart_ss)
229
 
runners/base_runner.py CHANGED
@@ -74,7 +74,8 @@ class BaseRunner:
74
  return edited_latents
75
 
76
  def _setup_latent_editor(self):
77
- self.latent_editor = LatentEditor(self.config.exp.domain)
 
78
 
79
  def _setup_device(self):
80
  config_device = self.config.model["device"].lower()
 
74
  return edited_latents
75
 
76
  def _setup_latent_editor(self):
77
+ # Pass device to avoid unintended CUDA initialization on Spaces
78
+ self.latent_editor = LatentEditor(self.config.exp.domain, device=self.device)
79
 
80
  def _setup_device(self):
81
  config_device = self.config.model["device"].lower()