Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import pickle | |
| from src.data_loader import load_data | |
| from src.model import build_model | |
| from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping | |
| def train_model(captions_file, images_dir, epochs=20, batch_size=32): | |
| # Load and prepare data | |
| train_data, val_data, tokenizer, max_length, vocab_size = load_data(captions_file, images_dir) | |
| # Build model | |
| model = build_model(vocab_size, max_length) | |
| # Define callbacks | |
| checkpoint = ModelCheckpoint('models/model.h5', save_best_only=True, monitor='val_loss', mode='min') | |
| early_stopping = EarlyStopping(patience=5, monitor='val_loss', mode='min') | |
| # Train model | |
| history = model.fit( | |
| train_data[0], train_data[1], | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| validation_data=val_data, | |
| callbacks=[checkpoint, early_stopping] | |
| ) | |
| # Save tokenizer | |
| with open('models/tokenizer.pickle', 'wb') as handle: | |
| pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
| # Save max_length | |
| with open('models/max_length.txt', 'w') as f: | |
| f.write(str(max_length)) | |
| return history | |
| if __name__ == "__main__": | |
| train_model('data\caption\captions.txt', 'data/images') |