Spaces:
Build error
Build error
| import streamlit as st | |
| from PIL import Image | |
| import os | |
| import shutil | |
| # ------------------------------ | |
| # Streamlit Config (must be first) | |
| # ------------------------------ | |
| st.set_page_config(page_title="Image Categorization Demo", layout="wide") | |
| # ------------------------------ | |
| # Load YOLO classification model | |
| # ------------------------------ | |
| def load_model(): | |
| from ultralytics import YOLO | |
| model = YOLO("yolov8m-cls.pt") # replace with your trained model | |
| return model | |
| model = load_model() | |
| # ------------------------------ | |
| # Helper: manage temp folder | |
| # ------------------------------ | |
| TEMP_FOLDER = "temfolder" | |
| def prepare_temp_folder(): | |
| if os.path.exists(TEMP_FOLDER): | |
| shutil.rmtree(TEMP_FOLDER) | |
| os.makedirs(TEMP_FOLDER) | |
| def cleanup_temp_folder(): | |
| if os.path.exists(TEMP_FOLDER): | |
| shutil.rmtree(TEMP_FOLDER) | |
| # ------------------------------ | |
| # Streamlit UI | |
| # ------------------------------ | |
| st.title("Image Categorization Demo") | |
| with st.form("upload_form", clear_on_submit=True): | |
| uploaded_files = st.file_uploader( | |
| "Upload one or more images", | |
| type=["jpg", "jpeg", "png"], | |
| accept_multiple_files=True | |
| ) | |
| col1, col2 = st.columns([1, 1]) | |
| submit = col1.form_submit_button("π Submit for Classification") | |
| refresh = col2.form_submit_button("π Refresh") | |
| if refresh: | |
| cleanup_temp_folder() | |
| st.success("π§Ή Uploads cleared!") | |
| if submit: | |
| if not uploaded_files: | |
| st.warning("β οΈ Please upload at least one image before submitting.") | |
| else: | |
| total_files = len(uploaded_files) | |
| st.write(f"π Classifying **{total_files}** images...") | |
| # Prepare clean folder | |
| prepare_temp_folder() | |
| results_by_class = {} | |
| progress = st.progress(0) # progress bar | |
| status_text = st.empty() # placeholder for progress text | |
| for idx, file in enumerate(uploaded_files, start=1): | |
| # Save uploaded file into temfolder | |
| img_path = os.path.join(TEMP_FOLDER, file.name) | |
| with open(img_path, "wb") as f: | |
| f.write(file.read()) | |
| # Run YOLO classification | |
| results = model(img_path) | |
| pred_class = results[0].names[results[0].probs.top1] | |
| # Group images by predicted class | |
| if pred_class not in results_by_class: | |
| results_by_class[pred_class] = [] | |
| results_by_class[pred_class].append(img_path) | |
| # Update progress bar + text | |
| percent = int((idx / total_files) * 100) | |
| progress.progress(idx / total_files) | |
| status_text.text(f"Processing {idx}/{total_files} images ({percent}%)") | |
| st.success("β Classification complete!") | |
| # ------------------------------ | |
| # Show gallery grouped by class | |
| # ------------------------------ | |
| for cls, img_list in results_by_class.items(): | |
| st.subheader(f"π Category: **{cls}** ({len(img_list)})") | |
| cols = st.columns(4) # show 4 images per row | |
| for i, img_path in enumerate(img_list): | |
| with cols[i % 4]: | |
| st.image(Image.open(img_path), use_column_width=True) | |
| # Cleanup after displaying | |
| cleanup_temp_folder() | |