pgatoula commited on
Commit
664026e
·
1 Parent(s): b79a585

Minor corrections

Browse files
Files changed (3) hide show
  1. .gitignore +4 -0
  2. generate_images.py +11 -4
  3. utils/inference_utils.py +10 -7
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets_medmnist/
2
+ train.sh
3
+ train2.py
4
+ results_medmnist/
generate_images.py CHANGED
@@ -11,20 +11,27 @@ if __name__ == "__main__":
11
  parser.add_argument("--latent_dim", default=8, type=int, help='Dimensionality of latent space')
12
  parser.add_argument("--save_dir", default="./fake_images", type=str, help='Path to save synthetic images')
13
  parser.add_argument("--num_of_images", default=10, type=int, help='Number of images to generate')
 
 
14
  args = parser.parse_args()
 
15
 
16
  os.makedirs(args.save_dir, exist_ok=True)
17
 
18
  if not os.path.exists(args.weights_path):
19
  print("Not a valid path")
20
 
21
- vae = init_vae_model(args.model_name, args.latent_dim)
22
- noise_vector = get_noise_seeded((args.num_of_images, args.latent_dim))
 
23
 
24
  # Load weights
25
  vae = load_weights(vae, args.weights_path)
26
  vae.trainable = False
27
 
28
  # Generate & Save images
29
- fake_images = decode_noise(vae, noise_vector, return_list=True)
30
- save_images(args.save_dir, fake_images)
 
 
 
 
11
  parser.add_argument("--latent_dim", default=8, type=int, help='Dimensionality of latent space')
12
  parser.add_argument("--save_dir", default="./fake_images", type=str, help='Path to save synthetic images')
13
  parser.add_argument("--num_of_images", default=10, type=int, help='Number of images to generate')
14
+ parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training')
15
+
16
  args = parser.parse_args()
17
+ args.input_shape = tuple(map(int, args.input_shape))
18
 
19
  os.makedirs(args.save_dir, exist_ok=True)
20
 
21
  if not os.path.exists(args.weights_path):
22
  print("Not a valid path")
23
 
24
+ vae = init_vae_model(args.model_name, args.latent_dim, args.input_shape)
25
+ # noise_vector = get_noise_seeded((args.num_of_images, args.latent_dim))
26
+
27
 
28
  # Load weights
29
  vae = load_weights(vae, args.weights_path)
30
  vae.trainable = False
31
 
32
  # Generate & Save images
33
+ for i in range(args.num_of_images):
34
+ print(f'Generating image for seed {i}/{args.num_of_images}, ')
35
+ noise_vector = get_noise_seeded((1, args.latent_dim), seed=i)
36
+ fake_images = decode_noise(vae, noise_vector, return_list=True)
37
+ save_images(args.save_dir, fake_images, seed=i)
utils/inference_utils.py CHANGED
@@ -8,11 +8,12 @@ from model.vae import VAE
8
  from model import tidev2
9
 
10
 
11
- def init_vae_model(model_name, latent_dim):
12
  if model_name == 'tidev2':
13
  vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim),
14
- tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim)
15
  )
 
16
  return vae_model
17
 
18
 
@@ -27,12 +28,11 @@ def load_weights(vae, weights_path):
27
  return vae
28
 
29
 
30
- def get_noise_seeded(noise_shape):
31
- np.random.seed(0)
32
  random_z = np.random.normal(0, 1, noise_shape)
33
  return random_z
34
 
35
-
36
  def decode_noise(trained_vae, noise, return_list=False):
37
  print("Generating synthetic images ...")
38
  pred = trained_vae.decoder.predict(noise, batch_size=1)
@@ -44,9 +44,12 @@ def decode_noise(trained_vae, noise, return_list=False):
44
  return pred
45
 
46
 
47
- def save_images(save_folder, images):
48
  print(f"Saving synthetic images into {save_folder}")
49
  if isinstance(images, list):
50
  for i, image in enumerate(images):
51
  image = image.astype(np.uint8)
52
- Image.fromarray(image).save(os.path.join(save_folder, f"image-{i}.jpg"))
 
 
 
 
8
  from model import tidev2
9
 
10
 
11
+ def init_vae_model(model_name, latent_dim, input_shape):
12
  if model_name == 'tidev2':
13
  vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim),
14
+ tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim, image_dims=input_shape[:2], out_channels=input_shape[-1])
15
  )
16
+ vae_model.build((None, *input_shape))
17
  return vae_model
18
 
19
 
 
28
  return vae
29
 
30
 
31
+ def get_noise_seeded(noise_shape, seed=0):
32
+ np.random.seed(seed)
33
  random_z = np.random.normal(0, 1, noise_shape)
34
  return random_z
35
 
 
36
  def decode_noise(trained_vae, noise, return_list=False):
37
  print("Generating synthetic images ...")
38
  pred = trained_vae.decoder.predict(noise, batch_size=1)
 
44
  return pred
45
 
46
 
47
+ def save_images(save_folder, images, seed=None):
48
  print(f"Saving synthetic images into {save_folder}")
49
  if isinstance(images, list):
50
  for i, image in enumerate(images):
51
  image = image.astype(np.uint8)
52
+ if image.shape[-1] == 1:
53
+ image = np.squeeze(image, axis=-1)
54
+ save_filename = f"image-{i}.jpg" if seed is None else f"image-{seed}.jpg"
55
+ Image.fromarray(image).save(os.path.join(save_folder, save_filename))