Minor corrections
Browse files- .gitignore +4 -0
- generate_images.py +11 -4
- utils/inference_utils.py +10 -7
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets_medmnist/
|
| 2 |
+
train.sh
|
| 3 |
+
train2.py
|
| 4 |
+
results_medmnist/
|
generate_images.py
CHANGED
|
@@ -11,20 +11,27 @@ if __name__ == "__main__":
|
|
| 11 |
parser.add_argument("--latent_dim", default=8, type=int, help='Dimensionality of latent space')
|
| 12 |
parser.add_argument("--save_dir", default="./fake_images", type=str, help='Path to save synthetic images')
|
| 13 |
parser.add_argument("--num_of_images", default=10, type=int, help='Number of images to generate')
|
|
|
|
|
|
|
| 14 |
args = parser.parse_args()
|
|
|
|
| 15 |
|
| 16 |
os.makedirs(args.save_dir, exist_ok=True)
|
| 17 |
|
| 18 |
if not os.path.exists(args.weights_path):
|
| 19 |
print("Not a valid path")
|
| 20 |
|
| 21 |
-
vae = init_vae_model(args.model_name, args.latent_dim)
|
| 22 |
-
noise_vector = get_noise_seeded((args.num_of_images, args.latent_dim))
|
|
|
|
| 23 |
|
| 24 |
# Load weights
|
| 25 |
vae = load_weights(vae, args.weights_path)
|
| 26 |
vae.trainable = False
|
| 27 |
|
| 28 |
# Generate & Save images
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
parser.add_argument("--latent_dim", default=8, type=int, help='Dimensionality of latent space')
|
| 12 |
parser.add_argument("--save_dir", default="./fake_images", type=str, help='Path to save synthetic images')
|
| 13 |
parser.add_argument("--num_of_images", default=10, type=int, help='Number of images to generate')
|
| 14 |
+
parser.add_argument("--input_shape", default=[320, 320, 3], nargs=3, help='Image shape for training')
|
| 15 |
+
|
| 16 |
args = parser.parse_args()
|
| 17 |
+
args.input_shape = tuple(map(int, args.input_shape))
|
| 18 |
|
| 19 |
os.makedirs(args.save_dir, exist_ok=True)
|
| 20 |
|
| 21 |
if not os.path.exists(args.weights_path):
|
| 22 |
print("Not a valid path")
|
| 23 |
|
| 24 |
+
vae = init_vae_model(args.model_name, args.latent_dim, args.input_shape)
|
| 25 |
+
# noise_vector = get_noise_seeded((args.num_of_images, args.latent_dim))
|
| 26 |
+
|
| 27 |
|
| 28 |
# Load weights
|
| 29 |
vae = load_weights(vae, args.weights_path)
|
| 30 |
vae.trainable = False
|
| 31 |
|
| 32 |
# Generate & Save images
|
| 33 |
+
for i in range(args.num_of_images):
|
| 34 |
+
print(f'Generating image for seed {i}/{args.num_of_images}, ')
|
| 35 |
+
noise_vector = get_noise_seeded((1, args.latent_dim), seed=i)
|
| 36 |
+
fake_images = decode_noise(vae, noise_vector, return_list=True)
|
| 37 |
+
save_images(args.save_dir, fake_images, seed=i)
|
utils/inference_utils.py
CHANGED
|
@@ -8,11 +8,12 @@ from model.vae import VAE
|
|
| 8 |
from model import tidev2
|
| 9 |
|
| 10 |
|
| 11 |
-
def init_vae_model(model_name, latent_dim):
|
| 12 |
if model_name == 'tidev2':
|
| 13 |
vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim),
|
| 14 |
-
tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim)
|
| 15 |
)
|
|
|
|
| 16 |
return vae_model
|
| 17 |
|
| 18 |
|
|
@@ -27,12 +28,11 @@ def load_weights(vae, weights_path):
|
|
| 27 |
return vae
|
| 28 |
|
| 29 |
|
| 30 |
-
def get_noise_seeded(noise_shape):
|
| 31 |
-
np.random.seed(
|
| 32 |
random_z = np.random.normal(0, 1, noise_shape)
|
| 33 |
return random_z
|
| 34 |
|
| 35 |
-
|
| 36 |
def decode_noise(trained_vae, noise, return_list=False):
|
| 37 |
print("Generating synthetic images ...")
|
| 38 |
pred = trained_vae.decoder.predict(noise, batch_size=1)
|
|
@@ -44,9 +44,12 @@ def decode_noise(trained_vae, noise, return_list=False):
|
|
| 44 |
return pred
|
| 45 |
|
| 46 |
|
| 47 |
-
def save_images(save_folder, images):
|
| 48 |
print(f"Saving synthetic images into {save_folder}")
|
| 49 |
if isinstance(images, list):
|
| 50 |
for i, image in enumerate(images):
|
| 51 |
image = image.astype(np.uint8)
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from model import tidev2
|
| 9 |
|
| 10 |
|
| 11 |
+
def init_vae_model(model_name, latent_dim, input_shape):
|
| 12 |
if model_name == 'tidev2':
|
| 13 |
vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim),
|
| 14 |
+
tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim, image_dims=input_shape[:2], out_channels=input_shape[-1])
|
| 15 |
)
|
| 16 |
+
vae_model.build((None, *input_shape))
|
| 17 |
return vae_model
|
| 18 |
|
| 19 |
|
|
|
|
| 28 |
return vae
|
| 29 |
|
| 30 |
|
| 31 |
+
def get_noise_seeded(noise_shape, seed=0):
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
random_z = np.random.normal(0, 1, noise_shape)
|
| 34 |
return random_z
|
| 35 |
|
|
|
|
| 36 |
def decode_noise(trained_vae, noise, return_list=False):
|
| 37 |
print("Generating synthetic images ...")
|
| 38 |
pred = trained_vae.decoder.predict(noise, batch_size=1)
|
|
|
|
| 44 |
return pred
|
| 45 |
|
| 46 |
|
| 47 |
+
def save_images(save_folder, images, seed=None):
|
| 48 |
print(f"Saving synthetic images into {save_folder}")
|
| 49 |
if isinstance(images, list):
|
| 50 |
for i, image in enumerate(images):
|
| 51 |
image = image.astype(np.uint8)
|
| 52 |
+
if image.shape[-1] == 1:
|
| 53 |
+
image = np.squeeze(image, axis=-1)
|
| 54 |
+
save_filename = f"image-{i}.jpg" if seed is None else f"image-{seed}.jpg"
|
| 55 |
+
Image.fromarray(image).save(os.path.join(save_folder, save_filename))
|