File size: 751 Bytes
5ce8761 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch
from torch import nn
import transformers
class ClipTokenizer:
def __init__(self):
super().__init__()
self.tokenizer = transformers.CLIPTokenizer.from_pretrained(
"openai/clip-vit-base-patch32"
)
@torch.inference_mode()
def __call__(self, instructions):
return self.tokenizer(
instructions,
padding="longest",
return_tensors="pt"
)["input_ids"]
class ClipTextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = transformers.CLIPTextModel.from_pretrained(
"openai/clip-vit-base-patch32"
)
def forward(self, tokens):
return self.model(tokens).last_hidden_state
|