Image_Categorization / streamlit_app.py
mahmudunnabi's picture
Upload 2 files
6d1d2b1 verified
import streamlit as st
from PIL import Image
import os
import shutil
# ------------------------------
# Streamlit Config (must be first)
# ------------------------------
st.set_page_config(page_title="Image Categorization Demo", layout="wide")
# ------------------------------
# Load YOLO classification model
# ------------------------------
@st.cache_resource
def load_model():
from ultralytics import YOLO
model = YOLO("yolov8m-cls.pt") # replace with your trained model
return model
model = load_model()
# ------------------------------
# Helper: manage temp folder
# ------------------------------
TEMP_FOLDER = "temfolder"
def prepare_temp_folder():
if os.path.exists(TEMP_FOLDER):
shutil.rmtree(TEMP_FOLDER)
os.makedirs(TEMP_FOLDER)
def cleanup_temp_folder():
if os.path.exists(TEMP_FOLDER):
shutil.rmtree(TEMP_FOLDER)
# ------------------------------
# Streamlit UI
# ------------------------------
st.title("Image Categorization Demo")
with st.form("upload_form", clear_on_submit=True):
uploaded_files = st.file_uploader(
"Upload one or more images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
col1, col2 = st.columns([1, 1])
submit = col1.form_submit_button("πŸš€ Submit for Classification")
refresh = col2.form_submit_button("πŸ”„ Refresh")
if refresh:
cleanup_temp_folder()
st.success("🧹 Uploads cleared!")
if submit:
if not uploaded_files:
st.warning("⚠️ Please upload at least one image before submitting.")
else:
total_files = len(uploaded_files)
st.write(f"πŸ” Classifying **{total_files}** images...")
# Prepare clean folder
prepare_temp_folder()
results_by_class = {}
progress = st.progress(0) # progress bar
status_text = st.empty() # placeholder for progress text
for idx, file in enumerate(uploaded_files, start=1):
# Save uploaded file into temfolder
img_path = os.path.join(TEMP_FOLDER, file.name)
with open(img_path, "wb") as f:
f.write(file.read())
# Run YOLO classification
results = model(img_path)
pred_class = results[0].names[results[0].probs.top1]
# Group images by predicted class
if pred_class not in results_by_class:
results_by_class[pred_class] = []
results_by_class[pred_class].append(img_path)
# Update progress bar + text
percent = int((idx / total_files) * 100)
progress.progress(idx / total_files)
status_text.text(f"Processing {idx}/{total_files} images ({percent}%)")
st.success("βœ… Classification complete!")
# ------------------------------
# Show gallery grouped by class
# ------------------------------
for cls, img_list in results_by_class.items():
st.subheader(f"πŸ“‚ Category: **{cls}** ({len(img_list)})")
cols = st.columns(4) # show 4 images per row
for i, img_path in enumerate(img_list):
with cols[i % 4]:
st.image(Image.open(img_path), use_column_width=True)
# Cleanup after displaying
cleanup_temp_folder()