| from torch import nn | |
| import transformers | |
| from .modeling_gpt2 import GPT2LMHeadModel | |
| from .configuration_gptvision import GPT2Config | |
| transformers.logging.set_verbosity_error() | |
| class TextModel(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| if type(config.gpt2_config) == dict: | |
| gpt2_config = GPT2Config(**config.gpt2_config) | |
| else: | |
| gpt2_config = config.gpt2_config | |
| self.model = GPT2LMHeadModel(gpt2_config) | |
| self.text_emb = self.model.get_input_embeddings() |