AYYasaswini commited on
Commit
81f1128
·
verified ·
1 Parent(s): f07ce06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -159,7 +159,8 @@ def generate_with_prompt_style(prompt, style, seed = 42):
159
  token_embeddings = token_emb_layer(input_ids)
160
  # The new embedding - our special birb word
161
  replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
162
-
 
163
  # Insert this into the token embeddings
164
  token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device)
165
 
 
159
  token_embeddings = token_emb_layer(input_ids)
160
  # The new embedding - our special birb word
161
  replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
162
+ replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
163
+ replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
164
  # Insert this into the token embeddings
165
  token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device)
166