import streamlit as st from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array from PIL import Image import numpy as np import matplotlib.pyplot as plt st.set_page_config(page_title="Image Augmentation with Keras", layout="wide") st.title("๐Ÿงช Keras Image Augmentation (Legacy - ImageDataGenerator)") st.write(""" This interactive app allows you to apply classic image augmentation techniques using Keras' legacy `ImageDataGenerator`. Upload your image and tweak augmentation parameters to visualize the results live. """) # Sidebar parameters st.sidebar.header("โš™๏ธ Augmentation Parameters") rotation_range = st.sidebar.slider("Rotation Range (degrees)", 0, 90, 30) width_shift = st.sidebar.slider("Width Shift Range (%)", 0, 50, 10) / 100 height_shift = st.sidebar.slider("Height Shift Range (%)", 0, 50, 10) / 100 shear_range = st.sidebar.slider("Shear Range", 0.0, 1.0, 0.2) zoom_range = st.sidebar.slider("Zoom Range", 0.0, 1.0, 0.2) horizontal_flip = st.sidebar.checkbox("Horizontal Flip", value=True) vertical_flip = st.sidebar.checkbox("Vertical Flip", value=False) brightness_low = st.sidebar.slider("Brightness Range - Low", 0.0, 2.0, 0.8) brightness_high = st.sidebar.slider("Brightness Range - High", 0.0, 2.0, 1.2) fill_mode = st.sidebar.selectbox("Fill Mode", ["nearest", "constant", "reflect", "wrap"]) # Upload image uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file: # Load and preprocess image img = Image.open(uploaded_file).convert("RGB") img = img.resize((224, 224)) x = img_to_array(img) x = x.reshape((1,) + x.shape) # Add batch dimension # Define datagen datagen = ImageDataGenerator( rotation_range=rotation_range, width_shift_range=width_shift, height_shift_range=height_shift, shear_range=shear_range, zoom_range=zoom_range, horizontal_flip=horizontal_flip, vertical_flip=vertical_flip, brightness_range=[brightness_low, brightness_high], fill_mode=fill_mode ) st.subheader("Augmented Variants (x9)") fig, axs = plt.subplots(3, 3, figsize=(12, 9)) axs = axs.flatten() titles = [ "Augmented 1", "Augmented 2", "Augmented 3", "Augmented 4", "Augmented 5", "Augmented 6", "Augmented 7", "Augmented 8", "Augmented 9" ] i = 0 for batch in datagen.flow(x, batch_size=1): axs[i].imshow(batch[0].astype('uint8')) axs[i].axis('off') axs[i].set_title(titles[i], fontsize=10) i += 1 if i == 9: break st.pyplot(fig) with st.expander("๐Ÿ“˜ What Each Augmentation Means"): st.markdown(""" - **Rotation**: Rotates image by a random angle up to ยฑ the specified value. - **Width/Height Shift**: Moves image sideways or up/down by a percentage of its size. - **Shear**: Applies a slant or distortion to simulate a shifted viewpoint. - **Zoom**: Randomly zooms in or out of the image. - **Horizontal/Vertical Flip**: Flips the image along the X or Y axis. - **Brightness Range**: Changes the brightness of the image randomly within the given range. - **Fill Mode**: Fills in empty areas created by transformations using methods like nearest neighbor or reflection. """) else: st.info("๐Ÿ“ท Upload an image to see augmentations and experiment with parameters.")