Spaces:
Sleeping
Sleeping
cpu pls
Browse files- test_few_shot.py +6 -6
test_few_shot.py
CHANGED
|
@@ -23,7 +23,7 @@ def test_main_model(opts):
|
|
| 23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
| 24 |
|
| 25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
| 26 |
-
|
| 27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 28 |
print("Inference With Device:", device)
|
| 29 |
if opts.streamlit:
|
|
@@ -75,8 +75,8 @@ def test_main_model(opts):
|
|
| 75 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
| 76 |
im = Image.open(save_file_merge)
|
| 77 |
st.image(im, caption='img_sample_merge')
|
| 78 |
-
|
| 79 |
-
for char_idx in range(opts.char_num):
|
| 80 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
| 81 |
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
|
| 82 |
save_image(img_gt, save_file_gt, normalize=True)
|
|
@@ -87,7 +87,7 @@ def test_main_model(opts):
|
|
| 87 |
|
| 88 |
# write results w/o parallel refinement
|
| 89 |
svg_dec_out = svg_sampled.clone().detach()
|
| 90 |
-
for i, one_seq in enumerate(svg_dec_out):
|
| 91 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
|
| 92 |
|
| 93 |
syn_svg_f_ = open(syn_svg_outfile, 'w')
|
|
@@ -105,7 +105,7 @@ def test_main_model(opts):
|
|
| 105 |
|
| 106 |
# write results w/ parallel refinement
|
| 107 |
svg_dec_out = sampled_svg_2.clone().detach()
|
| 108 |
-
for i, one_seq in enumerate(svg_dec_out):
|
| 109 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
|
| 110 |
|
| 111 |
syn_svg_f = open(syn_svg_outfile, 'w')
|
|
@@ -127,7 +127,7 @@ def test_main_model(opts):
|
|
| 127 |
iou_max[i] = iou_tmp
|
| 128 |
idx_best_sample[i] = sample_idx
|
| 129 |
|
| 130 |
-
for i in range(opts.char_num):
|
| 131 |
# print(idx_best_sample[i])
|
| 132 |
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
|
| 133 |
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
|
|
|
|
| 23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
| 24 |
|
| 25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
| 26 |
+
|
| 27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 28 |
print("Inference With Device:", device)
|
| 29 |
if opts.streamlit:
|
|
|
|
| 75 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
| 76 |
im = Image.open(save_file_merge)
|
| 77 |
st.image(im, caption='img_sample_merge')
|
| 78 |
+
|
| 79 |
+
for char_idx in tqdm(range(opts.char_num)):
|
| 80 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
| 81 |
save_file_gt = os.path.join(dir_save,"imgs", f"{char_idx:02d}_gt.png")
|
| 82 |
save_image(img_gt, save_file_gt, normalize=True)
|
|
|
|
| 87 |
|
| 88 |
# write results w/o parallel refinement
|
| 89 |
svg_dec_out = svg_sampled.clone().detach()
|
| 90 |
+
for i, one_seq in tqdm(enumerate(svg_dec_out)):
|
| 91 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_wo_refine.svg")
|
| 92 |
|
| 93 |
syn_svg_f_ = open(syn_svg_outfile, 'w')
|
|
|
|
| 105 |
|
| 106 |
# write results w/ parallel refinement
|
| 107 |
svg_dec_out = sampled_svg_2.clone().detach()
|
| 108 |
+
for i, one_seq in tqdm(enumerate(svg_dec_out)):
|
| 109 |
syn_svg_outfile = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{sample_idx}_refined.svg")
|
| 110 |
|
| 111 |
syn_svg_f = open(syn_svg_outfile, 'w')
|
|
|
|
| 127 |
iou_max[i] = iou_tmp
|
| 128 |
idx_best_sample[i] = sample_idx
|
| 129 |
|
| 130 |
+
for i in tqdm(range(opts.char_num)):
|
| 131 |
# print(idx_best_sample[i])
|
| 132 |
syn_svg_outfile_best = os.path.join(os.path.join(dir_save, "svgs_single"), f"syn_{i:02d}_{int(idx_best_sample[i])}_refined.svg")
|
| 133 |
syn_svg_merge_f.write(open(syn_svg_outfile_best, 'r').read())
|