DenseLabelDev / third_parts /APE /ape /modeling /text /text_encoder.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import logging
from collections import OrderedDict
from typing import List, Union
import torch
from torch import nn
from .clip_wrapper import build_clip_text_encoder, get_clip_embeddings
from .clip_wrapper_open import build_openclip_text_encoder, get_openclip_embeddings
class TextModel(nn.Module):
def __init__(
self,
model_type,
model_name,
model_path,
):
super().__init__()
self.model_type = model_type
self.model_name = model_name
self.model_path = model_path
if self.model_type == "CLIP":
self.model = build_clip_text_encoder(model_path, pretrain=True)
if self.model_type == "OPENCLIP":
self.model, self.tokenizer = build_openclip_text_encoder(model_name, model_path)
self.model.eval()
def forward_text(self, text, prompt="a "):
if self.model_type == "CLIP":
return get_clip_embeddings(self.model, text, prompt)
if self.model_type == "OPENCLIP":
return get_openclip_embeddings(self.model, self.tokenizer, text, prompt)