pgatoula commited on
Commit
b79a585
·
1 Parent(s): 7eb7541

Minor corrections

Browse files
Files changed (5) hide show
  1. model/tidev2.py +16 -5
  2. model/tidev2_utils.py +7 -7
  3. train.py +8 -3
  4. utils/dataloader.py +11 -4
  5. utils/plots.py +8 -6
model/tidev2.py CHANGED
@@ -87,17 +87,28 @@ class ConvNeXtDecoderTiny(Model):
87
  drop_path_rate=0.0,
88
  layer_scale_init_value=1e-6,
89
  model_name="convnext",
90
- latent_dim=None):
 
 
91
  super().__init__(name=model_name)
92
 
93
  if latent_dim is None:
94
  raise ValueError("latent_dim must be specified for decoder")
95
 
96
  # Intro layer (dense + reshape)
 
 
 
 
 
 
 
 
 
97
  self.intro = Sequential([
98
- layers.Dense(10 * 10 * projection_dims[0], activation="relu"),
99
- layers.Reshape((10, 10, projection_dims[0]))
100
- ], name=model_name + "_intro")
101
 
102
  # Upsampling layers
103
  self.upsample_layers = [self.intro]
@@ -133,7 +144,7 @@ class ConvNeXtDecoderTiny(Model):
133
  ], name=model_name + "_top")
134
 
135
  self.top_layer = TopLayer(filters=96)
136
- self.pred_layer = layers.Conv2DTranspose(3, kernel_size=1, activation="sigmoid",
137
  padding="same", name="pred_layer")
138
 
139
  def call(self, inputs, training=False):
 
87
  drop_path_rate=0.0,
88
  layer_scale_init_value=1e-6,
89
  model_name="convnext",
90
+ latent_dim=None,
91
+ image_dims=(320, 320),
92
+ out_channels=3):
93
  super().__init__(name=model_name)
94
 
95
  if latent_dim is None:
96
  raise ValueError("latent_dim must be specified for decoder")
97
 
98
  # Intro layer (dense + reshape)
99
+ # self.intro = Sequential([
100
+ # layers.Dense(10 * 10 * projection_dims[0], activation="relu"),
101
+ # layers.Reshape((10, 10, projection_dims[0]))
102
+ # ], name=model_name + "_intro")
103
+ # TODO
104
+ downsample_factor = 4 * 2 * 2 * 2
105
+ input_height, input_width = image_dims
106
+ init_h = input_height // downsample_factor
107
+ init_w = input_width // downsample_factor
108
  self.intro = Sequential([
109
+ layers.Dense(init_h * init_w * projection_dims[0], activation="relu"),
110
+ layers.Reshape((init_h, init_w, projection_dims[0]))
111
+ ])
112
 
113
  # Upsampling layers
114
  self.upsample_layers = [self.intro]
 
144
  ], name=model_name + "_top")
145
 
146
  self.top_layer = TopLayer(filters=96)
147
+ self.pred_layer = layers.Conv2DTranspose(out_channels, kernel_size=1, activation="sigmoid",
148
  padding="same", name="pred_layer")
149
 
150
  def call(self, inputs, training=False):
model/tidev2_utils.py CHANGED
@@ -8,24 +8,24 @@ class TopLayer(layers.Layer):
8
  self.filters = filters
9
 
10
  self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
11
- name="_top_layer")
12
  self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same",
13
- name="_top_layer")
14
  self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same",
15
- name="_top_layer")
16
  self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same",
17
- name="_top_layer")
18
 
19
  self.concat = layers.Concatenate(axis=-1)
20
  self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
21
- padding='same', name="_top_layer")
22
  self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
23
- padding='same', name="_top_layer")
24
 
25
  self.addition = layers.Add()
26
  self.gelu = layers.Activation('gelu')
27
  self.final_conv = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
28
- name="_top_layer")
29
 
30
  def call(self, inputs, training=False):
31
  x = self.conv_1x1(inputs, training=training)
 
8
  self.filters = filters
9
 
10
  self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
11
+ name="top_layer_1x1")
12
  self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same",
13
+ name="top_layer_2x2")
14
  self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same",
15
+ name="top_layer_4x4")
16
  self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same",
17
+ name="top_layer_8x8")
18
 
19
  self.concat = layers.Concatenate(axis=-1)
20
  self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
21
+ padding='same', name="top_layer_point_wise")
22
  self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
23
+ padding='same', name="top_layer_fusion")
24
 
25
  self.addition = layers.Add()
26
  self.gelu = layers.Activation('gelu')
27
  self.final_conv = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
28
+ name="top_layer_out")
29
 
30
  def call(self, inputs, training=False):
31
  x = self.conv_1x1(inputs, training=training)
train.py CHANGED
@@ -4,7 +4,6 @@ import tensorflow as tf
4
  from json import dump
5
  from argparse import ArgumentParser
6
 
7
-
8
  from model import tidev2
9
  from model.vae import VAE
10
  from utils.callbacks import VisualizeCallback, CheckpointCallback
@@ -14,10 +13,11 @@ from utils.plots import visualize_from_latent_space
14
 
15
  if __name__ == '__main__':
16
  parser = ArgumentParser()
 
17
  parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model')
18
  parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results')
19
  # VAE model
20
- parser.add_argument("--input_shape", default=(320, 320, 3), type=tuple, help='Image shape for training')
21
  parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space')
22
  # Training
23
  parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs')
@@ -33,6 +33,7 @@ if __name__ == '__main__':
33
  parser.add_argument("--crop_dim", default=None, type=tuple,
34
  help='Dimensions for cropping images. Ignore if images are already cropped')
35
  args = parser.parse_args()
 
36
 
37
  # Create folders & Save training config
38
  os.makedirs(args.output_path, exist_ok=True)
@@ -59,8 +60,11 @@ if __name__ == '__main__':
59
  # Create Model
60
  if args.model_name == 'tidev2':
61
  vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent),
62
- tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent)
 
 
63
  )
 
64
  vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate))
65
 
66
  # Training
@@ -85,3 +89,4 @@ if __name__ == '__main__':
85
  shuffle=True,
86
  initial_epoch=0)
87
 
 
 
4
  from json import dump
5
  from argparse import ArgumentParser
6
 
 
7
  from model import tidev2
8
  from model.vae import VAE
9
  from utils.callbacks import VisualizeCallback, CheckpointCallback
 
13
 
14
  if __name__ == '__main__':
15
  parser = ArgumentParser()
16
+
17
  parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model')
18
  parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results')
19
  # VAE model
20
+ parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training')
21
  parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space')
22
  # Training
23
  parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs')
 
33
  parser.add_argument("--crop_dim", default=None, type=tuple,
34
  help='Dimensions for cropping images. Ignore if images are already cropped')
35
  args = parser.parse_args()
36
+ args.input_shape = tuple(map(int, args.input_shape))
37
 
38
  # Create folders & Save training config
39
  os.makedirs(args.output_path, exist_ok=True)
 
60
  # Create Model
61
  if args.model_name == 'tidev2':
62
  vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent),
63
+ tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent,
64
+ image_dims=args.input_shape[:2],
65
+ out_channels=args.input_shape[-1])
66
  )
67
+ vae.build((None, *args.input_shape))
68
  vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate))
69
 
70
  # Training
 
89
  shuffle=True,
90
  initial_epoch=0)
91
 
92
+ print('Training finished')
utils/dataloader.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import random
3
  import numpy as np
 
4
  from PIL import Image
5
  from re import split, compile
6
  from tensorflow.keras.utils import Sequence
@@ -19,7 +19,7 @@ def list_filenames(data_path, img_extension='png', filename_prefix=None):
19
 
20
 
21
  class Dataset(Sequence):
22
- def __init__(self, file_list, batch_size=32, crop_dim=None, resize_dim=None, shuffle=True):
23
  self.files_list = file_list
24
  self.batch_size = batch_size
25
 
@@ -28,6 +28,8 @@ class Dataset(Sequence):
28
  self.shuffle = shuffle
29
  self.on_epoch_end()
30
 
 
 
31
  def __len__(self):
32
  return int(np.ceil(len(self.files_list) / self.batch_size))
33
 
@@ -53,7 +55,10 @@ class Dataset(Sequence):
53
  return image.crop((left, top, right, bottom))
54
 
55
  def load_images(self, filepath):
56
- image = Image.open(filepath).convert('RGB')
 
 
 
57
  if self.crop_dim:
58
  image = self.center_crop(image, crop_dim=self.crop_dim)
59
  if self.resize_dim:
@@ -61,4 +66,6 @@ class Dataset(Sequence):
61
 
62
  image = np.array(image).astype(np.float32)
63
  image = image / 255.0
64
- return image
 
 
 
1
  import os
 
2
  import numpy as np
3
+
4
  from PIL import Image
5
  from re import split, compile
6
  from tensorflow.keras.utils import Sequence
 
19
 
20
 
21
  class Dataset(Sequence):
22
+ def __init__(self, file_list, batch_size=32, crop_dim=None, resize_dim=None, shuffle=True, mode='RGB'):
23
  self.files_list = file_list
24
  self.batch_size = batch_size
25
 
 
28
  self.shuffle = shuffle
29
  self.on_epoch_end()
30
 
31
+ self.mode=mode
32
+
33
  def __len__(self):
34
  return int(np.ceil(len(self.files_list) / self.batch_size))
35
 
 
55
  return image.crop((left, top, right, bottom))
56
 
57
  def load_images(self, filepath):
58
+ if self.mode=='RGB':
59
+ image = Image.open(filepath).convert('RGB')
60
+ else:
61
+ image = Image.open(filepath)
62
  if self.crop_dim:
63
  image = self.center_crop(image, crop_dim=self.crop_dim)
64
  if self.resize_dim:
 
66
 
67
  image = np.array(image).astype(np.float32)
68
  image = image / 255.0
69
+ if image.ndim == 2:
70
+ image = np.expand_dims(image, -1)
71
+ return image
utils/plots.py CHANGED
@@ -1,11 +1,11 @@
1
- import imageio
2
  import numpy as np
 
3
 
4
 
5
  def visualize_from_latent_space(latent_dim, input_shape, vae, output_path, epoch="final", num_items=10,):
6
 
7
  image_size, _, img_channels = input_shape
8
- figure = np.zeros((image_size * num_items, image_size * num_items, 3))
9
 
10
  scale = 1.0
11
  grid_x = np.linspace(-scale, scale, num_items)
@@ -18,8 +18,10 @@ def visualize_from_latent_space(latent_dim, input_shape, vae, output_path, epoch
18
  x_decoded = vae.decoder.predict(random_z)
19
  image = x_decoded[0].reshape(input_shape)
20
  figure[i * image_size: (i + 1) * image_size, j * image_size: (j + 1) * image_size, ] = image
21
- print(f'Saving collage in {output_path}/decoding-noise-ep{epoch}.jpg')
22
- imageio.imsave(f'{output_path}/decoding-noise-ep{epoch}.jpg', (figure * 255).astype('uint8'))
23
-
24
-
 
 
25
 
 
 
1
  import numpy as np
2
+ from PIL import Image
3
 
4
 
5
  def visualize_from_latent_space(latent_dim, input_shape, vae, output_path, epoch="final", num_items=10,):
6
 
7
  image_size, _, img_channels = input_shape
8
+ figure = np.zeros((image_size * num_items, image_size * num_items, img_channels))
9
 
10
  scale = 1.0
11
  grid_x = np.linspace(-scale, scale, num_items)
 
18
  x_decoded = vae.decoder.predict(random_z)
19
  image = x_decoded[0].reshape(input_shape)
20
  figure[i * image_size: (i + 1) * image_size, j * image_size: (j + 1) * image_size, ] = image
21
+ print(f'Saving collage in {output_path}/decoding-noise-ep{epoch}.png')
22
+ figure = (figure * 255).astype('uint8')
23
+ if img_channels == 1:
24
+ figure = np.squeeze(figure, axis=-1)
25
+ figure = Image.fromarray(figure)
26
+ figure.save(f"{output_path}/decoding-noise-ep{epoch}.jpg")
27