ytrsoymr's picture
Create app.py
b6d5ded verified
import streamlit as st
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
st.set_page_config(page_title="Image Augmentation with Keras", layout="wide")
st.title("🧪 Keras Image Augmentation (Legacy - ImageDataGenerator)")
st.write("""
This interactive app allows you to apply classic image augmentation techniques using Keras' legacy `ImageDataGenerator`.
Upload your image and tweak augmentation parameters to visualize the results live.
""")
# Sidebar parameters
st.sidebar.header("⚙️ Augmentation Parameters")
rotation_range = st.sidebar.slider("Rotation Range (degrees)", 0, 90, 30)
width_shift = st.sidebar.slider("Width Shift Range (%)", 0, 50, 10) / 100
height_shift = st.sidebar.slider("Height Shift Range (%)", 0, 50, 10) / 100
shear_range = st.sidebar.slider("Shear Range", 0.0, 1.0, 0.2)
zoom_range = st.sidebar.slider("Zoom Range", 0.0, 1.0, 0.2)
horizontal_flip = st.sidebar.checkbox("Horizontal Flip", value=True)
vertical_flip = st.sidebar.checkbox("Vertical Flip", value=False)
brightness_low = st.sidebar.slider("Brightness Range - Low", 0.0, 2.0, 0.8)
brightness_high = st.sidebar.slider("Brightness Range - High", 0.0, 2.0, 1.2)
fill_mode = st.sidebar.selectbox("Fill Mode", ["nearest", "constant", "reflect", "wrap"])
# Upload image
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
# Load and preprocess image
img = Image.open(uploaded_file).convert("RGB")
img = img.resize((224, 224))
x = img_to_array(img)
x = x.reshape((1,) + x.shape) # Add batch dimension
# Define datagen
datagen = ImageDataGenerator(
rotation_range=rotation_range,
width_shift_range=width_shift,
height_shift_range=height_shift,
shear_range=shear_range,
zoom_range=zoom_range,
horizontal_flip=horizontal_flip,
vertical_flip=vertical_flip,
brightness_range=[brightness_low, brightness_high],
fill_mode=fill_mode
)
st.subheader("Augmented Variants (x9)")
fig, axs = plt.subplots(3, 3, figsize=(12, 9))
axs = axs.flatten()
titles = [
"Augmented 1", "Augmented 2", "Augmented 3",
"Augmented 4", "Augmented 5", "Augmented 6",
"Augmented 7", "Augmented 8", "Augmented 9"
]
i = 0
for batch in datagen.flow(x, batch_size=1):
axs[i].imshow(batch[0].astype('uint8'))
axs[i].axis('off')
axs[i].set_title(titles[i], fontsize=10)
i += 1
if i == 9:
break
st.pyplot(fig)
with st.expander("📘 What Each Augmentation Means"):
st.markdown("""
- **Rotation**: Rotates image by a random angle up to ± the specified value.
- **Width/Height Shift**: Moves image sideways or up/down by a percentage of its size.
- **Shear**: Applies a slant or distortion to simulate a shifted viewpoint.
- **Zoom**: Randomly zooms in or out of the image.
- **Horizontal/Vertical Flip**: Flips the image along the X or Y axis.
- **Brightness Range**: Changes the brightness of the image randomly within the given range.
- **Fill Mode**: Fills in empty areas created by transformations using methods like nearest neighbor or reflection.
""")
else:
st.info("📷 Upload an image to see augmentations and experiment with parameters.")