roseai / train.py
cross-hedgehog's picture
Create train.py
90874dd
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint
from datasets import load_dataset
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Train a text generation model.')
parser.add_argument('filename', help='Name of the text file to train on')
parser.add_argument('-s', '--save', help='Name of the file to save the trained model')
parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs')
args = parser.parse_args()
dataset = load_dataset("code_search_net", "python")
# Load the text file
with open(args.filename, 'r', encoding='utf-8') as f:
text = f.read()
# Create character mappings
chars = sorted(list(set(text)))
char_to_index = {char: index for index, char in enumerate(chars)}
index_to_char = {index: char for index, char in enumerate(chars)}
# Prepare the training data
max_sequence_length = 100
step = 1
sequences = []
next_chars = []
for i in range(0, len(text) - max_sequence_length, step):
sequences.append(text[i:i+max_sequence_length])
next_chars.append(text[i+max_sequence_length])
X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool)
y = np.zeros((len(sequences), len(chars)), dtype=np.bool)
for i, sequence in enumerate(sequences):
for t, char in enumerate(sequence):
X[i, t, char_to_index[char]] = 1
y[i, char_to_index[next_chars[i]]] = 1
# Define the model
model = Sequential([
LSTM(128, input_shape=(max_sequence_length, len(chars))),
Dense(len(chars), activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam')
# Train the model
if args.save:
checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list)
else:
history = model.fit(X, y, batch_size=128, epochs=args.epochs)
# Print the final loss
print(f"Final loss: {history.history['loss'][-1]:.4f}")