File size: 1,098 Bytes
032e687 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import logging
from collections import OrderedDict
from typing import List, Union
import torch
from torch import nn
from .clip_wrapper import build_clip_text_encoder, get_clip_embeddings
from .clip_wrapper_open import build_openclip_text_encoder, get_openclip_embeddings
class TextModel(nn.Module):
def __init__(
self,
model_type,
model_name,
model_path,
):
super().__init__()
self.model_type = model_type
self.model_name = model_name
self.model_path = model_path
if self.model_type == "CLIP":
self.model = build_clip_text_encoder(model_path, pretrain=True)
if self.model_type == "OPENCLIP":
self.model, self.tokenizer = build_openclip_text_encoder(model_name, model_path)
self.model.eval()
def forward_text(self, text, prompt="a "):
if self.model_type == "CLIP":
return get_clip_embeddings(self.model, text, prompt)
if self.model_type == "OPENCLIP":
return get_openclip_embeddings(self.model, self.tokenizer, text, prompt)
|