Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -115,7 +115,7 @@ def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
|
| 115 |
return prompt_embeds, pooled_prompt_embeds
|
| 116 |
|
| 117 |
# Override the high-level encode_prompt to use T5 encoding and return three outputs.
|
| 118 |
-
# --- KEY CHANGE: Return token_ids as a
|
| 119 |
def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
|
| 120 |
negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
|
| 121 |
text_inputs = self.tokenizer(
|
|
@@ -142,18 +142,13 @@ def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classif
|
|
| 142 |
pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
|
| 143 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
|
| 144 |
pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
|
| 145 |
-
|
| 146 |
else:
|
| 147 |
-
|
| 148 |
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 149 |
pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
# --- Create dummy image token ids with the same shape as text token ids.
|
| 153 |
-
dummy_img_ids = torch.full_like(token_ids_text, fill_value=t5_tokenizer.pad_token_id)
|
| 154 |
-
|
| 155 |
-
# Return a tuple of token id tensors.
|
| 156 |
-
return text_embeddings, pooled_text_embeddings, (token_ids_text, dummy_img_ids)
|
| 157 |
|
| 158 |
pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
|
| 159 |
pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
|
|
@@ -171,7 +166,6 @@ def patched_time_embed(self, timestep, guidance, pooled_projections):
|
|
| 171 |
text_out = self.fixed_text_proj(pooled_projections)
|
| 172 |
return time_out + text_out
|
| 173 |
|
| 174 |
-
# Patch the forward method.
|
| 175 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 176 |
|
| 177 |
# ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
|
|
|
|
| 115 |
return prompt_embeds, pooled_prompt_embeds
|
| 116 |
|
| 117 |
# Override the high-level encode_prompt to use T5 encoding and return three outputs.
|
| 118 |
+
# --- KEY CHANGE: Return token_ids as a single tensor.
|
| 119 |
def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
|
| 120 |
negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
|
| 121 |
text_inputs = self.tokenizer(
|
|
|
|
| 142 |
pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
|
| 143 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
|
| 144 |
pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
|
| 145 |
+
token_ids = text_inputs.input_ids
|
| 146 |
else:
|
| 147 |
+
token_ids = text_inputs.input_ids
|
| 148 |
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 149 |
pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 150 |
+
token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
| 151 |
+
return text_embeddings, pooled_text_embeddings, token_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
|
| 154 |
pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
|
|
|
|
| 166 |
text_out = self.fixed_text_proj(pooled_projections)
|
| 167 |
return time_out + text_out
|
| 168 |
|
|
|
|
| 169 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 170 |
|
| 171 |
# ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
|