|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import PretrainedConfig |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def set_seed(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def import_model_class_from_model_name_or_path( |
|
|
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" |
|
|
): |
|
|
text_encoder_config = PretrainedConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
|
|
) |
|
|
model_class = text_encoder_config.architectures[0] |
|
|
|
|
|
if model_class == "CLIPTextModel": |
|
|
from transformers import CLIPTextModel |
|
|
|
|
|
return CLIPTextModel |
|
|
elif model_class == "CLIPTextModelWithProjection": |
|
|
from transformers import CLIPTextModelWithProjection |
|
|
|
|
|
return CLIPTextModelWithProjection |
|
|
else: |
|
|
raise ValueError(f"{model_class} is not supported.") |
|
|
|
|
|
def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts=0): |
|
|
prompt_embeds_list = [] |
|
|
captions = [] |
|
|
if type(prompt_batch) == str: |
|
|
prompt_batch = [prompt_batch] |
|
|
for caption in prompt_batch: |
|
|
if random.random() < proportion_empty_prompts: |
|
|
|
|
|
captions.append("") |
|
|
elif isinstance(caption, str): |
|
|
|
|
|
captions.append(caption) |
|
|
elif isinstance(caption, (list, np.ndarray)): |
|
|
|
|
|
raise ValueError("Multiple captions were passed in the wrong format.") |
|
|
else: |
|
|
raise ValueError("Prompt is in the wrong format.") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for tokenizer, text_encoder in zip(tokenizers, text_encoders): |
|
|
text_inputs = tokenizer( |
|
|
captions, |
|
|
padding="max_length", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
untruncated_ids = tokenizer(captions, padding="longest", return_tensors="pt").input_ids |
|
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
|
text_input_ids, untruncated_ids |
|
|
): |
|
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) |
|
|
print( |
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
|
f" {tokenizer.model_max_length} tokens: {removed_text}" |
|
|
) |
|
|
|
|
|
prompt_embeds = text_encoder( |
|
|
text_input_ids.to(text_encoder.device), |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
|
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
|
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
|
def is_torch2_available(): |
|
|
return hasattr(F, "scaled_dot_product_attention") |
|
|
|
|
|
def get_generator(seed, device): |
|
|
|
|
|
if seed is not None: |
|
|
if isinstance(seed, list): |
|
|
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] |
|
|
else: |
|
|
generator = torch.Generator(device).manual_seed(seed) |
|
|
else: |
|
|
generator = None |
|
|
|
|
|
return generator |