"""CLIP text encoder for AeroGen. Uses transformers only (no ldm).""" import torch.nn as nn from transformers import CLIPTokenizer, CLIPTextModel class AeroGenCLIPTextEncoder(nn.Module): """CLIP text encoder compatible with FrozenCLIPEmbedder interface. Uses transformers CLIPTextModel + CLIPTokenizer. No ldm dependency. """ def __init__(self, version: str = "openai/clip-vit-large-patch14", device: str = "cuda", max_length: int = 77): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): if isinstance(text, str): text = [text] batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) device = next(self.parameters()).device tokens = batch_encoding["input_ids"].to(device) outputs = self.transformer(input_ids=tokens) return outputs.last_hidden_state def encode(self, text): return self(text)