import random import numpy as np import torch from transformers import PretrainedConfig import torch.nn.functional as F def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "CLIPTextModelWithProjection": from transformers import CLIPTextModelWithProjection return CLIPTextModelWithProjection else: raise ValueError(f"{model_class} is not supported.") def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts=0): prompt_embeds_list = [] captions = [] if type(prompt_batch) == str: prompt_batch = [prompt_batch] for caption in prompt_batch: if random.random() < proportion_empty_prompts: # randomly replace some captions with empty ones captions.append("") elif isinstance(caption, str): # keep the caption captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # This happens when passing multiple captions for the same image raise ValueError("Multiple captions were passed in the wrong format.") else: raise ValueError("Prompt is in the wrong format.") with torch.no_grad(): for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_inputs = tokenizer( captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(captions, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) print( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, ) # We are only interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") def get_generator(seed, device): if seed is not None: if isinstance(seed, list): generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] else: generator = torch.Generator(device).manual_seed(seed) else: generator = None return generator