| import torch | |
| from torch import nn | |
| import transformers | |
| class ClipTokenizer: | |
| def __init__(self): | |
| super().__init__() | |
| self.tokenizer = transformers.CLIPTokenizer.from_pretrained( | |
| "openai/clip-vit-base-patch32" | |
| ) | |
| def __call__(self, instructions): | |
| return self.tokenizer( | |
| instructions, | |
| padding="longest", | |
| return_tensors="pt" | |
| )["input_ids"] | |
| class ClipTextEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = transformers.CLIPTextModel.from_pretrained( | |
| "openai/clip-vit-base-patch32" | |
| ) | |
| def forward(self, tokens): | |
| return self.model(tokens).last_hidden_state | |