from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast import logging def load_text_encoders(args, class_one, class_two): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") def create_logger(logging_dir,accelerator): """ Create a logger that writes to a log file and stdout. """ if accelerator.is_main_process: # real logger logging.basicConfig( level=logging.INFO, format="[\033[34m%(asctime)s\033[0m] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[ logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt"), ], ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) return logger