File size: 1,393 Bytes
b6acc0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
class FrozenCLIPEmbedder(nn.Module):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.layer = layer
self.layer_idx = layer_idx
if freeze:
self.transformer = self.transformer.eval()
for p in self.parameters():
p.requires_grad = False
def forward(self, text):
enc = self.tokenizer(
text, truncation=True, max_length=self.max_length,
return_length=True, return_overflowing_tokens=False,
padding="max_length", return_tensors="pt"
)
tokens = enc["input_ids"].to(next(self.transformer.parameters()).device)
out = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last":
return out.last_hidden_state
if self.layer == "pooled":
return out.pooler_output[:, None, :]
return out.hidden_states[self.layer_idx]
def encode(self, text):
return self(text)
|