LogicGoInfotechSpaces commited on
Commit
ad8ad02
·
1 Parent(s): cfa4803

Fix CLIP model device handling: pass device to StyleCLIPGlobalDirection and use .to(device) instead of .cuda()

Browse files
editings/latent_editor.py CHANGED
@@ -154,7 +154,7 @@ class LatentEditor:
154
  s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
155
  with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
156
  text_prompt_templates = templates.readlines()
157
- global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
158
  return global_direction_calculator
159
 
160
 
 
154
  s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
155
  with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
156
  text_prompt_templates = templates.readlines()
157
+ global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, device=self.device)
158
  return global_direction_calculator
159
 
160
 
editings/styleclip/mapper/gloabl_mapper.py CHANGED
@@ -11,7 +11,7 @@ STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128,
11
  TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
12
  STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
13
 
14
- def features_channels_to_s(s_without_torgb, s_std):
15
  s = []
16
  start_index_features = 0
17
  for c in range(len(STYLESPACE_DIMENSIONS)):
@@ -20,19 +20,20 @@ def features_channels_to_s(s_without_torgb, s_std):
20
  s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
21
  start_index_features = end_index_features
22
  else:
23
- s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).cuda()
24
  s_i = s_i.view(1, 1, -1, 1, 1)
25
  s.append(s_i)
26
  return s
27
 
28
  class StyleCLIPGlobalDirection:
29
 
30
- def __init__(self, delta_i_c, s_std, text_prompts_templates):
31
  super(StyleCLIPGlobalDirection, self).__init__()
32
  self.delta_i_c = delta_i_c
33
  self.s_std = s_std
34
  self.text_prompts_templates = text_prompts_templates
35
- self.clip_model, _ = clip.load("ViT-B/32", device="cuda")
 
36
 
37
  def get_delta_s(self, neutral_text, target_text, beta):
38
  delta_i = self.get_delta_i([target_text, neutral_text]).float()
@@ -43,7 +44,7 @@ class StyleCLIPGlobalDirection:
43
  max_channel_value = torch.abs(delta_s).max()
44
  if max_channel_value > 0:
45
  delta_s /= max_channel_value
46
- direction = features_channels_to_s(delta_s, self.s_std)
47
  return direction
48
 
49
  def get_delta_i(self, text_prompts):
@@ -61,11 +62,11 @@ class StyleCLIPGlobalDirection:
61
  text_features_list = []
62
  for text_prompt in text_prompts:
63
  formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
64
- formatted_text_prompts = clip.tokenize(formatted_text_prompts).cuda() # tokenize
65
  text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
66
  text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
67
  text_embedding = text_embeddings.mean(dim=0)
68
  text_embedding /= text_embedding.norm()
69
  text_features_list.append(text_embedding)
70
- text_features = torch.stack(text_features_list, dim=1).cuda()
71
  return text_features.t()
 
11
  TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
12
  STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
13
 
14
+ def features_channels_to_s(s_without_torgb, s_std, device="cpu"):
15
  s = []
16
  start_index_features = 0
17
  for c in range(len(STYLESPACE_DIMENSIONS)):
 
20
  s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
21
  start_index_features = end_index_features
22
  else:
23
+ s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).to(device)
24
  s_i = s_i.view(1, 1, -1, 1, 1)
25
  s.append(s_i)
26
  return s
27
 
28
  class StyleCLIPGlobalDirection:
29
 
30
+ def __init__(self, delta_i_c, s_std, text_prompts_templates, device="cpu"):
31
  super(StyleCLIPGlobalDirection, self).__init__()
32
  self.delta_i_c = delta_i_c
33
  self.s_std = s_std
34
  self.text_prompts_templates = text_prompts_templates
35
+ self.device = device
36
+ self.clip_model, _ = clip.load("ViT-B/32", device=device)
37
 
38
  def get_delta_s(self, neutral_text, target_text, beta):
39
  delta_i = self.get_delta_i([target_text, neutral_text]).float()
 
44
  max_channel_value = torch.abs(delta_s).max()
45
  if max_channel_value > 0:
46
  delta_s /= max_channel_value
47
+ direction = features_channels_to_s(delta_s, self.s_std, self.device)
48
  return direction
49
 
50
  def get_delta_i(self, text_prompts):
 
62
  text_features_list = []
63
  for text_prompt in text_prompts:
64
  formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
65
+ formatted_text_prompts = clip.tokenize(formatted_text_prompts).to(self.device) # tokenize
66
  text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
67
  text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
68
  text_embedding = text_embeddings.mean(dim=0)
69
  text_embedding /= text_embedding.norm()
70
  text_features_list.append(text_embedding)
71
+ text_features = torch.stack(text_features_list, dim=1).to(self.device)
72
  return text_features.t()