# -*- coding: utf-8 -*- """ Created on Tue Nov 18 09:07:10 2025 @author: THYAGHARAJAN """ import streamlit as st import tensorflow as tf import numpy as np from PIL import Image from huggingface_hub import hf_hub_download, list_repo_files import os os.environ["STREAMLIT_SERVER_ENABLE_CORS"] = "false" os.environ["STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION"] = "false" # ------------------------------ # CONFIGURATION # ------------------------------ REPO_ID = "kkthyagharajan/KKT-HF-TransferLearning-Models" # <<< CHANGE THIS IMG_SIZE = (300, 300) st.set_page_config(page_title="Insect Classifier", layout="wide") # Cache dictionaries @st.cache_resource def load_tf_model(model_path): return tf.keras.models.load_model(model_path, compile=False) @st.cache_resource def load_class_names(model_dir): class_file = hf_hub_download(repo_id=REPO_ID, filename=f"{model_dir}/class_names.txt") with open(class_file, "r") as f: return [x.strip() for x in f.read().split(",")] # ---------------------------------- # Helper Functions # ---------------------------------- def get_available_models(): """Return mapping: model_dir → model file (.h5 preferred over .keras).""" files = list_repo_files(REPO_ID) models = {} # Prefer .h5 for file in files: if file.endswith(".h5"): dir = file.split("/")[0] models[dir] = file # Use .keras only if .h5 missing for file in files: if file.endswith(".keras"): dir = file.split("/")[0] if dir not in models: models[dir] = file return models def get_sample_images(model_dir): """List sample images inside model_dir/sample_images/""" files = list_repo_files(REPO_ID) sample_imgs = [] prefix = f"{model_dir}/sample_images/" for f in files: if f.startswith(prefix) and f.lower().endswith((".jpg", ".jpeg", ".png")): sample_imgs.append(f.replace(prefix, "")) return sample_imgs def load_sample_image(model_dir, image_name): """Download sample image.""" path = hf_hub_download(repo_id=REPO_ID, filename=f"{model_dir}/sample_images/{image_name}") return Image.open(path) def preprocess(img): img = img.resize(IMG_SIZE) arr = np.array(img) / 255.0 arr = arr.reshape(1, IMG_SIZE[0], IMG_SIZE[1], 3) return arr # ---------------------------------- # UI Layout # ---------------------------------- st.title("🦋 Insect Classification System") st.markdown(""" ### A Multi-Model Deep Learning Web App Developed by **Dr. Thyagharajan K K, Professor & Dean (Research)** RMD Engineering College """) col1, col2 = st.columns([1, 1]) # ---------------------------------- # LEFT PANEL # ---------------------------------- with col1: st.subheader("1️⃣ Select Model") models = get_available_models() if not models: st.error("No models found in HuggingFace repo.") st.stop() model_choice = st.selectbox("Choose a model", list(models.keys())) st.subheader("2️⃣ Choose Image Source") input_mode = st.radio( "Select input method:", ["Upload Image", "Use Sample Image", "Live Camera"] ) input_image = None # Upload if input_mode == "Upload Image": uploaded = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"]) if uploaded: input_image = Image.open(uploaded) # Sample Images elif input_mode == "Use Sample Image": sample_images = get_sample_images(model_choice) if sample_images: selected_sample = st.selectbox("Choose sample image", sample_images) if selected_sample: input_image = load_sample_image(model_choice, selected_sample) st.image(input_image, caption="Sample Image", width=250) else: st.warning("No sample images found for this model.") # Live Camera elif input_mode == "Live Camera": camera_image = st.camera_input("Take a picture using your webcam") if camera_image: input_image = Image.open(camera_image) st.image(input_image, caption="Live Camera Capture", width=250) st.markdown("---") predict_btn = st.button("🔍 Predict", use_container_width=True) # ---------------------------------- # RIGHT PANEL # ---------------------------------- with col2: st.subheader("📊 Prediction Results") if predict_btn: if input_image is None: st.error("Please upload or select an image.") else: # Show image st.image(input_image, caption="Input Image", width=300) # Load model model_path = hf_hub_download(repo_id=REPO_ID, filename=models[model_choice]) model = load_tf_model(model_path) class_names = load_class_names(model_choice) # Predict arr = preprocess(input_image) preds = model.predict(arr, verbose=0)[0] idx = np.argmax(preds) predicted = class_names[idx] st.success(f"### 🟩 Predicted: **{predicted}** ({preds[idx]*100:.2f}%)") # Top-3 Predictions st.subheader("Top 3 Predictions") top3 = preds.argsort()[-3:][::-1] for i in top3: st.write(f"**{class_names[i]}** — {preds[i]*100:.2f}%") # Footer st.markdown("---") st.markdown(""" **Developed by:** Dr. Thyagharajan K K **Professor & Dean (Research)** RMD Engineering College 📧 **kkthyagharajan@yahoo.com** """)