AYYasaswini commited on
Commit
ffaa8ff
·
verified ·
1 Parent(s): 26e2f18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -144,8 +144,10 @@ def embed_style(prompt, style_embed, style_seed):
144
  replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
145
 
146
  # Insert this into the token embeddings
147
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
148
-
 
 
149
  # Combine with pos embs
150
  input_embeddings = token_embeddings + position_embeddings
151
 
 
144
  replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
145
 
146
  # Insert this into the token embeddings
147
+ # token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
148
+ indices = torch.where(input_ids[0] == 6829)[0]
149
+ for index in indices:
150
+ token_embeddings[0, index] = replacement_token_embedding.to(torch_device)
151
  # Combine with pos embs
152
  input_embeddings = token_embeddings + position_embeddings
153