Commit ·
ceea12e
1
Parent(s): 1427339
Upload new model.py
Browse files
model.py
CHANGED
|
@@ -314,35 +314,37 @@ if __name__ == '__main__':
|
|
| 314 |
't': 'inputs/transformer_train.npz',
|
| 315 |
'b': 'inputs/bard_train.npz'
|
| 316 |
}[MODEL_TYPE]
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
|
| 338 |
print("Initializing model")
|
| 339 |
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
|
| 340 |
model = models[MODEL_TYPE]()
|
| 341 |
if MODEL_TYPE != 'b':
|
| 342 |
-
|
|
|
|
| 343 |
else:
|
| 344 |
-
x0 =
|
| 345 |
-
x1 =
|
| 346 |
res = model([x0, x1])
|
| 347 |
if VERBOSE:
|
| 348 |
print(model)
|
|
@@ -391,8 +393,6 @@ if __name__ == '__main__':
|
|
| 391 |
print(pretty_tokens(genTokens(model, 500)))
|
| 392 |
|
| 393 |
else:
|
| 394 |
-
del train_x
|
| 395 |
-
del train_y
|
| 396 |
print("Loading weights")
|
| 397 |
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
| 398 |
|
|
|
|
| 314 |
't': 'inputs/transformer_train.npz',
|
| 315 |
'b': 'inputs/bard_train.npz'
|
| 316 |
}[MODEL_TYPE]
|
| 317 |
+
if TRAINING:
|
| 318 |
+
print("Loading data from", fname)
|
| 319 |
+
loaded = np.load(fname)
|
| 320 |
+
train_x = loaded['x']
|
| 321 |
+
train_y = loaded['y']
|
| 322 |
+
if MODEL_TYPE == 'b':
|
| 323 |
+
train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
|
| 324 |
+
if MODEL_TYPE == 'n':
|
| 325 |
+
train_x = tf.convert_to_tensor(train_x, tf.int32)
|
| 326 |
+
del loaded
|
| 327 |
|
| 328 |
+
if VERBOSE:
|
| 329 |
+
if MODEL_TYPE != 'b':
|
| 330 |
+
print("X:", train_x[10:14])
|
| 331 |
+
else:
|
| 332 |
+
print("X:", train_x[0][10:14])
|
| 333 |
+
print("RM:", train_x[1][10:14][1])
|
| 334 |
+
print("Y:", train_y[10:14])
|
| 335 |
+
if MODEL_TYPE != 'b':
|
| 336 |
+
print("X shape:", train_x.shape)
|
| 337 |
+
print("Y shape:", train_y.shape)
|
| 338 |
|
| 339 |
print("Initializing model")
|
| 340 |
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
|
| 341 |
model = models[MODEL_TYPE]()
|
| 342 |
if MODEL_TYPE != 'b':
|
| 343 |
+
x0 = np.zeros((1,NGRAM_N-1 if MODEL_TYPE=='n' else TRANSFORMER_N))
|
| 344 |
+
res = model(x0)
|
| 345 |
else:
|
| 346 |
+
x0 = np.zeros((1,TRANSFORMER_N))
|
| 347 |
+
x1 = np.zeros((1,TRANSFORMER_N,RHYME_STACK_SIZE*2+METER_STACK_SIZE))
|
| 348 |
res = model([x0, x1])
|
| 349 |
if VERBOSE:
|
| 350 |
print(model)
|
|
|
|
| 393 |
print(pretty_tokens(genTokens(model, 500)))
|
| 394 |
|
| 395 |
else:
|
|
|
|
|
|
|
| 396 |
print("Loading weights")
|
| 397 |
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
| 398 |
|