File size: 2,166 Bytes
90874dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import argparse
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint
from datasets import load_dataset
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Train a text generation model.')
parser.add_argument('filename', help='Name of the text file to train on')
parser.add_argument('-s', '--save', help='Name of the file to save the trained model')
parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs')
args = parser.parse_args()
dataset = load_dataset("code_search_net", "python")
# Load the text file
with open(args.filename, 'r', encoding='utf-8') as f:
text = f.read()
# Create character mappings
chars = sorted(list(set(text)))
char_to_index = {char: index for index, char in enumerate(chars)}
index_to_char = {index: char for index, char in enumerate(chars)}
# Prepare the training data
max_sequence_length = 100
step = 1
sequences = []
next_chars = []
for i in range(0, len(text) - max_sequence_length, step):
sequences.append(text[i:i+max_sequence_length])
next_chars.append(text[i+max_sequence_length])
X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool)
y = np.zeros((len(sequences), len(chars)), dtype=np.bool)
for i, sequence in enumerate(sequences):
for t, char in enumerate(sequence):
X[i, t, char_to_index[char]] = 1
y[i, char_to_index[next_chars[i]]] = 1
# Define the model
model = Sequential([
LSTM(128, input_shape=(max_sequence_length, len(chars))),
Dense(len(chars), activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam')
# Train the model
if args.save:
checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list)
else:
history = model.fit(X, y, batch_size=128, epochs=args.epochs)
# Print the final loss
print(f"Final loss: {history.history['loss'][-1]:.4f}")
|