import streamlit as st import cv2 import mediapipe as mp import numpy as np import os from PIL import Image # ------------------------------- # MediaPipe Classifier Setup # ------------------------------- BaseOptions = mp.tasks.BaseOptions ImageClassifier = mp.tasks.vision.ImageClassifier ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions model_path = "classifier.tflite" options = ImageClassifierOptions( base_options=BaseOptions(model_asset_path=model_path), max_results=5 ) classifier = ImageClassifier.create_from_options(options) # ------------------------------- # Streamlit UI Setup # ------------------------------- st.set_page_config(page_title="Image Classifier", layout="wide", page_icon="🛒") # Compact layout fix — title fully visible st.markdown( """ """, unsafe_allow_html=True ) st.title("E-Commerce Image Classifier") st.write( "Try uploading an image or a folder to see automatic classification results. " "You can navigate between images using the arrow buttons below. " "This project is open source — check it out on [GitHub](https://github.com/travelmateen/image-classification-ecommerce). 🚀" ) st.markdown("", unsafe_allow_html=True) # ✅ Sidebar uploader and controls with st.sidebar: st.title("User Configuration") num_classes = st.number_input( "Number of classes to display", min_value=1, max_value=5, value=3, help="Choose how many classification results to show (1-5)" ) # Selection mode (Images or Directory) selection_mode = st.radio( "Choose upload type:", ["Directory", "Select Images"], index=0, horizontal=True, ) st.header("Upload Your Files") if selection_mode == "Directory": uploaded_files = st.file_uploader( "Upload images from directory", accept_multiple_files="directory", type=["jpg", "jpeg", "png"], ) else: uploaded_files = st.file_uploader( "Select individual images", type=["jpg", "jpeg", "png"], accept_multiple_files=True ) with st.sidebar.expander("⚠️ Limitations & Tips"): st.write(""" **Known Limitations:** - Pre-trained MediaPipe general classifier - 1000 ImageNet categories only - Not customized for specific domains - Max 10MB per image **For Best Results:** - Clear, single-subject images - Common objects and scenes - Good lighting and focus - Avoid ambiguous or complex scenes """) # ------------------------------- # Default folder handling # ------------------------------- if not uploaded_files: default_folder = "images" if os.path.exists(default_folder): image_files = [ os.path.join(default_folder, f) for f in os.listdir(default_folder) if f.lower().endswith((".jpg", ".jpeg", ".png")) ] if image_files: uploaded_files = [open(img, "rb") for img in image_files] # ------------------------------- # Classification Logic # ------------------------------- if uploaded_files: total_images = len(uploaded_files) if 'foo' not in st.session_state: st.session_state['foo'] = 0 current_index = st.session_state['foo'] # Prevent out-of-range errors if current_index >= len(uploaded_files): current_index = len(uploaded_files) - 1 st.session_state['foo'] = current_index elif current_index < 0: current_index = 0 st.session_state['foo'] = 0 current_image = uploaded_files[current_index] # --- Read image --- file_bytes = np.asarray(bytearray(current_image.read()), dtype=np.uint8) frame = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) if frame is None: st.error("⚠️ Unable to read image.") st.stop() # --- Scale image to 50% --- frame = cv2.resize(frame, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA) # --- Convert to RGB --- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # --- Classify image --- mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) result = classifier.classify(mp_image) # --- Layout: image + classification --- col1, col2 = st.columns([1, 1]) with col1: st.subheader("Original Image") st.image(rgb, use_container_width=True) nav_col1, nav_col2, nav_col3 = st.columns([3, 4, 1], gap="small") with nav_col1: st.markdown("
", unsafe_allow_html=True) if st.button("⬅️", key="prev") and current_index > 0: st.session_state['foo'] = current_index - 1 st.rerun() st.markdown("
", unsafe_allow_html=True) with nav_col2: st.caption(f"🖼️ Image {current_index + 1} of {total_images}") with nav_col3: st.markdown("
", unsafe_allow_html=True) if st.button("➡️", key="next") and current_index < total_images - 1: st.session_state['foo'] = current_index + 1 st.rerun() st.markdown("
", unsafe_allow_html=True) with col2: st.subheader("Classification Results") if result.classifications: categories = result.classifications[0].categories for cat in categories[:num_classes]: st.write(f"**{cat.category_name}** ({cat.score:.2f})") st.progress(float(cat.score)) else: st.write("No classification detected.") else: st.info("📂 Please upload images using the sidebar to begin classification, or place images in the 'images' folder.") # ------------------------------- # Footer # ------------------------------- st.markdown("""

Made by Techtics.ai

""", unsafe_allow_html=True)