Spaces:
Sleeping
Sleeping
| import os | |
| from PIL import Image | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| # Set paths | |
| base_dir = 'data/chest_xray' | |
| val_dir = os.path.join(base_dir, 'val') | |
| normal_class_dir = os.path.join(val_dir, 'NORMAL') | |
| pneumonia_class_dir = os.path.join(val_dir, 'PNEUMONIA') | |
| def augment_images(class_directory, num_augmented_images): | |
| datagen = ImageDataGenerator( | |
| rescale=1. / 255, | |
| rotation_range=20, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| shear_range=0.2, | |
| zoom_range=0.2, | |
| horizontal_flip=True, | |
| fill_mode='nearest' | |
| ) | |
| generator = datagen.flow_from_directory( | |
| directory=os.path.dirname(class_directory), # Parent directory | |
| target_size=(150, 150), | |
| batch_size=1, | |
| class_mode=None, | |
| shuffle=False, | |
| classes=[os.path.basename(class_directory)] # Specify class if using subdirectory | |
| ) | |
| print(f"Found {generator.samples} images in {class_directory}") | |
| if generator.samples == 0: | |
| print("No images found in the directory.") | |
| return | |
| count = 0 | |
| while count < num_augmented_images: | |
| try: | |
| img_batch = generator.__next__() # Use __next__() to get image batch | |
| img = (img_batch[0] * 255).astype('uint8') # Extract the first image in the batch | |
| img_pil = Image.fromarray(img) | |
| img_path = os.path.join(class_directory, f"augmented_{count}.png") | |
| img_pil.save(img_path) | |
| count += 1 | |
| except StopIteration: | |
| print("No more images to generate.") | |
| break | |
| print(f"Total augmented images created: {count}") | |
| # Number of augmented images to generate | |
| num_augmented_images_normal = 2944 - 3875 # This should be a negative number since NORMAL is already balanced | |
| num_augmented_images_pneumonia = 2944 - 1171 # To match the number of NORMAL images | |
| # Generate augmented images for the NORMAL class | |
| augment_images(normal_class_dir, max(num_augmented_images_normal, 0)) | |
| # Generate augmented images for the PNEUMONIA class | |
| augment_images(pneumonia_class_dir, num_augmented_images_pneumonia) | |