import torch import torch.nn as nn from transformers import CLIPTextModel, CLIPTokenizer class FrozenCLIPEmbedder(nn.Module): def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length self.layer = layer self.layer_idx = layer_idx if freeze: self.transformer = self.transformer.eval() for p in self.parameters(): p.requires_grad = False def forward(self, text): enc = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt" ) tokens = enc["input_ids"].to(next(self.transformer.parameters()).device) out = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": return out.last_hidden_state if self.layer == "pooled": return out.pooler_output[:, None, :] return out.hidden_states[self.layer_idx] def encode(self, text): return self(text)