|
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Image Classifier", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
import joblib |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_rgb_hist(image_path): |
|
|
img = cv2.imread(image_path) |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
hist_r = cv2.calcHist([img], [0], None, [256], [0, 256]).flatten() |
|
|
hist_g = cv2.calcHist([img], [1], None, [256], [0, 256]).flatten() |
|
|
hist_b = cv2.calcHist([img], [2], None, [256], [0, 256]).flatten() |
|
|
|
|
|
return np.concatenate((hist_r, hist_g, hist_b)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
svm = joblib.load("svm_model_1.pkl") |
|
|
dt = joblib.load("decision_tree_model.pkl") |
|
|
rf = joblib.load("random_forest_model.pkl") |
|
|
return svm, dt, rf |
|
|
|
|
|
|
|
|
|
|
|
svm_model, dt_model, rf_model = load_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("๐ Image Classification App by Delumi!") |
|
|
st.write("Upload an image and select a model to classify it.") |
|
|
|
|
|
model_choice = st.selectbox( |
|
|
"Choose a Model:", |
|
|
("SVM (Best Tuned)", "Decision Tree", "Random Forest") |
|
|
) |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Upload Image", |
|
|
type=["jpg", "jpeg", "png"] |
|
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
temp_path = "temp_image.jpg" |
|
|
|
|
|
with open(temp_path, "wb") as f: |
|
|
f.write(uploaded_file.read()) |
|
|
|
|
|
st.image(temp_path, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
features = extract_rgb_hist(temp_path).reshape(1, -1) |
|
|
|
|
|
if model_choice == "SVM (Best Tuned)": |
|
|
model = svm_model |
|
|
elif model_choice == "Decision Tree": |
|
|
model = dt_model |
|
|
else: |
|
|
model = rf_model |
|
|
|
|
|
prediction = model.predict(features)[0] |
|
|
|
|
|
st.subheader("๐ฎ Prediction Result") |
|
|
st.write(f"**Model Used:** {model_choice}") |
|
|
st.success(f"**Predicted Class:** {prediction}") |
|
|
|
|
|
os.remove(temp_path) |
|
|
|