DasariHarshitha's picture
Upload 3001 files
63ca793 verified
raw
history blame
3.56 kB
import streamlit as st
import os
from PIL import Image
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# Page config
st.set_page_config(page_title="Animal Classifier DL", layout="centered")
st.title("🧠 Animal Classifier (Deep Learning)")
st.markdown("Train from folders like `Animals/cat`, `Animals/dog`, etc. and predict uploaded images.")
IMAGE_SIZE = (128, 128)
DATA_DIR = "Animals" # Folder with subfolders
# πŸ” Load dataset and preprocess
@st.cache_data
def load_dataset():
X = []
y = []
for label in os.listdir(DATA_DIR):
folder = os.path.join(DATA_DIR, label)
if os.path.isdir(folder):
for file in os.listdir(folder):
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
path = os.path.join(folder, file)
img = Image.open(path).convert("RGB").resize(IMAGE_SIZE)
arr = img_to_array(img) / 255.0
X.append(arr)
y.append(label)
except Exception as e:
print(f"⚠️ Skipped {file}: {e}")
return np.array(X), np.array(y)
X, y = load_dataset()
# Stop if no data
if len(X) == 0:
st.error("❌ No images found in 'Animals' folder.")
st.stop()
# Encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(y)
y_cat = to_categorical(y_encoded)
# 🧠 Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y_cat, test_size=0.2, random_state=42)
# 🧠 Build model
@st.cache_resource
def build_model():
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(*IMAGE_SIZE, 3)),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.3),
Dense(len(np.unique(y_encoded)), activation='softmax')
])
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
return model
model = build_model()
# πŸ‘¨β€πŸ« Train the model
with st.spinner("Training model..."):
model.fit(X_train, y_train, epochs=5, batch_size=32, validation_split=0.1, verbose=0)
st.success("βœ… Model trained!")
# πŸ“€ Upload images for prediction
uploaded_files = st.file_uploader("Upload animal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
if uploaded_files:
st.markdown("### Predictions")
cols = st.columns(len(uploaded_files[:3]))
for i, file in enumerate(uploaded_files[:3]):
with cols[i]:
img = Image.open(file).convert("RGB")
st.image(img, caption="Uploaded", use_container_width=True)
img_resized = img.resize(IMAGE_SIZE)
arr = img_to_array(img_resized) / 255.0
arr = np.expand_dims(arr, axis=0)
pred = model.predict(arr, verbose=0)[0]
top_idx = np.argmax(pred)
label = le.inverse_transform([top_idx])[0]
st.success(f"πŸ” Prediction: **{label.capitalize()}**")