Spaces:
Sleeping
Sleeping
Commit
·
ad8ad02
1
Parent(s):
cfa4803
Fix CLIP model device handling: pass device to StyleCLIPGlobalDirection and use .to(device) instead of .cuda()
Browse files
editings/latent_editor.py
CHANGED
|
@@ -154,7 +154,7 @@ class LatentEditor:
|
|
| 154 |
s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
|
| 155 |
with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
|
| 156 |
text_prompt_templates = templates.readlines()
|
| 157 |
-
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
|
| 158 |
return global_direction_calculator
|
| 159 |
|
| 160 |
|
|
|
|
| 154 |
s_std = [torch.from_numpy(s_i).float().to(self.device) for s_i in s_std]
|
| 155 |
with open("editings/styleclip/global_mapper_data/templates.txt", "r") as templates:
|
| 156 |
text_prompt_templates = templates.readlines()
|
| 157 |
+
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, device=self.device)
|
| 158 |
return global_direction_calculator
|
| 159 |
|
| 160 |
|
editings/styleclip/mapper/gloabl_mapper.py
CHANGED
|
@@ -11,7 +11,7 @@ STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128,
|
|
| 11 |
TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
|
| 12 |
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
|
| 13 |
|
| 14 |
-
def features_channels_to_s(s_without_torgb, s_std):
|
| 15 |
s = []
|
| 16 |
start_index_features = 0
|
| 17 |
for c in range(len(STYLESPACE_DIMENSIONS)):
|
|
@@ -20,19 +20,20 @@ def features_channels_to_s(s_without_torgb, s_std):
|
|
| 20 |
s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
|
| 21 |
start_index_features = end_index_features
|
| 22 |
else:
|
| 23 |
-
s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).
|
| 24 |
s_i = s_i.view(1, 1, -1, 1, 1)
|
| 25 |
s.append(s_i)
|
| 26 |
return s
|
| 27 |
|
| 28 |
class StyleCLIPGlobalDirection:
|
| 29 |
|
| 30 |
-
def __init__(self, delta_i_c, s_std, text_prompts_templates):
|
| 31 |
super(StyleCLIPGlobalDirection, self).__init__()
|
| 32 |
self.delta_i_c = delta_i_c
|
| 33 |
self.s_std = s_std
|
| 34 |
self.text_prompts_templates = text_prompts_templates
|
| 35 |
-
self.
|
|
|
|
| 36 |
|
| 37 |
def get_delta_s(self, neutral_text, target_text, beta):
|
| 38 |
delta_i = self.get_delta_i([target_text, neutral_text]).float()
|
|
@@ -43,7 +44,7 @@ class StyleCLIPGlobalDirection:
|
|
| 43 |
max_channel_value = torch.abs(delta_s).max()
|
| 44 |
if max_channel_value > 0:
|
| 45 |
delta_s /= max_channel_value
|
| 46 |
-
direction = features_channels_to_s(delta_s, self.s_std)
|
| 47 |
return direction
|
| 48 |
|
| 49 |
def get_delta_i(self, text_prompts):
|
|
@@ -61,11 +62,11 @@ class StyleCLIPGlobalDirection:
|
|
| 61 |
text_features_list = []
|
| 62 |
for text_prompt in text_prompts:
|
| 63 |
formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
|
| 64 |
-
formatted_text_prompts = clip.tokenize(formatted_text_prompts).
|
| 65 |
text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
|
| 66 |
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
|
| 67 |
text_embedding = text_embeddings.mean(dim=0)
|
| 68 |
text_embedding /= text_embedding.norm()
|
| 69 |
text_features_list.append(text_embedding)
|
| 70 |
-
text_features = torch.stack(text_features_list, dim=1).
|
| 71 |
return text_features.t()
|
|
|
|
| 11 |
TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
|
| 12 |
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
|
| 13 |
|
| 14 |
+
def features_channels_to_s(s_without_torgb, s_std, device="cpu"):
|
| 15 |
s = []
|
| 16 |
start_index_features = 0
|
| 17 |
for c in range(len(STYLESPACE_DIMENSIONS)):
|
|
|
|
| 20 |
s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
|
| 21 |
start_index_features = end_index_features
|
| 22 |
else:
|
| 23 |
+
s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).to(device)
|
| 24 |
s_i = s_i.view(1, 1, -1, 1, 1)
|
| 25 |
s.append(s_i)
|
| 26 |
return s
|
| 27 |
|
| 28 |
class StyleCLIPGlobalDirection:
|
| 29 |
|
| 30 |
+
def __init__(self, delta_i_c, s_std, text_prompts_templates, device="cpu"):
|
| 31 |
super(StyleCLIPGlobalDirection, self).__init__()
|
| 32 |
self.delta_i_c = delta_i_c
|
| 33 |
self.s_std = s_std
|
| 34 |
self.text_prompts_templates = text_prompts_templates
|
| 35 |
+
self.device = device
|
| 36 |
+
self.clip_model, _ = clip.load("ViT-B/32", device=device)
|
| 37 |
|
| 38 |
def get_delta_s(self, neutral_text, target_text, beta):
|
| 39 |
delta_i = self.get_delta_i([target_text, neutral_text]).float()
|
|
|
|
| 44 |
max_channel_value = torch.abs(delta_s).max()
|
| 45 |
if max_channel_value > 0:
|
| 46 |
delta_s /= max_channel_value
|
| 47 |
+
direction = features_channels_to_s(delta_s, self.s_std, self.device)
|
| 48 |
return direction
|
| 49 |
|
| 50 |
def get_delta_i(self, text_prompts):
|
|
|
|
| 62 |
text_features_list = []
|
| 63 |
for text_prompt in text_prompts:
|
| 64 |
formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
|
| 65 |
+
formatted_text_prompts = clip.tokenize(formatted_text_prompts).to(self.device) # tokenize
|
| 66 |
text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
|
| 67 |
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
|
| 68 |
text_embedding = text_embeddings.mean(dim=0)
|
| 69 |
text_embedding /= text_embedding.norm()
|
| 70 |
text_features_list.append(text_embedding)
|
| 71 |
+
text_features = torch.stack(text_features_list, dim=1).to(self.device)
|
| 72 |
return text_features.t()
|