| | import random |
| | import numpy as np |
| | import torch |
| | from transformers import PretrainedConfig |
| |
|
| |
|
| | def set_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | def import_model_class_from_model_name_or_path( |
| | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" |
| | ): |
| | text_encoder_config = PretrainedConfig.from_pretrained( |
| | pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
| | ) |
| | model_class = text_encoder_config.architectures[0] |
| |
|
| | if model_class == "CLIPTextModel": |
| | from transformers import CLIPTextModel |
| |
|
| | return CLIPTextModel |
| | elif model_class == "CLIPTextModelWithProjection": |
| | from transformers import CLIPTextModelWithProjection |
| |
|
| | return CLIPTextModelWithProjection |
| | else: |
| | raise ValueError(f"{model_class} is not supported.") |
| |
|
| | def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts=0): |
| | prompt_embeds_list = [] |
| | captions = [] |
| | if type(prompt_batch) == str: |
| | prompt_batch = [prompt_batch] |
| | for caption in prompt_batch: |
| | if random.random() < proportion_empty_prompts: |
| | |
| | captions.append("") |
| | elif isinstance(caption, str): |
| | |
| | captions.append(caption) |
| | elif isinstance(caption, (list, np.ndarray)): |
| | |
| | raise ValueError("Multiple captions were passed in the wrong format.") |
| | else: |
| | raise ValueError("Prompt is in the wrong format.") |
| |
|
| | with torch.no_grad(): |
| | for tokenizer, text_encoder in zip(tokenizers, text_encoders): |
| | text_inputs = tokenizer( |
| | captions, |
| | padding="max_length", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids = text_inputs.input_ids |
| | untruncated_ids = tokenizer(captions, padding="longest", return_tensors="pt").input_ids |
| |
|
| | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| | text_input_ids, untruncated_ids |
| | ): |
| | removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) |
| | print( |
| | "The following part of your input was truncated because CLIP can only handle sequences up to" |
| | f" {tokenizer.model_max_length} tokens: {removed_text}" |
| | ) |
| |
|
| | prompt_embeds = text_encoder( |
| | text_input_ids.to(text_encoder.device), |
| | output_hidden_states=True, |
| | ) |
| |
|
| | |
| | pooled_prompt_embeds = prompt_embeds[0] |
| | prompt_embeds = prompt_embeds.hidden_states[-2] |
| | bs_embed, seq_len, _ = prompt_embeds.shape |
| | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
| | prompt_embeds_list.append(prompt_embeds) |
| | |
| | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
| | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
| | return prompt_embeds, pooled_prompt_embeds |
| |
|