"""CLIP text encoder - same interface as FrozenCLIPEmbedder (forward(text) returns last_hidden_state).""" import torch import torch.nn as nn from transformers import CLIPTokenizer, CLIPTextModel class CLIPTextEncoder(nn.Module): """CLIP text encoder wrapping transformers CLIPTokenizer + CLIPTextModel. Same interface as FrozenCLIPEmbedder: forward(text) returns last_hidden_state. """ def __init__( self, version: str = "openai/clip-vit-large-patch14", max_length: int = 77, freeze: bool = True, ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.max_length = max_length if freeze: self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): """Encode text. Returns last_hidden_state (B, seq_len, dim).""" if isinstance(text, str): text = [text] batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(next(self.parameters()).device) outputs = self.transformer(input_ids=tokens) return outputs.last_hidden_state