|
|
|
|
|
import numbers |
|
|
import os |
|
|
from argparse import ArgumentParser, Namespace |
|
|
|
|
|
import mxnet as mx |
|
|
import numpy as np |
|
|
|
|
|
from ..app import MaskRenderer |
|
|
from ..data.rec_builder import RecBuilder |
|
|
from . import BaseInsightFaceCLICommand |
|
|
|
|
|
|
|
|
def rec_add_mask_param_command_factory(args: Namespace): |
|
|
|
|
|
return RecAddMaskParamCommand( |
|
|
args.input, args.output |
|
|
) |
|
|
|
|
|
|
|
|
class RecAddMaskParamCommand(BaseInsightFaceCLICommand): |
|
|
@staticmethod |
|
|
def register_subcommand(parser: ArgumentParser): |
|
|
_parser = parser.add_parser("rec.addmaskparam") |
|
|
_parser.add_argument("input", type=str, help="input rec") |
|
|
_parser.add_argument("output", type=str, help="output rec, with mask param") |
|
|
_parser.set_defaults(func=rec_add_mask_param_command_factory) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input: str, |
|
|
output: str, |
|
|
): |
|
|
self._input = input |
|
|
self._output = output |
|
|
|
|
|
|
|
|
def run(self): |
|
|
tool = MaskRenderer() |
|
|
tool.prepare(ctx_id=0, det_size=(128,128)) |
|
|
root_dir = self._input |
|
|
path_imgrec = os.path.join(root_dir, 'train.rec') |
|
|
path_imgidx = os.path.join(root_dir, 'train.idx') |
|
|
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') |
|
|
save_path = self._output |
|
|
wrec=RecBuilder(path=save_path) |
|
|
s = imgrec.read_idx(0) |
|
|
header, _ = mx.recordio.unpack(s) |
|
|
if header.flag > 0: |
|
|
if len(header.label)==2: |
|
|
imgidx = np.array(range(1, int(header.label[0]))) |
|
|
else: |
|
|
imgidx = np.array(list(self.imgrec.keys)) |
|
|
else: |
|
|
imgidx = np.array(list(self.imgrec.keys)) |
|
|
stat = [0, 0] |
|
|
print('total:', len(imgidx)) |
|
|
for iid, idx in enumerate(imgidx): |
|
|
|
|
|
|
|
|
if iid%1000==0: |
|
|
print('processing:', iid) |
|
|
s = imgrec.read_idx(idx) |
|
|
header, img = mx.recordio.unpack(s) |
|
|
label = header.label |
|
|
if not isinstance(label, numbers.Number): |
|
|
label = label[0] |
|
|
sample = mx.image.imdecode(img).asnumpy() |
|
|
bgr = sample[:,:,::-1] |
|
|
params = tool.build_params(bgr) |
|
|
|
|
|
|
|
|
|
|
|
stat[1] += 1 |
|
|
if params is None: |
|
|
wlabel = [label] + [-1.0]*236 |
|
|
stat[0] += 1 |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_label = tool.encode_params(params) |
|
|
wlabel = [label, 0.0]+mask_label |
|
|
if iid==0: |
|
|
print('param size:', len(mask_label), len(wlabel), label) |
|
|
assert len(wlabel)==237 |
|
|
wrec.add_image(img, wlabel) |
|
|
|
|
|
|
|
|
wrec.close() |
|
|
print('finished on', self._output, ', failed:', stat[0]) |
|
|
|
|
|
|