| | import tensorflow as tf |
| | from chatbot_model import RetrievalChatbot, ChatbotConfig |
| | from environment_setup import EnvironmentSetup |
| | from plotter import Plotter |
| |
|
| | from logger_config import config_logger |
| | logger = config_logger(__name__) |
| |
|
| | def inspect_tfrecord(tfrecord_file_path, num_examples=3): |
| | def parse_example(example_proto): |
| | feature_description = { |
| | 'query_ids': tf.io.FixedLenFeature([512], tf.int64), |
| | 'positive_ids': tf.io.FixedLenFeature([512], tf.int64), |
| | 'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), |
| | } |
| | return tf.io.parse_single_example(example_proto, feature_description) |
| | |
| | dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
| | dataset = dataset.map(parse_example) |
| | |
| | for i, example in enumerate(dataset.take(num_examples)): |
| | print(f"Example {i+1}:") |
| | print(f"Query IDs: {example['query_ids'].numpy()}") |
| | print(f"Positive IDs: {example['positive_ids'].numpy()}") |
| | print(f"Negative IDs: {example['negative_ids'].numpy()}") |
| | print("-" * 50) |
| | |
| | def main(): |
| | tf.keras.backend.clear_session() |
| | |
| | |
| | |
| | |
| | |
| | env = EnvironmentSetup() |
| | env.initialize() |
| | |
| | |
| | EPOCHS = 20 |
| | TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord' |
| | CHECKPOINT_DIR = 'checkpoints/' |
| | |
| | batch_size = 32 |
| | |
| | |
| | config = ChatbotConfig() |
| | chatbot = RetrievalChatbot(config, mode='training') |
| | |
| | |
| | latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR) |
| | initial_epoch = 0 |
| | if latest_checkpoint: |
| | try: |
| | ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) |
| | initial_epoch = ckpt_number |
| | logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}") |
| | except (IndexError, ValueError): |
| | logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}") |
| | initial_epoch = 0 |
| | |
| | |
| | chatbot.train_model( |
| | tfrecord_file_path=TF_RECORD_FILE_PATH, |
| | epochs=EPOCHS, |
| | batch_size=batch_size, |
| | use_lr_schedule=True, |
| | test_mode=True, |
| | checkpoint_dir=CHECKPOINT_DIR, |
| | initial_epoch=initial_epoch |
| | ) |
| | |
| | |
| | model_save_path = env.training_dirs['base'] / 'final_model' |
| | chatbot.save_models(model_save_path) |
| | |
| | |
| | plotter = Plotter(save_dir=env.training_dirs['plots']) |
| | plotter.plot_training_history(chatbot.history) |
| | |
| | if __name__ == "__main__": |
| | main() |