|
|
import os |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow import keras |
|
|
from keras.layers import Dense, Flatten, Dropout, Embedding,\ |
|
|
Add, MultiHeadAttention, LayerNormalization, Input, Softmax |
|
|
import sys |
|
|
|
|
|
from constants import * |
|
|
from tokens import pretty_tokens, rhymeMeterFromTokens |
|
|
|
|
|
EPOCHS = 10 |
|
|
WARMUP_STEPS = 800 |
|
|
EMBED_DIM = 512 |
|
|
TRANSFORMER_LAYERS = 8 |
|
|
TRANSFORMER_DFF = 1024 |
|
|
RHYME_METER_DFF = 64 |
|
|
TRANSFORMER_HEADS = 4 |
|
|
VAL_SPLIT = 0.2 |
|
|
BATCH_SIZE = 256 |
|
|
SAVE_AT_END = False |
|
|
VERBOSE = False |
|
|
TRAINING = True |
|
|
|
|
|
if '--epochs' in sys.argv: |
|
|
EPOCHS = int(sys.argv[sys.argv.index('--epochs')+1]) |
|
|
if '--warmup-steps' in sys.argv: |
|
|
WARMUP_STEPS = int(sys.argv[sys.argv.index('--warmup-steps')+1]) |
|
|
if '--embed-dim' in sys.argv: |
|
|
EMBED_DIM = int(sys.argv[sys.argv.index('--embed-dim')+1]) |
|
|
if '--transformer-layers' in sys.argv: |
|
|
TRANSFORMER_LAYERS = int(sys.argv[sys.argv.index('--transformer-layers')+1]) |
|
|
if '--transformer-dff' in sys.argv: |
|
|
TRANSFORMER_DFF = int(sys.argv[sys.argv.index('--transformer-dff')+1]) |
|
|
if '--rhyme-meter-dff' in sys.argv: |
|
|
RHYME_METER_DFF = int(sys.argv[sys.argv.index('--rhyme-meter-dff')+1]) |
|
|
if '--transformer-heads' in sys.argv: |
|
|
TRANSFORMER_HEADS = int(sys.argv[sys.argv.index('--transformer-heads')+1]) |
|
|
if '--val-split' in sys.argv: |
|
|
VAL_SPLIT = float(sys.argv[sys.argv.index('--val-split')+1]) |
|
|
if '--batch-size' in sys.argv: |
|
|
BATCH_SIZE = int(sys.argv[sys.argv.index('--batch-size')+1]) |
|
|
if '--save-at-end' in sys.argv: |
|
|
SAVE_AT_END = True |
|
|
if '--verbose' in sys.argv: |
|
|
VERBOSE = True |
|
|
if '--load' in sys.argv: |
|
|
TRAINING = False |
|
|
|
|
|
N = NGRAM_N if MODEL_TYPE == 'n' else TRANSFORMER_N |
|
|
VOCAB = list(np.load('lemmas/lemmas.npy')) |
|
|
TEST_PROMPT = '<title> stop =ing by woods on a snowy evening <newline> '+\ |
|
|
'whose woods these are i think i know <newline> '+\ |
|
|
'his house is in the village though <newline> he' |
|
|
|
|
|
def sampleVocab(dist, temperature): |
|
|
temperature = 1e-8 if temperature == 0 else temperature |
|
|
dist = np.power(dist, temperature) |
|
|
dist /= np.sum(dist) |
|
|
sample = np.random.choice(np.arange(VOCAB_SIZE), p=dist) |
|
|
return sample |
|
|
|
|
|
def genTokens(model, tokens, temperature=0.7, prompt=None): |
|
|
res = [model.vocab.index(TITLE.lower()[1:-1])] |
|
|
if prompt is not None: |
|
|
res = [model.vocab.index(x) for x in prompt.split(' ') if x in model.vocab] |
|
|
for _ in range(tokens): |
|
|
pred = model.generate(res, temperature) |
|
|
assert pred is not None |
|
|
res.append(pred) |
|
|
res = list(map(lambda token: model.vocab[token], res)) |
|
|
return res |
|
|
|
|
|
class LinearModel(keras.Model): |
|
|
def __init__(self): |
|
|
super(LinearModel, self).__init__() |
|
|
self.vocab = VOCAB |
|
|
self.seq = keras.Sequential([ |
|
|
Input(shape=(NGRAM_N-1, VOCAB_SIZE)), |
|
|
Flatten(), |
|
|
Dense(1024, activation='relu'), |
|
|
Dense(1024, activation='relu'), |
|
|
Dense(2048, activation='relu'), |
|
|
Dropout(0.2), |
|
|
Dense(VOCAB_SIZE, activation='softmax') |
|
|
]) |
|
|
|
|
|
def call(self, input): |
|
|
x = tf.one_hot(input, VOCAB_SIZE) |
|
|
x = self.seq(x) |
|
|
return x |
|
|
|
|
|
def generate(self, fullContext, temperature=0.7): |
|
|
context = fullContext[-(N-1):] |
|
|
while len(context) > NGRAM_N-1: |
|
|
context.pop(0) |
|
|
while len(context) < NGRAM_N-1: |
|
|
context.append(-1) |
|
|
context = np.asarray([context]) |
|
|
pred = self.call(context)[0] |
|
|
pred = sampleVocab(pred, temperature) |
|
|
return pred |
|
|
|
|
|
|
|
|
def positional_encoding(length, depth): |
|
|
depth = depth / 2 |
|
|
positions = np.arange(length)[:, np.newaxis] |
|
|
depths = np.arange(depth)[np.newaxis, :]/depth |
|
|
angle_rates = 1 / (10000**depths) |
|
|
angle_rads = positions * angle_rates |
|
|
pos_encoding = np.concatenate( |
|
|
[np.sin(angle_rads), np.cos(angle_rads)], |
|
|
axis=-1) |
|
|
return tf.cast(pos_encoding, dtype=tf.float32) |
|
|
|
|
|
class InputEmbedding(keras.layers.Layer): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.embed = Embedding(input_dim=VOCAB_SIZE+1, output_dim=EMBED_DIM) |
|
|
self.pos = positional_encoding(length=TRANSFORMER_N, depth=EMBED_DIM) |
|
|
self.add = Add() |
|
|
self.dropout = Dropout(0.1) |
|
|
def call(self, input): |
|
|
length = tf.shape(input)[1] |
|
|
x = self.embed(input) |
|
|
x *= tf.math.sqrt(tf.cast(EMBED_DIM, tf.float32)) |
|
|
x = self.add([x, self.pos[tf.newaxis, :length, :]]) |
|
|
x = self.dropout(x) |
|
|
return x |
|
|
|
|
|
class AttentionBlock(keras.layers.Layer): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__() |
|
|
self.mha = MultiHeadAttention(**kwargs) |
|
|
self.dropout = Dropout(0.1) |
|
|
self.norm = LayerNormalization() |
|
|
self.add = Add() |
|
|
def call(self, input): |
|
|
x = self.mha(query=input, value=input, key=input, use_causal_mask=True) |
|
|
x = self.dropout(x) |
|
|
x = self.add([input, x]) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class FeedForward(keras.layers.Layer): |
|
|
def __init__(self, dff): |
|
|
super().__init__() |
|
|
self.seq = keras.Sequential([ |
|
|
Dense(dff, activation='relu'), |
|
|
Dense(EMBED_DIM), |
|
|
Dropout(0.1) |
|
|
]) |
|
|
self.add = Add() |
|
|
self.norm = LayerNormalization() |
|
|
def call(self, input): |
|
|
x = self.add([input, self.seq(input)]) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class Decoder(keras.layers.Layer): |
|
|
def __init__(self, *, num_layers, num_heads, dff): |
|
|
super(Decoder, self).__init__() |
|
|
attention = [] |
|
|
for _ in range(num_layers): |
|
|
attention.append(AttentionBlock(num_heads=num_heads, key_dim=EMBED_DIM, dropout=0.1)) |
|
|
self.attn_seq = keras.Sequential(attention) |
|
|
self.ffn = FeedForward(dff) |
|
|
def call(self, input): |
|
|
x = self.attn_seq(input) |
|
|
x = self.ffn(x) |
|
|
return x |
|
|
|
|
|
class TransformerModel(keras.Model): |
|
|
def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF): |
|
|
super(TransformerModel, self).__init__() |
|
|
self.vocab = VOCAB |
|
|
self.embed = InputEmbedding() |
|
|
self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff) |
|
|
self.out = Dense(VOCAB_SIZE, activation='softmax') |
|
|
|
|
|
def call(self, input): |
|
|
x = self.embed(input) |
|
|
x = self.decoder(x) |
|
|
x = self.out(x) |
|
|
try: |
|
|
del x._keras_mask |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
return x |
|
|
|
|
|
def generate(self, fullContext, temperature=0.7): |
|
|
context = fullContext[-N:] |
|
|
lastToken = len(context)-1 |
|
|
while len(context) > TRANSFORMER_N: |
|
|
context.pop(0) |
|
|
while len(context) < TRANSFORMER_N: |
|
|
context.append(-1) |
|
|
context = np.asarray([context])+1 |
|
|
pred = self.call(context)[0] |
|
|
pred = pred[lastToken] |
|
|
pred = sampleVocab(pred, temperature) |
|
|
return pred |
|
|
|
|
|
|
|
|
def rhyme_meter_encoding(input): |
|
|
vowels = input[:,:,:RHYME_STACK_SIZE-1] |
|
|
consonants = input[:,:,RHYME_STACK_SIZE-1:(RHYME_STACK_SIZE-1)*2] |
|
|
rhyme_match = input[:,:,(RHYME_STACK_SIZE-1)*2:(RHYME_STACK_SIZE-1)*3] |
|
|
vowels = tf.cast(vowels, tf.int8) |
|
|
consonants = tf.cast(consonants, tf.int8) |
|
|
vowels = tf.one_hot(vowels, depth=VOWEL_TYPES) |
|
|
consonants = tf.one_hot(consonants, depth=CONSONANT_TYPES) |
|
|
vowels = tf.reshape(vowels, shape=(tf.shape(vowels)[0], tf.shape(vowels)[1], -1)) |
|
|
consonants = tf.reshape(consonants, shape=(tf.shape(consonants)[0], tf.shape(consonants)[1], -1)) |
|
|
meter = input[:,:,-METER_STACK_SIZE:] |
|
|
vowels = tf.cast(vowels, tf.float32) |
|
|
consonants = tf.cast(consonants, tf.float32) |
|
|
rhyme_match = tf.cast(rhyme_match, tf.float32) |
|
|
meter = tf.cast(meter, tf.float32) |
|
|
rhyme = tf.concat([vowels, consonants, rhyme_match], axis=2) |
|
|
return rhyme, meter |
|
|
|
|
|
class RhymeMeterLayer(keras.layers.Layer): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.dense_r1 = Dense(RHYME_METER_DFF, activation='relu') |
|
|
self.dense_m1 = Dense(RHYME_METER_DFF//2, activation='relu') |
|
|
self.dense_r2 = Dense(RHYME_METER_DFF, activation='relu') |
|
|
|
|
|
self.dense_3 = Dense(RHYME_METER_DFF*2, activation='relu') |
|
|
self.dense_final = Dense(VOCAB_SIZE) |
|
|
def call(self, input): |
|
|
rhyme, meter = rhyme_meter_encoding(input) |
|
|
rhyme = self.dense_r1(rhyme) |
|
|
rhyme = self.dense_r2(rhyme) |
|
|
meter = self.dense_m1(meter) |
|
|
|
|
|
x = tf.concat([rhyme, meter], axis=2) |
|
|
x = self.dense_3(x) |
|
|
x = self.dense_final(x) |
|
|
return x |
|
|
|
|
|
class BardModel(keras.Model): |
|
|
def __init__(self, *, num_layers=TRANSFORMER_LAYERS, num_heads=TRANSFORMER_HEADS, dff=TRANSFORMER_DFF): |
|
|
super(BardModel, self).__init__() |
|
|
self.vocab = VOCAB |
|
|
self.tl = VOCAB.index(TITLE.lower()[1:-1]) |
|
|
self.rhyme_types = max(VOWEL_TYPES, CONSONANT_TYPES) |
|
|
self.embed = InputEmbedding() |
|
|
self.decoder = Decoder(num_layers=num_layers, num_heads=num_heads, dff=dff) |
|
|
self.transformer_pred = Dense(VOCAB_SIZE) |
|
|
self.rhyme_meter_pred = RhymeMeterLayer() |
|
|
self.add = Add() |
|
|
self.softmax = Softmax() |
|
|
|
|
|
def call(self, input): |
|
|
x = self.embed(input[0]) |
|
|
x = self.decoder(x) |
|
|
x = self.transformer_pred(x) |
|
|
try: |
|
|
del x._keras_mask |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
rhyme_meter_x = self.rhyme_meter_pred(input[1]) |
|
|
x = self.add([x, rhyme_meter_x]) |
|
|
x = self.softmax(x) |
|
|
return x |
|
|
|
|
|
def generate(self, fullContext, temperature=0.7): |
|
|
context = fullContext[-N:] |
|
|
lastToken = len(context)-1 |
|
|
while len(context) > TRANSFORMER_N: |
|
|
context.pop(0) |
|
|
while len(context) < TRANSFORMER_N: |
|
|
context.append(-1) |
|
|
context = np.asarray([context])+1 |
|
|
rm = rhymeMeterFromTokens(fullContext, len(fullContext), self.tl, self.vocab) |
|
|
rm = np.asarray([rm]) |
|
|
pred = self.call([context, rm])[0] |
|
|
pred = pred[lastToken] |
|
|
pred = sampleVocab(pred, temperature) |
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
|
def __init__(self, d_model, warmup_steps=WARMUP_STEPS): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.d_model = tf.cast(self.d_model, tf.float32) |
|
|
|
|
|
self.warmup_steps = warmup_steps |
|
|
|
|
|
def __call__(self, step): |
|
|
step = tf.cast(step, dtype=tf.float32) |
|
|
arg1 = tf.math.rsqrt(step) |
|
|
arg2 = step * (self.warmup_steps ** -1.5) |
|
|
|
|
|
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) |
|
|
|
|
|
|
|
|
def sparse_loss(y_true, y_pred): |
|
|
loss_obj = keras.losses.SparseCategoricalCrossentropy(ignore_class=-1, reduction='none') |
|
|
loss = loss_obj(y_true, y_pred) |
|
|
return loss |
|
|
def sparse_perplexity(y_true, y_pred): |
|
|
return tf.math.exp(tf.math.reduce_mean(sparse_loss(y_true, y_pred))) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
fname = {'n': 'inputs/ngram_train.npz', |
|
|
't': 'inputs/transformer_train.npz', |
|
|
'b': 'inputs/bard_train.npz' |
|
|
}[MODEL_TYPE] |
|
|
if TRAINING: |
|
|
print("Loading data from", fname) |
|
|
loaded = np.load(fname) |
|
|
train_x = loaded['x'] |
|
|
train_y = loaded['y'] |
|
|
if MODEL_TYPE == 'b': |
|
|
train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] |
|
|
if MODEL_TYPE == 'n': |
|
|
train_x = tf.convert_to_tensor(train_x, tf.int32) |
|
|
del loaded |
|
|
|
|
|
if VERBOSE: |
|
|
if MODEL_TYPE != 'b': |
|
|
print("X:", train_x[10:14]) |
|
|
else: |
|
|
print("X:", train_x[0][10:14]) |
|
|
print("RM:", train_x[1][10:14][1]) |
|
|
print("Y:", train_y[10:14]) |
|
|
if MODEL_TYPE != 'b': |
|
|
print("X shape:", train_x.shape) |
|
|
print("Y shape:", train_y.shape) |
|
|
|
|
|
print("Initializing model") |
|
|
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel} |
|
|
model = models[MODEL_TYPE]() |
|
|
if MODEL_TYPE != 'b': |
|
|
x0 = np.zeros((1,NGRAM_N-1 if MODEL_TYPE=='n' else TRANSFORMER_N)) |
|
|
res = model(x0) |
|
|
else: |
|
|
x0 = np.zeros((1,TRANSFORMER_N)) |
|
|
x1 = np.zeros((1,TRANSFORMER_N,RHYME_STACK_SIZE*2+METER_STACK_SIZE)) |
|
|
res = model([x0, x1]) |
|
|
if VERBOSE: |
|
|
print(model) |
|
|
print(res) |
|
|
print(model.summary()) |
|
|
|
|
|
if TRAINING: |
|
|
print("Compiling model") |
|
|
learning_rate = CustomSchedule(EMBED_DIM) |
|
|
model.compile(optimizer=keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9), |
|
|
loss=sparse_loss, metrics=[sparse_perplexity]) |
|
|
|
|
|
print("Generating sample from baseline") |
|
|
print(pretty_tokens(genTokens(model, 25))) |
|
|
|
|
|
print("Training model") |
|
|
min_perplexity = None |
|
|
if not os.path.exists('saved_models'): |
|
|
os.mkdir('saved_models') |
|
|
class TrainCallback(keras.callbacks.Callback): |
|
|
def on_epoch_end(self, epoch, logs=None): |
|
|
global min_perplexity |
|
|
perplexity = logs['val_sparse_perplexity'] if VAL_SPLIT > 0 else logs['sparse_perplexity'] |
|
|
print("\rGenerating sample from model in training: "+ |
|
|
"epoch "+str(epoch+1)+", perplexity "+str(round(perplexity, 2)), end='') |
|
|
print(pretty_tokens(genTokens(model, 75))) |
|
|
if (min_perplexity is None or perplexity <= min_perplexity) and not SAVE_AT_END: |
|
|
min_perplexity = perplexity |
|
|
print("Saving model weights") |
|
|
model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5') |
|
|
|
|
|
model.fit(train_x, train_y, |
|
|
batch_size=BATCH_SIZE, validation_split=VAL_SPLIT, epochs=EPOCHS, |
|
|
callbacks=[TrainCallback()]) |
|
|
|
|
|
if SAVE_AT_END: |
|
|
print("Saving final model weights") |
|
|
model.save_weights('saved_models/'+MODEL_TYPE+'_model.h5') |
|
|
|
|
|
print("Generating samples from final model") |
|
|
if VERBOSE: |
|
|
for i in range(10): |
|
|
print(pretty_tokens(genTokens(model, 100))) |
|
|
print(pretty_tokens(genTokens(model, 150, prompt=TEST_PROMPT))) |
|
|
print(pretty_tokens(genTokens(model, 500))) |
|
|
print(pretty_tokens(genTokens(model, 500))) |
|
|
|
|
|
else: |
|
|
print("Loading weights") |
|
|
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5') |
|
|
|
|
|
while True: |
|
|
temp = 0.7 |
|
|
print("Commands:\ng: generate sample with 250 tokens\nl: generate sample with custom length\np: generate sample with prompt\nt: set temperature\nq: quit") |
|
|
cmd = input("Enter command: ") |
|
|
try: |
|
|
if cmd == 'g': |
|
|
print("Generating sample...") |
|
|
print(pretty_tokens(genTokens(model, 250, temperature=temp))) |
|
|
if cmd == 'l': |
|
|
length = int(input("Enter length: ")) |
|
|
print("Generating sample...") |
|
|
print(pretty_tokens(genTokens(model, length, temperature=temp))) |
|
|
if cmd == 'p': |
|
|
prompt = "" |
|
|
print("Enter prompt as tokens separated by spaces and newlines.") |
|
|
print("Example: <title> stop =ing by woods on a snowy evening\nwhose woods these are i think i know") |
|
|
print("All tokens not in the vocabulary will be ignored.") |
|
|
while not prompt.endswith('\n\n\n'): |
|
|
prompt += input("")+'\n' |
|
|
while prompt.startswith(' ') or prompt.startswith('\n'): |
|
|
prompt = prompt[1:] |
|
|
while prompt.endswith(' ') or prompt.endswith('\n'): |
|
|
prompt = prompt[:-1] |
|
|
prompt = prompt.replace('\n', NEWLINE.lower()) |
|
|
length = int(input("Enter length: ")) |
|
|
print("Generating sample...") |
|
|
print(pretty_tokens(genTokens(model, length, temperature=temp, prompt=prompt))) |
|
|
if cmd == 't': |
|
|
print("Current temperature:", temp) |
|
|
temp = float(input("New temperature: ")) |
|
|
print("Temperature set to", temp) |
|
|
if cmd == 'q': |
|
|
sys.exit(0) |
|
|
except Exception as e: |
|
|
print("Error:", e) |