File size: 4,713 Bytes
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79a585
b620cf3
 
 
b79a585
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79a585
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79a585
 
 
b620cf3
b79a585
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79a585
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import tensorflow as tf

from json import dump
from argparse import ArgumentParser

from model import tidev2
from model.vae import VAE
from utils.callbacks import VisualizeCallback, CheckpointCallback
from utils.dataloader import list_filenames, Dataset
from utils.plots import visualize_from_latent_space


if __name__ == '__main__':
    parser = ArgumentParser()

    parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model')
    parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results')
    # VAE model
    parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training')
    parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space')
    # Training
    parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs')
    parser.add_argument("--batch_size", default=4, type=int, help='Number of training batch size')
    parser.add_argument("--learning_rate",  default=0.0002, type=float, help='Learning rate')
    parser.add_argument("--ckpt_interval", default=200, type=int, help='Epoch interval for saving checkpoints')
    parser.add_argument("--visualization_interval", default=25, type=int, help='Epoch interval for visualizing results')
    # Data
    parser.add_argument("--datadir", default='./kid/inflammatory', type=str, help='Folder with images for training')
    parser.add_argument("--files_ext", default='png', type=str, help='Extension of training files')
    parser.add_argument("--files_prefix", default=None, type=str,
                        help='Prefix of training files. Ignore if datadir contains only the appropriate files')
    parser.add_argument("--crop_dim", default=None, type=tuple,
                        help='Dimensions for cropping images. Ignore if images are already cropped')
    args = parser.parse_args()
    args.input_shape = tuple(map(int, args.input_shape))

    # Create folders & Save training config
    os.makedirs(args.output_path, exist_ok=True)
    log_dir = os.path.join(args.output_path, 'logs')
    ckpt_dir = os.path.join(args.output_path, 'checkpoints')
    visualize_dir = os.path.join(args.output_path, 'visualize')

    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(visualize_dir, exist_ok=True)

    with open(os.path.join(args.output_path, "training_config.json"), 'w') as fp:
        dump(vars(args), fp)

    # Setup training data
    filenames = list_filenames(data_path=args.datadir,
                               img_extension=args.files_ext,
                               filename_prefix=args.files_prefix)
    images = Dataset(filenames,
                     batch_size=args.batch_size,
                     crop_dim=args.crop_dim,
                     resize_dim=args.input_shape[:2],)

    # Create Model
    if args.model_name == 'tidev2':
        vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent),
                  tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent,
                                             image_dims=args.input_shape[:2],
                                             out_channels=args.input_shape[-1])
                  )
        vae.build((None, *args.input_shape))
        vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate))

    # Training
    callbacks = [VisualizeCallback(args.visualization_interval, lambda model,  epoch: visualize_from_latent_space(
                                                                                latent_dim=args.dim_latent,
                                                                                input_shape=args.input_shape,
                                                                                vae=model,
                                                                                output_path=visualize_dir,
                                                                                epoch=epoch,
                                                                                num_items=10,)),
                 CheckpointCallback(vae=vae,
                                    path=ckpt_dir,
                                    epoch_interval=args.ckpt_interval,
                                    restore_training=False,
                                    restore_path=None),
                tf.keras.callbacks.TensorBoard(log_dir=log_dir)]

    vae.fit(x=images,
            epochs=args.epochs,
            batch_size=args.batch_size,
            callbacks=callbacks,
            shuffle=True,
            initial_epoch=0)

    print('Training finished')