import streamlit as st import numpy as np import tensorflow as tf from tensorflow.keras.applications.vgg16 import preprocess_input as vgg_preprocess from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess from PIL import Image import os import time # ───────────────────────────────────────────────────────────────────────────── # CLASS NAMES # Hardcoded from training notebooks (flow_from_directory sorts alphabetically) # This is the exact order your models learned — no class_indices.json needed. # ───────────────────────────────────────────────────────────────────────────── CLASS_NAMES = [ 'airplane', 'bench', 'bicycle', 'bird', 'bottle', 'bowl', 'bus', 'cake', 'car', 'cat', 'chair', 'couch', 'cow', 'cup', 'dog', 'elephant', 'horse', 'motorcycle', 'person', 'pizza', 'potted plant', 'stop sign', 'traffic light', 'train', 'truck', ] # Note: flow_from_directory loads classes in alphabetical order. # The list above is sorted A→Z to match exactly what your models output. CLASS_ICONS = { 'airplane': '✈️', 'bench': '🪑', 'bicycle': '🚲', 'bird': '🐦', 'bottle': '🍶', 'bowl': '🥣', 'bus': '🚌', 'cake': '🎂', 'car': '🚗', 'cat': '🐱', 'chair': '🪑', 'couch': '🛋️', 'cow': '🐮', 'cup': '☕', 'dog': '🐶', 'elephant': '🐘', 'horse': '🐴', 'motorcycle': '🏍️', 'person': '🧍', 'pizza': '🍕', 'potted plant': '🪴', 'stop sign': '🛑', 'traffic light': '🚦', 'train': '🚆', 'truck': '🚛', } # ───────────────────────────────────────────────────────────────────────────── # MODEL CONFIGS # Preprocessing matches exactly what each training notebook used # ───────────────────────────────────────────────────────────────────────────── # PLOT_DIR = "models/result" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) PLOT_DIR = os.path.join(BASE_DIR, "models", "result") MODEL_CONFIGS = { "EfficientNetB0": { "path": os.path.join(PLOT_DIR, "efficientnetb0_best.keras"), "preprocess": eff_preprocess, # Phase 2.4: preprocess_input from efficientnet "rescale": False, "color": "#f39c12", "icon": "🟨", "accuracy": "92%", "speed": "80 ms", "description": "Highest accuracy · Compound scaling", }, "ResNet50": { "path": os.path.join(PLOT_DIR, "resnet50_best.keras"), "preprocess": resnet_preprocess, # Phase 2.2: preprocess_input from resnet50 "rescale": False, "color": "#e74c3c", "icon": "🟥", "accuracy": "88%", "speed": "100 ms", "description": "Strong all-rounder · Residual learning", }, "MobileNetV2": { "path": os.path.join(PLOT_DIR, "mobilenetv2_final.h5"), "preprocess": None, # Phase 2.3: ImageDataGenerator(rescale=1./255) "rescale": True, "color": "#2ecc71", "icon": "🟩", "accuracy": "85%", "speed": "50 ms", "description": "Fastest · Lightweight · Edge-ready", }, "VGG16": { "path": os.path.join(PLOT_DIR, "vgg16_best.keras"), "preprocess": vgg_preprocess, # Phase 2.1: preprocess_input from vgg16 "rescale": False, "color": "#3498db", "icon": "🟦", "accuracy": "83%", "speed": "150 ms", "description": "Classic CNN · Reliable baseline", }, } # ───────────────────────────────────────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_all_models(): loaded = {} for name, cfg in MODEL_CONFIGS.items(): if os.path.exists(cfg["path"]): try: loaded[name] = tf.keras.models.load_model(cfg["path"]) except Exception as e: st.warning(f"⚠️ Failed to load {name}: {e}") loaded[name] = None else: loaded[name] = None return loaded def prepare_image(pil_img, cfg): """Resize, apply the correct preprocessing, return (1, 224, 224, 3) array.""" img = pil_img.resize((224, 224)) arr = np.array(img, dtype=np.float32) if cfg["rescale"]: arr = arr / 255.0 # MobileNetV2 path else: arr = cfg["preprocess"](arr) # VGG16 / ResNet50 / EfficientNet path return np.expand_dims(arr, axis=0) # ───────────────────────────────────────────────────────────────────────────── # PAGE HEADER # ───────────────────────────────────────────────────────────────────────────── st.title("🖼️ Image Classification") st.caption("Upload any image — all 4 CNN models classify it simultaneously") # ── Model status cards ──────────────────────────────────────────────────────── st.markdown("### 🤖 Available Models") cols = st.columns(4) for col, (mname, cfg) in zip(cols, MODEL_CONFIGS.items()): exists = os.path.exists(cfg["path"]) status_label = "✅ Ready" if exists else "❌ File missing" status_color = "#2ecc71" if exists else "#e74c3c" with col: st.markdown( f"""

{cfg['icon']}

{mname}

{cfg['description']}

🎯 {cfg['accuracy']}  ⚡ {cfg['speed']}

{status_label}

""", unsafe_allow_html=True, ) st.divider() # ───────────────────────────────────────────────────────────────────────────── # UPLOAD # ───────────────────────────────────────────────────────────────────────────── uploaded = st.file_uploader( "📤 Upload an image to classify", type=["jpg", "jpeg", "png"], help="Best results with a single clear object centred in the frame.", ) if not uploaded: st.info("👆 Upload an image above to see predictions from all 4 models.") st.stop() image = Image.open(uploaded).convert("RGB") # ───────────────────────────────────────────────────────────────────────────── # LOAD MODELS # ───────────────────────────────────────────────────────────────────────────── with st.spinner("⏳ Loading models (cached after first run)…"): models = load_all_models() available = {k: v for k, v in models.items() if v is not None} if not available: st.error( "❌ No model files found in `model/result/`. " "Make sure `.keras` / `.h5` files are present." ) st.stop() # ───────────────────────────────────────────────────────────────────────────── # RUN INFERENCE # ───────────────────────────────────────────────────────────────────────────── all_results = {} bar = st.progress(0, text="Running inference…") for i, (mname, model) in enumerate(available.items()): cfg = MODEL_CONFIGS[mname] arr = prepare_image(image, cfg) t0 = time.time() preds = model.predict(arr, verbose=0)[0] ms = (time.time() - t0) * 1000 top5_idx = np.argsort(preds)[::-1][:5] top5 = [(CLASS_NAMES[j], float(preds[j])) for j in top5_idx] all_results[mname] = {"top5": top5, "top1": top5[0], "ms": ms, "cfg": cfg} bar.progress((i + 1) / len(available), text=f"✅ {mname} complete") bar.empty() # ───────────────────────────────────────────────────────────────────────────── # LAYOUT: image | consensus # ───────────────────────────────────────────────────────────────────────────── img_col, sum_col = st.columns([1, 1]) with img_col: st.image(image, caption="📷 Input Image", use_container_width=True) st.caption(f"Size: {image.size[0]}×{image.size[1]} px · RGB") with sum_col: st.markdown("### 🧠 Model Consensus") votes = [r["top1"][0] for r in all_results.values()] consensus = max(set(votes), key=votes.count) n_agree = votes.count(consensus) icon = CLASS_ICONS.get(consensus, "🔍") c_color = "#2ecc71" if n_agree >= 3 else "#f39c12" if n_agree == 2 else "#e74c3c" st.markdown( f"""

{icon}

{consensus}

{n_agree} / {len(all_results)} models agree

""", unsafe_allow_html=True, ) st.markdown("**Top prediction per model:**") for mname, res in all_results.items(): cfg = res["cfg"] label = res["top1"][0] conf = res["top1"][1] match = "✅" if label == consensus else "🔶" st.markdown( f"""
{cfg['icon']} {mname} {match} {label} {conf:.0%} · {res['ms']:.0f}ms
""", unsafe_allow_html=True, ) st.divider() # ───────────────────────────────────────────────────────────────────────────── # DETAILED TOP-5 CARDS — 2-column grid # ───────────────────────────────────────────────────────────────────────────── st.markdown("### 🔍 Detailed Predictions — Top 5 Per Model") names = list(all_results.keys()) for i in range(0, len(names), 2): row = st.columns(2) for j, col in enumerate(row): if i + j >= len(names): break mname = names[i + j] res = all_results[mname] cfg = res["cfg"] with col: # Header st.markdown( f"""
{cfg['icon']} {mname} ⚡ {res['ms']:.0f} ms
{cfg['description']}
""", unsafe_allow_html=True, ) # Top-5 bars for rank, (label, conf) in enumerate(res["top5"]): li = CLASS_ICONS.get(label, "•") bar_c = cfg["color"] if rank == 0 else "#bbb" bg = f"{cfg['color']}18" if rank == 0 else "transparent" fw = "700" if rank == 0 else "400" medal = "🥇" if rank == 0 else f"#{rank+1}" st.markdown( f"""
{medal} {li} {label} {conf:.1%}
""", unsafe_allow_html=True, ) # Footer border st.markdown( f'
', unsafe_allow_html=True, ) st.markdown("") st.divider() # ───────────────────────────────────────────────────────────────────────────── # LIVE INFERENCE TIME — vertical bars # ───────────────────────────────────────────────────────────────────────────── st.markdown("### ⚡ Live Inference Time") max_ms = max(r["ms"] for r in all_results.values()) or 1 speed_cols = st.columns(len(all_results)) for col, (mname, res) in zip(speed_cols, all_results.items()): cfg = res["cfg"] pct = int((res["ms"] / max_ms) * 100) with col: st.markdown( f"""

{cfg['icon']} {mname}

{res['ms']:.0f} ms

""", unsafe_allow_html=True, ) st.caption("📝 Inference times measured live on this machine. GPU will be significantly faster.")