| """CLIP text encoder - same interface as FrozenCLIPEmbedder (forward(text) returns last_hidden_state).""" | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| class CLIPTextEncoder(nn.Module): | |
| """CLIP text encoder wrapping transformers CLIPTokenizer + CLIPTextModel. | |
| Same interface as FrozenCLIPEmbedder: forward(text) returns last_hidden_state. | |
| """ | |
| def __init__( | |
| self, | |
| version: str = "openai/clip-vit-large-patch14", | |
| max_length: int = 77, | |
| freeze: bool = True, | |
| ): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.transformer = CLIPTextModel.from_pretrained(version) | |
| self.max_length = max_length | |
| if freeze: | |
| self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| """Encode text. Returns last_hidden_state (B, seq_len, dim).""" | |
| if isinstance(text, str): | |
| text = [text] | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(next(self.parameters()).device) | |
| outputs = self.transformer(input_ids=tokens) | |
| return outputs.last_hidden_state | |