|
|
import argparse |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.layers import LSTM, Dense |
|
|
from tensorflow.keras.models import Sequential |
|
|
from tensorflow.keras.callbacks import ModelCheckpoint |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Train a text generation model.') |
|
|
parser.add_argument('filename', help='Name of the text file to train on') |
|
|
parser.add_argument('-s', '--save', help='Name of the file to save the trained model') |
|
|
parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs') |
|
|
args = parser.parse_args() |
|
|
|
|
|
dataset = load_dataset("code_search_net", "python") |
|
|
|
|
|
with open(args.filename, 'r', encoding='utf-8') as f: |
|
|
text = f.read() |
|
|
|
|
|
|
|
|
chars = sorted(list(set(text))) |
|
|
char_to_index = {char: index for index, char in enumerate(chars)} |
|
|
index_to_char = {index: char for index, char in enumerate(chars)} |
|
|
|
|
|
|
|
|
max_sequence_length = 100 |
|
|
step = 1 |
|
|
sequences = [] |
|
|
next_chars = [] |
|
|
for i in range(0, len(text) - max_sequence_length, step): |
|
|
sequences.append(text[i:i+max_sequence_length]) |
|
|
next_chars.append(text[i+max_sequence_length]) |
|
|
X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool) |
|
|
y = np.zeros((len(sequences), len(chars)), dtype=np.bool) |
|
|
for i, sequence in enumerate(sequences): |
|
|
for t, char in enumerate(sequence): |
|
|
X[i, t, char_to_index[char]] = 1 |
|
|
y[i, char_to_index[next_chars[i]]] = 1 |
|
|
|
|
|
|
|
|
model = Sequential([ |
|
|
LSTM(128, input_shape=(max_sequence_length, len(chars))), |
|
|
Dense(len(chars), activation='softmax') |
|
|
]) |
|
|
model.compile(loss='categorical_crossentropy', optimizer='adam') |
|
|
|
|
|
|
|
|
if args.save: |
|
|
checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min') |
|
|
callbacks_list = [checkpoint] |
|
|
history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list) |
|
|
else: |
|
|
history = model.fit(X, y, batch_size=128, epochs=args.epochs) |
|
|
|
|
|
|
|
|
print(f"Final loss: {history.history['loss'][-1]:.4f}") |
|
|
|