Spaces:
Sleeping
Sleeping
| from __future__ import print_function | |
| import os | |
| from xml.etree import ElementTree | |
| import numpy as np | |
| import drawing | |
| def get_stroke_sequence(filename): | |
| tree = ElementTree.parse(filename).getroot() | |
| strokes = [i for i in tree if i.tag == 'StrokeSet'][0] | |
| coords = [] | |
| for stroke in strokes: | |
| for i, point in enumerate(stroke): | |
| coords.append([ | |
| int(point.attrib['x']), | |
| -1*int(point.attrib['y']), | |
| int(i == len(stroke) - 1) | |
| ]) | |
| coords = np.array(coords) | |
| coords = drawing.align(coords) | |
| coords = drawing.denoise(coords) | |
| offsets = drawing.coords_to_offsets(coords) | |
| offsets = offsets[:drawing.MAX_STROKE_LEN] | |
| offsets = drawing.normalize(offsets) | |
| return offsets | |
| def get_ascii_sequences(filename): | |
| sequences = open(filename, 'r').read() | |
| sequences = sequences.replace(r'%%%%%%%%%%%', '\n') | |
| sequences = [i.strip() for i in sequences.split('\n')] | |
| lines = sequences[sequences.index('CSR:') + 2:] | |
| lines = [line.strip() for line in lines if line.strip()] | |
| lines = [drawing.encode_ascii(line)[:drawing.MAX_CHAR_LEN] for line in lines] | |
| return lines | |
| def collect_data(): | |
| fnames = [] | |
| for dirpath, dirnames, filenames in os.walk('data/raw/ascii/'): | |
| if dirnames: | |
| continue | |
| for filename in filenames: | |
| if filename.startswith('.'): | |
| continue | |
| fnames.append(os.path.join(dirpath, filename)) | |
| # low quality samples (selected by collecting samples to | |
| # which the trained model assigned very low likelihood) | |
| blacklist = set(np.load('data/blacklist.npy')) | |
| stroke_fnames, transcriptions, writer_ids = [], [], [] | |
| for i, fname in enumerate(fnames): | |
| print(i, fname) | |
| if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt': | |
| continue | |
| head, tail = os.path.split(fname) | |
| last_letter = os.path.splitext(fname)[0][-1] | |
| last_letter = last_letter if last_letter.isalpha() else '' | |
| line_stroke_dir = head.replace('ascii', 'lineStrokes') | |
| line_stroke_fname_prefix = os.path.split(head)[-1] + last_letter + '-' | |
| if not os.path.isdir(line_stroke_dir): | |
| continue | |
| line_stroke_fnames = sorted([f for f in os.listdir(line_stroke_dir) | |
| if f.startswith(line_stroke_fname_prefix)]) | |
| if not line_stroke_fnames: | |
| continue | |
| original_dir = head.replace('ascii', 'original') | |
| original_xml = os.path.join(original_dir, 'strokes' + last_letter + '.xml') | |
| tree = ElementTree.parse(original_xml) | |
| root = tree.getroot() | |
| general = root.find('General') | |
| if general is not None: | |
| writer_id = int(general[0].attrib.get('writerID', '0')) | |
| else: | |
| writer_id = int('0') | |
| ascii_sequences = get_ascii_sequences(fname) | |
| assert len(ascii_sequences) == len(line_stroke_fnames) | |
| for ascii_seq, line_stroke_fname in zip(ascii_sequences, line_stroke_fnames): | |
| if line_stroke_fname in blacklist: | |
| continue | |
| stroke_fnames.append(os.path.join(line_stroke_dir, line_stroke_fname)) | |
| transcriptions.append(ascii_seq) | |
| writer_ids.append(writer_id) | |
| return stroke_fnames, transcriptions, writer_ids | |
| if __name__ == '__main__': | |
| print('traversing data directory...') | |
| stroke_fnames, transcriptions, writer_ids = collect_data() | |
| print('dumping to numpy arrays...') | |
| x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32) | |
| x_len = np.zeros([len(stroke_fnames)], dtype=np.int16) | |
| c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8) | |
| c_len = np.zeros([len(stroke_fnames)], dtype=np.int8) | |
| w_id = np.zeros([len(stroke_fnames)], dtype=np.int16) | |
| valid_mask = np.zeros([len(stroke_fnames)], dtype=np.bool) | |
| for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)): | |
| if i % 200 == 0: | |
| print(i, '\t', '/', len(stroke_fnames)) | |
| x_i = get_stroke_sequence(stroke_fname) | |
| valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60) | |
| x[i, :len(x_i), :] = x_i | |
| x_len[i] = len(x_i) | |
| c[i, :len(c_i)] = c_i | |
| c_len[i] = len(c_i) | |
| w_id[i] = w_id_i | |
| if not os.path.isdir('data/processed'): | |
| os.makedirs('data/processed') | |
| np.save('data/processed/x.npy', x[valid_mask]) | |
| np.save('data/processed/x_len.npy', x_len[valid_mask]) | |
| np.save('data/processed/c.npy', c[valid_mask]) | |
| np.save('data/processed/c_len.npy', c_len[valid_mask]) | |
| np.save('data/processed/w_id.npy', w_id[valid_mask]) | |