Spaces:
Build error
Build error
brayden-gg
commited on
Commit
·
904e4e5
1
Parent(s):
1a69cb2
switched to SVG rendering
Browse files- app.py +9 -9
- config/__pycache__/GlobalVariables.cpython-38.pyc +0 -0
- config/__pycache__/__init__.cpython-38.pyc +0 -0
- convenience.py +45 -9
- interpolation.py +11 -12
- output.svg +2 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -76,8 +76,8 @@ def update_writer_slider(val):
|
|
| 76 |
weights = [1 - writer_weight, writer_weight]
|
| 77 |
|
| 78 |
net.clamp_mdn = 0
|
| 79 |
-
|
| 80 |
-
return
|
| 81 |
|
| 82 |
|
| 83 |
def update_chosen_writers(writer1, writer2):
|
|
@@ -109,9 +109,9 @@ def update_char_slider(weight):
|
|
| 109 |
|
| 110 |
all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
|
| 111 |
all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
|
| 112 |
-
|
| 113 |
|
| 114 |
-
return
|
| 115 |
|
| 116 |
|
| 117 |
def update_blend_chars(c1, c2):
|
|
@@ -145,8 +145,8 @@ def update_mdn_word(target_word):
|
|
| 145 |
def sample_mdn(maxs, maxr):
|
| 146 |
net.clamp_mdn = maxr
|
| 147 |
net.scale_sd = maxs
|
| 148 |
-
|
| 149 |
-
return
|
| 150 |
|
| 151 |
|
| 152 |
update_writer_word(" ".join(writer_words))
|
|
@@ -173,7 +173,7 @@ with gr.Blocks() as demo:
|
|
| 173 |
writer_submit = gr.Button("Submit")
|
| 174 |
with gr.Row():
|
| 175 |
writer_default_image = update_writer_slider(writer_weight)
|
| 176 |
-
writer_output = gr.
|
| 177 |
|
| 178 |
writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
|
| 179 |
writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
|
|
@@ -192,7 +192,7 @@ with gr.Blocks() as demo:
|
|
| 192 |
char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
|
| 193 |
with gr.Row():
|
| 194 |
char_default_image = update_char_slider(char_weight)
|
| 195 |
-
char_output = gr.
|
| 196 |
|
| 197 |
char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output], show_progress=False)
|
| 198 |
|
|
@@ -210,7 +210,7 @@ with gr.Blocks() as demo:
|
|
| 210 |
mdn_sample_button = gr.Button(value="Resample!")
|
| 211 |
with gr.Row():
|
| 212 |
default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
|
| 213 |
-
mdn_output = gr.
|
| 214 |
|
| 215 |
max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
|
| 216 |
scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
|
|
|
|
| 76 |
weights = [1 - writer_weight, writer_weight]
|
| 77 |
|
| 78 |
net.clamp_mdn = 0
|
| 79 |
+
svg = convenience.draw_words_svg(writer_words, all_word_writer_Ws, all_word_writer_Cs, weights, net)
|
| 80 |
+
return svg
|
| 81 |
|
| 82 |
|
| 83 |
def update_chosen_writers(writer1, writer2):
|
|
|
|
| 109 |
|
| 110 |
all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
|
| 111 |
all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
|
| 112 |
+
svg = convenience.commands_to_svg(all_commands, 750, 160, 375)
|
| 113 |
|
| 114 |
+
return svg
|
| 115 |
|
| 116 |
|
| 117 |
def update_blend_chars(c1, c2):
|
|
|
|
| 145 |
def sample_mdn(maxs, maxr):
|
| 146 |
net.clamp_mdn = maxr
|
| 147 |
net.scale_sd = maxs
|
| 148 |
+
svg = convenience.draw_words_svg(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net)
|
| 149 |
+
return svg
|
| 150 |
|
| 151 |
|
| 152 |
update_writer_word(" ".join(writer_words))
|
|
|
|
| 173 |
writer_submit = gr.Button("Submit")
|
| 174 |
with gr.Row():
|
| 175 |
writer_default_image = update_writer_slider(writer_weight)
|
| 176 |
+
writer_output = gr.HTML(writer_default_image)
|
| 177 |
|
| 178 |
writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
|
| 179 |
writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
|
|
|
|
| 192 |
char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
|
| 193 |
with gr.Row():
|
| 194 |
char_default_image = update_char_slider(char_weight)
|
| 195 |
+
char_output = gr.HTML(char_default_image)
|
| 196 |
|
| 197 |
char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output], show_progress=False)
|
| 198 |
|
|
|
|
| 210 |
mdn_sample_button = gr.Button(value="Resample!")
|
| 211 |
with gr.Row():
|
| 212 |
default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
|
| 213 |
+
mdn_output = gr.HTML(default_im)
|
| 214 |
|
| 215 |
max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
|
| 216 |
scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
|
config/__pycache__/GlobalVariables.cpython-38.pyc
CHANGED
|
Binary files a/config/__pycache__/GlobalVariables.cpython-38.pyc and b/config/__pycache__/GlobalVariables.cpython-38.pyc differ
|
|
|
config/__pycache__/__init__.cpython-38.pyc
CHANGED
|
Binary files a/config/__pycache__/__init__.cpython-38.pyc and b/config/__pycache__/__init__.cpython-38.pyc differ
|
|
|
convenience.py
CHANGED
|
@@ -14,10 +14,12 @@ from config.GlobalVariables import *
|
|
| 14 |
from tensorboardX import SummaryWriter
|
| 15 |
from SynthesisNetwork import SynthesisNetwork
|
| 16 |
from DataLoader import DataLoader
|
|
|
|
| 17 |
# import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python
|
| 18 |
|
| 19 |
L = 256
|
| 20 |
|
|
|
|
| 21 |
def get_mean_global_W(net, loaded_data, device):
|
| 22 |
"""gets the mean global style vector for a given writer"""
|
| 23 |
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
|
|
@@ -231,14 +233,14 @@ def get_character_blend_W_c(character_weights, all_Ws, all_Cs):
|
|
| 231 |
W_vector = all_Ws[0, 0, :].unsqueeze(-1)
|
| 232 |
|
| 233 |
weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions
|
| 234 |
-
char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze()
|
| 235 |
|
| 236 |
W_c = char_matrix @ W_vector
|
| 237 |
|
| 238 |
return W_c.reshape(1, 1, L)
|
| 239 |
|
| 240 |
|
| 241 |
-
def get_commands(net, target_word, all_W_c):
|
| 242 |
"""converts character-dependent style-dependent DSDs to a list of commands for drawing"""
|
| 243 |
all_commands = []
|
| 244 |
current_id = 0
|
|
@@ -285,6 +287,7 @@ def get_commands(net, target_word, all_W_c): # seems like target_word is only us
|
|
| 285 |
|
| 286 |
return commands
|
| 287 |
|
|
|
|
| 288 |
def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device):
|
| 289 |
'''
|
| 290 |
Method creating gif of mdn samples
|
|
@@ -306,7 +309,7 @@ def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_dat
|
|
| 306 |
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
|
| 307 |
word_Ws.append(writer_Ws)
|
| 308 |
word_Cs.append(writer_Cs)
|
| 309 |
-
|
| 310 |
im = draw_words(words, word_Ws, word_Cs, [1], net)
|
| 311 |
im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png')
|
| 312 |
# Convert fromes to video using ffmpeg
|
|
@@ -314,6 +317,7 @@ def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_dat
|
|
| 314 |
videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p")
|
| 315 |
videos.run(overwrite_output=True)
|
| 316 |
|
|
|
|
| 317 |
def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"):
|
| 318 |
"""Generates an image of handwritten text based on target_sentence"""
|
| 319 |
words = target_sentence.split(' ')
|
|
@@ -329,7 +333,7 @@ def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data
|
|
| 329 |
writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
|
| 330 |
word_Ws.append(writer_Ws)
|
| 331 |
word_Cs.append(writer_Cs)
|
| 332 |
-
|
| 333 |
return draw_words(words, word_Ws, word_Cs, writer_weights, net)
|
| 334 |
|
| 335 |
|
|
@@ -356,10 +360,10 @@ def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"
|
|
| 356 |
wx = i / (grid_size - 1)
|
| 357 |
wy = j / (grid_size - 1)
|
| 358 |
|
| 359 |
-
character_weights = [(1 - wx) * (1 - wy),
|
| 360 |
-
wx
|
| 361 |
(1 - wx) * wy, # bottom left is 1 at (0, 1)
|
| 362 |
-
wx
|
| 363 |
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
|
| 364 |
all_commands = get_commands(net, letters[0], all_W_c)
|
| 365 |
|
|
@@ -377,6 +381,7 @@ def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"
|
|
| 377 |
|
| 378 |
return im
|
| 379 |
|
|
|
|
| 380 |
def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"):
|
| 381 |
"""
|
| 382 |
Generates a video of interpolating between each provided writer
|
|
@@ -416,6 +421,7 @@ def writer_interpolation_video(target_sentence, transition_time, net, all_loaded
|
|
| 416 |
videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p")
|
| 417 |
videos.run(overwrite_output=True)
|
| 418 |
|
|
|
|
| 419 |
def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device):
|
| 420 |
'''
|
| 421 |
Method creating gif of mdn samples
|
|
@@ -462,7 +468,7 @@ def sample_blended_chars(character_weights, letters, net, all_loaded_data, devic
|
|
| 462 |
def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"):
|
| 463 |
"""Generates an image of handwritten text based on target_sentence"""
|
| 464 |
|
| 465 |
-
os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True)
|
| 466 |
|
| 467 |
M = len(letters)
|
| 468 |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
|
|
@@ -507,6 +513,25 @@ def draw_words(words, word_Ws, word_Cs, writer_weights, net):
|
|
| 507 |
|
| 508 |
return im
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
def commands_to_image(commands, imW, imH, xoff, yoff):
|
| 511 |
im = Image.fromarray(np.zeros([imW, imH]))
|
| 512 |
dr = ImageDraw.Draw(im)
|
|
@@ -519,4 +544,15 @@ def commands_to_image(commands, imW, imH, xoff, yoff):
|
|
| 519 |
y - yoff), 255, 1)
|
| 520 |
px, py = x, y
|
| 521 |
return im
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from tensorboardX import SummaryWriter
|
| 15 |
from SynthesisNetwork import SynthesisNetwork
|
| 16 |
from DataLoader import DataLoader
|
| 17 |
+
import svgwrite
|
| 18 |
# import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python
|
| 19 |
|
| 20 |
L = 256
|
| 21 |
|
| 22 |
+
|
| 23 |
def get_mean_global_W(net, loaded_data, device):
|
| 24 |
"""gets the mean global style vector for a given writer"""
|
| 25 |
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
|
|
|
|
| 233 |
W_vector = all_Ws[0, 0, :].unsqueeze(-1)
|
| 234 |
|
| 235 |
weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions
|
| 236 |
+
char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze() # take weighted sum accross characters axis
|
| 237 |
|
| 238 |
W_c = char_matrix @ W_vector
|
| 239 |
|
| 240 |
return W_c.reshape(1, 1, L)
|
| 241 |
|
| 242 |
|
| 243 |
+
def get_commands(net, target_word, all_W_c): # seems like target_word is only used for length
|
| 244 |
"""converts character-dependent style-dependent DSDs to a list of commands for drawing"""
|
| 245 |
all_commands = []
|
| 246 |
current_id = 0
|
|
|
|
| 287 |
|
| 288 |
return commands
|
| 289 |
|
| 290 |
+
|
| 291 |
def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device):
|
| 292 |
'''
|
| 293 |
Method creating gif of mdn samples
|
|
|
|
| 309 |
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
|
| 310 |
word_Ws.append(writer_Ws)
|
| 311 |
word_Cs.append(writer_Cs)
|
| 312 |
+
|
| 313 |
im = draw_words(words, word_Ws, word_Cs, [1], net)
|
| 314 |
im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png')
|
| 315 |
# Convert fromes to video using ffmpeg
|
|
|
|
| 317 |
videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p")
|
| 318 |
videos.run(overwrite_output=True)
|
| 319 |
|
| 320 |
+
|
| 321 |
def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"):
|
| 322 |
"""Generates an image of handwritten text based on target_sentence"""
|
| 323 |
words = target_sentence.split(' ')
|
|
|
|
| 333 |
writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
|
| 334 |
word_Ws.append(writer_Ws)
|
| 335 |
word_Cs.append(writer_Cs)
|
| 336 |
+
|
| 337 |
return draw_words(words, word_Ws, word_Cs, writer_weights, net)
|
| 338 |
|
| 339 |
|
|
|
|
| 360 |
wx = i / (grid_size - 1)
|
| 361 |
wy = j / (grid_size - 1)
|
| 362 |
|
| 363 |
+
character_weights = [(1 - wx) * (1 - wy), # top left is 1 at (0, 0)
|
| 364 |
+
wx * (1 - wy), # top right is 1 at (1, 0)
|
| 365 |
(1 - wx) * wy, # bottom left is 1 at (0, 1)
|
| 366 |
+
wx * wy] # bottom right is 1 at (1, 1)
|
| 367 |
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
|
| 368 |
all_commands = get_commands(net, letters[0], all_W_c)
|
| 369 |
|
|
|
|
| 381 |
|
| 382 |
return im
|
| 383 |
|
| 384 |
+
|
| 385 |
def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"):
|
| 386 |
"""
|
| 387 |
Generates a video of interpolating between each provided writer
|
|
|
|
| 421 |
videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p")
|
| 422 |
videos.run(overwrite_output=True)
|
| 423 |
|
| 424 |
+
|
| 425 |
def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device):
|
| 426 |
'''
|
| 427 |
Method creating gif of mdn samples
|
|
|
|
| 468 |
def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"):
|
| 469 |
"""Generates an image of handwritten text based on target_sentence"""
|
| 470 |
|
| 471 |
+
os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True) # make a folder for the frames
|
| 472 |
|
| 473 |
M = len(letters)
|
| 474 |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
|
|
|
|
| 513 |
|
| 514 |
return im
|
| 515 |
|
| 516 |
+
|
| 517 |
+
def draw_words_svg(words, word_Ws, word_Cs, writer_weights, net):
|
| 518 |
+
dwg = svgwrite.Drawing("output.svg", size=(750, 160), style="background-color: black;")
|
| 519 |
+
width = 50
|
| 520 |
+
for word, all_writer_Ws, all_writer_Cs in zip(words, word_Ws, word_Cs):
|
| 521 |
+
all_W_c = get_writer_blend_W_c(writer_weights, all_writer_Ws, all_writer_Cs)
|
| 522 |
+
all_commands = get_commands(net, word, all_W_c)
|
| 523 |
+
|
| 524 |
+
for [x, y, t] in all_commands:
|
| 525 |
+
if t == 0:
|
| 526 |
+
path.push("L", x + width, y)
|
| 527 |
+
else:
|
| 528 |
+
path = svgwrite.path.Path(stroke="white", stroke_width="1")
|
| 529 |
+
dwg.add(path)
|
| 530 |
+
path.push("M", x + width, y)
|
| 531 |
+
width += np.max(all_commands[:, 0]) + 25
|
| 532 |
+
return dwg.tostring()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
def commands_to_image(commands, imW, imH, xoff, yoff):
|
| 536 |
im = Image.fromarray(np.zeros([imW, imH]))
|
| 537 |
dr = ImageDraw.Draw(im)
|
|
|
|
| 544 |
y - yoff), 255, 1)
|
| 545 |
px, py = x, y
|
| 546 |
return im
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def commands_to_svg(commands, imW, imH, xoff):
|
| 550 |
+
dwg = svgwrite.Drawing("output.svg", size=(imW, imH), style="background-color:black")
|
| 551 |
+
for [x, y, t] in commands:
|
| 552 |
+
if t == 0:
|
| 553 |
+
path.push("L", x + xoff, y)
|
| 554 |
+
else:
|
| 555 |
+
path = svgwrite.path.Path(stroke="white", stroke_width="1")
|
| 556 |
+
dwg.add(path)
|
| 557 |
+
path.push("M", x + xoff, y)
|
| 558 |
+
return dwg.tostring()
|
interpolation.py
CHANGED
|
@@ -20,11 +20,10 @@ def main(params):
|
|
| 20 |
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
|
| 21 |
|
| 22 |
if not torch.cuda.is_available():
|
| 23 |
-
try:
|
| 24 |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
|
| 25 |
except:
|
| 26 |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
|
| 27 |
-
|
| 28 |
|
| 29 |
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
|
| 30 |
|
|
@@ -34,7 +33,6 @@ def main(params):
|
|
| 34 |
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
|
| 35 |
all_loaded_data.append(loaded_data)
|
| 36 |
|
| 37 |
-
|
| 38 |
if params.output == "image":
|
| 39 |
|
| 40 |
if params.interpolate == "writer":
|
|
@@ -78,6 +76,7 @@ def main(params):
|
|
| 78 |
else:
|
| 79 |
raise ValueError("Invalid output")
|
| 80 |
|
|
|
|
| 81 |
if __name__ == '__main__':
|
| 82 |
parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')
|
| 83 |
|
|
@@ -89,25 +88,25 @@ if __name__ == '__main__':
|
|
| 89 |
parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])
|
| 90 |
|
| 91 |
# PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
|
| 92 |
-
|
| 93 |
parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
|
| 94 |
-
|
| 95 |
parser.add_argument('--frames_per_step', type=int, default=10)
|
| 96 |
|
| 97 |
# PARAMS IF WRITER INTERPOLATION:
|
| 98 |
parser.add_argument('--target_word', type=str, default="hello world")
|
| 99 |
parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
|
| 100 |
-
|
| 101 |
# PARAMS IF CHARACTER INTERPOLATION:
|
| 102 |
-
|
| 103 |
-
parser.add_argument('--blend_chars', type=str, nargs="+", default
|
| 104 |
-
|
| 105 |
-
parser.add_argument('--grid_chars', type=str, nargs="+", default=
|
| 106 |
parser.add_argument('--grid_size', type=int, default=10)
|
| 107 |
|
| 108 |
# PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
|
| 109 |
-
parser.add_argument('--max_randomness', type=float, default=1)
|
| 110 |
-
parser.add_argument('--scale_randomness', type=float, default=0.5)
|
| 111 |
parser.add_argument('--num_random_samples', type=int, default=10)
|
| 112 |
|
| 113 |
main(parser.parse_args())
|
|
|
|
| 20 |
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
|
| 21 |
|
| 22 |
if not torch.cuda.is_available():
|
| 23 |
+
try: # retrained model also contains loss in dict
|
| 24 |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
|
| 25 |
except:
|
| 26 |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
|
|
|
|
| 27 |
|
| 28 |
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
|
| 29 |
|
|
|
|
| 33 |
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
|
| 34 |
all_loaded_data.append(loaded_data)
|
| 35 |
|
|
|
|
| 36 |
if params.output == "image":
|
| 37 |
|
| 38 |
if params.interpolate == "writer":
|
|
|
|
| 76 |
else:
|
| 77 |
raise ValueError("Invalid output")
|
| 78 |
|
| 79 |
+
|
| 80 |
if __name__ == '__main__':
|
| 81 |
parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')
|
| 82 |
|
|
|
|
| 88 |
parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])
|
| 89 |
|
| 90 |
# PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
|
| 91 |
+
# IF IMAGE - weights to use for a single sample of interpolation
|
| 92 |
parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
|
| 93 |
+
# IF VIDEO - the number of frames for each character/writer
|
| 94 |
parser.add_argument('--frames_per_step', type=int, default=10)
|
| 95 |
|
| 96 |
# PARAMS IF WRITER INTERPOLATION:
|
| 97 |
parser.add_argument('--target_word', type=str, default="hello world")
|
| 98 |
parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
|
| 99 |
+
|
| 100 |
# PARAMS IF CHARACTER INTERPOLATION:
|
| 101 |
+
# IF VIDEO OR BLEND
|
| 102 |
+
parser.add_argument('--blend_chars', type=str, nargs="+", default=["a", "b", "c", "d", "e"])
|
| 103 |
+
# IF GRID
|
| 104 |
+
parser.add_argument('--grid_chars', type=str, nargs="+", default=["y", "s", "u", "n"])
|
| 105 |
parser.add_argument('--grid_size', type=int, default=10)
|
| 106 |
|
| 107 |
# PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
|
| 108 |
+
parser.add_argument('--max_randomness', type=float, default=1)
|
| 109 |
+
parser.add_argument('--scale_randomness', type=float, default=0.5)
|
| 110 |
parser.add_argument('--num_random_samples', type=int, default=10)
|
| 111 |
|
| 112 |
main(parser.parse_args())
|
output.svg
ADDED
|
|
requirements.txt
CHANGED
|
@@ -8,4 +8,5 @@ torch==1.11.0
|
|
| 8 |
typing_extensions==4.1.1
|
| 9 |
ffmpeg-python
|
| 10 |
gradio
|
|
|
|
| 11 |
|
|
|
|
| 8 |
typing_extensions==4.1.1
|
| 9 |
ffmpeg-python
|
| 10 |
gradio
|
| 11 |
+
svgwrite
|
| 12 |
|