Spaces:
Sleeping
Sleeping
| import drawing | |
| from rnn import rnn | |
| import numpy as np | |
| import svgwrite | |
| import logging | |
| import os | |
| class Hand(object): | |
| def __init__(self): | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| self.nn = rnn( | |
| log_dir='logs', | |
| checkpoint_dir='checkpoints', | |
| prediction_dir='predictions', | |
| learning_rates=[.0001, .00005, .00002], | |
| batch_sizes=[32, 64, 64], | |
| patiences=[1500, 1000, 500], | |
| beta1_decays=[.9, .9, .9], | |
| validation_batch_size=32, | |
| optimizer='rms', | |
| num_training_steps=100000, | |
| warm_start_init_step=17900, | |
| regularization_constant=0.0, | |
| keep_prob=1.0, | |
| enable_parameter_averaging=False, | |
| min_steps_to_checkpoint=2000, | |
| log_interval=20, | |
| logging_level=logging.CRITICAL, | |
| grad_clip=10, | |
| lstm_size=400, | |
| output_mixture_components=20, | |
| attention_mixture_components=10 | |
| ) | |
| self.nn.restore() | |
| def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None): | |
| valid_char_set = set(drawing.alphabet) | |
| for line_num, line in enumerate(lines): | |
| if len(line) > 75: | |
| raise ValueError( | |
| ( | |
| "Each line must be at most 75 characters. " | |
| "Line {} contains {}" | |
| ).format(line_num, len(line)) | |
| ) | |
| for char in line: | |
| if char not in valid_char_set: | |
| raise ValueError( | |
| ( | |
| "Invalid character {} detected in line {}. " | |
| "Valid character set is {}" | |
| ).format(char, line_num, valid_char_set) | |
| ) | |
| strokes = self._sample(lines, biases=biases, styles=styles) | |
| self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths) | |
| def _sample(self, lines, biases=None, styles=None): | |
| num_samples = len(lines) | |
| max_tsteps = 40*max([len(i) for i in lines]) | |
| biases = biases if biases is not None else [0.5]*num_samples | |
| x_prime = np.zeros([num_samples, 1200, 3]) | |
| x_prime_len = np.zeros([num_samples]) | |
| chars = np.zeros([num_samples, 120]) | |
| chars_len = np.zeros([num_samples]) | |
| if styles is not None: | |
| for i, (cs, style) in enumerate(zip(lines, styles)): | |
| x_p = np.load('styles/style-{}-strokes.npy'.format(style)) | |
| c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8') | |
| c_p = str(c_p) + " " + cs | |
| c_p = drawing.encode_ascii(c_p) | |
| c_p = np.array(c_p) | |
| x_prime[i, :len(x_p), :] = x_p | |
| x_prime_len[i] = len(x_p) | |
| chars[i, :len(c_p)] = c_p | |
| chars_len[i] = len(c_p) | |
| else: | |
| for i in range(num_samples): | |
| encoded = drawing.encode_ascii(lines[i]) | |
| chars[i, :len(encoded)] = encoded | |
| chars_len[i] = len(encoded) | |
| [samples] = self.nn.session.run( | |
| [self.nn.sampled_sequence], | |
| feed_dict={ | |
| self.nn.prime: styles is not None, | |
| self.nn.x_prime: x_prime, | |
| self.nn.x_prime_len: x_prime_len, | |
| self.nn.num_samples: num_samples, | |
| self.nn.sample_tsteps: max_tsteps, | |
| self.nn.c: chars, | |
| self.nn.c_len: chars_len, | |
| self.nn.bias: biases | |
| } | |
| ) | |
| samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples] | |
| return samples | |
| def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None): | |
| stroke_colors = stroke_colors or ['black']*len(lines) | |
| stroke_widths = stroke_widths or [4]*len(lines) # Increased default from 2 to 4 | |
| line_height = 80 # Increased from 60 to 80 | |
| view_width = 600 # Reduced from 1000 to 800 | |
| view_height = line_height*(len(strokes) + 1) | |
| dwg = svgwrite.Drawing(filename=filename) | |
| dwg.viewbox(width=view_width, height=view_height) | |
| dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white')) | |
| initial_coord = np.array([0, -(3*line_height / 4)]) | |
| for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths): | |
| if not line: | |
| initial_coord[1] -= line_height | |
| continue | |
| offsets[:, :2] *= 3.0 # Increased from 1.5 to 3.0 | |
| strokes = drawing.offsets_to_coords(offsets) | |
| strokes = drawing.denoise(strokes) | |
| strokes[:, :2] = drawing.align(strokes[:, :2]) | |
| strokes[:, 1] *= -1 | |
| strokes[:, :2] -= strokes[:, :2].min() + initial_coord | |
| # Center the text horizontally on the canvas (original centering logic) | |
| strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2 | |
| prev_eos = 1.0 | |
| p = "M{},{} ".format(0, 0) | |
| for x, y, eos in zip(*strokes.T): | |
| p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y) | |
| prev_eos = eos | |
| path = svgwrite.path.Path(p) | |
| path = path.stroke(color=color, width=width, linecap='round').fill("none") | |
| dwg.add(path) | |
| initial_coord[1] -= line_height | |
| dwg.save() |