offline_stores_try_on / util /garment_heatmap.py
Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
import argparse
import os
#import pprint
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import tools.init_paths
from lib.config import cfg
from lib.config import update_config
#from core.loss import JointsMSELoss
#from core.function import validate
#from core.inference import get_max_preds
#from utils.utils import create_logger
#import dataset
#from datasets.mannequin import MannequinDataset
#import models
import lib.models as models
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
# general
parser.add_argument('--cfg',
help='experiment configure file name',
# required=True,
default='experiments/deepfashion2/hrnet/top1_only.yaml',
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
args = parser.parse_args(args=[])
return args
class HeatmapGenerator():
def __init__(self, ckpt=None,useCUDA=True):
self.useCUDA=useCUDA
self.model = self._load_model(ckpt=ckpt)
def __call__(self, img):
return self.model(img)
def _load_model(self, ckpt):
if ckpt is None:
#ckpt= 'output/deepfashion2/pose_hrnet/top1_forward_only/2022-08-08-20-32/model_best.pth'
ckpt = 'pretrained_models/pose_hrnet/model_best.pth'
if not os.path.exists(ckpt):
print("Model not found, downloading from google drive...")
import gdown
path, name=os.path.split(ckpt)
os.makedirs(path,exist_ok=True)
id = "1voMopc-Cuq1Uls0ZAS-sDTCrHJ1M8zJJ"
output = ckpt
gdown.download(
f"https://drive.google.com/uc?export=download&confirm=pbef&id="+id,
output
)
args = parse_args()
args.opts=[]
update_config(cfg, args)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
cfg, is_train=False
)
model.load_state_dict(torch.load(ckpt), strict=True)
model.eval()
if torch.cuda.is_available() and self.useCUDA:
model=model.cuda()
return model