Update app.py
Browse files
app.py
CHANGED
|
@@ -144,8 +144,10 @@ def embed_style(prompt, style_embed, style_seed):
|
|
| 144 |
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 145 |
|
| 146 |
# Insert this into the token embeddings
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
# Combine with pos embs
|
| 150 |
input_embeddings = token_embeddings + position_embeddings
|
| 151 |
|
|
|
|
| 144 |
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 145 |
|
| 146 |
# Insert this into the token embeddings
|
| 147 |
+
# token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|
| 148 |
+
indices = torch.where(input_ids[0] == 6829)[0]
|
| 149 |
+
for index in indices:
|
| 150 |
+
token_embeddings[0, index] = replacement_token_embedding.to(torch_device)
|
| 151 |
# Combine with pos embs
|
| 152 |
input_embeddings = token_embeddings + position_embeddings
|
| 153 |
|