import os import tensorflow as tf from json import dump from argparse import ArgumentParser from model import tidev2 from model.vae import VAE from utils.callbacks import VisualizeCallback, CheckpointCallback from utils.dataloader import list_filenames, Dataset from utils.plots import visualize_from_latent_space if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model') parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results') # VAE model parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training') parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space') # Training parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs') parser.add_argument("--batch_size", default=4, type=int, help='Number of training batch size') parser.add_argument("--learning_rate", default=0.0002, type=float, help='Learning rate') parser.add_argument("--ckpt_interval", default=200, type=int, help='Epoch interval for saving checkpoints') parser.add_argument("--visualization_interval", default=25, type=int, help='Epoch interval for visualizing results') # Data parser.add_argument("--datadir", default='./kid/inflammatory', type=str, help='Folder with images for training') parser.add_argument("--files_ext", default='png', type=str, help='Extension of training files') parser.add_argument("--files_prefix", default=None, type=str, help='Prefix of training files. Ignore if datadir contains only the appropriate files') parser.add_argument("--crop_dim", default=None, type=tuple, help='Dimensions for cropping images. Ignore if images are already cropped') args = parser.parse_args() args.input_shape = tuple(map(int, args.input_shape)) # Create folders & Save training config os.makedirs(args.output_path, exist_ok=True) log_dir = os.path.join(args.output_path, 'logs') ckpt_dir = os.path.join(args.output_path, 'checkpoints') visualize_dir = os.path.join(args.output_path, 'visualize') os.makedirs(log_dir, exist_ok=True) os.makedirs(ckpt_dir, exist_ok=True) os.makedirs(visualize_dir, exist_ok=True) with open(os.path.join(args.output_path, "training_config.json"), 'w') as fp: dump(vars(args), fp) # Setup training data filenames = list_filenames(data_path=args.datadir, img_extension=args.files_ext, filename_prefix=args.files_prefix) images = Dataset(filenames, batch_size=args.batch_size, crop_dim=args.crop_dim, resize_dim=args.input_shape[:2],) # Create Model if args.model_name == 'tidev2': vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent), tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent, image_dims=args.input_shape[:2], out_channels=args.input_shape[-1]) ) vae.build((None, *args.input_shape)) vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate)) # Training callbacks = [VisualizeCallback(args.visualization_interval, lambda model, epoch: visualize_from_latent_space( latent_dim=args.dim_latent, input_shape=args.input_shape, vae=model, output_path=visualize_dir, epoch=epoch, num_items=10,)), CheckpointCallback(vae=vae, path=ckpt_dir, epoch_interval=args.ckpt_interval, restore_training=False, restore_path=None), tf.keras.callbacks.TensorBoard(log_dir=log_dir)] vae.fit(x=images, epochs=args.epochs, batch_size=args.batch_size, callbacks=callbacks, shuffle=True, initial_epoch=0) print('Training finished')