Kalhar.Pandya
Added images and model with Git LFS
abd2768
import os
import cv2
import numpy as np
import pickle
import gradio as gr
# Import the feature extraction function from feature_extractor.py
from feature_extractor import extract_features_from_image
# Global variables for the models, class names, and training log
models = {} # This will be a dictionary with keys: 'svm', 'rf', 'combined'
class_names = []
training_log = ""
# ---------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------
def load_model(model_filename):
global models, class_names, training_log
if os.path.exists(model_filename):
print("Found existing model file. Loading...")
with open(model_filename, "rb") as f:
model_data = pickle.load(f)
models = model_data['models'] # Expecting a dict: {'svm': ..., 'rf': ..., 'combined': ...}
class_names = model_data['class_names']
training_log += "Loaded model from disk.\n"
print("Loaded models from disk.")
else:
print(f"Model file {model_filename} not found. Please train the model first.")
# ---------------------------------------------------------------------
# Gradio Classification Function with Model Selection
# ---------------------------------------------------------------------
def classify_new_image(input_image_path, model_choice):
"""
Expects input_image_path as a file path and model_choice as one of the keys in models.
Loads the image, processes it by extracting patches, classifies each patch, and aggregates
patch predictions. Also, draws transparent overlays on each patch according to its predicted label.
Returns:
annotated_image_rgb (numpy array): The image with transparent overlays.
final_prediction (str): The final predicted class.
prob_dict (dict): Dictionary of class probabilities.
"""
global models, training_log, class_names
progress_log = training_log + "\nStarting classification...\n"
if model_choice not in models:
raise ValueError(f"Model choice '{model_choice}' not found. Available choices: {list(models.keys())}")
classifier = models[model_choice]
# Load image using OpenCV from file path
image = cv2.imread(input_image_path)
if image is None:
raise ValueError("Error: Could not load image from the provided file path.")
# Resize the image to a fixed width (1000 px) while maintaining aspect ratio
fixed_width = 1000
height, width = image.shape[:2]
aspect_ratio = height / width
new_height = int(fixed_width * aspect_ratio)
resized_image = cv2.resize(image, (fixed_width, new_height))
progress_log += "Resized image to fixed width of 1000 pixels.\n"
# The image from cv2.imread is already in BGR format.
image_bgr = resized_image
progress_log += "Image loaded in BGR format.\n"
# Preprocessing – Convert to grayscale, apply Gaussian blur, and compute edges
gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (9, 9), 0)
edges = cv2.Canny(blurred, threshold1=0, threshold2=100)
progress_log += "Computed edges using Canny edge detection.\n"
# Patch extraction parameters
patch_size = (100, 100)
patch_w, patch_h = patch_size
img_h, img_w = gray.shape
valid_patch_count = 0
summed_probabilities = None
overlays_list = [] # To store (x, y, w, h, predicted_label) for each valid patch
# Loop over non-overlapping patches
for y in range(0, img_h - patch_h + 1, patch_h):
for x in range(0, img_w - patch_w + 1, patch_w):
patch_edges = edges[y:y+patch_h, x:x+patch_w]
patch = resized_image[y:y+patch_h, x:x+patch_w]
num_edge_pixels = np.sum(patch_edges > 0)
total_pixels = patch_w * patch_h
density = num_edge_pixels / total_pixels
progress_log += f"Patch at ({x}, {y}) - edge density: {density:.3f}\n"
if 0.0 < density < 0.5:
valid_patch_count += 1
features = extract_features_from_image(patch)
feature_vector = features['combined_features'].reshape(1, -1)
patch_probabilities = classifier.predict_proba(feature_vector)[0]
predicted_index = np.argmax(patch_probabilities)
predicted_label = class_names[predicted_index]
progress_log += f"Patch at ({x}, {y}) predicted: {predicted_label} with probabilities {patch_probabilities}\n"
overlays_list.append((x, y, patch_w, patch_h, predicted_label))
if summed_probabilities is None:
summed_probabilities = patch_probabilities
else:
summed_probabilities += patch_probabilities
# Fallback: if no valid patches are found, classify the whole image.
if valid_patch_count == 0:
progress_log += "No valid patches found. Falling back to whole image classification.\n"
features = extract_features_from_image(image_bgr)
feature_vector = features['combined_features'].reshape(1, -1)
summed_probabilities = classifier.predict_proba(feature_vector)[0]
valid_patch_count = 1
# Average the probabilities from all valid patches and normalize them
averaged_probabilities = summed_probabilities / valid_patch_count
normalized_probabilities = averaged_probabilities / np.sum(averaged_probabilities)
final_prediction_index = np.argmax(normalized_probabilities)
final_prediction = class_names[final_prediction_index]
prob_dict = {cls: float(normalized_probabilities[i]) for i, cls in enumerate(class_names)}
progress_log += "Classification completed.\n"
print(progress_log)
print(prob_dict)
# Create an annotated image with transparent overlays
annotated_image = resized_image.copy()
overlay = annotated_image.copy()
alpha = 0.4 # Transparency factor
# Define overlay colors in BGR for each class (adjust as desired)
color_map = {
'wood': (0, 255, 255), # Yellow (BGR format)
'brick': (0, 0, 255), # Red (BGR format)
'stone': (128, 128, 128) # Gray (BGR format)
}
for (x, y, w, h, label) in overlays_list:
label_lower = label.lower()
if prob_dict[label] < 0.2:
continue
color = color_map.get(label_lower, (0, 255, 0)) # Default to green if unknown
cv2.rectangle(overlay, (x, y), (x+w, y+h), color, thickness=-1)
# Blend the overlay with the original image
annotated_image = cv2.addWeighted(overlay, alpha, annotated_image, 1 - alpha, 0)
# Convert annotated image from BGR to RGB for Gradio display
annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
return final_prediction, prob_dict, annotated_image_rgb
# ---------------------------------------------------------------------
# Gradio Interface Setup using file paths and model selection
# ---------------------------------------------------------------------
if __name__ == "__main__":
model_filename = "./svm_rf_combined.pkl" # Adjust filename as needed
load_model(model_filename)
# Create a dropdown for model selection.
model_choices = list(models.keys()) if models else ['svm', 'rf', 'combined']
iface = gr.Interface(
fn=classify_new_image,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0])
],
outputs=[
gr.Label(label="Predicted Class"),
gr.Label(label="Probabilities"),
gr.Image(label="Annotated Image")
],
title="Stone, Wood, Brick Classifier",
description=("Upload an image and select a classifier model (svm, rf, combined) to classify it.\n\n"
"The image is processed by subdividing it into patches and aggregating the predictions. "
"Transparent overlays are drawn on detected objects. Progress logs are printed to the terminal.")
)
iface.launch(share=True)