File size: 2,166 Bytes
90874dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint
from datasets import load_dataset

# Parse command-line arguments
parser = argparse.ArgumentParser(description='Train a text generation model.')
parser.add_argument('filename', help='Name of the text file to train on')
parser.add_argument('-s', '--save', help='Name of the file to save the trained model')
parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs')
args = parser.parse_args()

dataset = load_dataset("code_search_net", "python")
# Load the text file
with open(args.filename, 'r', encoding='utf-8') as f:
    text = f.read()

# Create character mappings
chars = sorted(list(set(text)))
char_to_index = {char: index for index, char in enumerate(chars)}
index_to_char = {index: char for index, char in enumerate(chars)}

# Prepare the training data
max_sequence_length = 100
step = 1
sequences = []
next_chars = []
for i in range(0, len(text) - max_sequence_length, step):
    sequences.append(text[i:i+max_sequence_length])
    next_chars.append(text[i+max_sequence_length])
X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool)
y = np.zeros((len(sequences), len(chars)), dtype=np.bool)
for i, sequence in enumerate(sequences):
    for t, char in enumerate(sequence):
        X[i, t, char_to_index[char]] = 1
    y[i, char_to_index[next_chars[i]]] = 1

# Define the model
model = Sequential([
    LSTM(128, input_shape=(max_sequence_length, len(chars))),
    Dense(len(chars), activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam')

# Train the model
if args.save:
    checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min')
    callbacks_list = [checkpoint]
    history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list)
else:
    history = model.fit(X, y, batch_size=128, epochs=args.epochs)

# Print the final loss
print(f"Final loss: {history.history['loss'][-1]:.4f}")