import streamlit as st from PIL import Image import os import shutil # ------------------------------ # Streamlit Config (must be first) # ------------------------------ st.set_page_config(page_title="Image Categorization Demo", layout="wide") # ------------------------------ # Load YOLO classification model # ------------------------------ @st.cache_resource def load_model(): from ultralytics import YOLO model = YOLO("yolov8m-cls.pt") # replace with your trained model return model model = load_model() # ------------------------------ # Helper: manage temp folder # ------------------------------ TEMP_FOLDER = "temfolder" def prepare_temp_folder(): if os.path.exists(TEMP_FOLDER): shutil.rmtree(TEMP_FOLDER) os.makedirs(TEMP_FOLDER) def cleanup_temp_folder(): if os.path.exists(TEMP_FOLDER): shutil.rmtree(TEMP_FOLDER) # ------------------------------ # Streamlit UI # ------------------------------ st.title("Image Categorization Demo") with st.form("upload_form", clear_on_submit=True): uploaded_files = st.file_uploader( "Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True ) col1, col2 = st.columns([1, 1]) submit = col1.form_submit_button("๐Ÿš€ Submit for Classification") refresh = col2.form_submit_button("๐Ÿ”„ Refresh") if refresh: cleanup_temp_folder() st.success("๐Ÿงน Uploads cleared!") if submit: if not uploaded_files: st.warning("โš ๏ธ Please upload at least one image before submitting.") else: total_files = len(uploaded_files) st.write(f"๐Ÿ” Classifying **{total_files}** images...") # Prepare clean folder prepare_temp_folder() results_by_class = {} progress = st.progress(0) # progress bar status_text = st.empty() # placeholder for progress text for idx, file in enumerate(uploaded_files, start=1): # Save uploaded file into temfolder img_path = os.path.join(TEMP_FOLDER, file.name) with open(img_path, "wb") as f: f.write(file.read()) # Run YOLO classification results = model(img_path) pred_class = results[0].names[results[0].probs.top1] # Group images by predicted class if pred_class not in results_by_class: results_by_class[pred_class] = [] results_by_class[pred_class].append(img_path) # Update progress bar + text percent = int((idx / total_files) * 100) progress.progress(idx / total_files) status_text.text(f"Processing {idx}/{total_files} images ({percent}%)") st.success("โœ… Classification complete!") # ------------------------------ # Show gallery grouped by class # ------------------------------ for cls, img_list in results_by_class.items(): st.subheader(f"๐Ÿ“‚ Category: **{cls}** ({len(img_list)})") cols = st.columns(4) # show 4 images per row for i, img_path in enumerate(img_list): with cols[i % 4]: st.image(Image.open(img_path), use_column_width=True) # Cleanup after displaying cleanup_temp_folder()