| | """ |
| | Here are some use cases: |
| | python main.py --config config/all.yaml --experiment experiment_8x1 --signature demo1 --target data/demo1.png |
| | """ |
| | import pydiffvg |
| | import torch |
| | import cv2 |
| | import matplotlib.pyplot as plt |
| | import random |
| | import argparse |
| | import math |
| | import errno |
| | from tqdm import tqdm |
| | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR |
| | from torch.nn.functional import adaptive_avg_pool2d |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | import PIL |
| | import PIL.Image |
| | import os |
| | import os.path as osp |
| | import numpy as np |
| | import numpy.random as npr |
| | import shutil |
| | import copy |
| | |
| | from xing_loss import xing_loss |
| |
|
| | import yaml |
| | from easydict import EasyDict as edict |
| |
|
| |
|
| | pydiffvg.set_print_timing(False) |
| | gamma = 1.0 |
| |
|
| | |
| | |
| | |
| |
|
| | from utils import \ |
| | get_experiment_id, \ |
| | get_path_schedule, \ |
| | edict_2_dict, \ |
| | check_and_create_dir |
| |
|
| | def get_bezier_circle(radius=1, segments=4, bias=None): |
| | points = [] |
| | if bias is None: |
| | bias = (random.random(), random.random()) |
| | avg_degree = 360 / (segments*3) |
| | for i in range(0, segments*3): |
| | point = (np.cos(np.deg2rad(i * avg_degree)), |
| | np.sin(np.deg2rad(i * avg_degree))) |
| | points.append(point) |
| | points = torch.tensor(points) |
| | points = (points)*radius + torch.tensor(bias).unsqueeze(dim=0) |
| | points = points.type(torch.FloatTensor) |
| | return points |
| |
|
| | def get_sdf(phi, method='skfmm', **kwargs): |
| | if method == 'skfmm': |
| | import skfmm |
| | phi = (phi-0.5)*2 |
| | if (phi.max() <= 0) or (phi.min() >= 0): |
| | return np.zeros(phi.shape).astype(np.float32) |
| | sd = skfmm.distance(phi, dx=1) |
| |
|
| | flip_negative = kwargs.get('flip_negative', True) |
| | if flip_negative: |
| | sd = np.abs(sd) |
| |
|
| | truncate = kwargs.get('truncate', 10) |
| | sd = np.clip(sd, -truncate, truncate) |
| | |
| |
|
| | zero2max = kwargs.get('zero2max', True) |
| | if zero2max and flip_negative: |
| | sd = sd.max() - sd |
| | elif zero2max: |
| | raise ValueError |
| |
|
| | normalize = kwargs.get('normalize', 'sum') |
| | if normalize == 'sum': |
| | sd /= sd.sum() |
| | elif normalize == 'to1': |
| | sd /= sd.max() |
| | return sd |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--debug', action='store_true', default=False) |
| | parser.add_argument("--config", type=str) |
| | parser.add_argument("--experiment", type=str) |
| | parser.add_argument("--seed", type=int) |
| | parser.add_argument("--target", type=str, help="target image path") |
| | parser.add_argument('--log_dir', metavar='DIR', default="log/debug") |
| | parser.add_argument('--initial', type=str, default="random", choices=['random', 'circle']) |
| | parser.add_argument('--signature', nargs='+', type=str) |
| | parser.add_argument('--seginit', nargs='+', type=str) |
| | parser.add_argument("--num_segments", type=int, default=4) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | cfg = edict() |
| | args = parser.parse_args() |
| | cfg.debug = args.debug |
| | cfg.config = args.config |
| | cfg.experiment = args.experiment |
| | cfg.seed = args.seed |
| | cfg.target = args.target |
| | cfg.log_dir = args.log_dir |
| | cfg.initial = args.initial |
| | cfg.signature = args.signature |
| | |
| | cfg.num_segments = args.num_segments |
| | if args.seginit is not None: |
| | cfg.seginit = edict() |
| | cfg.seginit.type = args.seginit[0] |
| | if cfg.seginit.type == 'circle': |
| | cfg.seginit.radius = float(args.seginit[1]) |
| | return cfg |
| |
|
| | def ycrcb_conversion(im, format='[bs x 3 x 2D]', reverse=False): |
| | mat = torch.FloatTensor([ |
| | [ 65.481/255, 128.553/255, 24.966/255], |
| | [-37.797/255, -74.203/255, 112.000/255], |
| | [112.000/255, -93.786/255, -18.214/255], |
| | ]).to(im.device) |
| |
|
| | if reverse: |
| | mat = mat.inverse() |
| |
|
| | if format == '[bs x 3 x 2D]': |
| | im = im.permute(0, 2, 3, 1) |
| | im = torch.matmul(im, mat.T) |
| | im = im.permute(0, 3, 1, 2).contiguous() |
| | return im |
| | elif format == '[2D x 3]': |
| | im = torch.matmul(im, mat.T) |
| | return im |
| | else: |
| | raise ValueError |
| |
|
| | class random_coord_init(): |
| | def __init__(self, canvas_size): |
| | self.canvas_size = canvas_size |
| | def __call__(self): |
| | h, w = self.canvas_size |
| | return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h] |
| |
|
| | class naive_coord_init(): |
| | def __init__(self, pred, gt, format='[bs x c x 2D]', replace_sampling=True): |
| | if isinstance(pred, torch.Tensor): |
| | pred = pred.detach().cpu().numpy() |
| | if isinstance(gt, torch.Tensor): |
| | gt = gt.detach().cpu().numpy() |
| |
|
| | if format == '[bs x c x 2D]': |
| | self.map = ((pred[0] - gt[0])**2).sum(0) |
| | elif format == ['[2D x c]']: |
| | self.map = ((pred - gt)**2).sum(-1) |
| | else: |
| | raise ValueError |
| | self.replace_sampling = replace_sampling |
| |
|
| | def __call__(self): |
| | coord = np.where(self.map == self.map.max()) |
| | coord_h, coord_w = coord[0][0], coord[1][0] |
| | if self.replace_sampling: |
| | self.map[coord_h, coord_w] = -1 |
| | return [coord_w, coord_h] |
| |
|
| |
|
| | class sparse_coord_init(): |
| | def __init__(self, pred, gt, format='[bs x c x 2D]', quantile_interval=200, nodiff_thres=0.1): |
| | if isinstance(pred, torch.Tensor): |
| | pred = pred.detach().cpu().numpy() |
| | if isinstance(gt, torch.Tensor): |
| | gt = gt.detach().cpu().numpy() |
| | if format == '[bs x c x 2D]': |
| | self.map = ((pred[0] - gt[0])**2).sum(0) |
| | self.reference_gt = copy.deepcopy( |
| | np.transpose(gt[0], (1, 2, 0))) |
| | elif format == ['[2D x c]']: |
| | self.map = (np.abs(pred - gt)).sum(-1) |
| | self.reference_gt = copy.deepcopy(gt[0]) |
| | else: |
| | raise ValueError |
| | |
| | self.map[self.map < nodiff_thres] = 0 |
| | quantile_interval = np.linspace(0., 1., quantile_interval) |
| | quantized_interval = np.quantile(self.map, quantile_interval) |
| | |
| | quantized_interval = np.unique(quantized_interval) |
| | quantized_interval = sorted(quantized_interval[1:-1]) |
| | self.map = np.digitize(self.map, quantized_interval, right=False) |
| | self.map = np.clip(self.map, 0, 255).astype(np.uint8) |
| | self.idcnt = {} |
| | for idi in sorted(np.unique(self.map)): |
| | self.idcnt[idi] = (self.map==idi).sum() |
| | self.idcnt.pop(min(self.idcnt.keys())) |
| | |
| | def __call__(self): |
| | if len(self.idcnt) == 0: |
| | h, w = self.map.shape |
| | return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h] |
| | target_id = max(self.idcnt, key=self.idcnt.get) |
| | _, component, cstats, ccenter = cv2.connectedComponentsWithStats( |
| | (self.map==target_id).astype(np.uint8), connectivity=4) |
| | |
| | csize = [ci[-1] for ci in cstats[1:]] |
| | target_cid = csize.index(max(csize))+1 |
| | center = ccenter[target_cid][::-1] |
| | coord = np.stack(np.where(component == target_cid)).T |
| | dist = np.linalg.norm(coord-center, axis=1) |
| | target_coord_id = np.argmin(dist) |
| | coord_h, coord_w = coord[target_coord_id] |
| | |
| | self.idcnt[target_id] -= max(csize) |
| | if self.idcnt[target_id] == 0: |
| | self.idcnt.pop(target_id) |
| | self.map[component == target_cid] = 0 |
| | return [coord_w, coord_h] |
| |
|
| |
|
| | def init_shapes(num_paths, |
| | num_segments, |
| | canvas_size, |
| | seginit_cfg, |
| | shape_cnt, |
| | pos_init_method=None, |
| | trainable_stroke=False, |
| | gt=None, |
| | **kwargs): |
| | shapes = [] |
| | shape_groups = [] |
| | h, w = canvas_size |
| |
|
| | |
| | if pos_init_method is None: |
| | pos_init_method = random_coord_init(canvas_size=canvas_size) |
| |
|
| | for i in range(num_paths): |
| | num_control_points = [2] * num_segments |
| |
|
| | if seginit_cfg.type=="random": |
| | points = [] |
| | p0 = pos_init_method() |
| | color_ref = copy.deepcopy(p0) |
| | points.append(p0) |
| | for j in range(num_segments): |
| | radius = seginit_cfg.radius |
| | p1 = (p0[0] + radius * npr.uniform(-0.5, 0.5), |
| | p0[1] + radius * npr.uniform(-0.5, 0.5)) |
| | p2 = (p1[0] + radius * npr.uniform(-0.5, 0.5), |
| | p1[1] + radius * npr.uniform(-0.5, 0.5)) |
| | p3 = (p2[0] + radius * npr.uniform(-0.5, 0.5), |
| | p2[1] + radius * npr.uniform(-0.5, 0.5)) |
| | points.append(p1) |
| | points.append(p2) |
| | if j < num_segments - 1: |
| | points.append(p3) |
| | p0 = p3 |
| | points = torch.FloatTensor(points) |
| |
|
| | |
| | elif seginit_cfg.type=="circle": |
| | radius = seginit_cfg.radius |
| | if radius is None: |
| | radius = npr.uniform(0.5, 1) |
| | center = pos_init_method() |
| | color_ref = copy.deepcopy(center) |
| | points = get_bezier_circle( |
| | radius=radius, segments=num_segments, |
| | bias=center) |
| |
|
| | path = pydiffvg.Path(num_control_points = torch.LongTensor(num_control_points), |
| | points = points, |
| | stroke_width = torch.tensor(0.0), |
| | is_closed = True) |
| | shapes.append(path) |
| | |
| |
|
| | if gt is not None: |
| | wref, href = color_ref |
| | wref = max(0, min(int(wref), w-1)) |
| | href = max(0, min(int(href), h-1)) |
| | fill_color_init = list(gt[0, :, href, wref]) + [1.] |
| | fill_color_init = torch.FloatTensor(fill_color_init) |
| | stroke_color_init = torch.FloatTensor(npr.uniform(size=[4])) |
| | else: |
| | fill_color_init = torch.FloatTensor(npr.uniform(size=[4])) |
| | stroke_color_init = torch.FloatTensor(npr.uniform(size=[4])) |
| |
|
| | path_group = pydiffvg.ShapeGroup( |
| | shape_ids = torch.LongTensor([shape_cnt+i]), |
| | fill_color = fill_color_init, |
| | stroke_color = stroke_color_init, |
| | ) |
| | shape_groups.append(path_group) |
| |
|
| | point_var = [] |
| | color_var = [] |
| |
|
| | for path in shapes: |
| | path.points.requires_grad = True |
| | point_var.append(path.points) |
| | for group in shape_groups: |
| | group.fill_color.requires_grad = True |
| | color_var.append(group.fill_color) |
| |
|
| | if trainable_stroke: |
| | stroke_width_var = [] |
| | stroke_color_var = [] |
| | for path in shapes: |
| | path.stroke_width.requires_grad = True |
| | stroke_width_var.append(path.stroke_width) |
| | for group in shape_groups: |
| | group.stroke_color.requires_grad = True |
| | stroke_color_var.append(group.stroke_color) |
| | return shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var |
| | else: |
| | return shapes, shape_groups, point_var, color_var |
| |
|
| | class linear_decay_lrlambda_f(object): |
| | def __init__(self, decay_every, decay_ratio): |
| | self.decay_every = decay_every |
| | self.decay_ratio = decay_ratio |
| |
|
| | def __call__(self, n): |
| | decay_time = n//self.decay_every |
| | decay_step = n %self.decay_every |
| | lr_s = self.decay_ratio**decay_time |
| | lr_e = self.decay_ratio**(decay_time+1) |
| | r = decay_step/self.decay_every |
| | lr = lr_s * (1-r) + lr_e * r |
| | return lr |
| |
|
| | def main_func(target, experiment, num_iter, cfg_arg): |
| | with open(cfg_arg.config, 'r') as f: |
| | cfg = yaml.load(f, Loader=yaml.FullLoader) |
| | cfg_default = edict(cfg['default']) |
| | cfg = edict(cfg[cfg_arg.experiment]) |
| | cfg.update(cfg_default) |
| | cfg.update(cfg_arg) |
| | cfg.exid = get_experiment_id(cfg.debug) |
| |
|
| | cfg.experiment_dir = \ |
| | osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature))) |
| | cfg.target = target |
| | cfg.experiment = experiment |
| | cfg.num_iter = num_iter |
| |
|
| | configfile = osp.join(cfg.experiment_dir, 'config.yaml') |
| | check_and_create_dir(configfile) |
| | with open(osp.join(configfile), 'w') as f: |
| | yaml.dump(edict_2_dict(cfg), f) |
| |
|
| | |
| | pydiffvg.set_use_gpu(torch.cuda.is_available()) |
| | device = pydiffvg.get_device() |
| |
|
| | |
| | gt = np.array(cfg.target) |
| | print(f"Input image shape is: {gt.shape}") |
| | if len(gt.shape) == 2: |
| | print("Converting the gray-scale image to RGB.") |
| | gt = gt.unsqueeze(dim=-1).repeat(1,1,3) |
| | if gt.shape[2] == 4: |
| | print("Input image includes alpha channel, simply dropout alpha channel.") |
| | gt = gt[:, :, :3] |
| | gt = (gt/255).astype(np.float32) |
| | gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device) |
| | if cfg.use_ycrcb: |
| | gt = ycrcb_conversion(gt) |
| | h, w = gt.shape[2:] |
| |
|
| | path_schedule = get_path_schedule(**cfg.path_schedule) |
| |
|
| | if cfg.seed is not None: |
| | random.seed(cfg.seed) |
| | npr.seed(cfg.seed) |
| | torch.manual_seed(cfg.seed) |
| | render = pydiffvg.RenderFunction.apply |
| |
|
| | shapes_record, shape_groups_record = [], [] |
| |
|
| | region_loss = None |
| | loss_matrix = [] |
| |
|
| | para_point, para_color = {}, {} |
| | if cfg.trainable.stroke: |
| | para_stroke_width, para_stroke_color = {}, {} |
| |
|
| | pathn_record = [] |
| | |
| | if cfg.trainable.bg: |
| | |
| | para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device) |
| | else: |
| | if cfg.use_ycrcb: |
| | para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device) |
| | else: |
| | para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device) |
| |
|
| | |
| | |
| | |
| |
|
| | loss_weight = None |
| | loss_weight_keep = 0 |
| | if cfg.coord_init.type == 'naive': |
| | pos_init_method = naive_coord_init( |
| | para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) |
| | elif cfg.coord_init.type == 'sparse': |
| | pos_init_method = sparse_coord_init( |
| | para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) |
| | elif cfg.coord_init.type == 'random': |
| | pos_init_method = random_coord_init([h, w]) |
| | else: |
| | raise ValueError |
| |
|
| | lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4) |
| | optim_schedular_dict = {} |
| |
|
| | for path_idx, pathn in enumerate(path_schedule): |
| | loss_list = [] |
| | print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type)) |
| | pathn_record.append(pathn) |
| | pathn_record_str = '-'.join([str(i) for i in pathn_record]) |
| |
|
| | |
| | if cfg.trainable.stroke: |
| | shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes( |
| | pathn, cfg.num_segments, (h, w), |
| | cfg.seginit, len(shapes_record), |
| | pos_init_method, |
| | trainable_stroke=True, |
| | gt=gt, ) |
| | para_stroke_width[path_idx] = stroke_width_var |
| | para_stroke_color[path_idx] = stroke_color_var |
| | else: |
| | shapes, shape_groups, point_var, color_var = init_shapes( |
| | pathn, cfg.num_segments, (h, w), |
| | cfg.seginit, len(shapes_record), |
| | pos_init_method, |
| | trainable_stroke=False, |
| | gt=gt, ) |
| |
|
| | shapes_record += shapes |
| | shape_groups_record += shape_groups |
| |
|
| | if cfg.save.init: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "svg-init", |
| | "{}-init.svg".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | pydiffvg.save_svg( |
| | filename, w, h, |
| | shapes_record, shape_groups_record) |
| |
|
| | para = {} |
| | if (cfg.trainable.bg) and (path_idx == 0): |
| | para['bg'] = [para_bg] |
| | para['point'] = point_var |
| | para['color'] = color_var |
| | if cfg.trainable.stroke: |
| | para['stroke_width'] = stroke_width_var |
| | para['stroke_color'] = stroke_color_var |
| |
|
| | pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())] |
| | optim = torch.optim.Adam(pg) |
| |
|
| | if cfg.trainable.record: |
| | scheduler = LambdaLR( |
| | optim, lr_lambda=lrlambda_f, last_epoch=-1) |
| | else: |
| | scheduler = LambdaLR( |
| | optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter) |
| | optim_schedular_dict[path_idx] = (optim, scheduler) |
| |
|
| | |
| | t_range = tqdm(range(cfg.num_iter)) |
| | for t in t_range: |
| |
|
| | for _, (optim, _) in optim_schedular_dict.items(): |
| | optim.zero_grad() |
| |
|
| | |
| | scene_args = pydiffvg.RenderFunction.serialize_scene( |
| | w, h, shapes_record, shape_groups_record) |
| | img = render(w, h, 2, 2, t, None, *scene_args) |
| |
|
| | |
| | img = img[:, :, 3:4] * img[:, :, :3] + \ |
| | para_bg * (1 - img[:, :, 3:4]) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | if cfg.save.video: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "video-png", |
| | "{}-iter{}.png".format(pathn_record_str, t)) |
| | check_and_create_dir(filename) |
| | if cfg.use_ycrcb: |
| | imshow = ycrcb_conversion( |
| | img, format='[2D x 3]', reverse=True).detach().cpu() |
| | else: |
| | imshow = img.detach().cpu() |
| | pydiffvg.imwrite(imshow, filename, gamma=gamma) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | x = img.unsqueeze(0).permute(0, 3, 1, 2) |
| |
|
| | if cfg.use_ycrcb: |
| | color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device) |
| | loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2 |
| | else: |
| | loss = ((x-gt)**2) |
| |
|
| | if cfg.loss.use_l1_loss: |
| | loss = abs(x-gt) |
| |
|
| | if cfg.loss.use_distance_weighted_loss: |
| | if cfg.use_ycrcb: |
| | raise ValueError |
| | shapes_forsdf = copy.deepcopy(shapes) |
| | shape_groups_forsdf = copy.deepcopy(shape_groups) |
| | for si in shapes_forsdf: |
| | si.stroke_width = torch.FloatTensor([0]).to(device) |
| | for sg_idx, sgi in enumerate(shape_groups_forsdf): |
| | sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device) |
| | sgi.shape_ids = torch.LongTensor([sg_idx]).to(device) |
| |
|
| | sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( |
| | w, h, shapes_forsdf, shape_groups_forsdf) |
| | with torch.no_grad(): |
| | im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf) |
| | |
| | im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() |
| | loss_weight = get_sdf(im_forsdf, normalize='to1') |
| | loss_weight += loss_weight_keep |
| | loss_weight = np.clip(loss_weight, 0, 1) |
| | loss_weight = torch.FloatTensor(loss_weight).to(device) |
| |
|
| | if cfg.save.loss: |
| | save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy() |
| | save_weight = loss_weight.cpu().detach().numpy() |
| | save_weighted_loss = save_loss*save_weight |
| | |
| | save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss) |
| | save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight) |
| | save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss) |
| |
|
| | |
| | plt.imshow(save_loss, cmap='Reds') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t)) |
| | check_and_create_dir(filename) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| | plt.imshow(save_weight, cmap='Greys') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t)) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| | plt.imshow(save_weighted_loss, cmap='Reds') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t)) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| |
|
| |
|
| |
|
| |
|
| | if loss_weight is None: |
| | loss = loss.sum(1).mean() |
| | else: |
| | loss = (loss.sum(1)*loss_weight).mean() |
| |
|
| | |
| | |
| | |
| | if (cfg.loss.xing_loss_weight is not None) \ |
| | and (cfg.loss.xing_loss_weight > 0): |
| | loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight |
| | loss = loss + loss_xing |
| |
|
| |
|
| | loss_list.append(loss.item()) |
| | t_range.set_postfix({'loss': loss.item()}) |
| | loss.backward() |
| |
|
| | |
| | for _, (optim, scheduler) in optim_schedular_dict.items(): |
| | optim.step() |
| | scheduler.step() |
| |
|
| | for group in shape_groups_record: |
| | group.fill_color.data.clamp_(0.0, 1.0) |
| |
|
| | if cfg.loss.use_distance_weighted_loss: |
| | loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 |
| |
|
| | if not cfg.trainable.record: |
| | for _, pi in pg.items(): |
| | for ppi in pi: |
| | pi.require_grad = False |
| | optim_schedular_dict = {} |
| |
|
| | if cfg.save.image: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | if cfg.use_ycrcb: |
| | imshow = ycrcb_conversion( |
| | img, format='[2D x 3]', reverse=True).detach().cpu() |
| | else: |
| | imshow = img.detach().cpu() |
| | pydiffvg.imwrite(imshow, filename, gamma=gamma) |
| |
|
| | svg_app_file_name = "" |
| | if cfg.save.output: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record) |
| | svg_app_file_name = filename |
| |
|
| | loss_matrix.append(loss_list) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | pos_init_method = naive_coord_init(x, gt) |
| |
|
| | if cfg.coord_init.type == 'naive': |
| | pos_init_method = naive_coord_init(x, gt) |
| | elif cfg.coord_init.type == 'sparse': |
| | pos_init_method = sparse_coord_init(x, gt) |
| | elif cfg.coord_init.type == 'random': |
| | pos_init_method = random_coord_init([h, w]) |
| | else: |
| | raise ValueError |
| |
|
| | if cfg.save.video: |
| | print("saving iteration video...") |
| | img_array = [] |
| | for ii in range(0, cfg.num_iter): |
| | filename = os.path.join( |
| | cfg.experiment_dir, "video-png", |
| | "{}-iter{}.png".format(pathn_record_str, ii)) |
| | img = cv2.imread(filename) |
| | |
| | |
| | |
| | img_array.append(img) |
| |
|
| | videoname = os.path.join( |
| | cfg.experiment_dir, "video-avi", |
| | "{}.avi".format(pathn_record_str)) |
| | check_and_create_dir(videoname) |
| | out = cv2.VideoWriter( |
| | videoname, |
| | |
| | cv2.VideoWriter_fourcc(*'FFV1'), |
| | 20.0, (w, h)) |
| | for iii in range(len(img_array)): |
| | out.write(img_array[iii]) |
| | out.release() |
| | |
| |
|
| | print("The last loss is: {}".format(loss.item())) |
| | return img.detach().cpu().numpy(), svg_app_file_name |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | |
| | |
| | |
| |
|
| | cfg_arg = parse_args() |
| | with open(cfg_arg.config, 'r') as f: |
| | cfg = yaml.load(f, Loader=yaml.FullLoader) |
| | cfg_default = edict(cfg['default']) |
| | cfg = edict(cfg[cfg_arg.experiment]) |
| | cfg.update(cfg_default) |
| | cfg.update(cfg_arg) |
| | cfg.exid = get_experiment_id(cfg.debug) |
| |
|
| | cfg.experiment_dir = \ |
| | osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature))) |
| | configfile = osp.join(cfg.experiment_dir, 'config.yaml') |
| | check_and_create_dir(configfile) |
| | with open(osp.join(configfile), 'w') as f: |
| | yaml.dump(edict_2_dict(cfg), f) |
| |
|
| | |
| | pydiffvg.set_use_gpu(torch.cuda.is_available()) |
| | device = pydiffvg.get_device() |
| |
|
| | gt = np.array(PIL.Image.open(cfg.target)) |
| | print(f"Input image shape is: {gt.shape}") |
| | if len(gt.shape) == 2: |
| | print("Converting the gray-scale image to RGB.") |
| | gt = gt.unsqueeze(dim=-1).repeat(1,1,3) |
| | if gt.shape[2] == 4: |
| | print("Input image includes alpha channel, simply dropout alpha channel.") |
| | gt = gt[:, :, :3] |
| | gt = (gt/255).astype(np.float32) |
| | gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device) |
| | if cfg.use_ycrcb: |
| | gt = ycrcb_conversion(gt) |
| | h, w = gt.shape[2:] |
| |
|
| | path_schedule = get_path_schedule(**cfg.path_schedule) |
| |
|
| | if cfg.seed is not None: |
| | random.seed(cfg.seed) |
| | npr.seed(cfg.seed) |
| | torch.manual_seed(cfg.seed) |
| | render = pydiffvg.RenderFunction.apply |
| |
|
| | shapes_record, shape_groups_record = [], [] |
| |
|
| | region_loss = None |
| | loss_matrix = [] |
| |
|
| | para_point, para_color = {}, {} |
| | if cfg.trainable.stroke: |
| | para_stroke_width, para_stroke_color = {}, {} |
| |
|
| | pathn_record = [] |
| | |
| | if cfg.trainable.bg: |
| | |
| | para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device) |
| | else: |
| | if cfg.use_ycrcb: |
| | para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device) |
| | else: |
| | para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device) |
| |
|
| | |
| | |
| | |
| |
|
| | loss_weight = None |
| | loss_weight_keep = 0 |
| | if cfg.coord_init.type == 'naive': |
| | pos_init_method = naive_coord_init( |
| | para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) |
| | elif cfg.coord_init.type == 'sparse': |
| | pos_init_method = sparse_coord_init( |
| | para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) |
| | elif cfg.coord_init.type == 'random': |
| | pos_init_method = random_coord_init([h, w]) |
| | else: |
| | raise ValueError |
| |
|
| | lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4) |
| | optim_schedular_dict = {} |
| |
|
| | for path_idx, pathn in enumerate(path_schedule): |
| | loss_list = [] |
| | print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type)) |
| | pathn_record.append(pathn) |
| | pathn_record_str = '-'.join([str(i) for i in pathn_record]) |
| |
|
| | |
| | if cfg.trainable.stroke: |
| | shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes( |
| | pathn, cfg.num_segments, (h, w), |
| | cfg.seginit, len(shapes_record), |
| | pos_init_method, |
| | trainable_stroke=True, |
| | gt=gt, ) |
| | para_stroke_width[path_idx] = stroke_width_var |
| | para_stroke_color[path_idx] = stroke_color_var |
| | else: |
| | shapes, shape_groups, point_var, color_var = init_shapes( |
| | pathn, cfg.num_segments, (h, w), |
| | cfg.seginit, len(shapes_record), |
| | pos_init_method, |
| | trainable_stroke=False, |
| | gt=gt, ) |
| |
|
| | shapes_record += shapes |
| | shape_groups_record += shape_groups |
| |
|
| | if cfg.save.init: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "svg-init", |
| | "{}-init.svg".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | pydiffvg.save_svg( |
| | filename, w, h, |
| | shapes_record, shape_groups_record) |
| |
|
| | para = {} |
| | if (cfg.trainable.bg) and (path_idx == 0): |
| | para['bg'] = [para_bg] |
| | para['point'] = point_var |
| | para['color'] = color_var |
| | if cfg.trainable.stroke: |
| | para['stroke_width'] = stroke_width_var |
| | para['stroke_color'] = stroke_color_var |
| |
|
| | pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())] |
| | optim = torch.optim.Adam(pg) |
| |
|
| | if cfg.trainable.record: |
| | scheduler = LambdaLR( |
| | optim, lr_lambda=lrlambda_f, last_epoch=-1) |
| | else: |
| | scheduler = LambdaLR( |
| | optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter) |
| | optim_schedular_dict[path_idx] = (optim, scheduler) |
| |
|
| | |
| | t_range = tqdm(range(cfg.num_iter)) |
| | for t in t_range: |
| |
|
| | for _, (optim, _) in optim_schedular_dict.items(): |
| | optim.zero_grad() |
| |
|
| | |
| | scene_args = pydiffvg.RenderFunction.serialize_scene( |
| | w, h, shapes_record, shape_groups_record) |
| | img = render(w, h, 2, 2, t, None, *scene_args) |
| |
|
| | |
| | img = img[:, :, 3:4] * img[:, :, :3] + \ |
| | para_bg * (1 - img[:, :, 3:4]) |
| |
|
| | if cfg.save.video: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "video-png", |
| | "{}-iter{}.png".format(pathn_record_str, t)) |
| | check_and_create_dir(filename) |
| | if cfg.use_ycrcb: |
| | imshow = ycrcb_conversion( |
| | img, format='[2D x 3]', reverse=True).detach().cpu() |
| | else: |
| | imshow = img.detach().cpu() |
| | pydiffvg.imwrite(imshow, filename, gamma=gamma) |
| |
|
| | x = img.unsqueeze(0).permute(0, 3, 1, 2) |
| |
|
| | if cfg.use_ycrcb: |
| | color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device) |
| | loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2 |
| | else: |
| | loss = ((x-gt)**2) |
| |
|
| | if cfg.loss.use_l1_loss: |
| | loss = abs(x-gt) |
| |
|
| | if cfg.loss.use_distance_weighted_loss: |
| | if cfg.use_ycrcb: |
| | raise ValueError |
| | shapes_forsdf = copy.deepcopy(shapes) |
| | shape_groups_forsdf = copy.deepcopy(shape_groups) |
| | for si in shapes_forsdf: |
| | si.stroke_width = torch.FloatTensor([0]).to(device) |
| | for sg_idx, sgi in enumerate(shape_groups_forsdf): |
| | sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device) |
| | sgi.shape_ids = torch.LongTensor([sg_idx]).to(device) |
| |
|
| | sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( |
| | w, h, shapes_forsdf, shape_groups_forsdf) |
| | with torch.no_grad(): |
| | im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf) |
| | |
| | im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() |
| | loss_weight = get_sdf(im_forsdf, normalize='to1') |
| | loss_weight += loss_weight_keep |
| | loss_weight = np.clip(loss_weight, 0, 1) |
| | loss_weight = torch.FloatTensor(loss_weight).to(device) |
| |
|
| | if cfg.save.loss: |
| | save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy() |
| | save_weight = loss_weight.cpu().detach().numpy() |
| | save_weighted_loss = save_loss*save_weight |
| | |
| | save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss) |
| | save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight) |
| | save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss) |
| |
|
| | |
| | plt.imshow(save_loss, cmap='Reds') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t)) |
| | check_and_create_dir(filename) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| | plt.imshow(save_weight, cmap='Greys') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t)) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| | plt.imshow(save_weighted_loss, cmap='Reds') |
| | plt.axis('off') |
| | |
| | filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t)) |
| | plt.savefig(filename, dpi=800) |
| | plt.close() |
| |
|
| |
|
| |
|
| |
|
| |
|
| | if loss_weight is None: |
| | loss = loss.sum(1).mean() |
| | else: |
| | loss = (loss.sum(1)*loss_weight).mean() |
| |
|
| | |
| | |
| | |
| | if (cfg.loss.xing_loss_weight is not None) \ |
| | and (cfg.loss.xing_loss_weight > 0): |
| | loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight |
| | loss = loss + loss_xing |
| |
|
| |
|
| | loss_list.append(loss.item()) |
| | t_range.set_postfix({'loss': loss.item()}) |
| | loss.backward() |
| |
|
| | |
| | for _, (optim, scheduler) in optim_schedular_dict.items(): |
| | optim.step() |
| | scheduler.step() |
| |
|
| | for group in shape_groups_record: |
| | group.fill_color.data.clamp_(0.0, 1.0) |
| |
|
| | if cfg.loss.use_distance_weighted_loss: |
| | loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 |
| |
|
| | if not cfg.trainable.record: |
| | for _, pi in pg.items(): |
| | for ppi in pi: |
| | pi.require_grad = False |
| | optim_schedular_dict = {} |
| |
|
| | if cfg.save.image: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | if cfg.use_ycrcb: |
| | imshow = ycrcb_conversion( |
| | img, format='[2D x 3]', reverse=True).detach().cpu() |
| | else: |
| | imshow = img.detach().cpu() |
| | pydiffvg.imwrite(imshow, filename, gamma=gamma) |
| |
|
| | if cfg.save.output: |
| | filename = os.path.join( |
| | cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str)) |
| | check_and_create_dir(filename) |
| | pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record) |
| |
|
| | loss_matrix.append(loss_list) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | pos_init_method = naive_coord_init(x, gt) |
| |
|
| | if cfg.coord_init.type == 'naive': |
| | pos_init_method = naive_coord_init(x, gt) |
| | elif cfg.coord_init.type == 'sparse': |
| | pos_init_method = sparse_coord_init(x, gt) |
| | elif cfg.coord_init.type == 'random': |
| | pos_init_method = random_coord_init([h, w]) |
| | else: |
| | raise ValueError |
| |
|
| | if cfg.save.video: |
| | print("saving iteration video...") |
| | img_array = [] |
| | for ii in range(0, cfg.num_iter): |
| | filename = os.path.join( |
| | cfg.experiment_dir, "video-png", |
| | "{}-iter{}.png".format(pathn_record_str, ii)) |
| | img = cv2.imread(filename) |
| | |
| | |
| | |
| | img_array.append(img) |
| |
|
| | videoname = os.path.join( |
| | cfg.experiment_dir, "video-avi", |
| | "{}.avi".format(pathn_record_str)) |
| | check_and_create_dir(videoname) |
| | out = cv2.VideoWriter( |
| | videoname, |
| | |
| | cv2.VideoWriter_fourcc(*'FFV1'), |
| | 20.0, (w, h)) |
| | for iii in range(len(img_array)): |
| | out.write(img_array[iii]) |
| | out.release() |
| | |
| |
|
| | print("The last loss is: {}".format(loss.item())) |
| |
|