AYYasaswini commited on
Commit
62846d1
·
verified ·
1 Parent(s): 131e2fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -138,6 +138,10 @@ def embed_style(prompt, style_embed, style_seed):
138
  token_embeddings = token_emb_layer(input_ids)
139
 
140
  replacement_token_embedding = style_embed.to(torch_device)
 
 
 
 
141
 
142
  # Insert this into the token embeddings
143
  token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
 
138
  token_embeddings = token_emb_layer(input_ids)
139
 
140
  replacement_token_embedding = style_embed.to(torch_device)
141
+ # replacement_token_embedding = birb_embed[embed_values[4]].to(torch_device)
142
+ # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
143
+ replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
144
+ replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
145
 
146
  # Insert this into the token embeddings
147
  token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)