| import numpy as np |
| import os |
| import sys |
| from constants import * |
| if __name__ == '__main__': |
| from threading import Thread |
|
|
| N_THREADS = 32 |
| if '--n_threads' in sys.argv: |
| N_THREADS = int(sys.argv[sys.argv.index('--n_threads')+1]) |
|
|
| if __name__ == '__main__': |
| if not os.path.exists('lemmas'): |
| os.mkdir('lemmas') |
|
|
| file = open("inputs/join.txt" if not KAGGLE else "inputs/join-kaggle.txt", "r") |
| text = file.read() |
| file.close() |
|
|
| tokens = text.split(" ") |
| tokens = [x for x in tokens if x != ''] |
| print("Total number of tokens:", len(tokens)) |
| print("Counting tokens") |
| counts = {} |
| for token in tokens: |
| if not token in counts: |
| counts[token] = 0 |
| counts[token] += 1 |
| words = list(counts.keys()) |
| words.sort(reverse=True, key=lambda word: counts[word]) |
|
|
| for token in BANNED_TOKENS: |
| if token in words: |
| words.remove(token) |
| words.append(token) |
| counts['<unk>'] = 0 |
| for word in words: |
| if word in words[:VOCAB_SIZE]: |
| continue |
| counts['<unk>'] += counts[word] |
| words = list(counts.keys()) |
| words.sort(reverse=True, key=lambda word: counts[word]) |
| for token in BANNED_TOKENS: |
| if token in words: |
| words.remove(token) |
| words.append(token) |
|
|
| vocab = set(words[:VOCAB_SIZE]) |
| else: |
| print("Loading vocab") |
| vocab = set(list(np.load('lemmas/lemmas.npy'))) |
|
|
| def pretty_tokens(tokens, mask=True): |
| s_dict = np.load('lemmas/s.npy', allow_pickle=True).item() |
| ed_dict = np.load('lemmas/ed.npy', allow_pickle=True).item() |
| er_dict = np.load('lemmas/er.npy', allow_pickle=True).item() |
| est_dict = np.load('lemmas/est.npy', allow_pickle=True).item() |
| ing_dict = np.load('lemmas/ing.npy', allow_pickle=True).item() |
| dicts = {'=s': s_dict, '=ed': ed_dict, '=er': er_dict, '=est': est_dict, '=ing': ing_dict} |
| res = [] |
| i = 0 |
| def includeSpace(this): |
| nonlocal res |
| quote = set(["'", '"']) |
| nospace = set(['\n','-']) |
| prev = res[len(res)-1] if len(res) > 0 else None |
| prev2 = res[len(res)-2] if len(res) > 1 else None |
| space = not prev in nospace\ |
| and not this in PUNCT and not this == '\n'\ |
| and not (this.startswith("'") and this != "'") |
| if prev in quote and not prev2 in PUNCT: |
| space = False |
| elif this in quote and prev in PUNCT: |
| space = False |
| return space |
| while i < len(tokens): |
| this = tokens[i] |
| if this == NEWLINE.lower()[1:-1]: |
| this = '\n' |
| elif this == TITLE.lower()[1:-1]: |
| this = '\n ༄༅༅ ' |
| elif mask and not this in vocab: |
| this = " <unk>" |
| if not includeSpace(this): |
| this = "<unk>" |
| res.append(this) |
| i += 1 |
| continue |
| if i+1 < len(tokens): |
| next = tokens[i+1] |
| while next.startswith('='): |
| if next == "=nt": |
| if tokens[i].endswith('n'): |
| this = this[:-1] |
| if tokens[i] == 'will': |
| this = 'wo' |
| elif tokens[i] == 'shall': |
| this = 'sha' |
| this = this+"n't" |
| else: |
| if tokens[i] in dicts[next]: |
| this = dicts[next][this] |
| else: |
| if next[1] == 'e' or next[1] == 'i': |
| if this.endswith('e'): |
| this = this[:-1] |
| elif this.endswith('c'): |
| this = this+'k' |
| if this.endswith('y') and next[1] == 'e' and len(this) > 2 and not this[-2] in VOWELS: |
| this = this[:-1]+'i' |
| if next[1] == 's': |
| if this.endswith('s') or this.endswith('sh') or this.endswith('x') or this.endswith('ch'): |
| this = this+'e' |
| if this.endswith('y') and len(this) > 2 and not this[-2] in VOWELS: |
| this = this[:-1]+'ie' |
| |
| this = this+next[1:] |
| i += 1 |
| next = tokens[i+1] if i+1 < len(tokens) else '' |
| if this.startswith('='): |
| this = this[1:] |
| elif includeSpace(this): |
| this = " "+this |
| res.append(this) |
| i += 1 |
| res = ''.join(res) |
| res = res[1:] if res.startswith(' ') else res |
| return res |
|
|
| def getRhyme(line): |
| |
| |
| |
| if line is None or len(line) == 0: |
| return [-1, -1] |
| nl = NEWLINE.lower()[1:-1] |
| tl = TITLE.lower()[1:-1] |
| if line[0] == tl: |
| return [-1, -1] |
| while line[-1] == nl or line[-1] in PUNCT or line[-1] == '"' or line[-1] == "'" or line[-1] is None: |
| line = line[:-1] |
| if len(line) == 0: |
| return [-1, -1] |
| word = line[-1]+'' |
| long_vowel = False |
| vowel_type = None |
| vowel_map = {'a': 0, 'e': 1, 'i': 2, 'o': 3, 'u': 4, 'ow': 5, 'ou': 5, 'oi': 6, 'oy': 6, |
| 'ay': 7, 'ai': 7, 'au': 3, 'aw': 3, 'ea': 8, 'ee': 8, 'eu': 11, 'ew': 11, |
| 'oa': 10, 'oo': 11, 'y': 9, 'ey': 7, 'ei': 9} |
| |
| |
| |
| |
| |
| |
| consonant_type = -1 |
| cons_map = {'r': 0, 'l': 1, 'n': 2, 'm': 2, 'ng': 2, |
| 'p': 3, 'b': 3, 't': 4, 'd': 4, 'f': 5, |
| 'v': 5, 's': 6, 'sh': 6, 'z': 6, 'zh': 6, |
| 'th': 9, 'k': 7, 'ch': 8, 'j': 8} |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| def getVowel(type, isLong, beforeR): |
| if beforeR and not isLong: |
| if type == 0: |
| return 12 |
| if type ==1 or type == 2 or type == 4: |
| return 13 |
| if type == 3: |
| return 10 |
| if isLong and 0 <= type <= 4: |
| return type+7 |
| return type |
|
|
| lock_consonant = -1 |
| if len(line) > 1: |
| if word == '=ed': |
| if line[-2].endswith('t') or line[-2].endswith('d'): |
| return [4, 4] |
| lock_consonant = 4 |
| word = line[-2] |
| if word == '=s' or word == "'s": |
| if line[-2].endswith('s') or line[-2].endswith('z') or line[-2].endswith('ch') or line[-2].endswith('sh') or line[-2].endswith('x'): |
| return [4, 6] |
| lock_consonant = 6 |
| word = line[-2] |
| elif word == "'re": |
| lock_consonant = 0 |
| word = line[-2] |
| elif word == "'ve": |
| lock_consonant = 5 |
| word = line[-2] |
| elif word == "'ll": |
| lock_consonant = 1 |
| word = line[-2] |
| elif word == "'d": |
| lock_consonant = 4 |
| word = line[-2] |
| elif word == "'m": |
| lock_consonant = 2 |
| word = line[-2] |
| elif word == "=nt'": |
| lock_consonant = 4 |
| word = line[-2] |
| if word in DEFINED_RHYMES: |
| vowel_type = DEFINED_RHYMES[word][0] |
| consonant_type = DEFINED_RHYMES[word][1] if lock_consonant == -1 else lock_consonant |
| return [vowel_type, consonant_type] |
| |
| if word.endswith('o'): |
| return [10, lock_consonant] |
| if word.endswith('bble') or word.endswith('ggle'): |
| return [4, 1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('old'): |
| return [10, 1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ance'): |
| return [0, 6 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ense') or word.endswith('ence'): |
| return [1, 6 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ince'): |
| return [2, 6 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ture') or word.endswith('sure'): |
| return [13, 0 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('all'): |
| return [3, 1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('row') or word.endswith('low'): |
| return [10, lock_consonant] |
| if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS: |
| return [4, 1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('on') and len(word) > 3 and not word.endswith('oon'): |
| return [4, 2 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('al') and len(word) > 3 and not word.endswith('eal'): |
| return [4, 1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ous'): |
| return [4, 6 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ly'): |
| return [8, -1 if lock_consonant == -1 else lock_consonant] |
| if word.endswith('ward'): |
| return [13, 4 if lock_consonant == -1 else lock_consonant] |
| |
| if word.endswith('e'): |
| long_vowel = True |
| word = word[:-1] |
| if lock_consonant == -1: |
| if word[-2:] in cons_map: |
| consonant_type = cons_map[word[-2:]] |
| elif word[-1:] in cons_map: |
| consonant_type = cons_map[word[-1:]] |
| elif word[-1] == 'c' and long_vowel: |
| consonant_type = cons_map['s'] |
| elif word[-1] == 'g' and long_vowel: |
| consonant_type = cons_map['j'] |
| else: |
| consonant_type = lock_consonant |
| |
| lock_r = False |
| if not word[-1] in SOMETIMES_VOWELS: |
| while not word[-1] in SOMETIMES_VOWELS: |
| if word.endswith('igh'): |
| return [9, consonant_type] |
| if word[-1] == 'r': |
| lock_r = True |
| elif lock_r: |
| lock_r = False |
| word = word[:-1] |
| if word == '': |
| return [8, lock_consonant] |
| if word[-2:] in vowel_map: |
| vowel_type = vowel_map[word[-2:]] |
| elif word[-1:] in vowel_map: |
| vowel_type = vowel_map[word[-1:]] |
| |
| vowel_type = getVowel(vowel_type, long_vowel, consonant_type == 0 or lock_r) |
| return [vowel_type, consonant_type] |
| def pretty_rhyme(rhyme): |
| v_map = ['bat', 'bet', 'bit', 'bot', 'but', 'pout', 'boil', 'bait', 'beat', 'bite', 'boat', 'boot', 'bar', 'sir'] |
| c_map = ['R', 'L', 'N/M/NG', 'P/B', 'T/D', 'F/V', 'S/SH/Z/ZH', 'K/G', 'CH/J', 'TH'] |
| return "Rhyme is " +\ |
| (v_map[rhyme[0]] if rhyme[0] != -1 else '--') + ' ' + (c_map[rhyme[1]] if rhyme[1] != -1 else 'ø') |
|
|
|
|
| def getMeter(line): |
| if line is None: |
| return 0 |
| res = 0 |
| nl = NEWLINE.lower()[1:-1] |
| tl = TITLE.lower()[1:-1] |
| for i in range(len(line)): |
| word = line[i] |
| if word == nl or word == tl or word is None: |
| continue |
| if word in DEFINED_METERS: |
| res += DEFINED_METERS[word] |
| continue |
| if word == '=ed' and i > 0: |
| if line[i-1].endswith('t') or line[i-1].endswith('d') or line[i-1].endswith('te') or line[i-1].endswith('de'): |
| res += 1 |
| continue |
| if word == '=s' and i > 0: |
| if line[i-1].endswith('s') or line[i-1].endswith('z') or line[i-1].endswith('ch') or line[i-1].endswith('sh') or line[i-1].endswith('x'): |
| res += 1 |
| continue |
| if word.endswith('le') and len(word) >= 3 and not word[-3] in VOWELS: |
| res += 1 |
| removed_e = False |
| if word.endswith('e'): |
| word = word[:-1] |
| removed_e = True |
| if word.endswith('y') and len(word) > 2 and not word[-2] in VOWELS: |
| word = word[:-1]+'i' |
| word = word.replace('ea','i').replace('ee','i') |
| word = word.replace('ai','i').replace('au','o') |
| word = word.replace('eu','u') |
| word = word.replace('ei','i').replace('ie','i') |
| word = word.replace('oa','o').replace('ou','o') |
| word = word.replace('oi','o').replace('oo','u') |
| if word.endswith('tion') or word.endswith('sion') or word.endswith('tian'): |
| word = word[:-4]+'shun' |
| this_count = 0 |
| for vowel in VOWELS: |
| this_count += word.count(vowel) |
| if removed_e and this_count == 0: |
| this_count = 1 |
| res += this_count |
| return res |
|
|
| def lastLine(tokens, endl): |
| res = [] |
| nl = NEWLINE.lower()[1:-1] |
| i = endl-1 |
| while i > 0: |
| if tokens[i] == nl: |
| break |
| i -= 1 |
| res = tokens[i:endl] |
| if len(res) == 0: |
| res = tokens[:endl] |
| return res |
| def processRhymeStack(rhyme_stack): |
| prev = rhyme_stack[:-1].flatten(order='F') |
| lastRhyme = rhyme_stack[-1] |
| res = np.zeros(RHYME_STACK_SIZE-1) |
| if lastRhyme[0] != -1: |
| for i in range(RHYME_STACK_SIZE-1): |
| if rhyme_stack[i][0] == lastRhyme[0]: |
| res[i] = 1 |
| if rhyme_stack[i][1] == lastRhyme[1]: |
| res[i] = 2 |
| res = np.concatenate([prev, res]) |
| return res |
| def processRhymeMeter(split): |
| in_title = False |
| meter = [] |
| rhymes = [] |
| meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
| rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
| tl = TITLE.lower()[1:-1] |
| nl = NEWLINE.lower()[1:-1] |
| for i in range(len(split)): |
| line = lastLine(split, i) |
| if split[i] == tl: |
| in_title = True |
| meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
| rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
| meter.append(meter_stack.copy()) |
| rhymes.append(processRhymeStack(rhyme_stack)) |
| continue |
| elif in_title and split[i] == nl: |
| in_title = False |
| meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
| meter_stack[-1] = getMeter(line) |
| meter.append(meter_stack.copy()) |
| rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
| rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
| rhymes.append(processRhymeStack(rhyme_stack)) |
| meter_stack = np.zeros(METER_STACK_SIZE, np.int8) |
| rhyme_stack = np.zeros((RHYME_STACK_SIZE, 2), np.int8) - 1 |
| continue |
| if not in_title and split[i] == nl: |
| rhymes.append(processRhymeStack(rhyme_stack)) |
| meter.append(meter_stack.copy()) |
| if split[i-1] != nl: |
| rhyme_stack = np.roll(rhyme_stack, -1, axis=0) |
| rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
| meter_stack = np.roll(meter_stack, -1, axis=0) |
| meter_stack[-1] = getMeter(line) |
| else: |
| meter_stack[-1] = getMeter(line) |
| rhyme_stack[-1] = np.array(getRhyme(line), np.int8) |
| rhymes.append(processRhymeStack(rhyme_stack)) |
| meter.append(meter_stack.copy()) |
| return [rhymes, meter] |
|
|
| def rhymeMeterFromTokens(tokens, endl, tl, vocab=None): |
| |
| res = [] |
| start = endl-1 |
| if len(tokens) >= endl: |
| while start > 0 and tokens[start] != tl: |
| start -= 1 |
| lines = tokens[start:endl] |
| while len(lines) < TRANSFORMER_N: |
| lines.append(None) |
| input_lines = lines if vocab is None else [(vocab[x] if (x is not None and 0 <= x < VOCAB_SIZE) else None) for x in lines] |
| rhymes, meter = processRhymeMeter(input_lines) |
| rhymes = rhymes[-TRANSFORMER_N:] |
| meter = meter[-TRANSFORMER_N:] |
| rhymes = np.array(rhymes) |
| meter = np.array(meter) |
| res = np.concatenate([rhymes, meter], axis=1) |
| return res |
|
|
| if __name__ == '__main__': |
| N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N+1 |
| for i in range(N-1): |
| tokens.append(None) |
| words.remove('<unk>') |
| print({word: counts[word] for word in words[:VOCAB_SIZE]}) |
| title_token = words.index(TITLE.lower()[1:-1]) |
| newline_token = words.index(NEWLINE.lower()[1:-1]) |
|
|
| print("Splitting poems with masked dividers") |
| mask_list = [-1]*N |
| splits = [] |
| chunk_size = len(tokens)//N_THREADS |
| for i in range(N_THREADS): |
| splits.append( |
| tokens[i*chunk_size : (i+1)*chunk_size if i < N_THREADS-1 else len(tokens)]) |
|
|
|
|
| results = [None] * N_THREADS |
| threads = [] |
|
|
| def add_dividers(thread_index, split): |
| i = 1 |
| while i < len(split): |
| if split[i] == title_token: |
| split = split[:i] + mask_list + split[i:] |
| i += N+5 |
| i += 1 |
| results[thread_index] = split |
| return split |
| for i in range(N_THREADS): |
| t = Thread(target=add_dividers, args=(i, splits[i],)) |
| threads.append(t) |
| t.start() |
| tokens = [] |
| for i in range(N_THREADS): |
| threads[i].join() |
| tokens += results[i] |
|
|
| if MODEL_TYPE == 'b': |
| print("Computing rhyme and meter information") |
| split_token_marks = [] |
| split_size = len(tokens)//N_THREADS |
| for i in range(N_THREADS+1): |
| split_token_marks.append(split_size*i) |
| for i in range(1, N_THREADS): |
| while tokens[split_token_marks[i]] != TITLE.lower()[1:-1]: |
| split_token_marks[i] += 1 |
| if split_token_marks[i] >= len(tokens): |
| break |
| meter_data = [] |
| rhymes_data = [] |
| split_token_marks[-1] = len(tokens) |
| split_tokens = [tokens[split_token_marks[i]:split_token_marks[i+1]] for i in range(N_THREADS)] |
| rhyme_meter_res = [None] * N_THREADS |
| threads = [] |
| def rhymeMeterThread(thread_index, split): |
| rhyme_meter_res[thread_index] = processRhymeMeter(split) |
| for i in range(N_THREADS): |
| t = Thread(target=rhymeMeterThread, args=(i, split_tokens[i])) |
| threads.append(t) |
| t.start() |
| for i in range(N_THREADS): |
| threads[i].join() |
| rhymes_data += rhyme_meter_res[i][0] |
| meter_data += rhyme_meter_res[i][1] |
| |
| print("Converting rhyme and meter information") |
| rhymes_data = np.asarray(rhymes_data) |
| meter_data = np.asarray(meter_data) |
| rhyme_meter_data = np.concatenate([rhymes_data, meter_data], axis=1) |
|
|
| print("Masking unknown tokens") |
| tokens = [(words.index(x) if x in vocab else -1) for x in tokens] |
|
|
| print("Creating sets of ngrams") |
| ngrams = [] |
| rm_ngrams = [] |
| for i in range(0, len(tokens)-N, TOKEN_SKIP): |
| ngrams.append(tokens[i:i+N]) |
| if MODEL_TYPE == 'b': |
| rm_ngrams.append(rhyme_meter_data[i:i+N-1,:]) |
| train_x = [] |
| train_y = [] |
| train_rm = [] |
| for i in range(len(ngrams)): |
| sample = ngrams[i][:N] |
| train_x.append(sample[:N-1]) |
| if MODEL_TYPE == 'b': |
| sample_rm = rm_ngrams[i] |
| train_rm.append(sample_rm) |
| if MODEL_TYPE != 'n': |
| train_y.append(sample[1:]) |
| else: |
| train_y.append(sample[N-1]) |
| print("Converting arrays") |
| train_x = np.asarray(train_x) |
| train_y = np.asarray(train_y) |
| if MODEL_TYPE == 'b': |
| train_rm = np.asarray(train_rm, np.int8) |
| if MODEL_TYPE != 'n': |
| train_x += 1 |
| |
| |
| print("Saving data") |
| fname = {'n': 'inputs/ngram_train.npz', |
| 't': 'inputs/transformer_train.npz', |
| 'b': 'inputs/bard_train.npz' |
| }[MODEL_TYPE] |
| if MODEL_TYPE != 'b': |
| np.savez_compressed(fname, x=train_x, y=train_y) |
| else: |
| np.savez_compressed(fname, x=train_x, rm=train_rm, y=train_y) |
| np.save('lemmas/lemmas.npy', words[:VOCAB_SIZE]) |
|
|