Commit
·
a04f2f2
1
Parent(s):
17131ca
Fix token embedding size mismatch
Browse files- src/utils/style_generator.py +24 -14
src/utils/style_generator.py
CHANGED
|
@@ -77,6 +77,7 @@ class StyleTransfer:
|
|
| 77 |
|
| 78 |
# Get the expected dimension from the text encoder
|
| 79 |
expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1]
|
|
|
|
| 80 |
current_dim = embeds.shape[0]
|
| 81 |
|
| 82 |
# Resize embeddings if dimensions don't match
|
|
@@ -85,7 +86,8 @@ class StyleTransfer:
|
|
| 85 |
if current_dim > expected_dim:
|
| 86 |
embeds = embeds[:expected_dim]
|
| 87 |
else:
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
# Reshape to match expected dimensions
|
| 91 |
embeds = embeds.unsqueeze(0) # Add batch dimension
|
|
@@ -94,16 +96,21 @@ class StyleTransfer:
|
|
| 94 |
dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype
|
| 95 |
embeds = embeds.to(dtype)
|
| 96 |
|
| 97 |
-
# Add the token in tokenizer
|
| 98 |
token = token if token is not None else trained_token
|
| 99 |
-
self.pipeline.tokenizer.add_tokens(token)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
# Get the id for the token and assign the embeds
|
| 105 |
-
token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token)
|
| 106 |
-
self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0]
|
| 107 |
return token
|
| 108 |
|
| 109 |
def generate_artwork(self, prompt, selected_style):
|
|
@@ -157,20 +164,23 @@ class StyleTransfer:
|
|
| 157 |
loss = self._calculate_color_distance(latents_copy)
|
| 158 |
|
| 159 |
# Compute gradients
|
| 160 |
-
if loss.requires_grad:
|
| 161 |
grads = torch.autograd.grad(
|
| 162 |
outputs=loss,
|
| 163 |
inputs=latents_copy,
|
| 164 |
allow_unused=True,
|
| 165 |
retain_graph=False
|
| 166 |
-
)
|
| 167 |
|
| 168 |
-
if grads is not None:
|
| 169 |
-
# Apply gradients to original latents
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
except Exception as e:
|
| 173 |
print(f"Error in color enhancement: {e}")
|
|
|
|
| 174 |
|
| 175 |
return latents
|
| 176 |
|
|
|
|
| 77 |
|
| 78 |
# Get the expected dimension from the text encoder
|
| 79 |
expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1]
|
| 80 |
+
vocab_size = self.pipeline.text_encoder.get_input_embeddings().weight.shape[0]
|
| 81 |
current_dim = embeds.shape[0]
|
| 82 |
|
| 83 |
# Resize embeddings if dimensions don't match
|
|
|
|
| 86 |
if current_dim > expected_dim:
|
| 87 |
embeds = embeds[:expected_dim]
|
| 88 |
else:
|
| 89 |
+
padding = torch.zeros(expected_dim - current_dim, device=embeds.device, dtype=embeds.dtype)
|
| 90 |
+
embeds = torch.cat([embeds, padding], dim=0)
|
| 91 |
|
| 92 |
# Reshape to match expected dimensions
|
| 93 |
embeds = embeds.unsqueeze(0) # Add batch dimension
|
|
|
|
| 96 |
dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype
|
| 97 |
embeds = embeds.to(dtype)
|
| 98 |
|
| 99 |
+
# Add the token in tokenizer and handle embedding resize
|
| 100 |
token = token if token is not None else trained_token
|
| 101 |
+
num_added_tokens = self.pipeline.tokenizer.add_tokens(token)
|
| 102 |
|
| 103 |
+
if num_added_tokens > 0:
|
| 104 |
+
# Safely resize token embeddings
|
| 105 |
+
self.pipeline.text_encoder.resize_token_embeddings(len(self.pipeline.tokenizer))
|
| 106 |
+
|
| 107 |
+
# Get the id for the token and assign the embeds
|
| 108 |
+
token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token)
|
| 109 |
+
if token_id < self.pipeline.text_encoder.get_input_embeddings().weight.shape[0]:
|
| 110 |
+
self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
| 111 |
+
else:
|
| 112 |
+
print(f"Warning: Token ID {token_id} is out of bounds. Skipping embedding assignment.")
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
return token
|
| 115 |
|
| 116 |
def generate_artwork(self, prompt, selected_style):
|
|
|
|
| 164 |
loss = self._calculate_color_distance(latents_copy)
|
| 165 |
|
| 166 |
# Compute gradients
|
| 167 |
+
if loss is not None and loss.requires_grad:
|
| 168 |
grads = torch.autograd.grad(
|
| 169 |
outputs=loss,
|
| 170 |
inputs=latents_copy,
|
| 171 |
allow_unused=True,
|
| 172 |
retain_graph=False
|
| 173 |
+
)
|
| 174 |
|
| 175 |
+
if grads and grads[0] is not None:
|
| 176 |
+
# Apply gradients to original latents with safety checks
|
| 177 |
+
grad_tensor = grads[0].detach()
|
| 178 |
+
if grad_tensor.shape == latents.shape:
|
| 179 |
+
return latents - 0.1 * grad_tensor
|
| 180 |
+
|
| 181 |
except Exception as e:
|
| 182 |
print(f"Error in color enhancement: {e}")
|
| 183 |
+
# Continue without enhancement on error
|
| 184 |
|
| 185 |
return latents
|
| 186 |
|