Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
elif opt.model == 'ASAPNet':
from .ASAPNet_model import ASAPNetModel, InferenceModel
if opt.isTrain:
model = ASAPNetModel()
else:
model = InferenceModel()
elif opt.model == 'ASAPNet_RGBA':
from .ASAPNet_rgba import ASAPNet_RGBA, InferenceModel
if opt.isTrain:
model = ASAPNet_RGBA()
else:
model = InferenceModel()
elif opt.model == 'pix2pixHD_DUAL_RGBA':
from .pix2pixHD_dual_rgba import Pix2PixHD_DUAL_RGBA, InferenceModel
if opt.isTrain:
model = Pix2PixHD_DUAL_RGBA()
else:
model = InferenceModel()
elif opt.model == 'pix2pixHD_RGBA':
from .pix2pixHD_rgba import Pix2PixHD_RGBA, InferenceModel
if opt.isTrain:
model = Pix2PixHD_RGBA()
else:
model = InferenceModel()
elif opt.model == 'pix2pixHD_RNN_RGBA':
from .pix2pixHD_rnn_rgba import Pix2PixHD_RNN_RGBA, InferenceModel
if opt.isTrain:
model = Pix2PixHD_RNN_RGBA()
else:
model = InferenceModel()
elif opt.model == 'pix2pixHD_mask':
from.pix2pixHD_mask import Pix2PixHDModel_Mask, InferenceModel_Mask
if opt.isTrain:
model = Pix2PixHDModel_Mask()
else:
model = InferenceModel_Mask()
elif opt.model == 'pix2pixHD_align':
from.pix2pixHD_align import Pix2PixHDModel_Align, InferenceModel_Align
if opt.isTrain:
model = Pix2PixHDModel_Align()
else:
model = InferenceModel_Align()
elif opt.model == 'pix2pixHD_Inpaint':
from.pix2pixHD_inpaint import Pix2PixHDModel_Inpaint, InferenceModel_Inpaint
if opt.isTrain:
model = Pix2PixHDModel_Inpaint()
else:
model = InferenceModel_Inpaint()
elif opt.model == 'UTransformerModel':
from .utransformer_model import UTransformerModel, InferenceModel
if opt.isTrain:
model = UTransformerModel()
else:
model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
#print(next(model.parameters()).device)
return model