import torch from typing import Any, Callable, Dict, List, Optional, Union def _get_t5_prompt_embeds( text_encoder, tokenizer, prompt: Union[str, List[str]] = None, max_sequence_length: int = 226, num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( text_encoder, tokenizer, prompt: Union[str, List[str]], max_sequence_length: int = 226, num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. num_videos_per_prompt (`int`, *optional*, defaults to 1): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): torch dtype """ device = text_encoder[0].device dtype = text_encoder[0].dtype prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] prompt_embeds = _get_t5_prompt_embeds( text_encoder=text_encoder[0], tokenizer=tokenizer[0], prompt=prompt, max_sequence_length=max_sequence_length, num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=dtype, ) return prompt_embeds