Insect_Streamlit_App / Insect_HFspace_Streamlit_App.py
kkthyagharajan's picture
Update Insect_HFspace_Streamlit_App.py
632bfaf verified
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 18 09:07:10 2025
@author: THYAGHARAJAN
"""
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, list_repo_files
import os
os.environ["STREAMLIT_SERVER_ENABLE_CORS"] = "false"
os.environ["STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION"] = "false"
# ------------------------------
# CONFIGURATION
# ------------------------------
REPO_ID = "kkthyagharajan/KKT-HF-TransferLearning-Models" # <<< CHANGE THIS
IMG_SIZE = (300, 300)
st.set_page_config(page_title="Insect Classifier", layout="wide")
# Cache dictionaries
@st.cache_resource
def load_tf_model(model_path):
return tf.keras.models.load_model(model_path, compile=False)
@st.cache_resource
def load_class_names(model_dir):
class_file = hf_hub_download(repo_id=REPO_ID, filename=f"{model_dir}/class_names.txt")
with open(class_file, "r") as f:
return [x.strip() for x in f.read().split(",")]
# ----------------------------------
# Helper Functions
# ----------------------------------
def get_available_models():
"""Return mapping: model_dir β†’ model file (.h5 preferred over .keras)."""
files = list_repo_files(REPO_ID)
models = {}
# Prefer .h5
for file in files:
if file.endswith(".h5"):
dir = file.split("/")[0]
models[dir] = file
# Use .keras only if .h5 missing
for file in files:
if file.endswith(".keras"):
dir = file.split("/")[0]
if dir not in models:
models[dir] = file
return models
def get_sample_images(model_dir):
"""List sample images inside model_dir/sample_images/"""
files = list_repo_files(REPO_ID)
sample_imgs = []
prefix = f"{model_dir}/sample_images/"
for f in files:
if f.startswith(prefix) and f.lower().endswith((".jpg", ".jpeg", ".png")):
sample_imgs.append(f.replace(prefix, ""))
return sample_imgs
def load_sample_image(model_dir, image_name):
"""Download sample image."""
path = hf_hub_download(repo_id=REPO_ID, filename=f"{model_dir}/sample_images/{image_name}")
return Image.open(path)
def preprocess(img):
img = img.resize(IMG_SIZE)
arr = np.array(img) / 255.0
arr = arr.reshape(1, IMG_SIZE[0], IMG_SIZE[1], 3)
return arr
# ----------------------------------
# UI Layout
# ----------------------------------
st.title("πŸ¦‹ Insect Classification System")
st.markdown("""
### A Multi-Model Deep Learning Web App
Developed by **Dr. Thyagharajan K K, Professor & Dean (Research)**
RMD Engineering College
""")
col1, col2 = st.columns([1, 1])
# ----------------------------------
# LEFT PANEL
# ----------------------------------
with col1:
st.subheader("1️⃣ Select Model")
models = get_available_models()
if not models:
st.error("No models found in HuggingFace repo.")
st.stop()
model_choice = st.selectbox("Choose a model", list(models.keys()))
st.subheader("2️⃣ Choose Image Source")
input_mode = st.radio("Select input method:", ["Upload Image", "Use Sample Image"])
input_image = None
# Upload
if input_mode == "Upload Image":
uploaded = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"])
if uploaded:
input_image = Image.open(uploaded)
# Sample Images
else:
sample_images = get_sample_images(model_choice)
if sample_images:
selected_sample = st.selectbox("Choose sample image", sample_images)
if selected_sample:
input_image = load_sample_image(model_choice, selected_sample)
st.image(input_image, caption="Sample Image", width=250)
else:
st.warning("No sample images found for this model.")
st.markdown("---")
predict_btn = st.button("πŸ” Predict", use_container_width=True)
# ----------------------------------
# RIGHT PANEL
# ----------------------------------
with col2:
st.subheader("πŸ“Š Prediction Results")
if predict_btn:
if input_image is None:
st.error("Please upload or select an image.")
else:
# Show image
st.image(input_image, caption="Input Image", width=300)
# Load model
model_path = hf_hub_download(repo_id=REPO_ID, filename=models[model_choice])
model = load_tf_model(model_path)
class_names = load_class_names(model_choice)
# Predict
arr = preprocess(input_image)
preds = model.predict(arr, verbose=0)[0]
idx = np.argmax(preds)
predicted = class_names[idx]
st.success(f"### 🟩 Predicted: **{predicted}** ({preds[idx]*100:.2f}%)")
# Top-3 Predictions
st.subheader("Top 3 Predictions")
top3 = preds.argsort()[-3:][::-1]
for i in top3:
st.write(f"**{class_names[i]}** β€” {preds[i]*100:.2f}%")
# Footer
st.markdown("---")
st.markdown("""
**Developed by:** Dr. Thyagharajan K K
**Professor & Dean (Research)**
RMD Engineering College
πŸ“§ **kkthyagharajan@yahoo.com**
""")