|
|
import os |
|
|
import tensorflow as tf |
|
|
|
|
|
from json import dump |
|
|
from argparse import ArgumentParser |
|
|
|
|
|
from model import tidev2 |
|
|
from model.vae import VAE |
|
|
from utils.callbacks import VisualizeCallback, CheckpointCallback |
|
|
from utils.dataloader import list_filenames, Dataset |
|
|
from utils.plots import visualize_from_latent_space |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = ArgumentParser() |
|
|
|
|
|
parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model') |
|
|
parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results') |
|
|
|
|
|
parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training') |
|
|
parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space') |
|
|
|
|
|
parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs') |
|
|
parser.add_argument("--batch_size", default=4, type=int, help='Number of training batch size') |
|
|
parser.add_argument("--learning_rate", default=0.0002, type=float, help='Learning rate') |
|
|
parser.add_argument("--ckpt_interval", default=200, type=int, help='Epoch interval for saving checkpoints') |
|
|
parser.add_argument("--visualization_interval", default=25, type=int, help='Epoch interval for visualizing results') |
|
|
|
|
|
parser.add_argument("--datadir", default='./kid/inflammatory', type=str, help='Folder with images for training') |
|
|
parser.add_argument("--files_ext", default='png', type=str, help='Extension of training files') |
|
|
parser.add_argument("--files_prefix", default=None, type=str, |
|
|
help='Prefix of training files. Ignore if datadir contains only the appropriate files') |
|
|
parser.add_argument("--crop_dim", default=None, type=tuple, |
|
|
help='Dimensions for cropping images. Ignore if images are already cropped') |
|
|
args = parser.parse_args() |
|
|
args.input_shape = tuple(map(int, args.input_shape)) |
|
|
|
|
|
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
log_dir = os.path.join(args.output_path, 'logs') |
|
|
ckpt_dir = os.path.join(args.output_path, 'checkpoints') |
|
|
visualize_dir = os.path.join(args.output_path, 'visualize') |
|
|
|
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
os.makedirs(visualize_dir, exist_ok=True) |
|
|
|
|
|
with open(os.path.join(args.output_path, "training_config.json"), 'w') as fp: |
|
|
dump(vars(args), fp) |
|
|
|
|
|
|
|
|
filenames = list_filenames(data_path=args.datadir, |
|
|
img_extension=args.files_ext, |
|
|
filename_prefix=args.files_prefix) |
|
|
images = Dataset(filenames, |
|
|
batch_size=args.batch_size, |
|
|
crop_dim=args.crop_dim, |
|
|
resize_dim=args.input_shape[:2],) |
|
|
|
|
|
|
|
|
if args.model_name == 'tidev2': |
|
|
vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent), |
|
|
tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent, |
|
|
image_dims=args.input_shape[:2], |
|
|
out_channels=args.input_shape[-1]) |
|
|
) |
|
|
vae.build((None, *args.input_shape)) |
|
|
vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate)) |
|
|
|
|
|
|
|
|
callbacks = [VisualizeCallback(args.visualization_interval, lambda model, epoch: visualize_from_latent_space( |
|
|
latent_dim=args.dim_latent, |
|
|
input_shape=args.input_shape, |
|
|
vae=model, |
|
|
output_path=visualize_dir, |
|
|
epoch=epoch, |
|
|
num_items=10,)), |
|
|
CheckpointCallback(vae=vae, |
|
|
path=ckpt_dir, |
|
|
epoch_interval=args.ckpt_interval, |
|
|
restore_training=False, |
|
|
restore_path=None), |
|
|
tf.keras.callbacks.TensorBoard(log_dir=log_dir)] |
|
|
|
|
|
vae.fit(x=images, |
|
|
epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
callbacks=callbacks, |
|
|
shuffle=True, |
|
|
initial_epoch=0) |
|
|
|
|
|
print('Training finished') |