Spaces:
Sleeping
Sleeping
latest output update
Browse files- app.py +3 -3
- generate.py +4 -3
- test_few_shot.py +5 -1
app.py
CHANGED
|
@@ -65,8 +65,8 @@ def generate_button(prefix, file_input, version, **kwargs):
|
|
| 65 |
)
|
| 66 |
|
| 67 |
if st.button("Generate image", key=f"{prefix}-btn"):
|
| 68 |
-
with st.spinner(f"⏳ Generating image
|
| 69 |
-
image = toggle_process( ttf_to_image(file_input, n_samples, ref_char_ids, version) )
|
| 70 |
set_img(OUTPUT_IMG_KEY, image.copy())
|
| 71 |
st.image(image)
|
| 72 |
|
|
@@ -110,7 +110,7 @@ def main():
|
|
| 110 |
generate_tab()
|
| 111 |
|
| 112 |
with st.sidebar:
|
| 113 |
-
st.header("Latest Output
|
| 114 |
output_image = get_img(OUTPUT_IMG_KEY)
|
| 115 |
if output_image:
|
| 116 |
st.image(output_image)
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
if st.button("Generate image", key=f"{prefix}-btn"):
|
| 68 |
+
with st.spinner(f"⏳ Generating image (5 minutes per n_sample estimated time)"):
|
| 69 |
+
image = toggle_process( ttf_to_image(file_input, OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) )
|
| 70 |
set_img(OUTPUT_IMG_KEY, image.copy())
|
| 71 |
st.image(image)
|
| 72 |
|
|
|
|
| 110 |
generate_tab()
|
| 111 |
|
| 112 |
with st.sidebar:
|
| 113 |
+
st.header("Latest Output")
|
| 114 |
output_image = get_img(OUTPUT_IMG_KEY)
|
| 115 |
if output_image:
|
| 116 |
st.image(output_image)
|
generate.py
CHANGED
|
@@ -123,9 +123,10 @@ def preprocessing(ttf_file) -> str:
|
|
| 123 |
print("Saved at", output_path)
|
| 124 |
return output_path
|
| 125 |
|
| 126 |
-
def inference_model(n_samples, ref_char_ids, version):
|
| 127 |
opts.n_samples = n_samples
|
| 128 |
opts.ref_char_ids = ref_char_ids
|
|
|
|
| 129 |
|
| 130 |
# Select Model
|
| 131 |
if version == "TH2TH":
|
|
@@ -137,9 +138,9 @@ def inference_model(n_samples, ref_char_ids, version):
|
|
| 137 |
|
| 138 |
return test_main_model(opts)
|
| 139 |
|
| 140 |
-
def ttf_to_image(ttf_file, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
|
| 141 |
preprocessing(ttf_file) # Make Data
|
| 142 |
-
merge_svg_img = inference_model(n_samples, ref_char_ids, version) # Inference
|
| 143 |
return merge_svg_img
|
| 144 |
|
| 145 |
def main():
|
|
|
|
| 123 |
print("Saved at", output_path)
|
| 124 |
return output_path
|
| 125 |
|
| 126 |
+
def inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version):
|
| 127 |
opts.n_samples = n_samples
|
| 128 |
opts.ref_char_ids = ref_char_ids
|
| 129 |
+
opts.OUTPUT_IMG_KEY = OUTPUT_IMG_KEY
|
| 130 |
|
| 131 |
# Select Model
|
| 132 |
if version == "TH2TH":
|
|
|
|
| 138 |
|
| 139 |
return test_main_model(opts)
|
| 140 |
|
| 141 |
+
def ttf_to_image(ttf_file, OUTPUT_IMG_KEY, n_samples=10, ref_char_ids="1,2,3,4,5,6,7,8", version="TH2TH"):
|
| 142 |
preprocessing(ttf_file) # Make Data
|
| 143 |
+
merge_svg_img = inference_model(OUTPUT_IMG_KEY, n_samples, ref_char_ids, version) # Inference
|
| 144 |
return merge_svg_img
|
| 145 |
|
| 146 |
def main():
|
test_few_shot.py
CHANGED
|
@@ -31,6 +31,9 @@ def test_main_model(opts):
|
|
| 31 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
print("Inference With Device:", device)
|
| 33 |
if opts.streamlit:
|
|
|
|
|
|
|
|
|
|
| 34 |
st.write("Loading Model Weight...")
|
| 35 |
st.write("Inference With Device:", device)
|
| 36 |
|
|
@@ -78,7 +81,8 @@ def test_main_model(opts):
|
|
| 78 |
if opts.streamlit:
|
| 79 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
| 80 |
im = Image.open(save_file_merge)
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
for char_idx in tqdm(range(opts.char_num)):
|
| 84 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|
|
|
|
| 31 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
print("Inference With Device:", device)
|
| 33 |
if opts.streamlit:
|
| 34 |
+
def set_img(key: str, img: Image.Image):
|
| 35 |
+
st.session_state[key] = img
|
| 36 |
+
|
| 37 |
st.write("Loading Model Weight...")
|
| 38 |
st.write("Inference With Device:", device)
|
| 39 |
|
|
|
|
| 81 |
if opts.streamlit:
|
| 82 |
st.progress((sample_idx+1)/opts.n_samples, f"Generating Font Sample {sample_idx+1} Please wait...")
|
| 83 |
im = Image.open(save_file_merge)
|
| 84 |
+
set_img(opts.OUTPUT_IMG_KEY, im)
|
| 85 |
+
st.image(im, caption=f"sample {sample_idx+1}")
|
| 86 |
|
| 87 |
for char_idx in tqdm(range(opts.char_num)):
|
| 88 |
img_gt = (1.0 - img_trg[char_idx,...]).data
|