Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.utils import save_image | |
| from dataloader import get_loader | |
| from models.model_main import ModelMain | |
| from models.transformers import denumericalize | |
| from options import get_parser_main_model | |
| from data_utils.svg_utils import render | |
| from models.util_funcs import svg2img, cal_iou | |
| from tqdm import tqdm | |
| from PIL import Image | |
| def test_main_model(opts): | |
| if opts.streamlit: | |
| import streamlit as st | |
| if opts.dir_res: | |
| dir_res = os.path.join(opts.dir_res, "results") | |
| if os.path.exists(dir_res): | |
| shutil.rmtree(dir_res) | |
| os.mkdir(os.path.join(opts.dir_res, "results")) | |
| else: | |
| dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results") | |
| test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print("Inference With Device:", device) | |
| if opts.streamlit: | |
| def set_img(key: str, img: Image.Image): | |
| st.session_state[key] = img | |
| st.write("Loading Model Weight...") | |
| st.write("Inference With Device:", device) | |
| model_main = ModelMain(opts) | |
| path_ckpt = os.path.join(f"{opts.model_path}") | |
| model_main.load_state_dict(torch.load(path_ckpt, map_location=device)['model']) | |
| model_main.to(device) | |
| model_main.eval() | |
| with torch.no_grad(): | |
| for test_idx, test_data in enumerate(test_loader): | |
| for key in test_data: test_data[key] = test_data[key].to(device) | |
| print("testing font %04d ..."%test_idx) | |
| dir_save = os.path.join(dir_res, "%04d"%test_idx) | |
| if not os.path.exists(dir_save): | |
| os.mkdir(dir_save) | |
| os.mkdir(os.path.join(dir_save, "imgs")) | |
| os.mkdir(os.path.join(dir_save, "svgs_single")) | |
| os.mkdir(os.path.join(dir_save, "svgs_merge")) | |
| svg_merge_dir = os.path.join(dir_save, "svgs_merge") | |
| iou_max = np.zeros(opts.char_num) | |
| idx_best_sample = np.zeros(opts.char_num) | |
| # syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}_rand_{sample_idx}.html"), 'w') | |
| syn_svg_merge_f = open(os.path.join(svg_merge_dir, f"{opts.name_ckpt}_syn_merge_{test_idx}.html"), 'w') | |
| for sample_idx in tqdm(range(opts.n_samples)): | |
| ret_dict_test, loss_dict_test = model_main(test_data, mode='test') | |
| svg_sampled = ret_dict_test['svg']['sampled_1'] | |
| sampled_svg_2 = ret_dict_test['svg']['sampled_2'] | |
| img_trg = ret_dict_test['img']['trg'] | |
| img_output = ret_dict_test['img']['out'] | |
| trg_seq_gt = ret_dict_test['svg']['trg'] | |
| img_sample_merge = torch.cat((img_trg.data, img_output.data), -2) | |
| save_file_merge = os.path.join(dir_save, "imgs", f"merge_{opts.img_size}.png") | |
| save_image(img_sample_merge, save_file_merge, nrow=8, normalize=True) | |
| if opts.streamlit: | |
| st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...") | |
| im = Image.open(save_file_merge) | |
| set_img(opts.OUTPUT_IMG_KEY, im.copy()) | |
| st.image(im, caption=f"sample {sample_idx+1}") | |
| for char_idx in tqdm(range(opts.char_num)): | |
| img_gt = (1.0 - img_trg[char_idx,...]).data | |
| save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png") | |
| save_image(img_gt, save_file_gt, normalize=True) | |
| img_sample = (1.0 - img_output[char_idx,...]).data | |
| save_file = os.path.join(dir_save,"imgs", f"{char_idx:02d}_{opts.img_size}.png") | |
| save_image(img_sample, save_file, normalize=True) | |
| # write results w/o parallel refinement | |
| svg_dec_out = svg_sampled.clone().detach() | |
| for i, one_seq in tqdm(enumerate(svg_dec_out)): | |
| syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg") | |
| syn_svg_f_ = open(syn_svg_outfile, 'w') | |
| try: | |
| svg = render(one_seq.cpu().numpy()) | |
| syn_svg_f_.write(svg) | |
| # syn_svg_merge_f.write(svg) | |
| if i > 0 and i % 13 == 12: | |
| syn_svg_f_.write('<br>') | |
| # syn_svg_merge_f.write('<br>') | |
| except: | |
| continue | |
| syn_svg_f_.close() | |
| # write results w/ parallel refinement | |
| svg_dec_out = sampled_svg_2.clone().detach() | |
| for i, one_seq in tqdm(enumerate(svg_dec_out)): | |
| syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg") | |
| syn_svg_f = open(syn_svg_outfile, 'w') | |
| try: | |
| svg = render(one_seq.cpu().numpy()) | |
| syn_svg_f.write(svg) | |
| #syn_svg_merge_f.write(svg) | |
| #if i > 0 and i % 13 == 12: | |
| # syn_svg_merge_f.write('<br>') | |
| except: | |
| continue | |
| syn_svg_f.close() | |
| syn_img_outfile = syn_svg_outfile.replace('.svg', '.png') | |
| svg2img(syn_svg_outfile, syn_img_outfile, img_size=opts.img_size) | |
| iou_tmp, l1_tmp = cal_iou(syn_img_outfile, os.path.join(dir_save, "imgs", f"{i:02d}_{opts.img_size}.png")) | |
| iou_tmp = iou_tmp | |
| if iou_tmp > iou_max[i]: | |
| iou_max[i] = iou_tmp | |
| idx_best_sample[i] = sample_idx | |
| for i in tqdm(range(opts.char_num)): | |
| # print(idx_best_sample[i]) | |
| syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg") | |
| syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read()) | |
| if i > 0 and i % 13 == 12: | |
| syn_svg_merge_f.write('<br>') | |
| svg_target = trg_seq_gt.clone().detach() | |
| tgt_commands_onehot = F.one_hot(svg_target[:, :, :1].long(), 4).squeeze() | |
| tgt_args_denum = denumericalize(svg_target[:, :, 1:]) | |
| svg_target = torch.cat([tgt_commands_onehot, tgt_args_denum], dim=-1) | |
| for i, one_gt_seq in enumerate(svg_target): | |
| # gt_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"gt_{i:02d}.svg") | |
| # gt_svg_f = open(gt_svg_outfile, 'w') | |
| gt_svg = render(one_gt_seq.cpu().numpy()) | |
| # gt_svg_f.write(gt_svg) | |
| syn_svg_merge_f.write(gt_svg) | |
| # gt_svg_f.close() | |
| if i > 0 and i % 13 == 12: | |
| syn_svg_merge_f.write('<br>') | |
| syn_svg_merge_f.close() | |
| return im | |
| def main(): | |
| opts = get_parser_main_model().parse_args() | |
| opts.name_exp = opts.name_exp + '_' + opts.model_name | |
| experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp) | |
| print(f"Testing on experiment {opts.name_exp}...") | |
| # Dump options | |
| test_main_model(opts) | |
| if __name__ == "__main__": | |
| main() |