Update app.py
Browse files
app.py
CHANGED
|
@@ -138,6 +138,10 @@ def embed_style(prompt, style_embed, style_seed):
|
|
| 138 |
token_embeddings = token_emb_layer(input_ids)
|
| 139 |
|
| 140 |
replacement_token_embedding = style_embed.to(torch_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# Insert this into the token embeddings
|
| 143 |
token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|
|
|
|
| 138 |
token_embeddings = token_emb_layer(input_ids)
|
| 139 |
|
| 140 |
replacement_token_embedding = style_embed.to(torch_device)
|
| 141 |
+
# replacement_token_embedding = birb_embed[embed_values[4]].to(torch_device)
|
| 142 |
+
# Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
|
| 143 |
+
replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
|
| 144 |
+
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 145 |
|
| 146 |
# Insert this into the token embeddings
|
| 147 |
token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|