Update app.py
Browse files
app.py
CHANGED
|
@@ -172,9 +172,14 @@ def loss_style(prompt, style_embed, style_seed):
|
|
| 172 |
# Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
|
| 173 |
replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
|
| 174 |
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# Insert this into the token embeddings
|
| 177 |
-
|
| 178 |
|
| 179 |
# Combine with pos embs
|
| 180 |
input_embeddings = token_embeddings + position_embeddings
|
|
|
|
| 172 |
# Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
|
| 173 |
replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
|
| 174 |
replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
|
| 175 |
+
indices = torch.where(input_ids[0] == 6829)[0] # Extract indices where the condition is True
|
| 176 |
+
print(f"indices: {indices}") # Debug print
|
| 177 |
+
for index in indices:
|
| 178 |
+
print(f"index: {index}") # Debug print
|
| 179 |
+
token_embeddings[0, index] = replacement_token_embedding.to(torch_device) # Update each index
|
| 180 |
+
|
| 181 |
# Insert this into the token embeddings
|
| 182 |
+
# token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
|
| 183 |
|
| 184 |
# Combine with pos embs
|
| 185 |
input_embeddings = token_embeddings + position_embeddings
|