Update app.py
Browse files
app.py
CHANGED
|
@@ -163,9 +163,13 @@ def loss_style(prompt, style_embed, style_seed):
|
|
| 163 |
|
| 164 |
# Get token embeddings
|
| 165 |
token_embeddings = token_emb_layer(input_ids)
|
|
|
|
| 166 |
|
| 167 |
# The new embedding - our special birb word
|
| 168 |
replacement_token_embedding = style_embed.to(torch_device)
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Insert this into the token embeddings
|
| 171 |
token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|
|
|
|
| 163 |
|
| 164 |
# Get token embeddings
|
| 165 |
token_embeddings = token_emb_layer(input_ids)
|
| 166 |
+
|
| 167 |
|
| 168 |
# The new embedding - our special birb word
|
| 169 |
replacement_token_embedding = style_embed.to(torch_device)
|
| 170 |
+
# Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
|
| 171 |
+
replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
|
| 172 |
+
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 173 |
|
| 174 |
# Insert this into the token embeddings
|
| 175 |
token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|