File size: 1,393 Bytes
b6acc0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer


class FrozenCLIPEmbedder(nn.Module):
    def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.layer = layer
        self.layer_idx = layer_idx
        if freeze:
            self.transformer = self.transformer.eval()
            for p in self.parameters():
                p.requires_grad = False

    def forward(self, text):
        enc = self.tokenizer(
            text, truncation=True, max_length=self.max_length,
            return_length=True, return_overflowing_tokens=False,
            padding="max_length", return_tensors="pt"
        )
        tokens = enc["input_ids"].to(next(self.transformer.parameters()).device)
        out = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
        if self.layer == "last":
            return out.last_hidden_state
        if self.layer == "pooled":
            return out.pooler_output[:, None, :]
        return out.hidden_states[self.layer_idx]

    def encode(self, text):
        return self(text)