File size: 1,482 Bytes
d295ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""CLIP text encoder for AeroGen. Uses transformers only (no ldm)."""

import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel


class AeroGenCLIPTextEncoder(nn.Module):
    """CLIP text encoder compatible with FrozenCLIPEmbedder interface.
    Uses transformers CLIPTextModel + CLIPTokenizer. No ldm dependency.
    """

    def __init__(self, version: str = "openai/clip-vit-large-patch14", device: str = "cuda", max_length: int = 77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        if isinstance(text, str):
            text = [text]
        batch_encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_length=True,
            return_overflowing_tokens=False,
            padding="max_length",
            return_tensors="pt",
        )
        device = next(self.parameters()).device
        tokens = batch_encoding["input_ids"].to(device)
        outputs = self.transformer(input_ids=tokens)
        return outputs.last_hidden_state

    def encode(self, text):
        return self(text)