import logging from collections import OrderedDict from typing import List, Union import torch from torch import nn from .clip_wrapper import build_clip_text_encoder, get_clip_embeddings from .clip_wrapper_open import build_openclip_text_encoder, get_openclip_embeddings class TextModel(nn.Module): def __init__( self, model_type, model_name, model_path, ): super().__init__() self.model_type = model_type self.model_name = model_name self.model_path = model_path if self.model_type == "CLIP": self.model = build_clip_text_encoder(model_path, pretrain=True) if self.model_type == "OPENCLIP": self.model, self.tokenizer = build_openclip_text_encoder(model_name, model_path) self.model.eval() def forward_text(self, text, prompt="a "): if self.model_type == "CLIP": return get_clip_embeddings(self.model, text, prompt) if self.model_type == "OPENCLIP": return get_openclip_embeddings(self.model, self.tokenizer, text, prompt)