Spaces:
Sleeping
Sleeping
| import keras | |
| import tensorflow as tf | |
| from make_dataset import train_dataset, valid_dataset | |
| from src.components.model import get_cnn_model, TransformerEncoderBlock, TransformerDecoderBlock, ImageCaptioningModel, image_augmentation, LRSchedule | |
| EMBED_DIM = 512 | |
| FF_DIM = 512 | |
| EPOCHS = 30 | |
| cnn_model = get_cnn_model() | |
| encoder = TransformerEncoderBlock( | |
| embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1) | |
| decoder = TransformerDecoderBlock( | |
| embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2) | |
| caption_model = ImageCaptioningModel( | |
| cnn_model=cnn_model, | |
| encoder=encoder, | |
| decoder=decoder, | |
| image_aug=image_augmentation, | |
| ) | |
| early_stopping = keras.callbacks.EarlyStopping( | |
| patience=3, restore_best_weights=True) | |
| num_train_steps = len(train_dataset) * EPOCHS | |
| num_warmup_steps = num_train_steps // 15 | |
| lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, | |
| warmup_steps=num_warmup_steps) | |
| caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss='sparse_categorical_crossentropy', | |
| metrics=['accuracy']) | |
| caption_model.fit( | |
| train_dataset, | |
| epochs=EPOCHS, | |
| validation_data=valid_dataset, | |
| callbacks=[early_stopping], | |
| ) | |
| caption_model.save("./artifacts/caption_model1.keras") | |