Spaces:
Build error
Build error
File size: 3,424 Bytes
6d1d2b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import streamlit as st
from PIL import Image
import os
import shutil
# ------------------------------
# Streamlit Config (must be first)
# ------------------------------
st.set_page_config(page_title="Image Categorization Demo", layout="wide")
# ------------------------------
# Load YOLO classification model
# ------------------------------
@st.cache_resource
def load_model():
from ultralytics import YOLO
model = YOLO("yolov8m-cls.pt") # replace with your trained model
return model
model = load_model()
# ------------------------------
# Helper: manage temp folder
# ------------------------------
TEMP_FOLDER = "temfolder"
def prepare_temp_folder():
if os.path.exists(TEMP_FOLDER):
shutil.rmtree(TEMP_FOLDER)
os.makedirs(TEMP_FOLDER)
def cleanup_temp_folder():
if os.path.exists(TEMP_FOLDER):
shutil.rmtree(TEMP_FOLDER)
# ------------------------------
# Streamlit UI
# ------------------------------
st.title("Image Categorization Demo")
with st.form("upload_form", clear_on_submit=True):
uploaded_files = st.file_uploader(
"Upload one or more images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
col1, col2 = st.columns([1, 1])
submit = col1.form_submit_button("π Submit for Classification")
refresh = col2.form_submit_button("π Refresh")
if refresh:
cleanup_temp_folder()
st.success("π§Ή Uploads cleared!")
if submit:
if not uploaded_files:
st.warning("β οΈ Please upload at least one image before submitting.")
else:
total_files = len(uploaded_files)
st.write(f"π Classifying **{total_files}** images...")
# Prepare clean folder
prepare_temp_folder()
results_by_class = {}
progress = st.progress(0) # progress bar
status_text = st.empty() # placeholder for progress text
for idx, file in enumerate(uploaded_files, start=1):
# Save uploaded file into temfolder
img_path = os.path.join(TEMP_FOLDER, file.name)
with open(img_path, "wb") as f:
f.write(file.read())
# Run YOLO classification
results = model(img_path)
pred_class = results[0].names[results[0].probs.top1]
# Group images by predicted class
if pred_class not in results_by_class:
results_by_class[pred_class] = []
results_by_class[pred_class].append(img_path)
# Update progress bar + text
percent = int((idx / total_files) * 100)
progress.progress(idx / total_files)
status_text.text(f"Processing {idx}/{total_files} images ({percent}%)")
st.success("β
Classification complete!")
# ------------------------------
# Show gallery grouped by class
# ------------------------------
for cls, img_list in results_by_class.items():
st.subheader(f"π Category: **{cls}** ({len(img_list)})")
cols = st.columns(4) # show 4 images per row
for i, img_path in enumerate(img_list):
with cols[i % 4]:
st.image(Image.open(img_path), use_column_width=True)
# Cleanup after displaying
cleanup_temp_folder()
|