File size: 4,457 Bytes
8e1d81b 94311e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from PIL import Image
import os
import glob
import zipfile
import time
# --- PAGE CONFIG ---
st.set_page_config(page_title="Cloud Inventory System", layout="wide")
# --- 1. SETUP & UNZIP LOGIC ---
IMAGES_DIR = "sample_fruits_50"
ZIP_FILE = "sample_fruits_50.zip"
# FIX 1: Force creation of the directory
if not os.path.exists(IMAGES_DIR):
if os.path.exists(ZIP_FILE):
with st.spinner("Unpacking image database..."):
# Create the folder explicitly
os.makedirs(IMAGES_DIR, exist_ok=True)
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
# Extract INSIDE the folder to handle "flat" zip files
zip_ref.extractall(IMAGES_DIR)
st.success("Database loaded!")
else:
st.warning(f"β οΈ Please upload '{ZIP_FILE}' to the Files tab!")
# --- 2. LOAD MODEL ---
@st.cache_resource
def load_model():
# FIX 2: Updated filename to match your screenshot ('fruit_classifier.h5')
model_path = "fruit_classifier.h5"
if not os.path.exists(model_path):
return None
return tf.keras.models.load_model(model_path)
model = load_model()
# --- 3. HELPER FUNCTIONS ---
CLASS_NAMES = [
'fresh_apple', 'fresh_banana', 'fresh_grape', 'fresh_orange', 'fresh_pomegranate',
'rotten_apple', 'rotten_banana', 'rotten_grape', 'rotten_orange', 'rotten_pomegranate'
]
def get_initial_db():
fruits = ['Apple', 'Banana', 'Grape', 'Orange', 'Pomegranate']
data = {fruit: {'Fresh Qty': 0, 'Rotten Qty': 0} for fruit in fruits}
return data
# --- 4. MAIN APP UI ---
st.title("π Cloud AI Inventory Scan")
st.markdown("This system will scan the **50 test images** uploaded to the cloud.")
if model is None:
st.error("Model file 'fruit_classifier.h5' not found in Files tab.")
st.stop()
if st.button("π Start Cloud Scan"):
# Get list of images
files_to_scan = []
if os.path.exists(IMAGES_DIR):
all_images = glob.glob(os.path.join(IMAGES_DIR, "**", "*.*"), recursive=True)
files_to_scan = [f for f in all_images if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# Pick 15 random ones
import random
if len(files_to_scan) > 15:
files_to_scan = random.sample(files_to_scan, 15)
if not files_to_scan:
st.error("No images found! The zip file might be empty or failed to unzip.")
st.stop()
# Layout
col1, col2 = st.columns([1.5, 1])
with col1:
st.subheader("π¦ Live Inventory")
table_placeholder = st.empty()
with col2:
st.subheader("π· Feed")
image_placeholder = st.empty()
# Init Data
db_data = get_initial_db()
current_df = pd.DataFrame.from_dict(db_data, orient='index')
table_placeholder.table(current_df)
progress_bar = st.progress(0)
# LOOP
for i, filepath in enumerate(files_to_scan):
# 1. Display Image (Updated to use_container_width to fix warnings)
image_placeholder.image(filepath, caption=f"Item #{i+1}", use_container_width=True)
# 2. Predict
try:
# Preprocess
img = Image.open(filepath)
img = img.resize((224, 224))
img_arr = np.array(img)
if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]
img_arr = np.expand_dims(img_arr, axis=0) / 255.0
# Inference
preds = model.predict(img_arr, verbose=0)
idx = np.argmax(preds[0])
label = CLASS_NAMES[idx]
# Parse
parts = label.split('_')
quality = parts[0]
fruit = parts[1].title()
# Update DB
if quality == 'fresh':
db_data[fruit]['Fresh Qty'] += 1
else:
db_data[fruit]['Rotten Qty'] += 1
# Update Table
current_df = pd.DataFrame.from_dict(db_data, orient='index')
table_placeholder.table(current_df)
time.sleep(0.2) # Visual delay
except Exception as e:
st.error(f"Error: {e}")
progress_bar.progress((i + 1) / len(files_to_scan))
st.success("Scan Complete!")
# Graph
st.divider()
st.subheader("π Final Cloud Report")
st.bar_chart(current_df, color=["#4CAF50", "#FF5252"]) |