import numbers import os from argparse import ArgumentParser, Namespace import mxnet as mx import numpy as np from ..app import MaskRenderer from ..data.rec_builder import RecBuilder from . import BaseInsightFaceCLICommand def rec_add_mask_param_command_factory(args: Namespace): return RecAddMaskParamCommand( args.input, args.output ) class RecAddMaskParamCommand(BaseInsightFaceCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): _parser = parser.add_parser("rec.addmaskparam") _parser.add_argument("input", type=str, help="input rec") _parser.add_argument("output", type=str, help="output rec, with mask param") _parser.set_defaults(func=rec_add_mask_param_command_factory) def __init__( self, input: str, output: str, ): self._input = input self._output = output def run(self): tool = MaskRenderer() tool.prepare(ctx_id=0, det_size=(128,128)) root_dir = self._input path_imgrec = os.path.join(root_dir, 'train.rec') path_imgidx = os.path.join(root_dir, 'train.idx') imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') save_path = self._output wrec=RecBuilder(path=save_path) s = imgrec.read_idx(0) header, _ = mx.recordio.unpack(s) if header.flag > 0: if len(header.label)==2: imgidx = np.array(range(1, int(header.label[0]))) else: imgidx = np.array(list(self.imgrec.keys)) else: imgidx = np.array(list(self.imgrec.keys)) stat = [0, 0] print('total:', len(imgidx)) for iid, idx in enumerate(imgidx): #if iid==500000: # break if iid%1000==0: print('processing:', iid) s = imgrec.read_idx(idx) header, img = mx.recordio.unpack(s) label = header.label if not isinstance(label, numbers.Number): label = label[0] sample = mx.image.imdecode(img).asnumpy() bgr = sample[:,:,::-1] params = tool.build_params(bgr) #if iid<10: # mask_out = tool.render_mask(bgr, 'mask_blue', params) # cv2.imwrite('maskout_%d.jpg'%iid, mask_out) stat[1] += 1 if params is None: wlabel = [label] + [-1.0]*236 stat[0] += 1 else: #print(0, params[0].shape, params[0].dtype) #print(1, params[1].shape, params[1].dtype) #print(2, params[2]) #print(3, len(params[3]), params[3][0].__class__) #print(4, params[4].shape, params[4].dtype) mask_label = tool.encode_params(params) wlabel = [label, 0.0]+mask_label # 237 including idlabel, total mask params size is 235 if iid==0: print('param size:', len(mask_label), len(wlabel), label) assert len(wlabel)==237 wrec.add_image(img, wlabel) #print(len(params)) wrec.close() print('finished on', self._output, ', failed:', stat[0])