YuanTang96's picture
1
b30c1d8
import open_clip
from .point_encoder import PointcloudEncoder
import timm
def build_text_encoder():
clip_model_type = "EVA02-E-14-plus"
pretrained = "./pretrained_weight/clip_used_in_Uni3D/open_clip_pytorch_model.bin"
clip_model, _, _ = open_clip.create_model_and_transforms(model_name=clip_model_type, pretrained=pretrained)
return clip_model
def build_pc_encoder(args):
pretrained_pc = ''
drop_path_rate = 0.0
pc_encoder_type = getattr(args, 'pc_encoder_type', 'small')
if pc_encoder_type == "giant":
pc_model = "eva_giant_patch14_560"
args.pc_feat_dim = 1408
elif pc_encoder_type == "large":
pc_model = "eva02_large_patch14_448"
args.pc_feat_dim = 1024
elif pc_encoder_type == "base":
pc_model = "eva02_base_patch14_448"
args.pc_feat_dim = 768
elif pc_encoder_type == "small":
pc_model = "eva02_small_patch14_224"
args.pc_feat_dim = 384
elif pc_encoder_type == "tiny":
pc_model = "eva02_tiny_patch14_224"
args.pc_feat_dim = 192
# create transformer blocks for point cloud via timm
point_transformer = timm.create_model(pc_model, checkpoint_path=pretrained_pc, drop_path_rate= drop_path_rate)
# create whole point cloud encoder
point_encoder = PointcloudEncoder(point_transformer, args)
return point_encoder