AYYasaswini commited on
Commit
86ab9bb
·
verified ·
1 Parent(s): bc226e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -172,9 +172,14 @@ def loss_style(prompt, style_embed, style_seed):
172
  # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
173
  replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
174
  replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
175
-
 
 
 
 
 
176
  # Insert this into the token embeddings
177
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
178
 
179
  # Combine with pos embs
180
  input_embeddings = token_embeddings + position_embeddings
 
172
  # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
173
  replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
174
  replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
175
+ indices = torch.where(input_ids[0] == 6829)[0] # Extract indices where the condition is True
176
+ print(f"indices: {indices}") # Debug print
177
+ for index in indices:
178
+ print(f"index: {index}") # Debug print
179
+ token_embeddings[0, index] = replacement_token_embedding.to(torch_device) # Update each index
180
+
181
  # Insert this into the token embeddings
182
+ # token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
183
 
184
  # Combine with pos embs
185
  input_embeddings = token_embeddings + position_embeddings