pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
raw
history blame contribute delete
670 Bytes
from .skin_cvae_model import SkinCVAEModel
from .skin_fsq_cvae_model import SkinFSQCVAEModel
def get_model_cvae(
pretrained_path: str=None,
**kwargs
) -> SkinCVAEModel:
model = SkinCVAEModel(**kwargs)
if pretrained_path is not None:
state_dict = torch.load(pretrained_path, weights_only=True)
model.load_state_dict(state_dict)
return model
def get_model_fsq_cvae(
pretrained_path: str=None,
**kwargs
) -> SkinFSQCVAEModel:
model = SkinFSQCVAEModel(**kwargs)
if pretrained_path is not None:
state_dict = torch.load(pretrained_path, weights_only=True)
model.load_state_dict(state_dict)
return model