cross-hedgehog commited on
Commit
90874dd
·
1 Parent(s): 50fb81e

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +57 -0
train.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.layers import LSTM, Dense
5
+ from tensorflow.keras.models import Sequential
6
+ from tensorflow.keras.callbacks import ModelCheckpoint
7
+ from datasets import load_dataset
8
+
9
+ # Parse command-line arguments
10
+ parser = argparse.ArgumentParser(description='Train a text generation model.')
11
+ parser.add_argument('filename', help='Name of the text file to train on')
12
+ parser.add_argument('-s', '--save', help='Name of the file to save the trained model')
13
+ parser.add_argument('-e', '--epochs', type=int, default=100, help='Number of training epochs')
14
+ args = parser.parse_args()
15
+
16
+ dataset = load_dataset("code_search_net", "python")
17
+ # Load the text file
18
+ with open(args.filename, 'r', encoding='utf-8') as f:
19
+ text = f.read()
20
+
21
+ # Create character mappings
22
+ chars = sorted(list(set(text)))
23
+ char_to_index = {char: index for index, char in enumerate(chars)}
24
+ index_to_char = {index: char for index, char in enumerate(chars)}
25
+
26
+ # Prepare the training data
27
+ max_sequence_length = 100
28
+ step = 1
29
+ sequences = []
30
+ next_chars = []
31
+ for i in range(0, len(text) - max_sequence_length, step):
32
+ sequences.append(text[i:i+max_sequence_length])
33
+ next_chars.append(text[i+max_sequence_length])
34
+ X = np.zeros((len(sequences), max_sequence_length, len(chars)), dtype=np.bool)
35
+ y = np.zeros((len(sequences), len(chars)), dtype=np.bool)
36
+ for i, sequence in enumerate(sequences):
37
+ for t, char in enumerate(sequence):
38
+ X[i, t, char_to_index[char]] = 1
39
+ y[i, char_to_index[next_chars[i]]] = 1
40
+
41
+ # Define the model
42
+ model = Sequential([
43
+ LSTM(128, input_shape=(max_sequence_length, len(chars))),
44
+ Dense(len(chars), activation='softmax')
45
+ ])
46
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
47
+
48
+ # Train the model
49
+ if args.save:
50
+ checkpoint = ModelCheckpoint(args.save, monitor='loss', verbose=1, save_best_only=True, mode='min')
51
+ callbacks_list = [checkpoint]
52
+ history = model.fit(X, y, batch_size=128, epochs=args.epochs, callbacks=callbacks_list)
53
+ else:
54
+ history = model.fit(X, y, batch_size=128, epochs=args.epochs)
55
+
56
+ # Print the final loss
57
+ print(f"Final loss: {history.history['loss'][-1]:.4f}")