Update app.py
Browse files
app.py
CHANGED
|
@@ -159,7 +159,8 @@ def generate_with_prompt_style(prompt, style, seed = 42):
|
|
| 159 |
token_embeddings = token_emb_layer(input_ids)
|
| 160 |
# The new embedding - our special birb word
|
| 161 |
replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
|
| 162 |
-
|
|
|
|
| 163 |
# Insert this into the token embeddings
|
| 164 |
token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device)
|
| 165 |
|
|
|
|
| 159 |
token_embeddings = token_emb_layer(input_ids)
|
| 160 |
# The new embedding - our special birb word
|
| 161 |
replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
|
| 162 |
+
replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
|
| 163 |
+
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 164 |
# Insert this into the token embeddings
|
| 165 |
token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device)
|
| 166 |
|