AYYasaswini commited on
Commit
364d47d
·
verified ·
1 Parent(s): 62846d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -163,9 +163,13 @@ def loss_style(prompt, style_embed, style_seed):
163
 
164
  # Get token embeddings
165
  token_embeddings = token_emb_layer(input_ids)
 
166
 
167
  # The new embedding - our special birb word
168
  replacement_token_embedding = style_embed.to(torch_device)
 
 
 
169
 
170
  # Insert this into the token embeddings
171
  token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
 
163
 
164
  # Get token embeddings
165
  token_embeddings = token_emb_layer(input_ids)
166
+
167
 
168
  # The new embedding - our special birb word
169
  replacement_token_embedding = style_embed.to(torch_device)
170
+ # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
171
+ replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
172
+ replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
173
 
174
  # Insert this into the token embeddings
175
  token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)