concauu commited on
Commit
8686600
·
verified ·
1 Parent(s): 065d948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -115,7 +115,7 @@ def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
115
  return prompt_embeds, pooled_prompt_embeds
116
 
117
  # Override the high-level encode_prompt to use T5 encoding and return three outputs.
118
- # --- KEY CHANGE: Return token_ids as a tuple (text_token_ids, dummy_img_token_ids)
119
  def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
120
  negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
121
  text_inputs = self.tokenizer(
@@ -142,18 +142,13 @@ def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classif
142
  pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
143
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
144
  pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
145
- token_ids_text = text_inputs.input_ids
146
  else:
147
- token_ids_text = text_inputs.input_ids
148
  text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
149
  pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
150
- token_ids_text = token_ids_text.repeat_interleave(num_images_per_prompt, dim=0)
151
-
152
- # --- Create dummy image token ids with the same shape as text token ids.
153
- dummy_img_ids = torch.full_like(token_ids_text, fill_value=t5_tokenizer.pad_token_id)
154
-
155
- # Return a tuple of token id tensors.
156
- return text_embeddings, pooled_text_embeddings, (token_ids_text, dummy_img_ids)
157
 
158
  pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
159
  pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
@@ -171,7 +166,6 @@ def patched_time_embed(self, timestep, guidance, pooled_projections):
171
  text_out = self.fixed_text_proj(pooled_projections)
172
  return time_out + text_out
173
 
174
- # Patch the forward method.
175
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
176
 
177
  # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
 
115
  return prompt_embeds, pooled_prompt_embeds
116
 
117
  # Override the high-level encode_prompt to use T5 encoding and return three outputs.
118
+ # --- KEY CHANGE: Return token_ids as a single tensor.
119
  def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
120
  negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
121
  text_inputs = self.tokenizer(
 
142
  pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
143
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
144
  pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
145
+ token_ids = text_inputs.input_ids
146
  else:
147
+ token_ids = text_inputs.input_ids
148
  text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
149
  pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
150
+ token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
151
+ return text_embeddings, pooled_text_embeddings, token_ids
 
 
 
 
 
152
 
153
  pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
154
  pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
 
166
  text_out = self.fixed_text_proj(pooled_projections)
167
  return time_out + text_out
168
 
 
169
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
170
 
171
  # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----