HSinghHuggingFace commited on
Commit
a04f2f2
·
1 Parent(s): 17131ca

Fix token embedding size mismatch

Browse files
Files changed (1) hide show
  1. src/utils/style_generator.py +24 -14
src/utils/style_generator.py CHANGED
@@ -77,6 +77,7 @@ class StyleTransfer:
77
 
78
  # Get the expected dimension from the text encoder
79
  expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1]
 
80
  current_dim = embeds.shape[0]
81
 
82
  # Resize embeddings if dimensions don't match
@@ -85,7 +86,8 @@ class StyleTransfer:
85
  if current_dim > expected_dim:
86
  embeds = embeds[:expected_dim]
87
  else:
88
- embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0)
 
89
 
90
  # Reshape to match expected dimensions
91
  embeds = embeds.unsqueeze(0) # Add batch dimension
@@ -94,16 +96,21 @@ class StyleTransfer:
94
  dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype
95
  embeds = embeds.to(dtype)
96
 
97
- # Add the token in tokenizer
98
  token = token if token is not None else trained_token
99
- self.pipeline.tokenizer.add_tokens(token)
100
 
101
- # Resize the token embeddings
102
- self.pipeline.text_encoder.resize_token_embeddings(len(self.pipeline.tokenizer))
 
 
 
 
 
 
 
 
103
 
104
- # Get the id for the token and assign the embeds
105
- token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token)
106
- self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0]
107
  return token
108
 
109
  def generate_artwork(self, prompt, selected_style):
@@ -157,20 +164,23 @@ class StyleTransfer:
157
  loss = self._calculate_color_distance(latents_copy)
158
 
159
  # Compute gradients
160
- if loss.requires_grad:
161
  grads = torch.autograd.grad(
162
  outputs=loss,
163
  inputs=latents_copy,
164
  allow_unused=True,
165
  retain_graph=False
166
- )[0]
167
 
168
- if grads is not None:
169
- # Apply gradients to original latents
170
- return latents - 0.1 * grads.detach()
171
-
 
 
172
  except Exception as e:
173
  print(f"Error in color enhancement: {e}")
 
174
 
175
  return latents
176
 
 
77
 
78
  # Get the expected dimension from the text encoder
79
  expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1]
80
+ vocab_size = self.pipeline.text_encoder.get_input_embeddings().weight.shape[0]
81
  current_dim = embeds.shape[0]
82
 
83
  # Resize embeddings if dimensions don't match
 
86
  if current_dim > expected_dim:
87
  embeds = embeds[:expected_dim]
88
  else:
89
+ padding = torch.zeros(expected_dim - current_dim, device=embeds.device, dtype=embeds.dtype)
90
+ embeds = torch.cat([embeds, padding], dim=0)
91
 
92
  # Reshape to match expected dimensions
93
  embeds = embeds.unsqueeze(0) # Add batch dimension
 
96
  dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype
97
  embeds = embeds.to(dtype)
98
 
99
+ # Add the token in tokenizer and handle embedding resize
100
  token = token if token is not None else trained_token
101
+ num_added_tokens = self.pipeline.tokenizer.add_tokens(token)
102
 
103
+ if num_added_tokens > 0:
104
+ # Safely resize token embeddings
105
+ self.pipeline.text_encoder.resize_token_embeddings(len(self.pipeline.tokenizer))
106
+
107
+ # Get the id for the token and assign the embeds
108
+ token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token)
109
+ if token_id < self.pipeline.text_encoder.get_input_embeddings().weight.shape[0]:
110
+ self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds
111
+ else:
112
+ print(f"Warning: Token ID {token_id} is out of bounds. Skipping embedding assignment.")
113
 
 
 
 
114
  return token
115
 
116
  def generate_artwork(self, prompt, selected_style):
 
164
  loss = self._calculate_color_distance(latents_copy)
165
 
166
  # Compute gradients
167
+ if loss is not None and loss.requires_grad:
168
  grads = torch.autograd.grad(
169
  outputs=loss,
170
  inputs=latents_copy,
171
  allow_unused=True,
172
  retain_graph=False
173
+ )
174
 
175
+ if grads and grads[0] is not None:
176
+ # Apply gradients to original latents with safety checks
177
+ grad_tensor = grads[0].detach()
178
+ if grad_tensor.shape == latents.shape:
179
+ return latents - 0.1 * grad_tensor
180
+
181
  except Exception as e:
182
  print(f"Error in color enhancement: {e}")
183
+ # Continue without enhancement on error
184
 
185
  return latents
186