image-Captioning / src /train.py
Anki2004's picture
Upload 117 files
c90c141 verified
Raw
History Blame Contribute Delete
1.36 kB
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pickle
from src.data_loader import load_data
from src.model import build_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
def train_model(captions_file, images_dir, epochs=20, batch_size=32):
# Load and prepare data
train_data, val_data, tokenizer, max_length, vocab_size = load_data(captions_file, images_dir)
# Build model
model = build_model(vocab_size, max_length)
# Define callbacks
checkpoint = ModelCheckpoint('models/model.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(patience=5, monitor='val_loss', mode='min')
# Train model
history = model.fit(
train_data[0], train_data[1],
epochs=epochs,
batch_size=batch_size,
validation_data=val_data,
callbacks=[checkpoint, early_stopping]
)
# Save tokenizer
with open('models/tokenizer.pickle', 'wb') as handle:
pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
# Save max_length
with open('models/max_length.txt', 'w') as f:
f.write(str(max_length))
return history
if __name__ == "__main__":
train_model('data\caption\captions.txt', 'data/images')