| import torch |
| from torch import nn |
| from transformers import T5EncoderModel, CLIPModel, CLIPProcessor |
|
|
| from opensora.utils.utils import get_precision |
|
|
|
|
| class T5Wrapper(nn.Module): |
| def __init__(self, args): |
| super(T5Wrapper, self).__init__() |
| self.model_name = args.text_encoder_name |
| dtype = get_precision(args) |
| t5_model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype} |
| self.text_enc = T5EncoderModel.from_pretrained(self.model_name, **t5_model_kwargs).eval() |
|
|
| def forward(self, input_ids, attention_mask): |
| text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] |
| return text_encoder_embs.detach() |
|
|
| class CLIPWrapper(nn.Module): |
| def __init__(self, args): |
| super(CLIPWrapper, self).__init__() |
| self.model_name = args.text_encoder_name |
| dtype = get_precision(args) |
| model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype} |
| self.text_enc = CLIPModel.from_pretrained(self.model_name, **model_kwargs).eval() |
|
|
| def forward(self, input_ids, attention_mask): |
| text_encoder_embs = self.text_enc.get_text_features(input_ids=input_ids, attention_mask=attention_mask) |
| return text_encoder_embs.detach() |
|
|
|
|
|
|
| text_encoder = { |
| 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, |
| 'openai/clip-vit-large-patch14': CLIPWrapper |
| } |
|
|
|
|
| def get_text_enc(args): |
| """deprecation""" |
| text_enc = text_encoder.get(args.text_encoder_name, None) |
| assert text_enc is not None |
| return text_enc(args) |
|
|