import argparse import numpy as np import tensorflow as tf from tensorflow.keras.layers import LSTM, Dense from tensorflow.keras.models import Sequential from tensorflow.keras.callbacks import ModelCheckpoint from datasets import load_dataset # Parse command-line arguments parser = argparse.ArgumentParser(description='Train a text generation model.') parser.add_argument('filename', help='Name of the text file to train on') parser.add_argument('-s', '--save', help='Name of the file to save the trained model') parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs') args = parser.parse_args() dataset = load_dataset("code_search_net", "python") # Load the text file with open(args.filename, 'r', encoding='utf-8') as f: text = f.read() # Create character mappings chars = sorted(list(set(text))) char_to_index = {char: index for index, char in enumerate(chars)} index_to_char = {index: char for index, char in enumerate(chars)} # Prepare the training data max_sequence_length = 100 step = 1 sequences = [] next_chars = [] for i in range(0, len(text) - max_sequence_length, step): sequences.append(text[i:i+max_sequence_length]) next_chars.append(text[i+max_sequence_length]) X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool) y = np.zeros((len(sequences), len(chars)), dtype=np.bool) for i, sequence in enumerate(sequences): for t, char in enumerate(sequence): X[i, t, char_to_index[char]] = 1 y[i, char_to_index[next_chars[i]]] = 1 # Define the model model = Sequential([ LSTM(128, input_shape=(max_sequence_length, len(chars))), Dense(len(chars), activation='softmax') ]) model.compile(loss='categorical_crossentropy', optimizer='adam') # Train the model if args.save: checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min') callbacks_list = [checkpoint] history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list) else: history = model.fit(X, y, batch_size=128, epochs=args.epochs) # Print the final loss print(f"Final loss: {history.history['loss'][-1]:.4f}")