import streamlit as st import numpy as np import tensorflow as tf from tensorflow.keras.applications.vgg16 import preprocess_input as vgg_preprocess from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess from tensorflow.keras.applications.efficientnet import preprocess_input as eff_preprocess from PIL import Image import os import time # ───────────────────────────────────────────────────────────────────────────── # CLASS NAMES # Hardcoded from training notebooks (flow_from_directory sorts alphabetically) # This is the exact order your models learned — no class_indices.json needed. # ───────────────────────────────────────────────────────────────────────────── CLASS_NAMES = [ 'airplane', 'bench', 'bicycle', 'bird', 'bottle', 'bowl', 'bus', 'cake', 'car', 'cat', 'chair', 'couch', 'cow', 'cup', 'dog', 'elephant', 'horse', 'motorcycle', 'person', 'pizza', 'potted plant', 'stop sign', 'traffic light', 'train', 'truck', ] # Note: flow_from_directory loads classes in alphabetical order. # The list above is sorted A→Z to match exactly what your models output. CLASS_ICONS = { 'airplane': '✈️', 'bench': '🪑', 'bicycle': '🚲', 'bird': '🐦', 'bottle': '🍶', 'bowl': '🥣', 'bus': '🚌', 'cake': '🎂', 'car': '🚗', 'cat': '🐱', 'chair': '🪑', 'couch': '🛋️', 'cow': '🐮', 'cup': '☕', 'dog': '🐶', 'elephant': '🐘', 'horse': '🐴', 'motorcycle': '🏍️', 'person': '🧍', 'pizza': '🍕', 'potted plant': '🪴', 'stop sign': '🛑', 'traffic light': '🚦', 'train': '🚆', 'truck': '🚛', } # ───────────────────────────────────────────────────────────────────────────── # MODEL CONFIGS # Preprocessing matches exactly what each training notebook used # ───────────────────────────────────────────────────────────────────────────── # PLOT_DIR = "models/result" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) PLOT_DIR = os.path.join(BASE_DIR, "models", "result") MODEL_CONFIGS = { "EfficientNetB0": { "path": os.path.join(PLOT_DIR, "efficientnetb0_best.keras"), "preprocess": eff_preprocess, # Phase 2.4: preprocess_input from efficientnet "rescale": False, "color": "#f39c12", "icon": "🟨", "accuracy": "92%", "speed": "80 ms", "description": "Highest accuracy · Compound scaling", }, "ResNet50": { "path": os.path.join(PLOT_DIR, "resnet50_best.keras"), "preprocess": resnet_preprocess, # Phase 2.2: preprocess_input from resnet50 "rescale": False, "color": "#e74c3c", "icon": "🟥", "accuracy": "88%", "speed": "100 ms", "description": "Strong all-rounder · Residual learning", }, "MobileNetV2": { "path": os.path.join(PLOT_DIR, "mobilenetv2_final.h5"), "preprocess": None, # Phase 2.3: ImageDataGenerator(rescale=1./255) "rescale": True, "color": "#2ecc71", "icon": "🟩", "accuracy": "85%", "speed": "50 ms", "description": "Fastest · Lightweight · Edge-ready", }, "VGG16": { "path": os.path.join(PLOT_DIR, "vgg16_best.keras"), "preprocess": vgg_preprocess, # Phase 2.1: preprocess_input from vgg16 "rescale": False, "color": "#3498db", "icon": "🟦", "accuracy": "83%", "speed": "150 ms", "description": "Classic CNN · Reliable baseline", }, } # ───────────────────────────────────────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_all_models(): loaded = {} for name, cfg in MODEL_CONFIGS.items(): if os.path.exists(cfg["path"]): try: loaded[name] = tf.keras.models.load_model(cfg["path"]) except Exception as e: st.warning(f"⚠️ Failed to load {name}: {e}") loaded[name] = None else: loaded[name] = None return loaded def prepare_image(pil_img, cfg): """Resize, apply the correct preprocessing, return (1, 224, 224, 3) array.""" img = pil_img.resize((224, 224)) arr = np.array(img, dtype=np.float32) if cfg["rescale"]: arr = arr / 255.0 # MobileNetV2 path else: arr = cfg["preprocess"](arr) # VGG16 / ResNet50 / EfficientNet path return np.expand_dims(arr, axis=0) # ───────────────────────────────────────────────────────────────────────────── # PAGE HEADER # ───────────────────────────────────────────────────────────────────────────── st.title("🖼️ Image Classification") st.caption("Upload any image — all 4 CNN models classify it simultaneously") # ── Model status cards ──────────────────────────────────────────────────────── st.markdown("### 🤖 Available Models") cols = st.columns(4) for col, (mname, cfg) in zip(cols, MODEL_CONFIGS.items()): exists = os.path.exists(cfg["path"]) status_label = "✅ Ready" if exists else "❌ File missing" status_color = "#2ecc71" if exists else "#e74c3c" with col: st.markdown( f"""
{cfg['icon']}
{mname}
{cfg['description']}
🎯 {cfg['accuracy']} ⚡ {cfg['speed']}
{status_label}
{icon}
{consensus}
{n_agree} / {len(all_results)} models agree
{cfg['icon']} {mname}
{res['ms']:.0f} ms