Update app.py
Browse files
app.py
CHANGED
|
@@ -6,10 +6,10 @@ import numpy as np
|
|
| 6 |
from sklearn.ensemble import RandomForestClassifier
|
| 7 |
from sklearn.preprocessing import LabelEncoder
|
| 8 |
|
| 9 |
-
|
| 10 |
DATASET_PATH = "PlantVillage"
|
| 11 |
|
| 12 |
-
#
|
| 13 |
@st.cache_resource
|
| 14 |
def load_data(data_dir, img_size=(64, 64)):
|
| 15 |
X, y = [], []
|
|
@@ -27,7 +27,7 @@ def load_data(data_dir, img_size=(64, 64)):
|
|
| 27 |
print(f"Error loading image {img_path}: {e}")
|
| 28 |
return np.array(X) / 255.0, np.array(y)
|
| 29 |
|
| 30 |
-
# β
|
| 31 |
@st.cache_resource
|
| 32 |
def train_model():
|
| 33 |
X, y = load_data(DATASET_PATH)
|
|
@@ -37,11 +37,10 @@ def train_model():
|
|
| 37 |
model.fit(X, y_encoded)
|
| 38 |
return model, le
|
| 39 |
|
| 40 |
-
# β
|
| 41 |
def get_all_images(data_dir):
|
| 42 |
image_paths = []
|
| 43 |
valid_extensions = (".jpg", ".jpeg", ".png")
|
| 44 |
-
|
| 45 |
for class_name in os.listdir(data_dir):
|
| 46 |
class_path = Path(data_dir) / class_name
|
| 47 |
if class_path.is_dir():
|
|
@@ -57,16 +56,23 @@ st.title("π₯ Potato Leaf Image Viewer + Classifier")
|
|
| 57 |
all_images = get_all_images(DATASET_PATH)
|
| 58 |
|
| 59 |
if all_images:
|
| 60 |
-
st.
|
|
|
|
|
|
|
| 61 |
image_display_names = [f"{cls} - {img.name}" for img, cls in all_images]
|
| 62 |
selected_display_name = st.selectbox("Select an image:", image_display_names)
|
| 63 |
|
|
|
|
| 64 |
selected_index = image_display_names.index(selected_display_name)
|
| 65 |
selected_image_path, actual_class = all_images[selected_index]
|
| 66 |
|
|
|
|
| 67 |
image = Image.open(selected_image_path).convert("RGB").resize((64, 64))
|
| 68 |
-
st.image(image, caption=f"Original Class: {actual_class}", use_container_width=True)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
model, le = train_model()
|
| 71 |
img_array = np.array(image).flatten().reshape(1, -1) / 255.0
|
| 72 |
prediction = model.predict(img_array)
|
|
@@ -75,4 +81,4 @@ if all_images:
|
|
| 75 |
st.markdown(f"### π§ Predicted Class: **{predicted_class}**")
|
| 76 |
|
| 77 |
else:
|
| 78 |
-
st.warning("No images found in the dataset.")
|
|
|
|
| 6 |
from sklearn.ensemble import RandomForestClassifier
|
| 7 |
from sklearn.preprocessing import LabelEncoder
|
| 8 |
|
| 9 |
+
|
| 10 |
DATASET_PATH = "PlantVillage"
|
| 11 |
|
| 12 |
+
# Loading image data from dataset
|
| 13 |
@st.cache_resource
|
| 14 |
def load_data(data_dir, img_size=(64, 64)):
|
| 15 |
X, y = [], []
|
|
|
|
| 27 |
print(f"Error loading image {img_path}: {e}")
|
| 28 |
return np.array(X) / 255.0, np.array(y)
|
| 29 |
|
| 30 |
+
# β
Training model using RandomForest
|
| 31 |
@st.cache_resource
|
| 32 |
def train_model():
|
| 33 |
X, y = load_data(DATASET_PATH)
|
|
|
|
| 37 |
model.fit(X, y_encoded)
|
| 38 |
return model, le
|
| 39 |
|
| 40 |
+
# β
Getting all images from dataset folders
|
| 41 |
def get_all_images(data_dir):
|
| 42 |
image_paths = []
|
| 43 |
valid_extensions = (".jpg", ".jpeg", ".png")
|
|
|
|
| 44 |
for class_name in os.listdir(data_dir):
|
| 45 |
class_path = Path(data_dir) / class_name
|
| 46 |
if class_path.is_dir():
|
|
|
|
| 56 |
all_images = get_all_images(DATASET_PATH)
|
| 57 |
|
| 58 |
if all_images:
|
| 59 |
+
st.success(f"β
Found {len(all_images)} images in dataset.")
|
| 60 |
+
|
| 61 |
+
# Show filenames for selection
|
| 62 |
image_display_names = [f"{cls} - {img.name}" for img, cls in all_images]
|
| 63 |
selected_display_name = st.selectbox("Select an image:", image_display_names)
|
| 64 |
|
| 65 |
+
# Find selected image
|
| 66 |
selected_index = image_display_names.index(selected_display_name)
|
| 67 |
selected_image_path, actual_class = all_images[selected_index]
|
| 68 |
|
| 69 |
+
# Load and show image (small size)
|
| 70 |
image = Image.open(selected_image_path).convert("RGB").resize((64, 64))
|
|
|
|
| 71 |
|
| 72 |
+
# Show image in smaller size
|
| 73 |
+
st.image(image, caption=f"π Original Class: {actual_class}", width=200)
|
| 74 |
+
|
| 75 |
+
# Predict
|
| 76 |
model, le = train_model()
|
| 77 |
img_array = np.array(image).flatten().reshape(1, -1) / 255.0
|
| 78 |
prediction = model.predict(img_array)
|
|
|
|
| 81 |
st.markdown(f"### π§ Predicted Class: **{predicted_class}**")
|
| 82 |
|
| 83 |
else:
|
| 84 |
+
st.warning("β οΈ No images found in the dataset. Please check the folder structure.")
|